├── .gitignore ├── LICENSE ├── README.md ├── clusters_po.yaml ├── clusters_po_with_push.yaml ├── clusters_po_with_push_separate_colors.yaml ├── clusters_po_with_push_separate_colors_units.yaml ├── clusters_po_with_push_units.yaml ├── conditional_action_trees ├── __init__.py ├── conditional_action_exploration.py ├── conditional_action_mixin.py └── conditional_action_policy_trainer.py ├── images ├── Flat_4.gif ├── M_2.gif └── Ma_4.gif ├── plots ├── plot_baseline_results.py └── plots_training.pdf ├── requirements.txt ├── rllib_baseline.py ├── rllib_baseline_flat.py └── rllib_conditional_actions.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # Wandb 132 | wandb/ 133 | 134 | # PyCharm 135 | .idea/ 136 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Chris Bamford 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Conditional Action Trees 2 | 3 | ![M_level_2](images/M_2.gif) 4 | ![M_level_2](images/Flat_4.gif) 5 | ![M_level_2](images/Ma_4.gif) 6 | 7 | ## Arxiv Paper: https://arxiv.org/abs/2104.07294 8 | 9 | ## Abstract 10 | 11 | There are relatively few conventions followed in reinforcement learning (RL) environments to structure the action spaces. As a consequence the application of RL algorithms to tasks with large action spaces with multiple components require additional effort to adjust to different formats. In this paper we introduce `Conditional Action Trees` with two main objectives: (1) as a method of structuring action spaces in RL to generalise across several action space specifications, and (2) to formalise a process to significantly reduce the action space by decomposing it into multiple sub-spaces, favoring a multi-staged decision making approach. We show several proof-of-concept experiments validating our scheme, ranging from environments with basic discrete action spaces to those with large combinatorial action spaces commonly found in RTS-style games. 12 | 13 | ### Join the Discord Community! [https://discord.gg/xuR8Dsv](https://discord.gg/xuR8Dsv) 14 | 15 | ## Install Dependencies for the experiments 16 | 17 | First navigate to this directory then: 18 | 19 | ``` 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | ## :warning: Rllib < 1.4.0 :warning: 24 | 25 | The current 1.3.0 release of rllib has some bugs that are fixed in the latest RLLib master branch which can be found here: 26 | 27 | pip install -U https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-2.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl 28 | 29 | ### Install Pytorch 30 | 31 | Instructions on how to install pytorch can be found [here](https://pytorch.org/get-started/locally/) 32 | 33 | 34 | ### WandB Integration 35 | 36 | To upload the results to your own WandB account, create a `.wandb_rc` file in your user directory that contains your WandB API key. 37 | All results and videos will then be automatically uploaded. 38 | 39 | ## Running experiments 40 | 41 | You can copy any of the following lines to run any of the experiments in the paper. 42 | 43 | #### No Masking 44 | 45 | ``` 46 | rllib_baseline.py --experiment-name="M" --yaml-file="clusters_po.yaml" 47 | rllib_baseline.py --experiment-name="MP" --yaml-file="clusters_po_with_push.yaml" 48 | rllib_baseline.py --experiment-name="MPS" --yaml-file="clusters_po_with_push_separate_colors.yaml" 49 | rllib_baseline.py --experiment-name="Ma" --yaml-file="clusters_po_with_push_units.yaml" 50 | rllib_baseline.py --experiment-name="MSa" --yaml-file="clusters_po_with_push_separate_colors_units.yaml" 51 | ``` 52 | 53 | #### Depth-2 54 | 55 | ``` 56 | rllib_baseline_flat.py --experiment-name="M" --yaml-file="clusters_po.yaml" 57 | rllib_baseline_flat.py --experiment-name="MP" --yaml-file="clusters_po_with_push.yaml" 58 | rllib_baseline_flat.py --experiment-name="MPS" --yaml-file="clusters_po_with_push_separate_colors.yaml" 59 | rllib_baseline_flat.py --experiment-name="Ma" --yaml-file="clusters_po_with_push_units.yaml" 60 | rllib_baseline_flat.py --experiment-name="MSa" --yaml-file="clusters_po_with_push_separate_colors_units.yaml" 61 | ``` 62 | 63 | #### CAT_CL + CAT_CD 64 | 65 | Both runs in these experiments run consecutively using ray's `grid_search` method 66 | 67 | ``` 68 | rllib_conditional_actions.py --experiment-name="M" --yaml-file="clusters_po.yaml" 69 | rllib_conditional_actions.py --experiment-name="MP" --yaml-file="clusters_po_with_push.yaml" 70 | rllib_conditional_actions.py --experiment-name="MPS" --yaml-file="clusters_po_with_push_separate_colors.yaml" 71 | rllib_conditional_actions.py --experiment-name="Ma" --yaml-file="clusters_po_with_push_units.yaml" 72 | rllib_conditional_actions.py --experiment-name="MSa" --yaml-file="clusters_po_with_push_separate_colors_units.yaml" 73 | ``` 74 | 75 | 76 | ## Griddly + RLLib 77 | 78 | The experiments are performed using several custom RLLib classes: 79 | 80 | ### [ConditionalActionImpalaTrainer](https://github.com/Bam4d/Griddly/blob/develop/python/griddly/util/rllib/torch/conditional_actions/conditional_action_policy_trainer.py#L119) 81 | 82 | Contains the code for setting up the mixin and the modified vtrace policy 83 | 84 | ### [ConditionalActionMixin](https://github.com/Bam4d/Griddly/blob/develop/python/griddly/util/rllib/torch/conditional_actions/conditional_action_mixin.py) 85 | 86 | Overrides the typical policy rollout method to use the Conditional Action Trees when sampling actions 87 | 88 | ### [ConditionalActionVTraceTorchPolicy](https://github.com/Bam4d/Griddly/blob/develop/python/griddly/util/rllib/torch/conditional_actions/conditional_action_policy_trainer.py#L104) 89 | 90 | Applies constructed masks to the vtrace policy 91 | 92 | ### [TorchConditionalMaskingExploration](https://github.com/Bam4d/Griddly/blob/develop/python/griddly/util/rllib/torch/conditional_actions/conditional_action_exploration.py) 93 | 94 | Contains the tree traversal and mask creation code 95 | 96 | ## Environments 97 | 98 | The 5 environments that are used for the paper are contained in this repository with filenames similar to `clusters_po....yaml` 99 | 100 | They are all based on the `Clusters` environment which has full documentation [here](https://griddly.readthedocs.io/en/latest/games/Clusters/index.html) 101 | 102 | ## WandB Results 103 | 104 | View all of the experiments, training results and videos [here](https://wandb.ai/chrisbam4d/conditional_action_trees) 105 | 106 | -------------------------------------------------------------------------------- /clusters_po.yaml: -------------------------------------------------------------------------------- 1 | Version: "0.1" 2 | Environment: 3 | Name: Partially Observable Clusters 4 | Description: Cluster the coloured objects together by pushing them against the static coloured blocks. 5 | Observers: 6 | Sprite2D: 7 | TileSize: 24 8 | BackgroundTile: oryx/oryx_fantasy/floor1-2.png 9 | Variables: 10 | - Name: box_count 11 | InitialValue: 0 12 | Player: 13 | Observer: 14 | RotateWithAvatar: true 15 | TrackAvatar: true 16 | Height: 5 17 | Width: 5 18 | OffsetX: 0 19 | OffsetY: 2 20 | AvatarObject: avatar # The player can only control a single avatar in the game 21 | Termination: 22 | Win: 23 | - eq: [box_count, 0] 24 | Lose: 25 | - eq: [broken_box:count, 1] 26 | - eq: [avatar:count, 0] 27 | Levels: 28 | - | 29 | w w w w w w w w w w w w w 30 | w . . . . . . . . . . . w 31 | w . . 1 1 . . . 2 . 2 . w 32 | w . . . . 1 . . . . . . w 33 | w . . . a . . . . . 2 . w 34 | w . . . . . . . h . . . w 35 | w . . . . 1 . . . . b . w 36 | w . . . . . . 1 . . . . w 37 | w . . . . . . . . A . . w 38 | w w w w w w w w w w w w w 39 | - | 40 | w w w w w w w w w w w w w 41 | w . . . . . . . . . . . w 42 | w . . 1 . . 2 . c 3 . . w 43 | w . . . . h . . h . . . w 44 | w . . . 2 . . 3 . . 1 . w 45 | w . . . . b . . h . . . w 46 | w . . 3 . . . 2 . . 1 . w 47 | w . . h . h . . . a . . w 48 | w . . . . . A . . . . . w 49 | w w w w w w w w w w w w w 50 | - | 51 | w w w w w w w w w w w w w 52 | w . . a . . b . . c . . w 53 | w . . . . . . . . . . . w 54 | w . . . . . . . . . . . w 55 | w h h h h h . h h h h h w 56 | w . . . . h . h . . . . w 57 | w . 1 2 . h . h . 1 3 . w 58 | w . 3 . . . . . . . 2 . w 59 | w . . . . . A . . . . . w 60 | w w w w w w w w w w w w w 61 | - | 62 | w w w w w w w w w w w w w 63 | w . . . . . . . . . . . w 64 | w . . . 1 . 2 . . c . . w 65 | w . . . . . 3 . . 3 . . w 66 | w . . a . 2 . . . h . . w 67 | w . . . . h h . 3 . . . w 68 | w . . 1 . . . . . 2 . . w 69 | w . . . . . 1 . . b . . w 70 | w . . . . . A . . . . . w 71 | w w w w w w w w w w w w w 72 | - | 73 | w w w w w w w w w w w w w 74 | w . . . . . . . . . . . w 75 | w . . . . . . 1 . . . . w 76 | w . . h . . b . . h . . w 77 | w . . . . 1 . . . . . . w 78 | w . . 3 . . . . 2 . . . w 79 | w . . . a . h . . c . . w 80 | w . . . . 3 . . . . 2 . w 81 | w . . . . . A . . . . . w 82 | w w w w w w w w w w w w w 83 | 84 | Actions: 85 | 86 | # A simple action to count the number of boxes in the game at the start 87 | # Not currently a way to do complex things in termination conditions like combine multiple conditions 88 | - Name: box_counter 89 | InputMapping: 90 | Internal: true 91 | Inputs: 92 | 1: 93 | Description: "The only action here is to increment the box count" 94 | Behaviours: 95 | - Src: 96 | Object: [blue_box, red_box, green_box] 97 | Commands: 98 | - incr: box_count 99 | Dst: 100 | Object: [blue_box, red_box, green_box] 101 | 102 | # Define the move action 103 | - Name: move 104 | InputMapping: 105 | Inputs: 106 | 1: 107 | Description: Rotate left 108 | OrientationVector: [-1, 0] 109 | 2: 110 | Description: Move forwards 111 | OrientationVector: [0, -1] 112 | VectorToDest: [0, -1] 113 | 3: 114 | Description: Rotate right 115 | OrientationVector: [1, 0] 116 | Relative: true 117 | Behaviours: 118 | 119 | # Avatar rotates 120 | - Src: 121 | Object: avatar 122 | Commands: 123 | - rot: _dir 124 | Dst: 125 | Object: avatar 126 | 127 | # Avatar and boxes can move into empty space 128 | - Src: 129 | Object: [avatar, blue_box, green_box, red_box] 130 | Commands: 131 | - mov: _dest 132 | Dst: 133 | Object: _empty 134 | 135 | # Boxes can be pushed by the avatar 136 | - Src: 137 | Object: avatar 138 | Commands: 139 | - mov: _dest 140 | Dst: 141 | Object: [blue_box, green_box, red_box] 142 | Commands: 143 | - cascade: _dest 144 | 145 | # When boxes are pushed against the blocks they change 146 | - Src: 147 | Object: blue_box 148 | Commands: 149 | - change_to: blue_block 150 | - reward: 1 151 | - decr: box_count 152 | Dst: 153 | Object: blue_block 154 | - Src: 155 | Object: red_box 156 | Commands: 157 | - reward: 1 158 | - change_to: red_block 159 | - decr: box_count 160 | Dst: 161 | Object: red_block 162 | - Src: 163 | Object: green_box 164 | Commands: 165 | - reward: 1 166 | - change_to: green_block 167 | - decr: box_count 168 | Dst: 169 | Object: green_block 170 | 171 | # Boxes break if they hit the spikes 172 | - Src: 173 | Object: [blue_box, green_box, red_box] 174 | Commands: 175 | - change_to: broken_box 176 | - reward: -1 177 | Dst: 178 | Object: spike 179 | 180 | # Avatar dies if it hits the spikes 181 | - Src: 182 | Object: avatar 183 | Commands: 184 | - remove: true 185 | - reward: -1 186 | Dst: 187 | Object: spike 188 | 189 | Objects: 190 | - Name: avatar 191 | MapCharacter: A 192 | Observers: 193 | Sprite2D: 194 | - Image: gvgai/oryx/knight1.png 195 | Block2D: 196 | - Shape: triangle 197 | Color: [0.0, 1.0, 0.0] 198 | Scale: 0.8 199 | 200 | - Name: wall 201 | MapCharacter: w 202 | Observers: 203 | Sprite2D: 204 | - TilingMode: WALL_16 205 | Image: 206 | - oryx/oryx_fantasy/wall1-0.png 207 | - oryx/oryx_fantasy/wall1-1.png 208 | - oryx/oryx_fantasy/wall1-2.png 209 | - oryx/oryx_fantasy/wall1-3.png 210 | - oryx/oryx_fantasy/wall1-4.png 211 | - oryx/oryx_fantasy/wall1-5.png 212 | - oryx/oryx_fantasy/wall1-6.png 213 | - oryx/oryx_fantasy/wall1-7.png 214 | - oryx/oryx_fantasy/wall1-8.png 215 | - oryx/oryx_fantasy/wall1-9.png 216 | - oryx/oryx_fantasy/wall1-10.png 217 | - oryx/oryx_fantasy/wall1-11.png 218 | - oryx/oryx_fantasy/wall1-12.png 219 | - oryx/oryx_fantasy/wall1-13.png 220 | - oryx/oryx_fantasy/wall1-14.png 221 | - oryx/oryx_fantasy/wall1-15.png 222 | Block2D: 223 | - Shape: square 224 | Color: [0.5, 0.5, 0.5] 225 | Scale: 0.9 226 | 227 | - Name: spike 228 | MapCharacter: h 229 | Observers: 230 | Sprite2D: 231 | - Image: gvgai/oryx/spike2.png 232 | Block2D: 233 | - Shape: triangle 234 | Color: [0.9, 0.1, 0.1] 235 | Scale: 0.5 236 | 237 | - Name: red_box 238 | MapCharacter: "2" 239 | InitialActions: 240 | - Action: box_counter 241 | ActionId: 1 242 | Observers: 243 | Sprite2D: 244 | - Image: gvgai/newset/blockR.png 245 | Block2D: 246 | - Shape: square 247 | Color: [0.5, 0.2, 0.2] 248 | Scale: 0.5 249 | - Name: red_block 250 | MapCharacter: b 251 | Observers: 252 | Sprite2D: 253 | - Image: gvgai/newset/blockR2.png 254 | Block2D: 255 | - Shape: square 256 | Color: [1.0, 0.0, 0.0] 257 | Scale: 1.0 258 | 259 | - Name: green_box 260 | MapCharacter: "3" 261 | InitialActions: 262 | - Action: box_counter 263 | ActionId: 1 264 | Observers: 265 | Sprite2D: 266 | - Image: gvgai/newset/blockG.png 267 | Block2D: 268 | - Shape: square 269 | Color: [0.2, 0.5, 0.2] 270 | Scale: 0.5 271 | - Name: green_block 272 | MapCharacter: c 273 | Observers: 274 | Sprite2D: 275 | - Image: gvgai/newset/blockG2.png 276 | Block2D: 277 | - Shape: square 278 | Color: [0.0, 1.0, 0.0] 279 | Scale: 1.0 280 | 281 | - Name: blue_box 282 | MapCharacter: "1" 283 | InitialActions: 284 | - Action: box_counter 285 | ActionId: 1 286 | Observers: 287 | Sprite2D: 288 | - Image: gvgai/newset/blockB.png 289 | Block2D: 290 | - Shape: square 291 | Color: [0.2, 0.2, 0.5] 292 | Scale: 0.5 293 | - Name: blue_block 294 | MapCharacter: a 295 | Observers: 296 | Sprite2D: 297 | - Image: gvgai/newset/blockB2.png 298 | Block2D: 299 | - Shape: square 300 | Color: [0.0, 0.0, 1.0] 301 | Scale: 1.0 302 | 303 | - Name: broken_box 304 | Observers: 305 | Sprite2D: 306 | - Image: gvgai/newset/block3.png 307 | Block2D: 308 | - Shape: triangle 309 | Color: [1.0, 0.0, 1.0] 310 | Scale: 1.0 311 | -------------------------------------------------------------------------------- /clusters_po_with_push.yaml: -------------------------------------------------------------------------------- 1 | Version: "0.1" 2 | Environment: 3 | Name: Partially Observable Clusters 4 | Description: Cluster the coloured objects together by pushing them against the static coloured blocks. 5 | Observers: 6 | Sprite2D: 7 | TileSize: 24 8 | BackgroundTile: oryx/oryx_fantasy/floor1-2.png 9 | Variables: 10 | - Name: box_count 11 | InitialValue: 0 12 | Player: 13 | Observer: 14 | RotateWithAvatar: true 15 | TrackAvatar: true 16 | Height: 5 17 | Width: 5 18 | OffsetX: 0 19 | OffsetY: 2 20 | AvatarObject: avatar # The player can only control a single avatar in the game 21 | Termination: 22 | Win: 23 | - eq: [box_count, 0] 24 | Lose: 25 | - eq: [broken_box:count, 1] 26 | - eq: [avatar:count, 0] 27 | Levels: 28 | - | 29 | w w w w w w w w w w w w w 30 | w . . . . . . . . . . . w 31 | w . . 1 1 . . . 2 . 2 . w 32 | w . . . . 1 . . . . . . w 33 | w . . . a . . . . . 2 . w 34 | w . . . . . . . h . . . w 35 | w . . . . 1 . . . . b . w 36 | w . . . . . . 1 . . . . w 37 | w . . . . . . . . A . . w 38 | w w w w w w w w w w w w w 39 | - | 40 | w w w w w w w w w w w w w 41 | w . . . . . . . . . . . w 42 | w . . 1 . . 2 . c 3 . . w 43 | w . . . . h . . h . . . w 44 | w . . . 2 . . 3 . . 1 . w 45 | w . . . . b . . h . . . w 46 | w . . 3 . . . 2 . . 1 . w 47 | w . . h . h . . . a . . w 48 | w . . . . . A . . . . . w 49 | w w w w w w w w w w w w w 50 | - | 51 | w w w w w w w w w w w w w 52 | w . . a . . b . . c . . w 53 | w . . . . . . . . . . . w 54 | w . . . . . . . . . . . w 55 | w h h h h h . h h h h h w 56 | w . . . . h . h . . . . w 57 | w . 1 2 . h . h . 1 3 . w 58 | w . 3 . . . . . . . 2 . w 59 | w . . . . . A . . . . . w 60 | w w w w w w w w w w w w w 61 | - | 62 | w w w w w w w w w w w w w 63 | w . . . . . . . . . . . w 64 | w . . . 1 . 2 . . c . . w 65 | w . . . . . 3 . . 3 . . w 66 | w . . a . 2 . . . h . . w 67 | w . . . . h h . 3 . . . w 68 | w . . 1 . . . . . 2 . . w 69 | w . . . . . 1 . . b . . w 70 | w . . . . . A . . . . . w 71 | w w w w w w w w w w w w w 72 | - | 73 | w w w w w w w w w w w w w 74 | w . . . . . . . . . . . w 75 | w . . . . . . 1 . . . . w 76 | w . . h . . b . . h . . w 77 | w . . . . 1 . . . . . . w 78 | w . . 3 . . . . 2 . . . w 79 | w . . . a . h . . c . . w 80 | w . . . . 3 . . . . 2 . w 81 | w . . . . . A . . . . . w 82 | w w w w w w w w w w w w w 83 | 84 | Actions: 85 | 86 | # A simple action to count the number of boxes in the game at the start 87 | # Not currently a way to do complex things in termination conditions like combine multiple conditions 88 | - Name: box_counter 89 | InputMapping: 90 | Internal: true 91 | Inputs: 92 | 1: 93 | Description: "The only action here is to increment the box count" 94 | Behaviours: 95 | - Src: 96 | Object: [blue_box, red_box, green_box] 97 | Commands: 98 | - incr: box_count 99 | Dst: 100 | Object: [blue_box, red_box, green_box] 101 | 102 | # Define the move action 103 | - Name: move 104 | InputMapping: 105 | Inputs: 106 | 1: 107 | Description: Rotate left 108 | OrientationVector: [-1, 0] 109 | 2: 110 | Description: Move forwards 111 | OrientationVector: [0, -1] 112 | VectorToDest: [0, -1] 113 | 3: 114 | Description: Rotate right 115 | OrientationVector: [1, 0] 116 | Relative: true 117 | Behaviours: 118 | 119 | # Avatar rotates 120 | - Src: 121 | Object: avatar 122 | Commands: 123 | - rot: _dir 124 | Dst: 125 | Object: avatar 126 | 127 | # Avatar can move into empty space 128 | - Src: 129 | Object: avatar 130 | Commands: 131 | - mov: _dest 132 | Dst: 133 | Object: _empty 134 | 135 | # Avatar dies if it hits the spikes 136 | - Src: 137 | Object: avatar 138 | Commands: 139 | - remove: true 140 | - reward: -1 141 | Dst: 142 | Object: spike 143 | 144 | 145 | - Name: push 146 | InputMapping: 147 | Inputs: 148 | 1: 149 | Description: Push Forwards 150 | OrientationVector: [ 0, -1 ] 151 | VectorToDest: [ 0, -1 ] 152 | Relative: true 153 | Behaviours: 154 | 155 | # Boxes can be pushed by the avatar 156 | - Src: 157 | Object: avatar 158 | Commands: 159 | - mov: _dest 160 | Dst: 161 | Object: [blue_box, green_box, red_box] 162 | Commands: 163 | - cascade: _dest 164 | 165 | # Boxes break if they hit the spikes 166 | - Src: 167 | Object: [ blue_box, green_box, red_box ] 168 | Commands: 169 | - change_to: broken_box 170 | - reward: -1 171 | Dst: 172 | Object: spike 173 | 174 | # Boxes can pushed into empty space 175 | - Src: 176 | Object: [blue_box, green_box, red_box] 177 | Commands: 178 | - mov: _dest 179 | Dst: 180 | Object: _empty 181 | 182 | # When boxes are pushed against the blocks they change 183 | - Src: 184 | Object: blue_box 185 | Commands: 186 | - change_to: blue_block 187 | - reward: 1 188 | - decr: box_count 189 | Dst: 190 | Object: blue_block 191 | - Src: 192 | Object: red_box 193 | Commands: 194 | - reward: 1 195 | - change_to: red_block 196 | - decr: box_count 197 | Dst: 198 | Object: red_block 199 | - Src: 200 | Object: green_box 201 | Commands: 202 | - reward: 1 203 | - change_to: green_block 204 | - decr: box_count 205 | Dst: 206 | Object: green_block 207 | 208 | 209 | Objects: 210 | - Name: avatar 211 | MapCharacter: A 212 | Observers: 213 | Sprite2D: 214 | - Image: gvgai/oryx/knight1.png 215 | Block2D: 216 | - Shape: triangle 217 | Color: [0.0, 1.0, 0.0] 218 | Scale: 0.8 219 | 220 | - Name: wall 221 | MapCharacter: w 222 | Observers: 223 | Sprite2D: 224 | - TilingMode: WALL_16 225 | Image: 226 | - oryx/oryx_fantasy/wall1-0.png 227 | - oryx/oryx_fantasy/wall1-1.png 228 | - oryx/oryx_fantasy/wall1-2.png 229 | - oryx/oryx_fantasy/wall1-3.png 230 | - oryx/oryx_fantasy/wall1-4.png 231 | - oryx/oryx_fantasy/wall1-5.png 232 | - oryx/oryx_fantasy/wall1-6.png 233 | - oryx/oryx_fantasy/wall1-7.png 234 | - oryx/oryx_fantasy/wall1-8.png 235 | - oryx/oryx_fantasy/wall1-9.png 236 | - oryx/oryx_fantasy/wall1-10.png 237 | - oryx/oryx_fantasy/wall1-11.png 238 | - oryx/oryx_fantasy/wall1-12.png 239 | - oryx/oryx_fantasy/wall1-13.png 240 | - oryx/oryx_fantasy/wall1-14.png 241 | - oryx/oryx_fantasy/wall1-15.png 242 | Block2D: 243 | - Shape: square 244 | Color: [0.5, 0.5, 0.5] 245 | Scale: 0.9 246 | 247 | - Name: spike 248 | MapCharacter: h 249 | Observers: 250 | Sprite2D: 251 | - Image: gvgai/oryx/spike2.png 252 | Block2D: 253 | - Shape: triangle 254 | Color: [0.9, 0.1, 0.1] 255 | Scale: 0.5 256 | 257 | - Name: red_box 258 | MapCharacter: "2" 259 | InitialActions: 260 | - Action: box_counter 261 | ActionId: 1 262 | Observers: 263 | Sprite2D: 264 | - Image: gvgai/newset/blockR.png 265 | Block2D: 266 | - Shape: square 267 | Color: [0.5, 0.2, 0.2] 268 | Scale: 0.5 269 | - Name: red_block 270 | MapCharacter: b 271 | Observers: 272 | Sprite2D: 273 | - Image: gvgai/newset/blockR2.png 274 | Block2D: 275 | - Shape: square 276 | Color: [1.0, 0.0, 0.0] 277 | Scale: 1.0 278 | 279 | - Name: green_box 280 | MapCharacter: "3" 281 | InitialActions: 282 | - Action: box_counter 283 | ActionId: 1 284 | Observers: 285 | Sprite2D: 286 | - Image: gvgai/newset/blockG.png 287 | Block2D: 288 | - Shape: square 289 | Color: [0.2, 0.5, 0.2] 290 | Scale: 0.5 291 | - Name: green_block 292 | MapCharacter: c 293 | Observers: 294 | Sprite2D: 295 | - Image: gvgai/newset/blockG2.png 296 | Block2D: 297 | - Shape: square 298 | Color: [0.0, 1.0, 0.0] 299 | Scale: 1.0 300 | 301 | - Name: blue_box 302 | MapCharacter: "1" 303 | InitialActions: 304 | - Action: box_counter 305 | ActionId: 1 306 | Observers: 307 | Sprite2D: 308 | - Image: gvgai/newset/blockB.png 309 | Block2D: 310 | - Shape: square 311 | Color: [0.2, 0.2, 0.5] 312 | Scale: 0.5 313 | - Name: blue_block 314 | MapCharacter: a 315 | Observers: 316 | Sprite2D: 317 | - Image: gvgai/newset/blockB2.png 318 | Block2D: 319 | - Shape: square 320 | Color: [0.0, 0.0, 1.0] 321 | Scale: 1.0 322 | 323 | - Name: broken_box 324 | Observers: 325 | Sprite2D: 326 | - Image: gvgai/newset/block3.png 327 | Block2D: 328 | - Shape: triangle 329 | Color: [1.0, 0.0, 1.0] 330 | Scale: 1.0 331 | -------------------------------------------------------------------------------- /clusters_po_with_push_separate_colors.yaml: -------------------------------------------------------------------------------- 1 | Version: "0.1" 2 | Environment: 3 | Name: Partially Observable Clusters 4 | Description: Cluster the coloured objects together by pushing them against the static coloured blocks. 5 | Observers: 6 | Sprite2D: 7 | TileSize: 24 8 | BackgroundTile: oryx/oryx_fantasy/floor1-2.png 9 | Variables: 10 | - Name: box_count 11 | InitialValue: 0 12 | Player: 13 | Observer: 14 | RotateWithAvatar: true 15 | TrackAvatar: true 16 | Height: 5 17 | Width: 5 18 | OffsetX: 0 19 | OffsetY: 2 20 | AvatarObject: avatar # The player can only control a single avatar in the game 21 | Termination: 22 | Win: 23 | - eq: [box_count, 0] 24 | Lose: 25 | - eq: [broken_box:count, 1] 26 | - eq: [avatar:count, 0] 27 | Levels: 28 | - | 29 | w w w w w w w w w w w w w 30 | w . . . . . . . . . . . w 31 | w . . 1 1 . . . 2 . 2 . w 32 | w . . . . 1 . . . . . . w 33 | w . . . a . . . . . 2 . w 34 | w . . . . . . . h . . . w 35 | w . . . . 1 . . . . b . w 36 | w . . . . . . 1 . . . . w 37 | w . . . . . . . . A . . w 38 | w w w w w w w w w w w w w 39 | - | 40 | w w w w w w w w w w w w w 41 | w . . . . . . . . . . . w 42 | w . . 1 . . 2 . c 3 . . w 43 | w . . . . h . . h . . . w 44 | w . . . 2 . . 3 . . 1 . w 45 | w . . . . b . . h . . . w 46 | w . . 3 . . . 2 . . 1 . w 47 | w . . h . h . . . a . . w 48 | w . . . . . A . . . . . w 49 | w w w w w w w w w w w w w 50 | - | 51 | w w w w w w w w w w w w w 52 | w . . a . . b . . c . . w 53 | w . . . . . . . . . . . w 54 | w . . . . . . . . . . . w 55 | w h h h h h . h h h h h w 56 | w . . . . h . h . . . . w 57 | w . 1 2 . h . h . 1 3 . w 58 | w . 3 . . . . . . . 2 . w 59 | w . . . . . A . . . . . w 60 | w w w w w w w w w w w w w 61 | - | 62 | w w w w w w w w w w w w w 63 | w . . . . . . . . . . . w 64 | w . . . 1 . 2 . . c . . w 65 | w . . . . . 3 . . 3 . . w 66 | w . . a . 2 . . . h . . w 67 | w . . . . h h . 3 . . . w 68 | w . . 1 . . . . . 2 . . w 69 | w . . . . . 1 . . b . . w 70 | w . . . . . A . . . . . w 71 | w w w w w w w w w w w w w 72 | - | 73 | w w w w w w w w w w w w w 74 | w . . . . . . . . . . . w 75 | w . . . . . . 1 . . . . w 76 | w . . h . . b . . h . . w 77 | w . . . . 1 . . . . . . w 78 | w . . 3 . . . . 2 . . . w 79 | w . . . a . h . . c . . w 80 | w . . . . 3 . . . . 2 . w 81 | w . . . . . A . . . . . w 82 | w w w w w w w w w w w w w 83 | 84 | Actions: 85 | 86 | # A simple action to count the number of boxes in the game at the start 87 | # Not currently a way to do complex things in termination conditions like combine multiple conditions 88 | - Name: box_counter 89 | InputMapping: 90 | Internal: true 91 | Inputs: 92 | 1: 93 | Description: "The only action here is to increment the box count" 94 | Behaviours: 95 | - Src: 96 | Object: [blue_box, red_box, green_box] 97 | Commands: 98 | - incr: box_count 99 | Dst: 100 | Object: [blue_box, red_box, green_box] 101 | 102 | # Define the move action 103 | - Name: move 104 | InputMapping: 105 | Inputs: 106 | 1: 107 | Description: Rotate left 108 | OrientationVector: [-1, 0] 109 | 2: 110 | Description: Move forwards 111 | OrientationVector: [0, -1] 112 | VectorToDest: [0, -1] 113 | 3: 114 | Description: Rotate right 115 | OrientationVector: [1, 0] 116 | Relative: true 117 | Behaviours: 118 | 119 | # Avatar rotates 120 | - Src: 121 | Object: avatar 122 | Commands: 123 | - rot: _dir 124 | Dst: 125 | Object: avatar 126 | 127 | # Avatar can move into empty space 128 | - Src: 129 | Object: avatar 130 | Commands: 131 | - mov: _dest 132 | Dst: 133 | Object: _empty 134 | 135 | 136 | # Avatar dies if it hits the spikes 137 | - Src: 138 | Object: avatar 139 | Commands: 140 | - remove: true 141 | - reward: -1 142 | Dst: 143 | Object: spike 144 | 145 | 146 | - Name: push_blue 147 | InputMapping: 148 | Inputs: 149 | 1: 150 | Description: Push Blue 151 | OrientationVector: [ 0, -1 ] 152 | VectorToDest: [ 0, -1 ] 153 | Relative: true 154 | Behaviours: 155 | 156 | # Boxes can be pushed by the avatar 157 | - Src: 158 | Object: avatar 159 | Commands: 160 | - mov: _dest 161 | Dst: 162 | Object: blue_box 163 | Commands: 164 | - cascade: _dest 165 | 166 | # Boxes break if they are pushed into the spikes 167 | - Src: 168 | Object: blue_box 169 | Commands: 170 | - change_to: broken_box 171 | - reward: -1 172 | Dst: 173 | Object: spike 174 | 175 | # Boxes can pushed into empty space 176 | - Src: 177 | Object: blue_box 178 | Commands: 179 | - mov: _dest 180 | Dst: 181 | Object: _empty 182 | 183 | # When boxes are pushed against the blocks they change 184 | - Src: 185 | Object: blue_box 186 | Commands: 187 | - change_to: blue_block 188 | - reward: 1 189 | - decr: box_count 190 | Dst: 191 | Object: blue_block 192 | 193 | - Name: push_red 194 | InputMapping: 195 | Inputs: 196 | 1: 197 | Description: Push Red 198 | OrientationVector: [ 0, -1 ] 199 | VectorToDest: [ 0, -1 ] 200 | Relative: true 201 | Behaviours: 202 | 203 | # Boxes can be pushed by the avatar 204 | - Src: 205 | Object: avatar 206 | Commands: 207 | - mov: _dest 208 | Dst: 209 | Object: red_box 210 | Commands: 211 | - cascade: _dest 212 | 213 | # Boxes break if they are pushed into the spikes 214 | - Src: 215 | Object: red_box 216 | Commands: 217 | - change_to: broken_box 218 | - reward: -1 219 | Dst: 220 | Object: spike 221 | 222 | # Boxes can pushed into empty space 223 | - Src: 224 | Object: red_box 225 | Commands: 226 | - mov: _dest 227 | Dst: 228 | Object: _empty 229 | 230 | # When boxes are pushed against the blocks they change 231 | - Src: 232 | Object: red_box 233 | Commands: 234 | - reward: 1 235 | - change_to: red_block 236 | - decr: box_count 237 | Dst: 238 | Object: red_block 239 | 240 | - Name: push_green 241 | InputMapping: 242 | Inputs: 243 | 1: 244 | Description: Push Green 245 | OrientationVector: [ 0, -1 ] 246 | VectorToDest: [ 0, -1 ] 247 | Relative: true 248 | Behaviours: 249 | 250 | # Boxes can be pushed by the avatar 251 | - Src: 252 | Object: avatar 253 | Commands: 254 | - mov: _dest 255 | Dst: 256 | Object: green_box 257 | Commands: 258 | - cascade: _dest 259 | 260 | # Boxes break if they are pushed into the spikes 261 | - Src: 262 | Object: green_box 263 | Commands: 264 | - change_to: broken_box 265 | - reward: -1 266 | Dst: 267 | Object: spike 268 | 269 | # Boxes can pushed into empty space 270 | - Src: 271 | Object: green_box 272 | Commands: 273 | - mov: _dest 274 | Dst: 275 | Object: _empty 276 | 277 | # When boxes are pushed against the blocks they change 278 | - Src: 279 | Object: green_box 280 | Commands: 281 | - reward: 1 282 | - change_to: green_block 283 | - decr: box_count 284 | Dst: 285 | Object: green_block 286 | 287 | 288 | Objects: 289 | - Name: avatar 290 | MapCharacter: A 291 | Observers: 292 | Sprite2D: 293 | - Image: gvgai/oryx/knight1.png 294 | Block2D: 295 | - Shape: triangle 296 | Color: [0.0, 1.0, 0.0] 297 | Scale: 0.8 298 | 299 | - Name: wall 300 | MapCharacter: w 301 | Observers: 302 | Sprite2D: 303 | - TilingMode: WALL_16 304 | Image: 305 | - oryx/oryx_fantasy/wall1-0.png 306 | - oryx/oryx_fantasy/wall1-1.png 307 | - oryx/oryx_fantasy/wall1-2.png 308 | - oryx/oryx_fantasy/wall1-3.png 309 | - oryx/oryx_fantasy/wall1-4.png 310 | - oryx/oryx_fantasy/wall1-5.png 311 | - oryx/oryx_fantasy/wall1-6.png 312 | - oryx/oryx_fantasy/wall1-7.png 313 | - oryx/oryx_fantasy/wall1-8.png 314 | - oryx/oryx_fantasy/wall1-9.png 315 | - oryx/oryx_fantasy/wall1-10.png 316 | - oryx/oryx_fantasy/wall1-11.png 317 | - oryx/oryx_fantasy/wall1-12.png 318 | - oryx/oryx_fantasy/wall1-13.png 319 | - oryx/oryx_fantasy/wall1-14.png 320 | - oryx/oryx_fantasy/wall1-15.png 321 | Block2D: 322 | - Shape: square 323 | Color: [0.5, 0.5, 0.5] 324 | Scale: 0.9 325 | 326 | - Name: spike 327 | MapCharacter: h 328 | Observers: 329 | Sprite2D: 330 | - Image: gvgai/oryx/spike2.png 331 | Block2D: 332 | - Shape: triangle 333 | Color: [0.9, 0.1, 0.1] 334 | Scale: 0.5 335 | 336 | - Name: red_box 337 | MapCharacter: "2" 338 | InitialActions: 339 | - Action: box_counter 340 | ActionId: 1 341 | Observers: 342 | Sprite2D: 343 | - Image: gvgai/newset/blockR.png 344 | Block2D: 345 | - Shape: square 346 | Color: [0.5, 0.2, 0.2] 347 | Scale: 0.5 348 | - Name: red_block 349 | MapCharacter: b 350 | Observers: 351 | Sprite2D: 352 | - Image: gvgai/newset/blockR2.png 353 | Block2D: 354 | - Shape: square 355 | Color: [1.0, 0.0, 0.0] 356 | Scale: 1.0 357 | 358 | - Name: green_box 359 | MapCharacter: "3" 360 | InitialActions: 361 | - Action: box_counter 362 | ActionId: 1 363 | Observers: 364 | Sprite2D: 365 | - Image: gvgai/newset/blockG.png 366 | Block2D: 367 | - Shape: square 368 | Color: [0.2, 0.5, 0.2] 369 | Scale: 0.5 370 | - Name: green_block 371 | MapCharacter: c 372 | Observers: 373 | Sprite2D: 374 | - Image: gvgai/newset/blockG2.png 375 | Block2D: 376 | - Shape: square 377 | Color: [0.0, 1.0, 0.0] 378 | Scale: 1.0 379 | 380 | - Name: blue_box 381 | MapCharacter: "1" 382 | InitialActions: 383 | - Action: box_counter 384 | ActionId: 1 385 | Observers: 386 | Sprite2D: 387 | - Image: gvgai/newset/blockB.png 388 | Block2D: 389 | - Shape: square 390 | Color: [0.2, 0.2, 0.5] 391 | Scale: 0.5 392 | - Name: blue_block 393 | MapCharacter: a 394 | Observers: 395 | Sprite2D: 396 | - Image: gvgai/newset/blockB2.png 397 | Block2D: 398 | - Shape: square 399 | Color: [0.0, 0.0, 1.0] 400 | Scale: 1.0 401 | 402 | - Name: broken_box 403 | Observers: 404 | Sprite2D: 405 | - Image: gvgai/newset/block3.png 406 | Block2D: 407 | - Shape: triangle 408 | Color: [1.0, 0.0, 1.0] 409 | Scale: 1.0 410 | -------------------------------------------------------------------------------- /clusters_po_with_push_separate_colors_units.yaml: -------------------------------------------------------------------------------- 1 | Version: "0.1" 2 | Environment: 3 | Name: Partially Observable Clusters 4 | Description: Cluster the coloured objects together by pushing them against the static coloured blocks. 5 | Observers: 6 | Sprite2D: 7 | TileSize: 24 8 | BackgroundTile: oryx/oryx_fantasy/floor1-2.png 9 | Variables: 10 | - Name: box_count 11 | InitialValue: 0 12 | PerPlayer: true 13 | - Name: broken_boxes 14 | InitialValue: 0 15 | PerPlayer: true 16 | Player: 17 | Count: 1 18 | Termination: 19 | Win: 20 | - eq: [ box_count, 0 ] 21 | Lose: 22 | - eq: [ broken_boxes, 1 ] 23 | Levels: 24 | - | 25 | w w w w w w w w w w w w w 26 | w . . . . . . . . . . . w 27 | w . . b1 b1 . . . r1 . r1 . w 28 | w . . . . b1 . . . . . . w 29 | w . . . B . . . . . r1 . w 30 | w . . . . . . . x . . . w 31 | w . . . . b1 . . . . R . w 32 | w . . . . . . b1 . . . . w 33 | w . . . . . . . . . . . w 34 | w w w w w w w w w w w w w 35 | - | 36 | w w w w w w w w w w w w w 37 | w . . . . . . . . . . . w 38 | w . . b1 . . r1 . G g1 . . w 39 | w . . . . x . . x . . . w 40 | w . . . r1 . . g1 . . b1 . w 41 | w . . . . R . . x . . . w 42 | w . . g1 . . . r1 . . b1 . w 43 | w . . x . x . . . B . . w 44 | w . . . . . . . . . . . w 45 | w w w w w w w w w w w w w 46 | - | 47 | w w w w w w w w w w w w w 48 | w . . B . . R . . G . . w 49 | w . . . . . . . . . . . w 50 | w . . . . . . . . . . . w 51 | w x x x x x . x x x x x w 52 | w . . . . x . x . . . . w 53 | w . b1 r1 . x . x . b1 g1 . w 54 | w . g1 . . . . . . . r1 . w 55 | w . . . . . . . . . . . w 56 | w w w w w w w w w w w w w 57 | - | 58 | w w w w w w w w w w w w w 59 | w . . . . . . . . . . . w 60 | w . . . b1 . r1 . . G . . w 61 | w . . . . . g1 . . g1 . . w 62 | w . . B . r1 . . . x . . w 63 | w . . . . x x . g1 . . . w 64 | w . . b1 . . . . . r1 . . w 65 | w . . . . . b1 . . R . . w 66 | w . . . . . . . . . . . w 67 | w w w w w w w w w w w w w 68 | - | 69 | w w w w w w w w w w w w w 70 | w . . . . . . . . . . . w 71 | w . . . . . . b1 . . . . w 72 | w . . x . . R . . x . . w 73 | w . . . . b1 . . . . . . w 74 | w . . g1 . . . . r1 . . . w 75 | w . . . B . x . . G . . w 76 | w . . . . g1 . . . . r1 . w 77 | w . . . . . . . . . . . w 78 | w w w w w w w w w w w w w 79 | 80 | Actions: 81 | 82 | # A simple action to count the number of boxes in the game at the start 83 | # Not currently a way to do complex things in termination conditions like combine multiple conditions 84 | - Name: box_counter 85 | InputMapping: 86 | Internal: true 87 | Inputs: 88 | 1: 89 | Description: "The only action here is to increment the box count" 90 | Behaviours: 91 | - Src: 92 | Object: [ blue_box, red_box, green_box ] 93 | Commands: 94 | - incr: box_count 95 | Dst: 96 | Object: [ blue_box, red_box, green_box ] 97 | 98 | - Name: push_blue 99 | Behaviours: 100 | 101 | # Boxes break if they are pushed into the spikes 102 | - Src: 103 | Object: blue_box 104 | Commands: 105 | - incr: broken_boxes 106 | - change_to: broken_box 107 | - reward: -1 108 | Dst: 109 | Object: spike 110 | 111 | # Boxes can pushed into empty space 112 | - Src: 113 | Object: blue_box 114 | Commands: 115 | - mov: _dest 116 | Dst: 117 | Object: _empty 118 | 119 | # When boxes are pushed against the blocks they change 120 | - Src: 121 | Object: blue_box 122 | Commands: 123 | - change_to: blue_block 124 | - reward: 1 125 | - decr: box_count 126 | Dst: 127 | Object: blue_block 128 | 129 | - Name: push_red 130 | Behaviours: 131 | 132 | # Boxes break if they are pushed into the spikes 133 | - Src: 134 | Object: red_box 135 | Commands: 136 | - incr: broken_boxes 137 | - change_to: broken_box 138 | - reward: -1 139 | Dst: 140 | Object: spike 141 | 142 | # Boxes can pushed into empty space 143 | - Src: 144 | Object: red_box 145 | Commands: 146 | - mov: _dest 147 | Dst: 148 | Object: _empty 149 | 150 | # When boxes are pushed against the blocks they change 151 | - Src: 152 | Object: red_box 153 | Commands: 154 | - reward: 1 155 | - change_to: red_block 156 | - decr: box_count 157 | Dst: 158 | Object: red_block 159 | 160 | - Name: push_green 161 | Behaviours: 162 | 163 | # Boxes break if they are pushed into the spikes 164 | - Src: 165 | Object: green_box 166 | Commands: 167 | - incr: broken_boxes 168 | - change_to: broken_box 169 | - reward: -1 170 | Dst: 171 | Object: spike 172 | 173 | # Boxes can pushed into empty space 174 | - Src: 175 | Object: green_box 176 | Commands: 177 | - mov: _dest 178 | Dst: 179 | Object: _empty 180 | 181 | # When boxes are pushed against the blocks they change 182 | - Src: 183 | Object: green_box 184 | Commands: 185 | - reward: 1 186 | - change_to: green_block 187 | - decr: box_count 188 | Dst: 189 | Object: green_block 190 | 191 | 192 | Objects: 193 | 194 | - Name: wall 195 | MapCharacter: w 196 | Observers: 197 | Sprite2D: 198 | - TilingMode: WALL_16 199 | Image: 200 | - oryx/oryx_fantasy/wall1-0.png 201 | - oryx/oryx_fantasy/wall1-1.png 202 | - oryx/oryx_fantasy/wall1-2.png 203 | - oryx/oryx_fantasy/wall1-3.png 204 | - oryx/oryx_fantasy/wall1-4.png 205 | - oryx/oryx_fantasy/wall1-5.png 206 | - oryx/oryx_fantasy/wall1-6.png 207 | - oryx/oryx_fantasy/wall1-7.png 208 | - oryx/oryx_fantasy/wall1-8.png 209 | - oryx/oryx_fantasy/wall1-9.png 210 | - oryx/oryx_fantasy/wall1-10.png 211 | - oryx/oryx_fantasy/wall1-11.png 212 | - oryx/oryx_fantasy/wall1-12.png 213 | - oryx/oryx_fantasy/wall1-13.png 214 | - oryx/oryx_fantasy/wall1-14.png 215 | - oryx/oryx_fantasy/wall1-15.png 216 | Block2D: 217 | - Shape: square 218 | Color: [ 0.5, 0.5, 0.5 ] 219 | Scale: 0.9 220 | 221 | - Name: spike 222 | MapCharacter: x 223 | Observers: 224 | Sprite2D: 225 | - Image: gvgai/oryx/spike2.png 226 | Block2D: 227 | - Shape: triangle 228 | Color: [ 0.9, 0.1, 0.1 ] 229 | Scale: 0.5 230 | 231 | - Name: red_box 232 | MapCharacter: r 233 | InitialActions: 234 | - Action: box_counter 235 | ActionId: 1 236 | Observers: 237 | Sprite2D: 238 | - Image: gvgai/newset/blockR.png 239 | Block2D: 240 | - Shape: square 241 | Color: [ 0.5, 0.2, 0.2 ] 242 | Scale: 0.5 243 | - Name: red_block 244 | MapCharacter: R 245 | Observers: 246 | Sprite2D: 247 | - Image: gvgai/newset/blockR2.png 248 | Block2D: 249 | - Shape: square 250 | Color: [ 1.0, 0.0, 0.0 ] 251 | Scale: 1.0 252 | 253 | - Name: green_box 254 | MapCharacter: g 255 | InitialActions: 256 | - Action: box_counter 257 | ActionId: 1 258 | Observers: 259 | Sprite2D: 260 | - Image: gvgai/newset/blockG.png 261 | Block2D: 262 | - Shape: square 263 | Color: [ 0.2, 0.5, 0.2 ] 264 | Scale: 0.5 265 | - Name: green_block 266 | MapCharacter: G 267 | Observers: 268 | Sprite2D: 269 | - Image: gvgai/newset/blockG2.png 270 | Block2D: 271 | - Shape: square 272 | Color: [ 0.0, 1.0, 0.0 ] 273 | Scale: 1.0 274 | 275 | - Name: blue_box 276 | MapCharacter: b 277 | InitialActions: 278 | - Action: box_counter 279 | ActionId: 1 280 | Observers: 281 | Sprite2D: 282 | - Image: gvgai/newset/blockB.png 283 | Block2D: 284 | - Shape: square 285 | Color: [ 0.2, 0.2, 0.5 ] 286 | Scale: 0.5 287 | - Name: blue_block 288 | MapCharacter: B 289 | Observers: 290 | Sprite2D: 291 | - Image: gvgai/newset/blockB2.png 292 | Block2D: 293 | - Shape: square 294 | Color: [ 0.0, 0.0, 1.0 ] 295 | Scale: 1.0 296 | 297 | - Name: broken_box 298 | Observers: 299 | Sprite2D: 300 | - Image: gvgai/newset/block3.png 301 | Block2D: 302 | - Shape: triangle 303 | Color: [ 1.0, 0.0, 1.0 ] 304 | Scale: 1.0 305 | -------------------------------------------------------------------------------- /clusters_po_with_push_units.yaml: -------------------------------------------------------------------------------- 1 | Version: "0.1" 2 | Environment: 3 | Name: Partially Observable Clusters 4 | Description: Cluster the coloured objects together by pushing them against the static coloured blocks. 5 | Observers: 6 | Sprite2D: 7 | TileSize: 24 8 | BackgroundTile: oryx/oryx_fantasy/floor1-2.png 9 | Variables: 10 | - Name: box_count 11 | InitialValue: 0 12 | PerPlayer: true 13 | - Name: broken_boxes 14 | InitialValue: 0 15 | PerPlayer: true 16 | Player: 17 | Count: 1 18 | Termination: 19 | Win: 20 | - eq: [ box_count, 0 ] 21 | Lose: 22 | - eq: [ broken_boxes, 1 ] 23 | Levels: 24 | - | 25 | w w w w w w w w w w w w w 26 | w . . . . . . . . . . . w 27 | w . . b1 b1 . . . r1 . r1 . w 28 | w . . . . b1 . . . . . . w 29 | w . . . B . . . . . r1 . w 30 | w . . . . . . . x . . . w 31 | w . . . . b1 . . . . R . w 32 | w . . . . . . b1 . . . . w 33 | w . . . . . . . . . . . w 34 | w w w w w w w w w w w w w 35 | - | 36 | w w w w w w w w w w w w w 37 | w . . . . . . . . . . . w 38 | w . . b1 . . r1 . G g1 . . w 39 | w . . . . x . . x . . . w 40 | w . . . r1 . . g1 . . b1 . w 41 | w . . . . R . . x . . . w 42 | w . . g1 . . . r1 . . b1 . w 43 | w . . x . x . . . B . . w 44 | w . . . . . . . . . . . w 45 | w w w w w w w w w w w w w 46 | - | 47 | w w w w w w w w w w w w w 48 | w . . B . . R . . G . . w 49 | w . . . . . . . . . . . w 50 | w . . . . . . . . . . . w 51 | w x x x x x . x x x x x w 52 | w . . . . x . x . . . . w 53 | w . b1 r1 . x . x . b1 g1 . w 54 | w . g1 . . . . . . . r1 . w 55 | w . . . . . . . . . . . w 56 | w w w w w w w w w w w w w 57 | - | 58 | w w w w w w w w w w w w w 59 | w . . . . . . . . . . . w 60 | w . . . b1 . r1 . . G . . w 61 | w . . . . . g1 . . g1 . . w 62 | w . . B . r1 . . . x . . w 63 | w . . . . x x . g1 . . . w 64 | w . . b1 . . . . . r1 . . w 65 | w . . . . . b1 . . R . . w 66 | w . . . . . . . . . . . w 67 | w w w w w w w w w w w w w 68 | - | 69 | w w w w w w w w w w w w w 70 | w . . . . . . . . . . . w 71 | w . . . . . . b1 . . . . w 72 | w . . x . . R . . x . . w 73 | w . . . . b1 . . . . . . w 74 | w . . g1 . . . . r1 . . . w 75 | w . . . B . x . . G . . w 76 | w . . . . g1 . . . . r1 . w 77 | w . . . . . . . . . . . w 78 | w w w w w w w w w w w w w 79 | 80 | Actions: 81 | 82 | # A simple action to count the number of boxes in the game at the start 83 | # Not currently a way to do complex things in termination conditions like combine multiple conditions 84 | - Name: box_counter 85 | InputMapping: 86 | Internal: true 87 | Inputs: 88 | 1: 89 | Description: "The only action here is to increment the box count" 90 | Behaviours: 91 | - Src: 92 | Object: [ blue_box, red_box, green_box ] 93 | Commands: 94 | - incr: box_count 95 | Dst: 96 | Object: [ blue_box, red_box, green_box ] 97 | 98 | - Name: push 99 | Behaviours: 100 | 101 | # Boxes break if they hit the spikes 102 | - Src: 103 | Object: [ blue_box, green_box, red_box ] 104 | Commands: 105 | - incr: broken_boxes 106 | - change_to: broken_box 107 | - reward: -1 108 | Dst: 109 | Object: spike 110 | 111 | # Boxes can pushed into empty space 112 | - Src: 113 | Object: [ blue_box, green_box, red_box ] 114 | Commands: 115 | - mov: _dest 116 | Dst: 117 | Object: _empty 118 | 119 | # When boxes are pushed against the blocks they change 120 | - Src: 121 | Object: blue_box 122 | Commands: 123 | - change_to: blue_block 124 | - reward: 1 125 | - decr: box_count 126 | Dst: 127 | Object: blue_block 128 | - Src: 129 | Object: red_box 130 | Commands: 131 | - reward: 1 132 | - change_to: red_block 133 | - decr: box_count 134 | Dst: 135 | Object: red_block 136 | - Src: 137 | Object: green_box 138 | Commands: 139 | - reward: 1 140 | - change_to: green_block 141 | - decr: box_count 142 | Dst: 143 | Object: green_block 144 | 145 | 146 | Objects: 147 | 148 | - Name: wall 149 | MapCharacter: w 150 | Observers: 151 | Sprite2D: 152 | - TilingMode: WALL_16 153 | Image: 154 | - oryx/oryx_fantasy/wall1-0.png 155 | - oryx/oryx_fantasy/wall1-1.png 156 | - oryx/oryx_fantasy/wall1-2.png 157 | - oryx/oryx_fantasy/wall1-3.png 158 | - oryx/oryx_fantasy/wall1-4.png 159 | - oryx/oryx_fantasy/wall1-5.png 160 | - oryx/oryx_fantasy/wall1-6.png 161 | - oryx/oryx_fantasy/wall1-7.png 162 | - oryx/oryx_fantasy/wall1-8.png 163 | - oryx/oryx_fantasy/wall1-9.png 164 | - oryx/oryx_fantasy/wall1-10.png 165 | - oryx/oryx_fantasy/wall1-11.png 166 | - oryx/oryx_fantasy/wall1-12.png 167 | - oryx/oryx_fantasy/wall1-13.png 168 | - oryx/oryx_fantasy/wall1-14.png 169 | - oryx/oryx_fantasy/wall1-15.png 170 | Block2D: 171 | - Shape: square 172 | Color: [ 0.5, 0.5, 0.5 ] 173 | Scale: 0.9 174 | 175 | - Name: spike 176 | MapCharacter: x 177 | Observers: 178 | Sprite2D: 179 | - Image: gvgai/oryx/spike2.png 180 | Block2D: 181 | - Shape: triangle 182 | Color: [ 0.9, 0.1, 0.1 ] 183 | Scale: 0.5 184 | 185 | - Name: red_box 186 | MapCharacter: r 187 | InitialActions: 188 | - Action: box_counter 189 | ActionId: 1 190 | Observers: 191 | Sprite2D: 192 | - Image: gvgai/newset/blockR.png 193 | Block2D: 194 | - Shape: square 195 | Color: [ 0.5, 0.2, 0.2 ] 196 | Scale: 0.5 197 | - Name: red_block 198 | MapCharacter: R 199 | Observers: 200 | Sprite2D: 201 | - Image: gvgai/newset/blockR2.png 202 | Block2D: 203 | - Shape: square 204 | Color: [ 1.0, 0.0, 0.0 ] 205 | Scale: 1.0 206 | 207 | - Name: green_box 208 | MapCharacter: g 209 | InitialActions: 210 | - Action: box_counter 211 | ActionId: 1 212 | Observers: 213 | Sprite2D: 214 | - Image: gvgai/newset/blockG.png 215 | Block2D: 216 | - Shape: square 217 | Color: [ 0.2, 0.5, 0.2 ] 218 | Scale: 0.5 219 | - Name: green_block 220 | MapCharacter: G 221 | Observers: 222 | Sprite2D: 223 | - Image: gvgai/newset/blockG2.png 224 | Block2D: 225 | - Shape: square 226 | Color: [ 0.0, 1.0, 0.0 ] 227 | Scale: 1.0 228 | 229 | - Name: blue_box 230 | MapCharacter: b 231 | InitialActions: 232 | - Action: box_counter 233 | ActionId: 1 234 | Observers: 235 | Sprite2D: 236 | - Image: gvgai/newset/blockB.png 237 | Block2D: 238 | - Shape: square 239 | Color: [ 0.2, 0.2, 0.5 ] 240 | Scale: 0.5 241 | - Name: blue_block 242 | MapCharacter: B 243 | Observers: 244 | Sprite2D: 245 | - Image: gvgai/newset/blockB2.png 246 | Block2D: 247 | - Shape: square 248 | Color: [ 0.0, 0.0, 1.0 ] 249 | Scale: 1.0 250 | 251 | - Name: broken_box 252 | Observers: 253 | Sprite2D: 254 | - Image: gvgai/newset/block3.png 255 | Block2D: 256 | - Shape: triangle 257 | Color: [ 1.0, 0.0, 1.0 ] 258 | Scale: 1.0 259 | -------------------------------------------------------------------------------- /conditional_action_trees/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bam4d/conditional-action-trees/08e7574cd95b0a5714017e1ca78a12a32c38f645/conditional_action_trees/__init__.py -------------------------------------------------------------------------------- /conditional_action_trees/conditional_action_exploration.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | from gym.spaces import Discrete, MultiDiscrete 5 | from torch.distributions import Categorical 6 | import numpy as np 7 | 8 | 9 | class TorchConditionalMaskingExploration(): 10 | 11 | def __init__(self, model, dist_inputs, valid_action_trees, explore=False, invalid_action_masking='conditional', 12 | allow_nop=False): 13 | self._valid_action_trees = valid_action_trees 14 | 15 | self._num_inputs = dist_inputs.shape[0] 16 | if isinstance(model.action_space, Discrete): 17 | self._action_space_shape = [model.action_space.n] 18 | elif isinstance(model.action_space, MultiDiscrete): 19 | self._action_space_shape = model.action_space.nvec 20 | 21 | self._num_action_logits = np.sum(self._action_space_shape) 22 | self._num_action_parts = len(self._action_space_shape) 23 | 24 | self._invalid_action_masking = invalid_action_masking 25 | self._allow_nop = allow_nop 26 | 27 | self._explore = explore 28 | 29 | self._inputs_split = dist_inputs.split(tuple(self._action_space_shape), dim=1) 30 | 31 | def _mask_and_sample(self, options, logits, is_parameters=False): 32 | 33 | mask = torch.zeros([logits.shape[0]]).to(logits.device) 34 | mask[options] = 1 35 | 36 | if is_parameters: 37 | if not self._allow_nop: 38 | mask[0] = 0 39 | 40 | masked_logits = logits + torch.log(mask) 41 | 42 | dist = Categorical(logits=masked_logits) 43 | sampled = dist.sample() 44 | logp = dist.log_prob(sampled) 45 | out_logits = masked_logits 46 | 47 | # if not self._allow_nop and is_parameters: 48 | # assert sampled != 0 49 | 50 | return sampled, out_logits, logp, mask 51 | 52 | def _fill_node(self, keys, pos): 53 | if pos < len(keys): 54 | return {k: self._fill_node(keys, pos + 1) for k in np.arange(keys[pos])} 55 | else: 56 | return {} 57 | 58 | def _merge_all_branches(self, tree): 59 | all_nodes = {} 60 | merged_tree = {} 61 | for k, v in tree.items(): 62 | v = self._merge_all_branches(v) 63 | all_nodes.update(v) 64 | 65 | for k in tree.keys(): 66 | merged_tree[k] = all_nodes 67 | 68 | return merged_tree 69 | 70 | def _process_valid_action_tree(self, valid_action_tree): 71 | subtree = valid_action_tree 72 | subtree_options = list(subtree.keys()) 73 | 74 | # In the case there are no available actions for the player 75 | if len(subtree_options) == 0: 76 | build_tree = subtree 77 | for _ in range(self._num_action_parts): 78 | build_tree[0] = {} 79 | build_tree = build_tree[0] 80 | subtree_options = list(subtree.keys()) 81 | subtree = build_tree 82 | 83 | # If we want very basic action masking where parameterized masks are superimposed we use this 84 | if self._invalid_action_masking == 'collapsed': 85 | subtree = self._merge_all_branches(valid_action_tree) 86 | subtree_options = list(subtree.keys()) 87 | 88 | return subtree, subtree_options 89 | 90 | def get_actions_and_mask(self): 91 | 92 | actions = torch.zeros([self._num_inputs, self._num_action_parts]) 93 | masked_logits = torch.zeros([self._num_inputs, self._num_action_logits]) 94 | mask = torch.zeros([self._num_inputs, self._num_action_logits]) 95 | logp_sums = torch.zeros([self._num_inputs]) 96 | 97 | if self._valid_action_trees is not None: 98 | 99 | for i in range(self._num_inputs): 100 | if len(self._valid_action_trees) >= 1: 101 | 102 | subtree, subtree_options = self._process_valid_action_tree(self._valid_action_trees[i]) 103 | 104 | logp_parts = torch.zeros([self._num_action_parts]) 105 | mask_offset = 0 106 | for a in range(self._num_action_parts): 107 | dist_part = self._inputs_split[a] 108 | is_parameters = a == (self._num_action_parts - 1) 109 | sampled, masked_part_logits, logp, mask_part = self._mask_and_sample(subtree_options, 110 | dist_part[i], 111 | is_parameters) 112 | 113 | # Set the action and the mask for each part of the action 114 | actions[i, a] = sampled 115 | masked_logits[i, mask_offset:mask_offset + self._action_space_shape[a]] = masked_part_logits 116 | mask[i, mask_offset:mask_offset + self._action_space_shape[a]] = mask_part 117 | 118 | logp_parts[a] = logp 119 | 120 | mask_offset += self._action_space_shape[a] 121 | 122 | if len(subtree.keys()) > 0: 123 | subtree = subtree[int(sampled)] 124 | subtree_options = list(subtree.keys()) 125 | 126 | logp_sums[i] = torch.sum(logp_parts) 127 | 128 | # if its a discrete then flatten the space 129 | if self._num_action_parts == 1: 130 | actions = actions.flatten() 131 | 132 | return actions, masked_logits, logp_sums, mask 133 | -------------------------------------------------------------------------------- /conditional_action_trees/conditional_action_mixin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from ray.rllib import Policy, SampleBatch 4 | from ray.rllib.utils import override 5 | from ray.rllib.utils.torch_ops import convert_to_non_torch_type 6 | 7 | from conditional_action_trees.conditional_action_exploration import TorchConditionalMaskingExploration 8 | 9 | 10 | class ConditionalActionMixin: 11 | 12 | @override(Policy) 13 | def compute_actions_from_input_dict( 14 | self, 15 | input_dict, 16 | explore=None, 17 | timestep = None, 18 | **kwargs): 19 | 20 | explore = explore if explore is not None else self.config["explore"] 21 | timestep = timestep if timestep is not None else self.global_timestep 22 | 23 | with torch.no_grad(): 24 | # Pass lazy (torch) tensor dict to Model as `input_dict`. 25 | input_dict = self._lazy_tensor_dict(input_dict) 26 | # Pack internal state inputs into (separate) list. 27 | state_batches = [ 28 | input_dict[k] for k in input_dict.keys() if "state_in" in k[:8] 29 | ] 30 | # Calculate RNN sequence lengths. 31 | seq_lens = np.array([1] * len(input_dict["obs"])) \ 32 | if state_batches else None 33 | 34 | self._is_recurrent = state_batches is not None and state_batches != [] 35 | 36 | # Switch to eval mode. 37 | self.model.eval() 38 | 39 | dist_inputs, state_out = self.model(input_dict, state_batches, 40 | seq_lens) 41 | 42 | generate_valid_action_trees = self.config['env_config'].get('generate_valid_action_trees', False) 43 | invalid_action_masking = self.config["env_config"].get("invalid_action_masking", 'none') 44 | allow_nop = self.config["env_config"].get("allow_nop", False) 45 | 46 | extra_fetches = {} 47 | 48 | if generate_valid_action_trees: 49 | infos = input_dict[SampleBatch.INFOS] if SampleBatch.INFOS in input_dict else {} 50 | 51 | valid_action_trees = [] 52 | for info in infos: 53 | if isinstance(info, dict) and 'valid_action_tree' in info: 54 | valid_action_trees.append(info['valid_action_tree']) 55 | else: 56 | valid_action_trees.append({}) 57 | 58 | exploration = TorchConditionalMaskingExploration( 59 | self.model, 60 | dist_inputs, 61 | valid_action_trees, 62 | explore, 63 | invalid_action_masking, 64 | allow_nop 65 | ) 66 | 67 | actions, masked_logits, logp, mask = exploration.get_actions_and_mask() 68 | 69 | extra_fetches.update({ 70 | 'invalid_action_mask': mask 71 | }) 72 | else: 73 | action_dist = self.dist_class(dist_inputs, self.model) 74 | 75 | # Get the exploration action from the forward results. 76 | actions, logp = \ 77 | self.exploration.get_exploration_action( 78 | action_distribution=action_dist, 79 | timestep=timestep, 80 | explore=explore) 81 | 82 | masked_logits = dist_inputs 83 | 84 | input_dict[SampleBatch.ACTIONS] = actions 85 | 86 | extra_fetches.update({ 87 | SampleBatch.ACTION_DIST_INPUTS: masked_logits, 88 | SampleBatch.ACTION_PROB: torch.exp(logp.float()), 89 | SampleBatch.ACTION_LOGP: logp, 90 | }) 91 | 92 | # Update our global timestep by the batch size. 93 | self.global_timestep += len(input_dict[SampleBatch.CUR_OBS]) 94 | 95 | return convert_to_non_torch_type((actions, state_out, extra_fetches)) -------------------------------------------------------------------------------- /conditional_action_trees/conditional_action_policy_trainer.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | import torch 4 | from ray.rllib import SampleBatch 5 | from ray.rllib.agents.impala import ImpalaTrainer 6 | from ray.rllib.agents.impala.vtrace_torch_policy import build_vtrace_loss 7 | from ray.rllib.agents.impala.vtrace_torch_policy import VTraceTorchPolicy, VTraceLoss, make_time_major 8 | from ray.rllib.models.torch.torch_action_dist import TorchCategorical 9 | from ray.rllib.policy.torch_policy import LearningRateSchedule, EntropyCoeffSchedule 10 | from tensorflow import sequence_mask 11 | 12 | from conditional_action_trees.conditional_action_mixin import ConditionalActionMixin 13 | 14 | 15 | def build_invalid_masking_vtrace_loss(policy, model, dist_class, train_batch): 16 | if not policy.config['env_config'].get('vtrace_masking', False): 17 | return build_vtrace_loss(policy, model, dist_class, train_batch) 18 | 19 | model_out, _ = model.from_batch(train_batch) 20 | 21 | if isinstance(policy.action_space, gym.spaces.Discrete): 22 | is_multidiscrete = False 23 | output_hidden_shape = [policy.action_space.n] 24 | elif isinstance(policy.action_space, gym.spaces.MultiDiscrete): 25 | is_multidiscrete = True 26 | output_hidden_shape = policy.action_space.nvec.astype(np.int32) 27 | else: 28 | is_multidiscrete = False 29 | output_hidden_shape = 1 30 | 31 | def _make_time_major(*args, **kw): 32 | return make_time_major(policy, train_batch.get("seq_lens"), *args, 33 | **kw) 34 | 35 | actions = train_batch[SampleBatch.ACTIONS] 36 | dones = train_batch[SampleBatch.DONES] 37 | rewards = train_batch[SampleBatch.REWARDS] 38 | behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] 39 | behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] 40 | 41 | invalid_action_mask = train_batch['invalid_action_mask'] 42 | 43 | if 'seq_lens' in train_batch: 44 | max_seq_len = policy.config['rollout_fragment_length'] 45 | mask_orig = sequence_mask(train_batch["seq_lens"], max_seq_len) 46 | mask = torch.reshape(mask_orig, [-1]) 47 | else: 48 | mask = torch.ones_like(rewards) 49 | 50 | model_out += torch.maximum(torch.tensor(torch.finfo().min), torch.log(invalid_action_mask)) 51 | action_dist = dist_class(model_out, model) 52 | 53 | if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): 54 | unpacked_behaviour_logits = torch.split( 55 | behaviour_logits, list(output_hidden_shape), dim=1) 56 | unpacked_outputs = torch.split( 57 | model_out, list(output_hidden_shape), dim=1) 58 | else: 59 | unpacked_behaviour_logits = torch.chunk( 60 | behaviour_logits, output_hidden_shape, dim=1) 61 | unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) 62 | values = model.value_function() 63 | 64 | # Prepare actions for loss. 65 | loss_actions = actions if is_multidiscrete else torch.unsqueeze( 66 | actions, dim=1) 67 | 68 | # Inputs are reshaped from [B * T] => [T - 1, B] for V-trace calc. 69 | policy.loss = VTraceLoss( 70 | actions=_make_time_major(loss_actions, drop_last=True), 71 | actions_logp=_make_time_major( 72 | action_dist.logp(actions), drop_last=True), 73 | actions_entropy=_make_time_major( 74 | action_dist.entropy(), drop_last=True), 75 | dones=_make_time_major(dones, drop_last=True), 76 | behaviour_action_logp=_make_time_major( 77 | behaviour_action_logp, drop_last=True), 78 | behaviour_logits=_make_time_major( 79 | unpacked_behaviour_logits, drop_last=True), 80 | target_logits=_make_time_major(unpacked_outputs, drop_last=True), 81 | discount=policy.config["gamma"], 82 | rewards=_make_time_major(rewards, drop_last=True), 83 | values=_make_time_major(values, drop_last=True), 84 | bootstrap_value=_make_time_major(values)[-1], 85 | dist_class=TorchCategorical if is_multidiscrete else dist_class, 86 | model=model, 87 | valid_mask=_make_time_major(mask, drop_last=True), 88 | config=policy.config, 89 | vf_loss_coeff=policy.config["vf_loss_coeff"], 90 | entropy_coeff=policy.entropy_coeff, 91 | clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], 92 | clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"]) 93 | 94 | return policy.loss.total_loss 95 | 96 | 97 | def setup_mixins(policy, obs_space, action_space, config): 98 | ConditionalActionMixin.__init__(policy) 99 | EntropyCoeffSchedule.__init__(policy, config["entropy_coeff"], 100 | config["entropy_coeff_schedule"]) 101 | LearningRateSchedule.__init__(policy, config["lr"], config["lr_schedule"]) 102 | 103 | 104 | ConditionalActionVTraceTorchPolicy = VTraceTorchPolicy.with_updates( 105 | name="ConditionalActionVTraceTorchPolicy", 106 | loss_fn=build_invalid_masking_vtrace_loss, 107 | before_init=setup_mixins, 108 | mixins=[LearningRateSchedule, EntropyCoeffSchedule, ConditionalActionMixin] 109 | ) 110 | 111 | 112 | def get_vtrace_policy_class(config): 113 | if config['framework'] == 'torch': 114 | return ConditionalActionVTraceTorchPolicy 115 | else: 116 | raise NotImplementedError('Tensorflow not supported') 117 | 118 | 119 | ConditionalActionImpalaTrainer = ImpalaTrainer.with_updates(name="ConditionalActionImpalaTrainer", 120 | default_policy=ConditionalActionVTraceTorchPolicy, 121 | get_policy_class=get_vtrace_policy_class) -------------------------------------------------------------------------------- /images/Flat_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bam4d/conditional-action-trees/08e7574cd95b0a5714017e1ca78a12a32c38f645/images/Flat_4.gif -------------------------------------------------------------------------------- /images/M_2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bam4d/conditional-action-trees/08e7574cd95b0a5714017e1ca78a12a32c38f645/images/M_2.gif -------------------------------------------------------------------------------- /images/Ma_4.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bam4d/conditional-action-trees/08e7574cd95b0a5714017e1ca78a12a32c38f645/images/Ma_4.gif -------------------------------------------------------------------------------- /plots/plot_baseline_results.py: -------------------------------------------------------------------------------- 1 | import wandb 2 | import matplotlib.pyplot as plt 3 | import matplotlib as mpl 4 | from matplotlib.offsetbox import AnchoredText 5 | 6 | CB_color_cycle = ['#377eb8', '#ff7f00', '#4daf4a', 7 | '#f781bf', '#a65628', '#984ea3', 8 | '#999999', '#e41a1c', '#dede00'] 9 | 10 | 11 | def pull_run_data(runs, run_name): 12 | for run in runs: 13 | if run_name == run.name: 14 | print(f'Pulling data from run. Name: {run_name}') 15 | history = run.history(keys=['episode_reward_mean', 'timesteps_total']) 16 | history.set_axis(history['timesteps_total'], 0, inplace=True) 17 | exponential_moving_average = history['episode_reward_mean'].ewm(span=20).mean() 18 | 19 | return exponential_moving_average 20 | 21 | 22 | def get_data_for_experiments(api, experiment_type): 23 | runs = api.runs(f'chrisbam4d/conditional_action_trees') 24 | baseline_run = pull_run_data(runs, f'baseline-{experiment_type}') 25 | baseline_flat_run = pull_run_data(runs, f'baseline-flat-{experiment_type}') 26 | CAT_collapsed_run = pull_run_data(runs, f'CAT-{experiment_type}-V-collapsed') 27 | CAT_conditional_run = pull_run_data(runs, f'CAT-{experiment_type}-V-conditional') 28 | 29 | return baseline_run, baseline_flat_run, CAT_collapsed_run, CAT_conditional_run 30 | 31 | 32 | def plot_training_comparison(experiment_data): 33 | baseline_run, baseline_flat_run, CAT_collapsed_run, CAT_conditional_run = experiment_data 34 | 35 | l_b = baseline_run.plot(label='No Masking') 36 | l_bf = baseline_flat_run.plot(label='Depth 2') 37 | l_CATcl = CAT_collapsed_run.plot(label='CAT_CL') 38 | l_CATcd = CAT_conditional_run.plot(label='CAT_CD') 39 | 40 | return l_b, l_bf, l_CATcl, l_CATcd 41 | 42 | 43 | if __name__ == '__main__': 44 | api = wandb.Api() 45 | 46 | experiments_M = get_data_for_experiments(api, 'M') 47 | experiments_MP = get_data_for_experiments(api, 'MP') 48 | experiments_MPS = get_data_for_experiments(api, 'MPS') 49 | experiments_Ma = get_data_for_experiments(api, 'Ma') 50 | experiments_MSa = get_data_for_experiments(api, 'MSa') 51 | 52 | mpl.rcParams['axes.prop_cycle'] = mpl.cycler(color=CB_color_cycle) 53 | 54 | fig = plt.figure(figsize=(10, 5)) 55 | 56 | m_plot = plt.subplot(2, 3, 1) 57 | plot_training_comparison(experiments_M) 58 | mp_plot = plt.subplot(2, 3, 2) 59 | plot_training_comparison(experiments_MP) 60 | mps_plot = plt.subplot(2, 3, 3) 61 | plot_training_comparison(experiments_MPS) 62 | ma_plot = plt.subplot(2, 3, 4) 63 | plot_training_comparison(experiments_Ma) 64 | msa_plot = plt.subplot(2, 3, 5) 65 | plot_training_comparison(experiments_MSa) 66 | 67 | m_plot.title.set_text('M') 68 | m_plot.set_xlabel(None) 69 | m_plot.set_ylabel('Ave. Reward') 70 | mp_plot.title.set_text('MP') 71 | mp_plot.set_xlabel(None) 72 | mps_plot.title.set_text('MPS') 73 | mps_plot.set_xlabel('Steps') 74 | ma_plot.title.set_text('Ma') 75 | ma_plot.set_xlabel('Steps') 76 | ma_plot.set_ylabel('Ave. Reward') 77 | msa_plot.title.set_text('MSa') 78 | msa_plot.set_xlabel('Steps') 79 | 80 | 81 | labels = [ 82 | 'No Masking', 83 | 'Depth 2', 84 | 'CAT_CL', 85 | 'CAT_CD', 86 | ] 87 | fig.legend( 88 | labels=labels, 89 | loc="lower right", 90 | borderaxespad=0.1, 91 | prop={'size': 9}, 92 | framealpha=1.0, 93 | bbox_to_anchor=(0, 0.25, 0.9, 0) 94 | ) 95 | 96 | experiments_legend = ''' 97 | M = Move 98 | MP = Move+Push 99 | MPS = Move+Push+Separate 100 | Ma = Move-Agent 101 | MSa = Move+Separate-Agent 102 | ''' 103 | 104 | fig.text(0.73, 0.01, experiments_legend) 105 | 106 | plt.tight_layout() 107 | plt.savefig('plots_training.pdf') 108 | -------------------------------------------------------------------------------- /plots/plots_training.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bam4d/conditional-action-trees/08e7574cd95b0a5714017e1ca78a12a32c38f645/plots/plots_training.pdf -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | ray[rllib]>=1.3.0 2 | wandb>=0.10.30 3 | griddly>=1.0.2 -------------------------------------------------------------------------------- /rllib_baseline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import ray 6 | from ray import tune 7 | from ray.rllib.models import ModelCatalog 8 | from ray.tune.integration.wandb import WandbLoggerCallback 9 | from ray.tune.registry import register_env 10 | 11 | from griddly import gd 12 | from griddly.util.rllib.callbacks import VideoCallback 13 | from griddly.util.rllib.environment.core import RLlibEnv 14 | from griddly.util.rllib.torch.agents.conv_agent import SimpleConvAgent 15 | 16 | from conditional_action_trees.conditional_action_policy_trainer import ConditionalActionImpalaTrainer 17 | 18 | parser = argparse.ArgumentParser(description='Run experiments') 19 | 20 | parser.add_argument('--yaml-file', help='YAML file condining GDY for the game') 21 | parser.add_argument('--experiment-name', default='unknown', help='name of the experiment') 22 | 23 | parser.add_argument('--root-directory', default=os.path.expanduser("~/ray_results"), 24 | help='root directory for all data associated with the run') 25 | parser.add_argument('--num-gpus', default=1, type=int, help='Number of GPUs to make available to ray.') 26 | parser.add_argument('--num-cpus', default=8, type=int, help='Number of CPUs to make available to ray.') 27 | 28 | parser.add_argument('--num-workers', default=7, type=int, help='Number of workers') 29 | parser.add_argument('--num-envs-per-worker', default=5, type=int, help='Number of workers') 30 | parser.add_argument('--num-gpus-per-worker', default=0, type=float, help='Number of gpus per worker') 31 | parser.add_argument('--num-cpus-per-worker', default=1, type=float, help='Number of gpus per worker') 32 | parser.add_argument('--max-training-steps', default=20000000, type=int, help='Number of workers') 33 | 34 | parser.add_argument('--capture-video', action='store_true', help='enable video capture') 35 | parser.add_argument('--video-directory', default='videos', help='directory of video') 36 | parser.add_argument('--video-frequency', type=int, default=1000000, help='Frequency of videos') 37 | 38 | parser.add_argument('--seed', type=int, default=69420, help='seed for experiments') 39 | 40 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') 41 | 42 | if __name__ == '__main__': 43 | 44 | args = parser.parse_args() 45 | 46 | sep = os.pathsep 47 | os.environ['PYTHONPATH'] = sep.join(sys.path) 48 | 49 | ray.init(include_dashboard=False, num_gpus=args.num_gpus, num_cpus=args.num_cpus) 50 | #ray.init(include_dashboard=False, num_gpus=args.num_gpus, num_cpus=args.num_cpus, local_mode=True) 51 | 52 | env_name = "ray-griddly-env" 53 | 54 | register_env(env_name, RLlibEnv) 55 | ModelCatalog.register_custom_model("SimpleConv", SimpleConvAgent) 56 | 57 | wandbLoggerCallback = WandbLoggerCallback( 58 | project='conditional_action_trees_reproduce', 59 | api_key_file='~/.wandb_rc', 60 | dir=args.root_directory 61 | ) 62 | 63 | max_training_steps = args.max_training_steps 64 | gdy_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.yaml_file) 65 | 66 | config = { 67 | 'framework': 'torch', 68 | 'seed': args.seed, 69 | 'num_workers': args.num_workers, 70 | 'num_envs_per_worker': args.num_envs_per_worker, 71 | 'num_gpus_per_worker': float(args.num_gpus_per_worker), 72 | 'num_cpus_per_worker': args.num_cpus_per_worker, 73 | 74 | 'callbacks': VideoCallback, 75 | 76 | 'model': { 77 | 'custom_model': 'SimpleConv', 78 | 'custom_model_config': {} 79 | }, 80 | 'env': env_name, 81 | 'env_config': { 82 | 'generate_valid_action_trees': False, 83 | 'random_level_on_reset': True, 84 | 'yaml_file': gdy_file, 85 | 'global_observer_type': gd.ObserverType.SPRITE_2D, 86 | 'max_steps': 1000, 87 | }, 88 | 'entropy_coeff_schedule': [ 89 | [0, 0.01], 90 | [max_training_steps, 0.0] 91 | ], 92 | 'lr_schedule': [ 93 | [0, args.lr], 94 | [max_training_steps, 0.0] 95 | ], 96 | 97 | } 98 | if args.capture_video: 99 | real_video_frequency = int(args.video_frequency / (args.num_envs_per_worker * args.num_workers)) 100 | config['env_config']['record_video_config'] = { 101 | 'frequency': real_video_frequency, 102 | 'directory': os.path.join(args.root_directory, args.video_directory) 103 | } 104 | 105 | stop = { 106 | "timesteps_total": max_training_steps, 107 | } 108 | 109 | trial_name_creator = lambda trial: f'baseline-{args.experiment_name}' 110 | 111 | result = tune.run( 112 | ConditionalActionImpalaTrainer, 113 | local_dir=args.root_directory, 114 | config=config, 115 | stop=stop, 116 | callbacks=[wandbLoggerCallback], 117 | trial_name_creator=trial_name_creator 118 | ) 119 | -------------------------------------------------------------------------------- /rllib_baseline_flat.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import gym 6 | import numpy as np 7 | import ray 8 | import torch 9 | from griddly.util.rllib.callbacks import VideoCallback 10 | from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 11 | from torch import nn 12 | from gym.spaces import MultiDiscrete, Dict, Box 13 | from ray import tune 14 | from ray.rllib.agents.impala import ImpalaTrainer 15 | from ray.rllib.models import ModelCatalog 16 | from ray.tune.integration.wandb import WandbLoggerCallback 17 | from ray.tune.registry import register_env 18 | 19 | from griddly import gd 20 | from griddly.util.rllib.environment.core import RLlibEnv 21 | from griddly.util.rllib.torch.agents.common import layer_init 22 | 23 | parser = argparse.ArgumentParser(description='Run experiments') 24 | 25 | parser.add_argument('--yaml-file', help='YAML file condining GDY for the game') 26 | parser.add_argument('--experiment-name', default='unknown', help='name of the experiment') 27 | 28 | parser.add_argument('--root-directory', default=os.path.expanduser("~/ray_results"), 29 | help='root directory for all data associated with the run') 30 | parser.add_argument('--num-gpus', default=1, type=int, help='Number of GPUs to make available to ray.') 31 | parser.add_argument('--num-cpus', default=8, type=int, help='Number of CPUs to make available to ray.') 32 | 33 | parser.add_argument('--num-workers', default=7, type=int, help='Number of workers') 34 | parser.add_argument('--num-envs-per-worker', default=5, type=int, help='Number of workers') 35 | parser.add_argument('--num-gpus-per-worker', default=0, type=float, help='Number of gpus per worker') 36 | parser.add_argument('--num-cpus-per-worker', default=1, type=float, help='Number of gpus per worker') 37 | parser.add_argument('--max-training-steps', default=20000000, type=int, help='Number of workers') 38 | 39 | parser.add_argument('--capture-video', action='store_true', help='enable video capture') 40 | parser.add_argument('--video-directory', default='videos', help='directory of video') 41 | parser.add_argument('--video-frequency', type=int, default=1000000, help='Frequency of videos') 42 | 43 | parser.add_argument('--seed', type=int, default=69420, help='seed for experiments') 44 | 45 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') 46 | 47 | 48 | class FlatActionWrapper(gym.Wrapper): 49 | 50 | def __init__(self, env): 51 | super().__init__(env) 52 | 53 | self._num_action_parts = 1 54 | self._action_params_offset = 0 55 | if not self.has_avatar: 56 | self._num_action_parts += 1 57 | self._action_params_offset = 1 58 | 59 | self._action_splits = np.zeros(self._num_action_parts) 60 | 61 | self._total_position_params = 0 62 | if not self.has_avatar: 63 | self._action_splits[0] = self.width * self.height 64 | self._total_position_params += self.width * self.height 65 | 66 | self._action_logit_offsets = {} 67 | 68 | total_action_params = 0 69 | for i, action_name in enumerate(self.env.action_names): 70 | self._action_logit_offsets[action_name] = total_action_params + self._total_position_params 71 | total_action_params += self.num_action_ids[action_name] 72 | 73 | self._action_splits[self._action_params_offset] = total_action_params 74 | 75 | self._total_actions = int(np.sum(self._action_splits)) 76 | 77 | self.action_space = MultiDiscrete(self._action_splits) 78 | self.observation_space = Dict({ 79 | 'obs': self.observation_space, 80 | 'mask': Box(0, 1, shape=(self._total_actions,)), 81 | }) 82 | 83 | def _get_flat_mask(self): 84 | flat_mask = np.zeros(self._total_actions) 85 | for location, action_names in self.env.game.get_available_actions(1).items(): 86 | if not self.has_avatar: 87 | flat_location = self.width * location[1] + location[0] 88 | flat_mask[flat_location] = 1 89 | for action_name, action_ids in self.env.game.get_available_action_ids(location, list(action_names)).items(): 90 | mask_offset = self._action_logit_offsets[action_name] 91 | flat_mask[mask_offset:mask_offset + self.num_action_ids[action_name]][action_ids] = 1 92 | return flat_mask 93 | 94 | def _to_griddly_action(self, action): 95 | # convert the flat action back to Griddly's tree based format 96 | 97 | griddly_action = [] 98 | action_ptr = 0 99 | if not self.has_avatar: 100 | x = action[action_ptr] % self.width 101 | griddly_action.append(x) 102 | y = int(action[action_ptr] / self.width) 103 | griddly_action.append(y) 104 | action_ptr += 1 105 | 106 | if self.action_count > 0: 107 | action_type_id = 0 108 | action_param_id = 0 109 | for action_name in self.action_names: 110 | action_offset_after_position = (self._action_logit_offsets[action_name] - self._total_position_params) 111 | next_offset = action_offset_after_position + self.num_action_ids[action_name] 112 | if next_offset > action[action_ptr]: 113 | action_param_id = action[action_ptr] - action_offset_after_position 114 | break 115 | action_type_id += 1 116 | 117 | griddly_action.append(action_type_id) 118 | griddly_action.append(action_param_id) 119 | else: 120 | griddly_action.append(action[action_ptr]) 121 | 122 | return griddly_action 123 | 124 | def reset(self, **kwargs): 125 | 126 | obs = super().reset(**kwargs) 127 | 128 | observations = { 129 | 'obs': obs, 130 | 'mask': self._get_flat_mask() 131 | } 132 | 133 | return observations 134 | 135 | def step(self, action): 136 | griddly_action = self._to_griddly_action(action) 137 | 138 | obs, reward, info, done = super().step(griddly_action) 139 | 140 | observations = { 141 | 'obs': obs, 142 | 'mask': self._get_flat_mask() 143 | } 144 | 145 | return observations, reward, info, done 146 | 147 | 148 | class SimpleConvFlatAgent(TorchModelV2, nn.Module): 149 | 150 | def __init__(self, obs_space, action_space, num_outputs, model_config, name): 151 | super().__init__(obs_space, action_space, num_outputs, model_config, name) 152 | nn.Module.__init__(self) 153 | 154 | self._num_objects = obs_space.original_space['obs'].shape[2] 155 | self._num_actions = num_outputs 156 | 157 | linear_flatten = np.prod(obs_space.original_space['obs'].shape[:2]) * 64 158 | 159 | self.network = nn.Sequential( 160 | layer_init(nn.Conv2d(self._num_objects, 32, 3, padding=1)), 161 | nn.ReLU(), 162 | layer_init(nn.Conv2d(32, 64, 3, padding=1)), 163 | nn.ReLU(), 164 | nn.Flatten(), 165 | layer_init(nn.Linear(linear_flatten, 1024)), 166 | nn.ReLU(), 167 | layer_init(nn.Linear(1024, 512)), 168 | nn.ReLU(), 169 | ) 170 | 171 | self._actor_head = nn.Sequential( 172 | layer_init(nn.Linear(512, 256), std=0.01), 173 | nn.ReLU(), 174 | layer_init(nn.Linear(256, self._num_actions), std=0.01) 175 | ) 176 | 177 | self._critic_head = nn.Sequential( 178 | layer_init(nn.Linear(512, 1), std=0.01) 179 | ) 180 | 181 | def forward(self, input_dict, state, seq_lens): 182 | obs_transformed = input_dict['obs']['obs'].permute(0, 3, 1, 2) 183 | mask = input_dict['obs']['mask'] 184 | network_output = self.network(obs_transformed) 185 | value = self._critic_head(network_output) 186 | self._value = value.reshape(-1) 187 | logits = self._actor_head(network_output) 188 | 189 | logits += torch.maximum(torch.log(mask), torch.tensor(torch.finfo().min)) 190 | 191 | return logits, state 192 | 193 | def value_function(self): 194 | return self._value 195 | 196 | 197 | if __name__ == '__main__': 198 | 199 | args = parser.parse_args() 200 | 201 | sep = os.pathsep 202 | os.environ['PYTHONPATH'] = sep.join(sys.path) 203 | 204 | ray.init(include_dashboard=False, num_gpus=1, num_cpus=args.num_cpus) 205 | env_name = "ray-griddly-env" 206 | 207 | 208 | def _create_env(env_config): 209 | env = RLlibEnv(env_config) 210 | return FlatActionWrapper(env) 211 | 212 | 213 | register_env(env_name, _create_env) 214 | ModelCatalog.register_custom_model("SimpleConv", SimpleConvFlatAgent) 215 | 216 | wandbLoggerCallback = WandbLoggerCallback( 217 | project='conditional_action_trees_reproduce', 218 | api_key_file='~/.wandb_rc', 219 | dir=args.root_directory 220 | ) 221 | 222 | max_training_steps = args.max_training_steps 223 | gdy_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.yaml_file) 224 | 225 | config = { 226 | 'framework': 'torch', 227 | 'seed': args.seed, 228 | 'num_workers': args.num_workers, 229 | 'num_envs_per_worker': args.num_envs_per_worker, 230 | 'num_gpus_per_worker': float(args.num_gpus_per_worker), 231 | 'num_cpus_per_worker': args.num_cpus_per_worker, 232 | 233 | 'callbacks': VideoCallback, 234 | 235 | 'model': { 236 | 'custom_model': 'SimpleConv', 237 | 'custom_model_config': {} 238 | }, 239 | 'env': env_name, 240 | 'env_config': { 241 | 'generate_valid_action_trees': False, 242 | 'random_level_on_reset': True, 243 | 'yaml_file': gdy_file, 244 | 'global_observer_type': gd.ObserverType.SPRITE_2D, 245 | 'max_steps': 1000, 246 | }, 247 | 'entropy_coeff_schedule': [ 248 | [0, 0.01], 249 | [max_training_steps, 0.0] 250 | ], 251 | 'lr_schedule': [ 252 | [0, args.lr], 253 | [max_training_steps, 0.0] 254 | ], 255 | 256 | } 257 | if args.capture_video: 258 | real_video_frequency = int(args.video_frequency / (args.num_envs_per_worker * args.num_workers)) 259 | config['env_config']['record_video_config'] = { 260 | 'frequency': real_video_frequency, 261 | 'directory': os.path.join(args.root_directory, args.video_directory) 262 | } 263 | 264 | stop = { 265 | "timesteps_total": max_training_steps, 266 | } 267 | 268 | trial_name_creator = lambda trial: f'baseline-flat-{args.experiment_name}' 269 | 270 | result = tune.run( 271 | ImpalaTrainer, 272 | local_dir=args.root_directory, 273 | config=config, 274 | stop=stop, 275 | callbacks=[wandbLoggerCallback], 276 | trial_name_creator=trial_name_creator 277 | ) 278 | -------------------------------------------------------------------------------- /rllib_conditional_actions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import ray 5 | from ray import tune 6 | from ray.rllib.models import ModelCatalog 7 | from ray.tune.integration.wandb import WandbLoggerCallback 8 | from ray.tune.registry import register_env 9 | 10 | from griddly import gd 11 | from griddly.util.rllib.callbacks import VideoCallback 12 | from griddly.util.rllib.environment.core import RLlibEnv 13 | from griddly.util.rllib.torch.agents.conv_agent import SimpleConvAgent 14 | 15 | import argparse 16 | 17 | from conditional_action_trees.conditional_action_policy_trainer import ConditionalActionImpalaTrainer 18 | 19 | parser = argparse.ArgumentParser(description='Run experiments') 20 | 21 | parser.add_argument('--yaml-file', help='YAML file containing GDY for the game') 22 | parser.add_argument('--experiment-name', default='unknown', help='name of the experiment') 23 | 24 | parser.add_argument('--root-directory', default=os.path.expanduser("~/ray_results"), 25 | help='root directory for all data associated with the run') 26 | parser.add_argument('--num-gpus', default=1, type=int, help='Number of GPUs to make available to ray.') 27 | parser.add_argument('--num-cpus', default=8, type=int, help='Number of CPUs to make available to ray.') 28 | 29 | parser.add_argument('--num-workers', default=7, type=int, help='Number of workers') 30 | parser.add_argument('--num-envs-per-worker', default=5, type=int, help='Number of workers') 31 | parser.add_argument('--num-gpus-per-worker', default=0, type=float, help='Number of gpus per worker') 32 | parser.add_argument('--num-cpus-per-worker', default=1, type=float, help='Number of gpus per worker') 33 | parser.add_argument('--max-training-steps', default=20000000, type=int, help='Number of workers') 34 | 35 | parser.add_argument('--capture-video', action='store_true', help='enable video capture') 36 | parser.add_argument('--video-directory', default='videos', help='directory of video') 37 | parser.add_argument('--video-frequency', type=int, default=1000000, help='Frequency of videos') 38 | 39 | parser.add_argument('--allow-nop', action='store_true', default=True, help='allow NOP actions in action tree') 40 | parser.add_argument('--vtrace-masking', action='store_true', default=True, help='use masks in vtrace calculations') 41 | 42 | parser.add_argument('--seed', type=int, default=69420, help='seed for experiments') 43 | 44 | parser.add_argument('--lr', type=float, default=0.0005, help='learning rate') 45 | 46 | if __name__ == '__main__': 47 | 48 | args = parser.parse_args() 49 | 50 | sep = os.pathsep 51 | os.environ['PYTHONPATH'] = sep.join(sys.path) 52 | 53 | ray.init(include_dashboard=False, num_gpus=args.num_gpus, num_cpus=args.num_cpus) 54 | #ray.init(include_dashboard=False, num_gpus=1, num_cpus=args.num_cpus, local_mode=True) 55 | 56 | env_name = "ray-griddly-env" 57 | 58 | register_env(env_name, RLlibEnv) 59 | ModelCatalog.register_custom_model("SimpleConv", SimpleConvAgent) 60 | 61 | wandbLoggerCallback = WandbLoggerCallback( 62 | project='conditional_action_trees_reproduce', 63 | api_key_file='~/.wandb_rc', 64 | dir=args.root_directory 65 | ) 66 | 67 | max_training_steps = args.max_training_steps 68 | gdy_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), args.yaml_file) 69 | 70 | config = { 71 | 'framework': 'torch', 72 | 'seed': args.seed, 73 | 'num_workers': args.num_workers, 74 | 'num_envs_per_worker': args.num_envs_per_worker, 75 | 'num_gpus_per_worker': float(args.num_gpus_per_worker), 76 | 'num_cpus_per_worker': args.num_cpus_per_worker, 77 | 78 | 'callbacks': VideoCallback, 79 | 80 | 'model': { 81 | 'custom_model': 'SimpleConv', 82 | 'custom_model_config': {} 83 | }, 84 | 'env': env_name, 85 | 'env_config': { 86 | 87 | 'allow_nop': args.allow_nop, 88 | 'invalid_action_masking': tune.grid_search(['conditional', 'collapsed']), 89 | 'vtrace_masking': args.vtrace_masking, 90 | 'generate_valid_action_trees': True, 91 | #'level': 0, 92 | 'random_level_on_reset': True, 93 | 'yaml_file': gdy_file, 94 | 'global_observer_type': gd.ObserverType.SPRITE_2D, 95 | 'max_steps': 1000, 96 | }, 97 | 'entropy_coeff_schedule': [ 98 | [0, 0.01], 99 | [max_training_steps, 0.0] 100 | ], 101 | 'lr_schedule': [ 102 | [0, args.lr], 103 | [max_training_steps, 0.0] 104 | ], 105 | 106 | } 107 | 108 | if args.capture_video: 109 | real_video_frequency = int(args.video_frequency / (args.num_envs_per_worker * args.num_workers)) 110 | config['env_config']['record_video_config'] = { 111 | 'frequency': real_video_frequency, 112 | 'directory': os.path.join(args.root_directory, args.video_directory) 113 | } 114 | 115 | stop = { 116 | "timesteps_total": max_training_steps, 117 | } 118 | 119 | trial_name_creator = lambda trial: f'CAT-{args.experiment_name}-{trial.config["env_config"]["invalid_action_masking"]}' 120 | 121 | result = tune.run( 122 | ConditionalActionImpalaTrainer, 123 | local_dir=args.root_directory, 124 | config=config, 125 | stop=stop, 126 | callbacks=[wandbLoggerCallback], 127 | trial_name_creator=trial_name_creator 128 | ) 129 | --------------------------------------------------------------------------------