├── .gitignore ├── .idea ├── .gitignore ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml ├── onpolicy.iml └── vcs.xml ├── LICENSE ├── README.md ├── environment.yaml ├── onpolicy ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── r_mappo │ │ ├── __init__.py │ │ ├── algorithm │ │ │ ├── rMAPPOPolicy.py │ │ │ └── r_actor_critic.py │ │ └── r_mappo.py │ └── utils │ │ ├── act.py │ │ ├── attention.py │ │ ├── cnn.py │ │ ├── distributions.py │ │ ├── invariant.py │ │ ├── mix.py │ │ ├── mlp.py │ │ ├── rnn.py │ │ ├── util.py │ │ └── vit.py ├── config.py ├── docs │ ├── Makefile │ ├── make.bat │ └── source │ │ ├── _templates │ │ ├── module.rst_t │ │ ├── package.rst_t │ │ └── toc.rst_t │ │ ├── conf.py │ │ ├── full.gif │ │ ├── index.rst │ │ ├── quickstart.rst │ │ └── setup.rst ├── envs │ ├── __init__.py │ ├── env_wrappers.py │ └── gridworld │ │ ├── .travis.yml │ │ ├── GridWorld_Env.py │ │ ├── LICENSE │ │ ├── README.md │ │ ├── benchmark.py │ │ ├── figures │ │ ├── BlockedUnlockPickup.png │ │ ├── DistShift1.png │ │ ├── DistShift2.png │ │ ├── KeyCorridorS3R1.png │ │ ├── KeyCorridorS3R2.png │ │ ├── KeyCorridorS3R3.png │ │ ├── KeyCorridorS4R3.png │ │ ├── KeyCorridorS5R3.png │ │ ├── KeyCorridorS6R3.png │ │ ├── LavaCrossingS11N5.png │ │ ├── LavaCrossingS9N1.png │ │ ├── LavaCrossingS9N2.png │ │ ├── LavaCrossingS9N3.png │ │ ├── LavaGapS6.png │ │ ├── ObstructedMaze-1Dl.png │ │ ├── ObstructedMaze-1Dlh.png │ │ ├── ObstructedMaze-1Dlhb.png │ │ ├── ObstructedMaze-1Q.png │ │ ├── ObstructedMaze-2Dl.png │ │ ├── ObstructedMaze-2Dlh.png │ │ ├── ObstructedMaze-2Dlhb.png │ │ ├── ObstructedMaze-2Q.png │ │ ├── ObstructedMaze-4Q.png │ │ ├── SimpleCrossingS11N5.png │ │ ├── SimpleCrossingS9N1.png │ │ ├── SimpleCrossingS9N2.png │ │ ├── SimpleCrossingS9N3.png │ │ ├── Unlock.png │ │ ├── UnlockPickup.png │ │ ├── door-key-curriculum.gif │ │ ├── door-key-env.png │ │ ├── dynamic_obstacles.gif │ │ ├── empty-env.png │ │ ├── fetch-env.png │ │ ├── four-rooms-env.png │ │ ├── gotodoor-6x6.mp4 │ │ ├── gotodoor-6x6.png │ │ └── multi-room.gif │ │ ├── frontier │ │ ├── apf.py │ │ ├── nearest.py │ │ ├── rrt.py │ │ ├── utility.py │ │ ├── utils.py │ │ └── voronoi.py │ │ ├── gym_minigrid │ │ ├── __init__.py │ │ ├── envs │ │ │ ├── __init__.py │ │ │ ├── blockedunlockpickup.py │ │ │ ├── crossing.py │ │ │ ├── distshift.py │ │ │ ├── doorkey.py │ │ │ ├── dynamicobstacles.py │ │ │ ├── empty.py │ │ │ ├── fetch.py │ │ │ ├── fourrooms.py │ │ │ ├── gotodoor.py │ │ │ ├── gotoobject.py │ │ │ ├── human.py │ │ │ ├── irregular_room.py │ │ │ ├── keycorridor.py │ │ │ ├── lavagap.py │ │ │ ├── lockedroom.py │ │ │ ├── memory.py │ │ │ ├── multiexploration.py │ │ │ ├── multiroom.py │ │ │ ├── obstructedmaze.py │ │ │ ├── playground_v0.py │ │ │ ├── putnear.py │ │ │ ├── redbluedoors.py │ │ │ ├── unlock.py │ │ │ └── unlockpickup.py │ │ ├── minigrid.py │ │ ├── register.py │ │ ├── rendering.py │ │ ├── roomgrid.py │ │ ├── window.py │ │ └── wrappers.py │ │ ├── manual_control.py │ │ ├── run_tests.py │ │ └── setup.py ├── runner │ └── shared │ │ ├── base_runner.py │ │ └── gridworld_runner.py ├── scripts │ ├── __init__.py │ ├── render │ │ ├── __init__.py │ │ ├── render_gridworld.py │ │ └── render_gridworld_ft.py │ ├── render_gridworld.sh │ ├── render_gridworld_ft.sh │ ├── train │ │ ├── __init__.py │ │ └── train_gridworld.py │ └── train_gridworld.sh └── utils │ ├── RRT │ ├── __init__.py │ ├── rrt.py │ ├── rrt_with_pathsmoothing.py │ ├── rrt_with_sobol_sampler.py │ └── sobol │ │ ├── __init__.py │ │ └── sobol.py │ ├── __init__.py │ ├── multi_discrete.py │ ├── shared_buffer.py │ ├── util.py │ └── valuenorm.py ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | results 3 | *.pyc 4 | /gifs/ 5 | build 6 | win_rate 7 | *.so 8 | *.csv 9 | RODE 10 | *.egg-info 11 | wandb 12 | .vscode 13 | api 14 | logs 15 | .nfs* 16 | *.so 17 | *.out 18 | png 19 | *.log 20 | 21 | docker 22 | dataset 23 | data 24 | pretrained_models 25 | 26 | onpolicy/scripts/gifs/ 27 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Default ignored files 3 | /workspace.xml -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/onpolicy.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 14 | 15 | 16 | 19 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yang Xinyi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Asynchronous Multi-Agent Reinforcement Learning for Efficient Real-Time Multi-Robot Cooperative Exploration 2 | 3 | This is a PyTorch implementation of the paper: [Asynchronous Multi-Agent Reinforcement Learning for Efficient Real-Time Multi-Robot Cooperative Exploration](https://arxiv.org/abs/2301.03398) 4 | 5 | Project Website: https://sites.google.com/view/ace-aamas 6 | 7 | ## Training 8 | 9 | You could start training with by running `sh train_gridworld.sh` in directory [onpolicy/scripts](onpolicy/scripts). 10 | 11 | ## Evaluation 12 | 13 | Similar to training, you could run `sh render_gridworld.sh` in directory [onpolicy/scripts](onpolicy/scripts) to start evaluation. Remember to set up your path to the cooresponding model, correct hyperparameters and related evaluation parameters. 14 | 15 | We also provide our implementations of planning-based baselines. You could run `sh render_gridworld_ft.sh` to evaluate the planning-based methods. Note that `algorithm_name` determines the method to make global planning. It can be set to one of `mappo`, `ft_rrt`, `ft_apf`, `ft_nearest` and `ft_utility`. 16 | 17 | You could also visualize the result and generate gifs by adding `--use_render` and `--save_gifs` to the scripts. 18 | 19 | ## Citation 20 | If you find this repository useful, please cite our [paper](https://arxiv.org/abs/2301.03398): 21 | ``` 22 | @misc{yu2023asynchronous, 23 | title={Asynchronous Multi-Agent Reinforcement Learning for Efficient Real-Time Multi-Robot Cooperative Exploration}, 24 | author={Chao Yu and Xinyi Yang and Jiaxuan Gao and Jiayu Chen and Yunfei Li and Jijia Liu and Yunfei Xiang and Ruixin Huang and Huazhong Yang and Yi Wu and Yu Wang}, 25 | year={2023}, 26 | eprint={2301.03398}, 27 | archivePrefix={arXiv}, 28 | primaryClass={cs.RO} 29 | } 30 | ``` -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: marl 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - _tflow_select=2.1.0=gpu 7 | - absl-py=0.9.0=py36_0 8 | - astor=0.8.0=py36_0 9 | - blas=1.0=mkl 10 | - c-ares=1.15.0=h7b6447c_1001 11 | - ca-certificates=2020.1.1=0 12 | - certifi=2020.4.5.2=py36_0 13 | - cudatoolkit=10.0.130=0 14 | - cudnn=7.6.5=cuda10.0_0 15 | - cupti=10.0.130=0 16 | - gast=0.2.2=py36_0 17 | - google-pasta=0.2.0=py_0 18 | - grpcio=1.14.1=py36h9ba97e2_0 19 | - h5py=2.10.0=py36h7918eee_0 20 | - hdf5=1.10.4=hb1b8bf9_0 21 | - intel-openmp=2020.1=217 22 | - keras-applications=1.0.8=py_0 23 | - keras-preprocessing=1.1.0=py_1 24 | - libedit=3.1=heed3624_0 25 | - libffi=3.2.1=hd88cf55_4 26 | - libgcc-ng=9.1.0=hdf63c60_0 27 | - libgfortran-ng=7.3.0=hdf63c60_0 28 | - libprotobuf=3.12.3=hd408876_0 29 | - libstdcxx-ng=9.1.0=hdf63c60_0 30 | - markdown=3.1.1=py36_0 31 | - mkl=2020.1=217 32 | - mkl-service=2.3.0=py36he904b0f_0 33 | - mkl_fft=1.1.0=py36h23d657b_0 34 | - mkl_random=1.1.1=py36h0573a6f_0 35 | - ncurses=6.0=h9df7e31_2 36 | - numpy=1.18.1=py36h4f9e942_0 37 | - numpy-base=1.18.1=py36hde5b4d6_1 38 | - openssl=1.0.2u=h7b6447c_0 39 | - opt_einsum=3.1.0=py_0 40 | - pip=20.1.1=py36_1 41 | - protobuf=3.12.3=py36he6710b0_0 42 | - python=3.6.2=hca45abc_19 43 | - readline=7.0=ha6073c6_4 44 | - scipy=1.4.1=py36h0b6359f_0 45 | - setuptools=47.3.0=py36_0 46 | - six=1.15.0=py_0 47 | - sqlite=3.23.1=he433501_0 48 | - tensorboard=2.0.0=pyhb38c66f_1 49 | - tensorflow=2.0.0=gpu_py36h6b29c10_0 50 | - tensorflow-base=2.0.0=gpu_py36h0ec5d1f_0 51 | - tensorflow-estimator=2.0.0=pyh2649769_0 52 | - tensorflow-gpu=2.0.0=h0d30ee6_0 53 | - termcolor=1.1.0=py36_1 54 | - tk=8.6.8=hbc83047_0 55 | - werkzeug=0.16.1=py_0 56 | - wheel=0.34.2=py36_0 57 | - wrapt=1.12.1=py36h7b6447c_1 58 | - xz=5.2.5=h7b6447c_0 59 | - zlib=1.2.11=h7b6447c_3 60 | - pip: 61 | - aiohttp==3.6.2 62 | - aioredis==1.3.1 63 | - astunparse==1.6.3 64 | - async-timeout==3.0.1 65 | - atari-py==0.2.6 66 | - atomicwrites==1.2.1 67 | - attrs==18.2.0 68 | - beautifulsoup4==4.9.1 69 | - blessings==1.7 70 | - cachetools==4.1.1 71 | - cffi==1.14.1 72 | - chardet==3.0.4 73 | - click==7.1.2 74 | - cloudpickle==1.3.0 75 | - colorama==0.4.3 76 | - colorful==0.5.4 77 | - configparser==5.0.1 78 | - contextvars==2.4 79 | - cycler==0.10.0 80 | - cython==0.29.21 81 | - deepdiff==4.3.2 82 | - dill==0.3.2 83 | - docker-pycreds==0.4.0 84 | - docopt==0.6.2 85 | - fasteners==0.15 86 | - filelock==3.0.12 87 | - funcsigs==1.0.2 88 | - future==0.16.0 89 | - gin==0.1.6 90 | - gin-config==0.3.0 91 | - gitdb==4.0.5 92 | - gitpython==3.1.9 93 | - glfw==1.12.0 94 | - google==3.0.0 95 | - google-api-core==1.22.1 96 | - google-auth==1.21.0 97 | - google-auth-oauthlib==0.4.1 98 | - googleapis-common-protos==1.52.0 99 | - gpustat==0.6.0 100 | - gql==0.2.0 101 | - graphql-core==1.1 102 | - gym==0.17.2 103 | - hiredis==1.1.0 104 | - idna==2.7 105 | - idna-ssl==1.1.0 106 | - imageio==2.4.1 107 | - immutables==0.14 108 | - importlib-metadata==1.7.0 109 | - joblib==0.16.0 110 | - jsonnet==0.16.0 111 | - jsonpickle==0.9.6 112 | - jsonschema==3.2.0 113 | - kiwisolver==1.0.1 114 | - lockfile==0.12.2 115 | - mappo==0.0.1 116 | - matplotlib==3.0.0 117 | - mock==2.0.0 118 | - monotonic==1.5 119 | - more-itertools==4.3.0 120 | - mpi4py==3.0.3 121 | - mpyq==0.2.5 122 | - msgpack==1.0.0 123 | - mujoco-py==2.0.2.13 124 | - mujoco-worldgen==0.0.0 125 | - multidict==4.7.6 126 | - munch==2.3.2 127 | - nvidia-ml-py3==7.352.0 128 | - oauthlib==3.1.0 129 | - opencensus==0.7.10 130 | - opencensus-context==0.1.1 131 | - opencv-python==4.2.0.34 132 | - ordered-set==4.0.2 133 | - packaging==20.4 134 | - pandas==1.1.1 135 | - pathlib2==2.3.2 136 | - pathtools==0.1.2 137 | - pbr==4.3.0 138 | - pillow==5.3.0 139 | - pluggy==0.7.1 140 | - portpicker==1.2.0 141 | - probscale==0.2.3 142 | - progressbar2==3.53.1 143 | - prometheus-client==0.8.0 144 | - promise==2.3 145 | - psutil==5.7.2 146 | - py==1.6.0 147 | - py-spy==0.3.3 148 | - pyasn1==0.4.8 149 | - pyasn1-modules==0.2.8 150 | - pycparser==2.20 151 | - pygame==1.9.4 152 | - pyglet==1.5.0 153 | - pyopengl==3.1.5 154 | - pyopengl-accelerate==3.1.5 155 | - pyparsing==2.2.2 156 | - pyrsistent==0.16.0 157 | - pysc2==3.0.0 158 | - pytest==3.8.2 159 | - python-dateutil==2.7.3 160 | - python-utils==2.4.0 161 | - pytz==2020.1 162 | - pyyaml==3.13 163 | - pyzmq==19.0.2 164 | - ray==0.8.0 165 | - redis==3.4.1 166 | - requests==2.24.0 167 | - requests-oauthlib==1.3.0 168 | - rsa==4.6 169 | - s2clientprotocol==4.10.1.75800.0 170 | - s2protocol==4.11.4.78285.0 171 | - sacred==0.7.2 172 | - seaborn==0.10.1 173 | - sentry-sdk==0.18.0 174 | - shortuuid==1.0.1 175 | - sk-video==1.1.10 176 | - smmap==3.0.4 177 | - snakeviz==1.0.0 178 | - soupsieve==2.0.1 179 | - subprocess32==3.5.4 180 | - tabulate==0.8.7 181 | - tensorboard-logger==0.1.0 182 | - tensorboard-plugin-wit==1.7.0 183 | - tensorboardx==2.0 184 | - torch==1.5.1+cu101 185 | - torchvision==0.6.1+cu101 186 | - tornado==5.1.1 187 | - tqdm==4.48.2 188 | - typing-extensions==3.7.4.3 189 | - urllib3==1.23 190 | - wandb==0.10.5 191 | - watchdog==0.10.3 192 | - websocket-client==0.53.0 193 | - whichcraft==0.5.2 194 | - xmltodict==0.12.0 195 | - yarl==1.5.1 196 | - zipp==3.1.0 197 | - zmq==0.0.0 198 | prefix: /home/yuchao/anaconda3/envs/marl 199 | -------------------------------------------------------------------------------- /onpolicy/__init__.py: -------------------------------------------------------------------------------- 1 | from onpolicy import algorithms, envs, runner, scripts, utils, config 2 | 3 | 4 | __version__ = "0.1.0" 5 | 6 | __all__ = [ 7 | "algorithms", 8 | "envs", 9 | "runner", 10 | "scripts", 11 | "utils", 12 | "config", 13 | ] -------------------------------------------------------------------------------- /onpolicy/algorithms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/algorithms/__init__.py -------------------------------------------------------------------------------- /onpolicy/algorithms/r_mappo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/algorithms/r_mappo/__init__.py -------------------------------------------------------------------------------- /onpolicy/algorithms/r_mappo/algorithm/rMAPPOPolicy.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from onpolicy.algorithms.r_mappo.algorithm.r_actor_critic import R_Actor, R_Critic 4 | from onpolicy.utils.util import update_linear_schedule 5 | 6 | 7 | class R_MAPPOPolicy: 8 | def __init__(self, args, obs_space, share_obs_space, act_space, device=torch.device("cpu")): 9 | 10 | self.device = device 11 | self.lr = args.lr 12 | self.critic_lr = args.critic_lr 13 | self.opti_eps = args.opti_eps 14 | self.weight_decay = args.weight_decay 15 | 16 | self.obs_space = obs_space 17 | self.share_obs_space = share_obs_space 18 | self.act_space = act_space 19 | 20 | self.actor = R_Actor(args, self.obs_space, self.act_space, self.device) 21 | self.critic = R_Critic(args, self.share_obs_space, self.device) 22 | 23 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=self.lr, eps=self.opti_eps, weight_decay=self.weight_decay) 24 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=self.critic_lr, eps=self.opti_eps, weight_decay=self.weight_decay) 25 | 26 | def lr_decay(self, episode, episodes): 27 | update_linear_schedule(self.actor_optimizer, episode, episodes, self.lr) 28 | update_linear_schedule(self.critic_optimizer, episode, episodes, self.critic_lr) 29 | 30 | def get_actions(self, share_obs, obs, rnn_states_actor, rnn_states_critic, masks, available_actions=None, deterministic=False): 31 | actions, action_log_probs, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic) 32 | values, rnn_states_critic = self.critic(share_obs, rnn_states_critic, masks) 33 | return values, actions, action_log_probs, rnn_states_actor, rnn_states_critic 34 | 35 | def get_values(self, share_obs, rnn_states_critic, masks): 36 | values, _ = self.critic(share_obs, rnn_states_critic, masks) 37 | return values 38 | 39 | def evaluate_actions(self, share_obs, obs, rnn_states_actor, rnn_states_critic, action, masks, available_actions=None, active_masks=None): 40 | action_log_probs, dist_entropy, policy_values = self.actor.evaluate_actions(obs, rnn_states_actor, action, masks, available_actions, active_masks) 41 | values, _ = self.critic(share_obs, rnn_states_critic, masks) 42 | return values, action_log_probs, dist_entropy, policy_values 43 | 44 | def act(self, obs, rnn_states_actor, masks, available_actions=None, deterministic=False): 45 | actions, _, rnn_states_actor = self.actor(obs, rnn_states_actor, masks, available_actions, deterministic) 46 | return actions, rnn_states_actor 47 | -------------------------------------------------------------------------------- /onpolicy/algorithms/r_mappo/algorithm/r_actor_critic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from onpolicy.algorithms.utils.util import init, check 9 | from onpolicy.algorithms.utils.cnn import CNNBase 10 | from onpolicy.algorithms.utils.mlp import MLPBase, MLPLayer 11 | from onpolicy.algorithms.utils.mix import MIXBase 12 | from onpolicy.algorithms.utils.rnn import RNNLayer 13 | from onpolicy.algorithms.utils.act import ACTLayer 14 | from onpolicy.utils.util import get_shape_from_obs_space 15 | 16 | class R_Actor(nn.Module): 17 | def __init__(self, args, obs_space, action_space, device=torch.device("cpu")): 18 | super(R_Actor, self).__init__() 19 | self.hidden_size = args.hidden_size 20 | 21 | self._gain = args.gain 22 | self._use_orthogonal = args.use_orthogonal 23 | self._activation_id = args.activation_id 24 | self._use_policy_active_masks = args.use_policy_active_masks 25 | self._use_naive_recurrent_policy = args.use_naive_recurrent_policy 26 | self._use_recurrent_policy = args.use_recurrent_policy 27 | self._use_policy_vhead = args.use_policy_vhead 28 | self._recurrent_N = args.recurrent_N 29 | self._grid_goal = args.grid_goal 30 | self.tpdv = dict(dtype=torch.float32, device=device) 31 | 32 | obs_shape = get_shape_from_obs_space(obs_space) 33 | 34 | if 'Dict' in obs_shape.__class__.__name__: 35 | self._mixed_obs = True 36 | self.base = MIXBase(args, obs_shape, cnn_layers_params=args.cnn_layers_params) 37 | else: 38 | self._mixed_obs = False 39 | self.base = CNNBase(args, obs_shape) if len(obs_shape)==3 else MLPBase(args, obs_shape, use_attn_internal=args.use_attn_internal, use_cat_self=True) 40 | 41 | input_size = self.base.output_size 42 | 43 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 44 | self.rnn = RNNLayer(input_size, self.hidden_size, self._recurrent_N, self._use_orthogonal) 45 | input_size = self.hidden_size 46 | 47 | self.act = ACTLayer(action_space, input_size, self._use_orthogonal, self._gain, args=args) 48 | 49 | if self._use_policy_vhead: 50 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal] 51 | def init_(m): 52 | return init(m, init_method, lambda x: nn.init.constant_(x, 0)) 53 | self.v_out = init_(nn.Linear(input_size, 1)) 54 | 55 | self.to(device) 56 | 57 | def forward(self, obs, rnn_states, masks, available_actions=None, deterministic=False): 58 | if self._mixed_obs: 59 | for key in obs.keys(): 60 | obs[key] = check(obs[key]).to(**self.tpdv) 61 | else: 62 | obs = check(obs).to(**self.tpdv) 63 | rnn_states = check(rnn_states).to(**self.tpdv) 64 | masks = check(masks).to(**self.tpdv) 65 | 66 | if available_actions is not None: 67 | available_actions = check(available_actions).to(**self.tpdv) 68 | 69 | actor_features = self.base(obs) 70 | 71 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 72 | if self._grid_goal: 73 | actor_features[0], rnn_states = self.rnn(actor_features[0], rnn_states, masks) 74 | else: 75 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) 76 | 77 | actions, action_log_probs = self.act(actor_features, available_actions, deterministic) 78 | 79 | return actions, action_log_probs, rnn_states 80 | 81 | def evaluate_actions(self, obs, rnn_states, action, masks, available_actions=None, active_masks=None): 82 | if self._mixed_obs: 83 | for key in obs.keys(): 84 | obs[key] = check(obs[key]).to(**self.tpdv) 85 | else: 86 | obs = check(obs).to(**self.tpdv) 87 | 88 | rnn_states = check(rnn_states).to(**self.tpdv) 89 | action = check(action).to(**self.tpdv) 90 | masks = check(masks).to(**self.tpdv) 91 | 92 | if available_actions is not None: 93 | available_actions = check(available_actions).to(**self.tpdv) 94 | 95 | if active_masks is not None: 96 | active_masks = check(active_masks).to(**self.tpdv) 97 | 98 | actor_features = self.base(obs) 99 | 100 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 101 | if self._grid_goal: 102 | actor_features[0], rnn_states = self.rnn(actor_features[0], rnn_states, masks) 103 | else: 104 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) 105 | 106 | action_log_probs, dist_entropy = self.act.evaluate_actions(actor_features, action, available_actions, active_masks = active_masks if self._use_policy_active_masks else None) 107 | 108 | values = self.v_out(actor_features) if self._use_policy_vhead else None 109 | 110 | return action_log_probs, dist_entropy, values 111 | 112 | def get_policy_values(self, obs, rnn_states, masks): 113 | if self._mixed_obs: 114 | for key in obs.keys(): 115 | obs[key] = check(obs[key]).to(**self.tpdv) 116 | else: 117 | obs = check(obs).to(**self.tpdv) 118 | rnn_states = check(rnn_states).to(**self.tpdv) 119 | masks = check(masks).to(**self.tpdv) 120 | 121 | actor_features = self.base(obs) 122 | 123 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 124 | actor_features, rnn_states = self.rnn(actor_features, rnn_states, masks) 125 | 126 | values = self.v_out(actor_features) 127 | 128 | return values 129 | 130 | class R_Critic(nn.Module): 131 | def __init__(self, args, share_obs_space, device=torch.device("cpu")): 132 | super(R_Critic, self).__init__() 133 | self.hidden_size = args.hidden_size 134 | self._use_orthogonal = args.use_orthogonal 135 | self._activation_id = args.activation_id 136 | self._use_naive_recurrent_policy = args.use_naive_recurrent_policy 137 | self._use_recurrent_policy = args.use_recurrent_policy 138 | self._recurrent_N = args.recurrent_N 139 | self._grid_goal = args.grid_goal 140 | self.tpdv = dict(dtype=torch.float32, device=device) 141 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][self._use_orthogonal] 142 | 143 | share_obs_shape = get_shape_from_obs_space(share_obs_space) 144 | 145 | if 'Dict' in share_obs_shape.__class__.__name__: 146 | self._mixed_obs = True 147 | self.base = MIXBase(args, share_obs_shape, cnn_layers_params=args.cnn_layers_params) 148 | else: 149 | self._mixed_obs = False 150 | self.base = CNNBase(args, share_obs_shape) if len(share_obs_shape)==3 else MLPBase(args, share_obs_shape, use_attn_internal=True, use_cat_self=args.use_cat_self) 151 | 152 | input_size = self.base.output_size 153 | 154 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 155 | self.rnn = RNNLayer(input_size, self.hidden_size, self._recurrent_N, self._use_orthogonal) 156 | input_size = self.hidden_size 157 | 158 | def init_(m): 159 | return init(m, init_method, lambda x: nn.init.constant_(x, 0)) 160 | 161 | self.v_out = init_(nn.Linear(input_size, 1)) 162 | 163 | self.to(device) 164 | 165 | def forward(self, share_obs, rnn_states, masks): 166 | if self._mixed_obs: 167 | for key in share_obs.keys(): 168 | share_obs[key] = check(share_obs[key]).to(**self.tpdv) 169 | else: 170 | share_obs = check(share_obs).to(**self.tpdv) 171 | rnn_states = check(rnn_states).to(**self.tpdv) 172 | masks = check(masks).to(**self.tpdv) 173 | 174 | critic_features = self.base(share_obs) 175 | 176 | if self._grid_goal: 177 | critic_features = critic_features[0] 178 | 179 | if self._use_naive_recurrent_policy or self._use_recurrent_policy: 180 | critic_features, rnn_states = self.rnn(critic_features, rnn_states, masks) 181 | 182 | values = self.v_out(critic_features) 183 | 184 | return values, rnn_states 185 | -------------------------------------------------------------------------------- /onpolicy/algorithms/utils/cnn.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .util import init 8 | 9 | class Flatten(nn.Module): 10 | def forward(self, x): 11 | return x.view(x.size(0), -1) 12 | 13 | class CNNLayer(nn.Module): 14 | def __init__(self, obs_shape, hidden_size, use_orthogonal, activation_id, kernel_size=3, stride=1): 15 | super(CNNLayer, self).__init__() 16 | 17 | active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id] 18 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 19 | gain = nn.init.calculate_gain(['tanh', 'relu', 'leaky_relu', 'leaky_relu'][activation_id]) 20 | 21 | def init_(m): 22 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain) 23 | 24 | input_channel = obs_shape[0] 25 | input_width = obs_shape[1] 26 | input_height = obs_shape[2] 27 | 28 | self.cnn = nn.Sequential( 29 | init_(nn.Conv2d(in_channels=input_channel, out_channels=hidden_size//2, kernel_size=kernel_size, stride=stride)), active_func, 30 | Flatten(), 31 | init_(nn.Linear(hidden_size//2 * (input_width-kernel_size+stride) * (input_height-kernel_size+stride), hidden_size)), active_func, 32 | init_(nn.Linear(hidden_size, hidden_size)), active_func) 33 | 34 | def forward(self, x): 35 | x = x / 255.0 36 | x = self.cnn(x) 37 | 38 | return x 39 | 40 | class CNNBase(nn.Module): 41 | def __init__(self, args, obs_shape): 42 | super(CNNBase, self).__init__() 43 | 44 | self._use_orthogonal = args.use_orthogonal 45 | self._activation_id = args.activation_id 46 | self.hidden_size = args.hidden_size 47 | 48 | self.cnn = CNNLayer(obs_shape, self.hidden_size, self._use_orthogonal, self._activation_id) 49 | 50 | def forward(self, x): 51 | x = self.cnn(x) 52 | return x 53 | 54 | @property 55 | def output_size(self): 56 | return self.hidden_size 57 | -------------------------------------------------------------------------------- /onpolicy/algorithms/utils/distributions.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .util import init 8 | 9 | """ 10 | Modify standard PyTorch distributions so they are compatible with this code. 11 | """ 12 | 13 | # 14 | # Standardize distribution interfaces 15 | # 16 | 17 | # Categorical 18 | class FixedCategorical(torch.distributions.Categorical): 19 | def sample(self): 20 | return super().sample().unsqueeze(-1) 21 | 22 | def log_probs(self, actions): 23 | return ( 24 | super() 25 | .log_prob(actions.squeeze(-1)) 26 | .view(actions.size(0), -1) 27 | .sum(-1) 28 | .unsqueeze(-1) 29 | ) 30 | 31 | def mode(self): 32 | return self.probs.argmax(dim=-1, keepdim=True) 33 | 34 | 35 | # Normal 36 | class FixedNormal(torch.distributions.Normal): 37 | def log_probs(self, actions): 38 | return super().log_prob(actions).sum(-1, keepdim=True) 39 | 40 | def entrop(self): 41 | return super.entropy().sum(-1) 42 | 43 | def mode(self): 44 | return self.mean 45 | 46 | 47 | # Bernoulli 48 | class FixedBernoulli(torch.distributions.Bernoulli): 49 | def log_probs(self, actions): 50 | return super.log_prob(actions).view(actions.size(0), -1).sum(-1).unsqueeze(-1) 51 | 52 | def entropy(self): 53 | return super().entropy().sum(-1) 54 | 55 | def mode(self): 56 | return torch.gt(self.probs, 0.5).float() 57 | 58 | 59 | class Categorical(nn.Module): 60 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 61 | super(Categorical, self).__init__() 62 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 63 | def init_(m): 64 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 65 | 66 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 67 | 68 | def forward(self, x, available_actions=None, trans=True): 69 | if trans: 70 | x = self.linear(x) 71 | if available_actions is not None: 72 | x[available_actions == 0] = -1e10 73 | return FixedCategorical(logits=x) 74 | 75 | 76 | class DiagGaussian(nn.Module): 77 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 78 | super(DiagGaussian, self).__init__() 79 | 80 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 81 | def init_(m): 82 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 83 | 84 | self.fc_mean = init_(nn.Linear(num_inputs, num_outputs)) 85 | self.logstd = AddBias(torch.zeros(num_outputs)) 86 | 87 | def forward(self, x, trans=True): 88 | if trans: 89 | action_mean = self.fc_mean(x) 90 | else: 91 | action_mean = x 92 | 93 | # An ugly hack for my KFAC implementation. 94 | zeros = torch.zeros(action_mean.size()) 95 | if x.is_cuda: 96 | zeros = zeros.cuda() 97 | 98 | action_logstd = self.logstd(zeros) 99 | return FixedNormal(action_mean, action_logstd.exp()) 100 | 101 | 102 | class Bernoulli(nn.Module): 103 | def __init__(self, num_inputs, num_outputs, use_orthogonal=True, gain=0.01): 104 | super(Bernoulli, self).__init__() 105 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 106 | def init_(m): 107 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain) 108 | 109 | self.linear = init_(nn.Linear(num_inputs, num_outputs)) 110 | 111 | def forward(self, x): 112 | x = self.linear(x) 113 | return FixedBernoulli(logits=x) 114 | 115 | class AddBias(nn.Module): 116 | def __init__(self, bias): 117 | super(AddBias, self).__init__() 118 | self._bias = nn.Parameter(bias.unsqueeze(1)) 119 | 120 | def forward(self, x): 121 | if x.dim() == 2: 122 | bias = self._bias.t().view(1, -1) 123 | else: 124 | bias = self._bias.t().view(1, -1, 1, 1) 125 | 126 | return x + bias 127 | -------------------------------------------------------------------------------- /onpolicy/algorithms/utils/invariant.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from onpolicy.algorithms.utils.vit import ViT, Attention, PreNorm, Transformer, CrossAttention, FeedForward 6 | from einops.layers.torch import Rearrange 7 | from einops import rearrange, repeat 8 | import random 9 | 10 | def get_position_embedding(pos, hidden_dim, device = torch.device("cpu")): 11 | scaled_time = 2 * torch.arange(hidden_dim / 2) / hidden_dim 12 | scaled_time = 10000 ** scaled_time 13 | scaled_time = pos / scaled_time 14 | return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=0).to(device) 15 | 16 | def get_explicit_position_embedding(n_embed_input=10, n_embed_output=8 , device = torch.device("cpu")): 17 | 18 | return nn.Embedding(n_embed_input, n_embed_output).to(device) 19 | 20 | class AlterEncoder(nn.Module): 21 | def __init__(self, num_grids, input_dim, depth = 2, hidden_dim = 128, heads = 4, dim_head = 32, mlp_dim = 128, dropout = 0.): 22 | super().__init__() 23 | self.num_grids = num_grids 24 | self.hidden_dim = hidden_dim 25 | self.depth = depth 26 | self.encode_actor = nn.Linear(input_dim, hidden_dim) 27 | self.encode_other = nn.Linear(input_dim, hidden_dim) 28 | self.last_cross_attn = nn.ModuleList([ 29 | nn.LayerNorm(hidden_dim), 30 | nn.Linear(hidden_dim, 2 * heads * dim_head, bias = False), 31 | CrossAttention(hidden_dim, heads = heads, dim_head = dim_head, dropout = dropout), 32 | PreNorm(hidden_dim, FeedForward(hidden_dim, mlp_dim, dropout = dropout)) 33 | ]) 34 | 35 | def forward(self, data): 36 | x, others = data 37 | B = x.shape[0] 38 | # print("alter_attn", x.shape) 39 | x = self.encode_actor(x) 40 | all = [x,] 41 | for i, y in enumerate(others): 42 | y = self.encode_other(y) 43 | all.append(y) 44 | num_agents = len(all) 45 | out = torch.stack(all, dim = 1) # B x num_agents x 64 x D 46 | out = rearrange(out, "b n g d -> (b g) n d", b = B, n = num_agents, g = self.num_grids) 47 | norm, to_kv, cross_attn, ff= self.last_cross_attn 48 | out = norm(out) 49 | x = out[:, :1, :] # 64B x 1 x D 50 | others = out[:, 1:, :] # 64B x (n-1) x D 51 | if num_agents > 1: 52 | k, v = to_kv(others).chunk(2, dim=-1) 53 | out = cross_attn(x, k, v) + x # # 64B x 1 x D 54 | else: 55 | out = x 56 | out = ff(out) + out 57 | out = rearrange(out, " (b g) n d -> n b g d", b = B, n = 1, g = self.num_grids)[0] 58 | return out 59 | -------------------------------------------------------------------------------- /onpolicy/algorithms/utils/mlp.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .util import init, get_clones 8 | from .attention import Encoder 9 | 10 | class MLPLayer(nn.Module): 11 | def __init__(self, input_dim, hidden_size, layer_N, use_orthogonal, activation_id): 12 | super(MLPLayer, self).__init__() 13 | self._layer_N = layer_N 14 | 15 | active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id] 16 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 17 | gain = nn.init.calculate_gain(['tanh', 'relu', 'leaky_relu', 'leaky_relu'][activation_id]) 18 | 19 | def init_(m): 20 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain) 21 | 22 | self.fc1 = nn.Sequential( 23 | init_(nn.Linear(input_dim, hidden_size)), active_func, nn.LayerNorm(hidden_size)) 24 | self.fc_h = nn.Sequential(init_( 25 | nn.Linear(hidden_size, hidden_size)), active_func, nn.LayerNorm(hidden_size)) 26 | self.fc2 = get_clones(self.fc_h, self._layer_N) 27 | 28 | def forward(self, x): 29 | x = self.fc1(x) 30 | for i in range(self._layer_N): 31 | x = self.fc2[i](x) 32 | return x 33 | 34 | class CONVLayer(nn.Module): 35 | def __init__(self, input_dim, hidden_size, use_orthogonal, activation_id): 36 | super(CONVLayer, self).__init__() 37 | 38 | active_func = [nn.Tanh(), nn.ReLU(), nn.LeakyReLU(), nn.ELU()][activation_id] 39 | init_method = [nn.init.xavier_uniform_, nn.init.orthogonal_][use_orthogonal] 40 | gain = nn.init.calculate_gain(['tanh', 'relu', 'leaky_relu', 'leaky_relu'][activation_id]) 41 | 42 | def init_(m): 43 | return init(m, init_method, lambda x: nn.init.constant_(x, 0), gain=gain) 44 | 45 | self.conv = nn.Sequential( 46 | init_(nn.Conv1d(in_channels=input_dim, out_channels=hidden_size//4, kernel_size=3, stride=2, padding=0)), active_func, #nn.BatchNorm1d(hidden_size//4), 47 | init_(nn.Conv1d(in_channels=hidden_size//4, out_channels=hidden_size//2, kernel_size=3, stride=1, padding=1)), active_func, #nn.BatchNorm1d(hidden_size//2), 48 | init_(nn.Conv1d(in_channels=hidden_size//2, out_channels=hidden_size, kernel_size=3, stride=1, padding=1)), active_func) #, nn.BatchNorm1d(hidden_size)) 49 | 50 | def forward(self, x): 51 | x = self.conv(x) 52 | return x 53 | 54 | 55 | class MLPBase(nn.Module): 56 | def __init__(self, args, obs_shape, use_attn_internal=False, use_cat_self=True): 57 | super(MLPBase, self).__init__() 58 | 59 | self._use_feature_normalization = args.use_feature_normalization 60 | self._use_orthogonal = args.use_orthogonal 61 | self._activation_id = args.activation_id 62 | self._use_attn = args.use_attn 63 | self._use_attn_internal = use_attn_internal 64 | self._use_average_pool = args.use_average_pool 65 | self._use_conv1d = args.use_conv1d 66 | self._stacked_frames = args.stacked_frames 67 | self._layer_N = 0 if args.use_single_network else args.layer_N 68 | self._attn_size = args.attn_size 69 | self.hidden_size = args.hidden_size 70 | 71 | obs_dim = obs_shape[0] 72 | 73 | if self._use_feature_normalization: 74 | self.feature_norm = nn.LayerNorm(obs_dim) 75 | 76 | if self._use_attn and self._use_attn_internal: 77 | 78 | if self._use_average_pool: 79 | if use_cat_self: 80 | inputs_dim = self._attn_size + obs_shape[-1][1] 81 | else: 82 | inputs_dim = self._attn_size 83 | else: 84 | split_inputs_dim = 0 85 | split_shape = obs_shape[1:] 86 | for i in range(len(split_shape)): 87 | split_inputs_dim += split_shape[i][0] 88 | inputs_dim = split_inputs_dim * self._attn_size 89 | self.attn = Encoder(args, obs_shape, use_cat_self) 90 | self.attn_norm = nn.LayerNorm(inputs_dim) 91 | else: 92 | inputs_dim = obs_dim 93 | 94 | if self._use_conv1d: 95 | self.conv = CONVLayer(self._stacked_frames, self.hidden_size, self._use_orthogonal, self._activation_id) 96 | random_x = torch.FloatTensor(1, self._stacked_frames, inputs_dim//self._stacked_frames) 97 | random_out = self.conv(random_x) 98 | assert len(random_out.shape)==3 99 | inputs_dim = random_out.size(-1) * random_out.size(-2) 100 | 101 | self.mlp = MLPLayer(inputs_dim, self.hidden_size, 102 | self._layer_N, self._use_orthogonal, self._activation_id) 103 | 104 | def forward(self, x): 105 | if self._use_feature_normalization: 106 | x = self.feature_norm(x) 107 | 108 | if self._use_attn and self._use_attn_internal: 109 | x = self.attn(x, self_idx=-1) 110 | x = self.attn_norm(x) 111 | 112 | if self._use_conv1d: 113 | batch_size = x.size(0) 114 | x = x.view(batch_size, self._stacked_frames, -1) 115 | x = self.conv(x) 116 | x = x.view(batch_size, -1) 117 | 118 | x = self.mlp(x) 119 | 120 | return x 121 | 122 | @property 123 | def output_size(self): 124 | return self.hidden_size -------------------------------------------------------------------------------- /onpolicy/algorithms/utils/rnn.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | class RNNLayer(nn.Module): 9 | def __init__(self, inputs_dim, outputs_dim, recurrent_N, use_orthogonal): 10 | super(RNNLayer, self).__init__() 11 | self._recurrent_N = recurrent_N 12 | self._use_orthogonal = use_orthogonal 13 | 14 | self.rnn = nn.GRU(inputs_dim, outputs_dim, num_layers=self._recurrent_N) 15 | for name, param in self.rnn.named_parameters(): 16 | if 'bias' in name: 17 | nn.init.constant_(param, 0) 18 | elif 'weight' in name: 19 | if self._use_orthogonal: 20 | nn.init.orthogonal_(param) 21 | else: 22 | nn.init.xavier_uniform_(param) 23 | self.norm = nn.LayerNorm(outputs_dim) 24 | 25 | def forward(self, x, hxs, masks): 26 | if x.size(0) == hxs.size(0): 27 | x, hxs = self.rnn(x.unsqueeze(0), (hxs * masks.repeat(1, self._recurrent_N).unsqueeze(-1)).transpose(0, 1).contiguous()) 28 | #x= self.gru(x.unsqueeze(0)) 29 | x = x.squeeze(0) 30 | hxs = hxs.transpose(0, 1) 31 | else: 32 | # x is a (T, N, -1) tensor that has been flatten to (T * N, -1) 33 | N = hxs.size(0) 34 | T = int(x.size(0) / N) 35 | 36 | # unflatten 37 | x = x.view(T, N, x.size(1)) 38 | 39 | # Same deal with masks 40 | masks = masks.view(T, N) 41 | 42 | # Let's figure out which steps in the sequence have a zero for any agent 43 | # We will always assume t=0 has a zero in it as that makes the logic cleaner 44 | has_zeros = ((masks[1:] == 0.0) 45 | .any(dim=-1) 46 | .nonzero() 47 | .squeeze() 48 | .cpu()) 49 | 50 | # +1 to correct the masks[1:] 51 | if has_zeros.dim() == 0: 52 | # Deal with scalar 53 | has_zeros = [has_zeros.item() + 1] 54 | else: 55 | has_zeros = (has_zeros + 1).numpy().tolist() 56 | 57 | # add t=0 and t=T to the list 58 | has_zeros = [0] + has_zeros + [T] 59 | 60 | hxs = hxs.transpose(0, 1) 61 | 62 | outputs = [] 63 | for i in range(len(has_zeros) - 1): 64 | # We can now process steps that don't have any zeros in masks together! 65 | # This is much faster 66 | start_idx = has_zeros[i] 67 | end_idx = has_zeros[i + 1] 68 | temp = (hxs * masks[start_idx].view(1, -1, 1).repeat(self._recurrent_N, 1, 1)).contiguous() 69 | rnn_scores, hxs = self.rnn(x[start_idx:end_idx], temp) 70 | outputs.append(rnn_scores) 71 | 72 | # assert len(outputs) == T 73 | # x is a (T, N, -1) tensor 74 | x = torch.cat(outputs, dim=0) 75 | 76 | # flatten 77 | x = x.reshape(T * N, -1) 78 | hxs = hxs.transpose(0, 1) 79 | 80 | x = self.norm(x) 81 | return x, hxs 82 | -------------------------------------------------------------------------------- /onpolicy/algorithms/utils/util.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import copy 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | def init(module, weight_init, bias_init, gain=1): 10 | weight_init(module.weight.data, gain=gain) 11 | bias_init(module.bias.data) 12 | return module 13 | 14 | def get_clones(module, N): 15 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 16 | 17 | def check(input): 18 | output = torch.from_numpy(input) if type(input) == np.ndarray else input 19 | return output 20 | -------------------------------------------------------------------------------- /onpolicy/algorithms/utils/vit.py: -------------------------------------------------------------------------------- 1 | from os import XATTR_SIZE_MAX 2 | import torch 3 | from torch import nn, einsum 4 | 5 | from einops import rearrange, repeat 6 | from einops.layers.torch import Rearrange 7 | 8 | # https://github.com/lucidrains/vit-pytorch 9 | 10 | # helpers 11 | 12 | def pair(t): 13 | return t if isinstance(t, tuple) else (t, t) 14 | 15 | # classes 16 | 17 | class PreNorm(nn.Module): 18 | def __init__(self, dim, fn): 19 | super().__init__() 20 | self.norm = nn.LayerNorm(dim) 21 | self.fn = fn 22 | def forward(self, x, **kwargs): 23 | return self.fn(self.norm(x), **kwargs) 24 | 25 | class FeedForward(nn.Module): 26 | def __init__(self, dim, hidden_dim, dropout = 0.): 27 | super().__init__() 28 | self.net = nn.Sequential( 29 | nn.Linear(dim, hidden_dim), 30 | nn.GELU(), 31 | nn.Dropout(dropout), 32 | nn.Linear(hidden_dim, dim), 33 | nn.Dropout(dropout) 34 | ) 35 | def forward(self, x): 36 | return self.net(x) 37 | 38 | class Attention(nn.Module): 39 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 40 | super().__init__() 41 | inner_dim = dim_head * heads 42 | project_out = not (heads == 1 and dim_head == dim) 43 | 44 | self.heads = heads 45 | self.scale = dim_head ** -0.5 46 | 47 | self.attend = nn.Softmax(dim = -1) 48 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 49 | 50 | self.to_out = nn.Sequential( 51 | nn.Linear(inner_dim, dim), 52 | nn.Dropout(dropout) 53 | ) if project_out else nn.Identity() 54 | 55 | def forward(self, x): 56 | b, n, _, h = *x.shape, self.heads 57 | qkv = self.to_qkv(x).chunk(3, dim = -1) 58 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv) 59 | 60 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 61 | 62 | attn = self.attend(dots) 63 | 64 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 65 | out = rearrange(out, 'b h n d -> b n (h d)') 66 | return self.to_out(out) 67 | 68 | class CrossAttention(nn.Module): 69 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 70 | super().__init__() 71 | 72 | inner_dim = dim_head * heads 73 | project_out = not (heads == 1 and dim_head == dim) 74 | 75 | self.heads = heads 76 | self.scale = dim_head ** -0.5 77 | 78 | self.attend = nn.Softmax(dim = -1) 79 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 80 | 81 | self.to_out = nn.Sequential( 82 | nn.Linear(inner_dim, dim), 83 | nn.Dropout(dropout) 84 | ) if project_out else nn.Identity() 85 | 86 | def forward(self, x, k, v): 87 | b, n, _, h = *x.shape, self.heads 88 | q = self.to_q(x) 89 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), [q, k, v]) 90 | dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale 91 | attn = self.attend(dots) 92 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 93 | out = rearrange(out, 'b h n d -> b n (h d)') 94 | return self.to_out(out) 95 | 96 | class Transformer(nn.Module): 97 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 98 | super().__init__() 99 | self.layers = nn.ModuleList([]) 100 | for _ in range(depth): 101 | self.layers.append(nn.ModuleList([ 102 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 103 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 104 | ])) 105 | def forward(self, x): 106 | for attn, ff in self.layers: 107 | x = attn(x) + x 108 | x = ff(x) + x 109 | return x 110 | 111 | class ViT(nn.Module): 112 | def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.): 113 | super().__init__() 114 | image_height, image_width = pair(image_size) 115 | patch_height, patch_width = pair(patch_size) 116 | self.r, self.c = image_height // patch_height, image_width // patch_width 117 | 118 | assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' 119 | 120 | num_patches = (image_height // patch_height) * (image_width // patch_width) 121 | patch_dim = channels * patch_height * patch_width 122 | assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' 123 | 124 | self.to_patch_embedding = nn.Sequential( 125 | Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width), 126 | nn.Linear(patch_dim, dim), 127 | ) 128 | 129 | self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) 130 | self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) 131 | self.dropout = nn.Dropout(emb_dropout) 132 | 133 | self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) 134 | 135 | self.pool = pool 136 | self.to_latent = nn.Identity() 137 | 138 | self.mlp_head = nn.Sequential( 139 | nn.LayerNorm(dim), 140 | nn.Linear(dim, num_classes) 141 | ) 142 | 143 | self.recover = Rearrange('b (h w) c -> b c h w', h = self.r, w = self.c) 144 | 145 | def forward(self, img): 146 | x = self.to_patch_embedding(img) 147 | b, n, _ = x.shape 148 | 149 | cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) 150 | x = torch.cat((cls_tokens, x), dim=1) 151 | x += self.pos_embedding[:, :(n + 1)] 152 | x = self.dropout(x) 153 | 154 | x = self.transformer(x) 155 | 156 | # x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0] 157 | 158 | # x = self.to_latent(x) 159 | x = self.recover(x[:, 1:]) 160 | 161 | return x -------------------------------------------------------------------------------- /onpolicy/docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = onpolicy 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /onpolicy/docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | set SPHINXPROJ=onpolicy 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /onpolicy/docs/source/_templates/module.rst_t: -------------------------------------------------------------------------------- 1 | {%- if show_headings %} 2 | {{- [basename, "module"] | join(' ') | e | heading }} 3 | 4 | {% endif -%} 5 | .. automodule:: {{ qualname }} 6 | :members: 7 | :undoc-members: -------------------------------------------------------------------------------- /onpolicy/docs/source/_templates/package.rst_t: -------------------------------------------------------------------------------- 1 | {%- macro automodule(modname, options) -%} 2 | .. automodule:: {{ modname }} 3 | {%- for option in options %} 4 | :{{ option }}: 5 | {%- endfor %} 6 | {%- endmacro %} 7 | 8 | {%- macro toctree(docnames) -%} 9 | .. toctree:: 10 | :maxdepth: {{ maxdepth }} 11 | {% for docname in docnames %} 12 | {{ docname }} 13 | {%- endfor %} 14 | {%- endmacro %} 15 | 16 | {%- if is_namespace %} 17 | {{- [pkgname, "namespace"] | join(" ") | e | heading }} 18 | {% else %} 19 | {{- [pkgname, "package"] | join(" ") | e | heading }} 20 | {% endif %} 21 | 22 | {%- if modulefirst and not is_namespace %} 23 | {{ automodule(pkgname, automodule_options) }} 24 | {% endif %} 25 | 26 | {%- if subpackages %} 27 | {{ toctree(subpackages) }} 28 | {% endif %} 29 | 30 | {%- if submodules %} 31 | {% if separatemodules %} 32 | {{ toctree(submodules) }} 33 | {%- else %} 34 | {%- for submodule in submodules %} 35 | {% if show_headings %} 36 | {{- [submodule, "module"] | join(" ") | e | heading(2) }} 37 | {% endif %} 38 | {{ automodule(submodule, automodule_options) }} 39 | {% endfor %} 40 | {%- endif %} 41 | {% endif %} 42 | 43 | {%- if not modulefirst and not is_namespace %} 44 | Module contents 45 | --------------- 46 | 47 | {{ automodule(pkgname, automodule_options) }} 48 | {% endif %} -------------------------------------------------------------------------------- /onpolicy/docs/source/_templates/toc.rst_t: -------------------------------------------------------------------------------- 1 | {{ header | heading }} 2 | 3 | .. toctree:: 4 | :maxdepth: {{ maxdepth }} 5 | {% for docname in docnames %} 6 | {{ docname }} 7 | {%- endfor %} -------------------------------------------------------------------------------- /onpolicy/docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | import os 16 | import sys 17 | sys.path.insert(0, os.path.abspath("../..")) 18 | 19 | import sphinx_rtd_theme 20 | import sphinxcontrib 21 | from recommonmark.parser import CommonMarkParser 22 | 23 | 24 | # -- Project information ----------------------------------------------------- 25 | 26 | project = 'onpolicy' 27 | copyright = '2020, Chao Yu' 28 | author = 'Chao Yu' 29 | 30 | # The short X.Y version 31 | version = '' 32 | # The full version, including alpha/beta/rc tags 33 | release = '0.1.0' 34 | 35 | 36 | # -- General configuration --------------------------------------------------- 37 | 38 | # If your documentation needs a minimal Sphinx version, state it here. 39 | # 40 | # needs_sphinx = '1.0' 41 | 42 | # Add any Sphinx extension module names here, as strings. They can be 43 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 44 | # ones. 45 | extensions = [ 46 | "sphinx_rtd_theme", # Read The Docs theme 47 | 'sphinx.ext.autodoc', # Automatically extract docs from docstrings 48 | 'sphinx.ext.doctest', 49 | 'sphinx.ext.intersphinx', 50 | 'sphinx.ext.todo', 51 | 'sphinx.ext.coverage', # make coverage generates documentation coverage reports 52 | 'sphinx.ext.mathjax', 53 | 'sphinx.ext.ifconfig', 54 | 'sphinx.ext.viewcode', # link to sourcecode from docs 55 | 'sphinx.ext.githubpages', 56 | "sphinx.ext.napoleon", # support Numpy and Google doc style 57 | "sphinxcontrib.apidoc", 58 | ] 59 | 60 | # configuring automated generation of api documentation 61 | # See: https://github.com/sphinx-contrib/apidoc 62 | 63 | apidoc_module_dir = "../.." 64 | apidoc_excluded_paths = ["envs", "setup.py"] 65 | apidoc_module_first = True 66 | apidoc_extra_args = [ 67 | "--force", 68 | "--separate", 69 | "--ext-viewcode", 70 | "--doc-project=onpolicy", 71 | "--maxdepth=2", 72 | "--templatedir=_templates/apidoc", 73 | ] 74 | 75 | 76 | # Add any paths that contain templates here, relative to this directory. 77 | templates_path = ['_templates'] 78 | 79 | # The suffix(es) of source filenames. 80 | # You can specify multiple suffix as a list of string: 81 | # 82 | # source_suffix = ['.rst', '.md'] 83 | source_suffix = ['.rst', '.md', '.MD'] 84 | 85 | 86 | # The master toctree document. 87 | master_doc = 'index' 88 | 89 | # The language for content autogenerated by Sphinx. Refer to documentation 90 | # for a list of supported languages. 91 | # 92 | # This is also used if you do content translation via gettext catalogs. 93 | # Usually you set "language" from the command line for these cases. 94 | language = None 95 | 96 | # List of patterns, relative to source directory, that match files and 97 | # directories to ignore when looking for source files. 98 | # This pattern also affects html_static_path and html_extra_path . 99 | exclude_patterns = [] 100 | 101 | # The name of the Pygments (syntax highlighting) style to use. 102 | pygments_style = 'sphinx' 103 | 104 | 105 | # -- Options for HTML output ------------------------------------------------- 106 | 107 | # The theme to use for HTML and HTML Help pages. See the documentation for 108 | # a list of builtin themes. 109 | # 110 | html_theme = 'sphinx_rtd_theme' 111 | 112 | # Theme options are theme-specific and customize the look and feel of a theme 113 | # further. For a list of options available for each theme, see the 114 | # documentation. 115 | # 116 | # html_theme_options = {} 117 | 118 | # Add any paths that contain custom static files (such as style sheets) here, 119 | # relative to this directory. They are copied after the builtin static files, 120 | # so a file named "default.css" will overwrite the builtin "default.css". 121 | html_static_path = ['_static'] 122 | 123 | # Custom sidebar templates, must be a dictionary that maps document names 124 | # to template names. 125 | # 126 | # The default sidebars (for documents that don't match any pattern) are 127 | # defined by theme itself. Builtin themes are using these templates by 128 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 129 | # 'searchbox.html']``. 130 | # 131 | # html_sidebars = {} 132 | 133 | 134 | # -- Options for HTMLHelp output --------------------------------------------- 135 | 136 | # Output file base name for HTML help builder. 137 | htmlhelp_basename = 'onpolicydoc' 138 | 139 | 140 | # -- Options for LaTeX output ------------------------------------------------ 141 | 142 | latex_elements = { 143 | # The paper size ('letterpaper' or 'a4paper'). 144 | # 145 | # 'papersize': 'letterpaper', 146 | 147 | # The font size ('10pt', '11pt' or '12pt'). 148 | # 149 | # 'pointsize': '10pt', 150 | 151 | # Additional stuff for the LaTeX preamble. 152 | # 153 | # 'preamble': '', 154 | 155 | # Latex figure (float) alignment 156 | # 157 | # 'figure_align': 'htbp', 158 | } 159 | 160 | # Grouping the document tree into LaTeX files. List of tuples 161 | # (source start file, target name, title, 162 | # author, documentclass [howto, manual, or own class]). 163 | latex_documents = [ 164 | (master_doc, 'onpolicy.tex', 'onpolicy Documentation', 165 | 'Chao Yu', 'manual'), 166 | ] 167 | 168 | 169 | # -- Options for manual page output ------------------------------------------ 170 | 171 | # One entry per manual page. List of tuples 172 | # (source start file, name, description, authors, manual section). 173 | man_pages = [ 174 | (master_doc, 'onpolicy', 'onpolicy Documentation', 175 | [author], 1) 176 | ] 177 | 178 | 179 | # -- Options for Texinfo output ---------------------------------------------- 180 | 181 | # Grouping the document tree into Texinfo files. List of tuples 182 | # (source start file, target name, title, author, 183 | # dir menu entry, description, category) 184 | texinfo_documents = [ 185 | (master_doc, 'onpolicy', 'onpolicy Documentation', 186 | author, 'onpolicy', 'One line description of project.', 187 | 'Miscellaneous'), 188 | ] 189 | 190 | 191 | # -- Extension configuration ------------------------------------------------- 192 | 193 | # -- Options for intersphinx extension --------------------------------------- 194 | 195 | # Example configuration for intersphinx: refer to the Python standard library. 196 | intersphinx_mapping = {'https://docs.python.org/': None} 197 | 198 | # -- Options for todo extension ---------------------------------------------- 199 | 200 | # If true, `todo` and `todoList` produce output, else they produce nothing. 201 | todo_include_todos = True 202 | 203 | source_parsers = { 204 | '.md': CommonMarkParser, 205 | '.MD': CommonMarkParser, 206 | } -------------------------------------------------------------------------------- /onpolicy/docs/source/full.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/docs/source/full.gif -------------------------------------------------------------------------------- /onpolicy/docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. onpolicy documentation master file, created by 2 | sphinx-quickstart on Sun Dec 6 23:54:05 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to onpolicy's documentation! 7 | ==================================== 8 | 9 | .. image:: full.gif 10 | 11 | .. include:: setup.rst 12 | 13 | .. include:: quickstart.rst 14 | 15 | .. toctree:: 16 | :maxdepth: 2 17 | :caption: API Reference: 18 | 19 | api/modules.rst 20 | 21 | 22 | 23 | Indices and tables 24 | ================== 25 | 26 | * :ref:`genindex` 27 | * :ref:`modindex` 28 | * :ref:`search` 29 | -------------------------------------------------------------------------------- /onpolicy/docs/source/quickstart.rst: -------------------------------------------------------------------------------- 1 | Quickstart: 2 | ========================================= 3 | 4 | 5 | Training the Highway Environment with MAPPO Algorithm 6 | -------------------- 7 | 8 | One line to start training the agent in Highway Environment with MAPPO Algorithm: 9 | 1. change the root path into `scripts/train/` 10 | 2. run the following code 11 | 12 | .. code-block:: bash 13 | 14 | ./train_highway.sh 15 | 16 | 17 | 18 | Hyperparameters 19 | -------------------- 20 | 21 | Within `train_highway.sh` file, `train_highway.py` under `scripts/train/` is called with part hyperparameters specified. 22 | 23 | .. code-block:: bash 24 | 25 | CUDA_VISIBLE_DEVICES=0 python train/train_highway.py --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --scenario_name ${scenario} --task_type ${task} --n_attackers ${n_attackers} --n_defenders ${n_defenders} --n_dummies ${n_dummies} --seed ${seed} --n_training_threads 2 --n_rollout_threads 2 --n_eval_rollout_threads 2 --horizon 40 --episode_length 40 --log_interval 1 --use_wandb 26 | 27 | 28 | Hyperparameters contain two types 29 | 30 | - common hyperparameters used in all environments. Such parameters are parserd in ./config.py 31 | - private hyperparameters used in specified environment itself, which are parsered by ./scripts/train/train_.python 32 | 33 | Take highway env as an example, 34 | - the common hyperparameters are following: 35 | 36 | .. automodule:: config 37 | :members: 38 | 39 | - the private hyperparameters are following: 40 | 41 | Take highway environment as an example: 42 | 43 | .. automodule:: scripts.train.train_highway.make_train_env 44 | :members: 45 | 46 | Hierarchical structure 47 | -------------------- -------------------------------------------------------------------------------- /onpolicy/docs/source/setup.rst: -------------------------------------------------------------------------------- 1 | Quickstart: 2 | ========================================= 3 | .. code-block:: bash 4 | 5 | python3.7 -m venv ~/venv/python37_onpolicy 6 | source ~/venv/python37_onpolicy/bin/activate 7 | python3 -m pip install --upgrade pip 8 | pip install setuptools==51.0.0 9 | pip install wheel 10 | pip install torch==1.5.1+cu101 torchvision==0.6.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html 11 | pip install wandb setproctitle absl-py gym matplotlib pandas pygame imageio tensorboardX numba 12 | # change to the root directory of the onpolicy library 13 | pip install -e . 14 | 15 | -------------------------------------------------------------------------------- /onpolicy/envs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | import socket 3 | from absl import flags 4 | FLAGS = flags.FLAGS 5 | FLAGS(['train_sc.py']) 6 | 7 | 8 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/.travis.yml: -------------------------------------------------------------------------------- 1 | language: python 2 | python: 3 | - "3.5" 4 | 5 | # command to install dependencies 6 | install: 7 | - pip3 install -e . 8 | 9 | # command to run tests 10 | script: ./run_tests.py 11 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/GridWorld_Env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | from .gym_minigrid.envs.human import HumanEnv 3 | from onpolicy.envs.gridworld.gym_minigrid.register import register 4 | import numpy as np 5 | from icecream import ic 6 | from onpolicy.utils.multi_discrete import MultiDiscrete 7 | 8 | class GridWorldEnv(object): 9 | def __init__(self, args): 10 | 11 | self.num_agents = args.num_agents 12 | self.scenario_name = args.scenario_name 13 | self.use_random_pos = args.use_random_pos 14 | self.agent_pos = None if self.use_random_pos else args.agent_pos 15 | self.num_obstacles = args.num_obstacles 16 | self.use_single_reward = args.use_single_reward 17 | self.use_discrect = args.use_discrect 18 | 19 | register( 20 | id=self.scenario_name, 21 | grid_size=args.grid_size, 22 | max_steps=args.max_steps, 23 | local_step_num=args.local_step_num, 24 | agent_view_size=args.agent_view_size, 25 | num_agents=self.num_agents, 26 | num_obstacles=self.num_obstacles, 27 | agent_pos=self.agent_pos, 28 | use_merge_plan=args.use_merge_plan, 29 | use_merge=args.use_merge, 30 | use_constrict_map=args.use_constrict_map, 31 | use_fc_net=args.use_fc_net, 32 | use_agent_id=args.use_agent_id, 33 | use_stack=args.use_stack, 34 | use_orientation=args.use_orientation, 35 | use_same_location=args.use_same_location, 36 | use_complete_reward=args.use_complete_reward, 37 | use_agent_obstacle=args.use_agent_obstacle, 38 | use_multiroom=args.use_multiroom, 39 | use_irregular_room=args.use_irregular_room, 40 | use_time_penalty=args.use_time_penalty, 41 | use_overlap_penalty=args.use_overlap_penalty, 42 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MultiExplorationEnv', 43 | astar_cost_mode=args.astar_cost_mode 44 | ) 45 | 46 | self.env = gym.make(self.scenario_name) 47 | self.max_steps = self.env.max_steps 48 | # print("max step is {}".format(self.max_steps)) 49 | 50 | self.observation_space = self.env.observation_space 51 | self.share_observation_space = self.env.observation_space 52 | 53 | if self.use_discrect: 54 | self.action_space = [ 55 | MultiDiscrete([[0, args.grid_size - 1],[0, args.grid_size - 1]]) 56 | for _ in range(self.num_agents) 57 | ] 58 | else: 59 | self.action_space = [ 60 | gym.spaces.Box(low=0.0, high=1.0, shape=(2,), dtype=np.float32) 61 | for _ in range(self.num_agents) 62 | ] 63 | 64 | def seed(self, seed=None): 65 | if seed is None: 66 | self.env.seed(1) 67 | else: 68 | self.env.seed(seed) 69 | 70 | def reset(self, choose=True): 71 | if choose: 72 | obs, info = self.env.reset() 73 | else: 74 | obs = [ 75 | { 76 | 'image': np.zeros((self.env.width, self.env.height, 3), dtype='uint8'), 77 | 'direction': 0, 78 | 'mission': " " 79 | } for agent_id in range(self.num_agents) 80 | ] 81 | info = {} 82 | return obs, info 83 | 84 | def step(self, actions): 85 | if not np.all(actions == np.ones((self.num_agents, 1)).astype(np.int) * (-1.0)): 86 | obs, rewards, done, infos = self.env.step(actions) 87 | dones = np.array([done for agent_id in range(self.num_agents)]) 88 | if self.use_single_reward: 89 | rewards = 0.3 * np.expand_dims(infos['agent_explored_reward'], axis=1) + 0.7 * np.expand_dims( 90 | np.array([infos['merge_explored_reward'] for _ in range(self.num_agents)]), axis=1) 91 | else: 92 | rewards = np.expand_dims( 93 | np.array([infos['merge_explored_reward'] for _ in range(self.num_agents)]), axis=1) 94 | else: 95 | obs = [ 96 | { 97 | 'image': np.zeros((self.env.width, self.env.height, 3), dtype='uint8'), 98 | 'direction': 0, 99 | 'mission': " " 100 | } for agent_id in range(self.num_agents) 101 | ] 102 | rewards = np.zeros((self.num_agents, 1)) 103 | dones = np.array([None for agent_id in range(self.num_agents)]) 104 | infos = {} 105 | 106 | return obs, rewards, dones, infos 107 | 108 | def close(self): 109 | self.env.close() 110 | 111 | def get_short_term_action(self, input): 112 | outputs = self.env.get_short_term_action(input) 113 | return outputs 114 | 115 | def render(self, mode="human", short_goal_pos=None): 116 | if mode == "human": 117 | self.env.render(mode=mode, short_goal_pos=short_goal_pos) 118 | else: 119 | return self.env.render(mode=mode, short_goal_pos=short_goal_pos) 120 | 121 | def ft_get_short_term_goals(self, args, mode=""): 122 | mode_list = ['apf', 'utility', 'nearest', 'rrt', 'voronoi'] 123 | assert mode in mode_list, (f"frontier global mode should be in {mode_list}") 124 | return self.env.ft_get_short_term_goals(args, mode=mode) 125 | 126 | def ft_get_short_term_actions(self, *args): 127 | return self.env.ft_get_short_term_actions(*args) 128 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/benchmark.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import time 4 | import argparse 5 | import gym_minigrid 6 | import gym 7 | from gym_minigrid.wrappers import * 8 | 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument( 11 | "--env-name", 12 | dest="env_name", 13 | help="gym environment to load", 14 | default='MiniGrid-LavaGapS7-v0' 15 | ) 16 | parser.add_argument("--num_resets", default=200) 17 | parser.add_argument("--num_frames", default=5000) 18 | args = parser.parse_args() 19 | 20 | env = gym.make(args.env_name) 21 | 22 | # Benchmark env.reset 23 | t0 = time.time() 24 | for i in range(args.num_resets): 25 | env.reset() 26 | t1 = time.time() 27 | dt = t1 - t0 28 | reset_time = (1000 * dt) / args.num_resets 29 | 30 | # Benchmark rendering 31 | t0 = time.time() 32 | for i in range(args.num_frames): 33 | env.render('rgb_array') 34 | t1 = time.time() 35 | dt = t1 - t0 36 | frames_per_sec = args.num_frames / dt 37 | 38 | # Create an environment with an RGB agent observation 39 | env = gym.make(args.env_name) 40 | env = RGBImgPartialObsWrapper(env) 41 | env = ImgObsWrapper(env) 42 | 43 | # Benchmark rendering 44 | t0 = time.time() 45 | for i in range(args.num_frames): 46 | obs, reward, done, info = env.step(0) 47 | t1 = time.time() 48 | dt = t1 - t0 49 | agent_view_fps = args.num_frames / dt 50 | 51 | print('Env reset time: {:.1f} ms'.format(reset_time)) 52 | print('Rendering FPS : {:.0f}'.format(frames_per_sec)) 53 | print('Agent view FPS: {:.0f}'.format(agent_view_fps)) 54 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/BlockedUnlockPickup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/BlockedUnlockPickup.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/DistShift1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/DistShift1.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/DistShift2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/DistShift2.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/KeyCorridorS3R1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS3R1.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/KeyCorridorS3R2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS3R2.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/KeyCorridorS3R3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS3R3.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/KeyCorridorS4R3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS4R3.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/KeyCorridorS5R3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS5R3.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/KeyCorridorS6R3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/KeyCorridorS6R3.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/LavaCrossingS11N5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/LavaCrossingS11N5.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/LavaCrossingS9N1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/LavaCrossingS9N1.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/LavaCrossingS9N2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/LavaCrossingS9N2.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/LavaCrossingS9N3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/LavaCrossingS9N3.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/LavaGapS6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/LavaGapS6.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/ObstructedMaze-1Dl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-1Dl.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/ObstructedMaze-1Dlh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-1Dlh.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/ObstructedMaze-1Dlhb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-1Dlhb.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/ObstructedMaze-1Q.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-1Q.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/ObstructedMaze-2Dl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-2Dl.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/ObstructedMaze-2Dlh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-2Dlh.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/ObstructedMaze-2Dlhb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-2Dlhb.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/ObstructedMaze-2Q.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-2Q.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/ObstructedMaze-4Q.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/ObstructedMaze-4Q.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/SimpleCrossingS11N5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/SimpleCrossingS11N5.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/SimpleCrossingS9N1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/SimpleCrossingS9N1.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/SimpleCrossingS9N2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/SimpleCrossingS9N2.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/SimpleCrossingS9N3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/SimpleCrossingS9N3.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/Unlock.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/Unlock.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/UnlockPickup.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/UnlockPickup.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/door-key-curriculum.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/door-key-curriculum.gif -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/door-key-env.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/door-key-env.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/dynamic_obstacles.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/dynamic_obstacles.gif -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/empty-env.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/empty-env.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/fetch-env.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/fetch-env.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/four-rooms-env.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/four-rooms-env.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/gotodoor-6x6.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/gotodoor-6x6.mp4 -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/gotodoor-6x6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/gotodoor-6x6.png -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/figures/multi-room.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/envs/gridworld/figures/multi-room.gif -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/frontier/apf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from queue import deque 3 | import pyastar2d 4 | 5 | # class of APF(Artificial Potential Field) 6 | 7 | 8 | class APF(object): 9 | def __init__(self, args): 10 | self.args = args 11 | 12 | self.cluster_radius = args.apf_cluster_radius 13 | self.k_attract = args.apf_k_attract 14 | self.k_agents = args.apf_k_agents 15 | self.AGENT_INFERENCE_RADIUS = args.apf_AGENT_INFERENCE_RADIUS 16 | self.num_iters = args.apf_num_iters 17 | self.repeat_penalty = args.apf_repeat_penalty 18 | self.dis_type = args.apf_dis_type 19 | 20 | self.num_agents = args.num_agents 21 | 22 | def distance(self, a, b): 23 | a = np.array(a) 24 | b = np.array(b) 25 | if self.dis_type == "l2": 26 | return np.sqrt(((a-b)**2).sum()) 27 | elif self.dis_type == "l1": 28 | return abs(a-b).sum() 29 | 30 | def schedule(self, map, locations, steps, agent_id, penalty, full_path=True): 31 | ''' 32 | APF to schedule path for agent agent_id 33 | map: H x W 34 | - 0 for explored & available cell 35 | - 1 for obstacle 36 | - 2 for target (frontier) 37 | locations: num_agents x 2 38 | steps: available actions 39 | penalty: repeat penalty 40 | full_path: default True, False for single step (i.e., next cell) 41 | ''' 42 | H, W = map.shape 43 | 44 | # find available targets 45 | vis = np.zeros((H, W), dtype=np.uint8) 46 | que = deque([]) 47 | x, y = locations[agent_id] 48 | vis[x, y] = 1 49 | que.append((x, y)) 50 | while len(que) > 0: 51 | x, y = que.popleft() 52 | for dx, dy in steps: 53 | x1 = x + dx 54 | y1 = y + dy 55 | if vis[x1, y1] == 0 and map[x1, y1] in [0, 2]: 56 | vis[x1, y1] = 1 57 | que.append((x1, y1)) 58 | 59 | targets = [] 60 | for i in range(H): 61 | for j in range(W): 62 | if map[i, j] == 2 and vis[i, j] == 1: 63 | targets.append((i, j)) 64 | # clustering 65 | clusters = [] 66 | num_targets = len(targets) 67 | valid = [True for _ in range(num_targets)] 68 | for i in range(num_targets): 69 | if valid[i]: 70 | # not clustered 71 | chosen_targets = [] 72 | for j in range(num_targets): 73 | if valid[j] and self.distance(targets[i], targets[j]) <= self.cluster_radius: 74 | valid[j] = False 75 | chosen_targets.append(targets[j]) 76 | min_r = 1e6 77 | center = None 78 | for a in chosen_targets: 79 | max_d = max([self.distance(a, b) for b in chosen_targets]) 80 | if max_d < min_r: 81 | min_r = max_d 82 | center = a 83 | clusters.append({"center": center, "weight": len(chosen_targets)}) 84 | 85 | # potential 86 | num_clusters = len(clusters) 87 | potential = np.zeros((H, W)) 88 | potential[map == 1] = 1e6 89 | 90 | # potential of targets & obstacles (wave-front dist) 91 | for cluster in clusters: 92 | sx, sy = cluster["center"] 93 | w = cluster["weight"] 94 | dis = np.ones((H, W), dtype=np.int64) * 1e6 95 | dis[sx, sy] = 0 96 | que = deque([(sx, sy)]) 97 | while len(que) > 0: 98 | (x, y) = que.popleft() 99 | for dx, dy in steps: 100 | x1 = x + dx 101 | y1 = y + dy 102 | if dis[x1, y1] == 1e6 and map[x1, y1] in [0, 2]: 103 | dis[x1, y1] = dis[x, y]+1 104 | que.append((x1, y1)) 105 | dis[sx, sy] = 1e6 106 | dis = 1 / dis 107 | dis[sx, sy] = 0 108 | potential[map != 1] -= dis[map != 1] * self.k_attract * w 109 | 110 | # potential of agents 111 | for x in range(H): 112 | for y in range(W): 113 | for agent_loc in locations: 114 | d = self.distance(agent_loc, (x, y)) 115 | if d <= self.AGENT_INFERENCE_RADIUS: 116 | potential[x, y] += self.k_agents * (self.AGENT_INFERENCE_RADIUS - d) 117 | 118 | # repeat penalty 119 | potential += penalty 120 | 121 | # schedule path 122 | it = 1 123 | current_loc = locations[agent_id] 124 | current_potential = 1e4 125 | minDis2Target = 1e6 126 | path = [(current_loc[0], current_loc[1])] 127 | while it <= self.num_iters and minDis2Target > 1: 128 | it = it + 1 129 | potential[current_loc[0], current_loc[1]] += self.repeat_penalty 130 | best_neigh = None 131 | min_potential = 1e6 132 | for dx, dy in steps: 133 | neighbor_loc = (current_loc[0] + dx, current_loc[1] + dy) 134 | if map[neighbor_loc[0], neighbor_loc[1]] == 1: 135 | continue 136 | if min_potential > potential[neighbor_loc[0], neighbor_loc[1]]: 137 | min_potential = potential[neighbor_loc[0], neighbor_loc[1]] 138 | best_neigh = neighbor_loc 139 | if current_potential > min_potential: 140 | current_potential = min_potential 141 | current_loc = best_neigh 142 | path.append(best_neigh) 143 | for tar in targets: 144 | l = self.distance(current_loc, tar) 145 | if l == 0: 146 | continue 147 | minDis2Target = min(minDis2Target, l) 148 | if l <= 1: 149 | path.append((tar[0], tar[1])) 150 | break 151 | if not full_path and len(path) > 1: 152 | return path[1] # next grid 153 | random_plan = False 154 | if minDis2Target > 1: 155 | random_plan = True 156 | for i in range(agent_id): 157 | if locations[i][0] == locations[agent_id][0] and locations[i][1] == locations[agent_id][1]: 158 | random_plan = True # two agents are at the same location, replan 159 | if random_plan: 160 | # if not reaching a frontier, randomly pick a traget as goal 161 | if num_targets == 0: 162 | targets = [(np.random.randint(0, H), np.random.randint(0, W))] 163 | num_targets = 1 164 | w = np.random.randint(0, num_targets) 165 | goal = targets[w] 166 | temp_map = np.ones((H, W), dtype=np.float32) 167 | temp_map[temp_map == 1] = 1000000 168 | temp_map[temp_map == 2] = 1.0 169 | path = pyastar2d.astar_path(temp_map, locations[agent_id], goal, allow_diagonal=False) 170 | if len(path) == 1: 171 | path = (locations[agent_id], goal) 172 | return path 173 | return path 174 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/frontier/nearest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from .utils import * 3 | import random 4 | 5 | 6 | def nearest_goal(map, loc, steps): 7 | dis, vis = bfs(map, loc, steps) 8 | H, W = map.shape 9 | frontiers = [] 10 | for x in range(H): 11 | for y in range(W): 12 | if map[x, y] == 2 and vis[x, y]: 13 | frontiers.append((x, y)) 14 | if len(frontiers) == 0: 15 | goal = random.randint(0, H-1), random.randint(0, W-1) 16 | return goal 17 | dist = [dis[x, y] for x, y in frontiers] 18 | mi = min(dist) 19 | candidates = [(x, y) for i, (x, y) in enumerate(frontiers) if dist[i] == mi] 20 | goal = random.choice(candidates) 21 | return goal 22 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/frontier/rrt.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.frontier.utils import generate_square 2 | from onpolicy.envs.gridworld.frontier.utils import find_rectangle_obstacles 3 | from onpolicy.utils.RRT.rrt import RRT 4 | import numpy as np 5 | from .utils import * 6 | import random 7 | 8 | 9 | class rrt_config(object): 10 | cluster_radius = 0 11 | rrt_max_iter = 10000 12 | rrt_num_targets = 500 13 | utility_edge_len = 7 14 | 15 | 16 | def rrt_goal(map, unexplored, loc): 17 | H, W = map.shape 18 | map = map.astype(np.int32) 19 | 20 | obstacles = find_rectangle_obstacles(map) 21 | 22 | rrt = RRT(start=(loc[0] + 0.5, loc[1] + 0.5), 23 | goals=[], 24 | rand_area=((0, H), (0, W)), 25 | obstacle_list=obstacles, 26 | expand_dis=0.5, 27 | goal_sample_rate=-1, 28 | max_iter=rrt_config.rrt_max_iter) # maybe more iterations? 29 | 30 | rrt_map = unexplored.copy().astype(np.int32) 31 | targets = rrt.select_frontiers(rrt_map, num_targets=rrt_config.rrt_num_targets) 32 | # print("targets: ", len(targets)) 33 | 34 | clusters = get_frontier_cluster(targets, cluster_radius=rrt_config.cluster_radius) 35 | # print("clusters: ",len(clusters)) 36 | 37 | if len(clusters) == 0: 38 | x, y = random.randint(0, H-1), random.randint(0, W-1) 39 | while map[x, y] == 1: 40 | x, y = random.randint(0, H-1), random.randint(0, W-1) 41 | goal = (x, y) 42 | return goal 43 | for cluster in clusters: 44 | center = cluster['center'] 45 | # navigation cost 46 | nav_cost = l1distance(center, loc) 47 | # information gain 48 | mat = generate_square(H, W, center, rrt_config.utility_edge_len) 49 | area = mat.sum() 50 | info_gain = rrt_map[mat == 1].sum() 51 | info_gain /= area 52 | cluster['info_gain'] = info_gain 53 | cluster['nav_cost'] = nav_cost 54 | D = max([cluster['nav_cost'] for cluster in clusters]) 55 | goal = None 56 | mx = -1e9 57 | for cluster in clusters: 58 | cluster['nav_cost'] /= D 59 | cluster['utility'] = cluster['info_gain'] - 1.0 * cluster['nav_cost'] 60 | if mx < cluster['utility']: 61 | mx = cluster['utility'] 62 | goal = cluster['center'] 63 | 64 | return goal 65 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/frontier/utility.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.frontier.utils import generate_square 2 | import numpy as np 3 | from .utils import * 4 | import random 5 | 6 | 7 | def utility_goal(map, unexplored, loc, steps, edge_len=7): 8 | _, vis = bfs(map, loc, steps) 9 | H, W = map.shape 10 | utility = np.zeros((H, W), dtype=np.int32) 11 | frontiers = [] 12 | for x in range(H): 13 | for y in range(W): 14 | if map[x, y] == 2 and vis[x, y]: 15 | mat = generate_square(H, W, (x, y), edge_len) 16 | utility[x, y] = unexplored[mat == 1].sum() 17 | frontiers.append((x, y)) 18 | mx = utility.max() 19 | value = [utility[x, y] for x, y in frontiers] 20 | candidates = [(x, y) for i, (x, y) in enumerate(frontiers) if value[i] == mx] 21 | if len(candidates) > 0: 22 | goal = random.choice(candidates) 23 | else: 24 | goal = random.randint(0, H-1), random.randint(0, W-1) 25 | return goal 26 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/frontier/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from queue import deque 3 | import math 4 | 5 | 6 | def l1distance(x, y): 7 | return abs(x[0] - y[0]) + abs(x[1] - y[1]) 8 | 9 | 10 | def l2distance(x, y): 11 | return math.hypot(x[0]-y[0], x[1]-y[1]) 12 | 13 | 14 | def bfs(map, start, steps): 15 | sx, sy = start[0], start[1] 16 | assert map[sx, sy] != 1, ('start position should not be obstacle') 17 | que = deque([(sx, sy)]) 18 | H, W = map.shape 19 | dis = np.ones((H, W), dtype=np.int32) * 1e9 20 | dis[sx, sy] = 0 21 | while len(que) > 0: 22 | x, y = que.popleft() 23 | neigh = [(x+dx, y+dy) for dx, dy in steps] 24 | neigh = [(x, y) for x, y in neigh if x >= 0 and x < H and y > 25 | 0 and y < W and map[x, y] != 1 and dis[x, y] == 1e9] 26 | for u, v in neigh: 27 | dis[u, v] = dis[x, y] + 1 28 | que.append((u, v)) 29 | vis = (dis < 1e9) 30 | return dis, vis 31 | 32 | 33 | def generate_square(H, W, loc, d): 34 | mat = np.zeros((H, W), dtype=np.int32) 35 | for x in range(H): 36 | for y in range(W): 37 | if max(abs(x-loc[0]), abs(y-loc[1])) <= d: 38 | mat[x, y] = 1 39 | return mat 40 | 41 | 42 | def get_frontier_cluster(frontiers, cluster_radius=5.0): 43 | num_frontier = len(frontiers) 44 | clusters = [] 45 | valid = [True for _ in range(num_frontier)] 46 | for i in range(num_frontier): 47 | if valid[i]: 48 | neigh = [] 49 | for j in range(num_frontier): 50 | if valid[j] and l2distance(frontiers[i], frontiers[j]) <= cluster_radius: 51 | valid[j] = False 52 | neigh.append(frontiers[j]) 53 | center = None 54 | min_r = 1e9 55 | for p in neigh: 56 | r = max([l2distance(p, q) for q in neigh]) 57 | if r < min_r: 58 | min_r = r 59 | center = p 60 | if len(neigh) >= 5: 61 | clusters.append({'center': center, 'weight': len(neigh)}) 62 | return clusters 63 | 64 | 65 | def find_rectangle_obstacles(map): 66 | map = map.copy().astype(np.int32) 67 | map[map == 2] = 0 68 | H, W = map.shape 69 | obstacles = [] 70 | covered = np.zeros((H, W), dtype=np.int32) 71 | pad = 0.01 72 | for x in range(H): 73 | for y in range(W): 74 | if map[x, y] == 1 and covered[x, y] == 0: 75 | x1 = x 76 | x2 = x 77 | while x2 < H-1 and map[x2 + 1, y] == 1: 78 | x2 = x2 + 1 79 | y1 = y 80 | y2 = y 81 | while y2 < W-1 and map[x1: x2+1, y2 + 1].sum() == x2-x1+1: 82 | y2 = y2 + 1 83 | covered[x1: x2 + 1, y1: y2 + 1] = 1 84 | obstacles.append((x1-pad, y1-pad, x2 + 1 + pad, y2 + 1 + pad)) 85 | return obstacles 86 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/frontier/voronoi.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.frontier.utils import generate_square 2 | import numpy as np 3 | from .utils import * 4 | import random 5 | 6 | 7 | def voronoi_goal(map, unexplored, locs, agent_id, steps, edge_len=7): 8 | num_agents = len(locs) 9 | 10 | dis = [] 11 | vis = [] 12 | for a, loc in enumerate(locs): 13 | _dis, _vis = bfs(map, loc, steps) 14 | dis.append(_dis) 15 | vis.append(_vis) 16 | 17 | H, W = map.shape 18 | my_grids = np.ones((H, W), dtype=np.int32) 19 | for a in range(num_agents): 20 | my_grids[dis[agent_id] > dis[a]] = 0 21 | 22 | utility = np.zeros((H, W), dtype=np.int32) 23 | frontiers = [] 24 | for x in range(H): 25 | for y in range(W): 26 | if map[x, y] == 2 and vis[agent_id][x, y] and my_grids[x, y]: 27 | mat = generate_square(H, W, (x, y), edge_len) 28 | utility[x, y] = unexplored[mat == 1].sum() 29 | frontiers.append((x, y)) 30 | mx = utility.max() 31 | value = [utility[x, y] for x, y in frontiers] 32 | candidates = [(x, y) for i, (x, y) in enumerate(frontiers) if value[i] == mx] 33 | if len(candidates) > 0: 34 | goal = random.choice(candidates) 35 | else: 36 | goal = random.randint(0, H-1), random.randint(0, W-1) 37 | return goal 38 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/__init__.py: -------------------------------------------------------------------------------- 1 | # Import the envs module so that envs register themselves 2 | import onpolicy.envs.gridworld.gym_minigrid.envs 3 | 4 | # Import wrappers so it's accessible when installing with pip 5 | import onpolicy.envs.gridworld.gym_minigrid.wrappers 6 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # from onpolicy.envs.gridworld.gym_minigrid.envs.empty import * 2 | # from onpolicy.envs.gridworld.gym_minigrid.envs.doorkey import * 3 | # from onpolicy.envs.gridworld.gym_minigrid.envs.multiroom import * 4 | # from onpolicy.envs.gridworld.gym_minigrid.envs.fetch import * 5 | # from onpolicy.envs.gridworld.gym_minigrid.envs.gotoobject import * 6 | # from onpolicy.envs.gridworld.gym_minigrid.envs.gotodoor import * 7 | # from onpolicy.envs.gridworld.gym_minigrid.envs.putnear import * 8 | # from onpolicy.envs.gridworld.gym_minigrid.envs.lockedroom import * 9 | # from onpolicy.envs.gridworld.gym_minigrid.envs.keycorridor import * 10 | # from onpolicy.envs.gridworld.gym_minigrid.envs.unlock import * 11 | # from onpolicy.envs.gridworld.gym_minigrid.envs.unlockpickup import * 12 | # from onpolicy.envs.gridworld.gym_minigrid.envs.blockedunlockpickup import * 13 | # from onpolicy.envs.gridworld.gym_minigrid.envs.playground_v0 import * 14 | # from onpolicy.envs.gridworld.gym_minigrid.envs.redbluedoors import * 15 | # from onpolicy.envs.gridworld.gym_minigrid.envs.obstructedmaze import * 16 | # from onpolicy.envs.gridworld.gym_minigrid.envs.memory import * 17 | # from onpolicy.envs.gridworld.gym_minigrid.envs.fourrooms import * 18 | # from onpolicy.envs.gridworld.gym_minigrid.envs.crossing import * 19 | # from onpolicy.envs.gridworld.gym_minigrid.envs.lavagap import * 20 | # from onpolicy.envs.gridworld.gym_minigrid.envs.dynamicobstacles import * 21 | # from onpolicy.envs.gridworld.gym_minigrid.envs.distshift import * 22 | from onpolicy.envs.gridworld.gym_minigrid.envs.human import * 23 | from onpolicy.envs.gridworld.gym_minigrid.envs.multiexploration import * 24 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/blockedunlockpickup.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import Ball 2 | from onpolicy.envs.gridworld.gym_minigrid.roomgrid import RoomGrid 3 | from onpolicy.envs.gridworld.gym_minigrid.register import register 4 | 5 | class BlockedUnlockPickup(RoomGrid): 6 | """ 7 | Unlock a door blocked by a ball, then pick up a box 8 | in another room 9 | """ 10 | 11 | def __init__(self, seed=None): 12 | room_size = 6 13 | super().__init__( 14 | num_rows=1, 15 | num_cols=2, 16 | room_size=room_size, 17 | max_steps=16*room_size**2, 18 | seed=seed 19 | ) 20 | 21 | def _gen_grid(self, width, height): 22 | super()._gen_grid(width, height) 23 | 24 | # Add a box to the room on the right 25 | obj, _ = self.add_object(1, 0, kind="box") 26 | # Make sure the two rooms are directly connected by a locked door 27 | door, pos = self.add_door(0, 0, 0, locked=True) 28 | # Block the door with a ball 29 | color = self._rand_color() 30 | self.grid.set(pos[0]-1, pos[1], Ball(color)) 31 | # Add a key to unlock the door 32 | self.add_object(0, 0, 'key', door.color) 33 | 34 | self.place_agent(0, 0) 35 | 36 | self.obj = obj 37 | self.mission = "pick up the %s %s" % (obj.color, obj.type) 38 | 39 | def step(self, action): 40 | obs, reward, done, info = super().step(action) 41 | 42 | if action == self.actions.pickup: 43 | if self.carrying and self.carrying == self.obj: 44 | reward = self._reward() 45 | done = True 46 | 47 | return obs, reward, done, info 48 | 49 | register( 50 | id='MiniGrid-BlockedUnlockPickup-v0', 51 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:BlockedUnlockPickup' 52 | ) 53 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/crossing.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | import itertools as itt 5 | 6 | 7 | class CrossingEnv(MiniGridEnv): 8 | """ 9 | Environment with wall or lava obstacles, sparse reward. 10 | """ 11 | 12 | def __init__(self, size=9, num_crossings=1, obstacle_type=Lava, seed=None): 13 | self.num_crossings = num_crossings 14 | self.obstacle_type = obstacle_type 15 | super().__init__( 16 | grid_size=size, 17 | max_steps=4*size*size, 18 | # Set this to True for maximum speed 19 | see_through_walls=False, 20 | seed=None 21 | ) 22 | 23 | def _gen_grid(self, width, height): 24 | assert width % 2 == 1 and height % 2 == 1 # odd size 25 | 26 | # Create an empty grid 27 | self.grid = Grid(width, height) 28 | 29 | # Generate the surrounding walls 30 | self.grid.wall_rect(0, 0, width, height) 31 | 32 | # Place the agent in the top-left corner 33 | self.agent_pos = (1, 1) 34 | self.agent_dir = 0 35 | 36 | # Place a goal square in the bottom-right corner 37 | self.put_obj(Goal(), width - 2, height - 2) 38 | 39 | # Place obstacles (lava or walls) 40 | v, h = object(), object() # singleton `vertical` and `horizontal` objects 41 | 42 | # Lava rivers or walls specified by direction and position in grid 43 | rivers = [(v, i) for i in range(2, height - 2, 2)] 44 | rivers += [(h, j) for j in range(2, width - 2, 2)] 45 | self.np_random.shuffle(rivers) 46 | rivers = rivers[:self.num_crossings] # sample random rivers 47 | rivers_v = sorted([pos for direction, pos in rivers if direction is v]) 48 | rivers_h = sorted([pos for direction, pos in rivers if direction is h]) 49 | obstacle_pos = itt.chain( 50 | itt.product(range(1, width - 1), rivers_h), 51 | itt.product(rivers_v, range(1, height - 1)), 52 | ) 53 | for i, j in obstacle_pos: 54 | self.put_obj(self.obstacle_type(), i, j) 55 | 56 | # Sample path to goal 57 | path = [h] * len(rivers_v) + [v] * len(rivers_h) 58 | self.np_random.shuffle(path) 59 | 60 | # Create openings 61 | limits_v = [0] + rivers_v + [height - 1] 62 | limits_h = [0] + rivers_h + [width - 1] 63 | room_i, room_j = 0, 0 64 | for direction in path: 65 | if direction is h: 66 | i = limits_v[room_i + 1] 67 | j = self.np_random.choice( 68 | range(limits_h[room_j] + 1, limits_h[room_j + 1])) 69 | room_i += 1 70 | elif direction is v: 71 | i = self.np_random.choice( 72 | range(limits_v[room_i] + 1, limits_v[room_i + 1])) 73 | j = limits_h[room_j + 1] 74 | room_j += 1 75 | else: 76 | assert False 77 | self.grid.set(i, j, None) 78 | 79 | self.mission = ( 80 | "avoid the lava and get to the green goal square" 81 | if self.obstacle_type == Lava 82 | else "find the opening and get to the green goal square" 83 | ) 84 | 85 | class LavaCrossingEnv(CrossingEnv): 86 | def __init__(self): 87 | super().__init__(size=9, num_crossings=1) 88 | 89 | class LavaCrossingS9N2Env(CrossingEnv): 90 | def __init__(self): 91 | super().__init__(size=9, num_crossings=2) 92 | 93 | class LavaCrossingS9N3Env(CrossingEnv): 94 | def __init__(self): 95 | super().__init__(size=9, num_crossings=3) 96 | 97 | class LavaCrossingS11N5Env(CrossingEnv): 98 | def __init__(self): 99 | super().__init__(size=11, num_crossings=5) 100 | 101 | register( 102 | id='MiniGrid-LavaCrossingS9N1-v0', 103 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaCrossingEnv' 104 | ) 105 | 106 | register( 107 | id='MiniGrid-LavaCrossingS9N2-v0', 108 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaCrossingS9N2Env' 109 | ) 110 | 111 | register( 112 | id='MiniGrid-LavaCrossingS9N3-v0', 113 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaCrossingS9N3Env' 114 | ) 115 | 116 | register( 117 | id='MiniGrid-LavaCrossingS11N5-v0', 118 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaCrossingS11N5Env' 119 | ) 120 | 121 | class SimpleCrossingEnv(CrossingEnv): 122 | def __init__(self): 123 | super().__init__(size=9, num_crossings=1, obstacle_type=Wall) 124 | 125 | class SimpleCrossingS9N2Env(CrossingEnv): 126 | def __init__(self): 127 | super().__init__(size=9, num_crossings=2, obstacle_type=Wall) 128 | 129 | class SimpleCrossingS9N3Env(CrossingEnv): 130 | def __init__(self): 131 | super().__init__(size=9, num_crossings=3, obstacle_type=Wall) 132 | 133 | class SimpleCrossingS11N5Env(CrossingEnv): 134 | def __init__(self): 135 | super().__init__(size=11, num_crossings=5, obstacle_type=Wall) 136 | 137 | register( 138 | id='MiniGrid-SimpleCrossingS9N1-v0', 139 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:SimpleCrossingEnv' 140 | ) 141 | 142 | register( 143 | id='MiniGrid-SimpleCrossingS9N2-v0', 144 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:SimpleCrossingS9N2Env' 145 | ) 146 | 147 | register( 148 | id='MiniGrid-SimpleCrossingS9N3-v0', 149 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:SimpleCrossingS9N3Env' 150 | ) 151 | 152 | register( 153 | id='MiniGrid-SimpleCrossingS11N5-v0', 154 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:SimpleCrossingS11N5Env' 155 | ) 156 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/distshift.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class DistShiftEnv(MiniGridEnv): 5 | """ 6 | Distributional shift environment. 7 | """ 8 | 9 | def __init__( 10 | self, 11 | width=9, 12 | height=7, 13 | agent_start_pos=(1,1), 14 | agent_start_dir=0, 15 | strip2_row=2 16 | ): 17 | self.agent_start_pos = agent_start_pos 18 | self.agent_start_dir = agent_start_dir 19 | self.goal_pos = (width-2, 1) 20 | self.strip2_row = strip2_row 21 | 22 | super().__init__( 23 | width=width, 24 | height=height, 25 | max_steps=4*width*height, 26 | # Set this to True for maximum speed 27 | see_through_walls=True 28 | ) 29 | 30 | def _gen_grid(self, width, height): 31 | # Create an empty grid 32 | self.grid = Grid(width, height) 33 | 34 | # Generate the surrounding walls 35 | self.grid.wall_rect(0, 0, width, height) 36 | 37 | # Place a goal square in the bottom-right corner 38 | self.put_obj(Goal(), *self.goal_pos) 39 | 40 | # Place the lava rows 41 | for i in range(self.width - 6): 42 | self.grid.set(3+i, 1, Lava()) 43 | self.grid.set(3+i, self.strip2_row, Lava()) 44 | 45 | # Place the agent 46 | if self.agent_start_pos is not None: 47 | self.agent_pos = self.agent_start_pos 48 | self.agent_dir = self.agent_start_dir 49 | else: 50 | self.place_agent() 51 | 52 | self.mission = "get to the green goal square" 53 | 54 | class DistShift1(DistShiftEnv): 55 | def __init__(self): 56 | super().__init__(strip2_row=2) 57 | 58 | class DistShift2(DistShiftEnv): 59 | def __init__(self): 60 | super().__init__(strip2_row=5) 61 | 62 | register( 63 | id='MiniGrid-DistShift1-v0', 64 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DistShift1' 65 | ) 66 | 67 | register( 68 | id='MiniGrid-DistShift2-v0', 69 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DistShift2' 70 | ) 71 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/doorkey.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class DoorKeyEnv(MiniGridEnv): 5 | """ 6 | Environment with a door and key, sparse reward 7 | """ 8 | 9 | def __init__(self, size=8): 10 | super().__init__( 11 | grid_size=size, 12 | max_steps=10*size*size 13 | ) 14 | 15 | def _gen_grid(self, width, height): 16 | # Create an empty grid 17 | self.grid = Grid(width, height) 18 | 19 | # Generate the surrounding walls 20 | self.grid.wall_rect(0, 0, width, height) 21 | 22 | # Place a goal in the bottom-right corner 23 | self.put_obj(Goal(), width - 2, height - 2) 24 | 25 | # Create a vertical splitting wall 26 | splitIdx = self._rand_int(2, width-2) 27 | self.grid.vert_wall(splitIdx, 0) 28 | 29 | # Place the agent at a random position and orientation 30 | # on the left side of the splitting wall 31 | self.place_agent(size=(splitIdx, height)) 32 | 33 | # Place a door in the wall 34 | doorIdx = self._rand_int(1, width-2) 35 | self.put_obj(Door('yellow', is_locked=True), splitIdx, doorIdx) 36 | 37 | # Place a yellow key on the left side 38 | self.place_obj( 39 | obj=Key('yellow'), 40 | top=(0, 0), 41 | size=(splitIdx, height) 42 | ) 43 | 44 | self.mission = "use the key to open the door and then get to the goal" 45 | 46 | class DoorKeyEnv5x5(DoorKeyEnv): 47 | def __init__(self): 48 | super().__init__(size=5) 49 | 50 | class DoorKeyEnv6x6(DoorKeyEnv): 51 | def __init__(self): 52 | super().__init__(size=6) 53 | 54 | class DoorKeyEnv16x16(DoorKeyEnv): 55 | def __init__(self): 56 | super().__init__(size=16) 57 | 58 | register( 59 | id='MiniGrid-DoorKey-5x5-v0', 60 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DoorKeyEnv5x5' 61 | ) 62 | 63 | register( 64 | id='MiniGrid-DoorKey-6x6-v0', 65 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DoorKeyEnv6x6' 66 | ) 67 | 68 | register( 69 | id='MiniGrid-DoorKey-8x8-v0', 70 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DoorKeyEnv' 71 | ) 72 | 73 | register( 74 | id='MiniGrid-DoorKey-16x16-v0', 75 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DoorKeyEnv16x16' 76 | ) 77 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/dynamicobstacles.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | from operator import add 4 | 5 | class DynamicObstaclesEnv(MiniGridEnv): 6 | """ 7 | Single-room square grid environment with moving obstacles 8 | """ 9 | 10 | def __init__( 11 | self, 12 | size=8, 13 | agent_start_pos=(1, 1), 14 | agent_start_dir=0, 15 | n_obstacles=4 16 | ): 17 | self.agent_start_pos = agent_start_pos 18 | self.agent_start_dir = agent_start_dir 19 | 20 | # Reduce obstacles if there are too many 21 | if n_obstacles <= size/2 + 1: 22 | self.n_obstacles = int(n_obstacles) 23 | else: 24 | self.n_obstacles = int(size/2) 25 | super().__init__( 26 | grid_size=size, 27 | max_steps=4 * size * size, 28 | # Set this to True for maximum speed 29 | see_through_walls=True, 30 | ) 31 | # Allow only 3 actions permitted: left, right, forward 32 | self.action_space = spaces.Discrete(self.actions.forward + 1) 33 | self.reward_range = (-1, 1) 34 | 35 | def _gen_grid(self, width, height): 36 | # Create an empty grid 37 | self.grid = Grid(width, height) 38 | 39 | # Generate the surrounding walls 40 | self.grid.wall_rect(0, 0, width, height) 41 | 42 | # Place a goal square in the bottom-right corner 43 | self.grid.set(width - 2, height - 2, Goal()) 44 | 45 | # Place the agent 46 | if self.agent_start_pos is not None: 47 | self.agent_pos = self.agent_start_pos 48 | self.agent_dir = self.agent_start_dir 49 | else: 50 | self.place_agent() 51 | 52 | # Place obstacles 53 | self.obstacles = [] 54 | for i_obst in range(self.n_obstacles): 55 | self.obstacles.append(Ball()) 56 | self.place_obj(self.obstacles[i_obst], max_tries=100) 57 | 58 | self.mission = "get to the green goal square" 59 | 60 | def step(self, action): 61 | # Invalid action 62 | if action >= self.action_space.n: 63 | action = 0 64 | 65 | # Check if there is an obstacle in front of the agent 66 | front_cell = self.grid.get(*self.front_pos) 67 | not_clear = front_cell and front_cell.type != 'goal' 68 | 69 | # Update obstacle positions 70 | for i_obst in range(len(self.obstacles)): 71 | old_pos = self.obstacles[i_obst].cur_pos 72 | top = tuple(map(add, old_pos, (-1, -1))) 73 | 74 | try: 75 | self.place_obj(self.obstacles[i_obst], top=top, size=(3,3), max_tries=100) 76 | self.grid.set(*old_pos, None) 77 | except: 78 | pass 79 | 80 | # Update the agent's position/direction 81 | obs, reward, done, info = MiniGridEnv.step(self, action) 82 | 83 | # If the agent tried to walk over an obstacle or wall 84 | if action == self.actions.forward and not_clear: 85 | reward = -1 86 | done = True 87 | return obs, reward, done, info 88 | 89 | return obs, reward, done, info 90 | 91 | class DynamicObstaclesEnv5x5(DynamicObstaclesEnv): 92 | def __init__(self): 93 | super().__init__(size=5, n_obstacles=2) 94 | 95 | class DynamicObstaclesRandomEnv5x5(DynamicObstaclesEnv): 96 | def __init__(self): 97 | super().__init__(size=5, agent_start_pos=None, n_obstacles=2) 98 | 99 | class DynamicObstaclesEnv6x6(DynamicObstaclesEnv): 100 | def __init__(self): 101 | super().__init__(size=6, n_obstacles=3) 102 | 103 | class DynamicObstaclesRandomEnv6x6(DynamicObstaclesEnv): 104 | def __init__(self): 105 | super().__init__(size=6, agent_start_pos=None, n_obstacles=3) 106 | 107 | class DynamicObstaclesEnv16x16(DynamicObstaclesEnv): 108 | def __init__(self): 109 | super().__init__(size=16, n_obstacles=8) 110 | 111 | register( 112 | id='MiniGrid-Dynamic-Obstacles-5x5-v0', 113 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesEnv5x5' 114 | ) 115 | 116 | register( 117 | id='MiniGrid-Dynamic-Obstacles-Random-5x5-v0', 118 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesRandomEnv5x5' 119 | ) 120 | 121 | register( 122 | id='MiniGrid-Dynamic-Obstacles-6x6-v0', 123 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesEnv6x6' 124 | ) 125 | 126 | register( 127 | id='MiniGrid-Dynamic-Obstacles-Random-6x6-v0', 128 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesRandomEnv6x6' 129 | ) 130 | 131 | register( 132 | id='MiniGrid-Dynamic-Obstacles-8x8-v0', 133 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesEnv' 134 | ) 135 | 136 | register( 137 | id='MiniGrid-Dynamic-Obstacles-16x16-v0', 138 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:DynamicObstaclesEnv16x16' 139 | ) 140 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/empty.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class EmptyEnv(MiniGridEnv): 5 | """ 6 | Empty grid environment, no obstacles, sparse reward 7 | """ 8 | 9 | def __init__( 10 | self, 11 | size=8, 12 | agent_start_pos=(1,1), 13 | agent_start_dir=0, 14 | ): 15 | self.agent_start_pos = agent_start_pos 16 | self.agent_start_dir = agent_start_dir 17 | 18 | super().__init__( 19 | grid_size=size, 20 | max_steps=4*size*size, 21 | # Set this to True for maximum speed 22 | see_through_walls=True 23 | ) 24 | 25 | def _gen_grid(self, width, height): 26 | # Create an empty grid 27 | self.grid = Grid(width, height) 28 | 29 | # Generate the surrounding walls 30 | self.grid.wall_rect(0, 0, width, height) 31 | 32 | # Place a goal square in the bottom-right corner 33 | self.put_obj(Goal(), width - 2, height - 2) 34 | 35 | # Place the agent 36 | if self.agent_start_pos is not None: 37 | self.agent_pos = self.agent_start_pos 38 | self.agent_dir = self.agent_start_dir 39 | else: 40 | self.place_agent() 41 | 42 | self.mission = "get to the green goal square" 43 | 44 | class EmptyEnv5x5(EmptyEnv): 45 | def __init__(self, **kwargs): 46 | super().__init__(size=5, **kwargs) 47 | 48 | class EmptyRandomEnv5x5(EmptyEnv): 49 | def __init__(self): 50 | super().__init__(size=5, agent_start_pos=None) 51 | 52 | class EmptyEnv6x6(EmptyEnv): 53 | def __init__(self, **kwargs): 54 | super().__init__(size=6, **kwargs) 55 | 56 | class EmptyRandomEnv6x6(EmptyEnv): 57 | def __init__(self): 58 | super().__init__(size=6, agent_start_pos=None) 59 | 60 | class EmptyEnv16x16(EmptyEnv): 61 | def __init__(self, **kwargs): 62 | super().__init__(size=16, **kwargs) 63 | 64 | register( 65 | id='MiniGrid-Empty-5x5-v0', 66 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyEnv5x5' 67 | ) 68 | 69 | register( 70 | id='MiniGrid-Empty-Random-5x5-v0', 71 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyRandomEnv5x5' 72 | ) 73 | 74 | register( 75 | id='MiniGrid-Empty-6x6-v0', 76 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyEnv6x6' 77 | ) 78 | 79 | register( 80 | id='MiniGrid-Empty-Random-6x6-v0', 81 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyRandomEnv6x6' 82 | ) 83 | 84 | register( 85 | id='MiniGrid-Empty-8x8-v0', 86 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyEnv' 87 | ) 88 | 89 | register( 90 | id='MiniGrid-Empty-16x16-v0', 91 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:EmptyEnv16x16' 92 | ) 93 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/fetch.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class FetchEnv(MiniGridEnv): 5 | """ 6 | Environment in which the agent has to fetch a random object 7 | named using English text strings 8 | """ 9 | 10 | def __init__( 11 | self, 12 | size=8, 13 | numObjs=3 14 | ): 15 | self.numObjs = numObjs 16 | 17 | super().__init__( 18 | grid_size=size, 19 | max_steps=5*size**2, 20 | # Set this to True for maximum speed 21 | see_through_walls=True 22 | ) 23 | 24 | def _gen_grid(self, width, height): 25 | self.grid = Grid(width, height) 26 | 27 | # Generate the surrounding walls 28 | self.grid.horz_wall(0, 0) 29 | self.grid.horz_wall(0, height-1) 30 | self.grid.vert_wall(0, 0) 31 | self.grid.vert_wall(width-1, 0) 32 | 33 | types = ['key', 'ball'] 34 | 35 | objs = [] 36 | 37 | # For each object to be generated 38 | while len(objs) < self.numObjs: 39 | objType = self._rand_elem(types) 40 | objColor = self._rand_elem(COLOR_NAMES) 41 | 42 | if objType == 'key': 43 | obj = Key(objColor) 44 | elif objType == 'ball': 45 | obj = Ball(objColor) 46 | 47 | self.place_obj(obj) 48 | objs.append(obj) 49 | 50 | # Randomize the player start position and orientation 51 | self.place_agent() 52 | 53 | # Choose a random object to be picked up 54 | target = objs[self._rand_int(0, len(objs))] 55 | self.targetType = target.type 56 | self.targetColor = target.color 57 | 58 | descStr = '%s %s' % (self.targetColor, self.targetType) 59 | 60 | # Generate the mission string 61 | idx = self._rand_int(0, 5) 62 | if idx == 0: 63 | self.mission = 'get a %s' % descStr 64 | elif idx == 1: 65 | self.mission = 'go get a %s' % descStr 66 | elif idx == 2: 67 | self.mission = 'fetch a %s' % descStr 68 | elif idx == 3: 69 | self.mission = 'go fetch a %s' % descStr 70 | elif idx == 4: 71 | self.mission = 'you must fetch a %s' % descStr 72 | assert hasattr(self, 'mission') 73 | 74 | def step(self, action): 75 | obs, reward, done, info = MiniGridEnv.step(self, action) 76 | 77 | if self.carrying: 78 | if self.carrying.color == self.targetColor and \ 79 | self.carrying.type == self.targetType: 80 | reward = self._reward() 81 | done = True 82 | else: 83 | reward = 0 84 | done = True 85 | 86 | return obs, reward, done, info 87 | 88 | class FetchEnv5x5N2(FetchEnv): 89 | def __init__(self): 90 | super().__init__(size=5, numObjs=2) 91 | 92 | class FetchEnv6x6N2(FetchEnv): 93 | def __init__(self): 94 | super().__init__(size=6, numObjs=2) 95 | 96 | register( 97 | id='MiniGrid-Fetch-5x5-N2-v0', 98 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:FetchEnv5x5N2' 99 | ) 100 | 101 | register( 102 | id='MiniGrid-Fetch-6x6-N2-v0', 103 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:FetchEnv6x6N2' 104 | ) 105 | 106 | register( 107 | id='MiniGrid-Fetch-8x8-N3-v0', 108 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:FetchEnv' 109 | ) 110 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/fourrooms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 5 | from onpolicy.envs.gridworld.gym_minigrid.register import register 6 | 7 | 8 | class FourRoomsEnv(MiniGridEnv): 9 | """ 10 | Classic 4 rooms gridworld environment. 11 | Can specify agent and goal position, if not it set at random. 12 | """ 13 | 14 | def __init__(self, agent_pos=None, goal_pos=None): 15 | self._agent_default_pos = agent_pos 16 | self._goal_default_pos = goal_pos 17 | super().__init__(grid_size=19, max_steps=100) 18 | 19 | def _gen_grid(self, width, height): 20 | # Create the grid 21 | self.grid = Grid(width, height) 22 | 23 | # Generate the surrounding walls 24 | self.grid.horz_wall(0, 0) 25 | self.grid.horz_wall(0, height - 1) 26 | self.grid.vert_wall(0, 0) 27 | self.grid.vert_wall(width - 1, 0) 28 | 29 | room_w = width // 2 30 | room_h = height // 2 31 | 32 | # For each row of rooms 33 | for j in range(0, 2): 34 | 35 | # For each column 36 | for i in range(0, 2): 37 | xL = i * room_w 38 | yT = j * room_h 39 | xR = xL + room_w 40 | yB = yT + room_h 41 | 42 | # Bottom wall and door 43 | if i + 1 < 2: 44 | self.grid.vert_wall(xR, yT, room_h) 45 | pos = (xR, self._rand_int(yT + 1, yB)) 46 | self.grid.set(*pos, None) 47 | 48 | # Bottom wall and door 49 | if j + 1 < 2: 50 | self.grid.horz_wall(xL, yB, room_w) 51 | pos = (self._rand_int(xL + 1, xR), yB) 52 | self.grid.set(*pos, None) 53 | 54 | # Randomize the player start position and orientation 55 | if self._agent_default_pos is not None: 56 | self.agent_pos = self._agent_default_pos 57 | self.grid.set(*self._agent_default_pos, None) 58 | self.agent_dir = self._rand_int(0, 4) # assuming random start direction 59 | else: 60 | self.place_agent() 61 | 62 | if self._goal_default_pos is not None: 63 | goal = Goal() 64 | self.put_obj(goal, *self._goal_default_pos) 65 | goal.init_pos, goal.cur_pos = self._goal_default_pos 66 | else: 67 | self.place_obj(Goal()) 68 | 69 | self.mission = 'Reach the goal' 70 | 71 | def step(self, action): 72 | obs, reward, done, info = MiniGridEnv.step(self, action) 73 | return obs, reward, done, info 74 | 75 | register( 76 | id='MiniGrid-FourRooms-v0', 77 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:FourRoomsEnv' 78 | ) 79 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/gotodoor.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class GoToDoorEnv(MiniGridEnv): 5 | """ 6 | Environment in which the agent is instructed to go to a given object 7 | named using an English text string 8 | """ 9 | 10 | def __init__( 11 | self, 12 | size=5 13 | ): 14 | assert size >= 5 15 | 16 | super().__init__( 17 | grid_size=size, 18 | max_steps=5*size**2, 19 | # Set this to True for maximum speed 20 | see_through_walls=True 21 | ) 22 | 23 | def _gen_grid(self, width, height): 24 | # Create the grid 25 | self.grid = Grid(width, height) 26 | 27 | # Randomly vary the room width and height 28 | width = self._rand_int(5, width+1) 29 | height = self._rand_int(5, height+1) 30 | 31 | # Generate the surrounding walls 32 | self.grid.wall_rect(0, 0, width, height) 33 | 34 | # Generate the 4 doors at random positions 35 | doorPos = [] 36 | doorPos.append((self._rand_int(2, width-2), 0)) 37 | doorPos.append((self._rand_int(2, width-2), height-1)) 38 | doorPos.append((0, self._rand_int(2, height-2))) 39 | doorPos.append((width-1, self._rand_int(2, height-2))) 40 | 41 | # Generate the door colors 42 | doorColors = [] 43 | while len(doorColors) < len(doorPos): 44 | color = self._rand_elem(COLOR_NAMES) 45 | if color in doorColors: 46 | continue 47 | doorColors.append(color) 48 | 49 | # Place the doors in the grid 50 | for idx, pos in enumerate(doorPos): 51 | color = doorColors[idx] 52 | self.grid.set(*pos, Door(color)) 53 | 54 | # Randomize the agent start position and orientation 55 | self.place_agent(size=(width, height)) 56 | 57 | # Select a random target door 58 | doorIdx = self._rand_int(0, len(doorPos)) 59 | self.target_pos = doorPos[doorIdx] 60 | self.target_color = doorColors[doorIdx] 61 | 62 | # Generate the mission string 63 | self.mission = 'go to the %s door' % self.target_color 64 | 65 | def step(self, action): 66 | obs, reward, done, info = super().step(action) 67 | 68 | ax, ay = self.agent_pos 69 | tx, ty = self.target_pos 70 | 71 | # Don't let the agent open any of the doors 72 | if action == self.actions.toggle: 73 | done = True 74 | 75 | # Reward performing done action in front of the target door 76 | if action == self.actions.done: 77 | if (ax == tx and abs(ay - ty) == 1) or (ay == ty and abs(ax - tx) == 1): 78 | reward = self._reward() 79 | done = True 80 | 81 | return obs, reward, done, info 82 | 83 | class GoToDoor8x8Env(GoToDoorEnv): 84 | def __init__(self): 85 | super().__init__(size=8) 86 | 87 | class GoToDoor6x6Env(GoToDoorEnv): 88 | def __init__(self): 89 | super().__init__(size=6) 90 | 91 | register( 92 | id='MiniGrid-GoToDoor-5x5-v0', 93 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:GoToDoorEnv' 94 | ) 95 | 96 | register( 97 | id='MiniGrid-GoToDoor-6x6-v0', 98 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:GoToDoor6x6Env' 99 | ) 100 | 101 | register( 102 | id='MiniGrid-GoToDoor-8x8-v0', 103 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:GoToDoor8x8Env' 104 | ) 105 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/gotoobject.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class GoToObjectEnv(MiniGridEnv): 5 | """ 6 | Environment in which the agent is instructed to go to a given object 7 | named using an English text string 8 | """ 9 | 10 | def __init__( 11 | self, 12 | size=6, 13 | numObjs=2 14 | ): 15 | self.numObjs = numObjs 16 | 17 | super().__init__( 18 | grid_size=size, 19 | max_steps=5*size**2, 20 | # Set this to True for maximum speed 21 | see_through_walls=True 22 | ) 23 | 24 | def _gen_grid(self, width, height): 25 | self.grid = Grid(width, height) 26 | 27 | # Generate the surrounding walls 28 | self.grid.wall_rect(0, 0, width, height) 29 | 30 | # Types and colors of objects we can generate 31 | types = ['key', 'ball', 'box'] 32 | 33 | objs = [] 34 | objPos = [] 35 | 36 | # Until we have generated all the objects 37 | while len(objs) < self.numObjs: 38 | objType = self._rand_elem(types) 39 | objColor = self._rand_elem(COLOR_NAMES) 40 | 41 | # If this object already exists, try again 42 | if (objType, objColor) in objs: 43 | continue 44 | 45 | if objType == 'key': 46 | obj = Key(objColor) 47 | elif objType == 'ball': 48 | obj = Ball(objColor) 49 | elif objType == 'box': 50 | obj = Box(objColor) 51 | 52 | pos = self.place_obj(obj) 53 | objs.append((objType, objColor)) 54 | objPos.append(pos) 55 | 56 | # Randomize the agent start position and orientation 57 | self.place_agent() 58 | 59 | # Choose a random object to be picked up 60 | objIdx = self._rand_int(0, len(objs)) 61 | self.targetType, self.target_color = objs[objIdx] 62 | self.target_pos = objPos[objIdx] 63 | 64 | descStr = '%s %s' % (self.target_color, self.targetType) 65 | self.mission = 'go to the %s' % descStr 66 | #print(self.mission) 67 | 68 | def step(self, action): 69 | obs, reward, done, info = MiniGridEnv.step(self, action) 70 | 71 | ax, ay = self.agent_pos 72 | tx, ty = self.target_pos 73 | 74 | # Toggle/pickup action terminates the episode 75 | if action == self.actions.toggle: 76 | done = True 77 | 78 | # Reward performing the done action next to the target object 79 | if action == self.actions.done: 80 | if abs(ax - tx) <= 1 and abs(ay - ty) <= 1: 81 | reward = self._reward() 82 | done = True 83 | 84 | return obs, reward, done, info 85 | 86 | class GotoEnv8x8N2(GoToObjectEnv): 87 | def __init__(self): 88 | super().__init__(size=8, numObjs=2) 89 | 90 | register( 91 | id='MiniGrid-GoToObject-6x6-N2-v0', 92 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:GoToObjectEnv' 93 | ) 94 | 95 | register( 96 | id='MiniGrid-GoToObject-8x8-N2-v0', 97 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:GotoEnv8x8N2' 98 | ) 99 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/human.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from icecream import ic 3 | import collections 4 | import math 5 | from copy import deepcopy 6 | 7 | class HumanEnv(MiniGridEnv): 8 | """ 9 | Environment in which the agent is instructed to go to a given object 10 | named using an English text string 11 | """ 12 | 13 | def __init__( 14 | self, 15 | num_agents=2, 16 | num_preies=2, 17 | num_obstacles=4, 18 | direction_alpha=0.5, 19 | use_direction_reward = False, 20 | use_human_command=False, 21 | coverage_discounter=0.1, 22 | size=19 23 | ): 24 | self.size = size 25 | self.num_preies = num_preies 26 | self.use_direction_reward = use_direction_reward 27 | self.direction_alpha = direction_alpha 28 | self.use_human_command = use_human_command 29 | self.coverage_discounter = coverage_discounter 30 | # initial the covering rate 31 | self.covering_rate = 0 32 | # Reduce obstacles if there are too many 33 | if num_obstacles <= size/2 + 1: 34 | self.num_obstacles = int(num_obstacles) 35 | else: 36 | self.num_obstacles = int(size/2) 37 | 38 | super().__init__( 39 | num_agents=num_agents, 40 | grid_size=size, 41 | max_steps=math.floor(((size-2)**2) / num_agents * 2), 42 | # Set this to True for maximum speed 43 | see_through_walls=True 44 | ) 45 | 46 | def _gen_grid(self, width, height): 47 | # Create the grid 48 | self.grid = Grid(width, height) 49 | 50 | # Generate the surrounding walls 51 | self.grid.horz_wall(0, 0) 52 | self.grid.horz_wall(0, height - 1) 53 | self.grid.vert_wall(0, 0) 54 | self.grid.vert_wall(width - 1, 0) 55 | 56 | room_w = width // 2 57 | room_h = height // 2 58 | 59 | # For each row of rooms 60 | for j in range(0, 2): 61 | 62 | # For each column 63 | for i in range(0, 2): 64 | xL = i * room_w 65 | yT = j * room_h 66 | xR = xL + room_w 67 | yB = yT + room_h 68 | 69 | # Bottom wall and door 70 | if i + 1 < 2: 71 | self.grid.vert_wall(xR, yT, room_h) 72 | pos = (xR, self._rand_int(yT + 1, yB)) 73 | self.grid.set(*pos, None) 74 | 75 | # Bottom wall and door 76 | if j + 1 < 2: 77 | self.grid.horz_wall(xL, yB, room_w) 78 | pos = (self._rand_int(xL + 1, xR), yB) 79 | self.grid.set(*pos, None) 80 | 81 | # initial the cover_grid 82 | self.cover_grid = np.zeros([width,height]) 83 | for j in range(0, height): 84 | for i in range(0, width): 85 | if self.grid.get(i,j) != None and self.grid.get(i,j).type == 'wall': 86 | self.cover_grid[j,i] = 1.0 87 | self.cover_grid_initial = self.cover_grid.copy() 88 | self.num_none = collections.Counter(self.cover_grid_initial.flatten())[0.] 89 | # import pdb; pdb.set_trace() 90 | 91 | # Types and colors of objects we can generate 92 | types = ['key'] 93 | 94 | objs = [] 95 | objPos = [] 96 | 97 | # Until we have generated all the objects 98 | while len(objs) < self.num_preies: 99 | objType = self._rand_elem(types) 100 | objColor = self._rand_elem(COLOR_NAMES) 101 | 102 | # If this object already exists, try again 103 | if (objType, objColor) in objs: 104 | continue 105 | 106 | if objType == 'key': 107 | obj = Key(objColor) 108 | elif objType == 'box': 109 | obj = Box(objColor) 110 | elif objType == 'ball': 111 | obj = Ball(objColor) 112 | 113 | pos = self.place_obj(obj) 114 | objs.append((objType, objColor)) 115 | objPos.append(pos) 116 | 117 | # Place obstacles 118 | self.obstacles = [] 119 | for i_obst in range(self.num_obstacles): 120 | self.obstacles.append(Obstacle()) 121 | pos = self.place_obj(self.obstacles[i_obst], max_tries=100) 122 | 123 | self.occupy_grid = self.grid.copy() 124 | # Randomize the agent start position and orientation 125 | self.place_agent() 126 | 127 | # Choose a random object to be picked up 128 | objIdx = self._rand_int(0, len(objs)) 129 | self.targetType, self.target_color = objs[objIdx] 130 | self.target_pos = objPos[objIdx] 131 | 132 | # direction 133 | array_direction = np.array([[0,1], [0,-1], [1,0], [-1,0], [1,1], [1,-1], [-1,1], [-1,-1]]) 134 | self.direction = [] 135 | self.direction_encoder = [] 136 | self.direction_index = [] 137 | for agent_id in range(self.num_agents): 138 | center_pos = np.array([int((self.size-1)/2),int((self.size-1)/2)]) 139 | direction = np.sign(center_pos - self.agent_pos[agent_id]) 140 | direction_index = np.argmax(np.all(np.where(array_direction == direction, True, False), axis=1)) 141 | direction_encoder = np.eye(8)[direction_index] 142 | self.direction_index.append(direction_index) 143 | self.direction.append(direction) 144 | self.direction_encoder.append(direction_encoder) 145 | 146 | # text 147 | descStr = '%s %s' % (self.target_color, self.targetType) 148 | self.mission = 'go to the %s' % descStr 149 | # print(self.mission) 150 | 151 | def step(self, action): 152 | obs, reward, done, info = MiniGridEnv.step(self, action) 153 | 154 | rewards = [] 155 | 156 | for agent_id in range(self.num_agents): 157 | ax, ay = self.agent_pos[agent_id] 158 | tx, ty = self.target_pos 159 | if self.cover_grid[ay,ax] == 0: 160 | reward += self.coverage_discounter 161 | self.cover_grid[ay, ax] = 1.0 162 | self.covering_rate = collections.Counter((self.cover_grid - self.cover_grid_initial).flatten())[1] / self.num_none 163 | 164 | # if abs(ax - tx) < 1 and abs(ay - ty) < 1: 165 | # reward += 1.0 166 | # self.num_reach_goal += 1 167 | # # done = True 168 | 169 | rewards.append(reward) 170 | 171 | rewards = [[np.sum(rewards)]] * self.num_agents 172 | 173 | dones = [done for agent_id in range(self.num_agents)] 174 | 175 | info['num_reach_goal'] = self.num_reach_goal 176 | info['covering_rate'] = self.covering_rate 177 | info['num_same_direction'] = self.num_same_direction 178 | 179 | return obs, rewards, dones, info 180 | 181 | 182 | 183 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/keycorridor.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.roomgrid import RoomGrid 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class KeyCorridor(RoomGrid): 5 | """ 6 | A ball is behind a locked door, the key is placed in a 7 | random room. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | num_rows=3, 13 | obj_type="ball", 14 | room_size=6, 15 | seed=None 16 | ): 17 | self.obj_type = obj_type 18 | 19 | super().__init__( 20 | room_size=room_size, 21 | num_rows=num_rows, 22 | max_steps=30*room_size**2, 23 | seed=seed, 24 | ) 25 | 26 | def _gen_grid(self, width, height): 27 | super()._gen_grid(width, height) 28 | 29 | # Connect the middle column rooms into a hallway 30 | for j in range(1, self.num_rows): 31 | self.remove_wall(1, j, 3) 32 | 33 | # Add a locked door on the bottom right 34 | # Add an object behind the locked door 35 | room_idx = self._rand_int(0, self.num_rows) 36 | door, _ = self.add_door(2, room_idx, 2, locked=True) 37 | obj, _ = self.add_object(2, room_idx, kind=self.obj_type) 38 | 39 | # Add a key in a random room on the left side 40 | self.add_object(0, self._rand_int(0, self.num_rows), 'key', door.color) 41 | 42 | # Place the agent in the middle 43 | self.place_agent(1, self.num_rows // 2) 44 | 45 | # Make sure all rooms are accessible 46 | self.connect_all() 47 | 48 | self.obj = obj 49 | self.mission = "pick up the %s %s" % (obj.color, obj.type) 50 | 51 | def step(self, action): 52 | obs, reward, done, info = super().step(action) 53 | 54 | if action == self.actions.pickup: 55 | if self.carrying and self.carrying == self.obj: 56 | reward = self._reward() 57 | done = True 58 | 59 | return obs, reward, done, info 60 | 61 | class KeyCorridorS3R1(KeyCorridor): 62 | def __init__(self, seed=None): 63 | super().__init__( 64 | room_size=3, 65 | num_rows=1, 66 | seed=seed 67 | ) 68 | 69 | class KeyCorridorS3R2(KeyCorridor): 70 | def __init__(self, seed=None): 71 | super().__init__( 72 | room_size=3, 73 | num_rows=2, 74 | seed=seed 75 | ) 76 | 77 | class KeyCorridorS3R3(KeyCorridor): 78 | def __init__(self, seed=None): 79 | super().__init__( 80 | room_size=3, 81 | num_rows=3, 82 | seed=seed 83 | ) 84 | 85 | class KeyCorridorS4R3(KeyCorridor): 86 | def __init__(self, seed=None): 87 | super().__init__( 88 | room_size=4, 89 | num_rows=3, 90 | seed=seed 91 | ) 92 | 93 | class KeyCorridorS5R3(KeyCorridor): 94 | def __init__(self, seed=None): 95 | super().__init__( 96 | room_size=5, 97 | num_rows=3, 98 | seed=seed 99 | ) 100 | 101 | class KeyCorridorS6R3(KeyCorridor): 102 | def __init__(self, seed=None): 103 | super().__init__( 104 | room_size=6, 105 | num_rows=3, 106 | seed=seed 107 | ) 108 | 109 | register( 110 | id='MiniGrid-KeyCorridorS3R1-v0', 111 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS3R1' 112 | ) 113 | 114 | register( 115 | id='MiniGrid-KeyCorridorS3R2-v0', 116 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS3R2' 117 | ) 118 | 119 | register( 120 | id='MiniGrid-KeyCorridorS3R3-v0', 121 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS3R3' 122 | ) 123 | 124 | register( 125 | id='MiniGrid-KeyCorridorS4R3-v0', 126 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS4R3' 127 | ) 128 | 129 | register( 130 | id='MiniGrid-KeyCorridorS5R3-v0', 131 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS5R3' 132 | ) 133 | 134 | register( 135 | id='MiniGrid-KeyCorridorS6R3-v0', 136 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:KeyCorridorS6R3' 137 | ) 138 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/lavagap.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class LavaGapEnv(MiniGridEnv): 5 | """ 6 | Environment with one wall of lava with a small gap to cross through 7 | This environment is similar to LavaCrossing but simpler in structure. 8 | """ 9 | 10 | def __init__(self, size, obstacle_type=Lava, seed=None): 11 | self.obstacle_type = obstacle_type 12 | super().__init__( 13 | grid_size=size, 14 | max_steps=4*size*size, 15 | # Set this to True for maximum speed 16 | see_through_walls=False, 17 | seed=None 18 | ) 19 | 20 | def _gen_grid(self, width, height): 21 | assert width >= 5 and height >= 5 22 | 23 | # Create an empty grid 24 | self.grid = Grid(width, height) 25 | 26 | # Generate the surrounding walls 27 | self.grid.wall_rect(0, 0, width, height) 28 | 29 | # Place the agent in the top-left corner 30 | self.agent_pos = (1, 1) 31 | self.agent_dir = 0 32 | 33 | # Place a goal square in the bottom-right corner 34 | self.goal_pos = np.array((width - 2, height - 2)) 35 | self.put_obj(Goal(), *self.goal_pos) 36 | 37 | # Generate and store random gap position 38 | self.gap_pos = np.array(( 39 | self._rand_int(2, width - 2), 40 | self._rand_int(1, height - 1), 41 | )) 42 | 43 | # Place the obstacle wall 44 | self.grid.vert_wall(self.gap_pos[0], 1, height - 2, self.obstacle_type) 45 | 46 | # Put a hole in the wall 47 | self.grid.set(*self.gap_pos, None) 48 | 49 | self.mission = ( 50 | "avoid the lava and get to the green goal square" 51 | if self.obstacle_type == Lava 52 | else "find the opening and get to the green goal square" 53 | ) 54 | 55 | class LavaGapS5Env(LavaGapEnv): 56 | def __init__(self): 57 | super().__init__(size=5) 58 | 59 | class LavaGapS6Env(LavaGapEnv): 60 | def __init__(self): 61 | super().__init__(size=6) 62 | 63 | class LavaGapS7Env(LavaGapEnv): 64 | def __init__(self): 65 | super().__init__(size=7) 66 | 67 | register( 68 | id='MiniGrid-LavaGapS5-v0', 69 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaGapS5Env' 70 | ) 71 | 72 | register( 73 | id='MiniGrid-LavaGapS6-v0', 74 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaGapS6Env' 75 | ) 76 | 77 | register( 78 | id='MiniGrid-LavaGapS7-v0', 79 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LavaGapS7Env' 80 | ) 81 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/lockedroom.py: -------------------------------------------------------------------------------- 1 | from gym import spaces 2 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 3 | from onpolicy.envs.gridworld.gym_minigrid.register import register 4 | 5 | class Room: 6 | def __init__(self, 7 | top, 8 | size, 9 | doorPos 10 | ): 11 | self.top = top 12 | self.size = size 13 | self.doorPos = doorPos 14 | self.color = None 15 | self.locked = False 16 | 17 | def rand_pos(self, env): 18 | topX, topY = self.top 19 | sizeX, sizeY = self.size 20 | return env._rand_pos( 21 | topX + 1, topX + sizeX - 1, 22 | topY + 1, topY + sizeY - 1 23 | ) 24 | 25 | class LockedRoom(MiniGridEnv): 26 | """ 27 | Environment in which the agent is instructed to go to a given object 28 | named using an English text string 29 | """ 30 | 31 | def __init__( 32 | self, 33 | size=19 34 | ): 35 | super().__init__(grid_size=size, max_steps=10*size) 36 | 37 | def _gen_grid(self, width, height): 38 | # Create the grid 39 | self.grid = Grid(width, height) 40 | 41 | # Generate the surrounding walls 42 | for i in range(0, width): 43 | self.grid.set(i, 0, Wall()) 44 | self.grid.set(i, height-1, Wall()) 45 | for j in range(0, height): 46 | self.grid.set(0, j, Wall()) 47 | self.grid.set(width-1, j, Wall()) 48 | 49 | # Hallway walls 50 | lWallIdx = width // 2 - 2 51 | rWallIdx = width // 2 + 2 52 | for j in range(0, height): 53 | self.grid.set(lWallIdx, j, Wall()) 54 | self.grid.set(rWallIdx, j, Wall()) 55 | 56 | self.rooms = [] 57 | 58 | # Room splitting walls 59 | for n in range(0, 3): 60 | j = n * (height // 3) 61 | for i in range(0, lWallIdx): 62 | self.grid.set(i, j, Wall()) 63 | for i in range(rWallIdx, width): 64 | self.grid.set(i, j, Wall()) 65 | 66 | roomW = lWallIdx + 1 67 | roomH = height // 3 + 1 68 | self.rooms.append(Room( 69 | (0, j), 70 | (roomW, roomH), 71 | (lWallIdx, j + 3) 72 | )) 73 | self.rooms.append(Room( 74 | (rWallIdx, j), 75 | (roomW, roomH), 76 | (rWallIdx, j + 3) 77 | )) 78 | 79 | # Choose one random room to be locked 80 | lockedRoom = self._rand_elem(self.rooms) 81 | lockedRoom.locked = True 82 | goalPos = lockedRoom.rand_pos(self) 83 | self.grid.set(*goalPos, Goal()) 84 | 85 | # Assign the door colors 86 | colors = set(COLOR_NAMES) 87 | for room in self.rooms: 88 | color = self._rand_elem(sorted(colors)) 89 | colors.remove(color) 90 | room.color = color 91 | if room.locked: 92 | self.grid.set(*room.doorPos, Door(color, is_locked=True)) 93 | else: 94 | self.grid.set(*room.doorPos, Door(color)) 95 | 96 | # Select a random room to contain the key 97 | while True: 98 | keyRoom = self._rand_elem(self.rooms) 99 | if keyRoom != lockedRoom: 100 | break 101 | keyPos = keyRoom.rand_pos(self) 102 | self.grid.set(*keyPos, Key(lockedRoom.color)) 103 | 104 | # Randomize the player start position and orientation 105 | self.agent_pos = self.place_agent( 106 | top=(lWallIdx, 0), 107 | size=(rWallIdx-lWallIdx, height) 108 | ) 109 | 110 | # Generate the mission string 111 | self.mission = ( 112 | 'get the %s key from the %s room, ' 113 | 'unlock the %s door and ' 114 | 'go to the goal' 115 | ) % (lockedRoom.color, keyRoom.color, lockedRoom.color) 116 | 117 | def step(self, action): 118 | obs, reward, done, info = MiniGridEnv.step(self, action) 119 | return obs, reward, done, info 120 | 121 | register( 122 | id='MiniGrid-LockedRoom-v0', 123 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:LockedRoom' 124 | ) 125 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/memory.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class MemoryEnv(MiniGridEnv): 5 | """ 6 | This environment is a memory test. The agent starts in a small room 7 | where it sees an object. It then has to go through a narrow hallway 8 | which ends in a split. At each end of the split there is an object, 9 | one of which is the same as the object in the starting room. The 10 | agent has to remember the initial object, and go to the matching 11 | object at split. 12 | """ 13 | 14 | def __init__( 15 | self, 16 | seed, 17 | size=8, 18 | random_length=False, 19 | ): 20 | self.random_length = random_length 21 | super().__init__( 22 | seed=seed, 23 | grid_size=size, 24 | max_steps=5*size**2, 25 | # Set this to True for maximum speed 26 | see_through_walls=False, 27 | ) 28 | 29 | def _gen_grid(self, width, height): 30 | self.grid = Grid(width, height) 31 | 32 | # Generate the surrounding walls 33 | self.grid.horz_wall(0, 0) 34 | self.grid.horz_wall(0, height-1) 35 | self.grid.vert_wall(0, 0) 36 | self.grid.vert_wall(width - 1, 0) 37 | 38 | assert height % 2 == 1 39 | upper_room_wall = height // 2 - 2 40 | lower_room_wall = height // 2 + 2 41 | if self.random_length: 42 | hallway_end = self._rand_int(4, width - 2) 43 | else: 44 | hallway_end = width - 3 45 | 46 | # Start room 47 | for i in range(1, 5): 48 | self.grid.set(i, upper_room_wall, Wall()) 49 | self.grid.set(i, lower_room_wall, Wall()) 50 | self.grid.set(4, upper_room_wall + 1, Wall()) 51 | self.grid.set(4, lower_room_wall - 1, Wall()) 52 | 53 | # Horizontal hallway 54 | for i in range(5, hallway_end): 55 | self.grid.set(i, upper_room_wall + 1, Wall()) 56 | self.grid.set(i, lower_room_wall - 1, Wall()) 57 | 58 | # Vertical hallway 59 | for j in range(0, height): 60 | if j != height // 2: 61 | self.grid.set(hallway_end, j, Wall()) 62 | self.grid.set(hallway_end + 2, j, Wall()) 63 | 64 | # Fix the player's start position and orientation 65 | self.agent_pos = (self._rand_int(1, hallway_end + 1), height // 2) 66 | self.agent_dir = 0 67 | 68 | # Place objects 69 | start_room_obj = self._rand_elem([Key, Ball]) 70 | self.grid.set(1, height // 2 - 1, start_room_obj('green')) 71 | 72 | other_objs = self._rand_elem([[Ball, Key], [Key, Ball]]) 73 | pos0 = (hallway_end + 1, height // 2 - 2) 74 | pos1 = (hallway_end + 1, height // 2 + 2) 75 | self.grid.set(*pos0, other_objs[0]('green')) 76 | self.grid.set(*pos1, other_objs[1]('green')) 77 | 78 | # Choose the target objects 79 | if start_room_obj == other_objs[0]: 80 | self.success_pos = (pos0[0], pos0[1] + 1) 81 | self.failure_pos = (pos1[0], pos1[1] - 1) 82 | else: 83 | self.success_pos = (pos1[0], pos1[1] - 1) 84 | self.failure_pos = (pos0[0], pos0[1] + 1) 85 | 86 | self.mission = 'go to the matching object at the end of the hallway' 87 | 88 | def step(self, action): 89 | if action == MiniGridEnv.Actions.pickup: 90 | action = MiniGridEnv.Actions.toggle 91 | obs, reward, done, info = MiniGridEnv.step(self, action) 92 | 93 | if tuple(self.agent_pos) == self.success_pos: 94 | reward = self._reward() 95 | done = True 96 | if tuple(self.agent_pos) == self.failure_pos: 97 | reward = 0 98 | done = True 99 | 100 | return obs, reward, done, info 101 | 102 | class MemoryS17Random(MemoryEnv): 103 | def __init__(self, seed=None): 104 | super().__init__(seed=seed, size=17, random_length=True) 105 | 106 | register( 107 | id='MiniGrid-MemoryS17Random-v0', 108 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS17Random', 109 | ) 110 | 111 | class MemoryS13Random(MemoryEnv): 112 | def __init__(self, seed=None): 113 | super().__init__(seed=seed, size=13, random_length=True) 114 | 115 | register( 116 | id='MiniGrid-MemoryS13Random-v0', 117 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS13Random', 118 | ) 119 | 120 | class MemoryS13(MemoryEnv): 121 | def __init__(self, seed=None): 122 | super().__init__(seed=seed, size=13) 123 | 124 | register( 125 | id='MiniGrid-MemoryS13-v0', 126 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS13', 127 | ) 128 | 129 | class MemoryS11(MemoryEnv): 130 | def __init__(self, seed=None): 131 | super().__init__(seed=seed, size=11) 132 | 133 | register( 134 | id='MiniGrid-MemoryS11-v0', 135 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS11', 136 | ) 137 | 138 | class MemoryS9(MemoryEnv): 139 | def __init__(self, seed=None): 140 | super().__init__(seed=seed, size=9) 141 | 142 | register( 143 | id='MiniGrid-MemoryS9-v0', 144 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS9', 145 | ) 146 | 147 | class MemoryS7(MemoryEnv): 148 | def __init__(self, seed=None): 149 | super().__init__(seed=seed, size=7) 150 | 151 | register( 152 | id='MiniGrid-MemoryS7-v0', 153 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:MemoryS7', 154 | ) 155 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/obstructedmaze.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.roomgrid import RoomGrid 3 | from onpolicy.envs.gridworld.gym_minigrid.register import register 4 | 5 | class ObstructedMazeEnv(RoomGrid): 6 | """ 7 | A blue ball is hidden in the maze. Doors may be locked, 8 | doors may be obstructed by a ball and keys may be hidden in boxes. 9 | """ 10 | 11 | def __init__(self, 12 | num_rows, 13 | num_cols, 14 | num_rooms_visited, 15 | seed=None 16 | ): 17 | room_size = 6 18 | max_steps = 4*num_rooms_visited*room_size**2 19 | 20 | super().__init__( 21 | room_size=room_size, 22 | num_rows=num_rows, 23 | num_cols=num_cols, 24 | max_steps=max_steps, 25 | seed=seed 26 | ) 27 | 28 | def _gen_grid(self, width, height): 29 | super()._gen_grid(width, height) 30 | 31 | # Define all possible colors for doors 32 | self.door_colors = self._rand_subset(COLOR_NAMES, len(COLOR_NAMES)) 33 | # Define the color of the ball to pick up 34 | self.ball_to_find_color = COLOR_NAMES[0] 35 | # Define the color of the balls that obstruct doors 36 | self.blocking_ball_color = COLOR_NAMES[1] 37 | # Define the color of boxes in which keys are hidden 38 | self.box_color = COLOR_NAMES[2] 39 | 40 | self.mission = "pick up the %s ball" % self.ball_to_find_color 41 | 42 | def step(self, action): 43 | obs, reward, done, info = super().step(action) 44 | 45 | if action == self.actions.pickup: 46 | if self.carrying and self.carrying == self.obj: 47 | reward = self._reward() 48 | done = True 49 | 50 | return obs, reward, done, info 51 | 52 | def add_door(self, i, j, door_idx=0, color=None, locked=False, key_in_box=False, blocked=False): 53 | """ 54 | Add a door. If the door must be locked, it also adds the key. 55 | If the key must be hidden, it is put in a box. If the door must 56 | be obstructed, it adds a ball in front of the door. 57 | """ 58 | 59 | door, door_pos = super().add_door(i, j, door_idx, color, locked=locked) 60 | 61 | if blocked: 62 | vec = DIR_TO_VEC[door_idx] 63 | blocking_ball = Ball(self.blocking_ball_color) if blocked else None 64 | self.grid.set(door_pos[0]-vec[0], door_pos[1]-vec[1], blocking_ball) 65 | 66 | if locked: 67 | obj = Key(door.color) 68 | if key_in_box: 69 | box = Box(self.box_color) if key_in_box else None 70 | box.contains = obj 71 | obj = box 72 | self.place_in_room(i, j, obj) 73 | 74 | return door, door_pos 75 | 76 | class ObstructedMaze_1Dlhb(ObstructedMazeEnv): 77 | """ 78 | A blue ball is hidden in a 2x1 maze. A locked door separates 79 | rooms. Doors are obstructed by a ball and keys are hidden in boxes. 80 | """ 81 | 82 | def __init__(self, key_in_box=True, blocked=True, seed=None): 83 | self.key_in_box = key_in_box 84 | self.blocked = blocked 85 | 86 | super().__init__( 87 | num_rows=1, 88 | num_cols=2, 89 | num_rooms_visited=2, 90 | seed=seed 91 | ) 92 | 93 | def _gen_grid(self, width, height): 94 | super()._gen_grid(width, height) 95 | 96 | self.add_door(0, 0, door_idx=0, color=self.door_colors[0], 97 | locked=True, 98 | key_in_box=self.key_in_box, 99 | blocked=self.blocked) 100 | 101 | self.obj, _ = self.add_object(1, 0, "ball", color=self.ball_to_find_color) 102 | self.place_agent(0, 0) 103 | 104 | class ObstructedMaze_1Dl(ObstructedMaze_1Dlhb): 105 | def __init__(self, seed=None): 106 | super().__init__(False, False, seed) 107 | 108 | class ObstructedMaze_1Dlh(ObstructedMaze_1Dlhb): 109 | def __init__(self, seed=None): 110 | super().__init__(True, False, seed) 111 | 112 | class ObstructedMaze_Full(ObstructedMazeEnv): 113 | """ 114 | A blue ball is hidden in one of the 4 corners of a 3x3 maze. Doors 115 | are locked, doors are obstructed by a ball and keys are hidden in 116 | boxes. 117 | """ 118 | 119 | def __init__(self, agent_room=(1, 1), key_in_box=True, blocked=True, 120 | num_quarters=4, num_rooms_visited=25, seed=None): 121 | self.agent_room = agent_room 122 | self.key_in_box = key_in_box 123 | self.blocked = blocked 124 | self.num_quarters = num_quarters 125 | 126 | super().__init__( 127 | num_rows=3, 128 | num_cols=3, 129 | num_rooms_visited=num_rooms_visited, 130 | seed=seed 131 | ) 132 | 133 | def _gen_grid(self, width, height): 134 | super()._gen_grid(width, height) 135 | 136 | middle_room = (1, 1) 137 | # Define positions of "side rooms" i.e. rooms that are neither 138 | # corners nor the center. 139 | side_rooms = [(2, 1), (1, 2), (0, 1), (1, 0)][:self.num_quarters] 140 | for i in range(len(side_rooms)): 141 | side_room = side_rooms[i] 142 | 143 | # Add a door between the center room and the side room 144 | self.add_door(*middle_room, door_idx=i, color=self.door_colors[i], locked=False) 145 | 146 | for k in [-1, 1]: 147 | # Add a door to each side of the side room 148 | self.add_door(*side_room, locked=True, 149 | door_idx=(i+k)%4, 150 | color=self.door_colors[(i+k)%len(self.door_colors)], 151 | key_in_box=self.key_in_box, 152 | blocked=self.blocked) 153 | 154 | corners = [(2, 0), (2, 2), (0, 2), (0, 0)][:self.num_quarters] 155 | ball_room = self._rand_elem(corners) 156 | 157 | self.obj, _ = self.add_object(*ball_room, "ball", color=self.ball_to_find_color) 158 | self.place_agent(*self.agent_room) 159 | 160 | class ObstructedMaze_2Dl(ObstructedMaze_Full): 161 | def __init__(self, seed=None): 162 | super().__init__((2, 1), False, False, 1, 4, seed) 163 | 164 | class ObstructedMaze_2Dlh(ObstructedMaze_Full): 165 | def __init__(self, seed=None): 166 | super().__init__((2, 1), True, False, 1, 4, seed) 167 | 168 | 169 | class ObstructedMaze_2Dlhb(ObstructedMaze_Full): 170 | def __init__(self, seed=None): 171 | super().__init__((2, 1), True, True, 1, 4, seed) 172 | 173 | class ObstructedMaze_1Q(ObstructedMaze_Full): 174 | def __init__(self, seed=None): 175 | super().__init__((1, 1), True, True, 1, 5, seed) 176 | 177 | class ObstructedMaze_2Q(ObstructedMaze_Full): 178 | def __init__(self, seed=None): 179 | super().__init__((1, 1), True, True, 2, 11, seed) 180 | 181 | register( 182 | id="MiniGrid-ObstructedMaze-1Dl-v0", 183 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_1Dl" 184 | ) 185 | 186 | register( 187 | id="MiniGrid-ObstructedMaze-1Dlh-v0", 188 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_1Dlh" 189 | ) 190 | 191 | register( 192 | id="MiniGrid-ObstructedMaze-1Dlhb-v0", 193 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_1Dlhb" 194 | ) 195 | 196 | register( 197 | id="MiniGrid-ObstructedMaze-2Dl-v0", 198 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_2Dl" 199 | ) 200 | 201 | register( 202 | id="MiniGrid-ObstructedMaze-2Dlh-v0", 203 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_2Dlh" 204 | ) 205 | 206 | register( 207 | id="MiniGrid-ObstructedMaze-2Dlhb-v0", 208 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_2Dlhb" 209 | ) 210 | 211 | register( 212 | id="MiniGrid-ObstructedMaze-1Q-v0", 213 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_1Q" 214 | ) 215 | 216 | register( 217 | id="MiniGrid-ObstructedMaze-2Q-v0", 218 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_2Q" 219 | ) 220 | 221 | register( 222 | id="MiniGrid-ObstructedMaze-Full-v0", 223 | entry_point="onpolicy.envs.gridworld.gym_minigrid.envs:ObstructedMaze_Full" 224 | ) -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/playground_v0.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class PlaygroundV0(MiniGridEnv): 5 | """ 6 | Environment with multiple rooms and random objects. 7 | This environment has no specific goals or rewards. 8 | """ 9 | 10 | def __init__(self): 11 | super().__init__(grid_size=19, max_steps=100) 12 | 13 | def _gen_grid(self, width, height): 14 | # Create the grid 15 | self.grid = Grid(width, height) 16 | 17 | # Generate the surrounding walls 18 | self.grid.horz_wall(0, 0) 19 | self.grid.horz_wall(0, height-1) 20 | self.grid.vert_wall(0, 0) 21 | self.grid.vert_wall(width-1, 0) 22 | 23 | roomW = width // 3 24 | roomH = height // 3 25 | 26 | # For each row of rooms 27 | for j in range(0, 3): 28 | 29 | # For each column 30 | for i in range(0, 3): 31 | xL = i * roomW 32 | yT = j * roomH 33 | xR = xL + roomW 34 | yB = yT + roomH 35 | 36 | # Bottom wall and door 37 | if i+1 < 3: 38 | self.grid.vert_wall(xR, yT, roomH) 39 | pos = (xR, self._rand_int(yT+1, yB-1)) 40 | color = self._rand_elem(COLOR_NAMES) 41 | self.grid.set(*pos, Door(color)) 42 | 43 | # Bottom wall and door 44 | if j+1 < 3: 45 | self.grid.horz_wall(xL, yB, roomW) 46 | pos = (self._rand_int(xL+1, xR-1), yB) 47 | color = self._rand_elem(COLOR_NAMES) 48 | self.grid.set(*pos, Door(color)) 49 | 50 | # Randomize the player start position and orientation 51 | self.place_agent() 52 | 53 | # Place random objects in the world 54 | types = ['key', 'ball', 'box'] 55 | for i in range(0, 12): 56 | objType = self._rand_elem(types) 57 | objColor = self._rand_elem(COLOR_NAMES) 58 | if objType == 'key': 59 | obj = Key(objColor) 60 | elif objType == 'ball': 61 | obj = Ball(objColor) 62 | elif objType == 'box': 63 | obj = Box(objColor) 64 | self.place_obj(obj) 65 | 66 | # No explicit mission in this environment 67 | self.mission = '' 68 | 69 | def step(self, action): 70 | obs, reward, done, info = MiniGridEnv.step(self, action) 71 | return obs, reward, done, info 72 | 73 | register( 74 | id='MiniGrid-Playground-v0', 75 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:PlaygroundV0' 76 | ) 77 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/putnear.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class PutNearEnv(MiniGridEnv): 5 | """ 6 | Environment in which the agent is instructed to place an object near 7 | another object through a natural language string. 8 | """ 9 | 10 | def __init__( 11 | self, 12 | size=6, 13 | numObjs=2 14 | ): 15 | self.numObjs = numObjs 16 | 17 | super().__init__( 18 | grid_size=size, 19 | max_steps=5*size, 20 | # Set this to True for maximum speed 21 | see_through_walls=True 22 | ) 23 | 24 | def _gen_grid(self, width, height): 25 | self.grid = Grid(width, height) 26 | 27 | # Generate the surrounding walls 28 | self.grid.horz_wall(0, 0) 29 | self.grid.horz_wall(0, height-1) 30 | self.grid.vert_wall(0, 0) 31 | self.grid.vert_wall(width-1, 0) 32 | 33 | # Types and colors of objects we can generate 34 | types = ['key', 'ball', 'box'] 35 | 36 | objs = [] 37 | objPos = [] 38 | 39 | def near_obj(env, p1): 40 | for p2 in objPos: 41 | dx = p1[0] - p2[0] 42 | dy = p1[1] - p2[1] 43 | if abs(dx) <= 1 and abs(dy) <= 1: 44 | return True 45 | return False 46 | 47 | # Until we have generated all the objects 48 | while len(objs) < self.numObjs: 49 | objType = self._rand_elem(types) 50 | objColor = self._rand_elem(COLOR_NAMES) 51 | 52 | # If this object already exists, try again 53 | if (objType, objColor) in objs: 54 | continue 55 | 56 | if objType == 'key': 57 | obj = Key(objColor) 58 | elif objType == 'ball': 59 | obj = Ball(objColor) 60 | elif objType == 'box': 61 | obj = Box(objColor) 62 | 63 | pos = self.place_obj(obj, reject_fn=near_obj) 64 | 65 | objs.append((objType, objColor)) 66 | objPos.append(pos) 67 | 68 | # Randomize the agent start position and orientation 69 | self.place_agent() 70 | 71 | # Choose a random object to be moved 72 | objIdx = self._rand_int(0, len(objs)) 73 | self.move_type, self.moveColor = objs[objIdx] 74 | self.move_pos = objPos[objIdx] 75 | 76 | # Choose a target object (to put the first object next to) 77 | while True: 78 | targetIdx = self._rand_int(0, len(objs)) 79 | if targetIdx != objIdx: 80 | break 81 | self.target_type, self.target_color = objs[targetIdx] 82 | self.target_pos = objPos[targetIdx] 83 | 84 | self.mission = 'put the %s %s near the %s %s' % ( 85 | self.moveColor, 86 | self.move_type, 87 | self.target_color, 88 | self.target_type 89 | ) 90 | 91 | def step(self, action): 92 | preCarrying = self.carrying 93 | 94 | obs, reward, done, info = super().step(action) 95 | 96 | u, v = self.dir_vec 97 | ox, oy = (self.agent_pos[0] + u, self.agent_pos[1] + v) 98 | tx, ty = self.target_pos 99 | 100 | # If we picked up the wrong object, terminate the episode 101 | if action == self.actions.pickup and self.carrying: 102 | if self.carrying.type != self.move_type or self.carrying.color != self.moveColor: 103 | done = True 104 | 105 | # If successfully dropping an object near the target 106 | if action == self.actions.drop and preCarrying: 107 | if self.grid.get(ox, oy) is preCarrying: 108 | if abs(ox - tx) <= 1 and abs(oy - ty) <= 1: 109 | reward = self._reward() 110 | done = True 111 | 112 | return obs, reward, done, info 113 | 114 | class PutNear8x8N3(PutNearEnv): 115 | def __init__(self): 116 | super().__init__(size=8, numObjs=3) 117 | 118 | register( 119 | id='MiniGrid-PutNear-6x6-N2-v0', 120 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:PutNearEnv' 121 | ) 122 | 123 | register( 124 | id='MiniGrid-PutNear-8x8-N3-v0', 125 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:PutNear8x8N3' 126 | ) 127 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/redbluedoors.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import * 2 | from onpolicy.envs.gridworld.gym_minigrid.register import register 3 | 4 | class RedBlueDoorEnv(MiniGridEnv): 5 | """ 6 | Single room with red and blue doors on opposite sides. 7 | The red door must be opened before the blue door to 8 | obtain a reward. 9 | """ 10 | 11 | def __init__(self, size=8): 12 | self.size = size 13 | 14 | super().__init__( 15 | width=2*size, 16 | height=size, 17 | max_steps=20*size*size 18 | ) 19 | 20 | def _gen_grid(self, width, height): 21 | # Create an empty grid 22 | self.grid = Grid(width, height) 23 | 24 | # Generate the grid walls 25 | self.grid.wall_rect(0, 0, 2*self.size, self.size) 26 | self.grid.wall_rect(self.size//2, 0, self.size, self.size) 27 | 28 | # Place the agent in the top-left corner 29 | self.place_agent(top=(self.size//2, 0), size=(self.size, self.size)) 30 | 31 | # Add a red door at a random position in the left wall 32 | pos = self._rand_int(1, self.size - 1) 33 | self.red_door = Door("red") 34 | self.grid.set(self.size//2, pos, self.red_door) 35 | 36 | # Add a blue door at a random position in the right wall 37 | pos = self._rand_int(1, self.size - 1) 38 | self.blue_door = Door("blue") 39 | self.grid.set(self.size//2 + self.size - 1, pos, self.blue_door) 40 | 41 | # Generate the mission string 42 | self.mission = "open the red door then the blue door" 43 | 44 | def step(self, action): 45 | red_door_opened_before = self.red_door.is_open 46 | blue_door_opened_before = self.blue_door.is_open 47 | 48 | obs, reward, done, info = MiniGridEnv.step(self, action) 49 | 50 | red_door_opened_after = self.red_door.is_open 51 | blue_door_opened_after = self.blue_door.is_open 52 | 53 | if blue_door_opened_after: 54 | if red_door_opened_before: 55 | reward = self._reward() 56 | done = True 57 | else: 58 | reward = 0 59 | done = True 60 | 61 | elif red_door_opened_after: 62 | if blue_door_opened_before: 63 | reward = 0 64 | done = True 65 | 66 | return obs, reward, done, info 67 | 68 | class RedBlueDoorEnv6x6(RedBlueDoorEnv): 69 | def __init__(self): 70 | super().__init__(size=6) 71 | 72 | register( 73 | id='MiniGrid-RedBlueDoors-6x6-v0', 74 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:RedBlueDoorEnv6x6' 75 | ) 76 | 77 | register( 78 | id='MiniGrid-RedBlueDoors-8x8-v0', 79 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:RedBlueDoorEnv' 80 | ) 81 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/unlock.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import Ball 2 | from onpolicy.envs.gridworld.gym_minigrid.roomgrid import RoomGrid 3 | from onpolicy.envs.gridworld.gym_minigrid.register import register 4 | 5 | class Unlock(RoomGrid): 6 | """ 7 | Unlock a door 8 | """ 9 | 10 | def __init__(self, seed=None): 11 | room_size = 6 12 | super().__init__( 13 | num_rows=1, 14 | num_cols=2, 15 | room_size=room_size, 16 | max_steps=8*room_size**2, 17 | seed=seed 18 | ) 19 | 20 | def _gen_grid(self, width, height): 21 | super()._gen_grid(width, height) 22 | 23 | # Make sure the two rooms are directly connected by a locked door 24 | door, _ = self.add_door(0, 0, 0, locked=True) 25 | # Add a key to unlock the door 26 | self.add_object(0, 0, 'key', door.color) 27 | 28 | self.place_agent(0, 0) 29 | 30 | self.door = door 31 | self.mission = "open the door" 32 | 33 | def step(self, action): 34 | obs, reward, done, info = super().step(action) 35 | 36 | if action == self.actions.toggle: 37 | if self.door.is_open: 38 | reward = self._reward() 39 | done = True 40 | 41 | return obs, reward, done, info 42 | 43 | register( 44 | id='MiniGrid-Unlock-v0', 45 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:Unlock' 46 | ) 47 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/envs/unlockpickup.py: -------------------------------------------------------------------------------- 1 | from onpolicy.envs.gridworld.gym_minigrid.minigrid import Ball 2 | from onpolicy.envs.gridworld.gym_minigrid.roomgrid import RoomGrid 3 | from onpolicy.envs.gridworld.gym_minigrid.register import register 4 | 5 | class UnlockPickup(RoomGrid): 6 | """ 7 | Unlock a door, then pick up a box in another room 8 | """ 9 | 10 | def __init__(self, seed=None): 11 | room_size = 6 12 | super().__init__( 13 | num_rows=1, 14 | num_cols=2, 15 | room_size=room_size, 16 | max_steps=8*room_size**2, 17 | seed=seed 18 | ) 19 | 20 | def _gen_grid(self, width, height): 21 | super()._gen_grid(width, height) 22 | 23 | # Add a box to the room on the right 24 | obj, _ = self.add_object(1, 0, kind="box") 25 | # Make sure the two rooms are directly connected by a locked door 26 | door, _ = self.add_door(0, 0, 0, locked=True) 27 | # Add a key to unlock the door 28 | self.add_object(0, 0, 'key', door.color) 29 | 30 | self.place_agent(0, 0) 31 | 32 | self.obj = obj 33 | self.mission = "pick up the %s %s" % (obj.color, obj.type) 34 | 35 | def step(self, action): 36 | obs, reward, done, info = super().step(action) 37 | 38 | if action == self.actions.pickup: 39 | if self.carrying and self.carrying == self.obj: 40 | reward = self._reward() 41 | done = True 42 | 43 | return obs, reward, done, info 44 | 45 | register( 46 | id='MiniGrid-UnlockPickup-v0', 47 | entry_point='onpolicy.envs.gridworld.gym_minigrid.envs:UnlockPickup' 48 | ) 49 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/register.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register as gym_register 2 | 3 | env_list = [] 4 | 5 | def register( 6 | id, 7 | grid_size, 8 | max_steps, 9 | local_step_num, 10 | agent_view_size, 11 | num_obstacles, 12 | num_agents, 13 | agent_pos, 14 | entry_point, 15 | reward_threshold=0.95, 16 | use_merge = True, 17 | use_merge_plan = True, 18 | use_constrict_map = True, 19 | use_fc_net = False, 20 | use_agent_id = False, 21 | use_stack = False, 22 | use_orientation = False, 23 | use_same_location = True, 24 | use_complete_reward = True, 25 | use_agent_obstacle = False, 26 | use_multiroom = False, 27 | use_irregular_room = False, 28 | use_time_penalty = False, 29 | use_overlap_penalty = False, 30 | astar_cost_mode = 'normal' 31 | ): 32 | assert id.startswith("MiniGrid-") 33 | assert id not in env_list 34 | 35 | # Register the environment with OpenAI gym 36 | gym_register( 37 | id=id, 38 | entry_point=entry_point, 39 | kwargs={ 40 | 'grid_size': grid_size, 41 | 'max_steps': max_steps, 42 | 'local_step_num': local_step_num, 43 | 'agent_view_size': agent_view_size, 44 | 'num_obstacles': num_obstacles, 45 | 'num_agents': num_agents, 46 | 'agent_pos': agent_pos, 47 | 'use_merge': use_merge, 48 | 'use_merge_plan': use_merge_plan, 49 | 'use_constrict_map': use_constrict_map, 50 | 'use_fc_net':use_fc_net, 51 | 'use_agent_id':use_agent_id, 52 | 'use_stack':use_stack, 53 | 'use_orientation':use_orientation, 54 | 'use_same_location': use_same_location, 55 | 'use_complete_reward': use_complete_reward, 56 | 'use_agent_obstacle': use_agent_obstacle, 57 | 'use_multiroom': use_multiroom, 58 | 'use_irregular_room': use_irregular_room, 59 | 'use_time_penalty': use_time_penalty, 60 | 'use_overlap_penalty': use_overlap_penalty, 61 | 'astar_cost_mode': astar_cost_mode 62 | }, 63 | reward_threshold=reward_threshold 64 | ) 65 | 66 | # Add the environment to the set 67 | env_list.append(id) 68 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/rendering.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | def downsample(img, factor): 5 | """ 6 | Downsample an image along both dimensions by some factor 7 | """ 8 | 9 | assert img.shape[0] % factor == 0 10 | assert img.shape[1] % factor == 0 11 | 12 | img = img.reshape([img.shape[0]//factor, factor, img.shape[1]//factor, factor, 3]) 13 | img = img.mean(axis=3) 14 | img = img.mean(axis=1) 15 | 16 | return img 17 | 18 | def fill_coords(img, fn, color): 19 | """ 20 | Fill pixels of an image with coordinates matching a filter function 21 | """ 22 | 23 | for y in range(img.shape[0]): 24 | for x in range(img.shape[1]): 25 | yf = (y + 0.5) / img.shape[0] 26 | xf = (x + 0.5) / img.shape[1] 27 | if fn(xf, yf): 28 | img[y, x] = color 29 | 30 | return img 31 | 32 | def rotate_fn(fin, cx, cy, theta): 33 | def fout(x, y): 34 | x = x - cx 35 | y = y - cy 36 | 37 | x2 = cx + x * math.cos(-theta) - y * math.sin(-theta) 38 | y2 = cy + y * math.cos(-theta) + x * math.sin(-theta) 39 | 40 | return fin(x2, y2) 41 | 42 | return fout 43 | 44 | def point_in_line(x0, y0, x1, y1, r): #rounded rectangle 45 | p0 = np.array([x0, y0]) 46 | p1 = np.array([x1, y1]) 47 | dir = p1 - p0 48 | dist = np.linalg.norm(dir) 49 | dir = dir / dist 50 | 51 | xmin = min(x0, x1) - r 52 | xmax = max(x0, x1) + r 53 | ymin = min(y0, y1) - r 54 | ymax = max(y0, y1) + r 55 | 56 | def fn(x, y): 57 | # Fast, early escape test 58 | if x < xmin or x > xmax or y < ymin or y > ymax: 59 | return False 60 | 61 | q = np.array([x, y]) 62 | pq = q - p0 63 | 64 | # Closest point on line 65 | a = np.dot(pq, dir) 66 | a = np.clip(a, 0, dist) 67 | p = p0 + a * dir 68 | 69 | dist_to_line = np.linalg.norm(q - p) 70 | return dist_to_line <= r 71 | 72 | return fn 73 | 74 | def point_in_circle(cx, cy, r): 75 | def fn(x, y): 76 | return (x-cx)*(x-cx) + (y-cy)*(y-cy) <= r * r 77 | return fn 78 | 79 | def point_in_rect(xmin, xmax, ymin, ymax): 80 | def fn(x, y): 81 | return x >= xmin and x <= xmax and y >= ymin and y <= ymax 82 | return fn 83 | 84 | def point_in_triangle(a, b, c): 85 | a = np.array(a) 86 | b = np.array(b) 87 | c = np.array(c) 88 | 89 | def fn(x, y): 90 | v0 = c - a 91 | v1 = b - a 92 | v2 = np.array((x, y)) - a 93 | 94 | # Compute dot products 95 | dot00 = np.dot(v0, v0) 96 | dot01 = np.dot(v0, v1) 97 | dot02 = np.dot(v0, v2) 98 | dot11 = np.dot(v1, v1) 99 | dot12 = np.dot(v1, v2) 100 | 101 | # Compute barycentric coordinates 102 | inv_denom = 1 / (dot00 * dot11 - dot01 * dot01) 103 | u = (dot11 * dot02 - dot01 * dot12) * inv_denom 104 | v = (dot00 * dot12 - dot01 * dot02) * inv_denom 105 | 106 | # Check if point is in triangle 107 | return (u >= 0) and (v >= 0) and (u + v) < 1 108 | 109 | return fn 110 | 111 | def highlight_img(img, color=(255, 255, 255), alpha=0.30): 112 | """ 113 | Add highlighting to an image 114 | """ 115 | 116 | blend_img = img + alpha * (np.array(color, dtype=np.uint8) - img) 117 | blend_img = blend_img.clip(0, 255).astype(np.uint8) 118 | img[:, :, :] = blend_img 119 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/gym_minigrid/window.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import numpy as np 3 | 4 | # Only ask users to install matplotlib if they actually need it 5 | try: 6 | import matplotlib.pyplot as plt 7 | except: 8 | print('To display the environment in a window, please install matplotlib, eg:') 9 | print('pip3 install --user matplotlib') 10 | sys.exit(-1) 11 | 12 | class Window: 13 | """ 14 | Window to draw a gridworld instance using Matplotlib 15 | """ 16 | 17 | def __init__(self, title): 18 | self.fig = None 19 | 20 | self.imshow_obj = None 21 | self.local_imshow_obj = None 22 | 23 | # Create the figure and axes 24 | self.fig, self.ax = plt.subplots(1,2) 25 | 26 | # Show the env name in the window title 27 | self.fig.canvas.set_window_title(title) 28 | 29 | # Turn off x/y axis numbering/ticks 30 | for ax in self.ax: 31 | ax.xaxis.set_ticks_position('none') 32 | ax.yaxis.set_ticks_position('none') 33 | _ = ax.set_xticklabels([]) 34 | _ = ax.set_yticklabels([]) 35 | 36 | # Flag indicating the window was closed 37 | self.closed = False 38 | 39 | def close_handler(evt): 40 | self.closed = True 41 | 42 | self.fig.canvas.mpl_connect('close_event', close_handler) 43 | 44 | def show_img(self, img, local_img): 45 | """ 46 | Show an image or update the image being shown 47 | """ 48 | 49 | # Show the first image of the environment 50 | if self.imshow_obj is None: 51 | self.imshow_obj = self.ax[0].imshow(img, interpolation='bilinear') 52 | if self.local_imshow_obj is None: 53 | self.local_imshow_obj = self.ax[1].imshow(local_img, interpolation='bilinear') 54 | 55 | self.imshow_obj.set_data(img) 56 | self.local_imshow_obj.set_data(local_img) 57 | 58 | self.fig.canvas.draw() 59 | 60 | # Let matplotlib process UI events 61 | # This is needed for interactive mode to work properly 62 | plt.pause(0.001) 63 | 64 | def set_caption(self, text): 65 | """ 66 | Set/update the caption text below the image 67 | """ 68 | 69 | plt.xlabel(text) 70 | 71 | def reg_key_handler(self, key_handler): 72 | """ 73 | Register a keyboard event handler 74 | """ 75 | 76 | # Keyboard handler 77 | self.fig.canvas.mpl_connect('key_press_event', key_handler) 78 | 79 | def show(self, block=True): 80 | """ 81 | Show the window, and start an event loop 82 | """ 83 | 84 | # If not blocking, trigger interactive mode 85 | if not block: 86 | plt.ion() 87 | 88 | # Show the plot 89 | # In non-interative mode, this enters the matplotlib event loop 90 | # In interactive mode, this call does not block 91 | plt.show() 92 | 93 | def close(self): 94 | """ 95 | Close the window 96 | """ 97 | 98 | plt.close() 99 | self.closed = True 100 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/manual_control.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import time 4 | import argparse 5 | import numpy as np 6 | import gym 7 | import gym_minigrid 8 | from gym_minigrid.wrappers import * 9 | from gym_minigrid.window import Window 10 | 11 | def redraw(img): 12 | if not args.agent_view: 13 | img = env.render('rgb_array', tile_size=args.tile_size) 14 | 15 | window.show_img(img) 16 | 17 | def reset(): 18 | if args.seed != -1: 19 | env.seed(args.seed) 20 | 21 | obs = env.reset() 22 | 23 | if hasattr(env, 'mission'): 24 | print('Mission: %s' % env.mission) 25 | window.set_caption(env.mission) 26 | 27 | redraw(obs) 28 | 29 | def step(action): 30 | obs, reward, done, info = env.step(action) 31 | print('step=%s, reward=%.2f' % (env.step_count, reward)) 32 | 33 | if done: 34 | print('done!') 35 | reset() 36 | else: 37 | redraw(obs) 38 | 39 | def key_handler(event): 40 | print('pressed', event.key) 41 | 42 | if event.key == 'escape': 43 | window.close() 44 | return 45 | 46 | if event.key == 'backspace': 47 | reset() 48 | return 49 | 50 | if event.key == 'left': 51 | step(env.actions.left) 52 | return 53 | if event.key == 'right': 54 | step(env.actions.right) 55 | return 56 | if event.key == 'up': 57 | step(env.actions.forward) 58 | return 59 | 60 | # Spacebar 61 | if event.key == ' ': 62 | step(env.actions.toggle) 63 | return 64 | if event.key == 'pageup': 65 | step(env.actions.pickup) 66 | return 67 | if event.key == 'pagedown': 68 | step(env.actions.drop) 69 | return 70 | 71 | if event.key == 'enter': 72 | step(env.actions.done) 73 | return 74 | 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument( 77 | "--env", 78 | help="gym environment to load", 79 | default='MiniGrid-MultiRoom-N6-v0' 80 | ) 81 | parser.add_argument( 82 | "--seed", 83 | type=int, 84 | help="random seed to generate the environment with", 85 | default=-1 86 | ) 87 | parser.add_argument( 88 | "--tile_size", 89 | type=int, 90 | help="size at which to render tiles", 91 | default=32 92 | ) 93 | parser.add_argument( 94 | '--agent_view', 95 | default=False, 96 | help="draw the agent sees (partially observable view)", 97 | action='store_true' 98 | ) 99 | 100 | args = parser.parse_args() 101 | 102 | env = gym.make(args.env) 103 | 104 | if args.agent_view: 105 | env = RGBImgPartialObsWrapper(env) 106 | env = ImgObsWrapper(env) 107 | 108 | window = Window('gym_minigrid - ' + args.env) 109 | window.reg_key_handler(key_handler) 110 | 111 | reset() 112 | 113 | # Blocking event loop 114 | window.show(block=True) 115 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/run_tests.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import random 4 | import numpy as np 5 | import gym 6 | from gym_minigrid.register import env_list 7 | from gym_minigrid.minigrid import Grid, OBJECT_TO_IDX 8 | 9 | # Test specifically importing a specific environment 10 | from gym_minigrid.envs import DoorKeyEnv 11 | 12 | # Test importing wrappers 13 | from gym_minigrid.wrappers import * 14 | 15 | ############################################################################## 16 | 17 | print('%d environments registered' % len(env_list)) 18 | 19 | for env_idx, env_name in enumerate(env_list): 20 | print('testing {} ({}/{})'.format(env_name, env_idx+1, len(env_list))) 21 | 22 | # Load the gym environment 23 | env = gym.make(env_name) 24 | env.max_steps = min(env.max_steps, 200) 25 | env.reset() 26 | env.render('rgb_array') 27 | 28 | # Verify that the same seed always produces the same environment 29 | for i in range(0, 5): 30 | seed = 1337 + i 31 | env.seed(seed) 32 | grid1 = env.grid 33 | env.seed(seed) 34 | grid2 = env.grid 35 | assert grid1 == grid2 36 | 37 | env.reset() 38 | 39 | # Run for a few episodes 40 | num_episodes = 0 41 | while num_episodes < 5: 42 | # Pick a random action 43 | action = random.randint(0, env.action_space.n - 1) 44 | 45 | obs, reward, done, info = env.step(action) 46 | 47 | # Validate the agent position 48 | assert env.agent_pos[0] < env.width 49 | assert env.agent_pos[1] < env.height 50 | 51 | # Test observation encode/decode roundtrip 52 | img = obs['image'] 53 | grid, vis_mask = Grid.decode(img) 54 | img2 = grid.encode(vis_mask=vis_mask) 55 | assert np.array_equal(img, img2) 56 | 57 | # Test the env to string function 58 | str(env) 59 | 60 | # Check that the reward is within the specified range 61 | assert reward >= env.reward_range[0], reward 62 | assert reward <= env.reward_range[1], reward 63 | 64 | if done: 65 | num_episodes += 1 66 | env.reset() 67 | 68 | env.render('rgb_array') 69 | 70 | # Test the close method 71 | env.close() 72 | 73 | env = gym.make(env_name) 74 | env = ReseedWrapper(env) 75 | for _ in range(10): 76 | env.reset() 77 | env.step(0) 78 | env.close() 79 | 80 | env = gym.make(env_name) 81 | env = ImgObsWrapper(env) 82 | env.reset() 83 | env.step(0) 84 | env.close() 85 | 86 | # Test the fully observable wrapper 87 | env = gym.make(env_name) 88 | env = FullyObsWrapper(env) 89 | env.reset() 90 | obs, _, _, _ = env.step(0) 91 | assert obs['image'].shape == env.observation_space.spaces['image'].shape 92 | env.close() 93 | 94 | # RGB image observation wrapper 95 | env = gym.make(env_name) 96 | env = RGBImgPartialObsWrapper(env) 97 | env.reset() 98 | obs, _, _, _ = env.step(0) 99 | assert obs['image'].mean() > 0 100 | env.close() 101 | 102 | env = gym.make(env_name) 103 | env = FlatObsWrapper(env) 104 | env.reset() 105 | env.step(0) 106 | env.close() 107 | 108 | env = gym.make(env_name) 109 | env = ViewSizeWrapper(env, 5) 110 | env.reset() 111 | env.step(0) 112 | env.close() 113 | 114 | # Test the wrappers return proper observation spaces. 115 | wrappers = [ 116 | RGBImgObsWrapper, 117 | RGBImgPartialObsWrapper, 118 | OneHotPartialObsWrapper 119 | ] 120 | for wrapper in wrappers: 121 | env = wrapper(gym.make(env_name)) 122 | obs_space, wrapper_name = env.observation_space, wrapper.__name__ 123 | assert isinstance( 124 | obs_space, spaces.Dict 125 | ), "Observation space for {0} is not a Dict: {1}.".format( 126 | wrapper_name, obs_space 127 | ) 128 | # This should not fail either 129 | ImgObsWrapper(env) 130 | env.reset() 131 | env.step(0) 132 | env.close() 133 | 134 | ############################################################################## 135 | 136 | print('testing agent_sees method') 137 | env = gym.make('MiniGrid-DoorKey-6x6-v0') 138 | goal_pos = (env.grid.width - 2, env.grid.height - 2) 139 | 140 | # Test the "in" operator on grid objects 141 | assert ('green', 'goal') in env.grid 142 | assert ('blue', 'key') not in env.grid 143 | 144 | # Test the env.agent_sees() function 145 | env.reset() 146 | for i in range(0, 500): 147 | action = random.randint(0, env.action_space.n - 1) 148 | obs, reward, done, info = env.step(action) 149 | 150 | grid, _ = Grid.decode(obs['image']) 151 | goal_visible = ('green', 'goal') in grid 152 | 153 | agent_sees_goal = env.agent_sees(*goal_pos) 154 | assert agent_sees_goal == goal_visible 155 | if done: 156 | env.reset() 157 | -------------------------------------------------------------------------------- /onpolicy/envs/gridworld/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='gym_minigrid', 5 | version='1.0.2', 6 | keywords='memory, environment, agent, rl, openaigym, openai-gym, gym', 7 | url='https://github.com/maximecb/gym-minigrid', 8 | description='Minimalistic gridworld package for OpenAI Gym', 9 | packages=['gym_minigrid', 'gym_minigrid.envs'], 10 | install_requires=[ 11 | 'gym>=0.9.6', 12 | 'numpy>=1.15.0' 13 | ] 14 | ) 15 | -------------------------------------------------------------------------------- /onpolicy/scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/scripts/__init__.py -------------------------------------------------------------------------------- /onpolicy/scripts/render/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/scripts/render/__init__.py -------------------------------------------------------------------------------- /onpolicy/scripts/render_gridworld.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | env="GridWorld" 3 | scenario="MiniGrid-MultiExploration-v0" 4 | num_agents=2 5 | num_obstacles=0 6 | algo="mappo" 7 | exp="async_global_new_attn_para_single_to_async(1-5)" 8 | seed_max=3 9 | 10 | echo "env is ${env}" 11 | for seed in `seq ${seed_max}` 12 | do 13 | CUDA_VISIBLE_DEVICES=2 python render/render_gridworld.py\ 14 | --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --scenario_name ${scenario} \ 15 | --num_agents ${num_agents} --num_obstacles ${num_obstacles} \ 16 | --seed ${seed} --n_training_threads 1 --n_rollout_threads 1 --render_episodes 100 \ 17 | --cnn_layers_params '16,7,2,1 32,5,2,1 16,3,1,1' \ 18 | --model_dir "./results/GridWorld/MiniGrid-MultiExploration-v0/mappo/async_global_new_attn_para_no_agent_id_overlap_penalty_single_normal50/wandb/run-20220224_140614-3vhpd2ge/files/" \ 19 | --max_steps 200 --use_complete_reward --agent_view_size 7 --local_step_num 1 --use_random_pos \ 20 | --astar_cost_mode utility --grid_goal --goal_grid_size 5 --cnn_trans_layer 1,3,1,1 \ 21 | --use_stack --grid_size 25 --use_recurrent_policy --use_stack --use_global_goal --use_overlap_penalty --use_eval --wandb_name "mapping" --user_name "yang-xy20" --asynch & 22 | done 23 | -------------------------------------------------------------------------------- /onpolicy/scripts/render_gridworld_ft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | env="GridWorld" 3 | scenario="MiniGrid-MultiExploration-v0" 4 | num_agents=1 5 | grid_size=25 6 | num_obstacles=0 7 | local_step_num=1 8 | seed_max=3 9 | algo='ft_rrt' 10 | 11 | echo "env is ${env}" 12 | for seed in `seq ${seed_max}` 13 | do 14 | echo "seed is ${seed}" 15 | exp=new_async_${algo}_grid${grid_size}_stepgoal_${local_step_num}_merge_normal 16 | CUDA_VISIBLE_DEVICES=3 python render/render_gridworld_ft.py\ 17 | --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --scenario_name ${scenario} \ 18 | --num_agents ${num_agents} --num_obstacles ${num_obstacles} \ 19 | --seed ${seed} --n_training_threads 1 --n_rollout_threads 1 --render_episodes 100 \ 20 | --cnn_layers_params '16,3,1,1 32,3,1,1 16,3,1,1' \ 21 | --ifi 0.5 --max_steps 300 --grid_size ${grid_size} --local_step_num ${local_step_num} --use_random_pos \ 22 | --agent_view_size 7 --use_merge --use_merge_plan --use_eval \ 23 | --astar_cost_mode "normal" --wandb_name "mapping" --user_name "yang-xy20" --asynch 24 | done 25 | 26 | -------------------------------------------------------------------------------- /onpolicy/scripts/train/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/scripts/train/__init__.py -------------------------------------------------------------------------------- /onpolicy/scripts/train_gridworld.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | env="GridWorld" 3 | scenario="MiniGrid-MultiExploration-v0" 4 | num_agents=3 5 | num_obstacles=0 6 | algo="mappo" 7 | exp="async_global_new_attn_para_disc_single_overlap_normal50" 8 | seed_max=1 9 | 10 | echo "env is ${env}, scenario is ${scenario}, algo is ${algo}, exp is ${exp}, max seed is ${seed_max}" 11 | for seed in `seq ${seed_max}`; 12 | do 13 | echo "seed is ${seed}:" 14 | CUDA_VISIBLE_DEVICES=2 python train/train_gridworld.py \ 15 | --env_name ${env} --algorithm_name ${algo} --experiment_name ${exp} --scenario_name ${scenario} \ 16 | --log_interval 1 --wandb_name "mapping" --user_name "yang-xy20" --num_agents ${num_agents} \ 17 | --num_obstacles ${num_obstacles} --cnn_layers_params '16,7,2,1 32,5,2,1 16,3,1,1' --hidden_size 64 --seed 1 --n_training_threads 1 \ 18 | --n_rollout_threads 50 --num_mini_batch 1 --num_env_steps 80000000 --ppo_epoch 3 --gain 0.01 \ 19 | --lr 5e-4 --critic_lr 5e-4 --max_steps 150 --use_complete_reward --agent_view_size 7 --local_step_num 5 --use_random_pos \ 20 | --astar_cost_mode normal --cnn_trans_layer 1,3,1,1 --grid_size 25 --use_recurrent_policy \ 21 | --use_global_goal --use_overlap_penalty --use_stack --goal_grid_size 5 --use_discrect --asynch 22 | done -------------------------------------------------------------------------------- /onpolicy/utils/RRT/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/utils/RRT/__init__.py -------------------------------------------------------------------------------- /onpolicy/utils/RRT/rrt_with_pathsmoothing.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | Path planning Sample Code with RRT with path smoothing 4 | 5 | @author: AtsushiSakai(@Atsushi_twi) 6 | 7 | """ 8 | 9 | import math 10 | import os 11 | import random 12 | import sys 13 | 14 | import matplotlib.pyplot as plt 15 | 16 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | try: 19 | from rrt import RRT 20 | except ImportError: 21 | raise 22 | 23 | show_animation = True 24 | 25 | 26 | def get_path_length(path): 27 | le = 0 28 | for i in range(len(path) - 1): 29 | dx = path[i + 1][0] - path[i][0] 30 | dy = path[i + 1][1] - path[i][1] 31 | d = math.sqrt(dx * dx + dy * dy) 32 | le += d 33 | 34 | return le 35 | 36 | 37 | def get_target_point(path, targetL): 38 | le = 0 39 | ti = 0 40 | lastPairLen = 0 41 | for i in range(len(path) - 1): 42 | dx = path[i + 1][0] - path[i][0] 43 | dy = path[i + 1][1] - path[i][1] 44 | d = math.sqrt(dx * dx + dy * dy) 45 | le += d 46 | if le >= targetL: 47 | ti = i - 1 48 | lastPairLen = d 49 | break 50 | 51 | partRatio = (le - targetL) / lastPairLen 52 | 53 | x = path[ti][0] + (path[ti + 1][0] - path[ti][0]) * partRatio 54 | y = path[ti][1] + (path[ti + 1][1] - path[ti][1]) * partRatio 55 | 56 | return [x, y, ti] 57 | 58 | 59 | def line_collision_check(first, second, obstacleList): 60 | # Line Equation 61 | 62 | x1 = first[0] 63 | y1 = first[1] 64 | x2 = second[0] 65 | y2 = second[1] 66 | 67 | try: 68 | a = y2 - y1 69 | b = -(x2 - x1) 70 | c = y2 * (x2 - x1) - x2 * (y2 - y1) 71 | except ZeroDivisionError: 72 | return False 73 | 74 | for (ox, oy, size) in obstacleList: 75 | d = abs(a * ox + b * oy + c) / (math.sqrt(a * a + b * b)) 76 | if d <= size: 77 | return False 78 | 79 | return True # OK 80 | 81 | 82 | def path_smoothing(path, max_iter, obstacle_list): 83 | le = get_path_length(path) 84 | 85 | for i in range(max_iter): 86 | # Sample two points 87 | pickPoints = [random.uniform(0, le), random.uniform(0, le)] 88 | pickPoints.sort() 89 | first = get_target_point(path, pickPoints[0]) 90 | second = get_target_point(path, pickPoints[1]) 91 | 92 | if first[2] <= 0 or second[2] <= 0: 93 | continue 94 | 95 | if (second[2] + 1) > len(path): 96 | continue 97 | 98 | if second[2] == first[2]: 99 | continue 100 | 101 | # collision check 102 | if not line_collision_check(first, second, obstacle_list): 103 | continue 104 | 105 | # Create New path 106 | newPath = [] 107 | newPath.extend(path[:first[2] + 1]) 108 | newPath.append([first[0], first[1]]) 109 | newPath.append([second[0], second[1]]) 110 | newPath.extend(path[second[2] + 1:]) 111 | path = newPath 112 | le = get_path_length(path) 113 | 114 | return path 115 | 116 | 117 | def main(): 118 | # ====Search Path with RRT==== 119 | # Parameter 120 | obstacleList = [ 121 | (5, 5, 1), 122 | (3, 6, 2), 123 | (3, 8, 2), 124 | (3, 10, 2), 125 | (7, 5, 2), 126 | (9, 5, 2) 127 | ] # [x,y,size] 128 | rrt = RRT(start=[0, 0], goal=[6, 10], 129 | rand_area=[-2, 15], obstacle_list=obstacleList) 130 | path = rrt.planning(animation=show_animation) 131 | 132 | # Path smoothing 133 | maxIter = 1000 134 | smoothedPath = path_smoothing(path, maxIter, obstacleList) 135 | 136 | # Draw final path 137 | if show_animation: 138 | rrt.draw_graph() 139 | plt.plot([x for (x, y) in path], [y for (x, y) in path], '-r') 140 | 141 | plt.plot([x for (x, y) in smoothedPath], [ 142 | y for (x, y) in smoothedPath], '-c') 143 | 144 | plt.grid(True) 145 | plt.pause(0.01) # Need for Mac 146 | plt.show() 147 | 148 | 149 | if __name__ == '__main__': 150 | main() 151 | -------------------------------------------------------------------------------- /onpolicy/utils/RRT/sobol/__init__.py: -------------------------------------------------------------------------------- 1 | from .sobol import i4_sobol as sobol_quasirand 2 | -------------------------------------------------------------------------------- /onpolicy/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yang-xy20/async_mappo/038058ba6f4dcd99204b44994db768b711d9d4cd/onpolicy/utils/__init__.py -------------------------------------------------------------------------------- /onpolicy/utils/multi_discrete.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | # An old version of OpenAI Gym's multi_discrete.py. (Was getting affected by Gym updates) 5 | # (https://github.com/openai/gym/blob/1fb81d4e3fb780ccf77fec731287ba07da35eb84/gym/spaces/multi_discrete.py) 6 | class MultiDiscrete(gym.Space): 7 | """ 8 | - The multi-discrete action space consists of a series of discrete action spaces with different parameters 9 | - It can be adapted to both a Discrete action space or a continuous (Box) action space 10 | - It is useful to represent game controllers or keyboards where each key can be represented as a discrete action space 11 | - It is parametrized by passing an array of arrays containing [min, max] for each discrete action space where the discrete action space can take any integers from `min` to `max` (both inclusive) 12 | Note: A value of 0 always need to represent the NOOP action. 13 | e.g. Nintendo Game Controller 14 | - Can be conceptualized as 3 discrete action spaces: 15 | 1) Arrow Keys: Discrete 5 - NOOP[0], UP[1], RIGHT[2], DOWN[3], LEFT[4] - params: min: 0, max: 4 16 | 2) Button A: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 17 | 3) Button B: Discrete 2 - NOOP[0], Pressed[1] - params: min: 0, max: 1 18 | - Can be initialized as 19 | MultiDiscrete([ [0,4], [0,1], [0,1] ]) 20 | """ 21 | 22 | def __init__(self, array_of_param_array): 23 | self.low = np.array([x[0] for x in array_of_param_array]) 24 | self.high = np.array([x[1] for x in array_of_param_array]) 25 | self.num_discrete_space = self.low.shape[0] 26 | self.n = np.sum(self.high) + 2 27 | 28 | def sample(self): 29 | """ Returns a array with one sample from each discrete action space """ 30 | # For each row: round(random .* (max - min) + min, 0) 31 | random_array = np.random.rand(self.num_discrete_space) 32 | return [int(x) for x in np.floor(np.multiply((self.high - self.low + 1.), random_array) + self.low)] 33 | 34 | def contains(self, x): 35 | return len(x) == self.num_discrete_space and (np.array(x) >= self.low).all() and (np.array(x) <= self.high).all() 36 | 37 | @property 38 | def shape(self): 39 | return self.num_discrete_space 40 | 41 | def __repr__(self): 42 | return "MultiDiscrete" + str(self.num_discrete_space) 43 | 44 | def __eq__(self, other): 45 | return np.array_equal(self.low, other.low) and np.array_equal(self.high, other.high) 46 | -------------------------------------------------------------------------------- /onpolicy/utils/util.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import numpy as np 4 | import math 5 | import torch 6 | import torch.nn as nn 7 | 8 | class AsynchControl: 9 | def __init__(self, num_envs, num_agents, limit, random_fn, min_length, max_length): 10 | self.num_envs = num_envs 11 | self.num_agents = num_agents 12 | self.limit = limit 13 | self.random_fn = random_fn 14 | self.min_length = min_length 15 | self.max_length = max_length 16 | 17 | self.reset() 18 | 19 | def reset(self): 20 | self.cnt = np.zeros((self.num_envs, self.num_agents), dtype=np.int32) 21 | self.rest = np.zeros((self.num_envs, self.num_agents), dtype=np.int32) 22 | self.active = np.ones((self.num_envs, self.num_agents), dtype=np.int32) 23 | for e in range(self.num_envs): 24 | for a in range(self.num_agents): 25 | self.rest[e, a] = self.random_fn() # the first step is unlimited 26 | 27 | def step(self): 28 | for e in range(self.num_envs): 29 | for a in range(self.num_agents): 30 | self.rest[e, a] -= 1 31 | self.active[e, a] = 0 32 | if self.rest[e, a] <= 0: 33 | if self.cnt[e, a] < self.limit: 34 | self.cnt[e, a] += 1 35 | self.active[e, a] = 1 36 | self.rest[e, a] = min(max(self.random_fn(), self.min_length), self.max_length) 37 | 38 | def active_agents(self): 39 | ret = [] 40 | for e in range(self.num_envs): 41 | for a in range(self.num_agents): 42 | if self.active[e, a]: 43 | ret.append((e, a, self.cnt[e, a])) 44 | return ret 45 | 46 | def check(input): 47 | if type(input) == np.ndarray: 48 | return torch.from_numpy(input) 49 | 50 | def get_gard_norm(it): 51 | sum_grad = 0 52 | for x in it: 53 | if x.grad is None: 54 | continue 55 | sum_grad += x.grad.norm() ** 2 56 | return math.sqrt(sum_grad) 57 | 58 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 59 | """Decreases the learning rate linearly""" 60 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 61 | for param_group in optimizer.param_groups: 62 | param_group['lr'] = lr 63 | 64 | def huber_loss(e, d): 65 | a = (abs(e) <= d).float() 66 | b = (e > d).float() 67 | return a*e**2/2 + b*d*(abs(e)-d/2) 68 | 69 | def mse_loss(e): 70 | return e**2/2 71 | 72 | def get_shape_from_obs_space(obs_space): 73 | if obs_space.__class__.__name__ == 'Box': 74 | obs_shape = obs_space.shape 75 | elif obs_space.__class__.__name__ == 'list': 76 | obs_shape = obs_space 77 | elif obs_space.__class__.__name__ == 'Dict': 78 | obs_shape = obs_space.spaces 79 | else: 80 | raise NotImplementedError 81 | return obs_shape 82 | 83 | def get_shape_from_act_space(act_space): 84 | if act_space.__class__.__name__ == 'Discrete': 85 | act_shape = 1 86 | elif act_space.__class__.__name__ == "MultiDiscrete": 87 | act_shape = act_space.shape 88 | elif act_space.__class__.__name__ == "Box": 89 | act_shape = act_space.shape[0] 90 | elif act_space.__class__.__name__ == "MultiBinary": 91 | act_shape = act_space.shape[0] 92 | else: # agar 93 | act_shape = act_space[0].shape[0] + 1 94 | return act_shape 95 | 96 | 97 | def tile_images(img_nhwc): 98 | """ 99 | Tile N images into one big PxQ image 100 | (P,Q) are chosen to be as close as possible, and if N 101 | is square, then P=Q. 102 | input: img_nhwc, list or array of images, ndim=4 once turned into array 103 | n = batch index, h = height, w = width, c = channel 104 | returns: 105 | bigim_HWc, ndarray with ndim=3 106 | """ 107 | img_nhwc = np.asarray(img_nhwc) 108 | N, h, w, c = img_nhwc.shape 109 | H = int(np.ceil(np.sqrt(N))) 110 | W = int(np.ceil(float(N)/H)) 111 | img_nhwc = np.array(list(img_nhwc) + [img_nhwc[0]*0 for _ in range(N, H*W)]) 112 | img_HWhwc = img_nhwc.reshape(H, W, h, w, c) 113 | img_HhWwc = img_HWhwc.transpose(0, 2, 1, 3, 4) 114 | img_Hh_Ww_c = img_HhWwc.reshape(H*h, W*w, c) 115 | return img_Hh_Ww_c -------------------------------------------------------------------------------- /onpolicy/utils/valuenorm.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class ValueNorm(nn.Module): 9 | """ Normalize a vector of observations - across the first norm_axes dimensions""" 10 | 11 | def __init__(self, input_shape, norm_axes=1, beta=0.99999, per_element_update=False, epsilon=1e-5, device=torch.device("cpu")): 12 | super(ValueNorm, self).__init__() 13 | 14 | self.input_shape = input_shape 15 | self.norm_axes = norm_axes 16 | self.epsilon = epsilon 17 | self.beta = beta 18 | self.per_element_update = per_element_update 19 | self.tpdv = dict(dtype=torch.float32, device=device) 20 | 21 | self.running_mean = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) 22 | self.running_mean_sq = nn.Parameter(torch.zeros(input_shape), requires_grad=False).to(**self.tpdv) 23 | self.debiasing_term = nn.Parameter(torch.tensor(0.0), requires_grad=False).to(**self.tpdv) 24 | 25 | self.reset_parameters() 26 | 27 | def reset_parameters(self): 28 | self.running_mean.zero_() 29 | self.running_mean_sq.zero_() 30 | self.debiasing_term.zero_() 31 | 32 | def running_mean_var(self): 33 | debiased_mean = self.running_mean / self.debiasing_term.clamp(min=self.epsilon) 34 | debiased_mean_sq = self.running_mean_sq / self.debiasing_term.clamp(min=self.epsilon) 35 | debiased_var = (debiased_mean_sq - debiased_mean ** 2).clamp(min=1e-2) 36 | return debiased_mean, debiased_var 37 | 38 | @torch.no_grad() 39 | def update(self, input_vector): 40 | if type(input_vector) == np.ndarray: 41 | input_vector = torch.from_numpy(input_vector) 42 | input_vector = input_vector.to(**self.tpdv) 43 | 44 | batch_mean = input_vector.mean(dim=tuple(range(self.norm_axes))) 45 | batch_sq_mean = (input_vector ** 2).mean(dim=tuple(range(self.norm_axes))) 46 | 47 | if self.per_element_update: 48 | batch_size = np.prod(input_vector.size()[:self.norm_axes]) 49 | weight = self.beta ** batch_size 50 | else: 51 | weight = self.beta 52 | 53 | self.running_mean.mul_(weight).add_(batch_mean * (1.0 - weight)) 54 | self.running_mean_sq.mul_(weight).add_(batch_sq_mean * (1.0 - weight)) 55 | self.debiasing_term.mul_(weight).add_(1.0 * (1.0 - weight)) 56 | 57 | def normalize(self, input_vector): 58 | # Make sure input is float32 59 | if type(input_vector) == np.ndarray: 60 | input_vector = torch.from_numpy(input_vector) 61 | input_vector = input_vector.to(**self.tpdv) 62 | 63 | mean, var = self.running_mean_var() 64 | out = (input_vector - mean[(None,) * self.norm_axes]) / torch.sqrt(var)[(None,) * self.norm_axes] 65 | 66 | return out 67 | 68 | def denormalize(self, input_vector): 69 | """ Transform normalized data back into original distribution """ 70 | if type(input_vector) == np.ndarray: 71 | input_vector = torch.from_numpy(input_vector) 72 | input_vector = input_vector.to(**self.tpdv) 73 | 74 | mean, var = self.running_mean_var() 75 | out = input_vector * torch.sqrt(var)[(None,) * self.norm_axes] + mean[(None,) * self.norm_axes] 76 | 77 | out = out.cpu().numpy() 78 | 79 | return out 80 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | aiohttp==3.6.2 3 | aioredis==1.3.1 4 | astor==0.8.0 5 | astunparse==1.6.3 6 | async-timeout==3.0.1 7 | atari-py==0.2.6 8 | atomicwrites==1.2.1 9 | attrs==18.2.0 10 | beautifulsoup4==4.9.1 11 | blessings==1.7 12 | cachetools==4.1.1 13 | certifi==2020.4.5.2 14 | cffi==1.14.1 15 | chardet==3.0.4 16 | click==7.1.2 17 | cloudpickle==1.3.0 18 | colorama==0.4.3 19 | colorful==0.5.4 20 | configparser==5.0.1 21 | contextvars==2.4 22 | cycler==0.10.0 23 | Cython==0.29.21 24 | deepdiff==4.3.2 25 | dill==0.3.2 26 | docker-pycreds==0.4.0 27 | docopt==0.6.2 28 | fasteners==0.15 29 | filelock==3.0.12 30 | funcsigs==1.0.2 31 | future==0.16.0 32 | gast==0.2.2 33 | gin==0.1.6 34 | gin-config==0.3.0 35 | gitdb==4.0.5 36 | GitPython==3.1.9 37 | glfw==1.12.0 38 | google==3.0.0 39 | google-api-core==1.22.1 40 | google-auth==1.21.0 41 | google-auth-oauthlib==0.4.1 42 | google-pasta==0.2.0 43 | googleapis-common-protos==1.52.0 44 | gpustat==0.6.0 45 | gql==0.2.0 46 | graphql-core==1.1 47 | grpcio==1.31.0 48 | gym==0.17.2 49 | h5py==2.10.0 50 | hiredis==1.1.0 51 | idna==2.7 52 | idna-ssl==1.1.0 53 | imageio==2.4.1 54 | immutables==0.14 55 | importlib-metadata==1.7.0 56 | joblib==0.16.0 57 | jsonnet==0.16.0 58 | jsonpickle==0.9.6 59 | jsonschema==3.2.0 60 | Keras-Applications==1.0.8 61 | Keras-Preprocessing==1.1.2 62 | kiwisolver==1.0.1 63 | lockfile==0.12.2 64 | Markdown==3.1.1 65 | matplotlib==3.0.0 66 | mkl-fft==1.1.0 67 | mkl-random==1.1.1 68 | mkl-service==2.3.0 69 | mock==2.0.0 70 | monotonic==1.5 71 | more-itertools==4.3.0 72 | mpi4py==3.0.3 73 | mpyq==0.2.5 74 | msgpack==1.0.0 75 | mujoco-py==2.0.2.13 76 | multidict==4.7.6 77 | munch==2.3.2 78 | numpy==1.18.5 79 | nvidia-ml-py3==7.352.0 80 | oauthlib==3.1.0 81 | opencensus==0.7.10 82 | opencensus-context==0.1.1 83 | opencv-python==4.2.0.34 84 | opt-einsum==3.1.0 85 | ordered-set==4.0.2 86 | packaging==20.4 87 | pandas==1.1.1 88 | pathlib2==2.3.2 89 | pathtools==0.1.2 90 | pbr==4.3.0 91 | Pillow==5.3.0 92 | pluggy==0.7.1 93 | portpicker==1.2.0 94 | probscale==0.2.3 95 | progressbar2==3.53.1 96 | prometheus-client==0.8.0 97 | promise==2.3 98 | protobuf==3.12.3 99 | psutil==5.7.2 100 | py==1.6.0 101 | py-spy==0.3.3 102 | pyasn1==0.4.8 103 | pyasn1-modules==0.2.8 104 | pycparser==2.20 105 | pygame==1.9.4 106 | pyglet==1.5.0 107 | PyOpenGL==3.1.5 108 | PyOpenGL-accelerate==3.1.5 109 | pyparsing==2.2.2 110 | pyrsistent==0.16.0 111 | PySC2==3.0.0 112 | pytest==3.8.2 113 | python-dateutil==2.7.3 114 | python-utils==2.4.0 115 | pytz==2020.1 116 | PyYAML==3.13 117 | pyzmq==19.0.2 118 | ray==0.8.0 119 | redis==3.4.1 120 | requests==2.24.0 121 | requests-oauthlib==1.3.0 122 | rsa==4.6 123 | s2clientprotocol==4.10.1.75800.0 124 | s2protocol==4.11.4.78285.0 125 | sacred==0.7.2 126 | scipy==1.4.1 127 | seaborn==0.10.1 128 | sentry-sdk==0.18.0 129 | setproctitle==1.1.10 130 | shortuuid==1.0.1 131 | six==1.15.0 132 | sk-video==1.1.10 133 | smmap==3.0.4 134 | snakeviz==1.0.0 135 | soupsieve==2.0.1 136 | subprocess32==3.5.4 137 | tabulate==0.8.7 138 | tensorboard==2.0.2 139 | tensorboard-logger==0.1.0 140 | tensorboard-plugin-wit==1.7.0 141 | tensorboardX==2.0 142 | tensorflow==2.0.0 143 | tensorflow-estimator==2.0.0 144 | termcolor==1.1.0 145 | torch==1.5.1+cu101 146 | torchvision==0.6.1+cu101 147 | tornado==5.1.1 148 | tqdm==4.48.2 149 | typing-extensions==3.7.4.3 150 | urllib3==1.23 151 | wandb==0.10.5 152 | watchdog==0.10.3 153 | websocket-client==0.53.0 154 | Werkzeug==0.16.1 155 | whichcraft==0.5.2 156 | wrapt==1.12.1 157 | xmltodict==0.12.0 158 | yarl==1.5.1 159 | zipp==3.1.0 160 | zmq==0.0.0 161 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | import os 5 | from setuptools import setup, find_packages 6 | import setuptools 7 | 8 | def get_version() -> str: 9 | # https://packaging.python.org/guides/single-sourcing-package-version/ 10 | init = open(os.path.join("onpolicy", "__init__.py"), "r").read().split() 11 | return init[init.index("__version__") + 2][1:-1] 12 | 13 | setup( 14 | name="onpolicy", # Replace with your own username 15 | version=get_version(), 16 | description="on-policy algorithms of marlbenchmark", 17 | long_description=open("README.md", encoding="utf8").read(), 18 | long_description_content_type="text/markdown", 19 | author="yuchao", 20 | author_email="zoeyuchao@gmail.com", 21 | packages=setuptools.find_packages(), 22 | classifiers=[ 23 | "Development Status :: 3 - Alpha", 24 | "Intended Audience :: Science/Research", 25 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 26 | "Topic :: Software Development :: Libraries :: Python Modules", 27 | "Programming Language :: Python :: 3", 28 | "License :: OSI Approved :: MIT License", 29 | "Operating System :: OS Independent", 30 | ], 31 | keywords="multi-agent reinforcement learning platform pytorch", 32 | python_requires='>=3.6', 33 | ) 34 | --------------------------------------------------------------------------------