├── enn_zoo └── README.md ├── rogue_net └── README.md ├── enn_ppo └── README.md ├── entity_gym └── README.md ├── .gitignore ├── configs ├── entity-gym │ ├── rock_paper_scissors.ron │ ├── not_hotdog.ron │ ├── count.ron │ ├── move_to_origin.ron │ ├── xor.ron │ ├── multi_armed_bandit.ron │ ├── cherry_pick.ron │ ├── floor_is_lava.ron │ ├── pick_matching_balls.ron │ ├── minefield.ron │ ├── multisnake_1snakes11len.ron │ ├── multisnake_2snakes11len.ron │ ├── minefield-relattn1.ron │ └── minefield-relattn2.ron ├── procgen │ └── bigfish.ron ├── codecraft │ ├── allied_wealth.ron │ ├── allied_wealth-relattn2.ron │ ├── allied_wealth-relattn1.ron │ ├── arena_tiny_2v2.ron │ └── arena_medium.ron └── xprun │ ├── trainbc.ron │ ├── train.ron │ └── traincc.ron ├── pyproject.toml ├── .github └── workflows │ ├── checks.yaml │ └── publish.yaml ├── .readthedocs.yaml ├── README.md ├── LICENSE-MIT ├── mypy.ini ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE-APACHE └── ARCHITECTURE.md /enn_zoo/README.md: -------------------------------------------------------------------------------- 1 | Moved to [https://github.com/entity-neural-network/enn-zoo](https://github.com/entity-neural-network/enn-zoo). 2 | -------------------------------------------------------------------------------- /rogue_net/README.md: -------------------------------------------------------------------------------- 1 | Moved to [https://github.com/entity-neural-network/rogue-net](https://github.com/entity-neural-network/rogue-net). 2 | -------------------------------------------------------------------------------- /enn_ppo/README.md: -------------------------------------------------------------------------------- 1 | Moved to [https://github.com/entity-neural-network/enn-trainer](https://github.com/entity-neural-network/enn-trainer). 2 | -------------------------------------------------------------------------------- /entity_gym/README.md: -------------------------------------------------------------------------------- 1 | Moved to [https://github.com/entity-neural-network/entity-gym](https://github.com/entity-neural-network/entity-gym). 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | *.egg-info 3 | *.pyc 4 | wandb/ 5 | runs/ 6 | .dmypy.json 7 | __pycache__ 8 | 9 | # Pycharm 10 | .idea/ 11 | 12 | # Meld 13 | *.orig 14 | 15 | # Docs 16 | entity_gym/docs/build 17 | entity_gym/docs/source/generated 18 | entity_gym/docs/source/entity_gym 19 | -------------------------------------------------------------------------------- /configs/entity-gym/rock_paper_scissors.ron: -------------------------------------------------------------------------------- 1 | ExperimentConfig( 2 | version: 0, 3 | env: ( 4 | id: "RockPaperScissors", 5 | ), 6 | rollout: ( 7 | num_envs: 256, 8 | steps: 1, 9 | ), 10 | total_timesteps: 4000, 11 | net: ( 12 | d_model: 16, 13 | n_layer: 2, 14 | ), 15 | optim: ( 16 | bs: 256, 17 | lr: 0.03, 18 | ), 19 | ) -------------------------------------------------------------------------------- /configs/entity-gym/not_hotdog.ron: -------------------------------------------------------------------------------- 1 | // Achieves episodic return of 1.0: https://wandb.ai/entity-neural-network/enn-ppo/reports/NotHotdog--VmlldzoxNjI3MDcz 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "NotHotdog", 6 | ), 7 | rollout: ( 8 | num_envs: 256, 9 | steps: 1, 10 | ), 11 | total_timesteps: 32768, 12 | net: ( 13 | d_model: 16, 14 | n_layer: 1 15 | ), 16 | optim: ( 17 | bs: 256, 18 | lr: 0.003, 19 | ), 20 | ) -------------------------------------------------------------------------------- /configs/entity-gym/count.ron: -------------------------------------------------------------------------------- 1 | // Achieves episodic return of 1.0: https://wandb.ai/entity-neural-network/enn-ppo/reports/Count--VmlldzoxNjI2OTI4 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "Count", 6 | kwargs: "{\"masked_choices\": 2}" 7 | ), 8 | rollout: ( 9 | num_envs: 16, 10 | steps: 1, 11 | ), 12 | total_timesteps: 2000, 13 | net: ( 14 | d_model: 16, 15 | n_layer: 1 16 | ), 17 | optim: ( 18 | bs: 16, 19 | lr: 0.01 20 | ), 21 | ) -------------------------------------------------------------------------------- /configs/entity-gym/move_to_origin.ron: -------------------------------------------------------------------------------- 1 | // Achieves 0.99 mean episodic return most of the time: https://wandb.ai/entity-neural-network/enn-ppo/reports/MoveToOrigin--VmlldzoxNjI3MzA5 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "MoveToOrigin", 6 | ), 7 | rollout: ( 8 | num_envs: 64, 9 | processes: 16, 10 | steps: 32, 11 | ), 12 | total_timesteps: 1000000, 13 | net: ( 14 | d_model: 16, 15 | ), 16 | optim: ( 17 | lr: 0.003, 18 | bs: 2048, 19 | ), 20 | ) -------------------------------------------------------------------------------- /configs/entity-gym/xor.ron: -------------------------------------------------------------------------------- 1 | // Achieves 1.0 episodic return most of the time: https://wandb.ai/entity-neural-network/enn-ppo/reports/Xor--VmlldzoxOTI5NTQ0 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "Xor", 6 | ), 7 | rollout: ( 8 | num_envs: 2048, 9 | steps: 1, 10 | processes: 16, 11 | ), 12 | total_timesteps: 500000, 13 | net: ( 14 | n_layer: 2, 15 | d_model: 16 16 | ), 17 | optim: ( 18 | bs: 2048, 19 | lr: 0.003, 20 | ), 21 | ppo: ( 22 | ent_coef: 0.3, 23 | ), 24 | ) -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "incubator" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Clemens Winter "] 6 | 7 | [tool.poetry.dependencies] 8 | python = ">=3.8,<3.10" 9 | enn_zoo = {path = "enn_zoo", develop = true} 10 | griddly = {version = "1.3.5"} 11 | gym-microrts = {version = "0.6.0"} 12 | numba = "^0.55.1" 13 | 14 | [tool.poetry.dev-dependencies] 15 | black = "22.3.0" 16 | pytest = "^6.2.5" 17 | mypy = "^0.910" 18 | ipdb = "^0.13.9" 19 | pre-commit = "^2.17.0" 20 | 21 | [build-system] 22 | requires = ["poetry-core>=1.0.0"] 23 | build-backend = "poetry.core.masonry.api" 24 | -------------------------------------------------------------------------------- /configs/entity-gym/multi_armed_bandit.ron: -------------------------------------------------------------------------------- 1 | // Achieves episodic return of 1.0: https://wandb.ai/entity-neural-network/enn-ppo/reports/MultiArmedBandit--VmlldzoxNjI3MDY4 https://wandb.ai/entity-neural-network/enn-ppo/reports/MultiArmedBandit--VmlldzoxNjIxMDYw 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "MultiArmedBandit", 6 | ), 7 | rollout: ( 8 | processes: 16, 9 | num_envs: 256, 10 | steps: 1, 11 | ), 12 | total_timesteps: 32768, 13 | net: ( 14 | d_model: 16, 15 | n_layer: 1 16 | ), 17 | optim: ( 18 | bs: 256, 19 | lr: 0.003, 20 | ), 21 | ) -------------------------------------------------------------------------------- /configs/entity-gym/cherry_pick.ron: -------------------------------------------------------------------------------- 1 | // Achieves > 0.99 mean episodic return most of the time: https://wandb.ai/entity-neural-network/enn-ppo/reports/CherryPick--VmlldzoxNjI3MzA0 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "CherryPick", 6 | kwargs: "{\"num_cherries\": 32}", 7 | ), 8 | rollout: ( 9 | num_envs: 128, 10 | steps: 16, 11 | processes: 16, 12 | ), 13 | total_timesteps: 250000, 14 | net: ( 15 | d_model: 128, 16 | n_layer: 2, 17 | ), 18 | optim: ( 19 | bs: 2048, 20 | lr: 0.001, 21 | ), 22 | ppo: ( 23 | ent_coef: 0.1, 24 | gamma: 0.99, 25 | vf_coef: 0.25, 26 | anneal_entropy: true, 27 | ), 28 | ) -------------------------------------------------------------------------------- /.github/workflows/checks.yaml: -------------------------------------------------------------------------------- 1 | name: Checks 2 | 3 | on: 4 | push: 5 | branches: [ '*' ] 6 | pull_request: 7 | branches: [ main ] 8 | jobs: 9 | pre-commit: 10 | name: pre-commit 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.9, 3.8] 15 | 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | #---------------------------------------------- 23 | # ----- install & configure poetry ----- 24 | #---------------------------------------------- 25 | - name: Run tests 26 | run: | 27 | echo Success 28 | -------------------------------------------------------------------------------- /configs/entity-gym/floor_is_lava.ron: -------------------------------------------------------------------------------- 1 | // Achieves > 0.99 episodic return: https://wandb.ai/entity-neural-network/enn-ppo/reports/FloorIsLava--VmlldzoxNjM3OTU3 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "FloorIsLava", 6 | ), 7 | rollout: ( 8 | num_envs: 256, 9 | steps: 1, 10 | ), 11 | total_timesteps: 32768, 12 | net: ( 13 | d_model: 16, 14 | n_layer: 2, 15 | relpos_encoding: ( 16 | extent: [1, 1], 17 | position_features: ["x", "y"], 18 | per_entity_values: true, 19 | ), 20 | ), 21 | optim: ( 22 | bs: 256, 23 | lr: 0.01, 24 | ), 25 | ppo: PPOConfig( 26 | ent_coef: 1.0, 27 | anneal_entropy: true, 28 | ), 29 | ) -------------------------------------------------------------------------------- /configs/entity-gym/pick_matching_balls.ron: -------------------------------------------------------------------------------- 1 | // Achieves > 0.995 mean episodic return: https://wandb.ai/entity-neural-network/enn-ppo/reports/PickMatchingBalls--VmlldzoxNjIxMjQ5 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "PickMatchingBalls", 6 | kwargs: "{\"max_balls\": 32, \"one_hot\": true, \"randomize\": true}", 7 | ), 8 | rollout: ( 9 | num_envs: 64, 10 | steps: 16, 11 | processes: 16, 12 | ), 13 | total_timesteps: 500000, 14 | net: ( 15 | d_model: 128, 16 | n_layer: 4, 17 | d_qk: 32, 18 | ), 19 | optim: ( 20 | bs: 1024, 21 | lr: 0.001, 22 | ), 23 | ppo: ( 24 | ent_coef: 0.5, 25 | gamma: 0.99, 26 | anneal_entropy: true, 27 | ), 28 | ) -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.9" 13 | # You can also specify other tool versions: 14 | # nodejs: "16" 15 | # rust: "1.55" 16 | # golang: "1.17" 17 | apt_packages: 18 | - graphviz 19 | 20 | # Build documentation in the docs/ directory with Sphinx 21 | sphinx: 22 | configuration: entity_gym/docs/source/conf.py 23 | 24 | python: 25 | # Install our python package before building the docs 26 | install: 27 | - requirements: entity_gym/docs/requirements.txt 28 | - method: pip 29 | path: entity_gym 30 | -------------------------------------------------------------------------------- /configs/entity-gym/minefield.ron: -------------------------------------------------------------------------------- 1 | // Achieves > 0.99 episodic return: https://wandb.ai/entity-neural-network/enn-ppo/reports/Minefield-Baselines--VmlldzoxNzAyOTE2 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "Minefield", 6 | kwargs: "{\"max_mines\": 10}" 7 | ), 8 | rollout: ( 9 | num_envs: 512, 10 | steps: 128, 11 | processes: 16, 12 | ), 13 | total_timesteps: 3000000, 14 | net: ( 15 | d_model: 32, 16 | n_layer: 2, 17 | translation: ( 18 | position_features: ["x_pos", "y_pos"], 19 | rotation_angle_feature: "direction", 20 | reference_entity: "Vehicle", 21 | ), 22 | ), 23 | optim: ( 24 | bs: 1024, 25 | lr: 0.001, 26 | ), 27 | ppo: ( 28 | ent_coef: 0.0001, 29 | gamma: 0.999, 30 | ), 31 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ENN Incubator 2 | 3 | [![Discord](https://img.shields.io/discord/913497968701747270?style=flat-square)](https://discord.gg/SjVqhSW4Qf) 4 | 5 | The enn-incubator repo was used to develop a number of different projects which have since been split out into their own repos: 6 | 7 | - [entity-gym](https://github.com/entity-neural-network/entity-gym): Abstraction over reinforcement learning environments that represent observations as lists of structured objects. 8 | - [enn-trainer](https://github.com/entity-neural-network/enn-trainer): PPO training loop for entity-gym environments. 9 | - [rogue-net](https://github.com/entity-neural-network/rogue-net): Ragged batch transformer implementation that accepts variable length lists of structured objects as inputs. 10 | - [enn-zoo](https://github.com/entity-neural-network/enn-zoo): Collection of entity gym bindings for different environments. 11 | 12 | -------------------------------------------------------------------------------- /configs/entity-gym/multisnake_1snakes11len.ron: -------------------------------------------------------------------------------- 1 | // Achieves > 0.99 episodic return: https://wandb.ai/entity-neural-network/enn-ppo/reports/MultiSnake--VmlldzoxNjM3OTYz 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "MultiSnake", 6 | kwargs: "{\"num_snakes\": 1, \"max_snake_length\": 11}", 7 | ), 8 | rollout: ( 9 | num_envs: 512, 10 | steps: 64, 11 | processes: 16, 12 | ), 13 | total_timesteps: 10000000, 14 | net: ( 15 | d_model: 32, 16 | n_layer: 2, 17 | relpos_encoding: ( 18 | extent: [10, 10], 19 | position_features: ["x", "y"], 20 | ), 21 | ), 22 | optim: ( 23 | bs: 32768, 24 | lr: 0.018, 25 | ), 26 | ppo: ( 27 | ent_coef: 0.02, // Higher? Just use default? 28 | gamma: 0.97, 29 | anneal_entropy: true, 30 | ), 31 | ) -------------------------------------------------------------------------------- /configs/procgen/bigfish.ron: -------------------------------------------------------------------------------- 1 | // Baselines for BigFish, BossFight, StarPilot, Leaper, Plunder on both easy and hard distribution mode: https://wandb.ai/entity-neural-network/enn-ppo/reports/Procgen-Baselines--VmlldzoxNzUxNDcy 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "Procgen:BigFish", 6 | ), 7 | rollout: ( 8 | num_envs: 64, 9 | steps: 256, 10 | processes: 16, 11 | ), 12 | total_timesteps: 25000000, 13 | net: ( 14 | d_model: 64, 15 | n_layer: 2, 16 | translation: ( 17 | position_features: ["x", "y"], 18 | reference_entity: "Player", 19 | add_dist_feature: true, 20 | ), 21 | ), 22 | optim: ( 23 | bs: 8192, 24 | micro_bs: 1024, 25 | lr: 0.01, 26 | ), 27 | ppo: ( 28 | ent_coef: 0.1, 29 | anneal_entropy: true, 30 | gamma: 0.999, 31 | ), 32 | ) -------------------------------------------------------------------------------- /configs/entity-gym/multisnake_2snakes11len.ron: -------------------------------------------------------------------------------- 1 | // Achieves ~0.99 episodic return: https://wandb.ai/entity-neural-network/enn-ppo/reports/MultiSnake-snakes-2-len-11--VmlldzoxODE1OTA0 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "MultiSnake", 6 | kwargs: "{\"num_snakes\": 2, \"max_snake_length\": 11}", 7 | ), 8 | rollout: ( 9 | num_envs: 256, 10 | steps: 128, 11 | processes: 16, 12 | ), 13 | total_timesteps: 100000000, 14 | net: ( 15 | d_model: 128, 16 | n_layer: 2, 17 | relpos_encoding: ( 18 | extent: [10, 10], 19 | position_features: ["x", "y"], 20 | ), 21 | ), 22 | optim: ( 23 | bs: 32768, 24 | lr: 0.005, 25 | max_grad_norm: 10, 26 | micro_bs: 4096, 27 | ), 28 | ppo: ( 29 | ent_coef: 0.15, 30 | gamma: 0.997, 31 | anneal_entropy: true, 32 | ), 33 | ) -------------------------------------------------------------------------------- /configs/entity-gym/minefield-relattn1.ron: -------------------------------------------------------------------------------- 1 | // Achieves > 0.99 episodic return: https://wandb.ai/entity-neural-network/enn-ppo/reports/Minefield-Baselines--VmlldzoxNzAyOTE2 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "Minefield", 6 | kwargs: "{\"max_mines\": 10}" 7 | ), 8 | rollout: ( 9 | num_envs: 512, 10 | steps: 128, 11 | processes: 16, 12 | ), 13 | total_timesteps: 3000000, 14 | net: ( 15 | d_model: 32, 16 | n_layer: 2, 17 | relpos_encoding: ( 18 | position_features: ["x_pos", "y_pos"], 19 | rotation_angle_feature: "direction", 20 | extent: [8, 2], 21 | interpolate: true, 22 | radial: true, 23 | distance: true, 24 | scale: 300.0, 25 | ), 26 | ), 27 | optim: ( 28 | bs: 1024, 29 | lr: 0.005, 30 | ), 31 | ppo: ( 32 | ent_coef: 0.0001, 33 | gamma: 0.999, 34 | ), 35 | ) -------------------------------------------------------------------------------- /configs/codecraft/allied_wealth.ron: -------------------------------------------------------------------------------- 1 | // https://wandb.ai/entity-neural-network/enn-ppo/reports/Allied-Wealth-baselines--VmlldzoxNzgyOTk1 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "CodeCraft", 6 | ), 7 | rollout: ( 8 | num_envs: 256, 9 | steps: 64, 10 | ), 11 | optim: ( 12 | max_grad_norm: 10, 13 | update_epochs: 3, 14 | lr: 0.00044, 15 | bs: 1024, 16 | anneal_lr: false, 17 | ), 18 | ppo: ( 19 | anneal_entropy: true, 20 | gamma: 0.9890351625500452, 21 | ent_coef: 1e-05, 22 | vf_coef: 3.7, 23 | ), 24 | net: ( 25 | translation: ( 26 | reference_entity: "ally", 27 | position_features: ["x", "y"], 28 | rotation_vec_features: ["orientation_x", "orientation_y"], 29 | add_dist_feature: true, 30 | ), 31 | d_model: 256, 32 | n_layer: 1, 33 | n_head: 2, 34 | ), 35 | total_timesteps: 2000000, 36 | ) -------------------------------------------------------------------------------- /configs/entity-gym/minefield-relattn2.ron: -------------------------------------------------------------------------------- 1 | // Achieves > 0.99 episodic return: https://wandb.ai/entity-neural-network/enn-ppo/reports/Minefield-Baselines--VmlldzoxNzAyOTE2 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "Minefield", 6 | kwargs: "{\"max_mines\": 10}" 7 | ), 8 | rollout: ( 9 | num_envs: 512, 10 | steps: 128, 11 | processes: 16, 12 | ), 13 | total_timesteps: 3000000, 14 | net: ( 15 | d_model: 32, 16 | n_layer: 2, 17 | relpos_encoding: ( 18 | position_features: ["x_pos", "y_pos"], 19 | rotation_angle_feature: "direction", 20 | extent: [8], 21 | interpolate: true, 22 | radial: true, 23 | key_relpos_projection: true, 24 | value_relpos_projection: true, 25 | ), 26 | ), 27 | optim: ( 28 | bs: 1024, 29 | lr: 0.005, 30 | ), 31 | ppo: ( 32 | ent_coef: 0.0001, 33 | gamma: 0.999, 34 | ), 35 | ) -------------------------------------------------------------------------------- /configs/codecraft/allied_wealth-relattn2.ron: -------------------------------------------------------------------------------- 1 | // https://wandb.ai/entity-neural-network/enn-ppo/reports/Allied-Wealth-baselines--VmlldzoxNzgyOTk1 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "CodeCraft", 6 | ), 7 | rollout: ( 8 | num_envs: 256, 9 | steps: 64, 10 | ), 11 | optim: ( 12 | max_grad_norm: 10, 13 | update_epochs: 3, 14 | lr: 0.00044, 15 | bs: 1024, 16 | anneal_lr: false, 17 | ), 18 | ppo: ( 19 | anneal_entropy: true, 20 | gamma: 0.9890351625500452, 21 | ent_coef: 1e-05, 22 | vf_coef: 3.7, 23 | ), 24 | net: ( 25 | d_model: 256, 26 | n_layer: 1, 27 | n_head: 2, 28 | relpos_encoding: ( 29 | extent: [8], 30 | position_features: ["x", "y"], 31 | key_relpos_projection: true, 32 | value_relpos_projection: true, 33 | rotation_vec_features: ["orientation_x", "orientation_y"], 34 | radial: true, 35 | ), 36 | ), 37 | total_timesteps: 2000000, 38 | ) -------------------------------------------------------------------------------- /configs/codecraft/allied_wealth-relattn1.ron: -------------------------------------------------------------------------------- 1 | // https://wandb.ai/entity-neural-network/enn-ppo/reports/Allied-Wealth-baselines--VmlldzoxNzgyOTk1 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "CodeCraft", 6 | ), 7 | rollout: ( 8 | num_envs: 256, 9 | steps: 64, 10 | ), 11 | optim: ( 12 | max_grad_norm: 10, 13 | update_epochs: 3, 14 | lr: 0.00044, 15 | bs: 1024, 16 | anneal_lr: false, 17 | ), 18 | ppo: ( 19 | anneal_entropy: true, 20 | gamma: 0.9890351625500452, 21 | ent_coef: 1e-05, 22 | vf_coef: 3.7, 23 | ), 24 | net: ( 25 | d_model: 256, 26 | n_layer: 1, 27 | n_head: 2, 28 | relpos_encoding: ( 29 | extent: [3, 2], 30 | position_features: ["x", "y"], 31 | radial: true, 32 | distance: true, 33 | rotation_vec_features: ["orientation_x", "orientation_y"], 34 | interpolate: true, 35 | scale: 1000.0, 36 | ), 37 | ), 38 | total_timesteps: 2000000, 39 | ) -------------------------------------------------------------------------------- /LICENSE-MIT: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Entity Neural Network developers 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | disallow_untyped_defs = True 3 | disallow_any_unimported = True 4 | no_implicit_optional = True 5 | check_untyped_defs = True 6 | warn_return_any = True 7 | show_error_codes = True 8 | warn_unused_ignores = True 9 | 10 | exclude = wandb 11 | 12 | [mypy-torch_scatter.*] 13 | ignore_missing_imports = True 14 | [mypy-wandb.*] 15 | ignore_missing_imports = True 16 | [mypy-msgpack.*] 17 | ignore_missing_imports = True 18 | [mypy-msgpack_numpy.*] 19 | ignore_missing_imports = True 20 | [mypy-tqdm.*] 21 | ignore_missing_imports = True 22 | [mypy-griddly.*] 23 | ignore_missing_imports = True 24 | [mypy-cloudpickle.*] 25 | ignore_missing_imports = True 26 | [mypy-optuna.*] 27 | ignore_missing_imports = True 28 | [mypy-orjson.*] 29 | ignore_missing_imports = True 30 | [mypy-gym_microrts.*] 31 | ignore_missing_imports = True 32 | [mypy-jpype.*] 33 | ignore_missing_imports = True 34 | [mypy-rts.*] 35 | ignore_missing_imports = True 36 | [mypy-ai.*] 37 | ignore_missing_imports = True 38 | [mypy-ts.*] 39 | ignore_missing_imports = True 40 | [mypy-PIL.*] 41 | ignore_missing_imports = True 42 | [mypy-web_pdb.*] 43 | ignore_missing_imports = True 44 | [mypy-procgen.*] 45 | ignore_missing_imports = True 46 | [mypy-opensimplex.*] 47 | ignore_missing_imports = True 48 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/asottile/pyupgrade 3 | rev: v2.31.1 4 | hooks: 5 | - id: pyupgrade 6 | args: 7 | - --py38-plus 8 | - repo: https://github.com/PyCQA/isort 9 | rev: 5.10.1 10 | hooks: 11 | - id: isort 12 | args: 13 | - --profile=black 14 | - --skip-glob=wandb/**/* 15 | - --thirdparty=wandb 16 | - repo: https://github.com/myint/autoflake 17 | rev: v1.4 18 | hooks: 19 | - id: autoflake 20 | args: 21 | - -r 22 | - --exclude=wandb 23 | - --in-place 24 | - --remove-unused-variables 25 | - --remove-all-unused-imports 26 | - repo: https://github.com/python/black 27 | rev: 22.3.0 28 | hooks: 29 | - id: black 30 | args: 31 | - --exclude=wandb 32 | - repo: https://github.com/codespell-project/codespell 33 | rev: v2.1.0 34 | hooks: 35 | - id: codespell 36 | args: 37 | - --ignore-words-list=nd,reacher,thist,ths,magent,crate 38 | - --skip=docs/css/termynal.css,docs/js/termynal.js 39 | - repo: local 40 | hooks: 41 | - id: mypy 42 | name: mypy 43 | entry: mypy 44 | language: system 45 | types: [python] 46 | -------------------------------------------------------------------------------- /configs/codecraft/arena_tiny_2v2.ron: -------------------------------------------------------------------------------- 1 | // Achieves ~0.2 eval/episodic_reward.mean: https://wandb.ai/entity-neural-network/enn-ppo/reports/Arena-Tiny-2v2-baseline--VmlldzoxNzgwMTQ3 2 | ExperimentConfig( 3 | env: ( 4 | id: "CodeCraft", 5 | kwargs: "{\"objective\": \"ARENA_TINY_2V2\"}", 6 | ), 7 | rollout: ( 8 | num_envs: 256, 9 | steps: 64, 10 | ), 11 | optim: ( 12 | max_grad_norm: 10, 13 | update_epochs: 3, 14 | lr: 0.005, 15 | bs: 1024, 16 | ), 17 | ppo: ( 18 | anneal_entropy: true, 19 | gamma: 0.99, 20 | ent_coef: 0.2, 21 | vf_coef: 2.0, 22 | ), 23 | net: ( 24 | d_model: 256, 25 | n_layer: 2, 26 | n_head: 2, 27 | relpos_encoding: ( 28 | extent: [8], 29 | position_features: ["x", "y"], 30 | key_relpos_projection: true, 31 | value_relpos_projection: true, 32 | rotation_vec_features: ["orientation_x", "orientation_y"], 33 | radial: true, 34 | interpolate: true, 35 | ), 36 | // Alternative: 37 | // relpos_encoding: ( 38 | // extent: [8, 2], 39 | // position_features: ["x", "y"], 40 | // rotation_vec_features: ["orientation_x", "orientation_y"], 41 | // radial: true, 42 | // distance: true, 43 | // interpolate: true, 44 | // scale: 1000.0, 45 | // ), 46 | ), 47 | total_timesteps: 10000000, 48 | eval: ( 49 | interval: 1000000, 50 | num_envs: 256, 51 | steps: 360, 52 | opponent: "/xprun/data/common/DeepCodeCraft/golden-models/arena_tiny_2v2/arena_tiny_2v2-e58ceea-0-25m", 53 | env: ( 54 | id: "CodeCraft", 55 | kwargs: "{\"objective\": \"ARENA_TINY_2V2\", \"stagger\": false}", 56 | ), 57 | ) 58 | ) -------------------------------------------------------------------------------- /configs/codecraft/arena_medium.ron: -------------------------------------------------------------------------------- 1 | // Achieves ~0.4 against eval opponent (old baseline: ~0.8-0.95): https://wandb.ai/entity-neural-network/enn-ppo/reports/Arena-Medium-baseline--VmlldzoxNzgwMTM1 2 | ExperimentConfig( 3 | version: 0, 4 | env: ( 5 | id: "CodeCraft", 6 | kwargs: "{\"objective\": \"ARENA_MEDIUM\", \"hardness\": 1.0, \"win_bonus\": 2.0, \"hidden_obs\": true}", 7 | ), 8 | rollout: ( 9 | num_envs: 128, 10 | steps: 64, 11 | ), 12 | optim: ( 13 | max_grad_norm: 10, 14 | update_epochs: 3, 15 | lr: 0.0005, 16 | bs: 4096, 17 | micro_bs: 1024, 18 | ), 19 | ppo: ( 20 | anneal_entropy: true, 21 | gamma: 0.999, 22 | ent_coef: 0.03, 23 | vf_coef: 2.0, 24 | ), 25 | net: ( 26 | d_model: 256, 27 | n_layer: 2, 28 | n_head: 2, 29 | relpos_encoding: ( 30 | extent: [8, 2], 31 | position_features: ["x", "y"], 32 | rotation_vec_features: ["orientation_x", "orientation_y"], 33 | radial: true, 34 | distance: true, 35 | interpolate: true, 36 | scale: 1000.0, 37 | per_entity_values: false, 38 | value_gate: "relu", 39 | ), 40 | // Alternative 41 | // relpos_encoding: ( 42 | // extent: [8], 43 | // position_features: ["x", "y"], 44 | // key_relpos_projection: true, 45 | // value_relpos_projection: true, 46 | // rotation_vec_features: ["orientation_x", "orientation_y"], 47 | // radial: true, 48 | // interpolate: true, 49 | // ), 50 | ), 51 | total_timesteps: 25000000, 52 | eval: ( 53 | interval: 500000, 54 | num_envs: 128, 55 | steps: 1000, 56 | opponent: "/xprun/data/common/DeepCodeCraft/golden-models/arena_medium/arena_medium-5f06842-0-10m", 57 | env: ( 58 | id: "CodeCraft", 59 | kwargs: "{\"objective\": \"ARENA_MEDIUM\", \"hardness\": 1.0, \"stagger\": false, \"symmetric\": 1.0, \"fair\": false}", 60 | ), 61 | ) 62 | ) 63 | -------------------------------------------------------------------------------- /.github/workflows/publish.yaml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | workflow_dispatch: 8 | 9 | jobs: 10 | test: 11 | runs-on: ubuntu-latest 12 | defaults: 13 | run: 14 | working-directory: entity_gym 15 | strategy: 16 | matrix: 17 | python-version: [3.8] 18 | fail-fast: false 19 | 20 | environment: PyPI 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | #---------------------------------------------- 28 | # ----- install & configure poetry ----- 29 | #---------------------------------------------- 30 | - name: Install Poetry 31 | uses: snok/install-poetry@v1 32 | with: 33 | virtualenvs-create: true 34 | virtualenvs-in-project: true 35 | installer-parallel: true 36 | #---------------------------------------------- 37 | # load cached venv if cache exists 38 | #---------------------------------------------- 39 | - name: Load cached venv 40 | id: cached-poetry-dependencies 41 | uses: actions/cache@v2 42 | with: 43 | path: .venv 44 | key: venv-${{ runner.os }}-${{ hashFiles('poetry.lock') }} 45 | #---------------------------------------------- 46 | # install dependencies if cache does not exist 47 | #---------------------------------------------- 48 | - name: Install dependencies 49 | if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' 50 | continue-on-error: true 51 | run: poetry install --no-interaction --no-root 52 | #---------------------------------------------- 53 | # install your root project, if required 54 | #---------------------------------------------- 55 | - name: Install library 56 | run: poetry install --no-interaction 57 | 58 | - name: Publish entity-gym to PyPI 59 | run: | 60 | poetry build 61 | poetry publish --username __token__ --password ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /configs/xprun/trainbc.ron: -------------------------------------------------------------------------------- 1 | XpV0( 2 | project: "enn", 3 | containers: { 4 | "main": ( 5 | command: ["poetry", "run", "python", "-u", "enn_zoo/enn_zoo/supervised.py"], 6 | build: [ 7 | From("nvcr.io/nvidia/pytorch:21.03-py3"), 8 | 9 | // Install Poetry 10 | Run("curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python -"), 11 | Env("PATH", "/root/.poetry/bin:${PATH}"), 12 | 13 | // Cache dependencies by installing them at fixed commit (to avoid long rebuilds when changing dependencies) 14 | Repo( 15 | paths: [ 16 | "pyproject.toml", 17 | "poetry.lock", 18 | "rogue_net/pyproject.toml", 19 | "rogue_net/poetry.lock", 20 | "rogue_net/rogue_net/__init__.py", 21 | "enn_ppo/pyproject.toml", 22 | "enn_ppo/poetry.lock", 23 | "enn_ppo/enn_ppo/__init__.py", 24 | "entity_gym/pyproject.toml", 25 | "entity_gym/poetry.lock", 26 | "entity_gym/entity_gym/__init__.py", 27 | "enn_zoo/pyproject.toml", 28 | "enn_zoo/poetry.lock", 29 | "enn_zoo/enn_zoo/__init__.py", 30 | ], 31 | target_dir: "/root/enn-incubator", 32 | cd: true, 33 | rev: "cf16b20", 34 | ), 35 | 36 | Run("poetry install"), 37 | Run("poetry run pip install setuptools==59.5.0"), 38 | Run("poetry run pip install torch==1.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html"), 39 | Run("poetry run pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html"), 40 | 41 | Repo( 42 | paths: [ 43 | "pyproject.toml", 44 | "poetry.lock", 45 | "rogue_net/pyproject.toml", 46 | "rogue_net/poetry.lock", 47 | "rogue_net/rogue_net/__init__.py", 48 | "enn_ppo/pyproject.toml", 49 | "enn_ppo/poetry.lock", 50 | "enn_ppo/enn_ppo/__init__.py", 51 | "entity_gym/pyproject.toml", 52 | "entity_gym/poetry.lock", 53 | "entity_gym/entity_gym/__init__.py", 54 | "enn_zoo/pyproject.toml", 55 | "enn_zoo/poetry.lock", 56 | "enn_zoo/enn_zoo/__init__.py", 57 | ], 58 | target_dir: "/root/enn-incubator", 59 | cd: true, 60 | ), 61 | Run("poetry install"), 62 | 63 | Repo(cd: true), 64 | 65 | ], 66 | gpu: 1, 67 | gpu_mem: "10GB", 68 | cpu_mem: "20GiB", 69 | env_secrets: { 70 | "WANDB_API_KEY": "wandb-api-key", 71 | }, 72 | ) 73 | } 74 | ) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Entity Neural Networks 2 | 3 | 👍🎉 Thank you for taking the time to contribute! 🎉👍 4 | 5 | To get an overview of some of this project's motivation and goals, you can take a look at [Neural Network Architectures for Structured State](https://docs.google.com/document/d/1Q87zeY7Z4u9cU0oLoH-BPQZDBQd4tHLWiEkj5YDSGw4). 6 | Feel free to open an issue or pull request if you have any questions or suggestions. 7 | You can also [join our Discord](https://discord.gg/rrwSkmCp) and ask questions there. 8 | If you plan to work on an issue, let us know in the issue thread so we can avoid duplicate work. 9 | 10 | ## Dev Setup 11 | 12 | ```bash 13 | poetry install 14 | poetry run pip install setuptools==59.5.0 # Required to work around bug in torch (https://github.com/pytorch/pytorch/pull/57040). We can remove this step once we upgrade to torch >= 1.11.0. 15 | poetry run pip install torch==1.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 16 | poetry run pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html 17 | ``` 18 | 19 | Then you can run the scripts under the poetry environment in two ways: `poetry run` or `poetry shell`. 20 | 21 | * `poetry run`: 22 | By prefixing `poetry run`, your command will run in poetry's virtual environment. For example, try running 23 | ```bash 24 | poetry run python enn_ppo/enn_ppo/train.py 25 | ``` 26 | * `poetry shell`: 27 | First, activate the poetry's virtual environment by executing `poetry shell`. Then, the name of the poetry's 28 | virtual environment (e.g. `(incubator-EKBuw-J_-py3.9)`) should appear in the left side of your shell. 29 | Afterwards, you can directly run 30 | ```bash 31 | python enn_ppo/enn_ppo/train.py 32 | ``` 33 | 34 | ### Common Build Problems 35 | 36 | `poetry` sometimes does not play nicely with `conda`. So make sure you run a fresh shell that does not activate any conda environments. If you try to run `poetry install` while a conda environment is active, you might encounter something like the following error: 37 | 38 | ``` 39 | Cargo, the Rust package manager, is not installed or is not on PATH. 40 | ``` 41 | 42 | This can be resolved by running `conda deactivate` first. 43 | 44 | If you are running into any other build issues, try the following recommended instructions. 45 | 46 | * Ubuntu/Debian/Mint: 47 | ```bash 48 | sudo apt install python3-dev make build-essential libssl-dev zlib1g-dev \ 49 | libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \ 50 | libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev 51 | ``` 52 | 53 | ### Install Example 54 | 55 | [![asciicast](https://asciinema.org/a/452597.svg)](https://asciinema.org/a/452597) 56 | 57 | ## Code Style 58 | 59 | We use [Pre-commit](https://pre-commit.com/) to 60 | 62 | * format code using black (via `black`) 63 | * check word spelling (via `codespell`) 64 | * check typing (via `mypy`) 65 | 66 | You can run the following command to do these automatically: 67 | 68 | ```bash 69 | poetry run pre-commit run --all-files 70 | ``` 71 | 72 | ## Running Tests 73 | 74 | ```bash 75 | poetry run pytest . 76 | ``` 77 | 78 | ## Building docs 79 | 80 | To build the documentation for entity-gym, go to the entity_gym/docs folder and run the following command: 81 | 82 | ```bash 83 | poetry run make html 84 | ``` 85 | 86 | You can use [watchexec](https://github.com/watchexec/watchexec) to automatically rebuild the documentation on changes: 87 | 88 | ``` 89 | watchexec -w ../entity_gym -w source -i source/generated -i source/entity_gym -- poetry run make html 90 | ``` 91 | 92 | You can view the generated docs by openaing `entity_gym/docs/build/html` in a browser. 93 | 94 | Some files won't be automatically cleaned up after the build, so you can manually clean up the build directory by running: 95 | 96 | ``` 97 | poetry run make clean && rm -rf source/generated source/entity_gym 98 | ``` 99 | -------------------------------------------------------------------------------- /configs/xprun/train.ron: -------------------------------------------------------------------------------- 1 | XpV0( 2 | project: "enn", 3 | containers: { 4 | "main": ( 5 | command: ["poetry", "run", "python", "-u", "enn_zoo/enn_zoo/train.py"], 6 | build: [ 7 | From("nvcr.io/nvidia/pytorch:21.03-py3"), 8 | 9 | // Install Poetry 10 | Run("curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python -"), 11 | Env("PATH", "/root/.poetry/bin:${PATH}"), 12 | 13 | // Install Vulkan drivers (required by Griddly) 14 | Run("wget -qO - http://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add -"), 15 | Run("wget -qO /etc/apt/sources.list.d/lunarg-vulkan-focal.list http://packages.lunarg.com/vulkan/lunarg-vulkan-focal.list"), 16 | Run("apt-get update"), 17 | Run("DEBIAN_FRONTEND=noninteractive apt-get install -y tzdata"), 18 | Run("apt-get install vulkan-sdk -y"), 19 | 20 | // Install Java (required by MicroRTS) 21 | Run("apt-get install -y --no-install-recommends software-properties-common"), 22 | Run("add-apt-repository -y ppa:openjdk-r/ppa"), 23 | Run("apt-get update"), 24 | Run("apt-get install -y openjdk-8-jdk"), 25 | Run("apt-get install -y openjdk-8-jre"), 26 | Run("update-alternatives --config java"), 27 | Run("update-alternatives --config javac"), 28 | 29 | // Install Rust toolchain 30 | Run("apt-get update"), 31 | Run("apt-get install curl build-essential --yes"), 32 | Run("curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y"), 33 | Env("PATH", "/root/.cargo/bin:${PATH}"), 34 | 35 | // Cache dependencies by installing them at fixed commit (to avoid long rebuilds when changing dependencies) 36 | Repo( 37 | paths: [ 38 | "pyproject.toml", 39 | "poetry.lock", 40 | "rogue_net/pyproject.toml", 41 | "rogue_net/poetry.lock", 42 | "rogue_net/rogue_net/__init__.py", 43 | "enn_ppo/pyproject.toml", 44 | "enn_ppo/poetry.lock", 45 | "enn_ppo/enn_ppo/__init__.py", 46 | "entity_gym/pyproject.toml", 47 | "entity_gym/poetry.lock", 48 | "entity_gym/entity_gym/__init__.py", 49 | "enn_zoo/pyproject.toml", 50 | "enn_zoo/poetry.lock", 51 | "enn_zoo/enn_zoo/__init__.py", 52 | ], 53 | target_dir: "/root/enn-incubator", 54 | cd: true, 55 | rev: "cf16b20", 56 | ), 57 | 58 | // Build xprun from source 59 | Repo(url: "git@github.com:cswinter/xprun.git", rev: "d8a58d8", target_dir: "/root"), 60 | Run("poetry run pip install maturin==0.12.6"), 61 | Run("poetry run maturin build --cargo-extra-args=--features=python --manifest-path=/root/xprun/Cargo.toml"), 62 | Run("poetry run pip install /root/xprun/target/wheels/xprun-0.1.4-cp38-cp38-manylinux_2_31_x86_64.whl"), 63 | 64 | Run("poetry install"), 65 | Run("poetry run pip install setuptools==59.5.0"), 66 | Run("poetry run pip install torch==1.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html"), 67 | Run("poetry run pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html"), 68 | 69 | Repo( 70 | paths: [ 71 | "pyproject.toml", 72 | "poetry.lock", 73 | "rogue_net/pyproject.toml", 74 | "rogue_net/poetry.lock", 75 | "rogue_net/rogue_net/__init__.py", 76 | "enn_ppo/pyproject.toml", 77 | "enn_ppo/poetry.lock", 78 | "enn_ppo/enn_ppo/__init__.py", 79 | "entity_gym/pyproject.toml", 80 | "entity_gym/poetry.lock", 81 | "entity_gym/entity_gym/__init__.py", 82 | "enn_zoo/pyproject.toml", 83 | "enn_zoo/poetry.lock", 84 | "enn_zoo/enn_zoo/__init__.py", 85 | ], 86 | target_dir: "/root/enn-incubator", 87 | cd: true, 88 | ), 89 | Run("poetry install"), 90 | 91 | Repo(cd: true), 92 | ], 93 | gpu: 1, 94 | gpu_mem: "10GiB", 95 | cpu_mem: "10GiB", 96 | env_secrets: { 97 | "WANDB_API_KEY": "wandb-api-key", 98 | }, 99 | ) 100 | } 101 | ) 102 | -------------------------------------------------------------------------------- /configs/xprun/traincc.ron: -------------------------------------------------------------------------------- 1 | XpV0( 2 | project: "enn", 3 | containers: { 4 | "main": ( 5 | command: ["poetry", "run", "python", "-u", "enn_zoo/enn_zoo/train.py"], 6 | build: [ 7 | From("nvcr.io/nvidia/pytorch:21.03-py3"), 8 | 9 | // Install Poetry 10 | Run("curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python -"), 11 | Env("PATH", "/root/.poetry/bin:${PATH}"), 12 | 13 | // Install Vulkan drivers (required by Griddly) 14 | Run("wget -qO - http://packages.lunarg.com/lunarg-signing-key-pub.asc | apt-key add -"), 15 | Run("wget -qO /etc/apt/sources.list.d/lunarg-vulkan-focal.list http://packages.lunarg.com/vulkan/lunarg-vulkan-focal.list"), 16 | Run("apt-get update"), 17 | Run("DEBIAN_FRONTEND=noninteractive apt-get install -y tzdata"), 18 | Run("apt-get install vulkan-sdk -y"), 19 | 20 | // Install Java (required by MicroRTS) 21 | Run("apt-get install -y --no-install-recommends software-properties-common"), 22 | Run("add-apt-repository -y ppa:openjdk-r/ppa"), 23 | Run("apt-get update"), 24 | Run("apt-get install -y openjdk-8-jdk"), 25 | Run("apt-get install -y openjdk-8-jre"), 26 | Run("update-alternatives --config java"), 27 | Run("update-alternatives --config javac"), 28 | 29 | // Install Rust toolchain 30 | Run("apt-get update"), 31 | Run("apt-get install curl build-essential --yes"), 32 | Run("curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y"), 33 | Env("PATH", "/root/.cargo/bin:${PATH}"), 34 | 35 | // Cache dependencies by installing them at fixed commit (to avoid long rebuilds when changing dependencies) 36 | Repo( 37 | paths: [ 38 | "pyproject.toml", 39 | "poetry.lock", 40 | "rogue_net/pyproject.toml", 41 | "rogue_net/poetry.lock", 42 | "rogue_net/rogue_net/__init__.py", 43 | "enn_ppo/pyproject.toml", 44 | "enn_ppo/poetry.lock", 45 | "enn_ppo/enn_ppo/__init__.py", 46 | "entity_gym/pyproject.toml", 47 | "entity_gym/poetry.lock", 48 | "entity_gym/entity_gym/__init__.py", 49 | "enn_zoo/pyproject.toml", 50 | "enn_zoo/poetry.lock", 51 | "enn_zoo/enn_zoo/__init__.py", 52 | ], 53 | target_dir: "/root/enn-incubator", 54 | cd: true, 55 | rev: "cf16b20", 56 | ), 57 | 58 | // Build xprun from source 59 | Repo(url: "git@github.com:cswinter/xprun.git", rev: "d8a58d8", target_dir: "/root"), 60 | Run("poetry run pip install maturin==0.12.6"), 61 | Run("poetry run maturin build --cargo-extra-args=--features=python --manifest-path=/root/xprun/Cargo.toml"), 62 | Run("poetry run pip install /root/xprun/target/wheels/xprun-0.1.4-cp38-cp38-manylinux_2_31_x86_64.whl"), 63 | 64 | Run("poetry install"), 65 | Run("poetry run pip install setuptools==59.5.0"), 66 | Run("poetry run pip install torch==1.10.2+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html"), 67 | Run("poetry run pip install torch-scatter -f https://data.pyg.org/whl/torch-1.10.0+cu113.html"), 68 | 69 | Repo( 70 | paths: [ 71 | "pyproject.toml", 72 | "poetry.lock", 73 | "rogue_net/pyproject.toml", 74 | "rogue_net/poetry.lock", 75 | "rogue_net/rogue_net/__init__.py", 76 | "enn_ppo/pyproject.toml", 77 | "enn_ppo/poetry.lock", 78 | "enn_ppo/enn_ppo/__init__.py", 79 | "entity_gym/pyproject.toml", 80 | "entity_gym/poetry.lock", 81 | "entity_gym/entity_gym/__init__.py", 82 | "enn_zoo/pyproject.toml", 83 | "enn_zoo/poetry.lock", 84 | "enn_zoo/enn_zoo/__init__.py", 85 | ], 86 | target_dir: "/root/enn-incubator", 87 | cd: true, 88 | ), 89 | Run("poetry install"), 90 | 91 | Repo(cd: true), 92 | ], 93 | gpu: 1, 94 | gpu_mem: "10GB", 95 | cpu_mem: "4GiB", 96 | env_secrets: { 97 | "WANDB_API_KEY": "wandb-api-key", 98 | }, 99 | ), 100 | 101 | "codecraftserver": ( 102 | command: ["server-0.1.0-SNAPSHOT/bin/server", "-Dplay.http.secret.key=ad31779d4ee49d5ad5162bf1429c32e2e9933f3b"], 103 | cpu: 4, 104 | cpu_mem: "12GiB", 105 | tty: true, 106 | env: { 107 | "SBT_OPTS": "-Xmx10G", 108 | }, 109 | build: [ 110 | From("hseeberger/scala-sbt:8u222_1.3.5_2.13.1"), 111 | 112 | // build fixed versions of CodeCraftGame and CodeCraftServer as a straightforward way to download sbt 0.13.16 and populate dependency cache 113 | Repo(url: "https://github.com/cswinter/CodeCraftGame.git", rev: "92304eb", cd: true, rm: true), 114 | Run("sbt publishLocal"), 115 | Repo(url: "https://github.com/cswinter/CodeCraftServer.git", rev: "df76892", cd: true, rm: true), 116 | Run("sbt compile"), 117 | 118 | // build CodeCraftGame and CodeCraftServer from source 119 | Repo(url: "https://github.com/cswinter/CodeCraftGame.git", rev: "edc5a9f2", cd: true), 120 | Run("sbt publishLocal"), 121 | Repo(url: "https://github.com/cswinter/CodeCraftServer.git", rev: "302a379", cd: true), 122 | Run("sbt dist"), 123 | Run("unzip server/target/universal/server-0.1.0-SNAPSHOT.zip"), 124 | ], 125 | ), 126 | } 127 | ) 128 | -------------------------------------------------------------------------------- /LICENSE-APACHE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. -------------------------------------------------------------------------------- /ARCHITECTURE.md: -------------------------------------------------------------------------------- 1 | # Internals 2 | 3 | This document follows an batch of observations from the MineSweeper environment through the internals of entity-gym, enn-ppo, and rogue-net in excruciating detail. 4 | 5 | 6 | ## High-level overview 7 | 8 | - The [Environment](#environment--observation) (entity-gym) provides a high-level abstraction for an environment. 9 | - The [VecEnv](#vecenv--observation) (entity-gym) combines multiple environments and exposes a more efficient and lower-level batched representation of observations/actions. 10 | - The [PPO training loop](#trainpy) (enn-ppo) keeps a sample buffer that combines observations from multiple steps. 11 | - The policy is implemented by [RogueNet](#RogueNet) (rogue-net), a ragged batch transformer that takes lists of entities as input and outputs corresponding lists of actions. 12 | 13 | ## MineSweeper / State 14 | 15 | Initial state of the three environments: 16 | 17 | ![](https://user-images.githubusercontent.com/12845088/152281730-d42a9ffe-b844-48c5-b6ff-de1ceecdb2f8.png) 18 | 19 |
20 | Environment State (click to expand) 21 | 22 | ```python 23 | # Environment 1 24 | mines = [(0, 2), (0, 1), (2, 2), (0, 0), (1, 0)] 25 | robots = [(1, 1)] 26 | orbital_cannon_cooldown = 5 27 | orbital_cannon = False 28 | # Environment 2 29 | mines = [(2, 1)] 30 | robots = [(2, 0)] 31 | orbital_cannon_cooldown = 0 32 | orbital_cannon = True 33 | # Environment 3 34 | mines = [(1, 0), (0, 1), (2, 2)] 35 | robots = [(0, 0), (2, 0)] 36 | orbital_cannon_cooldown = 5 37 | orbital_cannon = False 38 | ``` 39 |
40 | 41 | ## Environment / Observation 42 | 43 | The MineSweeper class implements `Environment`, which provides a high-level abstraction for an environment. 44 | `Environment`s expose their state as an `Observation` object, which contains a dictionary with the `features` of each entity, a list of `ids` to make it possible to reference specific entities, and a dictionary of `actions` that determines which entities can perform which actions. 45 | 46 |
47 | Observation #1 (click to expand) 48 | 49 | ```python 50 | Observation( 51 | features={ 52 | "Mine": [[0, 2], [0, 1], [2, 2], [0, 0], [1, 0]], 53 | "Robot": [[1, 1]] 54 | }, 55 | ids={ 56 | "Mine": [("Mine", 0), ("Mine", 1), ("Mine", 2), ("Mine", 3), ("Mine", 4)], 57 | "Robot": [("Robot", 0)], 58 | }, 59 | actions={ 60 | "Move": CategoricalActionMask( 61 | actor_ids=None, 62 | actor_types=["Robot"], 63 | mask=[[True, True, True, True, True]], 64 | ), 65 | "Fire Orbital Cannon": SelectEntityActionMask( 66 | actor_ids=None, 67 | actor_types=[], 68 | actee_types=["Mine", "Robot"], 69 | actee_ids=None, 70 | mask=None, 71 | ), 72 | }, 73 | done=False, 74 | reward=0.0, 75 | end_of_episode_info=None, 76 | ) 77 | ``` 78 |
79 | 80 |
81 | Observation #2 (click to expand) 82 | 83 | ```python 84 | Observation( 85 | features={ 86 | "Mine": [[2, 1]], 87 | "Robot": [[2, 0]], "Orbital Cannon": [[0]] 88 | }, 89 | actions={ 90 | "Move": CategoricalActionMask( 91 | actor_ids=None, 92 | actor_types=["Robot"], 93 | mask=[[False, True, True, False, True]], 94 | ), 95 | "Fire Orbital Cannon": SelectEntityActionMask( 96 | actor_ids=None, 97 | actor_types=["Orbital Cannon"], 98 | actee_types=["Mine", "Robot"], 99 | actee_ids=None, 100 | mask=None, 101 | ), 102 | }, 103 | done=False, 104 | reward=0.0, 105 | ids={ 106 | "Mine": [("Mine", 0)], 107 | "Robot": [("Robot", 0)], 108 | "Orbital Cannon": [("Orbital Cannon", 0)], 109 | }, 110 | end_of_episode_info=None, 111 | ) 112 | ``` 113 |
114 | 115 |
116 | Observation #3 (click to expand) 117 | 118 | ```python 119 | Observation( 120 | features={ 121 | "Mine": [[1, 0], [0, 1], [2, 2]], 122 | "Robot": [[0, 0], [2, 0]] 123 | }, 124 | actions={ 125 | "Move": CategoricalActionMask( 126 | actor_ids=None, 127 | actor_types=["Robot"], 128 | mask=[ 129 | [True, False, True, False, True], 130 | [False, True, True, False, True], 131 | ], 132 | ), 133 | "Fire Orbital Cannon": SelectEntityActionMask( 134 | actor_ids=None, 135 | actor_types=[], 136 | actee_types=["Mine", "Robot"], 137 | actee_ids=None, 138 | mask=None, 139 | ), 140 | }, 141 | done=False, 142 | reward=0.0, 143 | ids={ 144 | "Mine": [("Mine", 0), ("Mine", 1), ("Mine", 2)], 145 | "Robot": [("Robot", 0), ("Robot", 1)], 146 | }, 147 | end_of_episode_info=None, 148 | ) 149 | ``` 150 |
151 | 152 | ## VecEnv / VecObs 153 | 154 | The `ListEnv` is an implementation of `VecEnv` that aggregates the observations from multiple environments into a more efficient and lower level batched representation: 155 | - Features of each entity type from all environments are combined into a single `RaggedBufferF32` 156 | - Action masks from each action type from all environments are combined into a single `RaggedBufferBool` 157 | - Instead of specifying the `actors` and `actees` of each action using `EntityID`s, we use the corresponding integer indices instead. The index of an entity is defined as follows: 158 | - The `entities` field of the `ObsSpace` specified by an `Environment` defines an ordering of the entity types. 159 | - In this case, the entity types are ordered as `["Mine", "Robot", "Orbital Cannon"]`. 160 | - We now go through all entity types in this order and sequentially assign an index to each entity. 161 | - For example, if there are three entities with `ids = {"Robot": [("Robot", 0)], "Mine": [("Mine", 0), ("Mine", 1)]}`, then the index of `("Robot", 0)` is `0`, the index of `("Mine", 0)` is `1`, and the index of `("Mine", 1)` is `2`. 162 | 163 |
164 | VecObs (click to expand) 165 | 166 | ```python 167 | VecObs( 168 | features={ 169 | "Mine": RaggedBufferF32( 170 | [ 171 | [[0, 2], [0, 1], [2, 2], [0, 0], [1, 0]], 172 | [[2, 1]], 173 | [[1, 0], [0, 1], [2, 2]], 174 | ] 175 | ), 176 | "Robot": RaggedBufferF32( 177 | [ 178 | [[1, 1]], 179 | [[2, 0]], 180 | [[0, 0], [2, 0]], 181 | ] 182 | ), 183 | "Orbital Cannon": RaggedBuffer( 184 | [ 185 | [], 186 | [[0.0]], 187 | [], 188 | ] 189 | ), 190 | }, 191 | action_masks={ 192 | "Move": VecCategoricalActionMask( 193 | actors=RaggedBufferI64( 194 | [ 195 | [[5]], 196 | [[1]], 197 | [[3], [4]], 198 | ] 199 | ), 200 | mask=RaggedBufferBool( 201 | [ 202 | [[true, true, true, true, true]], 203 | [[false, true, true, false, true]], 204 | [ 205 | [true, false, true, false, true], 206 | [false, true, true, false, true], 207 | ], 208 | ] 209 | ), 210 | ), 211 | "Fire Orbital Cannon": VecSelectEntityActionMask( 212 | actors=RaggedBufferI64( 213 | [ 214 | [], 215 | [[2]], 216 | [], 217 | ] 218 | ), 219 | actees=RaggedBufferI64( 220 | [ 221 | [], 222 | [[0], [1]], 223 | [], 224 | ] 225 | ), 226 | ), 227 | }, 228 | reward=array([0.0, 0.0, 0.0], dtype=float32), 229 | done=array([False, False, False]), 230 | end_of_episode_info={}, 231 | ) 232 | ``` 233 |
234 | 235 | ## enn_ppo/train.py 236 | 237 | The PPO implementation in `enn_ppo/train.py` accumulates the `VecObs` from multiple steps into sample buffers. 238 | These are later shuffled and split up into minibatches during the optimization phase. 239 | In this case, we are just looking at a single rollout step and the batch of observations is forwarded unmodified to the policy to sample actions. 240 | 241 | ## RogueNet 242 | 243 | The core of the policy is `RogueNet`, a ragged batch transformer implementation that takes in a ragged batch of observations and actor/actee/masks for each action, and outputs a ragged batch of actions and log-probabilities. 244 | 245 | ### Embedding 246 | 247 | The first step is to flatten apply a projection to the features of each entity type to yield embeddings of the same size. 248 | All embeddings are then concatenated into a single tensor which is ordered first by environment and then by entity index: 249 | 250 |
251 | Embedding Tensor (click to expand) 252 | 253 | ```python 254 | tensor([ 255 | # Environment 1 256 | [ 1.5280, -0.7984, 0.8672, -0.7984, -0.7984], # Mine 0 257 | [ 0.6134, -0.7676, 1.6895, -0.7676, -0.7676], # Mine 1 258 | [ 0.1566, -0.8506, 1.8400, -0.2497, -0.8963], # Mine 2 259 | [-0.8081, -0.7904, 1.4962, 0.9104, -0.8081], # Mine 3 260 | [-0.9405, -0.5402, 1.2698, 1.1515, -0.9405], # Mine 4 261 | [ 1.8806, 0.1884, -0.6897, -0.6897, -0.6897], # Robot 4 262 | # Environment 2 263 | [-0.8848, -0.5453, 1.6356, 0.6792, -0.8848], # Mine 0 264 | [ 1.3690, 1.0691, -0.8127, -0.8127, -0.8127], # Robot 0 265 | [-0.8059, 1.5626, -0.7685, -0.8059, 0.8175], # Orbital Cannon 0 266 | # Environment 3 267 | [-0.9405, -0.5402, 1.2698, 1.1515, -0.9405], # Mine 0 268 | [ 0.6134, -0.7676, 1.6895, -0.7676, -0.7676], # Mine 1 269 | [ 0.1566, -0.8506, 1.8400, -0.2497, -0.8963], # Mine 3 270 | [ 1.4806, 0.9317, -0.8041, -0.8041, -0.8041], # Robot 0 271 | [ 1.3690, 1.0691, -0.8127, -0.8127, -0.8127], # Robot 1 272 | ], device='cuda:0') 273 | ``` 274 |
275 | 276 | ### Attention 277 | 278 | Most of the transformer layers are applied independently to each entity. 279 | However, the attention operation is applied to sequences of entities from the same timestep/environment. 280 | It is currently implemented by packing/padding the flattened embeddings into a (sequence, entity, feature) tensor that places all entities from the same timestep/environment into the same sequence. 281 | To do this, we compute three tensors: 282 | - the `index` determines which entity is placed at each position the packed tensor 283 | - the `batch` tells us what timestep/environment each entity came from, and is used to construct a mask that prevents attention from going across separate timesteps/environments 284 | - the `inverse_index` is used to reconstruct the original flattened embedding tensor from the packed tensor 285 | 286 |
287 | Packing/padding metadata (click to expand) 288 | 289 | ```python 290 | index = [ 291 | [ 0, 1, 2, 3, 4, 5], 292 | [ 6, 7, 8, 0, 0, 0], 293 | [ 9, 10, 11, 12, 13, 0], 294 | ] 295 | batch = [ 296 | [ 0., 0., 0., 0., 0., 0.], 297 | [ 1., 1., 1., nan, nan, nan], 298 | [ 2., 2., 2., 2., 2., nan], 299 | ] 300 | inverse_index = [ 301 | 0, 1, 2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16 302 | ] 303 | ``` 304 |
305 | 306 | ![](https://user-images.githubusercontent.com/12845088/147727605-d904ffff-42b4-4c51-9088-7ab32f9d481a.png) 307 | 308 | 309 | ### Categorical Action Head 310 | 311 | Once the embeddings have passed through all layers, we can compute the action heads for each entity. 312 | Recall that we have a ragged list of indices of each actor. 313 | However, the indices are only unique per environment, and we still need to add a ragged buffer of offsets to get a set of indices that is sequential over all environments and corresponds to the flattened embedding tensor. 314 | The corrected indices are then used to index into the flattened embedding tensor to get the embedding for each actor. 315 | We project the resulting embeddings onto the number of choices for each action to get a tensor of logits, and finally sample from the logits to get the action. 316 | 317 |
318 | "Move" action actors, offsets, indices, actions (click to expand) 319 | 320 | ```python 321 | actors = RaggedBufferI64([ 322 | [[5]], 323 | [[1]], 324 | [[3], [4]], 325 | ]) 326 | offsets = RaggedBuffer([ 327 | [[0]], 328 | [[6]], 329 | [[9]], 330 | ]) 331 | actors + offsets = RaggedBufferI64([ 332 | [[5]], 333 | [[7]], 334 | [[12], [13]], 335 | ]) 336 | indices = tensor([5, 7, 12, 13], dtype=int64) 337 | # TODO: logits? 338 | actions = tensor([4, 1, 4, 2], dtype=int64) 339 | ragged_actions = RaggedBufferI64([ 340 | [[4]], 341 | [[1]], 342 | [[4], [2]], 343 | ]) 344 | ``` 345 |
346 | 347 | ### Select Entity Action Head 348 | 349 | The "Fire Orbital Cannon" action is a little more tricky. It is a SelectEntityAction, which means that it does not have a fixed number of choices, but the number of choices instead depends on the number of selectable entities in each the environment. 350 | But at the end, we again get a list of indices corresponding to the entity selected by each actor. 351 | 352 | 353 | ![](https://user-images.githubusercontent.com/12845088/145058088-ae42f5f5-2782-4247-bcf5-8270a14e3510.png) 354 | 355 | 356 | ## Actions 357 | 358 | Now, the actions computed by the model travel back to the environments. 359 | The `ListEnv` receives ragged buffers for each action which represent the chosen action in the case of categorical actions, or the selected entity in the case of select entity actions. 360 | 361 |
362 | Ragged Actions (click to expand) 363 | 364 | ```python 365 | actions = { 366 | 'Fire Orbital Cannon': RaggedBuffer([ 367 | [], 368 | [[0]], 369 | [], 370 | ]), 371 | 'Move': RaggedBuffer([ 372 | [[4]], 373 | [[1]], 374 | [[4], [2]], 375 | ]), 376 | } 377 | ``` 378 |
379 | 380 | The actions are split up along the environment axis, joined with the list of actors from the initial `Observation`s, and actor indices are replaced with the corresponding `EntityID`s. 381 | The resulting `Action` objects are dispatched to the `act` methods of the individual environments. 382 | 383 |
384 | Actions (click to expand) 385 | 386 | ```python 387 | # Environment 1 388 | { 389 | 'Fire Orbital Cannon': SelectEntityAction( 390 | actors=[], 391 | actees=[], 392 | ), 393 | 'Move': CategoricalAction( 394 | actors=[('Robot', 0)], 395 | actions=array([4]), 396 | ), 397 | } 398 | # Environment 2 399 | { 400 | 'Fire Orbital Cannon': SelectEntityAction( 401 | actors=[('Orbital Cannon', 0)], 402 | actees=[('Mine', 0)], 403 | ), 404 | 'Move': CategoricalAction( 405 | actors=[('Robot', 0)], 406 | actions=array([1]), 407 | ), 408 | } 409 | # Environment 3 410 | { 411 | 'Fire Orbital Cannon': SelectEntityAction( 412 | actors=[], 413 | actees=[], 414 | ), 415 | 'Move': CategoricalAction( 416 | actors=[('Robot', 0), ('Robot', 1)], 417 | actions=array([4, 2]), 418 | ), 419 | } 420 | ``` 421 |
--------------------------------------------------------------------------------