├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── docs ├── ATARI_ENVPOOL.md ├── BRAX.md ├── CONFIG_PARAMS.md ├── DEEPMIND_ENVPOOL.md ├── HOW_TO_RL_GAMES.md ├── ISAAC_GYM.md ├── MUJOCO_ENVPOOL.md ├── OTHER.md ├── SMAC.md └── pictures │ ├── atari_envpool │ ├── breakout.jpg │ ├── breakout_envpool.png │ ├── pong.jpg │ └── pong_envpool.png │ ├── brax │ ├── brax_ant.jpg │ ├── brax_humanoid.jpg │ ├── brax_ur5e.jpg │ ├── humanoid.gif │ └── ur5e.gif │ ├── dqn_vs_dddqn.png │ ├── mario_random_stages.png │ ├── mujoco │ ├── half_cheetah.jpg │ ├── hopper.jpg │ ├── humanoid.jpg │ ├── mujoco_ant_envpool.png │ ├── mujoco_halfcheetah_envpool.png │ ├── mujoco_hopper_envpool.png │ ├── mujoco_humanoid_envpool.png │ ├── mujoco_walker2d_envpool.png │ └── walker.jpg │ ├── pong_dqn.png │ ├── rainbow_dqn_breakout.jpg │ └── smac │ ├── 2m_vs_1z.png │ ├── 3s5z_vs_3s6z.png │ ├── 3s_vs_5z.png │ ├── 5m_vs_6m.png │ ├── MMM2.png │ ├── corridor.png │ └── mmm2.gif ├── notebooks ├── brax_training.ipynb ├── brax_visualization.ipynb ├── mujoco_envpool_training.ipynb ├── train_and_export_onnx_example_continuous.ipynb ├── train_and_export_onnx_example_discrete.ipynb └── train_and_export_onnx_example_lstm_continuous.ipynb ├── poetry.lock ├── pyproject.toml ├── rl_games ├── __init__.py ├── algos_torch │ ├── __init__.py │ ├── a2c_continuous.py │ ├── a2c_discrete.py │ ├── central_value.py │ ├── d2rl.py │ ├── flatten.py │ ├── layers.py │ ├── model_builder.py │ ├── models.py │ ├── moving_mean_std.py │ ├── network_builder.py │ ├── players.py │ ├── running_mean_std.py │ ├── sac_agent.py │ ├── sac_helper.py │ ├── self_play_manager.py │ ├── spatial_softmax.py │ └── torch_ext.py ├── common │ ├── __init__.py │ ├── a2c_common.py │ ├── algo_observer.py │ ├── categorical.py │ ├── common_losses.py │ ├── datasets.py │ ├── diagnostics.py │ ├── divergence.py │ ├── env_configurations.py │ ├── experience.py │ ├── experiment.py │ ├── extensions │ │ ├── __init__.py │ │ └── distributions.py │ ├── interval_summary_writer.py │ ├── ivecenv.py │ ├── layers │ │ ├── __init__.py │ │ ├── action.py │ │ ├── recurrent.py │ │ └── value.py │ ├── object_factory.py │ ├── player.py │ ├── rollouts.py │ ├── schedulers.py │ ├── segment_tree.py │ ├── tr_helpers.py │ ├── transforms │ │ ├── __init__.py │ │ ├── soft_augmentation.py │ │ └── transforms.py │ ├── vecenv.py │ └── wrappers.py ├── configs │ ├── atari │ │ ├── ppo_breakout.yaml │ │ ├── ppo_breakout_cule.yaml │ │ ├── ppo_breakout_envpool.yaml │ │ ├── ppo_breakout_envpool_resnet.yaml │ │ ├── ppo_breakout_torch_impala.yaml │ │ ├── ppo_gopher.yaml │ │ ├── ppo_invaders_envpool.yaml │ │ ├── ppo_invaders_envpool_rnn.yaml │ │ ├── ppo_pacman_envpool.yaml │ │ ├── ppo_pacman_envpool_resnet.yaml │ │ ├── ppo_pacman_envpool_rnn.yaml │ │ ├── ppo_pacman_torch.yaml │ │ ├── ppo_pacman_torch_rnn.yaml │ │ ├── ppo_pong.yaml │ │ ├── ppo_pong_cule.yaml │ │ ├── ppo_pong_envpool.yaml │ │ ├── ppo_pong_envpool_resnet.yaml │ │ ├── ppo_space_invaders_resnet.yaml │ │ └── ppo_space_invaders_torch.yaml │ ├── brax │ │ ├── ppo_ant.yaml │ │ ├── ppo_ant_tcnn.yaml │ │ ├── ppo_grasp.yaml │ │ ├── ppo_halfcheetah.yaml │ │ ├── ppo_humanoid.yaml │ │ ├── ppo_ur5e.yaml │ │ ├── sac_ant.yaml │ │ └── sac_humanoid.yaml │ ├── carracing_ppo.yaml │ ├── dm_control │ │ ├── acrobot_swingup.yaml │ │ ├── ball_in_cup.yaml │ │ ├── cartpole.yaml │ │ ├── cheetah_walk.yaml │ │ ├── fish_swim.yaml │ │ ├── hopper_hop.yaml │ │ ├── hopper_stand.yaml │ │ ├── humanoid_run.yaml │ │ ├── humanoid_stand.yaml │ │ ├── humanoid_walk.yaml │ │ ├── manipulator_bringball.yaml │ │ ├── pendulum_swingup.yaml │ │ ├── walker_run.yaml │ │ ├── walker_stand.yaml │ │ └── walker_walk.yaml │ ├── ma │ │ ├── ppo_connect4_self_play.yaml │ │ ├── ppo_connect4_self_play_resnet.yaml │ │ ├── ppo_slime_self_play.yaml │ │ └── ppo_slime_v0.yaml │ ├── maniskill │ │ ├── ppo_ant.yaml │ │ ├── ppo_pick_cube_rgbd_NOT_WORKING_YET.yaml │ │ └── ppo_pick_cube_state.yaml │ ├── minigrid │ │ ├── lava_rnn_img.yaml │ │ └── minigrid_rnn_img.yaml │ ├── mujoco │ │ ├── ant.yaml │ │ ├── ant_envpool.yaml │ │ ├── halfcheetah.yaml │ │ ├── halfcheetah_envpool.yaml │ │ ├── hopper.yaml │ │ ├── hopper_envpool.yaml │ │ ├── humanoid.yaml │ │ ├── humanoid_envpool.yaml │ │ ├── sac_ant_envpool.yaml │ │ ├── sac_halfcheetah_envpool.yaml │ │ ├── walker2d.yaml │ │ └── walker2d_envpool.yaml │ ├── openai │ │ ├── ppo_gym_ant.yaml │ │ ├── ppo_gym_hand.yaml │ │ └── ppo_gym_humanoid.yaml │ ├── ppo_cartpole.yaml │ ├── ppo_cartpole_masked_velocity_rnn.yaml │ ├── ppo_continuous.yaml │ ├── ppo_continuous_lstm.yaml │ ├── ppo_lunar.yaml │ ├── ppo_lunar_continiuos_torch.yaml │ ├── ppo_lunar_discrete.yaml │ ├── ppo_multiwalker.yaml │ ├── ppo_myo.yaml │ ├── ppo_pendulum.yaml │ ├── ppo_pendulum_torch.yaml │ ├── ppo_reacher.yaml │ ├── ppo_smac.yaml │ ├── ppo_walker.yaml │ ├── ppo_walker_hardcore.yaml │ ├── ppo_walker_rnn.yaml │ ├── ppo_walker_tcnn.yaml │ ├── procgen │ │ └── ppo_coinrun.yaml │ ├── smac │ │ ├── v1 │ │ │ ├── 10m_vs_11m_torch.yaml │ │ │ ├── 27m_vs_30m_cv.yaml │ │ │ ├── 27m_vs_30m_torch.yaml │ │ │ ├── 2m_vs_1z.yaml │ │ │ ├── 2m_vs_1z_torch.yaml │ │ │ ├── 2s_vs_1c.yaml │ │ │ ├── 3m_cnn_torch.yaml │ │ │ ├── 3m_torch.yaml │ │ │ ├── 3m_torch_cv.yaml │ │ │ ├── 3m_torch_cv_joint.yaml │ │ │ ├── 3m_torch_cv_rnn.yaml │ │ │ ├── 3m_torch_rnn.yaml │ │ │ ├── 3m_torch_sa.yaml │ │ │ ├── 3m_torch_sparse.yaml │ │ │ ├── 3s5z_vs_3s6z_torch.yaml │ │ │ ├── 3s5z_vs_3s6z_torch_cv.yaml │ │ │ ├── 3s_vs_4z.yaml │ │ │ ├── 3s_vs_5z.yaml │ │ │ ├── 3s_vs_5z_cv.yaml │ │ │ ├── 3s_vs_5z_cv_rnn.yaml │ │ │ ├── 3s_vs_5z_torch_lstm.yaml │ │ │ ├── 3s_vs_5z_torch_lstm2.yaml │ │ │ ├── 5m_vs_6m_rnn.yaml │ │ │ ├── 5m_vs_6m_rnn_cv.yaml │ │ │ ├── 5m_vs_6m_sa.yaml │ │ │ ├── 5m_vs_6m_torch.yaml │ │ │ ├── 6h_vs_8z_torch.yaml │ │ │ ├── 6h_vs_8z_torch_cv.yaml │ │ │ ├── 8m_torch.yaml │ │ │ ├── 8m_torch_cv.yaml │ │ │ ├── MMM2_torch.yaml │ │ │ ├── corridor_torch.yaml │ │ │ ├── corridor_torch_cv.yaml │ │ │ └── runs │ │ │ │ ├── 2c_vs_64zg.yaml │ │ │ │ ├── 2c_vs_64zg_neg.yaml │ │ │ │ ├── 2s3z.yaml │ │ │ │ ├── 2s3z_neg.yaml │ │ │ │ ├── 2s_vs_1c.yaml │ │ │ │ ├── 2s_vs_1c_neg.yaml │ │ │ │ ├── 3s5z.yaml │ │ │ │ ├── 3s5z_neg.yaml │ │ │ │ ├── 3s_vs_5z.yaml │ │ │ │ ├── 3s_vs_5z_neg.yaml │ │ │ │ ├── 3s_vs_5z_neg_joint.yaml │ │ │ │ ├── 6h_vs_8z.yaml │ │ │ │ ├── 6h_vs_8z_neg.yaml │ │ │ │ ├── 6h_vs_8z_rnn.yaml │ │ │ │ ├── MMM2.yaml │ │ │ │ ├── MMM2_conv1d.yaml │ │ │ │ ├── MMM2_neg.yaml │ │ │ │ ├── MMM2_rnn.yaml │ │ │ │ ├── bane_vs_bane.yaml │ │ │ │ ├── bane_vs_bane_neg.yaml │ │ │ │ ├── corridor_cv.yaml │ │ │ │ └── corridor_cv_neg.yaml │ │ └── v2 │ │ │ ├── env_configs │ │ │ ├── sc2_gen_protoss.yaml │ │ │ ├── sc2_gen_protoss_epo.yaml │ │ │ ├── sc2_gen_terran.yaml │ │ │ ├── sc2_gen_terran_epo.yaml │ │ │ ├── sc2_gen_zerg.yaml │ │ │ └── sc2_gen_zerg_epo.yaml │ │ │ ├── protos_5_v_5.yaml │ │ │ ├── terran_5_v_5.yaml │ │ │ └── zerg_5_v_5.yaml │ └── test │ │ ├── test_asymmetric_continuous.yaml │ │ ├── test_asymmetric_discrete.yaml │ │ ├── test_asymmetric_discrete_mhv.yaml │ │ ├── test_asymmetric_discrete_mhv_mops.yaml │ │ ├── test_discrete.yaml │ │ ├── test_discrete_multidiscrete_mhv.yaml │ │ ├── test_discrite_testnet_aux_loss.yaml │ │ ├── test_ppo_walker_truncated_time.yaml │ │ ├── test_rnn.yaml │ │ ├── test_rnn_multidiscrete.yaml │ │ └── test_rnn_multidiscrete_mhv.yaml ├── envs │ ├── __init__.py │ ├── brax.py │ ├── cule.py │ ├── diambra │ │ └── diambra.py │ ├── envpool.py │ ├── maniskill.py │ ├── multiwalker.py │ ├── slimevolley_selfplay.py │ ├── smac_env.py │ ├── smac_v2_env.py │ ├── test │ │ ├── __init__.py │ │ ├── example_env.py │ │ ├── rnn_env.py │ │ └── test_asymmetric_env.py │ └── test_network.py ├── interfaces │ ├── __init__.py │ └── base_algorithm.py ├── networks │ ├── __init__.py │ └── tcnn_mlp.py └── torch_runner.py ├── runner.py ├── setup.py └── tests ├── __init__.py └── simple_test.py /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | # This workflow will install Python dependencies, run tests and lint with a variety of Python versions 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | ''' 4 | name: Python package 5 | 6 | on: 7 | push: 8 | branches: [ master ] 9 | pull_request: 10 | branches: [ master ] 11 | 12 | jobs: 13 | build: 14 | 15 | runs-on: ubuntu-latest 16 | strategy: 17 | matrix: 18 | python-version: [3.6, 3.7] 19 | 20 | steps: 21 | - uses: actions/checkout@v2 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | - name: Install dependencies 27 | run: | 28 | python -m pip install --upgrade pip 29 | pip install pytest 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | - name: Test with pytest 32 | run: | 33 | pytest 34 | ''' -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.pythonPath": "/home/trrrrr/anaconda3/envs/torch/bin/python", 3 | "python.linting.pylintEnabled": true, 4 | "python.linting.enabled": true, 5 | "python.linting.mypyEnabled": false, 6 | "python.linting.banditEnabled": false 7 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Denys88 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/ATARI_ENVPOOL.md: -------------------------------------------------------------------------------- 1 | # Atari with Envpool (https://envpool.readthedocs.io/en/latest/) 2 | 3 | ## How to run: 4 | * **Pong** 5 | 6 | ``` 7 | poetry install -E envpool 8 | poetry run pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 9 | poetry run python runner.py --train --file rl_games/configs/atari/ppo_pong_envpool.yaml 10 | ``` 11 | 12 | ## Results: 13 | * **Pong-v5** 2 minutes training time to achieve 20+ score. 14 | ![Pong](pictures/atari_envpool/pong_envpool.png) 15 | * **Breakout-v3** 15 minutes training time to achieve 400+ score. 16 | ![Breakout](pictures/atari_envpool/breakout_envpool.png) 17 | 18 | 19 | -------------------------------------------------------------------------------- /docs/BRAX.md: -------------------------------------------------------------------------------- 1 | # Brax (https://github.com/google/brax) 2 | 3 | ## How to run: 4 | 5 | * **Setup** 6 | 7 | ```bash 8 | poetry install -E brax 9 | poetry run pip install --upgrade "jax[cuda]==0.3.13" -f https://storage.googleapis.com/jax-releases/jax_releases.html 10 | poetry run pip install torch==1.10.2+cu113 torchvision==0.11.3+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 11 | ``` 12 | 13 | * **Ant** ```poetry run python runner.py --train --file rl_games/configs/brax/ppo_ant.yaml``` 14 | * **Humanoid** ```poetry run python runner.py --train --file rl_games/configs/brax/ppo_humanoid.yaml``` 15 | 16 | ## Visualization of the trained policy: 17 | * **brax_visualization.ipynb** 18 | 19 | ## Results: 20 | * **Ant** fps step: 1692066.6 fps total: 885603.1 21 | ![Ant](pictures/brax/brax_ant.jpg) 22 | * **Humanoid** fps step: 1244450.3 fps total: 661064.5 23 | ![Humanoid](pictures/brax/brax_humanoid.jpg) 24 | * **ur5e** fps step: 1116872.3 fps total: 627117.0 25 | ![Humanoid](pictures/brax/brax_ur5e.jpg) 26 | 27 | 28 | ![Alt Text](pictures/brax/humanoid.gif) 29 | ![Alt Text](pictures/brax/ur5e.gif) -------------------------------------------------------------------------------- /docs/CONFIG_PARAMS.md: -------------------------------------------------------------------------------- 1 | # Yaml Config Description 2 | 3 | Coming. 4 | -------------------------------------------------------------------------------- /docs/DEEPMIND_ENVPOOL.md: -------------------------------------------------------------------------------- 1 | # Deepmind Control (https://github.com/deepmind/dm_control) 2 | 3 | * I could not find any ppo deepmind_control benchmark. It is a first version only. Will be updated later. 4 | 5 | ## How to run: 6 | * **Humanoid (Stand, Walk or Run)** 7 | ``` 8 | poetry install -E envpool 9 | poetry run pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 10 | poetry run python runner.py --train --file rl_games/configs/dm_control/humanoid_walk.yaml 11 | ``` 12 | 13 | ## Results: 14 | 15 | * No tuning. I just run it on a couple of envs. 16 | * I used 4000 epochs which is ~32M steps for almost all envs except HumanoidRun. But a few millions of steps was enough for the most of the envs. 17 | * Deepmind used a pretty strange reward and training rules. A simple reward transformation: log(reward + 1) achieves best scores faster. 18 | 19 | | Env | Rewards | 20 | | ------------- | ------------- | 21 | | Ball In Cup Catch | 938 | 22 | | Cartpole Balance | 988 | 23 | | Cheetah Run | 685 | 24 | | Fish Swim | 600 | 25 | | Hopper Stand | 557 | 26 | | Humanoid Stand | 653 | 27 | | Humanoid Walk | 621 | 28 | | Humanoid Run | 200 | 29 | | Pendulum Swingup | 706 | 30 | | Walker Stand | 907 | 31 | | Walker Walk | 917 | 32 | | Walker Run | 702 | 33 | -------------------------------------------------------------------------------- /docs/MUJOCO_ENVPOOL.md: -------------------------------------------------------------------------------- 1 | # Mujoco (https://github.com/deepmind/mujoco) 2 | 3 | ## How to run: 4 | * **Humanoid** 5 | ``` 6 | poetry install -E envpool 7 | poetry run pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html 8 | poetry run python runner.py --train --file rl_games/configs/mujoco/humanoid_envpool.yaml 9 | ``` 10 | 11 | ## Results: 12 | * **HalfCheetah-v4** 13 | ![HalfCheetah](pictures/mujoco/mujoco_halfcheetah_envpool.png) 14 | * **Hopper-v4** 15 | ![Hopper](pictures/mujoco/mujoco_hopper_envpool.png) 16 | * **Walker2d-v4** 17 | ![Walker2d](pictures/mujoco/mujoco_walker2d_envpool.png) 18 | * **Ant-v4** 19 | ![Humanoid](pictures/mujoco/mujoco_ant_envpool.png) 20 | * **Humanoid-v4** 21 | ![Humanoid](pictures/mujoco/mujoco_humanoid_envpool.png) 22 | -------------------------------------------------------------------------------- /docs/OTHER.md: -------------------------------------------------------------------------------- 1 | # Random Games Results 2 | 3 | ## MiniGrid-MemoryS13Random-v0 (https://github.com/maximecb/gym-minigrid) to test my lstm implementation 4 | ```python runner.py --train --file rl_games/configs/minigrid/minigrid_rnn.yaml``` 5 | 6 | https://user-images.githubusercontent.com/1936835/147401083-ccd4020a-a3d6-4bcd-b283-50f5ab3b20b4.mp4 7 | 8 | ```python runner.py --train --file rl_games/configs/minigrid/lava_rnn.yaml``` 9 | 10 | https://user-images.githubusercontent.com/1936835/147620589-fa8b275d-e991-4bee-b225-9fb0f28b0f97.mp4 11 | 12 | 13 | 14 | ## BipedalWalkerHardcore-v3 15 | ```python runner.py --train --file rl_games/configs/ppo_walker_hardcode.yaml``` 16 | 17 | https://user-images.githubusercontent.com/1936835/147401481-7121c5bf-3ddd-4151-b200-270f29bfad32.mp4 18 | 19 | -------------------------------------------------------------------------------- /docs/pictures/atari_envpool/breakout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/atari_envpool/breakout.jpg -------------------------------------------------------------------------------- /docs/pictures/atari_envpool/breakout_envpool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/atari_envpool/breakout_envpool.png -------------------------------------------------------------------------------- /docs/pictures/atari_envpool/pong.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/atari_envpool/pong.jpg -------------------------------------------------------------------------------- /docs/pictures/atari_envpool/pong_envpool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/atari_envpool/pong_envpool.png -------------------------------------------------------------------------------- /docs/pictures/brax/brax_ant.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/brax/brax_ant.jpg -------------------------------------------------------------------------------- /docs/pictures/brax/brax_humanoid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/brax/brax_humanoid.jpg -------------------------------------------------------------------------------- /docs/pictures/brax/brax_ur5e.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/brax/brax_ur5e.jpg -------------------------------------------------------------------------------- /docs/pictures/brax/humanoid.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/brax/humanoid.gif -------------------------------------------------------------------------------- /docs/pictures/brax/ur5e.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/brax/ur5e.gif -------------------------------------------------------------------------------- /docs/pictures/dqn_vs_dddqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/dqn_vs_dddqn.png -------------------------------------------------------------------------------- /docs/pictures/mario_random_stages.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/mario_random_stages.png -------------------------------------------------------------------------------- /docs/pictures/mujoco/half_cheetah.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/mujoco/half_cheetah.jpg -------------------------------------------------------------------------------- /docs/pictures/mujoco/hopper.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/mujoco/hopper.jpg -------------------------------------------------------------------------------- /docs/pictures/mujoco/humanoid.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/mujoco/humanoid.jpg -------------------------------------------------------------------------------- /docs/pictures/mujoco/mujoco_ant_envpool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/mujoco/mujoco_ant_envpool.png -------------------------------------------------------------------------------- /docs/pictures/mujoco/mujoco_halfcheetah_envpool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/mujoco/mujoco_halfcheetah_envpool.png -------------------------------------------------------------------------------- /docs/pictures/mujoco/mujoco_hopper_envpool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/mujoco/mujoco_hopper_envpool.png -------------------------------------------------------------------------------- /docs/pictures/mujoco/mujoco_humanoid_envpool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/mujoco/mujoco_humanoid_envpool.png -------------------------------------------------------------------------------- /docs/pictures/mujoco/mujoco_walker2d_envpool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/mujoco/mujoco_walker2d_envpool.png -------------------------------------------------------------------------------- /docs/pictures/mujoco/walker.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/mujoco/walker.jpg -------------------------------------------------------------------------------- /docs/pictures/pong_dqn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/pong_dqn.png -------------------------------------------------------------------------------- /docs/pictures/rainbow_dqn_breakout.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/rainbow_dqn_breakout.jpg -------------------------------------------------------------------------------- /docs/pictures/smac/2m_vs_1z.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/smac/2m_vs_1z.png -------------------------------------------------------------------------------- /docs/pictures/smac/3s5z_vs_3s6z.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/smac/3s5z_vs_3s6z.png -------------------------------------------------------------------------------- /docs/pictures/smac/3s_vs_5z.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/smac/3s_vs_5z.png -------------------------------------------------------------------------------- /docs/pictures/smac/5m_vs_6m.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/smac/5m_vs_6m.png -------------------------------------------------------------------------------- /docs/pictures/smac/MMM2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/smac/MMM2.png -------------------------------------------------------------------------------- /docs/pictures/smac/corridor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/smac/corridor.png -------------------------------------------------------------------------------- /docs/pictures/smac/mmm2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/docs/pictures/smac/mmm2.gif -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "rl_games" 3 | version = "1.6.1" 4 | description = "" 5 | readme = "README.md" 6 | authors = [ 7 | "Denys Makoviichuk ", 8 | "Viktor Makoviichuk " 9 | ] 10 | 11 | [tool.poetry.dependencies] 12 | python = ">=3.7.1,<3.11" 13 | gym = {version = "^0.23.0", extras = ["classic_control"]} 14 | tensorboard = "^2.8.0" 15 | tensorboardX = "^2.5" 16 | PyYAML = "^6.0" 17 | psutil = "^5.9.0" 18 | setproctitle = "^1.2.2" 19 | opencv-python = "^4.5.5" 20 | wandb = "^0.12.11" 21 | 22 | ale-py = {version = "^0.7", optional = true} 23 | AutoROM = {version = "^0.4.2", optional = true, extras = ["accept-rom-license"]} 24 | brax = {version = "^0.0.13", optional = true} 25 | jax = {version = "^0.3.13", optional = true} 26 | mujoco-py = {version = "^2.1.2", optional = true} 27 | envpool = {version = "^0.6.1", optional = true} 28 | 29 | [build-system] 30 | requires = ["poetry-core>=1.0.0"] 31 | build-backend = "poetry.core.masonry.api" 32 | 33 | [tool.poetry.extras] 34 | atari = ["ale-py", "AutoROM"] 35 | brax = ["brax", "jax"] 36 | mujoco = ["mujoco-py"] 37 | envpool = ["envpool"] 38 | -------------------------------------------------------------------------------- /rl_games/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/rl_games/__init__.py -------------------------------------------------------------------------------- /rl_games/algos_torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/rl_games/algos_torch/__init__.py -------------------------------------------------------------------------------- /rl_games/algos_torch/d2rl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class D2RLNet(torch.nn.Module): 4 | def __init__(self, input_size, 5 | units, 6 | activations, 7 | norm_func_name = None): 8 | torch.nn.Module.__init__(self) 9 | self.activations = torch.nn.ModuleList(activations) 10 | self.linears = torch.nn.ModuleList([]) 11 | self.norm_layers = torch.nn.ModuleList([]) 12 | self.num_layers = len(units) 13 | last_size = input_size 14 | for i in range(self.num_layers): 15 | self.linears.append(torch.nn.Linear(last_size, units[i])) 16 | last_size = units[i] + input_size 17 | if norm_func_name == 'layer_norm': 18 | self.norm_layers.append(torch.nn.LayerNorm(units[i])) 19 | elif norm_func_name == 'batch_norm': 20 | self.norm_layers.append(torch.nn.BatchNorm1d(units[i])) 21 | else: 22 | self.norm_layers.append(torch.nn.Identity()) 23 | 24 | def forward(self, input): 25 | x = self.linears[0](input) 26 | x = self.activations[0](x) 27 | x = self.norm_layers[0](x) 28 | for i in range(1,self.num_layers): 29 | x = torch.cat([x,input], dim=1) 30 | x = self.linears[i](x) 31 | x = self.norm_layers[i](x) 32 | x = self.activations[i](x) 33 | return x -------------------------------------------------------------------------------- /rl_games/algos_torch/sac_helper.py: -------------------------------------------------------------------------------- 1 | from torch import distributions as pyd 2 | import math 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class TanhTransform(pyd.transforms.Transform): 8 | domain = pyd.constraints.real 9 | codomain = pyd.constraints.interval(-1.0, 1.0) 10 | bijective = True 11 | sign = +1 12 | 13 | def __init__(self, cache_size=1): 14 | super().__init__(cache_size=cache_size) 15 | 16 | @staticmethod 17 | def atanh(x): 18 | return 0.5 * (x.log1p() - (-x).log1p()) 19 | 20 | def __eq__(self, other): 21 | return isinstance(other, TanhTransform) 22 | 23 | def _call(self, x): 24 | return x.tanh() 25 | 26 | def _inverse(self, y): 27 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 28 | # one should use `cache_size=1` instead 29 | return self.atanh(y) 30 | 31 | def log_abs_det_jacobian(self, x, y): 32 | # We use a formula that is more numerically stable, see details in the following link 33 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 34 | return 2. * (math.log(2.) - x - F.softplus(-2. * x)) 35 | 36 | 37 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): 38 | def __init__(self, loc, scale): 39 | self.loc = loc 40 | self.scale = scale 41 | 42 | self.base_dist = pyd.Normal(loc, scale) 43 | transforms = [TanhTransform()] 44 | super().__init__(self.base_dist, transforms) 45 | 46 | @property 47 | def mean(self): 48 | mu = self.loc 49 | for tr in self.transforms: 50 | mu = tr(mu) 51 | return mu 52 | 53 | def entropy(self): 54 | return self.base_dist.entropy() 55 | -------------------------------------------------------------------------------- /rl_games/algos_torch/self_play_manager.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class SelfPlayManager: 4 | def __init__(self, config, writter): 5 | self.config = config 6 | self.writter = writter 7 | self.update_score = self.config['update_score'] 8 | self.games_to_check = self.config['games_to_check'] 9 | self.check_scores = self.config.get('check_scores', False) 10 | self.env_update_num = self.config.get('env_update_num', 1) 11 | self.env_indexes = np.arange(start=0, stop=self.env_update_num) 12 | self.updates_num = 0 13 | 14 | def update(self, algo): 15 | self.updates_num += 1 16 | if self.check_scores: 17 | data = algo.game_scores 18 | else: 19 | data = algo.game_rewards 20 | 21 | if len(data) >= self.games_to_check: 22 | mean_scores = data.get_mean() 23 | mean_rewards = algo.game_rewards.get_mean() 24 | if mean_scores > self.update_score: 25 | print('Mean scores: ', mean_scores, ' mean rewards: ', mean_rewards, ' updating weights') 26 | 27 | algo.clear_stats() 28 | self.writter.add_scalar('selfplay/iters_update_weigths', self.updates_num, algo.frame) 29 | algo.vec_env.set_weights(self.env_indexes, algo.get_weights()) 30 | self.env_indexes = (self.env_indexes + 1) % (algo.num_actors) 31 | self.updates_num = 0 32 | -------------------------------------------------------------------------------- /rl_games/common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/rl_games/common/__init__.py -------------------------------------------------------------------------------- /rl_games/common/divergence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributions as dist 3 | 4 | 5 | 6 | def d_kl_discrete(p, q): 7 | # p = target, q = online 8 | # categorical distribution parametrized by logits 9 | logits_diff = p - q 10 | p_probs = torch.exp(p) 11 | d_kl = (p_probs * logits_diff).sum(-1) 12 | return d_kl 13 | 14 | 15 | def d_kl_discrete_list(p, q): 16 | d_kl = 0 17 | for pi, qi in zip(p,q): 18 | d_kl += d_kl_discrete(pi, qi) 19 | return d_kl 20 | 21 | def d_kl_normal(p, q): 22 | # p = target, q = online 23 | p_mean, p_sigma = p 24 | q_mean, q_sigma = q 25 | mean_diff = ((q_mean - p_mean) / q_sigma).pow(2) 26 | var_ratio = (p_sigma / q_sigma).pow(2) 27 | 28 | d_kl = 0.5 * (var_ratio + mean_diff - 1 - var_ratio.log()) 29 | return d_kl.sum(-1) -------------------------------------------------------------------------------- /rl_games/common/extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/rl_games/common/extensions/__init__.py -------------------------------------------------------------------------------- /rl_games/common/ivecenv.py: -------------------------------------------------------------------------------- 1 | class IVecEnv: 2 | def step(self, actions): 3 | raise NotImplementedError 4 | 5 | def reset(self): 6 | raise NotImplementedError 7 | 8 | def has_action_masks(self): 9 | return False 10 | 11 | def get_number_of_agents(self): 12 | return 1 13 | 14 | def get_env_info(self): 15 | pass 16 | 17 | def seed(self, seed): 18 | pass 19 | 20 | def set_train_info(self, env_frames, *args, **kwargs): 21 | """ 22 | Send the information in the direction algo->environment. 23 | Most common use case: tell the environment how far along we are in the training process. This is useful 24 | for implementing curriculums and things such as that. 25 | """ 26 | pass 27 | 28 | def get_env_state(self): 29 | """ 30 | Return serializable environment state to be saved to checkpoint. 31 | Can be used for stateful training sessions, i.e. with adaptive curriculums. 32 | """ 33 | return None 34 | 35 | def set_env_state(self, env_state): 36 | pass 37 | -------------------------------------------------------------------------------- /rl_games/common/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/rl_games/common/layers/__init__.py -------------------------------------------------------------------------------- /rl_games/common/layers/action.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from rl_games.common import common_losses 4 | from rl_games.algos_torch.layers import symexp, symlog 5 | from rl_games.common.extensions.distributions import TwoHotDist 6 | 7 | 8 | class OneHotEncodedAction(nn.Module): 9 | def __init__(self, in_size, num_actions): 10 | nn.Module.__init__(self) 11 | self.value_linear = nn.Linear(in_size, out_size) 12 | 13 | def loss(self, value_preds_batch, values, curr_e_clip, return_batch, clip_value): 14 | value_preds_batch = symlog(value_preds_batch) 15 | values = symlog(values) 16 | return_batch = symlog(return_batch) 17 | return common_losses.default_critic_loss(value_preds_batch, values, curr_e_clip, return_batch, clip_value) 18 | 19 | def forward(self, input): 20 | out = self.value_linear(input) 21 | out = symexp(out) 22 | return out 23 | 24 | 25 | class TwoHotEncodedAction(nn.Module): 26 | def __init__(self, in_size, num_actions, backets=32, min_space=-1.0, max_space=1.0): 27 | nn.Module.__init__(self) 28 | assert(out_size==1) 29 | self.value_linear = nn.Linear(in_size, backets * num_actions) 30 | torch.nn.init.xavier_uniform_(self.value_linear.weight, gain=0.05) 31 | 32 | def loss(self, **kwargs): 33 | targets = kwargs.get('return_batch') 34 | neglog_prob = -self.distr.log_prob(targets) 35 | return neglog_prob 36 | 37 | def forward(self, input): 38 | out = self.value_linear(input) 39 | self.distr = TwoHotDist(logits=out, min_space=-1.0, max_space=1.0) 40 | out = self.distr.mode() 41 | return out 42 | -------------------------------------------------------------------------------- /rl_games/common/layers/value.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from rl_games.common import common_losses 4 | from rl_games.algos_torch.layers import symexp, symlog 5 | from rl_games.common.extensions.distributions import TwoHotDist 6 | 7 | 8 | class DefaultValue(nn.Module): 9 | def __init__(self, in_size, out_size): 10 | nn.Module.__init__(self) 11 | self.value_linear = nn.Linear(in_size, out_size) 12 | #torch.nn.init.xavier_uniform_(self.value_linear.weight, gain=0.05) 13 | 14 | def loss(self, value_preds_batch, values, curr_e_clip, return_batch, clip_value): 15 | #value_preds_batch = symlog(value_preds_batch) 16 | #values = symlog(values) 17 | #return_batch = symlog(return_batch) 18 | return common_losses.default_critic_loss(value_preds_batch, values, curr_e_clip, return_batch, clip_value) 19 | 20 | def forward(self, input): 21 | out = self.value_linear(input) 22 | #out = symexp(out) 23 | return out 24 | 25 | 26 | class TwoHotEncodedValue(nn.Module): 27 | def __init__(self, in_size, out_size): 28 | nn.Module.__init__(self) 29 | assert(out_size==1) 30 | self.value_linear = nn.Linear(in_size, 255) 31 | torch.nn.init.xavier_uniform_(self.value_linear.weight, gain=0.05) 32 | 33 | def loss(self, **kwargs): 34 | targets = kwargs.get('return_batch') 35 | targets = symlog(targets) 36 | 37 | neglog_prob = -self.distr.log_prob(targets) 38 | return neglog_prob 39 | 40 | def forward(self, input): 41 | out = self.value_linear(input) 42 | self.distr = TwoHotDist(logits=out) 43 | out = self.distr.mode() 44 | out = symexp(out) 45 | return out 46 | -------------------------------------------------------------------------------- /rl_games/common/object_factory.py: -------------------------------------------------------------------------------- 1 | class ObjectFactory: 2 | """General-purpose class to instantiate some other base class from rl_games. Usual use case it to instantiate algos, players etc. 3 | 4 | The ObjectFactory class is used to dynamically create any other object using a builder function (typically a lambda function). 5 | 6 | """ 7 | 8 | def __init__(self): 9 | """Initialise a dictionary of builders with keys as `str` and values as functions. 10 | 11 | """ 12 | self._builders = {} 13 | 14 | def register_builder(self, name, builder): 15 | """Register a passed builder by adding to the builders dict. 16 | 17 | Initialises runners and players for all algorithms available in the library using `rl_games.common.object_factory.ObjectFactory` 18 | 19 | Args: 20 | name (:obj:`str`): Key of the added builder. 21 | builder (:obj `func`): Function to return the requested object 22 | 23 | """ 24 | self._builders[name] = builder 25 | 26 | def set_builders(self, builders): 27 | self._builders = builders 28 | 29 | def create(self, name, **kwargs): 30 | """Create the requested object by calling a registered builder function. 31 | 32 | Args: 33 | name (:obj:`str`): Key of the requested builder. 34 | **kwargs: Arbitrary kwargs needed for the builder function 35 | 36 | """ 37 | builder = self._builders.get(name) 38 | if not builder: 39 | raise ValueError(name) 40 | return builder(**kwargs) -------------------------------------------------------------------------------- /rl_games/common/rollouts.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | ''' 4 | TODO: move play_steps here 5 | ''' 6 | class Rollout: 7 | def __init__(self, gamma): 8 | self.gamma = gamma 9 | 10 | def play_steps(self, env, max_steps_count = 1): 11 | pass 12 | 13 | 14 | class DiscretePpoRollout(Rollout): 15 | def __init__(self, gamma, lam): 16 | super(Rollout, self).__init__(gamma) 17 | self.lam = lam 18 | 19 | def play_steps(self, env, max_steps_count = 1): 20 | pass -------------------------------------------------------------------------------- /rl_games/common/transforms/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/rl_games/common/transforms/__init__.py -------------------------------------------------------------------------------- /rl_games/common/transforms/soft_augmentation.py: -------------------------------------------------------------------------------- 1 | from rl_games.common.transforms import transforms 2 | import torch 3 | 4 | class SoftAugmentation(): 5 | def __init__(self, **kwargs): 6 | self.transform_config = kwargs.pop('transform') 7 | self.aug_coef = kwargs.pop('aug_coef', 0.001) 8 | print('aug coef:', self.aug_coef) 9 | self.name = self.transform_config['name'] 10 | 11 | #TODO: remove hardcode 12 | self.transform = transforms.ImageDatasetTransform(**self.transform_config) 13 | 14 | def get_coef(self): 15 | return self.aug_coef 16 | 17 | def get_loss(self, p_dict, model, input_dict, loss_type = 'both'): 18 | ''' 19 | loss_type: 'critic', 'policy', 'both' 20 | ''' 21 | if self.transform: 22 | input_dict = self.transform(input_dict) 23 | loss = 0 24 | q_dict = model(input_dict) 25 | if loss_type == 'policy' or loss_type == 'both': 26 | p_dict['logits'] = p_dict['logits'].detach() 27 | loss = model.kl(p_dict, q_dict) 28 | if loss_type == 'critic' or loss_type == 'both': 29 | p_value = p_dict['value'].detach() 30 | q_value = q_dict['value'] 31 | loss = loss + (0.5 * (p_value - q_value)**2).sum(dim=-1) 32 | 33 | return loss -------------------------------------------------------------------------------- /rl_games/common/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class DatasetTransform(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def forward(self, dataset): 9 | return dataset 10 | 11 | 12 | class ImageDatasetTransform(DatasetTransform): 13 | def __init__(self, **kwargs): 14 | super().__init__() 15 | import kornia 16 | self.transform = torch.nn.Sequential( 17 | nn.ReplicationPad2d(4), 18 | kornia.augmentation.RandomCrop((84,84)) 19 | #kornia.augmentation.RandomErasing(p=0.2), 20 | #kornia.augmentation.RandomAffine(degrees=0, translate=(2.0/84,2.0/84), p=1), 21 | #kornia.augmentation.RandomCrop((84,84)) 22 | ) 23 | 24 | def forward(self, dataset): 25 | dataset['obs'] = self.transform(dataset['obs']) 26 | return dataset -------------------------------------------------------------------------------- /rl_games/configs/atari/ppo_breakout_envpool_resnet.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: resnet_actor_critic 10 | require_rewards: True 11 | require_last_actions: True 12 | separate: False 13 | value_shape: 1 14 | space: 15 | discrete: 16 | 17 | cnn: 18 | permute_input: False 19 | conv_depths: [16, 32, 32] 20 | activation: relu 21 | initializer: 22 | name: default 23 | regularizer: 24 | name: 'None' 25 | 26 | mlp: 27 | units: [512] 28 | activation: relu 29 | regularizer: 30 | name: 'None' 31 | initializer: 32 | name: default 33 | rnn: 34 | name: lstm 35 | units: 256 36 | layers: 1 37 | config: 38 | reward_shaper: 39 | min_val: -1 40 | max_val: 1 41 | 42 | normalize_advantage: True 43 | gamma: 0.995 44 | tau: 0.95 45 | learning_rate: 3e-4 46 | name: breakout_resnet 47 | score_to_win: 100000 48 | grad_norm: 1.5 49 | entropy_coef: 0.01 50 | truncate_grads: True 51 | env_name: envpool #'openai_gym' #'PongNoFrameskip-v4' # 52 | e_clip: 0.2 53 | clip_value: True 54 | num_actors: 64 55 | horizon_length: 128 56 | minibatch_size: 2048 57 | mini_epochs: 2 58 | critic_coef: 1 59 | lr_schedule: None 60 | kl_threshold: 0.01 61 | normalize_input: False 62 | use_diagnostics: True 63 | seq_length: 32 64 | max_epochs: 200000 65 | 66 | env_config: 67 | env_name: Breakout-v5 68 | episodic_life: True 69 | has_lives: True 70 | use_dict_obs_space: True 71 | 72 | player: 73 | render: False 74 | games_num: 20 75 | n_game_life: 5 76 | deterministic: True 77 | 78 | -------------------------------------------------------------------------------- /rl_games/configs/atari/ppo_breakout_torch_impala.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: resnet_actor_critic 10 | require_rewards: True 11 | require_last_actions: True 12 | separate: False 13 | value_shape: 1 14 | space: 15 | discrete: 16 | 17 | cnn: 18 | permute_input: True 19 | conv_depths: [16, 32, 32] 20 | activation: relu 21 | initializer: 22 | name: default 23 | regularizer: 24 | name: 'None' 25 | 26 | mlp: 27 | units: [512] 28 | activation: relu 29 | regularizer: 30 | name: 'None' 31 | initializer: 32 | name: default 33 | rnn: 34 | name: lstm 35 | units: 256 36 | layers: 1 37 | 38 | config: 39 | env_name: atari_gym 40 | reward_shaper: 41 | min_val: -1 42 | max_val: 1 43 | 44 | normalize_advantage: True 45 | gamma: 0.99 46 | tau: 0.95 47 | learning_rate: 5e-4 48 | name: breakout_impala_lstm 49 | score_to_win: 900 50 | grad_norm: 0.5 51 | entropy_coef: 0.01 52 | truncate_grads: True 53 | 54 | e_clip: 0.2 55 | clip_value: True 56 | num_actors: 16 57 | horizon_length: 256 58 | minibatch_size: 512 59 | mini_epochs: 3 60 | critic_coef: 1 61 | lr_schedule: None 62 | kl_threshold: 0.01 63 | normalize_input: False 64 | seq_length: 8 65 | 66 | # max_epochs: 5000 67 | env_config: 68 | skip: 4 69 | name: 'BreakoutNoFrameskip-v4' 70 | episode_life: True 71 | wrap_impala: True 72 | player: 73 | render: False 74 | games_num: 100 75 | n_game_life: 5 76 | deterministic: False 77 | -------------------------------------------------------------------------------- /rl_games/configs/atari/ppo_pong_envpool_resnet.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: resnet_actor_critic 10 | require_rewards: True 11 | require_last_actions: True 12 | separate: False 13 | value_shape: 1 14 | space: 15 | discrete: 16 | 17 | cnn: 18 | permute_input: False 19 | conv_depths: [16, 32, 32] 20 | activation: relu 21 | initializer: 22 | name: default 23 | regularizer: 24 | name: 'None' 25 | 26 | mlp: 27 | units: [512] 28 | activation: relu 29 | regularizer: 30 | name: 'None' 31 | initializer: 32 | name: default 33 | rnn: 34 | name: lstm 35 | units: 256 36 | layers: 1 37 | config: 38 | reward_shaper: 39 | min_val: -1 40 | max_val: 1 41 | 42 | normalize_advantage: True 43 | gamma: 0.995 44 | tau: 0.95 45 | learning_rate: 3e-4 46 | name: pong_resnet 47 | score_to_win: 100000 48 | grad_norm: 1.5 49 | entropy_coef: 0.01 50 | truncate_grads: True 51 | env_name: envpool #'openai_gym' #'PongNoFrameskip-v4' # 52 | e_clip: 0.2 53 | clip_value: True 54 | num_actors: 64 55 | horizon_length: 128 56 | minibatch_size: 2048 57 | mini_epochs: 2 58 | critic_coef: 1 59 | lr_schedule: None 60 | kl_threshold: 0.01 61 | normalize_input: False 62 | use_diagnostics: True 63 | seq_length: 32 64 | max_epochs: 200000 65 | 66 | env_config: 67 | env_name: Pong-v5 68 | has_lives: False 69 | use_dict_obs_space: True 70 | 71 | player: 72 | render: True 73 | games_num: 10 74 | n_game_life: 1 75 | deterministic: True 76 | 77 | -------------------------------------------------------------------------------- /rl_games/configs/atari/ppo_space_invaders_resnet.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: resnet_actor_critic 10 | separate: False 11 | value_shape: 1 12 | space: 13 | discrete: 14 | 15 | cnn: 16 | conv_depths: [16, 32, 32] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: 'None' 22 | 23 | mlp: 24 | units: [512] 25 | activation: relu 26 | regularizer: 27 | name: 'None' 28 | initializer: 29 | name: default 30 | rnn: 31 | name: lstm 32 | units: 256 33 | layers: 1 34 | config: 35 | reward_shaper: 36 | min_val: -1 37 | max_val: 1 38 | 39 | normalize_advantage: True 40 | gamma: 0.995 41 | tau: 0.95 42 | learning_rate: 3e-4 43 | name: invaders_resnet 44 | score_to_win: 100000 45 | grad_norm: 1.5 46 | entropy_coef: 0.001 47 | truncate_grads: True 48 | env_name: atari_gym #'openai_gym' #'PongNoFrameskip-v4' # 49 | e_clip: 0.2 50 | clip_value: True 51 | num_actors: 16 52 | horizon_length: 256 53 | minibatch_size: 2048 54 | mini_epochs: 4 55 | critic_coef: 1 56 | lr_schedule: None 57 | kl_threshold: 0.01 58 | normalize_input: False 59 | seq_length: 4 60 | max_epochs: 200000 61 | 62 | env_config: 63 | skip: 3 64 | name: 'SpaceInvadersNoFrameskip-v4' 65 | episode_life: False 66 | 67 | player: 68 | render: True 69 | games_num: 10 70 | n_game_life: 1 71 | deterministic: True 72 | 73 | -------------------------------------------------------------------------------- /rl_games/configs/brax/ppo_ant.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 7 3 | 4 | #devices: [0, 0] 5 | 6 | algo: 7 | name: a2c_continuous 8 | 9 | model: 10 | name: continuous_a2c_logstd 11 | 12 | network: 13 | name: actor_critic 14 | separate: False 15 | space: 16 | continuous: 17 | mu_activation: None 18 | sigma_activation: None 19 | 20 | mu_init: 21 | name: default 22 | sigma_init: 23 | name: const_initializer 24 | val: 0 25 | fixed_sigma: True 26 | mlp: 27 | units: [256, 128, 64] 28 | activation: elu 29 | d2rl: False 30 | 31 | initializer: 32 | name: default 33 | regularizer: 34 | name: None 35 | 36 | config: 37 | name: Ant_brax 38 | full_experiment_name: Ant_brax 39 | env_name: brax 40 | multi_gpu: False 41 | mixed_precision: True 42 | normalize_input: True 43 | normalize_value: True 44 | normalize_advantage: True 45 | use_smooth_clamp: True 46 | reward_shaper: 47 | scale_value: 1.0 48 | gamma: 0.99 49 | tau: 0.95 50 | learning_rate: 3e-4 51 | lr_schedule: adaptive 52 | kl_threshold: 0.008 53 | score_to_win: 20000 54 | max_epochs: 1000 55 | save_best_after: 100 56 | save_frequency: 50 57 | grad_norm: 1.0 58 | entropy_coef: 0.0 59 | truncate_grads: True 60 | e_clip: 0.2 61 | horizon_length: 8 62 | num_actors: 4096 63 | minibatch_size: 32768 64 | mini_epochs: 5 65 | critic_coef: 2 66 | clip_value: False 67 | bounds_loss_coef: 0.0001 68 | 69 | env_config: 70 | env_name: ant 71 | -------------------------------------------------------------------------------- /rl_games/configs/brax/ppo_ant_tcnn.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 7 3 | 4 | #devices: [0, 0] 5 | 6 | algo: 7 | name: a2c_continuous 8 | 9 | model: 10 | name: continuous_a2c_logstd 11 | 12 | network: 13 | name: tcnnnet 14 | 15 | encoding: 16 | otype: "Identity" 17 | 18 | network: 19 | type: "FullyFusedMLP" 20 | activation: "ReLU" 21 | output_activation: "None" 22 | n_neurons: 128 23 | n_hidden_layers: 4 24 | 25 | config: 26 | name: Ant_brax_tcnn 27 | env_name: brax 28 | multi_gpu: False 29 | mixed_precision: True 30 | normalize_input: True 31 | normalize_value: True 32 | reward_shaper: 33 | scale_value: 1.0 34 | normalize_advantage: True 35 | gamma: 0.99 36 | tau: 0.95 37 | learning_rate: 3e-4 38 | lr_schedule: adaptive 39 | kl_threshold: 0.008 40 | score_to_win: 20000 41 | max_epochs: 1000 42 | save_best_after: 100 43 | save_frequency: 50 44 | grad_norm: 1.0 45 | entropy_coef: 0.0 46 | truncate_grads: True 47 | e_clip: 0.2 48 | horizon_length: 8 49 | num_actors: 4096 50 | minibatch_size: 32768 51 | mini_epochs: 5 52 | critic_coef: 2 53 | clip_value: False 54 | bounds_loss_coef: 0.0001 55 | 56 | env_config: 57 | env_name: 'ant' 58 | -------------------------------------------------------------------------------- /rl_games/configs/brax/ppo_grasp.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 7 3 | 4 | #devices: [0, 0] 5 | 6 | algo: 7 | name: a2c_continuous 8 | 9 | model: 10 | name: continuous_a2c_logstd 11 | 12 | network: 13 | name: actor_critic 14 | separate: False 15 | space: 16 | continuous: 17 | mu_activation: None 18 | sigma_activation: None 19 | 20 | mu_init: 21 | name: default 22 | sigma_init: 23 | name: const_initializer 24 | val: 0 25 | fixed_sigma: True 26 | mlp: 27 | units: [512, 256, 128] 28 | activation: elu 29 | d2rl: False 30 | 31 | initializer: 32 | name: default 33 | regularizer: 34 | name: None 35 | 36 | config: 37 | name: 'Grasp_brax' 38 | env_name: brax 39 | multi_gpu: False 40 | mixed_precision: True 41 | normalize_input: True 42 | normalize_value: True 43 | reward_shaper: 44 | scale_value: 1.0 45 | normalize_advantage: True 46 | gamma: 0.99 47 | tau: 0.95 48 | learning_rate: 3e-4 49 | lr_schedule: adaptive 50 | kl_threshold: 0.008 51 | score_to_win: 20000 52 | max_epochs: 2000 53 | save_best_after: 100 54 | save_frequency: 50 55 | grad_norm: 1.0 56 | entropy_coef: 0.00 57 | truncate_grads: True 58 | e_clip: 0.2 59 | horizon_length: 16 60 | num_actors: 8192 61 | minibatch_size: 32768 62 | mini_epochs: 5 63 | critic_coef: 2 64 | clip_value: False 65 | bounds_loss_coef: 0.0004 66 | 67 | env_config: 68 | env_name: 'grasp' 69 | -------------------------------------------------------------------------------- /rl_games/configs/brax/ppo_halfcheetah.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 7 3 | 4 | #devices: [0, 0] 5 | 6 | algo: 7 | name: a2c_continuous 8 | 9 | model: 10 | name: continuous_a2c_logstd 11 | 12 | network: 13 | name: actor_critic 14 | separate: False 15 | space: 16 | continuous: 17 | mu_activation: None 18 | sigma_activation: None 19 | 20 | mu_init: 21 | name: default 22 | sigma_init: 23 | name: const_initializer 24 | val: 0 25 | fixed_sigma: True 26 | mlp: 27 | units: [512, 256, 128] 28 | activation: elu 29 | d2rl: False 30 | 31 | initializer: 32 | name: default 33 | regularizer: 34 | name: None 35 | 36 | config: 37 | name: Halfcheetah_brax 38 | env_name: brax 39 | multi_gpu: False 40 | mixed_precision: True 41 | normalize_input: True 42 | normalize_value: True 43 | reward_shaper: 44 | scale_value: 1.0 45 | normalize_advantage: True 46 | gamma: 0.99 47 | tau: 0.95 48 | learning_rate: 3e-4 49 | lr_schedule: adaptive 50 | kl_threshold: 0.008 51 | score_to_win: 20000 52 | max_epochs: 2000 53 | save_best_after: 100 54 | save_frequency: 50 55 | grad_norm: 1.0 56 | entropy_coef: 0.0 57 | truncate_grads: True 58 | e_clip: 0.2 59 | horizon_length: 16 60 | num_actors: 8192 61 | minibatch_size: 32768 62 | mini_epochs: 5 63 | critic_coef: 2 64 | clip_value: False 65 | bounds_loss_coef: 0.0004 66 | 67 | env_config: 68 | env_name: 'halfcheetah' 69 | -------------------------------------------------------------------------------- /rl_games/configs/brax/ppo_humanoid.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 7 3 | 4 | algo: 5 | name: a2c_continuous 6 | 7 | model: 8 | name: continuous_a2c_logstd 9 | 10 | network: 11 | name: actor_critic 12 | separate: False 13 | space: 14 | continuous: 15 | mu_activation: None 16 | sigma_activation: None 17 | 18 | mu_init: 19 | name: default 20 | sigma_init: 21 | name: const_initializer 22 | val: 0 23 | fixed_sigma: True 24 | mlp: 25 | units: [512, 256, 128] 26 | activation: elu 27 | d2rl: False 28 | 29 | initializer: 30 | name: default 31 | regularizer: 32 | name: None 33 | 34 | config: 35 | name: Humanoid_brax 36 | full_experiment_name: Humanoid_brax 37 | env_name: brax 38 | multi_gpu: False 39 | mixed_precision: True 40 | normalize_input: True 41 | normalize_value: True 42 | normalize_advantage: True 43 | use_smooth_clamp: True 44 | reward_shaper: 45 | scale_value: 1.0 46 | gamma: 0.99 47 | tau: 0.95 48 | learning_rate: 3e-4 49 | lr_schedule: adaptive 50 | kl_threshold: 0.008 51 | score_to_win: 20000 52 | max_epochs: 1000 53 | save_best_after: 100 54 | save_frequency: 50 55 | grad_norm: 1.0 56 | entropy_coef: 0.0 57 | truncate_grads: True 58 | e_clip: 0.2 59 | horizon_length: 16 60 | num_actors: 4096 61 | minibatch_size: 32768 62 | mini_epochs: 5 63 | critic_coef: 2 64 | 65 | clip_value: True 66 | bound_loss_type: regularisation 67 | bounds_loss_coef: 0.0 68 | 69 | env_config: 70 | env_name: humanoid -------------------------------------------------------------------------------- /rl_games/configs/brax/ppo_ur5e.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 7 3 | 4 | #devices: [0, 0] 5 | 6 | algo: 7 | name: a2c_continuous 8 | 9 | model: 10 | name: continuous_a2c_logstd 11 | 12 | network: 13 | name: actor_critic 14 | separate: False 15 | space: 16 | continuous: 17 | mu_activation: None 18 | sigma_activation: None 19 | 20 | mu_init: 21 | name: default 22 | sigma_init: 23 | name: const_initializer 24 | val: 0 25 | fixed_sigma: True 26 | mlp: 27 | units: [512, 256, 128] 28 | activation: elu 29 | d2rl: False 30 | 31 | initializer: 32 | name: default 33 | regularizer: 34 | name: None 35 | 36 | config: 37 | name: Ur5e_brax 38 | env_name: brax 39 | multi_gpu: False 40 | mixed_precision: True 41 | normalize_input: True 42 | normalize_value: True 43 | reward_shaper: 44 | scale_value: 1.0 45 | normalize_advantage: True 46 | gamma: 0.99 47 | tau: 0.95 48 | learning_rate: 3e-4 49 | lr_schedule: adaptive 50 | kl_threshold: 0.008 51 | score_to_win: 20000 52 | max_epochs: 2000 53 | save_best_after: 100 54 | save_frequency: 50 55 | grad_norm: 1.0 56 | entropy_coef: 0.00 57 | truncate_grads: True 58 | e_clip: 0.2 59 | horizon_length: 16 60 | num_actors: 8192 61 | minibatch_size: 32768 62 | mini_epochs: 5 63 | critic_coef: 2 64 | clip_value: False 65 | bounds_loss_coef: 0.0004 66 | 67 | env_config: 68 | env_name: 'ur5e' 69 | -------------------------------------------------------------------------------- /rl_games/configs/brax/sac_ant.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: sac 4 | 5 | model: 6 | name: soft_actor_critic 7 | 8 | network: 9 | name: soft_actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mlp: 14 | units: [256, 128, 64] 15 | activation: relu 16 | 17 | initializer: 18 | name: default 19 | log_std_bounds: [-5, 2] 20 | 21 | config: 22 | name: Ant_brax_SAC 23 | env_name: brax 24 | normalize_input: True 25 | reward_shaper: 26 | scale_value: 1 27 | device: cuda 28 | max_epochs: 10000 29 | num_steps_per_episode: 16 30 | save_best_after: 100 31 | save_frequency: 10000 32 | gamma: 0.99 33 | init_alpha: 1 34 | alpha_lr: 0.005 35 | actor_lr: 0.0005 36 | critic_lr: 0.0005 37 | critic_tau: 0.005 38 | batch_size: 4096 39 | learnable_temperature: True 40 | num_warmup_steps: 10 # total number of warmup steps: num_actors * num_steps_per_episode * num_warmup_steps 41 | replay_buffer_size: 1000000 42 | num_actors: 128 43 | 44 | env_config: 45 | env_name: ant -------------------------------------------------------------------------------- /rl_games/configs/brax/sac_humanoid.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: sac 4 | 5 | model: 6 | name: soft_actor_critic 7 | 8 | network: 9 | name: soft_actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | 14 | mlp: 15 | units: [512, 256] 16 | activation: relu 17 | initializer: 18 | name: default 19 | 20 | log_std_bounds: [-5, 2] 21 | 22 | config: 23 | name: Humanoid_brax_SAC 24 | env_name: brax 25 | normalize_input: True 26 | reward_shaper: 27 | scale_value: 1 28 | device: cuda 29 | max_epochs: 2000000 30 | num_steps_per_episode: 16 31 | save_best_after: 100 32 | save_frequency: 10000 33 | gamma: 0.99 34 | init_alpha: 1 35 | alpha_lr: 0.0002 36 | actor_lr: 0.0003 37 | critic_lr: 0.0003 38 | critic_tau: 0.005 39 | batch_size: 2048 40 | learnable_temperature: True 41 | num_warmup_steps: 5 # total number of warmup steps: num_actors * num_steps_per_episode * num_warmup_steps 42 | replay_buffer_size: 1000000 43 | num_actors: 64 44 | 45 | env_config: 46 | env_name: humanoid -------------------------------------------------------------------------------- /rl_games/configs/dm_control/acrobot_swingup.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [64, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: AcrobotSwingup_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.001 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.002 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: AcrobotSwingup-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/ball_in_cup.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [64, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: BallInCupCatch_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.002 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.002 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: BallInCupCatch-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/cartpole.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: CartpoleBalance_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: CartpoleBalance-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/cheetah_walk.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: CheetahRun_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: CheetahRun-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/fish_swim.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [512, 256, 128] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: FishSwim_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: FishSwim-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/hopper_hop.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: HopperHop_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: HopperHop-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/hopper_stand.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: HopperStand_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: HopperStand-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/humanoid_run.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [512, 256, 128] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: HumanoidRun_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: True 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | #max_epochs: 10000 55 | max_frames: 100_000_000 56 | num_actors: 64 57 | horizon_length: 128 58 | minibatch_size: 2048 59 | mini_epochs: 5 60 | critic_coef: 4 61 | use_diagnostics: True 62 | env_config: 63 | env_name: HumanoidRun-v1 64 | flatten_obs: True 65 | 66 | player: 67 | render: False 68 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/humanoid_stand.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [512, 256, 128] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: HumanoidStand_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | #max_epochs: 5000 55 | max_frames: 50_000_000 56 | num_actors: 64 57 | horizon_length: 128 58 | minibatch_size: 2048 59 | mini_epochs: 5 60 | critic_coef: 4 61 | use_diagnostics: True 62 | env_config: 63 | env_name: HumanoidStand-v1 64 | flatten_obs: True 65 | 66 | player: 67 | render: False 68 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/humanoid_walk.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [512, 256, 128] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: HumanoidWalk_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | max_epochs: 5000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: HumanoidWalk-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/manipulator_bringball.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [64, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: ManipulatorBringBall_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.001 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: ManipulatorBringBall-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/pendulum_swingup.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [64, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: PendulumSwingup_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.001 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.002 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: PendulumSwingup-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/walker_run.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: WalkerRun_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | max_epochs: 5000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: WalkerRun-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/walker_stand.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: WalkerStand_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: WalkerStand-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/dm_control/walker_walk.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: WalkerWalk_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: False 35 | reward_shaper: 36 | scale_value: 1 37 | shift_value: 1 38 | log_val: True 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | use_smooth_clamp: False 52 | bound_loss_type: regularisation 53 | bounds_loss_coef: 0.001 54 | max_epochs: 4000 55 | num_actors: 64 56 | horizon_length: 128 57 | minibatch_size: 2048 58 | mini_epochs: 5 59 | critic_coef: 4 60 | use_diagnostics: True 61 | env_config: 62 | env_name: WalkerWalk-v1 63 | flatten_obs: True 64 | 65 | player: 66 | render: False 67 | -------------------------------------------------------------------------------- /rl_games/configs/ma/ppo_connect4_self_play_resnet.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: connect4net 10 | blocks: 5 11 | 12 | config: 13 | name: connect4_rn 14 | reward_shaper: 15 | scale_value: 1 16 | normalize_advantage: True 17 | gamma: 0.995 18 | tau: 0.95 19 | learning_rate: 2e-4 20 | score_to_win: 100 21 | grad_norm: 0.5 22 | entropy_coef: 0.005 23 | truncate_grads: True 24 | env_name: connect4_env 25 | e_clip: 0.2 26 | clip_value: True 27 | num_actors: 4 28 | horizon_length: 128 29 | minibatch_size: 512 30 | mini_epochs: 4 31 | critic_coef: 1 32 | lr_schedule: None 33 | kl_threshold: 0.05 34 | normalize_input: False 35 | games_to_track: 1000 36 | use_action_masks: True 37 | weight_decay: 0.001 38 | self_play_config: 39 | update_score: 0.1 40 | games_to_check: 100 41 | env_update_num: 4 42 | 43 | env_config: 44 | name: connect_four_v0 45 | self_play: True 46 | is_human: True 47 | random_agent: False 48 | config_path: 'rl_games/configs/ma/ppo_connect4_self_play_resnet.yaml' -------------------------------------------------------------------------------- /rl_games/configs/ma/ppo_slime_self_play.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [128,64] 17 | activation: elu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: 'None' 22 | config: 23 | name: slime_pvp2 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.995 28 | tau: 0.95 29 | learning_rate: 2e-4 30 | score_to_win: 100 31 | grad_norm: 0.5 32 | entropy_coef: 0.01 33 | truncate_grads: True 34 | env_name: slime_gym 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 512 39 | minibatch_size: 2048 40 | mini_epochs: 4 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: False 45 | games_to_track: 500 46 | 47 | self_play_config: 48 | update_score: 1 49 | games_to_check: 200 50 | check_scores : False 51 | 52 | env_config: 53 | name: SlimeVolleyDiscrete-v0 54 | #neg_scale: 1 #0.5 55 | self_play: True 56 | config_path: 'rl_games/configs/ma/ppo_slime_self_play.yaml' 57 | 58 | player: 59 | render: True 60 | games_num: 200 61 | n_game_life: 1 62 | deterministic: True 63 | device_name: 'cpu' -------------------------------------------------------------------------------- /rl_games/configs/ma/ppo_slime_v0.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [128,64] 17 | activation: elu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: None 22 | 23 | config: 24 | name: slime 25 | reward_shaper: 26 | scale_value: 1 27 | normalize_advantage: True 28 | gamma: 0.99 29 | tau: 0.95 30 | learning_rate: 1e-4 31 | score_to_win: 20 32 | grad_norm: 0.5 33 | entropy_coef: 0.005 34 | truncate_grads: True 35 | env_name: slime_gym 36 | e_clip: 0.2 37 | clip_value: True 38 | num_actors: 8 39 | horizon_length: 128 40 | minibatch_size: 512 41 | mini_epochs: 4 42 | critic_coef: 1 43 | lr_schedule: None 44 | kl_threshold: 0.05 45 | normalize_input: False 46 | seq_length: 4 47 | use_action_masks: False 48 | ignore_dead_batches : False 49 | 50 | env_config: 51 | name: SlimeVolleyDiscrete-v0 52 | 53 | player: 54 | render: True 55 | games_num: 200 56 | n_game_life: 1 57 | deterministic: True -------------------------------------------------------------------------------- /rl_games/configs/maniskill/ppo_ant.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 7 3 | 4 | #devices: [0, 0] 5 | 6 | algo: 7 | name: a2c_continuous 8 | 9 | model: 10 | name: continuous_a2c_logstd 11 | 12 | network: 13 | name: actor_critic 14 | separate: False 15 | space: 16 | continuous: 17 | mu_activation: None 18 | sigma_activation: None 19 | 20 | mu_init: 21 | name: default 22 | sigma_init: 23 | name: const_initializer 24 | val: 0 25 | fixed_sigma: True 26 | mlp: 27 | units: [256, 128, 64] 28 | activation: elu 29 | d2rl: False 30 | 31 | initializer: 32 | name: default 33 | regularizer: 34 | name: None 35 | 36 | config: 37 | name: Ant_Maniskill 38 | full_experiment_name: Ant_Maniskill 39 | env_name: maniskill 40 | multi_gpu: False 41 | mixed_precision: True 42 | normalize_input: True 43 | normalize_value: True 44 | normalize_advantage: True 45 | use_smooth_clamp: False 46 | reward_shaper: 47 | scale_value: 1.0 48 | gamma: 0.99 49 | tau: 0.95 50 | learning_rate: 3e-4 51 | lr_schedule: adaptive 52 | kl_threshold: 0.008 53 | score_to_win: 20000 54 | max_epochs: 1000 55 | save_best_after: 100 56 | save_frequency: 50 57 | grad_norm: 1.0 58 | entropy_coef: 0.0 59 | truncate_grads: True 60 | e_clip: 0.2 61 | horizon_length: 8 62 | num_actors: 1024 63 | minibatch_size: 4096 64 | mini_epochs: 5 65 | critic_coef: 2 66 | clip_value: True 67 | bounds_loss_coef: 0.0001 68 | 69 | env_config: 70 | env_name: MS-AntRun-v1 71 | -------------------------------------------------------------------------------- /rl_games/configs/maniskill/ppo_pick_cube_state.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 7 3 | 4 | #devices: [0, 0] 5 | 6 | algo: 7 | name: a2c_continuous 8 | 9 | model: 10 | name: continuous_a2c_logstd 11 | 12 | network: 13 | name: actor_critic 14 | separate: False 15 | space: 16 | continuous: 17 | mu_activation: None 18 | sigma_activation: None 19 | 20 | mu_init: 21 | name: default 22 | sigma_init: 23 | name: const_initializer 24 | val: 0 25 | fixed_sigma: True 26 | mlp: 27 | units: [256, 128, 64] 28 | activation: elu 29 | d2rl: False 30 | 31 | initializer: 32 | name: default 33 | regularizer: 34 | name: None 35 | 36 | config: 37 | name: PickCube_State_Maniskill 38 | full_experiment_name: PickCube_State_Maniskill 39 | env_name: maniskill 40 | multi_gpu: False 41 | mixed_precision: True 42 | normalize_input: True 43 | normalize_value: True 44 | normalize_advantage: True 45 | use_smooth_clamp: False 46 | reward_shaper: 47 | scale_value: 1.0 48 | gamma: 0.99 49 | tau: 0.95 50 | learning_rate: 3e-4 51 | lr_schedule: adaptive 52 | kl_threshold: 0.008 53 | score_to_win: 20000 54 | max_epochs: 1000 55 | save_best_after: 100 56 | save_frequency: 50 57 | grad_norm: 1.0 58 | entropy_coef: 0.0 59 | truncate_grads: True 60 | e_clip: 0.2 61 | horizon_length: 8 62 | num_actors: 1024 63 | minibatch_size: 4096 64 | mini_epochs: 5 65 | critic_coef: 2 66 | clip_value: True 67 | bounds_loss_coef: 0.0001 68 | 69 | env_config: 70 | env_name: PickCube-v1 71 | obs_mode: "state" # there is also "state_dict", "rgbd", ... 72 | control_mode: "pd_ee_delta_pose" # there is also "pd_joint_delta_pos", .. 73 | -------------------------------------------------------------------------------- /rl_games/configs/mujoco/ant.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: Ant-v3_ray 30 | env_name: openai_gym 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: True 35 | reward_shaper: 36 | scale_value: 0.1 37 | normalize_advantage: True 38 | gamma: 0.99 39 | tau: 0.95 40 | 41 | learning_rate: 3e-4 42 | lr_schedule: adaptive 43 | kl_threshold: 0.008 44 | grad_norm: 1.0 45 | entropy_coef: 0.0 46 | truncate_grads: True 47 | e_clip: 0.2 48 | max_epochs: 2000 49 | num_actors: 8 #64 50 | horizon_length: 256 #64 51 | minibatch_size: 2048 52 | mini_epochs: 4 53 | critic_coef: 2 54 | clip_value: True 55 | use_smooth_clamp: True 56 | bound_loss_type: regularisation 57 | bounds_loss_coef: 0.0 58 | 59 | env_config: 60 | name: Ant-v3 61 | seed: 5 62 | 63 | player: 64 | render: True -------------------------------------------------------------------------------- /rl_games/configs/mujoco/ant_envpool.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: Ant-v4_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: True 35 | normalize_advantage: True 36 | reward_shaper: 37 | scale_value: 1 38 | 39 | gamma: 0.99 40 | tau: 0.95 41 | learning_rate: 3e-4 42 | lr_schedule: adaptive 43 | kl_threshold: 0.008 44 | grad_norm: 1.0 45 | entropy_coef: 0.0 46 | truncate_grads: True 47 | e_clip: 0.2 48 | clip_value: True 49 | use_smooth_clamp: True 50 | bound_loss_type: regularisation 51 | bounds_loss_coef: 0.0 52 | max_epochs: 2000 53 | num_actors: 64 54 | horizon_length: 64 55 | minibatch_size: 2048 56 | mini_epochs: 4 57 | critic_coef: 2 58 | 59 | env_config: 60 | env_name: Ant-v4 61 | seed: 5 62 | #flat_observation: True 63 | 64 | player: 65 | render: False -------------------------------------------------------------------------------- /rl_games/configs/mujoco/halfcheetah.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [128, 64, 32] 24 | activation: elu 25 | initializer: 26 | name: variance_scaling_initializer 27 | scale: 2.0 28 | 29 | config: 30 | name: HalfCheetah-v4_ray 31 | env_name: openai_gym 32 | score_to_win: 20000 33 | normalize_input: True 34 | normalize_value: True 35 | value_bootstrap: True 36 | reward_shaper: 37 | scale_value: 0.1 38 | normalize_advantage: True 39 | use_smooth_clamp: True 40 | gamma: 0.99 41 | tau: 0.95 42 | 43 | learning_rate: 5e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | grad_norm: 1.0 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: False 51 | num_actors: 64 52 | horizon_length: 256 53 | minibatch_size: 2048 54 | mini_epochs: 5 55 | critic_coef: 4 56 | bounds_loss_coef: 0.0 57 | max_epochs: 1000 58 | env_config: 59 | name: HalfCheetah-v4 60 | seed: 5 61 | 62 | player: 63 | render: True 64 | deterministic: True 65 | games_num: 100 -------------------------------------------------------------------------------- /rl_games/configs/mujoco/halfcheetah_envpool.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | # name: variance_scaling_initializer 19 | # scale: 1.0 20 | sigma_init: 21 | name: const_initializer 22 | val: 0 23 | fixed_sigma: True 24 | mlp: 25 | units: [128, 64, 32] 26 | activation: elu 27 | initializer: 28 | name: variance_scaling_initializer 29 | scale: 2.0 30 | 31 | config: 32 | name: HalfCheetah-v4_envpool 33 | env_name: envpool 34 | score_to_win: 20000 35 | normalize_input: True 36 | normalize_value: True 37 | value_bootstrap: True 38 | reward_shaper: 39 | scale_value: 1 40 | normalize_advantage: True 41 | use_smooth_clamp: True 42 | gamma: 0.99 43 | tau: 0.95 44 | 45 | learning_rate: 5e-4 46 | lr_schedule: adaptive 47 | kl_threshold: 0.008 48 | grad_norm: 1.0 49 | entropy_coef: 0.0 50 | truncate_grads: True 51 | e_clip: 0.2 52 | clip_value: False 53 | num_actors: 64 54 | horizon_length: 256 55 | minibatch_size: 2048 56 | mini_epochs: 5 57 | critic_coef: 4 58 | bounds_loss_coef: 0.0 59 | max_epochs: 1000 60 | env_config: 61 | env_name: HalfCheetah-v4 62 | seed: 5 63 | 64 | player: 65 | render: True 66 | deterministic: True 67 | games_num: 100 -------------------------------------------------------------------------------- /rl_games/configs/mujoco/hopper.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: Hopper-v4_ray 30 | env_name: openai_gym 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: True 35 | reward_shaper: 36 | scale_value: 0.1 37 | normalize_advantage: True 38 | gamma: 0.99 39 | tau: 0.95 40 | 41 | learning_rate: 5e-4 42 | lr_schedule: adaptive 43 | kl_threshold: 0.008 44 | grad_norm: 1.0 45 | entropy_coef: 0.0 46 | truncate_grads: True 47 | e_clip: 0.2 48 | clip_value: False 49 | num_actors: 64 50 | horizon_length: 64 51 | minibatch_size: 2048 52 | mini_epochs: 5 53 | critic_coef: 2 54 | use_smooth_clamp: True 55 | bound_loss_type: regularisation 56 | bounds_loss_coef: 0.0 57 | max_epochs: 1000 58 | 59 | env_config: 60 | name: Hopper-v4 61 | seed: 5 62 | 63 | player: 64 | render: True 65 | deterministic: True -------------------------------------------------------------------------------- /rl_games/configs/mujoco/hopper_envpool.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: Hopper-v4_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: True 35 | reward_shaper: 36 | scale_value: 0.1 37 | normalize_advantage: True 38 | use_smooth_clamp: True 39 | gamma: 0.99 40 | tau: 0.95 41 | 42 | learning_rate: 5e-4 43 | lr_schedule: adaptive 44 | kl_threshold: 0.008 45 | grad_norm: 1.0 46 | entropy_coef: 0.0 47 | truncate_grads: True 48 | e_clip: 0.2 49 | clip_value: False 50 | num_actors: 64 51 | horizon_length: 64 52 | minibatch_size: 2048 53 | mini_epochs: 5 54 | critic_coef: 2 55 | bound_loss_type: regularisation 56 | bounds_loss_coef: 0.0 57 | max_epochs: 1000 58 | 59 | env_config: 60 | env_name: Hopper-v4 61 | seed: 5 62 | 63 | player: 64 | render: True 65 | deterministic: True 66 | games_num: 100 -------------------------------------------------------------------------------- /rl_games/configs/mujoco/humanoid.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 7 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [512, 256, 128] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: Humanoid-v4_ray 30 | env_name: openai_gym 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: True 35 | reward_shaper: 36 | scale_value: 0.1 37 | normalize_advantage: True 38 | gamma: 0.99 39 | tau: 0.95 40 | 41 | learning_rate: 3e-4 42 | lr_schedule: adaptive 43 | kl_threshold: 0.008 44 | grad_norm: 1.0 45 | entropy_coef: 0.0 46 | truncate_grads: True 47 | e_clip: 0.2 48 | clip_value: True 49 | use_smooth_clamp: True 50 | bound_loss_type: regularisation 51 | bounds_loss_coef: 0.0005 52 | max_epochs: 2000 53 | num_actors: 64 54 | horizon_length: 128 55 | minibatch_size: 2048 56 | mini_epochs: 5 57 | critic_coef: 4 58 | 59 | env_config: 60 | name: Humanoid-v4 61 | seed: 5 62 | 63 | player: 64 | render: True -------------------------------------------------------------------------------- /rl_games/configs/mujoco/humanoid_envpool.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [512, 256, 128] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: Humanoid-v4_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: True 35 | reward_shaper: 36 | scale_value: 0.1 37 | normalize_advantage: True 38 | gamma: 0.99 39 | tau: 0.95 40 | 41 | learning_rate: 3e-4 42 | lr_schedule: adaptive 43 | kl_threshold: 0.008 44 | grad_norm: 1.0 45 | entropy_coef: 0.0 46 | truncate_grads: True 47 | e_clip: 0.2 48 | clip_value: True 49 | use_smooth_clamp: True 50 | bound_loss_type: regularisation 51 | bounds_loss_coef: 0.0005 52 | max_epochs: 2000 53 | num_actors: 64 54 | horizon_length: 128 55 | minibatch_size: 2048 56 | mini_epochs: 5 57 | critic_coef: 4 58 | 59 | env_config: 60 | env_name: Humanoid-v4 61 | 62 | player: 63 | render: True -------------------------------------------------------------------------------- /rl_games/configs/mujoco/sac_ant_envpool.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: sac 5 | 6 | model: 7 | name: soft_actor_critic 8 | 9 | network: 10 | name: soft_actor_critic 11 | separate: True 12 | space: 13 | continuous: 14 | mlp: 15 | units: [256, 128, 64] 16 | activation: relu 17 | 18 | initializer: 19 | name: default 20 | log_std_bounds: [-5, 2] 21 | 22 | config: 23 | name: Ant-v4_SAC 24 | env_name: envpool 25 | normalize_input: True 26 | reward_shaper: 27 | scale_value: 1.0 28 | 29 | max_epochs: 10000 30 | num_steps_per_episode: 8 31 | save_best_after: 500 32 | save_frequency: 1000 33 | gamma: 0.99 34 | init_alpha: 1 35 | alpha_lr: 5e-3 36 | actor_lr: 5e-4 37 | critic_lr: 5e-4 38 | critic_tau: 5e-3 39 | batch_size: 2048 40 | learnable_temperature: True 41 | num_warmup_steps: 10 # total number of warmup steps: num_actors * num_steps_per_episode * num_warmup_steps 42 | replay_buffer_size: 1000000 43 | num_actors: 64 44 | 45 | env_config: 46 | env_name: Ant-v4 47 | seed: 5 48 | -------------------------------------------------------------------------------- /rl_games/configs/mujoco/sac_halfcheetah_envpool.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: sac 5 | 6 | model: 7 | name: soft_actor_critic 8 | 9 | network: 10 | name: soft_actor_critic 11 | separate: True 12 | space: 13 | continuous: 14 | mlp: 15 | units: [256, 128, 64] 16 | activation: relu 17 | 18 | initializer: 19 | name: default 20 | log_std_bounds: [-5, 2] 21 | 22 | config: 23 | name: HalfCheetah-v4_SAC 24 | env_name: envpool 25 | normalize_input: True 26 | reward_shaper: 27 | scale_value: 1.0 28 | 29 | max_epochs: 40000 30 | num_steps_per_episode: 2 31 | save_best_after: 500 32 | save_frequency: 1000 33 | gamma: 0.99 34 | init_alpha: 1.0 35 | alpha_lr: 5e-3 36 | actor_lr: 5e-4 37 | critic_lr: 5e-4 38 | critic_tau: 0.005 39 | batch_size: 2048 40 | learnable_temperature: True 41 | num_warmup_steps: 50 # total number of warmup steps: num_actors * num_steps_per_episode * num_warmup_steps 42 | replay_buffer_size: 1000000 43 | num_actors: 32 44 | 45 | env_config: 46 | env_name: HalfCheetah-v4 47 | seed: 5 48 | 49 | player: 50 | render: True 51 | deterministic: True 52 | games_num: 100 -------------------------------------------------------------------------------- /rl_games/configs/mujoco/walker2d.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: Walker2d-v4_ray 30 | env_name: openai_gym 31 | normalize_input: True 32 | normalize_value: True 33 | value_bootstrap: True 34 | reward_shaper: 35 | scale_value: 0.1 36 | normalize_advantage: True 37 | gamma: 0.99 38 | tau: 0.95 39 | 40 | learning_rate: 3e-4 41 | lr_schedule: adaptive 42 | kl_threshold: 0.008 43 | grad_norm: 1.0 44 | entropy_coef: 0.0 45 | truncate_grads: True 46 | e_clip: 0.2 47 | clip_value: False 48 | num_actors: 64 49 | horizon_length: 128 50 | minibatch_size: 2048 51 | mini_epochs: 5 52 | critic_coef: 2 53 | use_smooth_clamp: True 54 | bound_loss_type: regularisation 55 | bounds_loss_coef: 0.0 56 | max_epochs: 1000 57 | env_config: 58 | name: Walker2d-v4 59 | seed: 5 60 | 61 | player: 62 | render: True -------------------------------------------------------------------------------- /rl_games/configs/mujoco/walker2d_envpool.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 5 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | 28 | config: 29 | name: Walker2d-v4_envpool 30 | env_name: envpool 31 | score_to_win: 20000 32 | normalize_input: True 33 | normalize_value: True 34 | value_bootstrap: True 35 | reward_shaper: 36 | scale_value: 0.1 37 | normalize_advantage: True 38 | gamma: 0.99 39 | tau: 0.95 40 | 41 | learning_rate: 3e-4 42 | lr_schedule: adaptive 43 | kl_threshold: 0.008 44 | grad_norm: 1.0 45 | entropy_coef: 0.0 46 | truncate_grads: True 47 | e_clip: 0.2 48 | clip_value: False 49 | num_actors: 64 50 | horizon_length: 128 51 | minibatch_size: 2048 52 | mini_epochs: 5 53 | critic_coef: 2 54 | use_smooth_clamp: True 55 | bound_loss_type: regularisation 56 | bounds_loss_coef: 0.0 57 | max_epochs: 1000 58 | env_config: 59 | env_name: Walker2d-v4 60 | seed: 5 61 | 62 | player: 63 | render: True 64 | 65 | -------------------------------------------------------------------------------- /rl_games/configs/openai/ppo_gym_ant.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: default 17 | scale: 0.02 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | regularizer: 28 | name: 'None' #'l2_regularizer' 29 | #scale: 0.001 30 | 31 | 32 | config: 33 | reward_shaper: 34 | scale_value: 0.1 35 | 36 | normalize_advantage: True 37 | gamma: 0.99 38 | tau: 0.9 39 | learning_rate: 3e-4 40 | name: Hand_block 41 | score_to_win: 100080 42 | grad_norm: 1.0 43 | entropy_coef: 0.0 44 | truncate_grads: True 45 | env_name: openai_gym 46 | e_clip: 0.2 47 | clip_value: True 48 | num_actors: 16 49 | horizon_length: 128 50 | minibatch_size: 2048 51 | mini_epochs: 12 52 | critic_coef: 2 53 | lr_schedule: adaptive 54 | kl_threshold: 0.008 55 | normalize_input: False 56 | seq_length: 4 57 | bounds_loss_coef: 0.0001 58 | max_epochs: 10000 59 | 60 | env_config: 61 | name: Ant-v3 62 | -------------------------------------------------------------------------------- /rl_games/configs/openai/ppo_gym_hand.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: default 17 | scale: 0.02 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [400, 200, 100] 24 | activation: elu 25 | initializer: 26 | name: default 27 | regularizer: 28 | name: 'None' #'l2_regularizer' 29 | #scale: 0.001 30 | 31 | config: 32 | reward_shaper: 33 | scale_value: 1.0 34 | 35 | normalize_advantage: True 36 | gamma: 0.99 37 | tau: 0.9 38 | learning_rate: 3e-4 39 | name: HandBlockDenseXYZ 40 | score_to_win: 10000 41 | grad_norm: 1.0 42 | entropy_coef: 0.0 43 | truncate_grads: True 44 | env_name: openai_robot_gym 45 | e_clip: 0.2 46 | clip_value: True 47 | num_actors: 16 48 | horizon_length: 256 49 | minibatch_size: 2048 50 | mini_epochs: 12 51 | critic_coef: 2 52 | lr_schedule: adaptive 53 | kl_threshold: 0.008 54 | normalize_input: True 55 | seq_length: 4 56 | bounds_loss_coef: 0.0001 57 | max_epochs: 10000 58 | 59 | env_config: 60 | name: HandVMManipulateBlockRotateXYZDense-v0 -------------------------------------------------------------------------------- /rl_games/configs/openai/ppo_gym_humanoid.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: default 17 | scale: 0.02 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [400, 200, 100] 24 | activation: elu 25 | initializer: 26 | name: default 27 | regularizer: 28 | name: 'None' #'l2_regularizer' 29 | #scale: 0.001 30 | 31 | config: 32 | reward_shaper: 33 | scale_value: 0.1 34 | 35 | normalize_advantage: True 36 | gamma: 0.99 37 | tau: 0.9 38 | learning_rate: 3e-4 39 | name: Humanoid 40 | score_to_win: 100080 41 | grad_norm: 1.0 42 | entropy_coef: 0.0 43 | truncate_grads: True 44 | env_name: openai_gym 45 | e_clip: 0.2 46 | clip_value: True 47 | num_actors: 16 48 | horizon_length: 256 49 | minibatch_size: 2048 50 | mini_epochs: 12 51 | critic_coef: 2 52 | lr_schedule: adaptive 53 | kl_threshold: 0.008 54 | normalize_input: False 55 | seq_length: 4 56 | bounds_loss_coef: 0.0001 57 | max_epochs: 10000 58 | 59 | env_config: 60 | name: Humanoid-v3 61 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_cartpole.yaml: -------------------------------------------------------------------------------- 1 | 2 | #Cartpole MLP 3 | 4 | params: 5 | algo: 6 | name: a2c_discrete 7 | 8 | model: 9 | name: discrete_a2c 10 | 11 | load_checkpoint: False 12 | load_path: path 13 | 14 | network: 15 | name: actor_critic 16 | separate: True 17 | space: 18 | discrete: 19 | mlp: 20 | units: [32, 32] 21 | activation: relu 22 | initializer: 23 | name: default 24 | regularizer: 25 | name: None 26 | 27 | config: 28 | reward_shaper: 29 | scale_value: 0.1 30 | normalize_advantage: True 31 | gamma: 0.99 32 | tau: 0.9 33 | learning_rate: 2e-4 34 | name: cartpole_vel_info 35 | score_to_win: 400 36 | grad_norm: 1.0 37 | entropy_coef: 0.01 38 | truncate_grads: True 39 | env_name: CartPole-v1 40 | e_clip: 0.2 41 | clip_value: True 42 | num_actors: 16 43 | horizon_length: 32 44 | minibatch_size: 64 45 | mini_epochs: 4 46 | critic_coef: 1 47 | lr_schedule: None 48 | kl_threshold: 0.008 49 | normalize_input: False 50 | save_best_after: 10 51 | device: 'cuda' 52 | multi_gpu: False 53 | 54 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_cartpole_masked_velocity_rnn.yaml: -------------------------------------------------------------------------------- 1 | 2 | #Cartpole without velocities lstm test 3 | 4 | params: 5 | algo: 6 | name: a2c_discrete 7 | 8 | model: 9 | name: discrete_a2c 10 | 11 | load_checkpoint: False 12 | load_path: path 13 | 14 | network: 15 | name: actor_critic 16 | separate: True 17 | space: 18 | discrete: 19 | 20 | mlp: 21 | units: [64, 64] 22 | activation: relu 23 | normalization: 'layer_norm' 24 | norm_only_first_layer: True 25 | initializer: 26 | name: default 27 | regularizer: 28 | name: None 29 | rnn: 30 | name: 'lstm' 31 | units: 64 32 | layers: 2 33 | before_mlp: False 34 | concat_input: True 35 | layer_norm: True 36 | 37 | config: 38 | env_name: CartPoleMaskedVelocity-v1 39 | reward_shaper: 40 | scale_value: 0.1 41 | normalize_advantage: True 42 | gamma: 0.99 43 | tau: 0.9 44 | learning_rate: 1e-4 45 | name: cartpole_vel_info 46 | score_to_win: 500 47 | grad_norm: 0.5 48 | entropy_coef: 0.01 49 | truncate_grads: True 50 | e_clip: 0.2 51 | clip_value: True 52 | num_actors: 16 53 | horizon_length: 256 54 | minibatch_size: 2048 55 | mini_epochs: 4 56 | critic_coef: 1 57 | lr_schedule: None 58 | kl_threshold: 0.008 59 | normalize_input: False 60 | seq_length: 4 -------------------------------------------------------------------------------- /rl_games/configs/ppo_continuous.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: default 17 | scale: 0.02 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | activation: elu 25 | initializer: 26 | name: default 27 | regularizer: 28 | name: None #'l2_regularizer' 29 | #scale: 0.001 30 | 31 | load_checkpoint: False 32 | load_path: path 33 | 34 | config: 35 | reward_shaper: 36 | scale_value: 0.1 37 | normalize_advantage: True 38 | gamma: 0.99 39 | tau: 0.9 40 | 41 | learning_rate: 3e-4 42 | name: walker 43 | score_to_win: 300 44 | 45 | grad_norm: 0.5 46 | entropy_coef: 0.0 47 | truncate_grads: True 48 | env_name: openai_gym 49 | e_clip: 0.2 50 | clip_value: True 51 | num_actors: 16 52 | horizon_length: 256 53 | minibatch_size: 1024 54 | mini_epochs: 8 55 | critic_coef: 1 56 | lr_schedule: adaptive 57 | kl_threshold: 0.008 58 | 59 | normalize_input: False 60 | seq_length: 8 61 | bounds_loss_coef: 0.001 62 | env_config: 63 | name: BipedalWalkerHardcore-v3 64 | 65 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_continuous_lstm.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_lstm_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: normc_initializer 17 | std: 0.01 18 | sigma_init: 19 | name: const_initializer 20 | value: 0.0 21 | fixed_sigma: True 22 | 23 | mlp: 24 | units: [256, 256, 128] 25 | activation: relu 26 | initializer: 27 | name: normc_initializer 28 | std: 1 29 | regularizer: 30 | name: None 31 | lstm: 32 | units: 128 33 | concated: False 34 | 35 | config: 36 | env_name: BipedalWalkerHardcore-v2 37 | reward_shaper: 38 | scale_value: 0.1 39 | 40 | normalize_advantage: True 41 | gamma: 0.99 42 | tau: 0.9 43 | learning_rate: 1e-4 44 | name: walker_lstm 45 | score_to_win: 300 46 | grad_norm: 0.5 47 | entropy_coef: 0.000 48 | truncate_grads: True 49 | e_clip: 0.2 50 | clip_value: True 51 | num_actors: 16 52 | horizon_length: 512 53 | minibatch_size: 2048 54 | mini_epochs: 8 55 | critic_coef: 1 56 | lr_schedule: None 57 | kl_threshold: 0.008 58 | normalize_input: False 59 | seq_length: 8 60 | bounds_loss_coef: 0.5 61 | max_epochs: 5000 62 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_lunar.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: glorot_normal_initializer 17 | #scal: 0.01 18 | sigma_init: 19 | name: const_initializer 20 | value: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [64, 64] 24 | activation: relu 25 | initializer: 26 | name: glorot_normal_initializer 27 | #gain: 2 28 | regularizer: 29 | name: 'None' #'l2_regularizer' 30 | #scale: 0.001 31 | 32 | load_checkpoint: False 33 | load_path: path 34 | 35 | config: 36 | reward_shaper: 37 | scale_value: 0.1 38 | normalize_advantage: True 39 | gamma: 0.99 40 | tau: 0.9 41 | 42 | learning_rate: 1e-4 43 | name: test 44 | score_to_win: 300 45 | 46 | grad_norm: 0.5 47 | entropy_coef: 0.0 48 | truncate_grads: True 49 | env_name: LunarLanderContinuous-v2 50 | e_clip: 0.2 51 | clip_value: True 52 | num_actors: 16 53 | horizon_length: 128 54 | minibatch_size: 1024 55 | mini_epochs: 4 56 | critic_coef: 1 57 | lr_schedule: adaptive 58 | kl_threshold: 0.008 59 | normalize_input: False 60 | bounds_loss_coef: 0 61 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_lunar_continiuos_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: default 17 | scale: 0.02 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [64] 24 | activation: relu 25 | initializer: 26 | name: default 27 | scale: 2 28 | rnn: 29 | name: 'lstm' 30 | units: 64 31 | layers: 1 32 | 33 | load_checkpoint: False 34 | load_path: path 35 | 36 | config: 37 | env_name: LunarLanderContinuous-v2 38 | reward_shaper: 39 | scale_value: 0.1 40 | normalize_advantage: True 41 | gamma: 0.99 42 | tau: 0.9 43 | 44 | learning_rate: 1e-3 45 | name: test 46 | score_to_win: 300 47 | 48 | grad_norm: 0.5 49 | entropy_coef: 0.0 50 | truncate_grads: True 51 | e_clip: 0.2 52 | clip_value: True 53 | num_actors: 16 54 | horizon_length: 128 55 | minibatch_size: 1024 56 | mini_epochs: 4 57 | critic_coef: 1 58 | lr_schedule: adaptive 59 | kl_threshold: 0.008 60 | normalize_input: True 61 | seq_length: 4 62 | bounds_loss_coef: 0 63 | 64 | player: 65 | render: True 66 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_lunar_discrete.yaml: -------------------------------------------------------------------------------- 1 | 2 | #Cartpole MLP 3 | 4 | params: 5 | algo: 6 | name: a2c_discrete 7 | 8 | model: 9 | name: discrete_a2c 10 | 11 | network: 12 | name: actor_critic 13 | separate: True 14 | space: 15 | discrete: 16 | mlp: 17 | units: [64, 64] 18 | activation: relu 19 | initializer: 20 | name: default 21 | regularizer: 22 | name: None 23 | 24 | config: 25 | env_name: LunarLander-v2 26 | reward_shaper: 27 | scale_value: 0.1 28 | normalize_advantage: True 29 | gamma: 0.99 30 | tau: 0.9 31 | learning_rate: 8e-4 32 | name: LunarLander-discrete 33 | score_to_win: 500 34 | grad_norm: 1.0 35 | entropy_coef: 0.01 36 | truncate_grads: True 37 | e_clip: 0.2 38 | clip_value: True 39 | num_actors: 16 40 | horizon_length: 32 41 | minibatch_size: 64 42 | mini_epochs: 4 43 | critic_coef: 1 44 | lr_schedule: None 45 | kl_threshold: 0.008 46 | normalize_input: False 47 | device: cuda 48 | multi_gpu: False 49 | use_diagnostics: True -------------------------------------------------------------------------------- /rl_games/configs/ppo_myo.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 8 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256,128,64] 24 | d2rl: False 25 | activation: elu 26 | initializer: 27 | name: default 28 | scale: 2 29 | config: 30 | env_name: myo_gym 31 | name: myo 32 | reward_shaper: 33 | min_val: -1 34 | scale_value: 0.1 35 | 36 | normalize_advantage: True 37 | gamma: 0.995 38 | tau: 0.95 39 | learning_rate: 3e-4 40 | lr_schedule: adaptive 41 | kl_threshold: 0.008 42 | save_best_after: 10 43 | score_to_win: 10000 44 | grad_norm: 1.5 45 | entropy_coef: 0 46 | truncate_grads: True 47 | e_clip: 0.2 48 | clip_value: False 49 | num_actors: 16 50 | horizon_length: 128 51 | minibatch_size: 1024 52 | mini_epochs: 4 53 | critic_coef: 2 54 | normalize_input: True 55 | bounds_loss_coef: 0.00 56 | max_epochs: 10000 57 | normalize_value: True 58 | use_diagnostics: True 59 | value_bootstrap: True 60 | #weight_decay: 0.0001 61 | use_smooth_clamp: True 62 | env_config: 63 | name: 'myoElbowPose1D6MRandom-v0' 64 | player: 65 | 66 | render: True 67 | deterministic: True 68 | games_num: 200 69 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_pendulum.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: default 17 | scale: 0.01 18 | sigma_init: 19 | name: const_initializer 20 | value: 0 21 | fixed_sigma: False 22 | mlp: 23 | units: [32, 32] 24 | activation: elu 25 | initializer: 26 | name: default 27 | scale: 1 28 | regularizer: 29 | name: 'None' #'l2_regularizer' 30 | #scale: 0.001 31 | 32 | load_checkpoint: False 33 | load_path: path 34 | 35 | config: 36 | env_name: Pendulum-v0 37 | reward_shaper: 38 | scale_value: 0.01 39 | normalize_advantage: True 40 | gamma: 0.99 41 | tau: 0.9 42 | 43 | learning_rate: 1e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.008 46 | name: test 47 | score_to_win: 300 48 | 49 | grad_norm: 0.5 50 | entropy_coef: 0.0 51 | truncate_grads: True 52 | e_clip: 0.2 53 | clip_value: True 54 | num_actors: 16 55 | horizon_length: 128 56 | minibatch_size: 1024 57 | mini_epochs: 4 58 | critic_coef: 1 59 | 60 | normalize_input: False 61 | seq_length: 8 62 | bounds_loss_coef: 0 63 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_pendulum_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: glorot_normal_initializer 17 | gain: 0.01 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [32, 32] 24 | activation: elu 25 | initializer: 26 | name: glorot_normal_initializer 27 | gain: 2 28 | regularizer: 29 | name: None #'l2_regularizer' 30 | #scale: 0.001 31 | 32 | config: 33 | env_name: openai_gym 34 | reward_shaper: 35 | scale_value: 0.01 36 | normalize_advantage: True 37 | gamma: 0.99 38 | tau: 0.9 39 | 40 | learning_rate: 1e-3 41 | name: pendulum 42 | score_to_win: 300 43 | 44 | grad_norm: 0.5 45 | entropy_coef: 0.0 46 | truncate_grads: True 47 | e_clip: 0.2 48 | clip_value: False 49 | num_actors: 16 50 | horizon_length: 128 51 | minibatch_size: 1024 52 | mini_epochs: 4 53 | critic_coef: 1 54 | lr_schedule: adaptive 55 | kl_threshold: 0.016 56 | 57 | normalize_input: False 58 | bounds_loss_coef: 0 59 | 60 | env_config: 61 | name: Pendulum-v1 -------------------------------------------------------------------------------- /rl_games/configs/ppo_reacher.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: default 17 | scale: 0.02 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256,128] 24 | activation: relu 25 | initializer: 26 | name: default 27 | regularizer: 28 | name: 'None' #'l2_regularizer' 29 | #scale: 0.001 30 | rnn1: 31 | name: lstm 32 | units: 64 33 | layers: 1 34 | load_checkpoint: False 35 | load_path: './nn/last_walkerep=10001rew=108.35405.pth' 36 | 37 | config: 38 | env_name: ReacherPyBulletEnv-v0 39 | name: walker 40 | reward_shaper: 41 | min_val: -1 42 | scale_value: 0.1 43 | 44 | normalize_advantage: True 45 | gamma: 0.995 46 | tau: 0.95 47 | learning_rate: 3e-4 48 | score_to_win: 300 49 | grad_norm: 0.5 50 | entropy_coef: 0 51 | truncate_grads: True 52 | 53 | e_clip: 0.2 54 | clip_value: False 55 | num_actors: 16 56 | horizon_length: 256 57 | minibatch_size: 1024 58 | mini_epochs: 4 59 | critic_coef: 1 60 | lr_schedule: none 61 | kl_threshold: 0.008 62 | normalize_input: True 63 | seq_length: 16 64 | bounds_loss_coef: 0.00 65 | max_epochs: 10000 66 | weight_decay: 0.0001 67 | 68 | player: 69 | render: True 70 | games_num: 200 71 | 72 | experiment_config1: 73 | start_exp: 0 74 | start_sub_exp: 0 75 | experiments: 76 | - exp: 77 | - path: config.bounds_loss_coef 78 | value: [0.5] -------------------------------------------------------------------------------- /rl_games/configs/ppo_smac.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | discrete: 13 | 14 | mlp: 15 | units: [256, 128] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: None 21 | 22 | config: 23 | name: 6h_vs_8z 24 | env_name: smac 25 | reward_shaper: 26 | scale_value: 1 27 | normalize_advantage: True 28 | gamma: 0.99 29 | tau: 0.95 30 | learning_rate: 1e-4 31 | lr_schedule: None 32 | kl_threshold: 0.05 33 | score_to_win: 1000 34 | grad_norm: 0.5 35 | entropy_coef: 0.001 36 | truncate_grads: True 37 | 38 | e_clip: 0.2 39 | clip_value: True 40 | num_actors: 8 41 | horizon_length: 128 42 | minibatch_size: 3072 43 | mini_epochs: 4 44 | critic_coef: 1 45 | normalize_input: False 46 | seq_length: 4 47 | use_action_masks: True 48 | 49 | env_config: 50 | name: 6h_vs_8z 51 | frames: 2 52 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/ppo_walker.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 8 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256,128,64] 24 | d2rl: False 25 | activation: elu 26 | initializer: 27 | name: default 28 | scale: 2 29 | config: 30 | env_name: BipedalWalker-v3 31 | name: walker 32 | reward_shaper: 33 | min_val: -1 34 | scale_value: 0.1 35 | 36 | normalize_advantage: True 37 | gamma: 0.995 38 | tau: 0.95 39 | learning_rate: 3e-4 40 | lr_schedule: adaptive 41 | kl_threshold: 0.008 42 | save_best_after: 10 43 | score_to_win: 300 44 | grad_norm: 1.5 45 | entropy_coef: 0 46 | truncate_grads: True 47 | e_clip: 0.2 48 | clip_value: False 49 | num_actors: 16 50 | horizon_length: 4096 51 | minibatch_size: 8192 52 | mini_epochs: 4 53 | critic_coef: 2 54 | normalize_input: True 55 | bounds_loss_coef: 0.00 56 | max_epochs: 10000 57 | normalize_value: True 58 | use_diagnostics: True 59 | value_bootstrap: True 60 | #weight_decay: 0.0001 61 | use_smooth_clamp: True 62 | 63 | player: 64 | render: True 65 | deterministic: True 66 | games_num: 200 67 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_walker_hardcore.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_continuous 4 | 5 | model: 6 | name: continuous_a2c_logstd 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | continuous: 13 | mu_activation: None 14 | sigma_activation: None 15 | mu_init: 16 | name: default 17 | sigma_init: 18 | name: const_initializer 19 | val: 0 20 | fixed_sigma: True 21 | mlp: 22 | units: [256,128, 64] 23 | d2rl: False 24 | activation: elu 25 | initializer: 26 | name: default 27 | load_checkpoint: False 28 | load_path: './nn/walker_hc.pth' 29 | 30 | config: 31 | env_name: BipedalWalkerHardcore-v3 32 | name: walker_hc 33 | reward_shaper: 34 | min_val: -1 35 | scale_value: 0.1 36 | 37 | normalize_advantage: True 38 | gamma: 0.995 39 | tau: 0.95 40 | learning_rate: 5e-4 41 | lr_schedule: adaptive 42 | kl_threshold: 0.008 43 | score_to_win: 300 44 | grad_norm: 1.5 45 | save_best_after: 10 46 | entropy_coef: 0 47 | truncate_grads: True 48 | e_clip: 0.2 49 | clip_value: False 50 | num_actors: 16 51 | horizon_length: 4096 52 | minibatch_size: 8192 53 | mini_epochs: 4 54 | critic_coef: 1 55 | normalize_input: True 56 | seq_length: 4 57 | bounds_loss_coef: 0.0 58 | max_epochs: 100000 59 | weight_decay: 0 60 | 61 | player: 62 | render: False 63 | games_num: 200 64 | deterministic: True 65 | 66 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_walker_rnn.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 8 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256,128,64] 24 | d2rl: False 25 | activation: elu 26 | initializer: 27 | name: default 28 | rnn: 29 | name: 'gru' 30 | units: 64 31 | layers: 1 32 | before_mlp: False 33 | #concat_input: True 34 | #layer_norm: True 35 | 36 | config: 37 | env_name: BipedalWalker-v3 38 | name: walker_rnn 39 | reward_shaper: 40 | min_val: -1 41 | scale_value: 0.1 42 | 43 | normalize_advantage: True 44 | gamma: 0.995 45 | tau: 0.95 46 | learning_rate: 3e-4 47 | lr_schedule: adaptive 48 | kl_threshold: 0.004 49 | save_best_after: 10 50 | score_to_win: 300 51 | grad_norm: 1.5 52 | entropy_coef: -0.003 53 | truncate_grads: True 54 | e_clip: 0.2 55 | clip_value: False 56 | num_actors: 16 57 | horizon_length: 256 58 | minibatch_size: 2048 59 | mini_epochs: 4 60 | critic_coef: 2 61 | normalize_input: True 62 | bound_loss_type: regularisation 63 | bounds_loss_coef: 0.001 64 | max_epochs: 10000 65 | seq_length: 32 66 | normalize_value: True 67 | use_diagnostics: True 68 | value_bootstrap: True 69 | weight_decay: 0.0001 70 | use_smooth_clamp: True 71 | 72 | player: 73 | render: True 74 | deterministic: True 75 | games_num: 200 76 | -------------------------------------------------------------------------------- /rl_games/configs/ppo_walker_tcnn.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 8 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: tcnnnet 11 | 12 | encoding: 13 | otype: "Identity" 14 | 15 | network: 16 | type: "FullyFusedMLP" 17 | activation: "ReLU" 18 | output_activation: "None" 19 | n_neurons: 64 20 | n_hidden_layers: 3 21 | 22 | config: 23 | env_name: BipedalWalker-v3 24 | name: walker_tcnn 25 | reward_shaper: 26 | min_val: -1 27 | scale_value: 0.1 28 | 29 | normalize_advantage: True 30 | gamma: 0.995 31 | tau: 0.95 32 | learning_rate: 3e-4 33 | lr_schedule: adaptive 34 | kl_threshold: 0.008 35 | save_best_after: 10 36 | score_to_win: 300 37 | grad_norm: 1.5 38 | entropy_coef: 0 39 | truncate_grads: True 40 | e_clip: 0.2 41 | clip_value: False 42 | num_actors: 16 43 | horizon_length: 4096 44 | minibatch_size: 8192 45 | mini_epochs: 4 46 | critic_coef: 2 47 | normalize_input: True 48 | bounds_loss_coef: 0.00 49 | max_epochs: 10000 50 | normalize_value: True 51 | use_diagnostics: True 52 | value_bootstrap: True 53 | #weight_decay: 0.0001 54 | 55 | player: 56 | render: True 57 | deterministic: True 58 | games_num: 200 59 | -------------------------------------------------------------------------------- /rl_games/configs/procgen/ppo_coinrun.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: resnet_actor_critic 10 | separate: False 11 | value_shape: 1 12 | space: 13 | discrete: 14 | 15 | cnn: 16 | conv_depths: [16, 32, 32] 17 | activation: elu 18 | initializer: 19 | name: default 20 | 21 | mlp: 22 | units: [512] 23 | activation: elu 24 | initializer: 25 | name: default 26 | rnn1: 27 | name: lstm 28 | units: 256 29 | layers: 1 30 | config: 31 | reward_shaper: 32 | max_val: 10 33 | 34 | normalize_advantage: True 35 | gamma: 0.999 36 | tau: 0.95 37 | learning_rate: 1e-4 38 | name: atari 39 | score_to_win: 900 40 | grad_norm: 0.5 41 | entropy_coef: 0.001 42 | truncate_grads: True 43 | env_name: openai_gym #'PongNoFrameskip-v4' 44 | e_clip: 0.2 45 | clip_value: True 46 | num_actors: 16 47 | horizon_length: 256 48 | minibatch_size: 1024 49 | mini_epochs: 3 50 | critic_coef: 1 51 | lr_schedule: polynom_decay 52 | kl_threshold: 0.01 53 | normalize_input: False 54 | seq_length: 4 55 | max_epochs: 2000 56 | env_config: 57 | name: "procgen:procgen-coinrun-v0" 58 | procgen: True 59 | frames: 4 60 | num_levels: 1000 61 | start_level: 323 62 | limit_steps: True 63 | distribution_mode: 'easy' 64 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/10m_vs_11m_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | cnn: 16 | type: conv1d 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: 'None' 22 | convs: 23 | - filters: 256 24 | kernel_size: 3 25 | strides: 1 26 | padding: 1 27 | - filters: 512 28 | kernel_size: 3 29 | strides: 1 30 | padding: 1 31 | - filters: 1024 32 | kernel_size: 3 33 | strides: 1 34 | padding: 1 35 | mlp: 36 | units: [256, 128] 37 | activation: relu 38 | initializer: 39 | name: default 40 | regularizer: 41 | name: None 42 | config: 43 | name: 10m 44 | reward_shaper: 45 | scale_value: 1 46 | normalize_advantage: True 47 | gamma: 0.99 48 | tau: 0.95 49 | learning_rate: 1e-4 50 | score_to_win: 20 51 | grad_norm: 0.5 52 | entropy_coef: 0.005 53 | truncate_grads: True 54 | env_name: smac_cnn 55 | e_clip: 0.2 56 | clip_value: True 57 | num_actors: 8 58 | horizon_length: 128 59 | minibatch_size: 2560 60 | mini_epochs: 4 61 | critic_coef: 2 62 | lr_schedule: None 63 | kl_threshold: 0.05 64 | normalize_input: True 65 | seq_length: 2 66 | use_action_masks: True 67 | env_config: 68 | name: 10m_vs_11m 69 | frames: 14 70 | transpose: False 71 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/27m_vs_30m_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | cnn: 16 | type: conv1d 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: None 22 | convs: 23 | - filters: 256 24 | kernel_size: 3 25 | strides: 1 26 | padding: 1 27 | - filters: 512 28 | kernel_size: 3 29 | strides: 1 30 | padding: 1 31 | - filters: 1024 32 | kernel_size: 3 33 | strides: 1 34 | padding: 1 35 | mlp: 36 | units: [256, 128] 37 | activation: relu 38 | initializer: 39 | name: default 40 | regularizer: 41 | name: 'None' 42 | config: 43 | name: 27m 44 | reward_shaper: 45 | scale_value: 1 46 | normalize_advantage: True 47 | gamma: 0.99 48 | tau: 0.95 49 | learning_rate: 1e-4 50 | score_to_win: 20 51 | grad_norm: 0.5 52 | entropy_coef: 0.005 53 | truncate_grads: True 54 | env_name: smac_cnn 55 | e_clip: 0.2 56 | clip_value: True 57 | num_actors: 8 58 | horizon_length: 128 59 | minibatch_size: 3456 60 | mini_epochs: 4 61 | critic_coef: 2 62 | lr_schedule: None 63 | kl_threshold: 0.05 64 | normalize_input: True 65 | seq_length: 2 66 | use_action_masks: True 67 | 68 | env_config: 69 | name: 27m_vs_30m 70 | frames: 4 71 | transpose: False 72 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/2m_vs_1z.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | discrete: 13 | mlp: 14 | units: [256, 128] 15 | activation: relu 16 | initializer: 17 | name: default 18 | regularizer: 19 | name: 'None' 20 | config: 21 | name: 2s_vs_1z 22 | reward_shaper: 23 | scale_value: 1 24 | normalize_advantage: True 25 | gamma: 0.99 26 | tau: 0.95 27 | learning_rate: 5e-4 28 | score_to_win: 1000 29 | grad_norm: 0.5 30 | entropy_coef: 0.005 31 | truncate_grads: True 32 | env_name: smac 33 | e_clip: 0.2 34 | clip_value: True 35 | num_actors: 8 36 | horizon_length: 128 37 | minibatch_size: 1024 38 | mini_epochs: 4 39 | critic_coef: 1 40 | lr_schedule: None 41 | kl_threshold: 0.05 42 | normalize_input: True 43 | seq_length: 4 44 | use_action_masks: True 45 | 46 | env_config: 47 | name: 2m_vs_1z 48 | frames: 1 49 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/2m_vs_1z_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | discrete: 13 | 14 | mlp: 15 | units: [256, 128] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: 'None' 21 | config: 22 | name: 2m_vs_1z 23 | reward_shaper: 24 | scale_value: 1 25 | normalize_advantage: True 26 | gamma: 0.99 27 | tau: 0.95 28 | learning_rate: 5e-4 29 | score_to_win: 1000 30 | grad_norm: 0.5 31 | entropy_coef: 0.005 32 | truncate_grads: True 33 | env_name: smac 34 | e_clip: 0.2 35 | clip_value: True 36 | num_actors: 8 37 | horizon_length: 128 38 | minibatch_size: 1024 39 | mini_epochs: 4 40 | critic_coef: 1 41 | lr_schedule: None 42 | kl_threshold: 0.05 43 | normalize_input: True 44 | seq_length: 4 45 | use_action_masks: True 46 | 47 | env_config: 48 | name: 2m_vs_1z 49 | frames: 1 50 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/2s_vs_1c.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c_lstm 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | discrete: 13 | 14 | mlp: 15 | units: [256, 128] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: 'None' 21 | lstm: 22 | units: 128 23 | concated: False 24 | config: 25 | name: 2m_vs_1z 26 | reward_shaper: 27 | scale_value: 1 28 | normalize_advantage: True 29 | gamma: 0.99 30 | tau: 0.95 31 | learning_rate: 1e-4 32 | score_to_win: 1000 33 | grad_norm: 0.5 34 | entropy_coef: 0.005 35 | truncate_grads: True 36 | env_name: smac 37 | e_clip: 0.2 38 | clip_value: True 39 | num_actors: 8 40 | horizon_length: 128 41 | minibatch_size: 1024 42 | mini_epochs: 4 43 | critic_coef: 1 44 | lr_schedule: None 45 | kl_threshold: 0.05 46 | normalize_input: False 47 | seq_length: 4 48 | use_action_masks: True 49 | 50 | env_config: 51 | name: 2m_vs_1z 52 | frames: 1 53 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3m_cnn_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | cnn: 16 | type: conv1d 17 | 18 | activation: relu 19 | initializer: 20 | name: glorot_uniform_initializer 21 | gain: 1 22 | regularizer: 23 | name: 'None' 24 | convs: 25 | - filters: 64 26 | kernel_size: 3 27 | strides: 2 28 | padding: 1 29 | - filters: 128 30 | kernel_size: 3 31 | strides: 1 32 | padding: 0 33 | - filters: 256 34 | kernel_size: 3 35 | strides: 1 36 | padding: 0 37 | mlp: 38 | units: [256, 128] 39 | activation: relu 40 | initializer: 41 | name: glorot_uniform_initializer 42 | gain: 1 43 | regularizer: 44 | name: 'None' 45 | config: 46 | name: 3m 47 | reward_shaper: 48 | scale_value: 1 49 | normalize_advantage: True 50 | gamma: 0.99 51 | tau: 0.95 52 | learning_rate: 5e-4 53 | score_to_win: 20 54 | grad_norm: 0.5 55 | entropy_coef: 0.005 56 | truncate_grads: True 57 | env_name: smac_cnn 58 | e_clip: 0.2 59 | clip_value: True 60 | num_actors: 8 61 | horizon_length: 128 62 | minibatch_size: 1536 63 | mini_epochs: 1 64 | critic_coef: 1 65 | lr_schedule: None 66 | kl_threshold: 0.05 67 | normalize_input: True 68 | seq_length: 2 69 | use_action_masks: True 70 | 71 | env_config: 72 | name: 3m 73 | frames: 4 74 | transpose: True 75 | random_invalid_step: True 76 | 77 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3m_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: None 22 | config: 23 | name: 3m 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.99 28 | tau: 0.95 29 | learning_rate: 5e-4 30 | score_to_win: 20 31 | grad_norm: 0.5 32 | entropy_coef: 0.001 33 | truncate_grads: True 34 | env_name: smac 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 128 39 | minibatch_size: 1536 40 | mini_epochs: 4 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: True 45 | #normalize_value: True 46 | use_action_masks: True 47 | ignore_dead_batches : False 48 | 49 | env_config: 50 | name: 3m 51 | frames: 1 52 | transpose: False 53 | random_invalid_step: False 54 | obs_last_action: True -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3m_torch_cv.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: False 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: 'None' 22 | 23 | config: 24 | name: 3m_cv 25 | reward_shaper: 26 | scale_value: 1 27 | normalize_advantage: True 28 | gamma: 0.99 29 | tau: 0.95 30 | learning_rate: 5e-4 31 | score_to_win: 20 32 | grad_norm: 0.5 33 | entropy_coef: 0.001 34 | truncate_grads: True 35 | env_name: smac 36 | e_clip: 0.2 37 | clip_value: True 38 | num_actors: 8 39 | horizon_length: 128 40 | minibatch_size: 1536 # 3 * 512 41 | mini_epochs: 4 42 | critic_coef: 1 43 | lr_schedule: None 44 | kl_threshold: 0.05 45 | normalize_input: True 46 | normalize_value: False 47 | use_action_masks: True 48 | ignore_dead_batches : False 49 | 50 | env_config: 51 | name: 3m 52 | frames: 1 53 | transpose: False 54 | random_invalid_step: False 55 | central_value: True 56 | reward_only_positive: True 57 | central_value_config: 58 | minibatch_size: 512 59 | mini_epochs: 4 60 | learning_rate: 5e-4 61 | clip_value: False 62 | normalize_input: True 63 | network: 64 | name: actor_critic 65 | central_value: True 66 | mlp: 67 | units: [256, 128] 68 | activation: relu 69 | initializer: 70 | name: default 71 | scale: 2 72 | regularizer: 73 | name: 'None' -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3m_torch_rnn.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: 'None' 22 | rnn: 23 | name: lstm 24 | units: 128 25 | layers: 1 26 | config: 27 | name: 3m 28 | reward_shaper: 29 | scale_value: 1 30 | normalize_advantage: True 31 | gamma: 0.99 32 | tau: 0.95 33 | learning_rate: 5e-4 34 | score_to_win: 20 35 | grad_norm: 0.5 36 | entropy_coef: 0.001 37 | truncate_grads: True 38 | env_name: smac 39 | e_clip: 0.2 40 | clip_value: True 41 | num_actors: 8 42 | horizon_length: 128 43 | minibatch_size: 1536 44 | mini_epochs: 4 45 | critic_coef: 1 46 | lr_schedule: None 47 | kl_threshold: 0.05 48 | normalize_input: True 49 | seq_length: 4 50 | use_action_masks: True 51 | ignore_dead_batches : False 52 | 53 | env_config: 54 | name: 3m 55 | frames: 1 56 | transpose: False 57 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3m_torch_sa.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: multi_discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | multi_discrete: 14 | 15 | mlp: 16 | units: [256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: None 22 | config: 23 | name: 3m_sa 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.99 28 | tau: 0.95 29 | learning_rate: 5e-4 30 | score_to_win: 20 31 | grad_norm: 0.5 32 | entropy_coef: 0.001 33 | truncate_grads: True 34 | env_name: smac 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 128 39 | minibatch_size: 512 40 | mini_epochs: 4 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: True 45 | use_action_masks: True 46 | ignore_dead_batches : False 47 | 48 | env_config: 49 | name: 3m 50 | frames: 1 51 | transpose: False 52 | random_invalid_step: False 53 | as_single_agent: True 54 | central_value: True -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3s5z_vs_3s6z_torch_cv.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: False 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | mlp: 15 | units: [1024, 512] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: 'None' 21 | 22 | config: 23 | name: 3s5z_vs_3s6z_cv 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.995 28 | tau: 0.95 29 | learning_rate: 5e-4 30 | score_to_win: 20 31 | grad_norm: 0.5 32 | entropy_coef: 0.001 33 | truncate_grads: True 34 | env_name: smac 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 128 39 | minibatch_size: 4096 # 8 * 512 40 | mini_epochs: 4 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: True 45 | use_action_masks: True 46 | ignore_dead_batches : False 47 | 48 | env_config: 49 | name: 3s5z_vs_3s6z 50 | central_value: True 51 | reward_only_positive: False 52 | obs_last_action: True 53 | frames: 1 54 | #reward_negative_scale: 0.9 55 | #apply_agent_ids: True 56 | #flatten: False 57 | 58 | central_value_config: 59 | minibatch_size: 512 60 | mini_epochs: 4 61 | learning_rate: 5e-4 62 | clip_value: True 63 | normalize_input: True 64 | network: 65 | name: actor_critic 66 | central_value: True 67 | mlp: 68 | units: [1024, 512] 69 | activation: relu 70 | initializer: 71 | name: default 72 | scale: 2 73 | regularizer: 74 | name: 'None' -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3s_vs_4z.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c_lstm 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | discrete: 13 | 14 | mlp: 15 | units: [256, 128] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: 'None' 21 | lstm: 22 | units: 128 23 | concated: False 24 | config: 25 | name: sc2_fc 26 | reward_shaper: 27 | scale_value: 1 28 | normalize_advantage: True 29 | gamma: 0.99 30 | tau: 0.95 31 | learning_rate: 1e-4 32 | score_to_win: 1000 33 | grad_norm: 0.5 34 | entropy_coef: 0.005 35 | truncate_grads: True 36 | env_name: smac 37 | e_clip: 0.2 38 | clip_value: True 39 | num_actors: 8 40 | horizon_length: 64 41 | minibatch_size: 1536 42 | mini_epochs: 8 43 | critic_coef: 1 44 | lr_schedule: None 45 | kl_threshold: 0.05 46 | normalize_input: False 47 | seq_length: 4 48 | use_action_masks: True 49 | 50 | env_config: 51 | name: 3s_vs_4z 52 | frames: 1 53 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3s_vs_5z.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c_lstm 7 | 8 | 9 | network: 10 | name: actor_critic 11 | separate: True 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: 'None' 22 | lstm: 23 | units: 128 24 | concated: False 25 | config: 26 | name: 3s_vs_5z 27 | reward_shaper: 28 | scale_value: 1 29 | normalize_advantage: True 30 | gamma: 0.99 31 | tau: 0.95 32 | learning_rate: 1e-4 33 | score_to_win: 1000 34 | grad_norm: 0.5 35 | entropy_coef: 0.001 36 | truncate_grads: True 37 | env_name: smac 38 | e_clip: 0.2 39 | clip_value: True 40 | num_actors: 8 41 | horizon_length: 128 42 | minibatch_size: 1536 #1024 43 | mini_epochs: 4 44 | critic_coef: 1 45 | lr_schedule: None 46 | kl_threshold: 0.05 47 | normalize_input: False 48 | seq_length: 4 49 | use_action_masks: True 50 | 51 | env_config: 52 | name: 3s_vs_5z 53 | frames: 1 54 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3s_vs_5z_cv.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: False 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: 'None' 22 | 23 | config: 24 | name: 3s_vs_5z_cv 25 | reward_shaper: 26 | scale_value: 1 27 | normalize_advantage: True 28 | gamma: 0.99 29 | tau: 0.95 30 | learning_rate: 5e-4 31 | score_to_win: 24 32 | grad_norm: 0.5 33 | entropy_coef: 0.01 34 | truncate_grads: True 35 | env_name: smac 36 | e_clip: 0.2 37 | clip_value: True 38 | num_actors: 8 39 | horizon_length: 128 40 | minibatch_size: 1536 # 3 * 512 41 | mini_epochs: 4 42 | critic_coef: 1 43 | lr_schedule: None 44 | kl_threshold: 0.05 45 | normalize_input: True 46 | use_action_masks: True 47 | max_epochs: 50000 48 | 49 | central_value_config: 50 | minibatch_size: 512 51 | mini_epochs: 4 52 | learning_rate: 5e-4 53 | clip_value: False 54 | normalize_input: True 55 | network: 56 | name: actor_critic 57 | central_value: True 58 | mlp: 59 | units: [512, 256,128] 60 | activation: relu 61 | initializer: 62 | name: default 63 | scale: 2 64 | regularizer: 65 | name: None 66 | 67 | env_config: 68 | name: 3s_vs_5z 69 | frames: 1 70 | transpose: False 71 | random_invalid_step: False 72 | central_value: True 73 | reward_only_positive: True 74 | obs_last_action: True 75 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3s_vs_5z_torch_lstm.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: None 22 | rnn: 23 | name: lstm 24 | units: 64 25 | layers: 1 26 | before_mlp: False 27 | 28 | config: 29 | name: 3s_vs_5z 30 | reward_shaper: 31 | scale_value: 1 32 | normalize_advantage: True 33 | gamma: 0.99 34 | tau: 0.95 35 | learning_rate: 1e-4 36 | score_to_win: 1000 37 | grad_norm: 0.5 38 | entropy_coef: 0.01 39 | truncate_grads: True 40 | env_name: smac 41 | e_clip: 0.2 42 | clip_value: True 43 | num_actors: 8 44 | horizon_length: 256 45 | minibatch_size: 1536 #1024 46 | mini_epochs: 4 47 | critic_coef: 1 48 | lr_schedule: None 49 | kl_threshold: 0.05 50 | normalize_input: True 51 | seq_length: 32 52 | use_action_masks: True 53 | max_epochs: 20000 54 | env_config: 55 | name: 3s_vs_5z 56 | frames: 1 57 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/3s_vs_5z_torch_lstm2.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | discrete: 13 | 14 | mlp: 15 | units: [256, 128] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: 'None' 21 | rnn: 22 | name: lstm 23 | units: 128 24 | layers: 1 25 | before_mlp: False 26 | config: 27 | name: 3s_vs_5z2 28 | reward_shaper: 29 | scale_value: 1 30 | normalize_advantage: True 31 | gamma: 0.99 32 | tau: 0.95 33 | learning_rate: 1e-4 34 | score_to_win: 1000 35 | grad_norm: 0.5 36 | entropy_coef: 0.005 37 | truncate_grads: True 38 | env_name: smac 39 | e_clip: 0.2 40 | clip_value: True 41 | num_actors: 8 42 | horizon_length: 128 43 | minibatch_size: 1536 #1024 44 | mini_epochs: 4 45 | critic_coef: 1 46 | lr_schedule: None 47 | kl_threshold: 0.05 48 | normalize_input: False 49 | seq_length: 4 50 | use_action_masks: True 51 | max_epochs: 20000 52 | env_config: 53 | name: 3s_vs_5z 54 | frames: 1 55 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/5m_vs_6m_rnn.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [512, 256] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: 'None' 22 | rnn: 23 | name: lstm 24 | units: 128 25 | layers: 1 26 | layer_norm: True 27 | config: 28 | name: 5m_vs_6m_rnn 29 | reward_shaper: 30 | scale_value: 1 31 | normalize_advantage: True 32 | gamma: 0.99 33 | tau: 0.95 34 | learning_rate: 1e-4 35 | score_to_win: 20 36 | entropy_coef: 0.005 37 | truncate_grads: True 38 | grad_norm: 1.5 39 | env_name: smac 40 | e_clip: 0.2 41 | clip_value: True 42 | num_actors: 8 43 | horizon_length: 128 44 | minibatch_size: 2560 # 5 * 512 45 | mini_epochs: 4 46 | critic_coef: 1 47 | lr_schedule: None 48 | kl_threshold: 0.05 49 | normalize_input: True 50 | normalize_value: False 51 | use_action_masks: True 52 | seq_length: 8 53 | #max_epochs: 10000 54 | env_config: 55 | name: 5m_vs_6m 56 | central_value: False 57 | reward_only_positive: True 58 | obs_last_action: True 59 | apply_agent_ids: False 60 | 61 | player: 62 | render: False 63 | games_num: 200 64 | n_game_life: 1 65 | deterministic: True 66 | 67 | #reward_negative_scale: 0.1 -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/5m_vs_6m_sa.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: multi_discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | multi_discrete: 13 | 14 | mlp: 15 | units: [512, 256, 128] 16 | activation: relu 17 | initializer: 18 | name: default 19 | 20 | config: 21 | name: 5m_vs_6m_sa 22 | reward_shaper: 23 | scale_value: 1 24 | normalize_advantage: True 25 | gamma: 0.99 26 | tau: 0.95 27 | learning_rate: 3e-4 28 | score_to_win: 20 29 | entropy_coef: 0.02 30 | truncate_grads: True 31 | grad_norm: 1 32 | env_name: smac 33 | e_clip: 0.2 34 | clip_value: False 35 | num_actors: 8 36 | horizon_length: 256 37 | minibatch_size: 1024 38 | mini_epochs: 4 39 | critic_coef: 2 40 | lr_schedule: None 41 | kl_threshold: 0.05 42 | normalize_input: True 43 | normalize_value: False 44 | use_action_masks: True 45 | use_diagnostics: True 46 | seq_length: 8 47 | max_epochs: 10000 48 | env_config: 49 | name: 5m_vs_6m 50 | central_value: True 51 | reward_only_positive: True 52 | obs_last_action: False 53 | apply_agent_ids: False 54 | as_single_agent: True 55 | 56 | player: 57 | render: False 58 | games_num: 200 59 | n_game_life: 1 60 | determenistic: True 61 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/5m_vs_6m_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | load_checkpoint: False 9 | load_path: 'nn/5msmac_cnn.pth' 10 | 11 | network: 12 | name: actor_critic 13 | separate: True 14 | #normalization: layer_norm 15 | space: 16 | discrete: 17 | 18 | cnn: 19 | type: conv1d 20 | activation: relu 21 | initializer: 22 | name: default 23 | regularizer: 24 | name: 'None' 25 | convs: 26 | - filters: 256 27 | kernel_size: 3 28 | strides: 1 29 | padding: 1 30 | - filters: 512 31 | kernel_size: 3 32 | strides: 1 33 | padding: 1 34 | - filters: 1024 35 | kernel_size: 3 36 | strides: 1 37 | padding: 1 38 | mlp: 39 | units: [256, 128] 40 | activation: relu 41 | initializer: 42 | name: default 43 | regularizer: 44 | name: 'None' 45 | config: 46 | name: 5m 47 | reward_shaper: 48 | scale_value: 1 49 | normalize_advantage: True 50 | gamma: 0.99 51 | tau: 0.95 52 | learning_rate: 1e-4 53 | score_to_win: 20 54 | grad_norm: 0.5 55 | entropy_coef: 0.005 56 | truncate_grads: True 57 | env_name: smac_cnn 58 | e_clip: 0.2 59 | clip_value: True 60 | num_actors: 8 61 | horizon_length: 128 62 | minibatch_size: 2560 63 | mini_epochs: 4 64 | critic_coef: 2 65 | lr_schedule: None 66 | kl_threshold: 0.05 67 | normalize_input: True 68 | seq_length: 2 69 | use_action_masks: True 70 | env_config: 71 | name: 5m_vs_6m 72 | frames: 4 73 | transpose: False 74 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/6h_vs_8z_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | mlp: 15 | units: [256, 256] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: 'None' 21 | 22 | config: 23 | name: 6h_vs_8z_separate 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.99 28 | tau: 0.95 29 | learning_rate: 5e-4 30 | score_to_win: 20 31 | grad_norm: 0.5 32 | entropy_coef: 0.002 33 | truncate_grads: True 34 | env_name: smac 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 128 39 | minibatch_size: 3072 # 6 * 512 40 | mini_epochs: 2 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: True 45 | use_action_masks: True 46 | ignore_dead_batches : False 47 | 48 | env_config: 49 | name: 6h_vs_8z 50 | central_value: False 51 | reward_only_positive: False 52 | obs_last_action: True 53 | frames: 1 54 | #flatten: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/8m_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: None 22 | config: 23 | name: 8m 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.99 28 | tau: 0.95 29 | learning_rate: 5e-4 30 | score_to_win: 20 31 | grad_norm: 0.5 32 | entropy_coef: 0.001 33 | truncate_grads: True 34 | env_name: smac 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 128 39 | minibatch_size: 4096 40 | mini_epochs: 4 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: True 45 | seq_length: 2 46 | use_action_masks: True 47 | ignore_dead_batches : False 48 | max_epochs: 10000 49 | env_config: 50 | name: 8m 51 | frames: 1 52 | transpose: False 53 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/8m_torch_cv.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: False 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: None 22 | 23 | config: 24 | name: 8m_cv 25 | reward_shaper: 26 | scale_value: 1 27 | normalize_advantage: True 28 | gamma: 0.99 29 | tau: 0.95 30 | learning_rate: 5e-4 31 | score_to_win: 20 32 | grad_norm: 0.5 33 | entropy_coef: 0.001 34 | truncate_grads: True 35 | env_name: smac 36 | e_clip: 0.2 37 | clip_value: True 38 | num_actors: 8 39 | horizon_length: 128 40 | minibatch_size: 4096 # 3 * 512 41 | mini_epochs: 4 42 | critic_coef: 1 43 | lr_schedule: None 44 | kl_threshold: 0.05 45 | normalize_input: True 46 | seq_length: 2 47 | use_action_masks: True 48 | ignore_dead_batches : False 49 | max_epochs: 10000 50 | 51 | central_value_config: 52 | minibatch_size: 512 53 | mini_epochs: 4 54 | learning_rate: 5e-4 55 | clip_value: False 56 | normalize_input: True 57 | network: 58 | name: actor_critic 59 | central_value: True 60 | mlp: 61 | units: [512, 256,128] 62 | activation: relu 63 | initializer: 64 | name: default 65 | scale: 2 66 | regularizer: 67 | name: None 68 | 69 | env_config: 70 | name: 8m 71 | frames: 1 72 | transpose: False 73 | random_invalid_step: False 74 | central_value: True 75 | reward_only_positive: False 76 | obs_last_action: True 77 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/MMM2_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | cnn: 16 | type: conv1d 17 | activation: relu 18 | initializer: 19 | name: default 20 | scale: 1.3 21 | regularizer: 22 | name: 'None' 23 | convs: 24 | - filters: 64 25 | kernel_size: 3 26 | strides: 2 27 | padding: 0 28 | - filters: 128 29 | kernel_size: 3 30 | strides: 1 31 | padding: 0 32 | - filters: 256 33 | kernel_size: 3 34 | strides: 1 35 | padding: 0 36 | mlp: 37 | units: [256, 128] 38 | activation: relu 39 | initializer: 40 | name: default 41 | regularizer: 42 | name: 'None' 43 | config: 44 | name: MMM2_cnn 45 | reward_shaper: 46 | scale_value: 1.3 47 | normalize_advantage: True 48 | gamma: 0.99 49 | tau: 0.95 50 | learning_rate: 1e-4 51 | score_to_win: 20 52 | grad_norm: 0.5 53 | entropy_coef: 0.005 54 | truncate_grads: True 55 | env_name: smac_cnn 56 | e_clip: 0.2 57 | clip_value: True 58 | num_actors: 8 59 | horizon_length: 64 60 | minibatch_size: 2560 61 | mini_epochs: 1 62 | critic_coef: 2 63 | lr_schedule: None 64 | kl_threshold: 0.05 65 | normalize_input: False 66 | use_action_masks: True 67 | 68 | env_config: 69 | name: MMM2 70 | frames: 4 71 | transpose: False # for pytorch transpose == not Transpose in tf 72 | random_invalid_step: False 73 | replay_save_freq: 100 -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/corridor_torch.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | cnn: 16 | type: conv1d 17 | activation: relu 18 | initializer: 19 | name: glorot_uniform_initializer 20 | gain: 1.4241 21 | regularizer: 22 | name: None 23 | convs: 24 | - filters: 64 25 | kernel_size: 3 26 | strides: 2 27 | padding: 1 28 | - filters: 128 29 | kernel_size: 3 30 | strides: 1 31 | padding: 0 32 | - filters: 256 33 | kernel_size: 3 34 | strides: 1 35 | padding: 0 36 | mlp: 37 | units: [256, 128] 38 | activation: relu 39 | initializer: 40 | name: default 41 | regularizer: 42 | name: None 43 | 44 | config: 45 | name: corridor_cnn 46 | reward_shaper: 47 | scale_value: 1 48 | normalize_advantage: True 49 | gamma: 0.99 50 | tau: 0.95 51 | learning_rate: 1e-4 52 | score_to_win: 20 53 | grad_norm: 0.5 54 | entropy_coef: 0.005 55 | truncate_grads: True 56 | env_name: smac_cnn 57 | e_clip: 0.2 58 | clip_value: True 59 | num_actors: 8 60 | horizon_length: 128 61 | minibatch_size: 3072 62 | mini_epochs: 1 63 | critic_coef: 2 64 | lr_schedule: None 65 | kl_threshold: 0.05 66 | normalize_input: False 67 | seq_length: 2 68 | use_action_masks: True 69 | ignore_dead_batches: False 70 | 71 | env_config: 72 | name: corridor 73 | frames: 4 74 | transpose: False 75 | random_invalid_step: False -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/corridor_torch_cv.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: False 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | mlp: 15 | units: [512, 256, 128] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: None 21 | 22 | config: 23 | name: corridor_cv 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.995 28 | tau: 0.95 29 | learning_rate: 3e-4 30 | score_to_win: 20 31 | grad_norm: 0.5 32 | entropy_coef: 0.005 33 | truncate_grads: True 34 | env_name: smac 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 128 39 | minibatch_size: 3072 # 6 * 512 40 | mini_epochs: 4 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: True 45 | use_action_masks: True 46 | ignore_dead_batches : False 47 | 48 | env_config: 49 | name: corridor 50 | central_value: True 51 | reward_only_positive: False 52 | obs_last_action: True 53 | frames: 1 54 | reward_negative_scale: 0.05 55 | #apply_agent_ids: True 56 | #flatten: False 57 | 58 | central_value_config: 59 | minibatch_size: 512 60 | mini_epochs: 4 61 | learning_rate: 3e-4 62 | clip_value: False 63 | normalize_input: True 64 | network: 65 | name: actor_critic 66 | central_value: True 67 | mlp: 68 | units: [512, 256, 128] 69 | activation: relu 70 | initializer: 71 | name: default 72 | scale: 2 73 | regularizer: 74 | name: None -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/runs/2c_vs_64zg.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | load_checkpoint: False 9 | load_path: 'nn/2c_vs_64zg' 10 | 11 | network: 12 | name: actor_critic 13 | separate: True 14 | space: 15 | discrete: 16 | 17 | mlp: 18 | units: [512, 256, 128] 19 | activation: relu 20 | initializer: 21 | name: default 22 | regularizer: 23 | name: 'None' 24 | config: 25 | name: 2c_vs_64zg 26 | reward_shaper: 27 | scale_value: 1 28 | normalize_advantage: True 29 | gamma: 0.99 30 | tau: 0.95 31 | learning_rate: 5e-4 32 | score_to_win: 1000 33 | grad_norm: 0.5 34 | entropy_coef: 0.005 35 | truncate_grads: True 36 | env_name: smac 37 | e_clip: 0.2 38 | clip_value: True 39 | num_actors: 8 40 | horizon_length: 128 41 | minibatch_size: 1024 42 | mini_epochs: 4 43 | critic_coef: 1 44 | lr_schedule: None 45 | kl_threshold: 0.05 46 | normalize_input: True 47 | use_action_masks: True 48 | 49 | env_config: 50 | name: 2c_vs_64zg 51 | frames: 1 52 | random_invalid_step: False 53 | central_value: True 54 | reward_only_positive: True 55 | state_last_action: True 56 | 57 | central_value_config: 58 | minibatch_size: 512 59 | mini_epochs: 4 60 | learning_rate: 5e-4 61 | clip_value: False 62 | normalize_input: True 63 | network: 64 | name: actor_critic 65 | central_value: True 66 | mlp: 67 | units: [512, 256, 128] 68 | activation: relu 69 | initializer: 70 | name: default 71 | scale: 2 72 | regularizer: 73 | name: 'None' 74 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/runs/2c_vs_64zg_neg.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | discrete: 13 | 14 | mlp: 15 | units: [512, 256, 128] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: 'None' 21 | config: 22 | name: 2c_vs_64zg_neg 23 | reward_shaper: 24 | scale_value: 1 25 | normalize_advantage: True 26 | gamma: 0.99 27 | tau: 0.95 28 | learning_rate: 5e-4 29 | score_to_win: 1000 30 | grad_norm: 0.5 31 | entropy_coef: 0.005 32 | truncate_grads: True 33 | env_name: smac 34 | e_clip: 0.2 35 | clip_value: True 36 | num_actors: 8 37 | horizon_length: 128 38 | minibatch_size: 1024 39 | mini_epochs: 4 40 | critic_coef: 1 41 | lr_schedule: None 42 | kl_threshold: 0.05 43 | normalize_input: True 44 | use_action_masks: True 45 | 46 | env_config: 47 | name: 2c_vs_64zg 48 | frames: 1 49 | random_invalid_step: False 50 | central_value: True 51 | reward_only_positive: False 52 | state_last_action: True 53 | 54 | central_value_config: 55 | minibatch_size: 512 56 | mini_epochs: 4 57 | learning_rate: 5e-4 58 | clip_value: False 59 | normalize_input: True 60 | network: 61 | name: actor_critic 62 | central_value: True 63 | mlp: 64 | units: [512, 256, 128] 65 | activation: relu 66 | initializer: 67 | name: default 68 | scale: 2 69 | regularizer: 70 | name: 'None' 71 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/runs/2s_vs_1c.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | 9 | network: 10 | name: actor_critic 11 | separate: True 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: 'None' 22 | config: 23 | name: 2s_vs_1sc_cv_neg 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.99 28 | tau: 0.95 29 | learning_rate: 5e-4 30 | score_to_win: 1000 31 | grad_norm: 0.5 32 | entropy_coef: 0.005 33 | truncate_grads: True 34 | env_name: smac 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 128 39 | minibatch_size: 1024 40 | mini_epochs: 4 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: True 45 | use_action_masks: True 46 | 47 | env_config: 48 | name: 2s_vs_1sc 49 | frames: 1 50 | random_invalid_step: False 51 | central_value: True 52 | reward_only_positive: True 53 | state_last_action: True 54 | 55 | central_value_config: 56 | minibatch_size: 512 57 | mini_epochs: 4 58 | learning_rate: 5e-4 59 | clip_value: False 60 | normalize_input: True 61 | network: 62 | name: actor_critic 63 | central_value: True 64 | mlp: 65 | units: [256, 128] 66 | activation: relu 67 | initializer: 68 | name: default 69 | scale: 2 70 | regularizer: 71 | name: 'None' 72 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/runs/2s_vs_1c_neg.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | space: 12 | discrete: 13 | 14 | mlp: 15 | units: [256, 128] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: None 21 | 22 | config: 23 | name: 2s_vs_1sc_cv_neg 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.99 28 | tau: 0.95 29 | learning_rate: 5e-4 30 | score_to_win: 1000 31 | grad_norm: 0.5 32 | entropy_coef: 0.005 33 | truncate_grads: True 34 | env_name: smac 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 128 39 | minibatch_size: 1024 40 | mini_epochs: 4 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: True 45 | use_action_masks: True 46 | 47 | env_config: 48 | name: 2s_vs_1sc 49 | frames: 1 50 | random_invalid_step: False 51 | central_value: True 52 | reward_only_positive: False 53 | state_last_action: True 54 | 55 | central_value_config: 56 | minibatch_size: 512 57 | mini_epochs: 4 58 | learning_rate: 5e-4 59 | clip_value: False 60 | normalize_input: True 61 | network: 62 | name: actor_critic 63 | central_value: True 64 | mlp: 65 | units: [256, 128] 66 | activation: relu 67 | initializer: 68 | name: default 69 | scale: 2 70 | regularizer: 71 | name: None -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/runs/3s_vs_5z.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: False 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [512, 256, 128] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: 'None' 22 | 23 | config: 24 | name: 3s_vs_5z_cv 25 | reward_shaper: 26 | scale_value: 1 27 | normalize_advantage: True 28 | gamma: 0.99 29 | tau: 0.95 30 | learning_rate: 1e-4 31 | score_to_win: 20 32 | grad_norm: 0.5 33 | entropy_coef: 0.005 34 | truncate_grads: True 35 | env_name: smac 36 | e_clip: 0.2 37 | clip_value: True 38 | num_actors: 8 39 | horizon_length: 128 40 | minibatch_size: 1536 # 3 * 512 41 | mini_epochs: 4 42 | critic_coef: 1 43 | lr_schedule: None 44 | kl_threshold: 0.05 45 | normalize_input: True 46 | use_action_masks: True 47 | max_epochs: 50000 48 | 49 | 50 | central_value_config: 51 | minibatch_size: 512 52 | mini_epochs: 4 53 | learning_rate: 1e-4 54 | clip_value: False 55 | normalize_input: True 56 | network: 57 | name: actor_critic 58 | central_value: True 59 | mlp: 60 | units: [512, 256,128] 61 | activation: relu 62 | initializer: 63 | name: default 64 | scale: 2 65 | regularizer: 66 | name: 'None' 67 | 68 | env_config: 69 | name: 3s_vs_5z 70 | frames: 1 71 | transpose: False 72 | random_invalid_step: False 73 | central_value: True 74 | reward_only_positive: True 75 | obs_last_action: True 76 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/runs/3s_vs_5z_neg.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | #normalization: layer_norm 13 | space: 14 | discrete: 15 | 16 | mlp: 17 | units: [256, 128] 18 | activation: relu 19 | initializer: 20 | name: default 21 | regularizer: 22 | name: None 23 | 24 | config: 25 | name: 3s_vs_5z_neg 26 | reward_shaper: 27 | scale_value: 1 28 | normalize_advantage: True 29 | gamma: 0.99 30 | tau: 0.95 31 | learning_rate: 5e-4 32 | score_to_win: 20 33 | grad_norm: 0.5 34 | entropy_coef: 0.005 35 | truncate_grads: True 36 | env_name: smac 37 | e_clip: 0.2 38 | clip_value: True 39 | num_actors: 8 40 | horizon_length: 128 41 | minibatch_size: 1536 # 3 * 512 42 | mini_epochs: 4 43 | critic_coef: 1 44 | lr_schedule: None 45 | kl_threshold: 0.05 46 | normalize_input: True 47 | use_action_masks: True 48 | max_epochs: 50000 49 | 50 | central_value_config: 51 | minibatch_size: 512 52 | mini_epochs: 4 53 | learning_rate: 5e-4 54 | clip_value: False 55 | normalize_input: True 56 | network: 57 | name: actor_critic 58 | central_value: True 59 | mlp: 60 | units: [256,128] 61 | activation: relu 62 | initializer: 63 | name: default 64 | scale: 2 65 | regularizer: 66 | name: None 67 | 68 | env_config: 69 | name: 3s_vs_5z 70 | frames: 1 71 | transpose: False 72 | random_invalid_step: False 73 | central_value: True 74 | reward_only_positive: False 75 | obs_last_action: True 76 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/runs/6h_vs_8z.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: False 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | mlp: 15 | units: [512, 256] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: None 21 | 22 | config: 23 | name: 6h_vs_8z 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.99 28 | tau: 0.95 29 | learning_rate: 5e-4 30 | score_to_win: 20 31 | grad_norm: 0.5 32 | entropy_coef: 0.002 33 | truncate_grads: True 34 | env_name: smac 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 128 39 | minibatch_size: 3072 # 6 * 512 40 | mini_epochs: 4 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: True 45 | use_action_masks: True 46 | 47 | env_config: 48 | name: 6h_vs_8z 49 | central_value: True 50 | reward_only_positive: True 51 | obs_last_action: True 52 | frames: 1 53 | #reward_negative_scale: 0.9 54 | #apply_agent_ids: True 55 | #flatten: False 56 | 57 | central_value_config: 58 | minibatch_size: 512 59 | mini_epochs: 4 60 | learning_rate: 5e-4 61 | clip_value: False 62 | normalize_input: True 63 | network: 64 | name: actor_critic 65 | central_value: True 66 | mlp: 67 | units: [512, 256, 128] 68 | activation: relu 69 | initializer: 70 | name: default 71 | scale: 2 72 | regularizer: 73 | name: None -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/runs/6h_vs_8z_neg.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: False 11 | #normalization: layer_norm 12 | space: 13 | discrete: 14 | mlp: 15 | units: [512, 256, 128] 16 | activation: relu 17 | initializer: 18 | name: default 19 | regularizer: 20 | name: None 21 | 22 | config: 23 | name: 6h_vs_8z_neg 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.99 28 | tau: 0.95 29 | learning_rate: 1e-4 30 | score_to_win: 20 31 | grad_norm: 0.5 32 | entropy_coef: 0.005 33 | truncate_grads: True 34 | env_name: smac 35 | e_clip: 0.2 36 | clip_value: True 37 | num_actors: 8 38 | horizon_length: 128 39 | minibatch_size: 3072 # 6 * 512 40 | mini_epochs: 4 41 | critic_coef: 1 42 | lr_schedule: None 43 | kl_threshold: 0.05 44 | normalize_input: True 45 | use_action_masks: True 46 | 47 | env_config: 48 | name: 6h_vs_8z 49 | central_value: True 50 | reward_only_positive: False 51 | obs_last_action: True 52 | frames: 1 53 | #reward_negative_scale: 0.9 54 | #apply_agent_ids: True 55 | #flatten: False 56 | 57 | central_value_config: 58 | minibatch_size: 512 59 | mini_epochs: 4 60 | learning_rate: 1e-4 61 | clip_value: False 62 | normalize_input: True 63 | network: 64 | name: actor_critic 65 | central_value: True 66 | mlp: 67 | units: [512, 256, 128] 68 | activation: relu 69 | initializer: 70 | name: default 71 | scale: 2 72 | regularizer: 73 | name: None -------------------------------------------------------------------------------- /rl_games/configs/smac/v1/runs/corridor_cv.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | load_checkpoint: False 9 | load_path: 'nn/corridor_cv.pth' 10 | 11 | network: 12 | name: actor_critic 13 | separate: False 14 | #normalization: layer_norm 15 | space: 16 | discrete: 17 | mlp: 18 | units: [512, 256, 128] 19 | activation: relu 20 | initializer: 21 | name: default 22 | regularizer: 23 | name: None 24 | 25 | config: 26 | name: corridor_cv 27 | reward_shaper: 28 | scale_value: 1 29 | normalize_advantage: True 30 | gamma: 0.995 31 | tau: 0.95 32 | learning_rate: 3e-4 33 | score_to_win: 20 34 | grad_norm: 0.5 35 | entropy_coef: 0.005 36 | truncate_grads: True 37 | env_name: smac 38 | e_clip: 0.2 39 | clip_value: True 40 | num_actors: 8 41 | horizon_length: 128 42 | minibatch_size: 3072 # 6 * 512 43 | mini_epochs: 4 44 | critic_coef: 1 45 | lr_schedule: None 46 | kl_threshold: 0.05 47 | normalize_input: True 48 | use_action_masks: True 49 | 50 | env_config: 51 | name: corridor 52 | central_value: True 53 | reward_only_positive: True 54 | obs_last_action: True 55 | frames: 1 56 | #apply_agent_ids: True 57 | #flatten: False 58 | 59 | central_value_config: 60 | minibatch_size: 512 61 | mini_epochs: 4 62 | learning_rate: 3e-4 63 | clip_value: False 64 | normalize_input: True 65 | network: 66 | name: actor_critic 67 | central_value: True 68 | mlp: 69 | units: [512, 256, 128] 70 | activation: relu 71 | initializer: 72 | name: default 73 | scale: 2 74 | regularizer: 75 | name: None 76 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v2/env_configs/sc2_gen_protoss.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_protoss" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | use_unit_ranges: True 28 | min_attack_range: 2 29 | obs_own_pos: True 30 | num_fov_actions: 12 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "stalker" 38 | - "zealot" 39 | - "colossus" 40 | weights: 41 | - 0.45 42 | - 0.45 43 | - 0.1 44 | observe: True 45 | start_positions: 46 | dist_type: "surrounded_and_reflect" 47 | p: 0.5 48 | map_x: 32 49 | map_y: 32 50 | 51 | # enemy_mask: 52 | # dist_type: "mask" 53 | # mask_probability: 0.5 54 | # n_enemies: 5 55 | state_last_action: True 56 | state_timestep_number: False 57 | step_mul: 8 58 | heuristic_ai: False 59 | # heuristic_rest: False 60 | debug: False 61 | prob_obs_enemy: 1.0 62 | action_mask: True 63 | 64 | test_nepisode: 32 65 | test_interval: 10000 66 | log_interval: 2000 67 | runner_log_interval: 2000 68 | learner_log_interval: 2000 69 | t_max: 10050000 70 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v2/env_configs/sc2_gen_protoss_epo.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_protoss" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | use_unit_ranges: True 28 | min_attack_range: 2 29 | obs_own_pos: True 30 | num_fov_actions: 12 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "stalker" 38 | - "zealot" 39 | - "colossus" 40 | weights: 41 | - 0.45 42 | - 0.45 43 | - 0.1 44 | observe: True 45 | start_positions: 46 | dist_type: "surrounded_and_reflect" 47 | p: 0.5 48 | map_x: 32 49 | map_y: 32 50 | 51 | # enemy_mask: 52 | # dist_type: "mask" 53 | # mask_probability: 0.5 54 | # n_enemies: 5 55 | state_last_action: True 56 | state_timestep_number: False 57 | step_mul: 8 58 | heuristic_ai: False 59 | # heuristic_rest: False 60 | debug: False 61 | # Most severe partial obs setting: 62 | prob_obs_enemy: 0.0 63 | action_mask: False 64 | 65 | test_nepisode: 32 66 | test_interval: 10000 67 | log_interval: 2000 68 | runner_log_interval: 2000 69 | learner_log_interval: 2000 70 | t_max: 10050000 -------------------------------------------------------------------------------- /rl_games/configs/smac/v2/env_configs/sc2_gen_terran.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_terran" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | obs_own_pos: True 28 | use_unit_ranges: True 29 | min_attack_range: 2 30 | num_fov_actions: 12 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "marine" 38 | - "marauder" 39 | - "medivac" 40 | weights: 41 | - 0.45 42 | - 0.45 43 | - 0.1 44 | exception_unit_types: 45 | - "medivac" 46 | observe: True 47 | 48 | start_positions: 49 | dist_type: "surrounded_and_reflect" 50 | p: 0.5 51 | map_x: 32 52 | map_y: 32 53 | # enemy_mask: 54 | # dist_type: "mask" 55 | # mask_probability: 0.5 56 | # n_enemies: 5 57 | state_last_action: True 58 | state_timestep_number: False 59 | step_mul: 8 60 | heuristic_ai: False 61 | # heuristic_rest: False 62 | debug: False 63 | prob_obs_enemy: 1.0 64 | action_mask: True 65 | 66 | test_nepisode: 32 67 | test_interval: 10000 68 | log_interval: 2000 69 | runner_log_interval: 2000 70 | learner_log_interval: 2000 71 | t_max: 10050000 72 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v2/env_configs/sc2_gen_terran_epo.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_terran" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | obs_own_pos: True 28 | use_unit_ranges: True 29 | min_attack_range: 2 30 | num_fov_actions: 12 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "marine" 38 | - "marauder" 39 | - "medivac" 40 | weights: 41 | - 0.45 42 | - 0.45 43 | - 0.1 44 | exception_unit_types: 45 | - "medivac" 46 | observe: True 47 | 48 | start_positions: 49 | dist_type: "surrounded_and_reflect" 50 | p: 0.5 51 | map_x: 32 52 | map_y: 32 53 | # enemy_mask: 54 | # dist_type: "mask" 55 | # mask_probability: 0.5 56 | # n_enemies: 5 57 | state_last_action: True 58 | state_timestep_number: False 59 | step_mul: 8 60 | heuristic_ai: False 61 | # heuristic_rest: False 62 | debug: False 63 | # Most severe partial obs setting: 64 | prob_obs_enemy: 0.0 65 | action_mask: False 66 | 67 | test_nepisode: 32 68 | test_interval: 10000 69 | log_interval: 2000 70 | runner_log_interval: 2000 71 | learner_log_interval: 2000 72 | t_max: 10050000 73 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v2/env_configs/sc2_gen_zerg.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_zerg" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | use_unit_ranges: True 28 | min_attack_range: 2 29 | num_fov_actions: 12 30 | obs_own_pos: True 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "zergling" 38 | - "baneling" 39 | - "hydralisk" 40 | weights: 41 | - 0.45 42 | - 0.1 43 | - 0.45 44 | exception_unit_types: 45 | - "baneling" 46 | observe: True 47 | 48 | start_positions: 49 | dist_type: "surrounded_and_reflect" 50 | p: 0.5 51 | map_x: 32 52 | map_y: 32 53 | # enemy_mask: 54 | # dist_type: "mask" 55 | # mask_probability: 0.5 56 | # n_enemies: 5 57 | state_last_action: True 58 | state_timestep_number: False 59 | step_mul: 8 60 | heuristic_ai: False 61 | # heuristic_rest: False 62 | debug: False 63 | prob_obs_enemy: 1.0 64 | action_mask: True 65 | 66 | test_nepisode: 32 67 | test_interval: 10000 68 | log_interval: 2000 69 | runner_log_interval: 2000 70 | learner_log_interval: 2000 71 | t_max: 10050000 72 | -------------------------------------------------------------------------------- /rl_games/configs/smac/v2/env_configs/sc2_gen_zerg_epo.yaml: -------------------------------------------------------------------------------- 1 | env: sc2wrapped 2 | 3 | env_args: 4 | continuing_episode: False 5 | difficulty: "7" 6 | game_version: null 7 | map_name: "10gen_zerg" 8 | move_amount: 2 9 | obs_all_health: True 10 | obs_instead_of_state: False 11 | obs_last_action: False 12 | obs_own_health: True 13 | obs_pathing_grid: False 14 | obs_terrain_height: False 15 | obs_timestep_number: False 16 | reward_death_value: 10 17 | reward_defeat: 0 18 | reward_negative_scale: 0.5 19 | reward_only_positive: True 20 | reward_scale: True 21 | reward_scale_rate: 20 22 | reward_sparse: False 23 | reward_win: 200 24 | replay_dir: "" 25 | replay_prefix: "" 26 | conic_fov: False 27 | use_unit_ranges: True 28 | min_attack_range: 2 29 | num_fov_actions: 12 30 | obs_own_pos: True 31 | capability_config: 32 | n_units: 5 33 | n_enemies: 5 34 | team_gen: 35 | dist_type: "weighted_teams" 36 | unit_types: 37 | - "zergling" 38 | - "baneling" 39 | - "hydralisk" 40 | weights: 41 | - 0.45 42 | - 0.1 43 | - 0.45 44 | exception_unit_types: 45 | - "baneling" 46 | observe: True 47 | 48 | start_positions: 49 | dist_type: "surrounded_and_reflect" 50 | p: 0.5 51 | map_x: 32 52 | map_y: 32 53 | # enemy_mask: 54 | # dist_type: "mask" 55 | # mask_probability: 0.5 56 | # n_enemies: 5 57 | state_last_action: True 58 | state_timestep_number: False 59 | step_mul: 8 60 | heuristic_ai: False 61 | # heuristic_rest: False 62 | debug: False 63 | # most severe partial obs setting: 64 | prob_obs_enemy: 0.0 65 | action_mask: False 66 | 67 | test_nepisode: 32 68 | test_interval: 10000 69 | log_interval: 2000 70 | runner_log_interval: 2000 71 | learner_log_interval: 2000 72 | t_max: 10050000 -------------------------------------------------------------------------------- /rl_games/configs/test/test_asymmetric_discrete_mhv_mops.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: testnet 10 | config: 11 | reward_shaper: 12 | scale_value: 1 13 | normalize_advantage: True 14 | gamma: 0.99 15 | tau: 0.9 16 | learning_rate: 2e-4 17 | name: test_md_multi_obs 18 | score_to_win: 0.95 19 | grad_norm: 10.5 20 | entropy_coef: 0.005 21 | truncate_grads: True 22 | env_name: test_env 23 | e_clip: 0.2 24 | clip_value: False 25 | num_actors: 16 26 | horizon_length: 256 27 | minibatch_size: 2048 28 | mini_epochs: 4 29 | critic_coef: 1 30 | lr_schedule: None 31 | kl_threshold: 0.008 32 | normalize_input: False 33 | normalize_value: False 34 | weight_decay: 0.0000 35 | max_epochs: 10000 36 | seq_length: 16 37 | save_best_after: 10 38 | save_frequency: 20 39 | 40 | env_config: 41 | name: TestRnnEnv-v0 42 | hide_object: False 43 | apply_dist_reward: False 44 | min_dist: 2 45 | max_dist: 8 46 | use_central_value: True 47 | multi_obs_space: True 48 | multi_head_value: False 49 | player: 50 | games_num: 100 51 | deterministic: True 52 | 53 | central_value_config: 54 | minibatch_size: 512 55 | mini_epochs: 4 56 | learning_rate: 5e-4 57 | clip_value: False 58 | normalize_input: False 59 | truncate_grads: True 60 | grad_norm: 10 61 | network: 62 | name: testnet 63 | central_value: True 64 | mlp: 65 | units: [64,32] 66 | activation: relu 67 | initializer: 68 | name: default -------------------------------------------------------------------------------- /rl_games/configs/test/test_discrete.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: 'layer_norm' 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [32,32] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: None 22 | 23 | config: 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.99 28 | tau: 0.9 29 | learning_rate: 2e-4 30 | name: test_md 31 | score_to_win: 0.95 32 | grad_norm: 10.5 33 | entropy_coef: 0.005 34 | truncate_grads: True 35 | env_name: test_env 36 | e_clip: 0.2 37 | clip_value: False 38 | num_actors: 16 39 | horizon_length: 512 40 | minibatch_size: 2048 41 | mini_epochs: 4 42 | critic_coef: 1 43 | lr_schedule: None 44 | kl_threshold: 0.008 45 | normalize_input: True 46 | weight_decay: 0.0000 47 | max_epochs: 10000 48 | 49 | env_config: 50 | name: TestRnnEnv-v0 51 | hide_object: False 52 | apply_dist_reward: True 53 | min_dist: 2 54 | max_dist: 8 55 | use_central_value: True 56 | multi_discrete_space: False 57 | multi_head_value: False 58 | player: 59 | games_num: 100 60 | deterministic: True 61 | 62 | -------------------------------------------------------------------------------- /rl_games/configs/test/test_discrete_multidiscrete_mhv.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: multi_discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: 'layer_norm' 12 | space: 13 | multi_discrete: 14 | 15 | mlp: 16 | units: [32,32] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: None 22 | 23 | config: 24 | reward_shaper: 25 | scale_value: 1 26 | normalize_advantage: True 27 | gamma: 0.99 28 | tau: 0.9 29 | learning_rate: 2e-4 30 | name: test_md_mhv 31 | score_to_win: 0.95 32 | grad_norm: 10.5 33 | entropy_coef: 0.005 34 | truncate_grads: True 35 | env_name: test_env 36 | e_clip: 0.2 37 | clip_value: False 38 | num_actors: 16 39 | horizon_length: 512 40 | minibatch_size: 2048 41 | mini_epochs: 4 42 | critic_coef: 1 43 | lr_schedule: None 44 | kl_threshold: 0.008 45 | normalize_input: False 46 | weight_decay: 0.0000 47 | max_epochs: 10000 48 | 49 | env_config: 50 | name: TestRnnEnv-v0 51 | hide_object: False 52 | apply_dist_reward: False 53 | min_dist: 2 54 | max_dist: 8 55 | use_central_value: False 56 | multi_discrete_space: True 57 | multi_head_value: False 58 | player: 59 | games_num: 100 60 | deterministic: True 61 | 62 | -------------------------------------------------------------------------------- /rl_games/configs/test/test_discrite_testnet_aux_loss.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: testnet_aux_loss 10 | config: 11 | reward_shaper: 12 | scale_value: 1 13 | normalize_advantage: True 14 | gamma: 0.99 15 | tau: 0.9 16 | learning_rate: 2e-4 17 | name: test_md_multi_obs 18 | score_to_win: 0.95 19 | grad_norm: 10.5 20 | entropy_coef: 0.005 21 | truncate_grads: True 22 | env_name: test_env 23 | e_clip: 0.2 24 | clip_value: False 25 | num_actors: 16 26 | horizon_length: 256 27 | minibatch_size: 2048 28 | mini_epochs: 4 29 | critic_coef: 1 30 | lr_schedule: None 31 | kl_threshold: 0.008 32 | normalize_input: False 33 | normalize_value: False 34 | weight_decay: 0.0000 35 | max_epochs: 10000 36 | seq_length: 16 37 | save_best_after: 10 38 | save_frequency: 20 39 | 40 | env_config: 41 | name: TestRnnEnv-v0 42 | hide_object: False 43 | apply_dist_reward: False 44 | min_dist: 2 45 | max_dist: 8 46 | use_central_value: True 47 | multi_obs_space: True 48 | multi_head_value: False 49 | aux_loss: True 50 | player: 51 | games_num: 100 52 | deterministic: True -------------------------------------------------------------------------------- /rl_games/configs/test/test_ppo_walker_truncated_time.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 8 3 | algo: 4 | name: a2c_continuous 5 | 6 | model: 7 | name: continuous_a2c_logstd 8 | 9 | network: 10 | name: actor_critic 11 | separate: False 12 | space: 13 | continuous: 14 | mu_activation: None 15 | sigma_activation: None 16 | mu_init: 17 | name: default 18 | sigma_init: 19 | name: const_initializer 20 | val: 0 21 | fixed_sigma: True 22 | mlp: 23 | units: [256, 128, 64] 24 | d2rl: False 25 | activation: relu 26 | initializer: 27 | name: default 28 | scale: 2 29 | 30 | config: 31 | name: walker_truncated_step_1000 32 | reward_shaper: 33 | min_val: -1 34 | scale_value: 0.1 35 | 36 | normalize_input: True 37 | normalize_advantage: True 38 | normalize_value: True 39 | value_bootstrap: True 40 | gamma: 0.995 41 | tau: 0.95 42 | 43 | learning_rate: 3e-4 44 | lr_schedule: adaptive 45 | kl_threshold: 0.005 46 | 47 | score_to_win: 300 48 | grad_norm: 0.5 49 | entropy_coef: 0 50 | truncate_grads: True 51 | env_name: BipedalWalker-v3 52 | e_clip: 0.2 53 | clip_value: False 54 | num_actors: 16 55 | horizon_length: 256 56 | minibatch_size: 256 57 | mini_epochs: 4 58 | critic_coef: 2 59 | 60 | bounds_loss_coef: 0.00 61 | max_epochs: 10000 62 | #weight_decay: 0.0001 63 | 64 | env_config: 65 | steps_limit: 1000 66 | 67 | player: 68 | render: True 69 | deterministic: True 70 | games_num: 200 71 | -------------------------------------------------------------------------------- /rl_games/configs/test/test_rnn.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | algo: 3 | name: a2c_discrete 4 | 5 | model: 6 | name: discrete_a2c 7 | 8 | network: 9 | name: actor_critic 10 | separate: True 11 | #normalization: 'layer_norm' 12 | space: 13 | discrete: 14 | 15 | mlp: 16 | units: [64] 17 | activation: relu 18 | initializer: 19 | name: default 20 | regularizer: 21 | name: None 22 | rnn: 23 | name: lstm 24 | #layer_norm: True 25 | units: 64 26 | layers: 1 27 | before_mlp: False 28 | config: 29 | reward_shaper: 30 | scale_value: 1 31 | normalize_advantage: True 32 | gamma: 0.99 33 | tau: 0.9 34 | learning_rate: 2e-4 35 | name: test_rnn 36 | score_to_win: 0.95 37 | grad_norm: 10.5 38 | entropy_coef: 0.005 39 | truncate_grads: True 40 | env_name: test_env 41 | e_clip: 0.2 42 | clip_value: False 43 | num_actors: 16 44 | horizon_length: 512 45 | minibatch_size: 2048 46 | mini_epochs: 4 47 | critic_coef: 1 48 | lr_schedule: None 49 | kl_threshold: 0.008 50 | normalize_input: False 51 | seq_length: 32 52 | weight_decay: 0.0000 53 | max_epochs: 10000 54 | 55 | env_config: 56 | name: TestRnnEnv-v0 57 | hide_object: True 58 | apply_dist_reward: False 59 | min_dist: 2 60 | max_dist: 8 61 | use_central_value: False 62 | 63 | player: 64 | games_num: 100 65 | deterministic: True 66 | 67 | -------------------------------------------------------------------------------- /rl_games/configs/test/test_rnn_multidiscrete_mhv.yaml: -------------------------------------------------------------------------------- 1 | params: 2 | seed: 322 3 | algo: 4 | name: a2c_discrete 5 | 6 | model: 7 | name: multi_discrete_a2c 8 | 9 | network: 10 | name: actor_critic 11 | separate: True 12 | #normalization: 'layer_norm' 13 | space: 14 | multi_discrete: 15 | 16 | mlp: 17 | units: [64] 18 | activation: relu 19 | initializer: 20 | name: default 21 | regularizer: 22 | name: None 23 | rnn: 24 | name: lstm 25 | #layer_norm: True 26 | units: 64 27 | layers: 1 28 | before_mlp: False 29 | 30 | config: 31 | reward_shaper: 32 | scale_value: 1 33 | normalize_advantage: True 34 | gamma: 0.99 35 | tau: 0.9 36 | learning_rate: 2e-4 37 | name: test_rnn_md_mhv 38 | score_to_win: 0.99 39 | grad_norm: 10.5 40 | entropy_coef: 0.005 41 | truncate_grads: True 42 | env_name: test_env 43 | e_clip: 0.2 44 | clip_value: False 45 | num_actors: 16 46 | horizon_length: 512 47 | minibatch_size: 2048 48 | mini_epochs: 4 49 | critic_coef: 1 50 | lr_schedule: None 51 | kl_threshold: 0.008 52 | normalize_input: False 53 | normalize_value: True 54 | seq_length: 16 55 | weight_decay: 0.0000 56 | max_epochs: 10000 57 | 58 | env_config: 59 | name: TestRnnEnv-v0 60 | hide_object: True 61 | apply_dist_reward: True 62 | min_dist: 2 63 | max_dist: 8 64 | use_central_value: False 65 | multi_discrete_space: True 66 | multi_head_value: True 67 | player: 68 | games_num: 100 69 | deterministic: True 70 | 71 | -------------------------------------------------------------------------------- /rl_games/envs/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from rl_games.envs.test_network import TestNetBuilder, TestNetAuxLossBuilder, SimpleNetBuilder 4 | from rl_games.algos_torch import model_builder 5 | 6 | model_builder.register_network('testnet', TestNetBuilder) 7 | model_builder.register_network('simplenet', SimpleNetBuilder) 8 | model_builder.register_network('testnet_aux_loss', TestNetAuxLossBuilder) -------------------------------------------------------------------------------- /rl_games/envs/test/__init__.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | gym.envs.register( 4 | id='TestRnnEnv-v0', 5 | entry_point='rl_games.envs.test.rnn_env:TestRNNEnv', 6 | max_episode_steps=100500, 7 | ) 8 | 9 | gym.envs.register( 10 | id='TestAsymmetricEnv-v0', 11 | entry_point='rl_games.envs.test.test_asymmetric_env:TestAsymmetricCritic' 12 | ) -------------------------------------------------------------------------------- /rl_games/envs/test/example_env.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | 5 | class ExampleEnv(gym.Env): 6 | ''' 7 | Just example empty env which demonstrates additional features compared to the default openai gym 8 | ''' 9 | def __init__(self, **kwargs): 10 | gym.Env.__init__(self) 11 | 12 | self.use_central_value = True 13 | self.value_size = 2 14 | self.concat_infos = False 15 | self.action_space = gym.spaces.Tuple([gym.spaces.Discrete(2),gym.spaces.Discrete(3)]) # gym.spaces.Discrete(3), gym.spaces.Box(low=0, high=1, shape=(3, ), dtype=np.float32) 16 | self.observation_space = gym.spaces.Box(low=0, high=1, shape=(6, ), dtype=np.float32) # or Dict 17 | 18 | def get_number_of_agents(self): 19 | return 1 20 | 21 | def has_action_mask(self): 22 | return False 23 | 24 | def get_action_mask(self): 25 | pass -------------------------------------------------------------------------------- /rl_games/interfaces/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/rl_games/interfaces/__init__.py -------------------------------------------------------------------------------- /rl_games/interfaces/base_algorithm.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | from abc import abstractmethod, abstractproperty 3 | 4 | 5 | class BaseAlgorithm(ABC): 6 | def __init__(self, base_name, config): 7 | pass 8 | 9 | @abstractproperty 10 | def device(self): 11 | pass 12 | 13 | @abstractmethod 14 | def clear_stats(self): 15 | pass 16 | 17 | @abstractmethod 18 | def train(self): 19 | pass 20 | 21 | @abstractmethod 22 | def train_epoch(self): 23 | pass 24 | 25 | @abstractmethod 26 | def get_full_state_weights(self): 27 | pass 28 | 29 | @abstractmethod 30 | def set_full_state_weights(self, weights, set_epoch): 31 | pass 32 | 33 | @abstractmethod 34 | def get_weights(self): 35 | pass 36 | 37 | @abstractmethod 38 | def set_weights(self, weights): 39 | pass 40 | 41 | # Get algo training parameters 42 | @abstractmethod 43 | def get_param(self, param_name): 44 | pass 45 | 46 | # Set algo training parameters 47 | @abstractmethod 48 | def set_param(self, param_name, param_value): 49 | pass 50 | 51 | 52 | -------------------------------------------------------------------------------- /rl_games/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from rl_games.networks.tcnn_mlp import TcnnNetBuilder 2 | from rl_games.algos_torch import model_builder 3 | 4 | model_builder.register_network('tcnnnet', TcnnNetBuilder) -------------------------------------------------------------------------------- /rl_games/networks/tcnn_mlp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TcnnNet(nn.Module): 7 | def __init__(self, params, **kwargs): 8 | import tinycudann as tcnn 9 | nn.Module.__init__(self) 10 | self.actions_num = actions_num = kwargs.pop('actions_num') 11 | input_shape = kwargs.pop('input_shape') 12 | num_inputs = input_shape[0] 13 | self.central_value = params.get('central_value', False) 14 | self.sigma = torch.nn.Parameter(torch.zeros(actions_num, requires_grad=True, dtype=torch.float32), 15 | requires_grad=True) 16 | self.model = tcnn.NetworkWithInputEncoding(n_input_dims=num_inputs, n_output_dims=actions_num+1, 17 | encoding_config=params["encoding"], network_config=params["network"]) 18 | def is_rnn(self): 19 | return False 20 | 21 | def forward(self, obs): 22 | obs = obs['obs'] 23 | mu_val = self.model(obs) 24 | mu, value = torch.split(mu_val, [self.actions_num, 1], dim=1) 25 | return mu, mu * 0.0 + self.sigma, value, None 26 | 27 | 28 | from rl_games.algos_torch.network_builder import NetworkBuilder 29 | 30 | 31 | class TcnnNetBuilder(NetworkBuilder): 32 | def __init__(self, **kwargs): 33 | NetworkBuilder.__init__(self) 34 | 35 | def load(self, params): 36 | self.params = params 37 | 38 | def build(self, name, **kwargs): 39 | return TcnnNet(self.params, **kwargs) 40 | 41 | def __call__(self, name, **kwargs): 42 | return self.build(name, **kwargs) 43 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """Setup script for rl_games""" 2 | 3 | import sys 4 | import os 5 | import pathlib 6 | 7 | from setuptools import setup, find_packages 8 | # The directory containing this file 9 | HERE = pathlib.Path(__file__).parent 10 | 11 | # The text of the README file 12 | README = (HERE / "README.md").read_text() 13 | print(find_packages()) 14 | 15 | setup(name='rl-games', 16 | long_description=README, 17 | long_description_content_type="text/markdown", 18 | url="https://github.com/Denys88/rl_games", 19 | #packages=[package for package in find_packages() if package.startswith('rl_games')], 20 | packages = ['.','rl_games','docs'], 21 | package_data={'rl_games':['*','*/*','*/*/*'],'docs':['*','*/*','*/*/*'],}, 22 | version='1.6.1', 23 | author='Denys Makoviichuk, Viktor Makoviichuk', 24 | author_email='trrrrr97@gmail.com, victor.makoviychuk@gmail.com', 25 | license="MIT", 26 | classifiers=[ 27 | "License :: OSI Approved :: MIT License", 28 | "Programming Language :: Python :: 3", 29 | "Programming Language :: Python :: 3.7", 30 | "Programming Language :: Python :: 3.8", 31 | "Programming Language :: Python :: 3.9", 32 | "Programming Language :: Python :: 3.10" 33 | ], 34 | #packages=["rlg"], 35 | include_package_data=True, 36 | install_requires=[ 37 | # this setup is only for pytorch 38 | # 39 | 'gym>=0.17.2', 40 | 'torch>=1.7.0', 41 | 'numpy>=1.16.0', 42 | 'tensorboard>=1.14.0', 43 | 'tensorboardX>=1.6', 44 | 'setproctitle', 45 | 'psutil', 46 | 'pyyaml', 47 | 'watchdog>=2.1.9,<3.0.0', # for evaluation process 48 | ], 49 | ) 50 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Denys88/rl_games/51ac9aa2981ba3204ea513104a1da46e6b5a39c9/tests/__init__.py -------------------------------------------------------------------------------- /tests/simple_test.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_true(): 5 | assert True --------------------------------------------------------------------------------