├── .dockerignore ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── documentation.yml │ ├── feature_request.yml │ ├── others.yml │ └── question.yml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── style_check.yml │ └── unit_test.yml ├── .gitignore ├── CONTRIBUTING.md ├── Gallery.md ├── LICENSE ├── Makefile ├── Project.md ├── README.md ├── README_zh.md ├── conda ├── LICENSE.txt ├── conda_build_config.yaml └── meta.yaml ├── docker └── Dockerfile ├── docs ├── CONTRIBUTING_zh.md └── images │ ├── cartpole.png │ ├── chat.gif │ ├── drone.png │ ├── gridworld.jpg │ ├── gym-retro.jpg │ ├── mujoco.png │ ├── openrl_text.png │ ├── pong.png │ ├── qq.png │ ├── simple_spread_trained.gif │ ├── smac.png │ ├── snakes_1v1.gif │ ├── tic-tac-toe.jpeg │ └── train_ppo_cartpole.gif ├── examples ├── __init__.py ├── arena │ ├── README.md │ ├── evaluate_more_envs.py │ ├── run_arena.py │ └── test_reproducibility.py ├── atari │ ├── README.md │ ├── atari_ppo.yaml │ └── train_ppo.py ├── behavior_cloning │ ├── README.md │ ├── cartpole_bc.yaml │ ├── test_env.py │ └── train_bc.py ├── cartpole │ ├── README.md │ ├── a2c.yaml │ ├── callbacks.yaml │ ├── dqn_cartpole.yaml │ ├── dual_clip_ppo.yaml │ ├── ppo.yaml │ ├── train_a2c.py │ ├── train_dqn_beta.py │ └── train_ppo.py ├── crafter │ ├── README.md │ ├── crafter_ppo.yaml │ ├── render_crafter.py │ └── train_crafter.py ├── custom_env │ ├── README.md │ ├── gymnasium_env.py │ ├── openai_gym_env.py │ ├── pettingzoo_env.py │ ├── rock_paper_scissors.py │ └── train_and_test.py ├── ddpg │ ├── ddpg_pendulum.yaml │ └── train_ddpg_beta.py ├── dm_control │ ├── README.md │ ├── ppo.yaml │ └── train_ppo.py ├── envpool │ ├── README.md │ ├── envpool_wrappers.py │ ├── make_env.py │ └── train_ppo.py ├── gail │ ├── README.md │ ├── cartpole_gail.yaml │ ├── cartpole_gail_without_action.yaml │ ├── gen_data.py │ ├── gen_data_v1.py │ ├── test_dataset.py │ └── train_gail.py ├── gfootball │ └── README.md ├── gridworld │ ├── README.md │ ├── dqn_gridworld.yaml │ ├── train_dqn.py │ └── train_ppo.py ├── gym_pybullet_drones │ ├── README.md │ ├── ppo.yaml │ ├── test_env.py │ └── train_ppo.py ├── isaac │ ├── README.md │ ├── cfg │ │ ├── config.yaml │ │ └── task │ │ │ └── Cartpole.yaml │ ├── isaac2openrl.py │ └── train_ppo.py ├── mpe │ ├── README.md │ ├── mpe_jrpo.yaml │ ├── mpe_mat.yaml │ ├── mpe_ppo.yaml │ ├── mpe_vdn.yaml │ ├── train_mat.py │ ├── train_ppo.py │ └── train_vdn.py ├── mujoco │ ├── README.md │ └── train_ppo.py ├── nlp │ ├── README.md │ ├── chat.py │ ├── chat_6b.py │ ├── ds_config.json │ ├── eval_ds_config.json │ ├── nlp_ppo.yaml │ ├── nlp_ppo_ds.yaml │ └── train_ppo.py ├── retro │ ├── custom_registration.py │ ├── retro_env │ │ ├── __init__.py │ │ └── retro_convert.py │ └── train_retro.py ├── sac │ ├── README.md │ ├── ddpg.yaml │ ├── sac.yaml │ ├── train_ddpg.py │ └── train_sac_beta.py ├── sb3 │ ├── README.md │ ├── ppo.yaml │ ├── test_model.py │ └── train_ppo.py ├── selfplay │ ├── README.md │ ├── human_vs_agent.py │ ├── opponent_templates │ │ ├── random_opponent │ │ │ └── opponent.py │ │ └── tictactoe_opponent │ │ │ ├── info.json │ │ │ └── opponent.py │ ├── selfplay.yaml │ ├── test_env.py │ ├── tictactoe_utils │ │ ├── __init__.py │ │ ├── game.py │ │ ├── minmax.py │ │ └── tictactoe_render.py │ └── train_selfplay.py ├── smac │ ├── README.md │ ├── custom_vecinfo.py │ ├── smac_env │ │ ├── StarCraft2_Env.py │ │ ├── __init__.py │ │ ├── multiagentenv.py │ │ ├── smac_env.py │ │ └── smac_maps.py │ ├── smac_ppo.yaml │ └── train_ppo.py ├── smacv2 │ ├── custom_vecinfo.py │ ├── smac_env │ │ ├── StarCraft2_Env.py │ │ ├── __init__.py │ │ ├── distributions.py │ │ ├── multiagentenv.py │ │ ├── smac_env.py │ │ ├── smac_maps.py │ │ └── wrapper.py │ ├── smacv2_ppo.yaml │ └── train_ppo.py ├── snake │ ├── README.md │ ├── jidi_eval.py │ ├── jidi_random_vs_openrl_random.py │ ├── selfplay.yaml │ ├── submissions │ │ ├── random_agent │ │ │ └── submission.py │ │ ├── rl │ │ │ ├── README.md │ │ │ └── submission.py │ │ └── rule_v1 │ │ │ └── submission.py │ ├── test_env.py │ ├── train_selfplay.py │ └── wrappers.py ├── super_mario │ └── train_super_mario.py └── toy_env │ ├── README.md │ ├── train.yaml │ ├── train_and_eval.py │ ├── train_ddpg.py │ ├── train_dqn.py │ ├── train_ppo.py │ └── train_sac.py ├── openrl ├── __init__.py ├── algorithms │ ├── __init__.py │ ├── a2c.py │ ├── base_algorithm.py │ ├── behavior_cloning.py │ ├── ddpg.py │ ├── dqn.py │ ├── gail.py │ ├── mat.py │ ├── ppo.py │ ├── sac.py │ └── vdn.py ├── arena │ ├── __init__.py │ ├── agents │ │ ├── __init__.py │ │ ├── base_agent.py │ │ ├── jidi_agent.py │ │ ├── local_agent.py │ │ └── random_agent.py │ ├── base_arena.py │ ├── games │ │ ├── __init__.py │ │ ├── base_game.py │ │ └── two_player_game.py │ ├── two_player_arena.py │ └── utils.py ├── buffers │ ├── __init__.py │ ├── normal_buffer.py │ ├── offpolicy_buffer.py │ ├── offpolicy_replay_data.py │ ├── replay_data.py │ └── utils │ │ ├── __init__.py │ │ ├── obs_data.py │ │ └── util.py ├── cli │ ├── __init__.py │ ├── cli.py │ └── train.py ├── configs │ ├── __init__.py │ ├── config.py │ └── utils.py ├── datasets │ ├── __init__.py │ └── expert_dataset.py ├── drivers │ ├── __init__.py │ ├── base_driver.py │ ├── offline_driver.py │ ├── offpolicy_driver.py │ ├── onpolicy_driver.py │ └── rl_driver.py ├── envs │ ├── PettingZoo │ │ ├── __init__.py │ │ └── registration.py │ ├── __init__.py │ ├── common │ │ ├── __init__.py │ │ ├── build_envs.py │ │ └── registration.py │ ├── connect_env │ │ ├── __init__.py │ │ ├── base_connect_env.py │ │ ├── connect3_env.py │ │ ├── connect4_env.py │ │ └── utils.py │ ├── crafter │ │ ├── __init__.py │ │ └── crafter.py │ ├── gridworld │ │ ├── __init__.py │ │ └── gridworld_env.py │ ├── gym_pybullet_drones │ │ └── __init__.py │ ├── gymnasium │ │ └── __init__.py │ ├── mpe │ │ ├── __init__.py │ │ ├── core.py │ │ ├── mpe_env.py │ │ ├── multi_discrete.py │ │ ├── multiagent_env.py │ │ ├── rendering.py │ │ ├── scenario.py │ │ └── scenarios │ │ │ ├── __init__.py │ │ │ └── simple_spread.py │ ├── nlp │ │ ├── __init__.py │ │ ├── daily_dialog_env.py │ │ ├── fake_dialog_env.py │ │ ├── nlp_env.py │ │ ├── rewards │ │ │ ├── intent.py │ │ │ ├── kl_penalty.py │ │ │ └── meteor.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── custom_text_generation_pools.py │ │ │ ├── distribution.py │ │ │ ├── evaluation_utils.py │ │ │ ├── metrics │ │ │ ├── __init__.py │ │ │ └── meteor.py │ │ │ ├── observation.py │ │ │ ├── sampler.py │ │ │ └── text_generation_pool.py │ ├── offline │ │ ├── __init__.py │ │ └── offline_env.py │ ├── snake │ │ ├── __init__.py │ │ ├── discrete.py │ │ ├── game.py │ │ ├── gridgame.py │ │ ├── observation.py │ │ ├── snake.py │ │ ├── snake_pettingzoo.py │ │ └── space.py │ ├── super_mario │ │ ├── __init__.py │ │ └── super_mario_convert.py │ ├── toy_envs │ │ ├── __init__.py │ │ ├── bit_flipping_env.py │ │ └── identity_env.py │ ├── vec_env │ │ ├── __init__.py │ │ ├── async_venv.py │ │ ├── base_venv.py │ │ ├── sync_venv.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── numpy_utils.py │ │ │ ├── share_memory.py │ │ │ └── util.py │ │ ├── vec_info │ │ │ ├── __init__.py │ │ │ ├── base_vec_info.py │ │ │ ├── episode_rewards_info.py │ │ │ ├── nlp_vec_info.py │ │ │ └── simple_vec_info.py │ │ └── wrappers │ │ │ ├── __init__.py │ │ │ ├── base_wrapper.py │ │ │ ├── gen_data.py │ │ │ ├── reward_wrapper.py │ │ │ ├── vec_monitor_wrapper.py │ │ │ └── zero_reward_wrapper.py │ └── wrappers │ │ ├── __init__.py │ │ ├── atari_wrappers.py │ │ ├── base_wrapper.py │ │ ├── extra_wrappers.py │ │ ├── flatten.py │ │ ├── image_wrappers.py │ │ ├── mat_wrapper.py │ │ ├── monitor.py │ │ ├── multiagent_wrapper.py │ │ ├── pettingzoo_wrappers.py │ │ └── util.py ├── modules │ ├── __init__.py │ ├── base_module.py │ ├── bc_module.py │ ├── common │ │ ├── __init__.py │ │ ├── a2c_net.py │ │ ├── base_net.py │ │ ├── bc_net.py │ │ ├── ddpg_net.py │ │ ├── dqn_net.py │ │ ├── gail_net.py │ │ ├── mat_net.py │ │ ├── ppo_net.py │ │ ├── sac_net.py │ │ └── vdn_net.py │ ├── ddpg_module.py │ ├── dqn_module.py │ ├── gail_module.py │ ├── model_config.py │ ├── networks │ │ ├── MAT_network.py │ │ ├── __init__.py │ │ ├── base_policy_network.py │ │ ├── base_value_network.py │ │ ├── base_value_policy_network.py │ │ ├── ddpg_network.py │ │ ├── gail_discriminator.py │ │ ├── policy_network.py │ │ ├── policy_network_gpt.py │ │ ├── policy_value_network.py │ │ ├── policy_value_network_gpt.py │ │ ├── policy_value_network_sb3.py │ │ ├── q_network.py │ │ ├── sac_network.py │ │ ├── utils │ │ │ ├── __init__.py │ │ │ ├── act.py │ │ │ ├── attention.py │ │ │ ├── cnn.py │ │ │ ├── distributed_utils.py │ │ │ ├── distributions.py │ │ │ ├── mix.py │ │ │ ├── mlp.py │ │ │ ├── nlp │ │ │ │ ├── __init__.py │ │ │ │ ├── base_policy.py │ │ │ │ └── causal_policy.py │ │ │ ├── popart.py │ │ │ ├── rnn.py │ │ │ ├── running_mean_std.py │ │ │ ├── transformer_act.py │ │ │ ├── util.py │ │ │ └── vdn.py │ │ ├── value_network.py │ │ ├── value_network_gpt.py │ │ └── vdn_network.py │ ├── ppo_module.py │ ├── rl_module.py │ ├── sac_module.py │ ├── utils │ │ ├── __init__.py │ │ ├── util.py │ │ └── valuenorm.py │ └── vdn_module.py ├── rewards │ ├── __init__.py │ ├── base_reward.py │ ├── gail_reward.py │ └── nlp_reward.py ├── runners │ ├── __init__.py │ └── common │ │ ├── __init__.py │ │ ├── a2c_agent.py │ │ ├── base_agent.py │ │ ├── bc_agent.py │ │ ├── chat_agent.py │ │ ├── ddpg_agent.py │ │ ├── dqn_agent.py │ │ ├── gail_agent.py │ │ ├── mat_agent.py │ │ ├── ppo_agent.py │ │ ├── rl_agent.py │ │ ├── sac_agent.py │ │ └── vdn_agent.py ├── selfplay │ ├── __init__.py │ ├── callbacks │ │ ├── __init__.py │ │ ├── base_callback.py │ │ ├── selfplay_api.py │ │ └── selfplay_callback.py │ ├── multiplayer_env.py │ ├── opponents │ │ ├── __init__.py │ │ ├── base_opponent.py │ │ ├── jidi_opponent.py │ │ ├── network_opponent.py │ │ ├── opponent_env.py │ │ ├── opponent_template.py │ │ ├── random_opponent.py │ │ └── utils.py │ ├── sample_strategy │ │ ├── __init__.py │ │ ├── base_sample_strategy.py │ │ ├── last_opponent.py │ │ └── random_opponent.py │ ├── selfplay_api │ │ ├── __init__.py │ │ ├── base_api.py │ │ ├── opponent_model.py │ │ ├── selfplay_api.py │ │ └── selfplay_client.py │ └── wrappers │ │ ├── __init__.py │ │ ├── base_multiplayer_wrapper.py │ │ ├── human_opponent_wrapper.py │ │ ├── opponent_pool_wrapper.py │ │ └── random_opponent_wrapper.py ├── supports │ ├── __init__.py │ ├── opendata │ │ ├── __init__.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ └── opendata_utils.py │ └── opengpu │ │ ├── __init__.py │ │ ├── gpu_info.py │ │ └── manager.py └── utils │ ├── __init__.py │ ├── callbacks │ ├── __init__.py │ ├── callbacks.py │ ├── callbacks_factory.py │ ├── checkpoint_callback.py │ ├── eval_callback.py │ ├── processbar_callback.py │ └── stop_callback.py │ ├── custom_data_structure.py │ ├── evaluation.py │ ├── file_tool.py │ ├── logger.py │ ├── type_aliases.py │ └── util.py ├── pytest.ini ├── scripts ├── build_docker.sh ├── conda_build.sh ├── conda_upload.sh ├── gen_api_docs.sh ├── modify_api_docs.py ├── pypi_build.sh ├── pypi_upload.sh └── unittest.sh ├── setup.py └── tests ├── project └── test_version.py ├── test_algorithm ├── test_a2c_algorithm.py ├── test_bc_algorithm.py ├── test_ddpg_algorithm.py ├── test_dqn_algorithm.py ├── test_mat_algorithm.py ├── test_ppo_algorithm.py ├── test_sac_algorithm.py └── test_vdn_algorithm.py ├── test_arena ├── test_new_envs.py └── test_reproducibility.py ├── test_buffer ├── test_buffer.py ├── test_generator.py └── test_offpolicy_generator.py ├── test_callbacks └── test_callbacks.py ├── test_cli └── test_cli.py ├── test_dataset └── test_expert_dataset.py ├── test_env ├── test_connect_env.py ├── test_gridworld_env.py ├── test_mpe_env.py ├── test_nlp │ └── test_DailyDialogEnv.py ├── test_offline_env.py ├── test_snake_env.py ├── test_super_mario_env.py ├── test_vec_env │ ├── test_async_env.py │ ├── test_sync_env.py │ └── test_vec_wrappers.py └── test_wrappers.py ├── test_examples ├── test_nlp.py ├── test_train_atari.py ├── test_train_cartpole.py ├── test_train_gail.py ├── test_train_mpe.py ├── test_train_mujoco.py └── test_train_super_mario.py ├── test_modules ├── test_common │ ├── test_ddpg_net.py │ ├── test_dqn_net.py │ ├── test_sac_net.py │ └── test_vdn_net.py └── test_networks │ ├── test_MAT_network.py │ ├── test_attention.py │ └── test_policy_value_network_gpt.py ├── test_rewards └── test_nlp_reward.py ├── test_selfplay └── test_train_selfplay.py └── test_supports ├── test_opendata └── test_opendata.py └── test_opengpu ├── test_gpuinfo.py └── test_manager.py /.dockerignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | logs 3 | .pytest_cache/ 4 | .coverage 5 | .coverage.* 6 | .idea/ 7 | logs/ 8 | .pytype/ 9 | htmlcov/ 10 | .vscode/ 11 | .git/ 12 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F4DA Documentation" 2 | description: Report an issue related to OpenRL documentation 3 | labels: ["documentation"] 4 | body: 5 | - type: textarea 6 | id: description 7 | attributes: 8 | label: 📚 Documentation 9 | description: A clear and concise description of what should be improved in the documentation. 10 | validations: 11 | required: true 12 | - type: checkboxes 13 | id: terms 14 | attributes: 15 | label: Checklist 16 | options: 17 | - label: I have checked that there is no similar [issues](https://github.com/OpenRL-Lab/openrl/issues) in the repo 18 | required: true 19 | - label: I have read the [documentation](https://openrl-docs.readthedocs.io/) 20 | required: true 21 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature Request" 2 | description: How to create an issue for requesting a feature 3 | title: "[Feature Request] request title" 4 | labels: ["enhancement"] 5 | body: 6 | - type: textarea 7 | id: description 8 | attributes: 9 | label: 🚀 Feature 10 | description: A clear and concise description of the feature proposal. 11 | validations: 12 | required: true 13 | - type: textarea 14 | id: motivation 15 | attributes: 16 | label: Motivation 17 | description: Please outline the motivation for the proposal. Is your feature request related to a problem? e.g.,"I'm always frustrated when [...]". If this is related to another GitHub issue, please link here too. 18 | - type: textarea 19 | id: additional-context 20 | attributes: 21 | label: Additional context 22 | description: Add any other context or screenshots about the feature request here. 23 | - type: checkboxes 24 | id: terms 25 | attributes: 26 | label: Checklist 27 | options: 28 | - label: I have checked that there is no similar [issues](https://github.com/OpenRL-Lab/openrl/issues) in the repo 29 | required: true 30 | - label: I have read the [documentation](https://openrl-docs.readthedocs.io/) 31 | required: true 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/others.yml: -------------------------------------------------------------------------------- 1 | name: "❓ Other Questions" 2 | description: You can ask any question here! 3 | title: "[Question] question title" 4 | labels: ["question"] 5 | body: 6 | - type: textarea 7 | id: question 8 | attributes: 9 | label: ❓ Question 10 | description: Any question here! 11 | validations: 12 | required: true 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/question.yml: -------------------------------------------------------------------------------- 1 | name: "❓ Question" 2 | description: How to ask a question regarding OpenRL 3 | title: "[Question] question title" 4 | labels: ["question"] 5 | body: 6 | - type: textarea 7 | id: question 8 | attributes: 9 | label: ❓ Question 10 | description: Your question. This can be e.g. questions regarding confusing or unclear behaviour of functions or a question if X can be done using OpenRL. Make sure to check out the documentation first. 11 | validations: 12 | required: true 13 | - type: checkboxes 14 | id: terms 15 | attributes: 16 | label: Checklist 17 | options: 18 | - label: I have checked that there is no similar [issues](https://github.com/OpenRL-Lab/openrl/issues) in the repo 19 | required: true 20 | - label: I have read the [documentation](https://openrl-docs.readthedocs.io/) 21 | required: true 22 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Description 2 | 3 | 4 | ## Types of changes 5 | 6 | - [ ] Bug fix (non-breaking change which fixes an issue) 7 | - [ ] New feature (non-breaking change which adds functionality) 8 | 9 | ## Checklist 10 | 11 | - [ ] I have ensured `make test` pass (**required**). 12 | - [ ] I have checked the code using `make format` (**required**). -------------------------------------------------------------------------------- /.github/workflows/style_check.yml: -------------------------------------------------------------------------------- 1 | # This workflow will check flake style 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: style_check 5 | 6 | on: [push, pull_request] 7 | 8 | jobs: 9 | style_check: 10 | # Skip CI if [ci skip] in the commit message 11 | if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')" 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: [3.8] 16 | 17 | steps: 18 | - uses: actions/checkout@v3 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v4 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | - name: install test packages 24 | run: | 25 | python -m pip install --upgrade pip 26 | python -m pip install "ruff" 27 | 28 | - name: Lint with ruff 29 | run: | 30 | make lint 31 | -------------------------------------------------------------------------------- /.github/workflows/unit_test.yml: -------------------------------------------------------------------------------- 1 | name: unit_test 2 | 3 | on: [ pull_request ] 4 | 5 | jobs: 6 | test_unittest: 7 | runs-on: ubuntu-latest 8 | # Skip CI if [ci skip] in the commit message 9 | if: "! contains(toJSON(github.event.commits.*.message), '[ci skip]')" 10 | strategy: 11 | matrix: 12 | python-version: [ 3.8, 3.11 ] 13 | 14 | steps: 15 | - uses: actions/checkout@v3 16 | - name: Set up Python ${{ matrix.python-version }} 17 | uses: actions/setup-python@v4 18 | with: 19 | python-version: ${{ matrix.python-version }} 20 | - name: Install system dependencies 21 | run: | 22 | sudo apt-get update 23 | sudo apt-get install -y xvfb libglu1-mesa-dev python3-opengl 24 | - name: Upgrade pip 25 | run: | 26 | python -m pip install --upgrade pip setuptools wheel 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install . 30 | python -m pip install ".[test]" --upgrade 31 | - name: do_unittest 32 | timeout-minutes: 40 33 | run: | 34 | xvfb-run -s "-screen 0 1400x900x24" python3 -m pytest tests --cov=openrl --cov-report=xml -m unittest --cov-report=term-missing --durations=0 -v --color=yes -s 35 | - name: Upload coverage reports to Codecov with GitHub Action 36 | uses: codecov/codecov-action@v3 37 | with: 38 | token: ${{ secrets.CODECOV_TOKEN }} 39 | file: ./coverage.xml 40 | flags: unittests 41 | name: codecov-umbrella 42 | fail_ci_if_error: false 43 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | SHELL=/bin/bash 2 | PROJECT_NAME=openrl 3 | PROJECT_PATH=${PROJECT_NAME}/ 4 | PYTHON_FILES = $(shell find setup.py ${PROJECT_NAME} tests examples -type f -name "*.py") 5 | 6 | check_install = python3 -c "import $(1)" || pip3 install $(1) --upgrade 7 | check_install_extra = python3 -c "import $(1)" || pip3 install $(2) --upgrade 8 | 9 | test: 10 | ./scripts/unittest.sh 11 | 12 | lint: 13 | $(call check_install, ruff) 14 | ruff ${PYTHON_FILES} --select=E9,F63,F7,F82 --show-source 15 | ruff ${PYTHON_FILES} --exit-zero | grep -v '501\|405\|401\|402\|403\|722' 16 | 17 | format: 18 | $(call check_install, isort) 19 | $(call check_install, black) 20 | # Sort imports 21 | isort ${PYTHON_FILES} 22 | # Reformat using black 23 | black ${PYTHON_FILES} --preview 24 | # do format agent 25 | isort ${PYTHON_FILES} 26 | black ${PYTHON_FILES} --preview 27 | 28 | commit-checks: format lint 29 | 30 | docker-cpu: 31 | RELEASE=True ./scripts/build_docker.sh 32 | 33 | docker-gpu: 34 | RELEASE=True USE_GPU=True ./scripts/build_docker.sh 35 | 36 | pypi: 37 | ./scripts/pypi_build.sh 38 | 39 | pypi-test-upload: 40 | ./scripts/pypi_upload.sh test 41 | 42 | pypi-upload: 43 | ./scripts/pypi_upload.sh 44 | 45 | conda-build: 46 | ./scripts/conda_build.sh 47 | 48 | conda-upload: 49 | ./scripts/conda_upload.sh 50 | 51 | doc: 52 | ./scripts/gen_api_docs.sh 53 | 54 | upload-codecov: 55 | codecov --file coverage.xml -t $(CODECOV_TOKEN) -------------------------------------------------------------------------------- /conda/conda_build_config.yaml: -------------------------------------------------------------------------------- 1 | python: 2 | - 3.8 3 | -------------------------------------------------------------------------------- /conda/meta.yaml: -------------------------------------------------------------------------------- 1 | {% set data = load_setup_py_data(setup_file='../setup.py', from_recipe_dir=True) %} 2 | package: 3 | name: openrl 4 | version: {{ data.get('version') }} 5 | 6 | source: 7 | path: .. 8 | 9 | build: 10 | number: 0 11 | script: python -m pip install . -vv 12 | entry_points: 13 | - openrl = openrl.cli.cli:run 14 | 15 | requirements: 16 | build: 17 | - python 18 | - setuptools 19 | run: 20 | - python 21 | 22 | test: 23 | imports: 24 | - openrl 25 | 26 | about: 27 | home: https://github.com/OpenRL-Lab/openrl 28 | license: Apache-2.0 29 | license_file: LICENSE.txt 30 | summary: OpenRL is a reinforcement learning framework (https://github.com/OpenRL-Lab/openrl). 31 | description: Please refer to https://openrl-docs.readthedocs.io/en/latest/ 32 | dev_url: https://github.com/OpenRL-Lab/openrl 33 | doc_url: Please refer to https://openrl-docs.readthedocs.io/en/latest/ 34 | doc_source_url: https://github.com/OpenRL-Lab/openrl-docs/ 35 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | ARG PARENT_IMAGE 2 | FROM $PARENT_IMAGE 3 | 4 | WORKDIR /openrl 5 | 6 | ADD setup.py setup.py 7 | ADD openrl openrl 8 | ADD README.md README.md 9 | 10 | ENV VENV /root/venv 11 | 12 | RUN \ 13 | pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple 14 | 15 | RUN \ 16 | python3 -m pip install --upgrade pip --no-cache-dir && \ 17 | python3 -m pip install --no-cache-dir . && \ 18 | python3 -m pip install --no-cache-dir ".[nlp]" && \ 19 | rm -rf $HOME/.cache/pip 20 | 21 | ENV PATH=$VENV/bin:$PATH 22 | 23 | CMD /bin/bash 24 | -------------------------------------------------------------------------------- /docs/CONTRIBUTING_zh.md: -------------------------------------------------------------------------------- 1 | ## 如何参与OpenRL的建设 2 | 3 | [English](../CONTRIBUTING.md) 4 | 5 | OpenRL社区欢迎任何人参与到OpenRL的建设中来,无论您是开发者还是用户,您的反馈和贡献都是我们前进的动力! 6 | 您可以通过以下方式加入到OpenRL的贡献中来: 7 | 8 | - 作为OpenRL的用户,发现OpenRL中的bug,并提交[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose)。 9 | - 作为OpenRL的用户,发现OpenRL文档中的错误,并提交[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose)。 10 | - 写测试代码,提升OpenRL的代码测试覆盖率(大家可以从[这里](https://app.codecov.io/gh/OpenRL-Lab/openrl)查到OpenRL的代码测试覆盖情况)。 11 | 您可以选择感兴趣的代码片段进行编写代码测试, 12 | - 作为OpenRL的开发者,为OpenRL修复已有的bug。 13 | - 作为OpenRL的开发者,为OpenRL添加新的环境和样例。 14 | - 作为OpenRL的开发者,为OpenRL添加新的算法。 15 | 16 | ## 贡献者手册 17 | 18 | 欢迎更多的人参与到OpenRL的开发中来,我们非常欢迎您的贡献! 19 | 20 | - 如果您想要贡献新的功能,请先在请先创建一个新的[issue](https://github.com/OpenRL-Lab/openrl/issues/new/choose), 21 | 以便我们讨论这个功能的实现细节。如果该功能得到了大家的认可,您可以开始进行代码实现。 22 | - 您也可以在 [Issues](https://github.com/OpenRL-Lab/openrl/issues) 中查看未被实现的功能和仍然存的在bug, 23 | 在对应的issue中进行回复,说明您想要解决该issue,然后开始进行代码实现。 24 | 25 | 在您完成了代码实现之后,您需要拉取最新的`main`分支并进行合并。 26 | 解决合并冲突后, 27 | 您可以通过提交 [Pull Request](https://github.com/OpenRL-Lab/openrl/pulls) 28 | 的方式将您的代码合并到OpenRL的main分支中。 29 | 30 | 在提交Pull Request前,您需要完成 [代码测试和代码格式化](#代码测试和代码格式化)。 31 | 32 | 然后,您的Pull Request需要通过GitHub上的自动化测试。 33 | 34 | 最后,需要得到至少一个开发人员的review和批准,才能被合并到main分支中。 35 | 36 | ## 代码测试和代码格式化 37 | 38 | 在您提交Pull Request之前,您需要确保您的代码通过了单元测试,并且符合OpenRL的代码风格。 39 | 40 | 首先,您需要安装测试相关的包:`pip install -e ".[test]"` 41 | 42 | 然后,您需要确保单元测试通过,这可以通过执行`make test`来完成。 43 | 44 | 最后,您需要执行`make format`来格式化您的代码。 45 | 46 | > 小技巧: OpenRL使用 [black](https://github.com/psf/black) 代码风格。 47 | 您可以在您的编辑器中安装black的[插件](https://black.readthedocs.io/en/stable/integrations/editors.html), 48 | 来帮助您自动格式化代码。 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /docs/images/cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/cartpole.png -------------------------------------------------------------------------------- /docs/images/chat.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/chat.gif -------------------------------------------------------------------------------- /docs/images/drone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/drone.png -------------------------------------------------------------------------------- /docs/images/gridworld.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/gridworld.jpg -------------------------------------------------------------------------------- /docs/images/gym-retro.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/gym-retro.jpg -------------------------------------------------------------------------------- /docs/images/mujoco.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/mujoco.png -------------------------------------------------------------------------------- /docs/images/openrl_text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/openrl_text.png -------------------------------------------------------------------------------- /docs/images/pong.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/pong.png -------------------------------------------------------------------------------- /docs/images/qq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/qq.png -------------------------------------------------------------------------------- /docs/images/simple_spread_trained.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/simple_spread_trained.gif -------------------------------------------------------------------------------- /docs/images/smac.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/smac.png -------------------------------------------------------------------------------- /docs/images/snakes_1v1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/snakes_1v1.gif -------------------------------------------------------------------------------- /docs/images/tic-tac-toe.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/tic-tac-toe.jpeg -------------------------------------------------------------------------------- /docs/images/train_ppo_cartpole.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/docs/images/train_ppo_cartpole.gif -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /examples/arena/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### Installation 3 | 4 | ```bash 5 | pip install "openrl[selfplay]" 6 | pip install "pettingzoo[mpe]","pettingzoo[butterfly]" 7 | ``` 8 | 9 | ### Usage 10 | 11 | ```shell 12 | python run_arena.py 13 | ``` 14 | 15 | 16 | ### Evaluate Google Research Football submissions for JiDi locally 17 | 18 | If you want to evaluate your Google Research Football submissions for JiDi locally, please try to use tizero as illustrated [here](foothttps://github.com/OpenRL-Lab/TiZero#evaluate-jidi-submissions-locally). 19 | 20 | ### Evaluate more environments 21 | 22 | We also provide a script to evaluate more environments, including MPE, Go, Texas Holdem, Butterfly. You can run the script as follows: 23 | 24 | ```shell 25 | python evaluate_more_envs.py 26 | ``` -------------------------------------------------------------------------------- /examples/arena/test_reproducibility.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from run_arena import run_arena 20 | 21 | 22 | def test_seed(seed: int): 23 | test_time = 5 24 | pre_result = None 25 | for parallel in [False, True]: 26 | for i in range(test_time): 27 | result = run_arena(seed=seed, parallel=parallel, total_games=20) 28 | if pre_result is not None: 29 | assert pre_result == result, f"parallel={parallel}, seed={seed}" 30 | pre_result = result 31 | 32 | 33 | if __name__ == "__main__": 34 | test_seed(0) 35 | -------------------------------------------------------------------------------- /examples/atari/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | `pip install "gymnasium[atari]"` 4 | 5 | Then install auto-rom via: 6 | `pip install "gymnasium[accept-rom-license]"` 7 | 8 | or: 9 | ```shell 10 | pip install autorom 11 | AutoROM --accept-license 12 | ``` 13 | 14 | or, if you can not download the ROMs, you can download them manually from [Google Drive](https://drive.google.com/file/d/1agerLX3fP2YqUCcAkMF7v_ZtABAOhlA7/view?usp=sharing). 15 | Then, you can install the ROMs via: 16 | ```shell 17 | pip install autorom 18 | AutoROM --source-file 19 | ```` 20 | 21 | 22 | ## Usage 23 | 24 | ```shell 25 | python train_ppo.py 26 | ``` -------------------------------------------------------------------------------- /examples/atari/atari_ppo.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | lr: 2.5e-4 3 | critic_lr: 2.5e-4 4 | episode_length: 128 5 | gamma: 0.99 6 | ppo_epoch: 3 7 | gain: 0.01 8 | use_linear_lr_decay: true 9 | use_share_model: true 10 | entropy_coef: 0.01 11 | hidden_size: 512 12 | num_mini_batch: 8 13 | clip_param: 0.2 14 | value_loss_coef: 0.5 15 | max_grad_norm: 10 16 | 17 | run_dir: ./run_results/ 18 | 19 | log_interval: 1 20 | use_recurrent_policy: false 21 | use_valuenorm: true 22 | use_adv_normalize: true 23 | 24 | wandb_entity: openrl-lab 25 | experiment_name: atari_ppo 26 | 27 | vec_info_class: 28 | id: "EPS_RewardInfo" -------------------------------------------------------------------------------- /examples/behavior_cloning/README.md: -------------------------------------------------------------------------------- 1 | ## Prepare Dataset 2 | 3 | Go to `examples/gail` folder. 4 | Run following command to generate dataset for behavior cloning: `python gen_data.py`, then you will get a file named `data.pkl` in current folder. 5 | Then copy the `data.pkl` to current folder. 6 | 7 | ## Train 8 | 9 | Run following command to train behavior cloning: `python train_bc.py --config cartpole_bc.yaml` -------------------------------------------------------------------------------- /examples/behavior_cloning/cartpole_bc.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | expert_data: "./data.pkl" -------------------------------------------------------------------------------- /examples/behavior_cloning/test_env.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | import numpy as np 4 | 5 | from openrl.configs.config import create_config_parser 6 | from openrl.envs.common import make 7 | 8 | 9 | def test_env(): 10 | # create the neural network 11 | cfg_parser = create_config_parser() 12 | cfg = cfg_parser.parse_args() 13 | 14 | # create environment 15 | env = make("OfflineEnv", env_num=1, cfg=cfg, asynchronous=True) 16 | 17 | for ep_index in range(10): 18 | done = False 19 | step = 0 20 | env.reset() 21 | 22 | while not np.all(done): 23 | obs, reward, done, info = env.step(env.random_action()) 24 | 25 | step += 1 26 | print(ep_index, step) 27 | 28 | 29 | if __name__ == "__main__": 30 | test_env() 31 | -------------------------------------------------------------------------------- /examples/cartpole/README.md: -------------------------------------------------------------------------------- 1 | ## How to Use 2 | 3 | Users can train CartPole via: 4 | 5 | ```shell 6 | python train_ppo.py --config ppo.yaml 7 | ``` 8 | 9 | 10 | To train with [Dual-clip PPO](https://arxiv.org/abs/1912.09729): 11 | 12 | ```shell 13 | python train_ppo.py --config dual_clip_ppo.yaml 14 | ``` 15 | 16 | To train with [A2C](https://arxiv.org/abs/1602.01783) algorithm: 17 | 18 | ```shell 19 | python train_a2c.py 20 | ``` 21 | 22 | If you want to evaluate the agent during training and save the best model and save checkpoints, try to train with callbacks: 23 | 24 | ```shell 25 | python train_ppo.py --config callbacks.yaml 26 | ``` 27 | 28 | More details about callbacks can be found in [Callbacks](https://openrl-docs.readthedocs.io/en/latest/callbacks/index.html). -------------------------------------------------------------------------------- /examples/cartpole/a2c.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | run_dir: ./run_results/ 3 | wandb_entity: openrl-lab -------------------------------------------------------------------------------- /examples/cartpole/dqn_cartpole.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | lr: 7e-4 3 | gamma: 0.9 4 | episode_length: 2000 5 | epsilon_anneal_time: 20000 6 | mini_batch_size: 128 7 | train_interval: 50 8 | num_mini_batch: 50 9 | run_dir: ./run_results/ 10 | experiment_name: train_dqn 11 | log_interval: 50 12 | 13 | use_recurrent_policy: false 14 | use_joint_action_loss: false 15 | use_valuenorm: false 16 | use_adv_normalize: false 17 | wandb_entity: openrl-lab 18 | 19 | callbacks: 20 | - id: "CheckpointCallback" 21 | args: { 22 | "save_freq": 500, # how often to save the model 23 | "save_path": "./results/checkpoints/", # where to save the model 24 | "name_prefix": "ppo", # the prefix of the saved model 25 | "save_replay_buffer": True # not work yet 26 | } 27 | - id: "EvalCallback" 28 | args: { 29 | "eval_env": {"id": "CartPole-v1","env_num":1}, # how many envs to set up for evaluation 30 | "n_eval_episodes": 4, # how many episodes to run for each evaluation 31 | "eval_freq": 500, # how often to run evaluation 32 | "log_path": "./results/eval_log_path", # where to save the evaluation results 33 | "best_model_save_path": "./results/best_model/", # where to save the best model 34 | "deterministic": True, # whether to use deterministic action 35 | "render": False, # whether to render the env 36 | "asynchronous": True, # whether to run evaluation asynchronously 37 | } -------------------------------------------------------------------------------- /examples/cartpole/dual_clip_ppo.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | run_dir: ./run_results/ 3 | wandb_entity: openrl-lab 4 | dual_clip_ppo: true 5 | dual_clip_coeff: 3.0 -------------------------------------------------------------------------------- /examples/cartpole/ppo.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | run_dir: ./run_results/ 3 | wandb_entity: openrl-lab -------------------------------------------------------------------------------- /examples/cartpole/train_dqn_beta.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | import numpy as np 4 | 5 | from openrl.configs.config import create_config_parser 6 | from openrl.envs.common import make 7 | from openrl.modules.common import DQNNet as Net 8 | from openrl.runners.common import DQNAgent as Agent 9 | 10 | 11 | def train(): 12 | # 添加读取配置文件的代码 13 | cfg_parser = create_config_parser() 14 | cfg = cfg_parser.parse_args(["--config", "dqn_cartpole.yaml"]) 15 | 16 | # 创建 环境 17 | env = make("CartPole-v1", env_num=4) 18 | 19 | # 创建 神经网络 20 | net = Net(env, cfg=cfg) 21 | # 初始化训练器 22 | agent = Agent(net) 23 | # 开始训练 24 | 25 | agent.train(total_time_steps=20000) 26 | 27 | env.close() 28 | return agent 29 | 30 | 31 | def evaluation(agent): 32 | # 开始测试环境 33 | env = make("CartPole-v1", render_mode="group_human") 34 | agent.set_env(env) 35 | obs, info = env.reset() 36 | done = False 37 | step = 0 38 | while not np.any(done): 39 | # 智能体根据 observation 预测下一个动作 40 | action, _ = agent.act(obs) 41 | obs, r, done, info = env.step(action) 42 | step += 1 43 | print(f"{step}: reward:{np.mean(r)}") 44 | env.close() 45 | 46 | 47 | if __name__ == "__main__": 48 | agent = train() 49 | evaluation(agent) 50 | -------------------------------------------------------------------------------- /examples/crafter/README.md: -------------------------------------------------------------------------------- 1 | # Crafter 2 | 3 | ## Installation 4 | 5 | ```bash 6 | git clone git@github.com:danijar/crafter.git 7 | git fetch origin pull/25/head:latest_gym 8 | git checkout latest_gym 9 | pip install -e . 10 | ``` 11 | 12 | ## train 13 | 14 | ```bash 15 | python train_crafter.py --config crafter_ppo.yaml 16 | ``` 17 | ## render video 18 | 19 | ```bash 20 | python render_crafter.py --config crafter_ppo.yaml 21 | ``` 22 | 23 | ## render trajectory 24 | 25 | * go to `openrl/envs/crafter/crafter.py` 26 | * set `save_stats=True` 27 | 28 | ```python 29 | self.env = crafter.Recorder( 30 | self.env, "crafter_traj", 31 | save_stats=True, # set this to be True 32 | save_episode=False, 33 | save_video=False, 34 | ) 35 | ``` 36 | 37 | * run the following command 38 | 39 | ```bash 40 | python render_crafter.py --config crafter_ppo.yaml 41 | ``` 42 | 43 | * you can get the trajectory in `crafter_traj/stats.json1`. Following is an example of the stats file. 44 | 45 | ```json 46 | {"length": 143, "reward": 1.1, "achievement_collect_coal": 0, "achievement_collect_diamond": 0, "achievement_collect_drink": 15, "achievement_collect_iron": 0, "achievement_collect_sapling": 0, "achievement_collect_stone": 0, "achievement_collect_wood": 0, "achievement_defeat_skeleton": 0, "achievement_defeat_zombie": 0, "achievement_eat_cow": 0, "achievement_eat_plant": 0, "achievement_make_iron_pickaxe": 0, "achievement_make_iron_sword": 0, "achievement_make_stone_pickaxe": 0, "achievement_make_stone_sword": 0, "achievement_make_wood_pickaxe": 0, "achievement_make_wood_sword": 0, "achievement_place_furnace": 0, "achievement_place_plant": 0, "achievement_place_stone": 0, "achievement_place_table": 0, "achievement_wake_up": 3} 47 | ``` -------------------------------------------------------------------------------- /examples/crafter/crafter_ppo.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | lr: 7e-4 3 | critic_lr: 7e-4 4 | episode_length: 512 5 | ppo_epoch: 5 6 | data_chunk_length: 8 7 | num_mini_batch: 8 8 | run_dir: ./run_results/ 9 | experiment_name: train_crafter 10 | log_interval: 10 11 | use_recurrent_policy: true 12 | use_joint_action_loss: false 13 | use_valuenorm: true 14 | use_adv_normalize: true 15 | wandb_entity: cwz19 16 | -------------------------------------------------------------------------------- /examples/crafter/train_crafter.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import numpy as np 20 | 21 | from openrl.configs.config import create_config_parser 22 | from openrl.envs.common import make 23 | from openrl.envs.wrappers import GIFWrapper 24 | from openrl.modules.common import PPONet as Net 25 | from openrl.runners.common import PPOAgent as Agent 26 | 27 | 28 | def train(): 29 | # create environment 30 | env = make("Crafter", env_num=32, asynchronous=True) 31 | # config 32 | cfg_parser = create_config_parser() 33 | cfg = cfg_parser.parse_args() 34 | # create the neural network 35 | net = Net(env, cfg=cfg, device="cuda") 36 | # initialize the trainer 37 | agent = Agent(net, use_wandb=True) 38 | # start training 39 | agent.train(total_time_steps=100000000) 40 | # save the trained model 41 | agent.save("crafter_agent-100M/") 42 | # close the environment 43 | env.close() 44 | return agent 45 | 46 | 47 | if __name__ == "__main__": 48 | agent = train() 49 | -------------------------------------------------------------------------------- /examples/custom_env/README.md: -------------------------------------------------------------------------------- 1 | # Integrate user-defined environments into OpenRL 2 | 3 | [[Tutorial](https://openrl-docs.readthedocs.io/en/latest/custom_env/index.html)] | [[中文教程](https://openrl-docs.readthedocs.io/zh/latest/custom_env/index.html)] 4 | 5 | Here, we provide several toy examples to show how to add user-defined environments into OpenRL. 6 | 7 | - `gymnasium_env.py`: a simple example to show how to create a Gymnasium environment and integrate it into OpenRL. 8 | - `openai_gym_env.py`: a simple example to show how to create a OpenAI Gym environment and integrate it into OpenRL. 9 | - `pettingzoo_env.py`: a simple example to show how to create a PettingZoo environment and integrate it into OpenRL. -------------------------------------------------------------------------------- /examples/custom_env/pettingzoo_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | 20 | from rock_paper_scissors import RockPaperScissors 21 | from train_and_test import train_and_test 22 | 23 | from openrl.envs.common import make 24 | from openrl.envs.PettingZoo.registration import register 25 | from openrl.selfplay.wrappers.random_opponent_wrapper import RandomOpponentWrapper 26 | 27 | register("RockPaperScissors", RockPaperScissors) 28 | 29 | env = make( 30 | "RockPaperScissors", 31 | env_num=10, 32 | opponent_wrappers=[RandomOpponentWrapper], 33 | ) 34 | 35 | train_and_test(env) 36 | -------------------------------------------------------------------------------- /examples/custom_env/train_and_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import numpy as np 20 | 21 | from openrl.modules.common.ppo_net import PPONet as Net 22 | from openrl.runners.common.ppo_agent import PPOAgent as Agent 23 | 24 | 25 | def train(env): 26 | agent = Agent(Net(env)) 27 | agent.train(5000) 28 | return agent 29 | 30 | 31 | def test(env, agent): 32 | obs, info = env.reset() 33 | done = False 34 | total_reward = 0 35 | total_step = 0 36 | while not np.any(done): 37 | action, _ = agent.act(obs) 38 | obs, r, done, info = env.step(action) 39 | total_step += 1 40 | total_reward += np.mean(r) 41 | print("total test step: ", total_step) 42 | print("total test reward: ", total_reward) 43 | 44 | 45 | def train_and_test(env): 46 | agent = train(env) 47 | test(env, agent) 48 | -------------------------------------------------------------------------------- /examples/ddpg/ddpg_pendulum.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | hidden_size: 32 3 | actor_lr: 0.001 4 | critic_lr: 0.002 5 | gamma: 0.9 6 | tau: 0.01 7 | var: 2 8 | episode_length: 200 9 | mini_batch_size: 64 10 | buffer_size: 10000 11 | train_interval: 10 12 | num_mini_batch: 50 13 | run_dir: ./run_results/ 14 | experiment_name: train_ddpg 15 | log_interval: 10 16 | 17 | use_recurrent_policy: false 18 | use_joint_action_loss: false 19 | use_valuenorm: false 20 | use_adv_normalize: false 21 | wandb_entity: openrl-lab -------------------------------------------------------------------------------- /examples/ddpg/train_ddpg_beta.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | import numpy as np 4 | 5 | from openrl.configs.config import create_config_parser 6 | 7 | # from openrl.envs.toy_envs import make 8 | from openrl.envs.common import make 9 | from openrl.modules.common import DDPGNet as Net 10 | from openrl.runners.common import DDPGAgent as Agent 11 | 12 | 13 | def train(): 14 | # 添加读取配置文件的代码 15 | cfg_parser = create_config_parser() 16 | cfg = cfg_parser.parse_args() 17 | 18 | # 创建 环境 19 | env = make("Pendulum-v1", env_num=5) 20 | # 创建 神经网络 21 | net = Net(env, cfg=cfg) 22 | # 初始化训练器 23 | agent = Agent(net) 24 | # 开始训练 25 | # agent.train(total_time_steps=100000) 26 | agent.train(total_time_steps=20000) 27 | env.close() 28 | return agent 29 | 30 | 31 | def evaluation(agent): 32 | # begin to test 33 | # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. 34 | env = make("Pendulum-v1", render_mode="group_human", env_num=4, asynchronous=True) 35 | # The trained agent sets up the interactive environment it needs. 36 | agent.set_env(env) 37 | # Initialize the environment and get initial observations and environmental information. 38 | obs, info = env.reset() 39 | done = False 40 | step = 0 41 | while not np.any(done): 42 | # Based on environmental observation input, predict next action. 43 | action = agent.act(obs) 44 | obs, r, done, info = env.step(action) 45 | step += 1 46 | if step % 50 == 0: 47 | print(f"{step}: reward:{np.mean(r)}") 48 | env.close() 49 | 50 | 51 | if __name__ == "__main__": 52 | agent = train() 53 | evaluation(agent) 54 | -------------------------------------------------------------------------------- /examples/dm_control/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | ```bash 3 | pip install "shimmy[dm-control]" 4 | ``` 5 | 6 | ## Usage 7 | ```bash 8 | python train_ppo.py 9 | ``` -------------------------------------------------------------------------------- /examples/dm_control/ppo.yaml: -------------------------------------------------------------------------------- 1 | episode_length: 25 2 | lr: 5e-4 3 | critic_lr: 5e-4 4 | gamma: 0.99 5 | ppo_epoch: 5 6 | use_valuenorm: true 7 | entropy_coef: 0.0 8 | hidden_size: 128 9 | layer_N: 4 10 | data_chunk_length: 1 -------------------------------------------------------------------------------- /examples/envpool/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | 4 | Install envpool with: 5 | 6 | ``` shell 7 | pip install envpool 8 | ``` 9 | 10 | Note 1: envpool only supports Linux operating system. 11 | 12 | ## Usage 13 | 14 | You can use `OpenRL` to train Cartpole (envpool) via: 15 | 16 | ``` shell 17 | PYTHON_PATH train_ppo.py 18 | ``` 19 | 20 | You can also add custom wrappers in `envpool_wrapper.py`. Currently we have `VecAdapter` and `VecMonitor` wrappers. -------------------------------------------------------------------------------- /examples/gail/README.md: -------------------------------------------------------------------------------- 1 | ## Prepare Dataset 2 | 3 | Run following command to generate dataset for GAIL: `python gen_data.py`, then you will get a file named `data.pkl` in current folder. 4 | 5 | ## Train 6 | 7 | Run following command to train GAIL: `python train_gail.py --config cartpole_gail.yaml` 8 | 9 | With GAIL, we can even train the agent without expert action! 10 | Run following command to train GAIL without expert action: `python train_gail.py --config cartpole_gail_without_action.yaml` -------------------------------------------------------------------------------- /examples/gail/cartpole_gail.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | expert_data: "./data.pkl" 3 | reward_class: 4 | id: "GAILReward" -------------------------------------------------------------------------------- /examples/gail/cartpole_gail_without_action.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | expert_data: "./data.pkl" 3 | gail_use_action: false 4 | reward_class: 5 | id: "GAILReward" -------------------------------------------------------------------------------- /examples/gail/gen_data_v1.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used for generate offline data for GAIL. 3 | """ 4 | 5 | from openrl.envs.common import make 6 | from openrl.envs.vec_env.wrappers.gen_data import GenDataWrapper_v1 as GenDataWrapper 7 | from openrl.envs.wrappers.monitor import Monitor 8 | from openrl.modules.common import PPONet as Net 9 | from openrl.runners.common import PPOAgent as Agent 10 | 11 | env_wrappers = [ 12 | Monitor, 13 | ] 14 | 15 | 16 | def gen_data(total_episode): 17 | # begin to test 18 | # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. 19 | render_mode = None 20 | env = make( 21 | "CartPole-v1", 22 | render_mode=render_mode, 23 | env_num=9, 24 | asynchronous=True, 25 | env_wrappers=env_wrappers, 26 | ) 27 | 28 | agent = Agent(Net(env)) 29 | agent.load("ppo_agent") 30 | 31 | env = GenDataWrapper(env, data_save_path="data_v1.pkl", total_episode=total_episode) 32 | obs, info = env.reset() 33 | done = False 34 | while not done: 35 | # Based on environmental observation input, predict next action. 36 | action, _ = agent.act(obs, deterministic=True) 37 | obs, r, done, info = env.step(action) 38 | env.close() 39 | 40 | 41 | if __name__ == "__main__": 42 | gen_data(total_episode=50) 43 | -------------------------------------------------------------------------------- /examples/gail/test_dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import torch 19 | 20 | from openrl.datasets.expert_dataset import ExpertDataset 21 | 22 | 23 | def test_dataset(): 24 | dataset = ExpertDataset(file_name="data.pkl", seed=0) 25 | print("data length:", len(dataset)) 26 | print("data[0]:", dataset[0][0]) 27 | print("data[1]:", dataset[1][0]) 28 | print("data[len(data)-1]:", dataset[len(dataset) - 1][0]) 29 | 30 | data_loader = torch.utils.data.DataLoader( 31 | dataset=dataset, batch_size=128, shuffle=False, drop_last=True 32 | ) 33 | for batch_data in data_loader: 34 | expert_obs, expert_action = batch_data 35 | 36 | 37 | if __name__ == "__main__": 38 | test_dataset() 39 | -------------------------------------------------------------------------------- /examples/gfootball/README.md: -------------------------------------------------------------------------------- 1 | This is the guidance for [Google Research Football](https://github.com/google-research/football). 2 | 3 | ### Installation 4 | 5 | - `pip install gfootball` 6 | - `pip install tizero` 7 | - test the installation by `python3 -m gfootball.play_game --action_set=full`. 8 | 9 | ### Evaluate JiDi submissions locally 10 | 11 | If you want to evaluate your JiDi submissions locally, please try to use tizero as illustrated [here](https://github.com/OpenRL-Lab/TiZero#evaluate-jidi-submissions-locally). 12 | 13 | 14 | ### Convert dump file to video 15 | 16 | After the installation, you can use tizero to convert a dump file to a video file. 17 | The usage is `tizero dump2video --episode_length --render_type <2d/3d>`. 18 | 19 | You can download an example dump file from [here](http://jidiai.cn/daily_6484285/daily_6484285.dump). 20 | And then execute `tizero dump2video daily_6484285.dump ./` in your terminal. By default, the episode length is 3000 and the render type is 2d. 21 | Wait a minute, you will get a video file named `daily_6484285.avi` in your current directory. -------------------------------------------------------------------------------- /examples/gridworld/README.md: -------------------------------------------------------------------------------- 1 | ## How to Use 2 | 3 | Users can train Gridworld via: 4 | 5 | ```shell 6 | python train_dqn.py 7 | ``` 8 | 9 | tips: Gridworld is a simple environment for quickly verifying custom algorithm -------------------------------------------------------------------------------- /examples/gridworld/dqn_gridworld.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | lr: 7e-4 3 | gamma: 0.9 4 | episode_length: 100 5 | mini_batch_size: 32 6 | buffer_size: 5000 7 | train_interval: 10 8 | num_mini_batch: 50 9 | run_dir: ./run_results/ 10 | experiment_name: train_dqn 11 | log_interval: 10 12 | 13 | use_recurrent_policy: false 14 | use_joint_action_loss: false 15 | use_valuenorm: false 16 | use_adv_normalize: false 17 | wandb_entity: openrl-lab -------------------------------------------------------------------------------- /examples/gridworld/train_dqn.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | import numpy as np 4 | 5 | from openrl.configs.config import create_config_parser 6 | from openrl.envs.common import make 7 | from openrl.modules.common import DQNNet as Net 8 | from openrl.runners.common import DQNAgent as Agent 9 | 10 | 11 | def train(): 12 | # 添加读取配置文件的代码 13 | cfg_parser = create_config_parser() 14 | cfg = cfg_parser.parse_args(["--config", "dqn_gridworld.yaml"]) 15 | 16 | # 创建 环境 17 | env = make("GridWorldEnv", env_num=9) 18 | # 创建 神经网络 19 | net = Net(env, cfg=cfg) 20 | # 初始化训练器 21 | agent = Agent(net) 22 | # 开始训练 23 | agent.train(total_time_steps=10000) 24 | env.close() 25 | return agent 26 | 27 | 28 | def evaluation(agent): 29 | # 开始测试环境 30 | env = make("GridWorldEnv", env_num=1, asynchronous=True) 31 | agent.set_env(env) 32 | obs, info = env.reset() 33 | done = False 34 | step = 0 35 | while not np.any(done): 36 | # 智能体根据 observation 预测下一个动作 37 | action, _ = agent.act(obs) 38 | obs, r, done, info = env.step(action) 39 | step += 1 40 | print(f"{step}: reward:{np.mean(r)}") 41 | env.close() 42 | 43 | 44 | if __name__ == "__main__": 45 | agent = train() 46 | evaluation(agent) 47 | -------------------------------------------------------------------------------- /examples/gridworld/train_ppo.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | import numpy as np 4 | 5 | from openrl.configs.config import create_config_parser 6 | from openrl.envs.common import make 7 | from openrl.modules.common import PPONet as Net 8 | from openrl.runners.common import PPOAgent as Agent 9 | 10 | 11 | def train(): 12 | # create environment, set environment parallelism to 9 13 | env = make("GridWorldEnv", env_num=9) 14 | # create the neural network 15 | cfg_parser = create_config_parser() 16 | cfg = cfg_parser.parse_args() 17 | net = Net( 18 | env, 19 | cfg=cfg, 20 | ) 21 | # initialize the trainer 22 | agent = Agent(net) 23 | # start training, set total number of training steps to 20000 24 | agent.train(total_time_steps=20000) 25 | env.close() 26 | return agent 27 | 28 | 29 | def evaluation(agent): 30 | # begin to test 31 | # Create an environment for testing and set the number of environments to interact with to 9. Set rendering mode to group_human. 32 | env = make("GridWorldEnv", env_num=9, asynchronous=True) 33 | # The trained agent sets up the interactive environment it needs. 34 | agent.set_env(env) 35 | # Initialize the environment and get initial observations and environmental information. 36 | obs, info = env.reset() 37 | done = False 38 | step = 0 39 | while not np.any(done): 40 | # Based on environmental observation input, predict next action. 41 | action, _ = agent.act(obs, deterministic=True) 42 | obs, r, done, info = env.step(action) 43 | step += 1 44 | if step % 50 == 0: 45 | print(f"{step}: reward:{np.mean(r)}") 46 | env.close() 47 | 48 | 49 | if __name__ == "__main__": 50 | agent = train() 51 | evaluation(agent) 52 | -------------------------------------------------------------------------------- /examples/gym_pybullet_drones/README.md: -------------------------------------------------------------------------------- 1 | 2 | ### Installation 3 | 4 | - Python >= 3.10 5 | - Fellow the installation instruction of [gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones#installation). 6 | 7 | ### Train PPO 8 | 9 | ```shell 10 | python train_ppo.py 11 | ``` -------------------------------------------------------------------------------- /examples/gym_pybullet_drones/ppo.yaml: -------------------------------------------------------------------------------- 1 | episode_length: 500 2 | lr: 1e-3 3 | critic_lr: 1e-3 4 | gamma: 0.1 5 | ppo_epoch: 5 6 | use_valuenorm: true 7 | entropy_coef: 0.0 8 | hidden_size: 128 9 | layer_N: 4 10 | use_recurrent_policy: true -------------------------------------------------------------------------------- /examples/isaac/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | ### 1. Simulator 4 | 5 | Download [Nvidia Isaac Sim](https://developer.nvidia.com/isaac-sim) and install it. 6 | 7 | Note 1: If you download the container version of Nvidia Isaac Sim running in the cloud, simulation interface can't be visualized. 8 | 9 | Note 2: Latest version Isaac Sim 2022.2.1 provides a built-in Python 3.7 environment that packages can use, similar to a system-level Python install. We recommend using this Python environment when running the Python scripts. 10 | 11 | ### 2. RL tasks 12 | Install [Omniverse Isaac Gym Reinforcement Learning Environments for Isaac Sim](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs). This repository contains Reinforcement Learning tasks that can be run with the latest release of Isaac Sim. 13 | 14 | Please make sure you follow the above [repo](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs) and successfully run the training script: 15 | 16 | ``` shell 17 | PYTHON_PATH scripts/rlgames_train.py task=Ant headless=True 18 | ``` 19 | 20 | ## Usage 21 | 22 | `cfg` folder provides Cartpole task configs in Isaac Sim following [Omniverse Isaac Gym Reinforcement Learning Environments for Isaac Sim](https://github.com/NVIDIA-Omniverse/OmniIsaacGymEnvs). 23 | 24 | You can use `OpenRL` to train Isaac Sim Cartpole via: 25 | 26 | ``` shell 27 | PYTHON_PATH train_ppo.py 28 | ``` -------------------------------------------------------------------------------- /examples/isaac/cfg/config.yaml: -------------------------------------------------------------------------------- 1 | 2 | # Task name - used to pick the class to load 3 | task_name: ${task.name} 4 | # experiment name. defaults to name of training config 5 | experiment: '' 6 | 7 | # if set to positive integer, overrides the default number of environments 8 | num_envs: '' 9 | 10 | # seed - set to -1 to choose random seed 11 | seed: 42 12 | # set to True for deterministic performance 13 | torch_deterministic: False 14 | 15 | # set the maximum number of learning iterations to train for. overrides default per-environment setting 16 | max_iterations: '' 17 | 18 | ## Device config 19 | physics_engine: 'physx' 20 | # whether to use cpu or gpu pipeline 21 | pipeline: 'gpu' 22 | # whether to use cpu or gpu physx 23 | sim_device: 'gpu' 24 | # used for gpu simulation only - device id for running sim and task if pipeline=gpu 25 | device_id: 0 26 | # device to run RL 27 | rl_device: 'cuda:0' 28 | # multi-GPU training 29 | multi_gpu: False 30 | 31 | ## PhysX arguments 32 | num_threads: 4 # Number of worker threads per scene used by PhysX - for CPU PhysX only. 33 | solver_type: 1 # 0: pgs, 1: tgs 34 | 35 | # RLGames Arguments 36 | # test - if set, run policy in inference mode (requires setting checkpoint to load) 37 | test: False 38 | # used to set checkpoint path 39 | checkpoint: '' 40 | 41 | # disables rendering 42 | headless: False 43 | # enables native livestream 44 | enable_livestream: False 45 | # timeout for MT script 46 | mt_timeout: 30 47 | 48 | wandb_activate: False 49 | wandb_group: '' 50 | wandb_name: Cartpole # default name 51 | wandb_entity: '' 52 | wandb_project: 'omniisaacgymenvs' 53 | 54 | # set default task and default training config based on task 55 | defaults: 56 | - task: Cartpole 57 | - hydra/job_logging: disabled 58 | 59 | # set the directory where the output files get saved 60 | hydra: 61 | output_subdir: null 62 | run: 63 | dir: . 64 | 65 | -------------------------------------------------------------------------------- /examples/mpe/README.md: -------------------------------------------------------------------------------- 1 | ## How to Use 2 | 3 | Train MPE with [MAPPO](https://arxiv.org/abs/2103.01955) algorithm: 4 | 5 | ```shell 6 | python train_ppo.py --config mpe_ppo.yaml 7 | ``` 8 | 9 | Train MPE with [JRPO](https://arxiv.org/abs/2302.07515) algorithm: 10 | 11 | ```shell 12 | python train_ppo.py --config mpe_jrpo.yaml 13 | ``` 14 | 15 | 16 | Train MPE with [MAT](https://arxiv.org/abs/2205.14953) algorithm: 17 | 18 | ```shell 19 | python train_mat.py --config mpe_mat.yaml 20 | ``` 21 | 22 | 23 | Train MPE with [VDN](https://arxiv.org/abs/1706.05296) algorithm: 24 | 25 | ```shell 26 | python train_vdn.py --config mpe_vdn.yaml 27 | ``` -------------------------------------------------------------------------------- /examples/mpe/mpe_jrpo.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | lr: 7e-4 3 | critic_lr: 7e-4 4 | episode_length: 25 5 | run_dir: ./run_results/ 6 | experiment_name: train_mpe_jrpo 7 | log_interval: 10 8 | use_recurrent_policy: true 9 | use_joint_action_loss: true 10 | use_valuenorm: true 11 | use_adv_normalize: true 12 | wandb_entity: openrl-lab -------------------------------------------------------------------------------- /examples/mpe/mpe_mat.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | lr: 7e-4 3 | episode_length: 25 4 | run_dir: ./run_results/ 5 | experiment_name: train_mpe_mat 6 | log_interval: 10 7 | use_valuenorm: true 8 | use_adv_normalize: true 9 | wandb_entity: openrl-lab -------------------------------------------------------------------------------- /examples/mpe/mpe_ppo.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | lr: 7e-4 3 | critic_lr: 7e-4 4 | episode_length: 25 5 | run_dir: ./run_results/ 6 | experiment_name: train_mpe 7 | log_interval: 10 8 | use_recurrent_policy: true 9 | use_joint_action_loss: false 10 | use_valuenorm: true 11 | use_adv_normalize: true 12 | wandb_entity: openrl-lab -------------------------------------------------------------------------------- /examples/mpe/mpe_vdn.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | lr: 7e-4 3 | episode_length: 200 4 | num_mini_batch: 20 5 | run_dir: ./run_results/ 6 | experiment_name: train_mpe_vdn 7 | log_interval: 10 8 | use_valuenorm: true 9 | use_adv_normalize: true 10 | wandb_entity: openrl-lab 11 | callbacks: 12 | - id: "CheckpointCallback" 13 | args: { 14 | "save_freq": 500, # how often to save the model 15 | "save_path": "./results/checkpoints/", # where to save the model 16 | "name_prefix": "vdn", # the prefix of the saved model 17 | "save_replay_buffer": True # not work yet 18 | } 19 | - id: "EvalCallback" 20 | args: { 21 | "eval_env": {"id": "simple_spread","env_num":1}, # how many envs to set up for evaluation 22 | "n_eval_episodes": 4, # how many episodes to run for each evaluation 23 | "eval_freq": 500, # how often to run evaluation 24 | "log_path": "./results/eval_log_path", # where to save the evaluation results 25 | "best_model_save_path": "./results/best_model/", # where to save the best model 26 | "deterministic": True, # whether to use deterministic action 27 | "render": False, # whether to render the env 28 | "asynchronous": True, # whether to run evaluation asynchronously 29 | } -------------------------------------------------------------------------------- /examples/mpe/train_mat.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | import numpy as np 4 | 5 | from openrl.configs.config import create_config_parser 6 | from openrl.envs.common import make 7 | from openrl.envs.wrappers.mat_wrapper import MATWrapper 8 | from openrl.modules.common import MATNet as Net 9 | from openrl.runners.common import MATAgent as Agent 10 | 11 | 12 | def train(): 13 | # create environment 14 | env_num = 100 15 | env = make( 16 | "simple_spread", 17 | env_num=env_num, 18 | asynchronous=True, 19 | ) 20 | env = MATWrapper(env) 21 | 22 | # create the neural network 23 | cfg_parser = create_config_parser() 24 | cfg = cfg_parser.parse_args() 25 | net = Net(env, cfg=cfg, device="cuda") 26 | 27 | # initialize the trainer 28 | agent = Agent(net, use_wandb=True) 29 | # start training 30 | agent.train(total_time_steps=5000000) 31 | env.close() 32 | agent.save("./mat_agent/") 33 | return agent 34 | 35 | 36 | def evaluation(agent): 37 | # render_model = "group_human" 38 | render_model = None 39 | env_num = 9 40 | env = make( 41 | "simple_spread", render_mode=render_model, env_num=env_num, asynchronous=False 42 | ) 43 | env = MATWrapper(env) 44 | agent.load("./mat_agent/") 45 | agent.set_env(env) 46 | obs, info = env.reset(seed=0) 47 | done = False 48 | step = 0 49 | total_reward = 0 50 | while not np.any(done): 51 | # Based on environmental observation input, predict next action. 52 | action, _ = agent.act(obs, deterministic=True) 53 | obs, r, done, info = env.step(action) 54 | step += 1 55 | total_reward += np.mean(r) 56 | print(f"total_reward: {total_reward}") 57 | env.close() 58 | 59 | 60 | if __name__ == "__main__": 61 | agent = train() 62 | evaluation(agent) 63 | -------------------------------------------------------------------------------- /examples/mpe/train_ppo.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | import numpy as np 4 | 5 | from openrl.configs.config import create_config_parser 6 | from openrl.envs.common import make 7 | from openrl.modules.common import PPONet as Net 8 | from openrl.runners.common import PPOAgent as Agent 9 | 10 | 11 | def train(): 12 | # create environment 13 | env_num = 100 14 | env = make( 15 | "simple_spread", 16 | env_num=env_num, 17 | asynchronous=True, 18 | ) 19 | # create the neural network 20 | cfg_parser = create_config_parser() 21 | cfg = cfg_parser.parse_args() 22 | net = Net(env, cfg=cfg, device="cuda") 23 | # initialize the trainer 24 | agent = Agent(net, use_wandb=True) 25 | # start training, set total number of training steps to 5000000 26 | agent.train(total_time_steps=5000000) 27 | env.close() 28 | agent.save("./ppo_agent/") 29 | return agent 30 | 31 | 32 | def evaluation(agent): 33 | render_model = "group_human" 34 | env_num = 9 35 | env = make( 36 | "simple_spread", render_mode=render_model, env_num=env_num, asynchronous=False 37 | ) 38 | agent.load("./ppo_agent/") 39 | agent.set_env(env) 40 | obs, info = env.reset(seed=0) 41 | done = False 42 | step = 0 43 | total_reward = 0 44 | while not np.any(done): 45 | # Based on environmental observation input, predict next action. 46 | action, _ = agent.act(obs, deterministic=True) 47 | obs, r, done, info = env.step(action) 48 | step += 1 49 | total_reward += np.mean(r) 50 | print(f"total_reward: {total_reward}") 51 | env.close() 52 | 53 | 54 | if __name__ == "__main__": 55 | agent = train() 56 | evaluation(agent) 57 | -------------------------------------------------------------------------------- /examples/mpe/train_vdn.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | import numpy as np 4 | 5 | from openrl.configs.config import create_config_parser 6 | from openrl.envs.common import make 7 | from openrl.envs.wrappers.mat_wrapper import MATWrapper 8 | from openrl.modules.common import VDNNet as Net 9 | from openrl.runners.common import VDNAgent as Agent 10 | 11 | 12 | def train(): 13 | # create environment 14 | env_num = 100 15 | env = make( 16 | "simple_spread", 17 | env_num=env_num, 18 | asynchronous=True, 19 | ) 20 | env = MATWrapper(env) 21 | 22 | # create the neural network 23 | cfg_parser = create_config_parser() 24 | cfg = cfg_parser.parse_args() 25 | net = Net(env, cfg=cfg, device="cuda") 26 | 27 | # initialize the trainer 28 | agent = Agent(net, use_wandb=True) 29 | # start training 30 | agent.train(total_time_steps=5000000) 31 | env.close() 32 | agent.save("./vdn_agent/") 33 | return agent 34 | 35 | 36 | def evaluation(agent): 37 | # render_model = "group_human" 38 | render_model = None 39 | env_num = 9 40 | env = make( 41 | "simple_spread", render_mode=render_model, env_num=env_num, asynchronous=False 42 | ) 43 | env = MATWrapper(env) 44 | agent.load("./vdn_agent/") 45 | agent.set_env(env) 46 | obs, info = env.reset(seed=0) 47 | done = False 48 | step = 0 49 | total_reward = 0 50 | while not np.any(done): 51 | # Based on environmental observation input, predict next action. 52 | action, _ = agent.act(obs, deterministic=True) 53 | obs, r, done, info = env.step(action) 54 | step += 1 55 | total_reward += np.mean(r) 56 | print(f"total_reward: {total_reward}") 57 | env.close() 58 | 59 | 60 | if __name__ == "__main__": 61 | agent = train() 62 | evaluation(agent) 63 | -------------------------------------------------------------------------------- /examples/mujoco/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | `pip install mujoco` 4 | 5 | ## Usage 6 | 7 | ```shell 8 | python train_ppo.py 9 | ``` -------------------------------------------------------------------------------- /examples/nlp/README.md: -------------------------------------------------------------------------------- 1 | ## How to Use 2 | 3 | Users can train the dialog task via: 4 | 5 | ```shell 6 | python train_ppo.py --config nlp_ppo.yaml 7 | ``` 8 | 9 | Users can train the dialog task with deepspeed via: 10 | 11 | ```shell 12 | deepspeed train_ppo.py --config nlp_ppo_ds.yaml 13 | 14 | 15 | ``` 16 | 17 | After the training, users can chat with the agent via: 18 | 19 | ```shell 20 | python chat.py 21 | ``` 22 | 23 | 24 | ### Chat with other agents 25 | 26 | - Chat with [ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B): `python chat_6b.py` -------------------------------------------------------------------------------- /examples/nlp/chat.py: -------------------------------------------------------------------------------- 1 | from openrl.runners.common import ChatAgent as Agent 2 | 3 | 4 | def chat(): 5 | agent = Agent.load( 6 | "./ppo_agent", 7 | tokenizer="gpt2", 8 | ) 9 | history = [] 10 | print("Welcome to OpenRL!") 11 | while True: 12 | input_text = input("> User: ") 13 | if input_text in ["quit", "exit", "quit()", "exit()", "q"]: 14 | break 15 | elif input_text == "reset": 16 | history = [] 17 | print("Welcome to OpenRL!") 18 | continue 19 | response = agent.chat(input_text, history) 20 | print(f"> OpenRL Agent: {response}") 21 | history.append(input_text) 22 | history.append(response) 23 | 24 | 25 | if __name__ == "__main__": 26 | chat() 27 | -------------------------------------------------------------------------------- /examples/nlp/chat_6b.py: -------------------------------------------------------------------------------- 1 | from openrl.runners.common import Chat6BAgent as Agent 2 | 3 | 4 | def chat(): 5 | agent = Agent.load( 6 | "THUDM/chatglm-6b", 7 | device="cuda:0", 8 | ) 9 | history = [] 10 | print("Welcome to OpenRL!") 11 | while True: 12 | input_text = input("> User: ") 13 | if input_text in ["quit", "exit", "quit()", "exit()", "q"]: 14 | break 15 | elif input_text == "reset": 16 | history = [] 17 | print("Welcome to OpenRL!") 18 | continue 19 | response = agent.chat(input_text, history) 20 | print(f"> Agent: {response}") 21 | history.append(input_text) 22 | history.append(response) 23 | 24 | 25 | if __name__ == "__main__": 26 | chat() 27 | -------------------------------------------------------------------------------- /examples/nlp/ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 32, 3 | "train_micro_batch_size_per_gpu": 16, 4 | "steps_per_print": 10, 5 | "zero_optimization": { 6 | "stage": 2 7 | }, 8 | "fp16": {"enabled": false, "loss_scale_window": 100} 9 | } -------------------------------------------------------------------------------- /examples/nlp/eval_ds_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 32, 3 | "train_micro_batch_size_per_gpu": 16, 4 | "steps_per_print": 10, 5 | "zero_optimization": { 6 | "stage": 0, 7 | "offload_param": {"device": "cpu"} 8 | }, 9 | "fp16": {"enabled": false} 10 | } -------------------------------------------------------------------------------- /examples/nlp/nlp_ppo.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | lr: 1e-7 3 | critic_lr: 1e-7 4 | run_dir: ./run_results/ 5 | log_interval: 1 6 | use_valuenorm: true 7 | use_adv_normalize: true 8 | wandb_entity: "openrl-lab" 9 | ppo_epoch: 5 10 | episode_length: 128 11 | num_mini_batch: 20 12 | 13 | hidden_size: 1 14 | 15 | model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog 16 | env: 17 | args: { 18 | 'tokenizer_path': 'gpt2', 19 | 'data_path': 'daily_dialog', 20 | } 21 | vec_info_class: 22 | id: "NLPVecInfo" 23 | reward_class: 24 | id: "NLPReward" 25 | args: { 26 | "ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog", 27 | "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier", 28 | } 29 | -------------------------------------------------------------------------------- /examples/nlp/nlp_ppo_ds.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | lr: 1e-7 3 | critic_lr: 1e-7 4 | run_dir: ./run_results/ 5 | log_interval: 1 6 | use_valuenorm: true 7 | use_adv_normalize: true 8 | wandb_entity: "openrl-lab" 9 | ppo_epoch: 5 10 | episode_length: 128 11 | num_mini_batch: 20 12 | 13 | hidden_size: 1 14 | 15 | use_deepspeed: true 16 | use_fp16: false 17 | use_offload: false 18 | deepspeed_config: ds_config.json 19 | 20 | model_path: rajkumarrrk/gpt2-fine-tuned-on-daily-dialog 21 | env: 22 | args: { 23 | 'tokenizer_path': 'gpt2', 24 | 'data_path': 'daily_dialog', 25 | } 26 | vec_info_class: 27 | id: "NLPVecInfo" 28 | reward_class: 29 | id: "NLPReward" 30 | args: { 31 | "use_deepspeed": true, 32 | "ref_ds_config": "eval_ds_config.json", 33 | "ref_model": "rajkumarrrk/gpt2-fine-tuned-on-daily-dialog", 34 | "intent_ds_config": "eval_ds_config.json", 35 | "intent_model": "rajkumarrrk/roberta-daily-dialog-intent-classifier", 36 | } 37 | -------------------------------------------------------------------------------- /examples/nlp/train_ppo.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | from openrl.configs.config import create_config_parser 4 | from openrl.envs.common import make 5 | from openrl.modules.common import PPONet as Net 6 | from openrl.modules.networks.policy_network_gpt import PolicyNetworkGPT as PolicyNetwork 7 | from openrl.modules.networks.value_network_gpt import ValueNetworkGPT as ValueNetwork 8 | from openrl.runners.common import PPOAgent as Agent 9 | 10 | 11 | def train(): 12 | # create environment 13 | cfg_parser = create_config_parser() 14 | try: 15 | import deepspeed 16 | 17 | cfg_parser = deepspeed.add_config_arguments(cfg_parser) 18 | except: 19 | print("choose not to use deepspeed in the nlp task") 20 | cfg = cfg_parser.parse_args() 21 | 22 | env_num = 5 23 | env = make( 24 | "daily_dialog", 25 | env_num=env_num, 26 | asynchronous=True, 27 | cfg=cfg, 28 | ) 29 | 30 | # create the neural network 31 | model_dict = {"policy": PolicyNetwork, "critic": ValueNetwork} 32 | net = Net(env, device="cuda", cfg=cfg, model_dict=model_dict) 33 | 34 | # initialize the trainer 35 | agent = Agent(net, use_wandb=True) 36 | 37 | # start training 38 | agent.train(total_time_steps=100000) 39 | agent.save("./ppo_agent") 40 | 41 | env.close() 42 | return agent 43 | 44 | 45 | if __name__ == "__main__": 46 | agent = train() 47 | -------------------------------------------------------------------------------- /examples/retro/retro_env/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from typing import Callable, List, Optional, Union 20 | 21 | import retro 22 | from gymnasium import Env 23 | 24 | from examples.retro.retro_env.retro_convert import RetroWrapper 25 | from openrl.envs.common import build_envs 26 | 27 | retro_all_envs = retro.data.list_games() 28 | 29 | 30 | def make_retro_envs( 31 | id: str, 32 | env_num: int = 1, 33 | render_mode: Optional[Union[str, List[str]]] = None, 34 | **kwargs, 35 | ) -> List[Callable[[], Env]]: 36 | from openrl.envs.wrappers import ( 37 | AutoReset, 38 | DictWrapper, 39 | RemoveTruncated, 40 | Single2MultiAgentWrapper, 41 | ) 42 | 43 | env_wrappers = [ 44 | DictWrapper, 45 | Single2MultiAgentWrapper, 46 | AutoReset, 47 | RemoveTruncated, 48 | ] 49 | 50 | env_fns = build_envs( 51 | make=RetroWrapper, 52 | id=id, 53 | env_num=env_num, 54 | render_mode=render_mode, 55 | wrappers=env_wrappers, 56 | **kwargs, 57 | ) 58 | 59 | return env_fns 60 | -------------------------------------------------------------------------------- /examples/retro/train_retro.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | import numpy as np 4 | from custom_registration import make 5 | 6 | from openrl.modules.common import PPONet as Net 7 | from openrl.runners.common import PPOAgent as Agent 8 | 9 | 10 | def train(): 11 | # Create an environment. If multiple environments need to be run in parallel, set the asynchronous parameter to True. 12 | # If you need to specify a level, you can set the state parameter which is specific to each game. 13 | env = make("Airstriker-Genesis", state="Level1", env_num=2, asynchronous=True) 14 | # create the neural network 15 | net = Net(env, device="cuda") 16 | # initialize the trainer 17 | agent = Agent(net) 18 | # start training 19 | agent.train(total_time_steps=2000) 20 | # close the environment 21 | env.close() 22 | return agent 23 | 24 | 25 | def game_test(agent): 26 | # begin to test 27 | env = make( 28 | "Airstriker-Genesis", 29 | state="Level1", 30 | render_mode="group_human", 31 | env_num=4, 32 | asynchronous=True, 33 | ) 34 | agent.set_env(env) 35 | obs, info = env.reset() 36 | done = False 37 | step = 0 38 | while True: 39 | # Based on environmental observation input, predict next action. 40 | action, _ = agent.act(obs, deterministic=True) 41 | obs, r, done, info = env.step(action) 42 | step += 1 43 | print(f"{step}: reward:{np.mean(r)}") 44 | 45 | if any(done): 46 | break 47 | 48 | env.close() 49 | 50 | 51 | if __name__ == "__main__": 52 | agent = train() 53 | game_test(agent) 54 | -------------------------------------------------------------------------------- /examples/sac/README.md: -------------------------------------------------------------------------------- 1 | ## How to Use 2 | 3 | To train with [SAC](https://arxiv.org/abs/1812.05905): 4 | 5 | ```shell 6 | python train_sac_beta.py --config sac.yaml 7 | ``` -------------------------------------------------------------------------------- /examples/sac/ddpg.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | episode_length: 300 3 | layer_N: 5 4 | log_interval: 5 5 | gamma: 0.99 6 | var: 3 7 | actor_lr: 0.0003 8 | critic_lr: 0.0003 9 | hidden_size: 256 -------------------------------------------------------------------------------- /examples/sac/sac.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | alpha_value: 0. 3 | layer_N: 6 4 | -------------------------------------------------------------------------------- /examples/sb3/README.md: -------------------------------------------------------------------------------- 1 | Load and use [stable-baseline3 models](https://huggingface.co/sb3) from huggingface. 2 | 3 | ## Installation 4 | 5 | ```bash 6 | pip install huggingface-tool 7 | pip install rl_zoo3 8 | ``` 9 | 10 | ## Download sb3 model from huggingface 11 | 12 | ```bash 13 | htool save-repo sb3/ppo-CartPole-v1 ppo-CartPole-v1 14 | ``` 15 | 16 | ## Use OpenRL to load the model trained by sb3 and then evaluate it 17 | 18 | ```bash 19 | python test_model.py 20 | ``` 21 | 22 | ## Use OpenRL to load the model trained by sb3 and then train it 23 | 24 | ```bash 25 | python train_ppo.py 26 | ``` 27 | 28 | 29 | -------------------------------------------------------------------------------- /examples/sb3/ppo.yaml: -------------------------------------------------------------------------------- 1 | use_share_model: true 2 | sb3_model_path: ppo-CartPole-v1/ppo-CartPole-v1.zip 3 | sb3_algo: ppo 4 | entropy_coef: 0.0 5 | gae_lambda: 0.8 6 | gamma: 0.98 7 | lr: 0.001 8 | episode_length: 32 9 | ppo_epoch: 20 10 | log_interval: 20 11 | log_each_episode: False 12 | 13 | callbacks: 14 | - id: "EvalCallback" 15 | args: { 16 | "eval_env": { "id": "CartPole-v1","env_num": 5 }, # how many envs to set up for evaluation 17 | "n_eval_episodes": 20, # how many episodes to run for each evaluation 18 | "eval_freq": 500, # how often to run evaluation 19 | "log_path": "./results/eval_log_path", # where to save the evaluation results 20 | "best_model_save_path": "./results/best_model/", # where to save the best model 21 | "deterministic": True, # whether to use deterministic action 22 | "render": False, # whether to render the env 23 | "asynchronous": True, # whether to run evaluation asynchronously 24 | "stop_logic": "OR", # the logic to stop training, OR means training stops when any one of the conditions is met, AND means training stops when all conditions are met 25 | } -------------------------------------------------------------------------------- /examples/selfplay/README.md: -------------------------------------------------------------------------------- 1 | ## Install 2 | 3 | ```shell 4 | pip install "openrl[selfplay]" 5 | ``` 6 | 7 | ## How to Use 8 | 9 | Users can train Tic-Tac-Toe via: 10 | 11 | ```shell 12 | python train_selfplay.py 13 | ``` 14 | 15 | 16 | ## Play with a Trained Agent 17 | 18 | Users can play with a trained agent via: 19 | 20 | ```shell 21 | python human_vs_agent.py 22 | ``` 23 | 24 | 25 | ## Evaluate Trained Agents 26 | 27 | If you want to evaluate your trained agents, please try to use OpenRL Arena as illustrated [here](https://openrl-docs.readthedocs.io/en/latest/arena/index.html) -------------------------------------------------------------------------------- /examples/selfplay/opponent_templates/random_opponent/opponent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from openrl.selfplay.opponents.random_opponent import RandomOpponent as Opponent 19 | 20 | if __name__ == "__main__": 21 | from pettingzoo.classic import tictactoe_v3 22 | 23 | opponent1 = Opponent() 24 | opponent2 = Opponent() 25 | env = tictactoe_v3.env(render_mode="human") 26 | opponent1.reset(env, "player_1") 27 | opponent2.reset(env, "player_2") 28 | player2opponent = {"player_1": opponent1, "player_2": opponent2} 29 | 30 | env.reset() 31 | for player_name in env.agent_iter(): 32 | observation, reward, termination, truncation, info = env.last() 33 | if termination: 34 | break 35 | action = player2opponent[player_name].act( 36 | player_name, observation, reward, termination, truncation, info 37 | ) 38 | print(player_name, action, type(action)) 39 | env.step(action) 40 | -------------------------------------------------------------------------------- /examples/selfplay/opponent_templates/tictactoe_opponent/info.json: -------------------------------------------------------------------------------- 1 | { 2 | "opponent_type": "tictactoe_opponent", 3 | "description": "used for tictactoe game, need to load a nerual network" 4 | } -------------------------------------------------------------------------------- /examples/selfplay/selfplay.yaml: -------------------------------------------------------------------------------- 1 | globals: 2 | selfplay_api_host: 127.0.0.1 3 | selfplay_api_port: 13486 4 | 5 | seed: 0 6 | selfplay_api: 7 | host: {{ selfplay_api_host }} 8 | port: {{ selfplay_api_port }} 9 | lazy_load_opponent: true # if true, when the opponents are the same opponent_type, will only load the weight. Otherwise, will load the python script. 10 | callbacks: 11 | - id: "ProgressBarCallback" 12 | - id: "SelfplayAPI" 13 | args: { 14 | host: {{ selfplay_api_host }}, 15 | port: {{ selfplay_api_port }}, 16 | sample_strategy: "RandomOpponent", 17 | } 18 | - id: "SelfplayCallback" 19 | args: { 20 | "save_freq": 100, # how often to save the model 21 | "opponent_pool_path": "./opponent_pool/", # where to save opponents 22 | "name_prefix": "opponent", # the prefix of the saved model 23 | "api_address": "http://{{ selfplay_api_host }}:{{ selfplay_api_port }}/selfplay/", 24 | "opponent_template": "./opponent_templates/tictactoe_opponent", 25 | "clear_past_opponents": true, 26 | "copy_script_file": false, 27 | "verbose": 2, 28 | } -------------------------------------------------------------------------------- /examples/selfplay/tictactoe_utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /examples/smac/README.md: -------------------------------------------------------------------------------- 1 | ## Installation 2 | 3 | Installation guide for Linux: 4 | 5 | - `pip install pysc2` 6 | - Download `StarCraftII.zip` from [Google Drive](https://drive.google.com/drive/folders/1umnlFotrXdEnmTUqfzGoJfJ7kAKf-eKO). 7 | - unzip `StarCraftII.zip` to `~/StarCraftII/`: `unzip StarCraftII.zip -d ~/` 8 | - If something is wrong with protobuf, you can do this: `pip install protobuf==3.20.3` 9 | 10 | ## Usage 11 | 12 | Train SMAC with [MAPPO](https://arxiv.org/abs/2103.01955) algorithm: 13 | 14 | `python train_ppo.py --config smac_ppo.yaml` 15 | 16 | ## Render replay on Mac 17 | 18 | -------------------------------------------------------------------------------- /examples/smac/smac_ppo.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | 3 | run_dir: ./run_results/ 4 | experiment_name: smac_mappo 5 | 6 | lr: 5e-4 7 | critic_lr: 5e-4 8 | 9 | episode_length: 400 10 | ppo_epoch: 5 11 | log_interval: 10 12 | attn_size: 128 13 | 14 | use_recurrent_policy: true 15 | use_joint_action_loss: false 16 | use_valuenorm: true 17 | use_adv_normalize: true 18 | use_value_active_masks: false 19 | 20 | wandb_entity: openrl-lab 21 | 22 | vec_info_class: 23 | id: "SMACInfo" -------------------------------------------------------------------------------- /examples/smacv2/smacv2_ppo.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | 3 | run_dir: ./run_results/ 4 | experiment_name: smac_mappo 5 | 6 | lr: 5e-4 7 | critic_lr: 5e-4 8 | 9 | episode_length: 400 10 | ppo_epoch: 5 11 | log_interval: 10 12 | attn_size: 128 13 | 14 | 15 | use_recurrent_policy: true 16 | use_joint_action_loss: false 17 | use_valuenorm: true 18 | use_adv_normalize: true 19 | use_value_active_masks: false 20 | 21 | wandb_entity: openrl-lab 22 | 23 | vec_info_class: 24 | id: "SMACInfo" -------------------------------------------------------------------------------- /examples/snake/README.md: -------------------------------------------------------------------------------- 1 | 2 | This is the example for the snake game. 3 | 4 | ### Installation 5 | 6 | ```bash 7 | pip install "openrl[selfplay]" 8 | ``` 9 | 10 | ### Usage 11 | 12 | ```bash 13 | python train_selfplay.py 14 | ``` 15 | 16 | ### Evaluate JiDi submissions locally 17 | 18 | ```bash 19 | python jidi_eval.py 20 | ``` 21 | 22 | ## Submit to JiDi 23 | 24 | Submition site: http://www.jidiai.cn/env_detail?envid=1. 25 | 26 | Snake senarios: [here](https://github.com/jidiai/ai_lib/blob/7a6986f0cb543994277103dbf605e9575d59edd6/env/config.json#L94) 27 | Original Snake environment: [here](https://github.com/jidiai/ai_lib/blob/master/env/snakes.py) 28 | 29 | 30 | 31 | 32 | ### Evaluate Google Research Football submissions for JiDi locally 33 | 34 | If you want to evaluate your Google Research Football submissions for JiDi locally, please try to use tizero as illustrated [here](foothttps://github.com/OpenRL-Lab/TiZero#evaluate-jidi-submissions-locally). -------------------------------------------------------------------------------- /examples/snake/selfplay.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | callbacks: 3 | - id: "ProgressBarCallback" 4 | -------------------------------------------------------------------------------- /examples/snake/submissions/random_agent/submission.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | def sample_single_dim(action_space_list_each, is_act_continuous): 3 | if is_act_continuous: 4 | each = action_space_list_each.sample() 5 | else: 6 | if action_space_list_each.__class__.__name__ == "Discrete": 7 | each = [0] * action_space_list_each.n 8 | idx = action_space_list_each.sample() 9 | each[idx] = 1 10 | elif action_space_list_each.__class__.__name__ == "MultiDiscreteParticle": 11 | each = [] 12 | nvec = action_space_list_each.high - action_space_list_each.low + 1 13 | sample_indexes = action_space_list_each.sample() 14 | 15 | for i in range(len(nvec)): 16 | dim = nvec[i] 17 | new_action = [0] * dim 18 | index = sample_indexes[i] 19 | new_action[index] = 1 20 | each.extend(new_action) 21 | return each 22 | 23 | 24 | def my_controller(observation, action_space, is_act_continuous): 25 | joint_action = [] 26 | for i in range(len(action_space)): 27 | player = sample_single_dim(action_space[i], is_act_continuous) 28 | joint_action.append(player) 29 | 30 | return joint_action 31 | -------------------------------------------------------------------------------- /examples/snake/submissions/rl/README.md: -------------------------------------------------------------------------------- 1 | # Download actor weight 2 | 3 | Please download [actor_2000.pth](https://github.com/CarlossShi/Competition_3v3snakes/tree/master/agent/rl) before use this code. -------------------------------------------------------------------------------- /examples/toy_env/README.md: -------------------------------------------------------------------------------- 1 | This is examples for training RL agents on a toy environment. 2 | 3 | ## How to Use 4 | 5 | To train with PPO: 6 | ```shell 7 | python train_ppo.py 8 | ``` 9 | 10 | To train with DQN: 11 | 12 | ```shell 13 | python train_dqn.py 14 | ``` 15 | 16 | To train with DDPG: 17 | 18 | ```shell 19 | python train_ddpg.py 20 | ``` 21 | 22 | To train with SAC: 23 | 24 | ```shell 25 | python train_sac.py 26 | ``` 27 | 28 | -------------------------------------------------------------------------------- /examples/toy_env/train.yaml: -------------------------------------------------------------------------------- 1 | seed: 0 2 | episode_length: 20 3 | layer_N: 5 4 | alpha_value: 0.01 5 | log_interval: 1 6 | gamma: 0.9 7 | var: 2 -------------------------------------------------------------------------------- /examples/toy_env/train_ddpg.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | from train_ppo import evaluation, train 4 | 5 | from openrl.modules.common import DDPGNet as Net 6 | from openrl.runners.common import DDPGAgent as Agent 7 | 8 | if __name__ == "__main__": 9 | agent = train(Agent, Net, "IdentityEnvcontinuous", 10, 20000) 10 | evaluation(agent, "IdentityEnvcontinuous") 11 | -------------------------------------------------------------------------------- /examples/toy_env/train_dqn.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | from train_ppo import evaluation, train 4 | 5 | from openrl.modules.common import DQNNet as Net 6 | from openrl.runners.common import DQNAgent as Agent 7 | 8 | if __name__ == "__main__": 9 | agent = train(Agent, Net, "IdentityEnv", 9, 20000) 10 | evaluation(agent, "IdentityEnv") 11 | -------------------------------------------------------------------------------- /examples/toy_env/train_ppo.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | from train_and_eval import evaluation, train 4 | 5 | from openrl.modules.common import PPONet as Net 6 | from openrl.runners.common import PPOAgent as Agent 7 | 8 | if __name__ == "__main__": 9 | agent = train(Agent, Net, "IdentityEnvcontinuous", 10, 1000) 10 | evaluation(agent, "IdentityEnvcontinuous") 11 | -------------------------------------------------------------------------------- /examples/toy_env/train_sac.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | from train_ppo import evaluation, train 4 | 5 | from openrl.modules.common import SACNet as Net 6 | from openrl.runners.common import SACAgent as Agent 7 | 8 | if __name__ == "__main__": 9 | agent = train(Agent, Net, "IdentityEnvcontinuous", 10, 5000) 10 | evaluation(agent, "IdentityEnvcontinuous") 11 | # test_env() 12 | -------------------------------------------------------------------------------- /openrl/__init__.py: -------------------------------------------------------------------------------- 1 | __TITLE__ = "openrl" 2 | __VERSION__ = "v0.2.1" 3 | __DESCRIPTION__ = "Distributed Deep RL Framework" 4 | __AUTHOR__ = "OpenRL Contributors" 5 | __EMAIL__ = "huangsy1314@163.com" 6 | __version__ = __VERSION__ 7 | 8 | import platform 9 | 10 | python_version_list = list(map(int, platform.python_version_tuple())) 11 | assert python_version_list >= [ 12 | 3, 13 | 8, 14 | 0, 15 | ], ( 16 | "OpenRL requires Python 3.8 or newer, but your Python is" 17 | f" {platform.python_version()}" 18 | ) 19 | -------------------------------------------------------------------------------- /openrl/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/algorithms/mat.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from openrl.algorithms.ppo import PPOAlgorithm 19 | 20 | 21 | class MATAlgorithm(PPOAlgorithm): 22 | def construct_loss_list(self, policy_loss, dist_entropy, value_loss, turn_on): 23 | loss_list = [] 24 | 25 | loss = ( 26 | policy_loss 27 | - dist_entropy * self.entropy_coef 28 | + value_loss * self.value_loss_coef 29 | ) 30 | loss_list.append(loss) 31 | 32 | return loss_list 33 | 34 | def get_data_generator(self, buffer, advantages): 35 | data_generator = buffer.feed_forward_generator_transformer( 36 | advantages, self.num_mini_batch 37 | ) 38 | return data_generator 39 | -------------------------------------------------------------------------------- /openrl/arena/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from typing import Callable, Optional 19 | 20 | import openrl 21 | from openrl.arena.two_player_arena import TwoPlayerArena 22 | from openrl.envs import pettingzoo_all_envs 23 | 24 | 25 | def make_arena( 26 | env_id: str, 27 | custom_build_env: Optional[Callable] = None, 28 | render: Optional[bool] = False, 29 | use_tqdm: Optional[bool] = True, 30 | **kwargs, 31 | ): 32 | if custom_build_env is None: 33 | from openrl.envs import PettingZoo 34 | 35 | if ( 36 | env_id in pettingzoo_all_envs 37 | or env_id in PettingZoo.registration.pettingzoo_env_dict.keys() 38 | ): 39 | from openrl.envs.PettingZoo import make_PettingZoo_env 40 | 41 | render_mode = None 42 | if render: 43 | render_mode = "human" 44 | env_fn = make_PettingZoo_env(env_id, render_mode=render_mode, **kwargs) 45 | else: 46 | raise ValueError(f"Unknown env_id: {env_id}") 47 | else: 48 | env_fn = custom_build_env(env_id, render, **kwargs) 49 | 50 | return TwoPlayerArena(env_fn, use_tqdm=use_tqdm) 51 | -------------------------------------------------------------------------------- /openrl/arena/agents/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/arena/agents/base_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from abc import ABC, abstractmethod 20 | from typing import Any, Dict 21 | 22 | from openrl.selfplay.opponents.base_opponent import BaseOpponent 23 | from openrl.selfplay.selfplay_api.opponent_model import BattleHistory, BattleResult 24 | 25 | 26 | class BaseAgent(ABC): 27 | def __init__(self): 28 | self.batch_history = BattleHistory() 29 | 30 | def new_agent(self) -> BaseOpponent: 31 | agent = self._new_agent() 32 | return agent 33 | 34 | @abstractmethod 35 | def _new_agent(self) -> BaseOpponent: 36 | raise NotImplementedError 37 | 38 | def add_battle_result(self, result: BattleResult): 39 | self.batch_history.update(result) 40 | 41 | def get_battle_info(self) -> Dict[str, Any]: 42 | return self.batch_history.get_battle_info() 43 | -------------------------------------------------------------------------------- /openrl/arena/agents/jidi_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from openrl.arena.agents.base_agent import BaseAgent 19 | from openrl.selfplay.opponents.base_opponent import BaseOpponent 20 | from openrl.selfplay.opponents.utils import load_opponent_from_jidi_path 21 | 22 | 23 | class JiDiAgent(BaseAgent): 24 | def __init__(self, local_agent_path, player_num: int = 1): 25 | super().__init__() 26 | self.local_agent_path = local_agent_path 27 | self.player_num = player_num 28 | 29 | def _new_agent(self) -> BaseOpponent: 30 | return load_opponent_from_jidi_path( 31 | self.local_agent_path, player_num=self.player_num 32 | ) 33 | -------------------------------------------------------------------------------- /openrl/arena/agents/local_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from openrl.arena.agents.base_agent import BaseAgent 19 | from openrl.selfplay.opponents.base_opponent import BaseOpponent 20 | from openrl.selfplay.opponents.utils import load_opponent_from_path 21 | 22 | 23 | class LocalAgent(BaseAgent): 24 | def __init__(self, local_agent_path): 25 | super().__init__() 26 | self.local_agent_path = local_agent_path 27 | 28 | def _new_agent(self) -> BaseOpponent: 29 | return load_opponent_from_path(self.local_agent_path) 30 | -------------------------------------------------------------------------------- /openrl/arena/agents/random_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from openrl.arena.agents.base_agent import BaseAgent 19 | from openrl.selfplay.opponents.base_opponent import BaseOpponent 20 | from openrl.selfplay.opponents.random_opponent import RandomOpponent 21 | from openrl.selfplay.opponents.utils import load_opponent_from_path 22 | 23 | 24 | class RandomAgent(BaseAgent): 25 | def __init__(self): 26 | super().__init__() 27 | 28 | def _new_agent(self) -> BaseOpponent: 29 | return RandomOpponent() 30 | -------------------------------------------------------------------------------- /openrl/arena/games/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/arena/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from pathlib import Path 19 | from typing import Callable, Union 20 | 21 | from openrl.selfplay.opponents.utils import load_opponent_from_path 22 | 23 | 24 | def load_agent(agent_path: Union[str, Path]) -> Callable: 25 | def _load_agent(): 26 | opponent = load_opponent_from_path(agent_path) 27 | return opponent 28 | 29 | return _load_agent 30 | -------------------------------------------------------------------------------- /openrl/buffers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2021 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from .normal_buffer import NormalReplayBuffer 20 | from .offpolicy_buffer import OffPolicyReplayBuffer 21 | 22 | __all__ = [ 23 | "NormalReplayBuffer", 24 | "OffPolicyReplayBuffer", 25 | ] 26 | -------------------------------------------------------------------------------- /openrl/buffers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/cli/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/cli/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import numpy as np 19 | 20 | from openrl.configs.config import create_config_parser 21 | from openrl.envs.common import make 22 | from openrl.modules.common import PPONet as Net 23 | from openrl.runners.common import PPOAgent as Agent 24 | 25 | 26 | def train_agent(env: str, total_time_steps: int = 20000): 27 | render_model = "rgb_array" 28 | env_num = 9 29 | env = make(env, render_mode=render_model, env_num=env_num, asynchronous=False) 30 | cfg_parser = create_config_parser() 31 | cfg = cfg_parser.parse_args([]) 32 | net = Net(env, cfg=cfg) 33 | agent = Agent(net, use_wandb=False) 34 | agent.train(total_time_steps=total_time_steps) 35 | 36 | agent.set_env(env) 37 | obs, info = env.reset() 38 | done = False 39 | step = 0 40 | total_reward = 0 41 | while not np.any(done): 42 | action, _ = agent.act(obs, deterministic=True) 43 | obs, r, done, info = env.step(action) 44 | total_reward += np.mean(r) 45 | step += 1 46 | print(f"Total reward: {total_reward}") 47 | 48 | env.close() 49 | -------------------------------------------------------------------------------- /openrl/configs/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/drivers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/drivers/base_driver.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class BaseDriver(ABC): 5 | @abstractmethod 6 | def __init__(self) -> None: 7 | raise NotImplementedError 8 | 9 | @abstractmethod 10 | def run(self, *args, **kwargs): 11 | raise NotImplementedError 12 | -------------------------------------------------------------------------------- /openrl/drivers/offline_driver.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import numpy as np 20 | 21 | from openrl.drivers.onpolicy_driver import OnPolicyDriver 22 | 23 | 24 | class OfflineDriver(OnPolicyDriver): 25 | def add2buffer(self, data): 26 | infos = data["infos"] 27 | offline_actions = [] 28 | for i, info in enumerate(infos): 29 | if "data_action" not in info: 30 | assert np.all(data["dones"][i]) 31 | data_action = info["final_info"]["data_action"] 32 | else: 33 | data_action = info["data_action"] 34 | offline_actions.append(data_action) 35 | offline_actions = np.stack(offline_actions, axis=0) 36 | 37 | data["actions"] = offline_actions 38 | super().add2buffer(data) 39 | -------------------------------------------------------------------------------- /openrl/envs/PettingZoo/registration.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from typing import Any, Dict, Optional 19 | 20 | pettingzoo_env_dict: Dict[str, Any] = {} 21 | 22 | 23 | def register(id: str, EnvClass): 24 | pettingzoo_env_dict[id] = EnvClass 25 | -------------------------------------------------------------------------------- /openrl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | mpe_all_envs = [ 2 | "simple_spread", 3 | ] 4 | nlp_all_envs = [ 5 | "daily_dialog", 6 | "fake_dialog_data", 7 | ] 8 | super_mario_all_envs = [ 9 | "SuperMarioBros", 10 | ] 11 | 12 | connect_all_envs = [ 13 | "connect3", 14 | "connect4", 15 | ] 16 | 17 | toy_all_envs = [ 18 | "BitFlippingEnv", 19 | "IdentityEnv", 20 | "IdentityEnvcontinuous", 21 | "IdentityEnvBox", 22 | "SimpleMultiObsEnv", 23 | "SimpleMultiObsEnv", 24 | ] 25 | gridworld_all_envs = ["GridWorldEnv", "GridWorldEnvRandomGoal"] 26 | 27 | offline_all_envs = ["OfflineEnv"] 28 | 29 | pettingzoo_all_envs = ["tictactoe_v3", "snakes_1v1", "snakes_3v3"] 30 | -------------------------------------------------------------------------------- /openrl/envs/common/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from .build_envs import build_envs 20 | from .registration import make 21 | 22 | __all__ = ["make", "build_envs"] 23 | -------------------------------------------------------------------------------- /openrl/envs/connect_env/connect3_env.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from openrl.envs.connect_env.base_connect_env import BaseConnectEnv 4 | 5 | 6 | class Connect3Env(BaseConnectEnv): 7 | def _get_board_size(self) -> Tuple[int, int]: 8 | return 3, 3 9 | 10 | def _get_num2win(self) -> int: 11 | return 3 12 | 13 | 14 | if __name__ == "__main__": 15 | env = Connect3Env(env_name="connect3") 16 | obs, info = env.reset() 17 | obs, reward, done, _, info = env.step(1, is_enemy=True) 18 | env.close() 19 | -------------------------------------------------------------------------------- /openrl/envs/connect_env/connect4_env.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | from openrl.envs.connect_env.base_connect_env import BaseConnectEnv 4 | 5 | 6 | class Connect4Env(BaseConnectEnv): 7 | def _get_board_size(self) -> Tuple[int, int]: 8 | return 4, 4 9 | 10 | def _get_num2win(self) -> int: 11 | return 4 12 | 13 | 14 | if __name__ == "__main__": 15 | env = Connect4Env(env_name="connect4") 16 | obs, info = env.reset() 17 | obs, reward, done, _, info = env.step(1, is_enemy=True) 18 | env.close() 19 | -------------------------------------------------------------------------------- /openrl/envs/crafter/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from typing import Callable, List, Optional, Union 20 | 21 | from gymnasium import Env 22 | 23 | from openrl.envs.common import build_envs 24 | from openrl.envs.crafter.crafter import CrafterWrapper 25 | 26 | 27 | def make_crafter_envs( 28 | id: str, 29 | env_num: int = 1, 30 | render_mode: Optional[Union[str, List[str]]] = None, 31 | **kwargs, 32 | ) -> List[Callable[[], Env]]: 33 | from openrl.envs.wrappers import ( 34 | AutoReset, 35 | DictWrapper, 36 | RemoveTruncated, 37 | Single2MultiAgentWrapper, 38 | ) 39 | 40 | env_wrappers = [ 41 | DictWrapper, 42 | Single2MultiAgentWrapper, 43 | AutoReset, 44 | RemoveTruncated, 45 | ] 46 | 47 | env_fns = build_envs( 48 | make=CrafterWrapper, 49 | id=id, 50 | env_num=env_num, 51 | render_mode=render_mode, 52 | wrappers=env_wrappers, 53 | **kwargs, 54 | ) 55 | 56 | return env_fns 57 | -------------------------------------------------------------------------------- /openrl/envs/gridworld/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from typing import Callable, List, Optional, Union 19 | 20 | from gymnasium import Env 21 | 22 | from openrl.envs.common import build_envs 23 | from openrl.envs.gridworld.gridworld_env import make 24 | 25 | 26 | def make_gridworld_envs( 27 | id: str, 28 | env_num: int = 1, 29 | render_mode: Optional[Union[str, List[str]]] = None, 30 | **kwargs, 31 | ) -> List[Callable[[], Env]]: 32 | from openrl.envs.wrappers import RemoveTruncated, Single2MultiAgentWrapper 33 | 34 | env_wrappers = [ 35 | Single2MultiAgentWrapper, 36 | RemoveTruncated, 37 | ] 38 | env_fns = build_envs( 39 | make=make, 40 | id=id, 41 | env_num=env_num, 42 | render_mode=render_mode, 43 | wrappers=env_wrappers, 44 | **kwargs, 45 | ) 46 | return env_fns 47 | -------------------------------------------------------------------------------- /openrl/envs/mpe/__init__.py: -------------------------------------------------------------------------------- 1 | # MPE env fetched from https://github.com/marlbenchmark/on-policy/tree/main/onpolicy/envs/mpe 2 | from typing import Callable, List, Optional, Union 3 | 4 | from gymnasium import Env 5 | 6 | from openrl.envs.common import build_envs 7 | from openrl.envs.mpe.mpe_env import make 8 | 9 | 10 | def make_mpe_envs( 11 | id: str, 12 | env_num: int = 1, 13 | render_mode: Optional[Union[str, List[str]]] = None, 14 | **kwargs, 15 | ) -> List[Callable[[], Env]]: 16 | env_wrappers = [] 17 | env_fns = build_envs( 18 | make=make, 19 | id=id, 20 | env_num=env_num, 21 | render_mode=render_mode, 22 | wrappers=env_wrappers, 23 | **kwargs, 24 | ) 25 | 26 | return env_fns 27 | -------------------------------------------------------------------------------- /openrl/envs/mpe/mpe_env.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from gymnasium import Env 4 | 5 | from .multiagent_env import MultiAgentEnv 6 | from .scenarios import load 7 | 8 | 9 | def make( 10 | id: str, 11 | render_mode: Optional[str] = None, 12 | **kwargs: Any, 13 | ) -> Env: 14 | # load scenario from script 15 | scenario = load(id + ".py").Scenario() 16 | # create world 17 | 18 | world = scenario.make_world(render_mode=render_mode) 19 | # create multiagent environment 20 | env = MultiAgentEnv( 21 | world, 22 | scenario.reset_world, 23 | scenario.reward, 24 | scenario.observation, 25 | scenario.info, 26 | render_mode=render_mode, 27 | ) 28 | 29 | return env 30 | -------------------------------------------------------------------------------- /openrl/envs/mpe/scenario.py: -------------------------------------------------------------------------------- 1 | # defines scenario upon which the world is built 2 | class BaseScenario(object): 3 | # create elements of the world 4 | def make_world(self): 5 | raise NotImplementedError() 6 | 7 | # create initial conditions of the world 8 | def reset_world(self, world): 9 | raise NotImplementedError() 10 | 11 | def info(self, agent, world): 12 | return {} 13 | -------------------------------------------------------------------------------- /openrl/envs/mpe/scenarios/__init__.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import os.path as osp 3 | 4 | 5 | def load(name): 6 | pathname = osp.join(osp.dirname(__file__), name) 7 | return imp.load_source("", pathname) 8 | -------------------------------------------------------------------------------- /openrl/envs/nlp/__init__.py: -------------------------------------------------------------------------------- 1 | # MPE env fetched from https://github.com/marlbenchmark/on-policy/tree/main/onpolicy/envs/mpe 2 | from typing import Callable, List, Optional, Union 3 | 4 | from gymnasium import Env 5 | 6 | from openrl.envs.common import build_envs 7 | from openrl.envs.nlp.nlp_env import make 8 | 9 | 10 | def make_nlp_envs( 11 | id: str, 12 | env_num: int = 1, 13 | render_mode: Optional[Union[str, List[str]]] = None, 14 | **kwargs, 15 | ) -> List[Callable[[], Env]]: 16 | from openrl.envs.wrappers import AutoReset # DictWrapper, 17 | from openrl.envs.wrappers import RemoveTruncated, Single2MultiAgentWrapper 18 | 19 | env_wrappers = [ 20 | # DictWrapper, 21 | Single2MultiAgentWrapper, 22 | AutoReset, 23 | RemoveTruncated, 24 | ] 25 | env_fns = build_envs( 26 | make=make, 27 | id=id, 28 | env_num=env_num, 29 | render_mode=render_mode, 30 | wrappers=env_wrappers, 31 | **kwargs, 32 | ) 33 | return env_fns 34 | -------------------------------------------------------------------------------- /openrl/envs/nlp/nlp_env.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Optional 2 | 3 | from gymnasium import Env 4 | 5 | from .daily_dialog_env import DailyDialogEnv 6 | from .fake_dialog_env import FakeDialogEnv 7 | 8 | 9 | def make( 10 | id: str, 11 | render_mode: Optional[str] = None, 12 | cfg: Any = None, 13 | **kwargs: Any, 14 | ) -> Env: 15 | if id == "daily_dialog": 16 | env = DailyDialogEnv(cfg=cfg) 17 | elif id == "fake_dialog_data": 18 | env = FakeDialogEnv(cfg=cfg) 19 | else: 20 | raise NotImplementedError 21 | 22 | return env 23 | -------------------------------------------------------------------------------- /openrl/envs/nlp/rewards/meteor.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Dict 3 | 4 | import evaluate 5 | 6 | import openrl.envs.nlp as nlp 7 | 8 | 9 | class VirtualMetric: 10 | def compute(self, predictions: Any, references: Any) -> Dict[str, float]: 11 | return {"meteor": 0.0} 12 | 13 | 14 | class Meteor: 15 | def __init__(self, meteor_coeff: int, test: bool = False) -> None: 16 | super().__init__() 17 | self._meteor_coeff = meteor_coeff 18 | if test: 19 | self._metric = VirtualMetric() 20 | else: 21 | self._metric = evaluate.load( 22 | str(Path(nlp.__file__).parent / "utils/metrics/meteor.py") 23 | ) 24 | 25 | def __call__( 26 | self, 27 | data: Dict[str, Any], 28 | ): 29 | generated_texts = [data["generated_texts"]] 30 | reference_texts = [data["reference_texts"]] 31 | score = self._metric.compute( 32 | predictions=generated_texts, references=reference_texts 33 | )["meteor"] 34 | 35 | reward = score * self._meteor_coeff 36 | info = {"meteor": score} 37 | 38 | return reward, info 39 | -------------------------------------------------------------------------------- /openrl/envs/nlp/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/openrl/envs/nlp/utils/__init__.py -------------------------------------------------------------------------------- /openrl/envs/nlp/utils/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The TARTRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/envs/nlp/utils/sampler.py: -------------------------------------------------------------------------------- 1 | from collections import deque 2 | from typing import Any, List 3 | 4 | import numpy as np 5 | 6 | 7 | class PrioritySampler: 8 | def __init__(self, max_size: int = None, priority_scale: float = 0.0): 9 | """ 10 | Creates a priority sampler 11 | 12 | Args: 13 | max_size (int): maximum size of the queue 14 | priority_scale (float): 0.0 is a pure uniform sampling, 1.0 is completely priority sampling 15 | """ 16 | self.max_size = max_size 17 | self.items = deque(maxlen=self.max_size) 18 | self.item_priorities = deque(maxlen=self.max_size) 19 | self.priority_scale = priority_scale 20 | 21 | def add(self, item: Any, priority: float): 22 | self.items.append(item) 23 | self.item_priorities.append(priority) 24 | 25 | def sample(self, size: int) -> List[Any]: 26 | min_sample_size = min(len(self.items), size) 27 | scaled_item_priorities = np.array(self.item_priorities) ** self.priority_scale 28 | sample_probs = scaled_item_priorities / np.sum(scaled_item_priorities) 29 | samples = np.random.choice(a=self.items, p=sample_probs, size=min_sample_size) 30 | return samples 31 | 32 | def update(self, item: Any, priority: float): 33 | index = self.items.index(item) 34 | del self.items[index] 35 | del self.item_priorities[index] 36 | self.add(item, priority) 37 | 38 | def get_all_samples(self) -> List[Any]: 39 | return self.items 40 | -------------------------------------------------------------------------------- /openrl/envs/nlp/utils/text_generation_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | from abc import abstractclassmethod 3 | from dataclasses import dataclass 4 | from typing import Any, Dict, List 5 | 6 | 7 | @dataclass(init=True) 8 | class Sample: 9 | id: str 10 | prompt_or_input_text: str 11 | references: List[str] 12 | meta_data: Dict[str, Any] = None 13 | 14 | 15 | class TextGenPool: 16 | def __init__(self, samples: List[Sample]): 17 | self._samples = samples 18 | 19 | def __len__(self): 20 | return len(self._samples) 21 | 22 | def __getitem__(self, ix: int) -> Sample: 23 | if ix >= len(self): 24 | raise StopIteration 25 | sample = self._samples[ix] 26 | return sample, 1.0 27 | 28 | def sample(self) -> Sample: 29 | random_sample = random.choice(self._samples) 30 | return random_sample 31 | 32 | @abstractclassmethod 33 | def prepare(cls, **args) -> "TextGenPool": 34 | """ 35 | A factory method to instantiate data pool 36 | """ 37 | raise NotImplementedError 38 | 39 | def split(self, split_ratios: List[float]) -> List["TextGenPool"]: 40 | start_ix = 0 41 | pools = [] 42 | for ratio in split_ratios: 43 | count = int(len(self) * ratio) 44 | end_ix = start_ix + count 45 | pools.append(type(self)(self._samples[start_ix:end_ix])) 46 | start_ix = end_ix 47 | return pools 48 | -------------------------------------------------------------------------------- /openrl/envs/offline/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import copy 19 | from typing import Callable, List, Optional, Union 20 | 21 | from gymnasium import Env 22 | 23 | from openrl.envs.common import build_envs 24 | from openrl.envs.offline.offline_env import OfflineEnv 25 | 26 | 27 | def offline_make(dataset, render_mode, disable_env_checker, **kwargs): 28 | env_id = kwargs["env_id"] 29 | env_num = kwargs["env_num"] 30 | seed = kwargs.pop("seed", None) 31 | assert seed is not None, "seed must be set" 32 | env = OfflineEnv(dataset, env_id, env_num, seed) 33 | return env 34 | 35 | 36 | def make_offline_envs( 37 | dataset: str, 38 | env_num: int = 1, 39 | render_mode: Optional[Union[str, List[str]]] = None, 40 | **kwargs, 41 | ) -> List[Callable[[], Env]]: 42 | env_wrappers = copy.copy(kwargs.pop("env_wrappers", [])) 43 | env_wrappers += [] 44 | 45 | env_fns = build_envs( 46 | make=offline_make, 47 | id=dataset, 48 | env_num=env_num, 49 | render_mode=render_mode, 50 | wrappers=env_wrappers, 51 | need_env_id=True, 52 | **kwargs, 53 | ) 54 | return env_fns 55 | -------------------------------------------------------------------------------- /openrl/envs/snake/discrete.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from .space import Space 4 | 5 | 6 | class Discrete(Space): 7 | r"""A discrete space in :math:`\{ 0, 1, \\dots, n-1 \}`. 8 | Example:: 9 | >>> Discrete(2) 10 | """ 11 | 12 | def __init__(self, n): 13 | assert n >= 0 14 | self.n = n 15 | super(Discrete, self).__init__((), np.int64) 16 | 17 | def sample(self): 18 | return self.np_random.randint(self.n) 19 | 20 | def contains(self, x): 21 | if isinstance(x, int): 22 | as_int = x 23 | elif isinstance(x, (np.generic, np.ndarray)) and ( 24 | x.dtype.char in np.typecodes["AllInteger"] and x.shape == () 25 | ): 26 | as_int = int(x) 27 | else: 28 | return False 29 | return as_int >= 0 and as_int < self.n 30 | 31 | def __repr__(self): 32 | return "Discrete(%d)" % self.n 33 | 34 | def __eq__(self, other): 35 | return isinstance(other, Discrete) and self.n == other.n 36 | -------------------------------------------------------------------------------- /openrl/envs/snake/game.py: -------------------------------------------------------------------------------- 1 | # -*- coding:utf-8 -*- 2 | # 作者:zruizhi 3 | # 创建时间: 2020/7/10 10:24 上午 4 | # 描述: 5 | from abc import ABC, abstractmethod 6 | 7 | 8 | class Game(ABC): 9 | def __init__( 10 | self, 11 | n_player, 12 | is_obs_continuous, 13 | is_act_continuous, 14 | game_name, 15 | agent_nums, 16 | obs_type, 17 | ): 18 | self.n_player = n_player 19 | self.current_state = None 20 | self.all_observes = None 21 | self.is_obs_continuous = is_obs_continuous 22 | self.is_act_continuous = is_act_continuous 23 | self.game_name = game_name 24 | self.agent_nums = agent_nums 25 | self.obs_type = obs_type 26 | 27 | def get_config(self, player_id): 28 | raise NotImplementedError 29 | 30 | def get_render_data(self, current_state): 31 | return current_state 32 | 33 | def set_current_state(self, current_state): 34 | raise NotImplementedError 35 | 36 | @abstractmethod 37 | def is_terminal(self): 38 | raise NotImplementedError 39 | 40 | def get_next_state(self, all_action): 41 | raise NotImplementedError 42 | 43 | def get_reward(self, all_action): 44 | raise NotImplementedError 45 | 46 | @abstractmethod 47 | def step(self, all_action): 48 | raise NotImplementedError 49 | 50 | @abstractmethod 51 | def reset(self): 52 | raise NotImplementedError 53 | 54 | def set_action_space(self): 55 | raise NotImplementedError 56 | -------------------------------------------------------------------------------- /openrl/envs/super_mario/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from typing import Callable, List, Optional, Union 20 | 21 | from gymnasium import Env 22 | 23 | from openrl.envs.common import build_envs 24 | from openrl.envs.super_mario.super_mario_convert import SuperMarioWrapper 25 | 26 | 27 | def make_super_mario_envs( 28 | id: str, 29 | env_num: int = 1, 30 | render_mode: Optional[Union[str, List[str]]] = None, 31 | **kwargs, 32 | ) -> List[Callable[[], Env]]: 33 | from openrl.envs.wrappers import ( 34 | AutoReset, 35 | DictWrapper, 36 | RemoveTruncated, 37 | Single2MultiAgentWrapper, 38 | ) 39 | 40 | env_wrappers = [ 41 | DictWrapper, 42 | Single2MultiAgentWrapper, 43 | AutoReset, 44 | RemoveTruncated, 45 | ] 46 | 47 | env_fns = build_envs( 48 | make=SuperMarioWrapper, 49 | id=id, 50 | env_num=env_num, 51 | render_mode=render_mode, 52 | wrappers=env_wrappers, 53 | **kwargs, 54 | ) 55 | 56 | return env_fns 57 | -------------------------------------------------------------------------------- /openrl/envs/vec_env/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Type, Union 2 | 3 | from gymnasium import Env as GymEnv 4 | 5 | from openrl.envs.vec_env.async_venv import AsyncVectorEnv 6 | from openrl.envs.vec_env.base_venv import BaseVecEnv 7 | from openrl.envs.vec_env.sync_venv import SyncVectorEnv 8 | from openrl.envs.vec_env.wrappers.base_wrapper import VecEnvWrapper 9 | from openrl.envs.vec_env.wrappers.reward_wrapper import RewardWrapper 10 | from openrl.envs.vec_env.wrappers.vec_monitor_wrapper import VecMonitorWrapper 11 | 12 | __all__ = [ 13 | "BaseVecEnv", 14 | "SyncVectorEnv", 15 | "AsyncVectorEnv", 16 | "VecMonitorWrapper", 17 | "RewardWrapper", 18 | ] 19 | 20 | 21 | def unwrap_vec_wrapper( 22 | env: Union[GymEnv, BaseVecEnv], vec_wrapper_class: Type[VecEnvWrapper] 23 | ) -> Optional[VecEnvWrapper]: 24 | """ 25 | Retrieve a ``VecEnvWrapper`` object by recursively searching. 26 | 27 | :param env: 28 | :param vec_wrapper_class: 29 | :return: 30 | """ 31 | env_tmp = env 32 | while isinstance(env_tmp, VecEnvWrapper): 33 | if isinstance(env_tmp, vec_wrapper_class): 34 | return env_tmp 35 | env_tmp = env_tmp.venv 36 | return None 37 | 38 | 39 | def is_vecenv_wrapped( 40 | env: Union[GymEnv, BaseVecEnv], vec_wrapper_class: Type[VecEnvWrapper] 41 | ) -> bool: 42 | """ 43 | Check if an environment is already wrapped by a given ``VecEnvWrapper``. 44 | 45 | :param env: 46 | :param vec_wrapper_class: 47 | :return: 48 | """ 49 | return unwrap_vec_wrapper(env, vec_wrapper_class) is not None 50 | -------------------------------------------------------------------------------- /openrl/envs/vec_env/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/envs/vec_env/vec_info/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from openrl.envs.vec_env.base_venv import BaseVecEnv 4 | from openrl.envs.vec_env.vec_info.simple_vec_info import SimpleVecInfo 5 | 6 | registed_vec_info = { 7 | "default": SimpleVecInfo, 8 | } 9 | 10 | 11 | class VecInfoFactory: 12 | @staticmethod 13 | def get_vec_info_class(vec_info_class: Any, env: BaseVecEnv): 14 | VecInfoFactory.auto_register(vec_info_class) 15 | if vec_info_class is None or vec_info_class.id is None: 16 | return registed_vec_info["default"](env.parallel_env_num, env.agent_num) 17 | return registed_vec_info[vec_info_class.id]( 18 | env.parallel_env_num, env.agent_num, **vec_info_class.args 19 | ) 20 | 21 | @staticmethod 22 | def register(name: str, vec_info: Any): 23 | registed_vec_info[name] = vec_info 24 | 25 | @staticmethod 26 | def auto_register(vec_info_class: Any): 27 | if vec_info_class is None: 28 | return 29 | elif vec_info_class.id == "NLPVecInfo": 30 | from openrl.envs.vec_env.vec_info.nlp_vec_info import NLPVecInfo 31 | 32 | VecInfoFactory.register("NLPVecInfo", NLPVecInfo) 33 | elif vec_info_class.id == "EPS_RewardInfo": 34 | from openrl.envs.vec_env.vec_info.episode_rewards_info import EPS_RewardInfo 35 | 36 | VecInfoFactory.register("EPS_RewardInfo", EPS_RewardInfo) 37 | -------------------------------------------------------------------------------- /openrl/envs/vec_env/vec_info/base_vec_info.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict 3 | 4 | 5 | class BaseVecInfo(ABC): 6 | def __init__(self, parallel_env_num: int, agent_num: int): 7 | super(BaseVecInfo, self).__init__() 8 | self.parallel_env_num = parallel_env_num 9 | self.agent_num = agent_num 10 | 11 | @abstractmethod 12 | def statistics(self, buffer: Any) -> Dict[str, Any]: 13 | raise NotImplementedError 14 | 15 | @abstractmethod 16 | def append(self, info: Dict[str, Any]) -> None: 17 | raise NotImplementedError 18 | 19 | @abstractmethod 20 | def reset(self) -> None: 21 | raise NotImplementedError 22 | -------------------------------------------------------------------------------- /openrl/envs/vec_env/vec_info/simple_vec_info.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Any, Dict 3 | 4 | import numpy as np 5 | 6 | from openrl.envs.vec_env.vec_info.base_vec_info import BaseVecInfo 7 | 8 | 9 | class SimpleVecInfo(BaseVecInfo): 10 | def __init__(self, parallel_env_num: int, agent_num: int): 11 | super().__init__(parallel_env_num, agent_num) 12 | 13 | self.infos = [] 14 | 15 | self.start_time = time.time() 16 | self.total_step = 0 17 | 18 | def statistics(self, buffer: Any) -> Dict[str, Any]: 19 | # this function should be called each episode 20 | rewards = buffer.data.rewards.copy() 21 | self.total_step += np.prod(rewards.shape[:2]) 22 | rewards = rewards.transpose(2, 1, 0, 3) 23 | info_dict = {} 24 | ep_rewards = [] 25 | for i in range(self.agent_num): 26 | agent_reward = rewards[i].mean(0).sum() 27 | ep_rewards.append(agent_reward) 28 | info_dict["agent_{}/rollout_episode_reward".format(i)] = agent_reward 29 | 30 | info_dict["FPS"] = int(self.total_step / (time.time() - self.start_time)) 31 | info_dict["rollout_episode_reward"] = np.mean(ep_rewards) 32 | return info_dict 33 | 34 | def append(self, info: Dict[str, Any]) -> None: 35 | self.infos.append(info) 36 | 37 | def reset(self) -> None: 38 | self.infos = [] 39 | self.rewards = [] 40 | -------------------------------------------------------------------------------- /openrl/envs/vec_env/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/envs/vec_env/wrappers/vec_monitor_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from typing import Any, Dict, Optional 19 | 20 | from gymnasium.core import ActType 21 | 22 | from openrl.envs.vec_env.base_venv import BaseVecEnv 23 | from openrl.envs.vec_env.vec_info.base_vec_info import BaseVecInfo 24 | from openrl.envs.vec_env.wrappers.base_wrapper import VecEnvWrapper 25 | 26 | 27 | class VecMonitorWrapper(VecEnvWrapper): 28 | def __init__(self, vec_info: BaseVecInfo, env: BaseVecEnv): 29 | super().__init__(env) 30 | self.vec_info = vec_info 31 | 32 | @property 33 | def use_monitor(self): 34 | return True 35 | 36 | def step(self, action: ActType, extra_data: Optional[Dict[str, Any]] = None): 37 | returns = self.env.step(action, extra_data) 38 | 39 | self.vec_info.append(info=returns[-1]) 40 | 41 | return returns 42 | 43 | def statistics(self, buffer): # TODO 44 | # this function should be called each episode 45 | info_dict = self.vec_info.statistics(buffer) 46 | self.vec_info.reset() 47 | return info_dict 48 | -------------------------------------------------------------------------------- /openrl/envs/vec_env/wrappers/zero_reward_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import numpy as np 20 | 21 | from openrl.envs.vec_env.wrappers.base_wrapper import ArrayType, VectorRewardWrapper 22 | 23 | 24 | class ZeroRewardWrapper(VectorRewardWrapper): 25 | def reward(self, reward: ArrayType) -> ArrayType: 26 | return np.zeros_like(reward) 27 | -------------------------------------------------------------------------------- /openrl/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_wrapper import BaseObservationWrapper, BaseRewardWrapper, BaseWrapper 2 | from .extra_wrappers import ( 3 | AutoReset, 4 | DictWrapper, 5 | FlattenObservation, 6 | GIFWrapper, 7 | MoveActionMask2InfoWrapper, 8 | RemoveTruncated, 9 | ) 10 | from .multiagent_wrapper import Single2MultiAgentWrapper 11 | 12 | __all__ = [ 13 | "BaseWrapper", 14 | "DictWrapper", 15 | "BaseObservationWrapper", 16 | "Single2MultiAgentWrapper", 17 | "AutoReset", 18 | "RemoveTruncated", 19 | "GIFWrapper", 20 | "BaseRewardWrapper", 21 | "MoveActionMask2InfoWrapper", 22 | "FlattenObservation", 23 | ] 24 | -------------------------------------------------------------------------------- /openrl/envs/wrappers/image_wrappers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from gymnasium.spaces.box import Box 20 | 21 | from openrl.envs.wrappers.base_wrapper import BaseObservationWrapper 22 | 23 | 24 | class TransposeImage(BaseObservationWrapper): 25 | def __init__(self, env=None, op=[2, 0, 1]): 26 | """ 27 | Transpose observation space for images 28 | """ 29 | super(TransposeImage, self).__init__(env) 30 | assert len(op) == 3, "Error: Operation, " + str(op) + ", must be dim3" 31 | self.op = op 32 | obs_shape = self.observation_space.shape 33 | self.observation_space = Box( 34 | self.observation_space.low[0, 0, 0], 35 | self.observation_space.high[0, 0, 0], 36 | [obs_shape[self.op[0]], obs_shape[self.op[1]], obs_shape[self.op[2]]], 37 | dtype=self.observation_space.dtype, 38 | ) 39 | 40 | def observation(self, ob): 41 | return ob.transpose(self.op[0], self.op[1], self.op[2]) 42 | -------------------------------------------------------------------------------- /openrl/modules/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/modules/base_module.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from abc import ABC, abstractmethod 19 | 20 | import torch 21 | from torch.nn.parallel import DistributedDataParallel as DDP 22 | 23 | 24 | class BaseModule(ABC): 25 | def __init__(self, cfg) -> None: 26 | self.cfg = cfg 27 | self.models = {} 28 | self.optimizers = {} 29 | 30 | @abstractmethod 31 | def lr_decay(self, episode: int, episodes: int) -> None: 32 | raise NotImplementedError 33 | 34 | @abstractmethod 35 | def restore(self, model_dir: str) -> None: 36 | raise NotImplementedError 37 | 38 | @abstractmethod 39 | def save(self, save_dir: str) -> None: 40 | raise NotImplementedError 41 | 42 | def convert_distributed_model(self) -> None: 43 | for model_name in self.models: 44 | model = self.models[model_name] 45 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 46 | model = DDP(model, device_ids=[self.device], find_unused_parameters=True) 47 | self.models[model_name] = model 48 | -------------------------------------------------------------------------------- /openrl/modules/common/__init__.py: -------------------------------------------------------------------------------- 1 | from .a2c_net import A2CNet 2 | from .base_net import BaseNet 3 | from .bc_net import BCNet 4 | from .ddpg_net import DDPGNet 5 | from .dqn_net import DQNNet 6 | from .gail_net import GAILNet 7 | from .mat_net import MATNet 8 | from .ppo_net import PPONet 9 | from .sac_net import SACNet 10 | from .vdn_net import VDNNet 11 | 12 | __all__ = [ 13 | "BaseNet", 14 | "PPONet", 15 | "DQNNet", 16 | "MATNet", 17 | "DDPGNet", 18 | "VDNNet", 19 | "GAILNet", 20 | "BCNet", 21 | "SACNet", 22 | "A2CNet", 23 | ] 24 | -------------------------------------------------------------------------------- /openrl/modules/common/a2c_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from openrl.modules.common.ppo_net import PPONet 19 | 20 | 21 | class A2CNet(PPONet): 22 | pass 23 | -------------------------------------------------------------------------------- /openrl/modules/common/base_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from abc import ABC 20 | 21 | 22 | class BaseNet(ABC): 23 | def __init__(self): 24 | self.first_reset = False 25 | -------------------------------------------------------------------------------- /openrl/modules/common/bc_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from typing import Any, Dict, Optional, Union 20 | 21 | import torch 22 | 23 | from openrl.envs.vec_env.wrappers.reward_wrapper import RewardWrapper 24 | from openrl.modules.base_module import BaseModule 25 | from openrl.modules.bc_module import BCModule 26 | from openrl.modules.common.ppo_net import PPONet 27 | 28 | # Network for Behavior Cloning 29 | 30 | 31 | class BCNet(PPONet): 32 | def __init__( 33 | self, 34 | env: RewardWrapper, 35 | cfg=None, 36 | device: Union[torch.device, str] = "cpu", 37 | n_rollout_threads: int = 1, 38 | model_dict: Optional[Dict[str, Any]] = None, 39 | module_class: type(BaseModule) = BCModule, 40 | ) -> None: 41 | super().__init__(env, cfg, device, n_rollout_threads, model_dict, module_class) 42 | self.env = env 43 | -------------------------------------------------------------------------------- /openrl/modules/common/gail_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from typing import Any, Dict, Optional, Union 20 | 21 | import torch 22 | 23 | from openrl.envs.vec_env.wrappers.reward_wrapper import RewardWrapper 24 | from openrl.modules.base_module import BaseModule 25 | from openrl.modules.common.ppo_net import PPONet 26 | from openrl.modules.gail_module import GAILModule 27 | from openrl.rewards.gail_reward import GAILReward 28 | 29 | 30 | class GAILNet(PPONet): 31 | def __init__( 32 | self, 33 | env: RewardWrapper, 34 | cfg=None, 35 | device: Union[torch.device, str] = "cpu", 36 | n_rollout_threads: int = 1, 37 | model_dict: Optional[Dict[str, Any]] = None, 38 | module_class: type(BaseModule) = GAILModule, 39 | ) -> None: 40 | super().__init__(env, cfg, device, n_rollout_threads, model_dict, module_class) 41 | assert isinstance( 42 | env.reward_class, GAILReward 43 | ), "env.reward_class must be GAILReward, but got {}".format(env.reward_class) 44 | env.reward_class.set_discriminator( 45 | self.cfg, self.module.models["gail_discriminator"] 46 | ) 47 | 48 | self.env = env 49 | -------------------------------------------------------------------------------- /openrl/modules/common/mat_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from typing import Any, Dict, Union 20 | 21 | import gymnasium as gym 22 | import torch 23 | 24 | from openrl.modules.common.ppo_net import PPONet 25 | from openrl.modules.networks.MAT_network import MultiAgentTransformer 26 | 27 | 28 | class MATNet(PPONet): 29 | def __init__( 30 | self, 31 | env: Union[gym.Env, str], 32 | cfg=None, 33 | device: Union[torch.device, str] = "cpu", 34 | n_rollout_threads: int = 1, 35 | model_dict: Dict[str, Any] = {"model": MultiAgentTransformer}, 36 | ) -> None: 37 | cfg.use_share_model = True 38 | super().__init__( 39 | env=env, 40 | cfg=cfg, 41 | device=device, 42 | n_rollout_threads=n_rollout_threads, 43 | model_dict=model_dict, 44 | ) 45 | -------------------------------------------------------------------------------- /openrl/modules/model_config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from typing import Optional 20 | 21 | import gym 22 | import torch 23 | 24 | 25 | class ModelConfig(dict): 26 | def __init__(self, *args, **kwargs) -> None: 27 | super(ModelConfig, self).__init__(*args, **kwargs) 28 | 29 | 30 | class ModelTrainConfig(ModelConfig): 31 | def __init__( 32 | self, 33 | model: torch.nn.Module, 34 | input_space: gym.spaces.Box, 35 | lr: Optional[float] = None, 36 | *args, 37 | **kwargs 38 | ) -> None: 39 | super(ModelTrainConfig, self).__init__(*args, **kwargs) 40 | self["model"] = model 41 | self["input_space"] = input_space 42 | self["lr"] = lr 43 | -------------------------------------------------------------------------------- /openrl/modules/networks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/modules/networks/base_policy_network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2022 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import torch.nn as nn 19 | 20 | from openrl.modules.utils.valuenorm import ValueNorm 21 | 22 | 23 | class BasePolicyNetwork(nn.Module): 24 | def __init__(self, cfg, device): 25 | super(BasePolicyNetwork, self).__init__() 26 | self.device = device 27 | self._use_valuenorm = cfg.use_valuenorm 28 | self._use_policy_vhead = cfg.use_policy_vhead 29 | 30 | if self._use_valuenorm: 31 | if self._use_policy_vhead: 32 | self.policy_value_normalizer = ValueNorm(1, device=self.device) 33 | else: 34 | if self._use_policy_vhead: 35 | self.policy_value_normalizer = None 36 | -------------------------------------------------------------------------------- /openrl/modules/networks/base_value_network.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2022 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from abc import ABC, abstractmethod 19 | 20 | import torch.nn as nn 21 | 22 | from openrl.modules.utils.valuenorm import ValueNorm 23 | 24 | 25 | class BaseValueNetwork(ABC, nn.Module): 26 | def __init__(self, cfg, device): 27 | super(BaseValueNetwork, self).__init__() 28 | self.device = device 29 | self._use_valuenorm = cfg.use_valuenorm 30 | 31 | if self._use_valuenorm: 32 | self.value_normalizer = ValueNorm(1, device=self.device) 33 | else: 34 | self.value_normalizer = None 35 | 36 | @abstractmethod 37 | def forward(self): 38 | raise NotImplementedError 39 | -------------------------------------------------------------------------------- /openrl/modules/networks/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/openrl/modules/networks/utils/__init__.py -------------------------------------------------------------------------------- /openrl/modules/networks/utils/distributed_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2022 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from torch import distributed as dist 20 | 21 | 22 | def reduce_tensor(tensor, n): 23 | rt = tensor.clone() 24 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 25 | rt /= n 26 | return rt 27 | -------------------------------------------------------------------------------- /openrl/modules/networks/utils/nlp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenRL-Lab/openrl/4c92aa440c12a6834c3fd76c574a341480868436/openrl/modules/networks/utils/nlp/__init__.py -------------------------------------------------------------------------------- /openrl/modules/networks/utils/util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def init(module, weight_init, bias_init, gain=1): 7 | weight_init(module.weight.data, gain=gain) 8 | if module.bias is not None: 9 | bias_init(module.bias.data) 10 | return module 11 | 12 | 13 | def get_clones(module, N): 14 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 15 | -------------------------------------------------------------------------------- /openrl/modules/networks/utils/vdn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class VDNBase(nn.Module): 6 | def __init__(self): 7 | super(VDNBase, self).__init__() 8 | 9 | def forward(self, agent_qs): 10 | return torch.sum(agent_qs, dim=1, keepdim=True) 11 | -------------------------------------------------------------------------------- /openrl/modules/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/modules/utils/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def get_grad_norm(it): 5 | sum_grad = 0 6 | for x in it: 7 | if x.grad is None: 8 | continue 9 | sum_grad += x.grad.norm() ** 2 10 | return math.sqrt(sum_grad) 11 | 12 | 13 | def update_linear_schedule(optimizer, epoch, total_num_epochs, initial_lr): 14 | """Decreases the learning rate linearly""" 15 | lr = initial_lr - (initial_lr * (epoch / float(total_num_epochs))) 16 | for param_group in optimizer.param_groups: 17 | param_group["lr"] = lr 18 | 19 | 20 | def huber_loss(e, d): 21 | a = (abs(e) <= d).float() 22 | b = (abs(e) > d).float() 23 | return a * e**2 / 2 + b * d * (abs(e) - d / 2) 24 | 25 | 26 | def mse_loss(e): 27 | return e**2 / 2 28 | -------------------------------------------------------------------------------- /openrl/rewards/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from openrl.envs.vec_env.base_venv import BaseVecEnv 4 | from openrl.rewards.base_reward import BaseReward 5 | 6 | registed_rewards = { 7 | "default": BaseReward, 8 | } 9 | 10 | 11 | class RewardFactory: 12 | @staticmethod 13 | def get_reward_class(reward_class: Any, env: BaseVecEnv): 14 | RewardFactory.auto_register(reward_class) 15 | if reward_class is None or reward_class.id is None: 16 | return registed_rewards["default"](env) 17 | return registed_rewards[reward_class.id](env, **reward_class.args) 18 | 19 | @staticmethod 20 | def register(reward_name, reward_class): 21 | registed_rewards.update({reward_name: reward_class}) 22 | 23 | @staticmethod 24 | def auto_register(reward_class: Any): 25 | if reward_class is None: 26 | return 27 | if reward_class.id == "NLPReward": 28 | from openrl.rewards.nlp_reward import NLPReward 29 | 30 | registed_rewards.update({"NLPReward": NLPReward}) 31 | elif reward_class.id == "GAILReward": 32 | from openrl.rewards.gail_reward import GAILReward 33 | 34 | registed_rewards.update({"GAILReward": GAILReward}) 35 | -------------------------------------------------------------------------------- /openrl/rewards/base_reward.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Union 2 | 3 | import numpy as np 4 | 5 | from openrl.envs.vec_env.base_venv import BaseVecEnv 6 | 7 | 8 | class BaseReward(object): 9 | def __init__(self, env: BaseVecEnv): 10 | self.step_rew_funcs = dict() 11 | self.inner_rew_funcs = dict() 12 | self.batch_rew_funcs = dict() 13 | 14 | def step_reward( 15 | self, data: Dict[str, Any] 16 | ) -> Union[np.ndarray, List[Dict[str, Any]]]: 17 | rewards = data["rewards"].copy() 18 | infos = [dict() for _ in range(rewards.shape[0])] 19 | 20 | return rewards, infos 21 | 22 | def batch_rewards(self, buffer: Any) -> Dict[str, Any]: 23 | infos = dict() 24 | 25 | return infos 26 | -------------------------------------------------------------------------------- /openrl/runners/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/runners/common/__init__.py: -------------------------------------------------------------------------------- 1 | from openrl.runners.common.a2c_agent import A2CAgent 2 | from openrl.runners.common.bc_agent import BCAgent 3 | from openrl.runners.common.chat_agent import Chat6BAgent, ChatAgent 4 | from openrl.runners.common.ddpg_agent import DDPGAgent 5 | from openrl.runners.common.dqn_agent import DQNAgent 6 | from openrl.runners.common.gail_agent import GAILAgent 7 | from openrl.runners.common.mat_agent import MATAgent 8 | from openrl.runners.common.ppo_agent import PPOAgent 9 | from openrl.runners.common.sac_agent import SACAgent 10 | from openrl.runners.common.vdn_agent import VDNAgent 11 | 12 | __all__ = [ 13 | "PPOAgent", 14 | "ChatAgent", 15 | "Chat6BAgent", 16 | "DQNAgent", 17 | "DDPGAgent", 18 | "MATAgent", 19 | "VDNAgent", 20 | "GAILAgent", 21 | "BCAgent", 22 | "SACAgent", 23 | "A2CAgent", 24 | ] 25 | -------------------------------------------------------------------------------- /openrl/runners/common/mat_agent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from typing import Type 19 | 20 | from openrl.algorithms.base_algorithm import BaseAlgorithm 21 | from openrl.algorithms.mat import MATAlgorithm 22 | from openrl.runners.common.base_agent import SelfAgent 23 | from openrl.runners.common.ppo_agent import PPOAgent 24 | from openrl.utils.logger import Logger 25 | 26 | 27 | class MATAgent(PPOAgent): 28 | def train( 29 | self: SelfAgent, 30 | total_time_steps: int, 31 | train_algo_class: Type[BaseAlgorithm] = MATAlgorithm, 32 | ) -> None: 33 | logger = Logger( 34 | cfg=self._cfg, 35 | project_name="MATAgent", 36 | scenario_name=self._env.env_name, 37 | wandb_entity=self._cfg.wandb_entity, 38 | exp_name=self.exp_name, 39 | log_path=self.run_dir, 40 | use_wandb=self._use_wandb, 41 | use_tensorboard=self._use_tensorboard, 42 | ) 43 | 44 | super(MATAgent, self).train( 45 | total_time_steps=total_time_steps, 46 | train_algo_class=train_algo_class, 47 | logger=logger, 48 | ) 49 | -------------------------------------------------------------------------------- /openrl/selfplay/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/selfplay/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/selfplay/callbacks/base_callback.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | 20 | from openrl.utils.callbacks.callbacks import BaseCallback 21 | 22 | 23 | class BaseSelfplayCallback(BaseCallback): 24 | pass 25 | -------------------------------------------------------------------------------- /openrl/selfplay/multiplayer_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/selfplay/opponents/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/selfplay/sample_strategy/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from openrl.selfplay.sample_strategy.last_opponent import LastOpponent 19 | from openrl.selfplay.sample_strategy.random_opponent import RandomOpponent 20 | 21 | sample_strategy_dict = { 22 | "LastOpponent": LastOpponent, 23 | "RandomOpponent": RandomOpponent, 24 | } 25 | 26 | 27 | class SampleStrategyFactory: 28 | def __init__(self): 29 | pass 30 | 31 | @staticmethod 32 | def register_sample_strategy(name, sample_strategy): 33 | sample_strategy_dict[name] = sample_strategy 34 | 35 | @staticmethod 36 | def get_sample_strategy(name): 37 | return sample_strategy_dict[name] 38 | -------------------------------------------------------------------------------- /openrl/selfplay/sample_strategy/base_sample_strategy.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | 20 | from abc import ABC, abstractmethod 21 | 22 | from openrl.selfplay.opponents.base_opponent import BaseOpponent 23 | from openrl.utils.custom_data_structure import ListDict 24 | 25 | 26 | class BaseSampleStrategy(ABC): 27 | def __init__(self): 28 | pass 29 | 30 | @abstractmethod 31 | def sample_opponent(self, opponents: ListDict) -> BaseOpponent: 32 | raise NotImplementedError 33 | -------------------------------------------------------------------------------- /openrl/selfplay/sample_strategy/last_opponent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from openrl.selfplay.opponents.base_opponent import BaseOpponent 20 | from openrl.selfplay.sample_strategy.base_sample_strategy import BaseSampleStrategy 21 | from openrl.utils.custom_data_structure import ListDict 22 | 23 | 24 | class LastOpponent(BaseSampleStrategy): 25 | def sample_opponent(self, opponents: ListDict) -> BaseOpponent: 26 | opponent_index = -1 27 | return opponents.get_by_index(opponent_index) 28 | -------------------------------------------------------------------------------- /openrl/selfplay/sample_strategy/random_opponent.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import random 20 | 21 | from openrl.selfplay.opponents.base_opponent import BaseOpponent 22 | from openrl.selfplay.sample_strategy.base_sample_strategy import BaseSampleStrategy 23 | 24 | 25 | class RandomOpponent(BaseSampleStrategy): 26 | def sample_opponent(self, opponents) -> BaseOpponent: 27 | opponent_index = random.randint(0, len(opponents) - 1) 28 | return opponents.get_by_index(opponent_index) 29 | -------------------------------------------------------------------------------- /openrl/selfplay/selfplay_api/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/selfplay/selfplay_api/base_api.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import logging 20 | from abc import ABC 21 | from typing import Any, Dict, Optional 22 | 23 | from fastapi import FastAPI 24 | from pydantic import BaseModel 25 | 26 | from openrl.selfplay.selfplay_api.opponent_model import OpponentModel 27 | from openrl.utils.custom_data_structure import ListDict 28 | 29 | 30 | class OpponentData(BaseModel): 31 | opponent_id: str 32 | opponent_info: Dict[str, str] 33 | 34 | 35 | class SkillData(BaseModel): 36 | opponent_id: str 37 | other_id: str 38 | result: int 39 | 40 | 41 | class SampleStrategyData(BaseModel): 42 | sample_strategy: str 43 | 44 | 45 | class BattleData(BaseModel): 46 | battle_info: Dict[str, Any] 47 | 48 | 49 | app = FastAPI() 50 | 51 | 52 | class BaseSelfplayAPIServer(ABC): 53 | def __init__(self): 54 | logger = logging.getLogger("ray.serve") 55 | logger.setLevel(logging.ERROR) 56 | self.opponents = ListDict() 57 | self.training_agent = OpponentModel("training_agent") 58 | self.sample_strategy = None 59 | -------------------------------------------------------------------------------- /openrl/selfplay/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/selfplay/wrappers/human_opponent_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import copy 19 | from typing import Optional 20 | 21 | from openrl.selfplay.wrappers.base_multiplayer_wrapper import BaseMultiPlayerWrapper 22 | 23 | 24 | class HumanOpponentWrapper(BaseMultiPlayerWrapper): 25 | def get_opponent_action( 26 | self, player_name, observation, reward, termination, truncation, info 27 | ): 28 | action = self.env.get_human_action( 29 | player_name, observation, termination, truncation, info 30 | ) 31 | action = [action] 32 | return action 33 | -------------------------------------------------------------------------------- /openrl/selfplay/wrappers/random_opponent_wrapper.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import copy 19 | from typing import Optional, Union 20 | 21 | import numpy as np 22 | from gymnasium import spaces 23 | from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType 24 | 25 | from openrl.selfplay.wrappers.base_multiplayer_wrapper import BaseMultiPlayerWrapper 26 | 27 | 28 | class RandomOpponentWrapper(BaseMultiPlayerWrapper): 29 | def get_opponent_action( 30 | self, player_name, observation, reward, termination, truncation, info 31 | ): 32 | mask = None 33 | if isinstance(observation, dict) and "action_mask" in observation: 34 | mask = observation["action_mask"] 35 | action_space = self.env.action_space(player_name) 36 | 37 | if isinstance(action_space, list): 38 | action = [] 39 | for space in action_space: 40 | action.append(space.sample(mask)) 41 | else: 42 | action = action_space.sample(mask) 43 | 44 | return action 45 | -------------------------------------------------------------------------------- /openrl/supports/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/supports/opendata/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/supports/opendata/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/supports/opengpu/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | -------------------------------------------------------------------------------- /openrl/utils/callbacks/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | 20 | from openrl.utils.callbacks.callbacks_factory import CallbackFactory 21 | 22 | __all__ = ["CallbackFactory"] 23 | -------------------------------------------------------------------------------- /openrl/utils/custom_data_structure.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | from collections import OrderedDict 19 | from typing import Any 20 | 21 | 22 | class ListDict(OrderedDict): 23 | def append(self, key: str, value: Any): 24 | self[key] = value 25 | 26 | def get_by_index(self, index): 27 | key = list(self.keys())[index] 28 | return self[key] 29 | -------------------------------------------------------------------------------- /openrl/utils/file_tool.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | from pathlib import Path 20 | from typing import List, Union 21 | 22 | 23 | def copy_files( 24 | source_files: Union[List[str], List[Path]], target_dir: Union[str, Path] 25 | ): 26 | if isinstance(target_dir, str): 27 | target_dir = Path(target_dir) 28 | target_dir.mkdir(parents=True, exist_ok=True) 29 | for source_file in source_files: 30 | if isinstance(source_file, str): 31 | source_file = Path(source_file) 32 | target_file = target_dir / source_file.name 33 | target_file.write_text(source_file.read_text()) 34 | 35 | 36 | def link_files( 37 | source_files: Union[List[str], List[Path]], target_dir: Union[str, Path] 38 | ): 39 | if isinstance(target_dir, str): 40 | target_dir = Path(target_dir) 41 | target_dir.mkdir(parents=True, exist_ok=True) 42 | for source_file in source_files: 43 | if isinstance(source_file, str): 44 | source_file = Path(source_file) 45 | target_file = target_dir / source_file.name 46 | target_file.symlink_to(source_file) 47 | -------------------------------------------------------------------------------- /openrl/utils/type_aliases.py: -------------------------------------------------------------------------------- 1 | # Modifed from https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/type_aliases.py 2 | 3 | """Common aliases for type hints""" 4 | 5 | from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union 6 | 7 | import gym 8 | import numpy as np 9 | import torch as th 10 | 11 | from openrl.envs import vec_env 12 | from openrl.utils.callbacks import callbacks 13 | 14 | GymEnv = Union[gym.Env, vec_env.BaseVecEnv] 15 | GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] 16 | 17 | TensorDict = Dict[Union[str, int], th.Tensor] 18 | OptimizerStateDict = Dict[str, Any] 19 | MaybeCallback = Union[ 20 | None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback 21 | ] 22 | 23 | 24 | class AgentActor(Protocol): 25 | def act( 26 | self, 27 | observation: Union[np.ndarray, Dict[str, np.ndarray]], 28 | deterministic: bool = False, 29 | ) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]: 30 | """ 31 | Get the policy action from an observation (and optional hidden state). 32 | Includes sugar-coating to handle different observations (e.g. normalizing images). 33 | 34 | :param observation: the input observation 35 | :param deterministic: Whether to return deterministic actions. 36 | :return: the model's action and the next hidden state 37 | (used in recurrent policies) 38 | """ 39 | -------------------------------------------------------------------------------- /openrl/utils/util.py: -------------------------------------------------------------------------------- 1 | import platform 2 | import random 3 | import re 4 | from typing import Dict 5 | 6 | import gymnasium as gym 7 | import numpy as np 8 | import torch 9 | 10 | import openrl 11 | 12 | 13 | def set_seed(seed): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def check(input): 21 | output = torch.from_numpy(input) if type(input) == np.ndarray else input 22 | return output 23 | 24 | 25 | def check_v2(input, use_half=False, tpdv=None): 26 | output = torch.from_numpy(input) if type(input) == np.ndarray else input 27 | if tpdv: 28 | output = output.to(**tpdv) 29 | if use_half: 30 | output = output.half() 31 | return output 32 | 33 | 34 | def _t2n(x): 35 | if isinstance(x, torch.Tensor): 36 | return x.detach().cpu().numpy() 37 | else: 38 | return x 39 | 40 | 41 | def get_system_info() -> Dict[str, str]: 42 | """ 43 | Retrieve system and python env info for the current system. 44 | 45 | :return: Dictionary summing up the version for each relevant package 46 | and a formatted string. 47 | """ 48 | 49 | env_info = { 50 | # In OS, a regex is used to add a space between a "#" and a number to avoid 51 | # wrongly linking to another issue on GitHub. 52 | "OS": re.sub(r"#(\d)", r"# \1", f"{platform.platform()} {platform.version()}"), 53 | "Python": platform.python_version(), 54 | "OpenRL": openrl.__version__, 55 | "PyTorch": torch.__version__, 56 | "GPU Enabled": str(torch.cuda.is_available()), 57 | "Numpy": np.__version__, 58 | "Gymnasium": gym.__version__, 59 | } 60 | return env_info 61 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | markers = 3 | unittest: for unittests 4 | -------------------------------------------------------------------------------- /scripts/build_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CPU_PARENT=cnstark/pytorch:2.0.0-py3.9.12-ubuntu20.04 4 | GPU_PARENT=cnstark/pytorch:2.0.0-py3.9.12-cuda11.8.0-ubuntu22.04 5 | 6 | TAG=openrllab/openrl 7 | VERSION=$(python -c "from openrl.__init__ import __version__;print(__version__)") 8 | 9 | if [[ ${USE_GPU} == "True" ]]; then 10 | PARENT=${GPU_PARENT} 11 | TAG="${TAG}" 12 | else 13 | PARENT=${CPU_PARENT} 14 | TAG="${TAG}-cpu" 15 | fi 16 | 17 | echo "docker build --build-arg PARENT_IMAGE=${PARENT} -t ${TAG}:${VERSION} . -f docker/Dockerfile" 18 | docker build --build-arg PARENT_IMAGE=${PARENT} -t ${TAG}:${VERSION} . -f docker/Dockerfile 19 | docker tag ${TAG}:${VERSION} ${TAG}:latest 20 | 21 | if [[ ${RELEASE} == "True" ]]; then 22 | docker push ${TAG}:${VERSION} 23 | docker push ${TAG}:latest 24 | fi 25 | -------------------------------------------------------------------------------- /scripts/conda_build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda build . -------------------------------------------------------------------------------- /scripts/conda_upload.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export PATH="~/anaconda3/bin:$PATH" 4 | 5 | VERSION=$(python setup.py --version) 6 | echo $VERSION 7 | deactivate 8 | conda init zsh 9 | conda activate base 10 | anaconda upload --user openrl ~/anaconda3/conda-bld/osx-64/openrl-v${VERSION}-py38_0.tar.bz2 11 | -------------------------------------------------------------------------------- /scripts/gen_api_docs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -r ./api_docs 4 | sphinx-apidoc -o ./api_docs openrl --force -H OpenRL -A OpenRL_Contributors 5 | python scripts/modify_api_docs.py -------------------------------------------------------------------------------- /scripts/pypi_build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | rm -r ./dist/ 4 | rm -r ./build/ 5 | python -m build -------------------------------------------------------------------------------- /scripts/pypi_upload.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | if [[ $1 = "test" ]]; then 4 | twine upload dist/* -r testpypi 5 | else 6 | twine upload dist/* 7 | fi -------------------------------------------------------------------------------- /scripts/unittest.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | pytest tests --cov=openrl --cov-report=xml -m unittest --cov-report=term-missing --durations=0 -v --color=yes 4 | -------------------------------------------------------------------------------- /tests/project/test_version.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import os 19 | import sys 20 | 21 | import pytest 22 | 23 | 24 | @pytest.mark.unittest 25 | def test_version(): 26 | import openrl 27 | 28 | assert hasattr(openrl, "__version__"), "openrl has no __version__ attribute" 29 | print(hasattr(openrl, "__version__"), openrl.__version__) 30 | 31 | 32 | if __name__ == "__main__": 33 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 34 | -------------------------------------------------------------------------------- /tests/test_cli/test_cli.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import os 20 | import sys 21 | 22 | import pytest 23 | 24 | 25 | @pytest.mark.unittest 26 | def test_version(): 27 | from openrl.cli.cli import print_version 28 | 29 | print_version(ctx=None, param=None, value=False) 30 | 31 | 32 | @pytest.mark.unittest 33 | def test_train(): 34 | from openrl.cli.train import train_agent 35 | 36 | train_agent(env="CartPole-v1", total_time_steps=1) 37 | 38 | 39 | if __name__ == "__main__": 40 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 41 | -------------------------------------------------------------------------------- /tests/test_env/test_connect_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import os 19 | import sys 20 | 21 | import pytest 22 | 23 | 24 | @pytest.mark.unittest 25 | def test_connect3(): 26 | from openrl.envs.common import make 27 | 28 | env = make("connect3", env_num=6) 29 | obs, _ = env.reset() 30 | obs, reward, done, info = env.step(env.random_action()) 31 | env.close() 32 | 33 | 34 | @pytest.mark.unittest 35 | def test_connect4(): 36 | from openrl.envs.common import make 37 | 38 | env = make("connect4", env_num=6) 39 | obs, _ = env.reset() 40 | obs, reward, done, info = env.step(env.random_action()) 41 | env.close() 42 | 43 | 44 | if __name__ == "__main__": 45 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 46 | -------------------------------------------------------------------------------- /tests/test_env/test_gridworld_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import os 19 | import sys 20 | 21 | import pytest 22 | 23 | 24 | @pytest.mark.unittest 25 | def test_gridworld(): 26 | from openrl.envs.common import make 27 | 28 | env = make("GridWorldEnv", env_num=6) 29 | obs, info = env.reset() 30 | obs, reward, done, info = env.step(env.random_action()) 31 | env.close() 32 | 33 | 34 | @pytest.mark.unittest 35 | def test_gridworldrandom(): 36 | from openrl.envs.common import make 37 | 38 | env = make("GridWorldEnvRandomGoal", env_num=6) 39 | obs, info = env.reset() 40 | obs, reward, done, info = env.step(env.random_action()) 41 | env.close() 42 | 43 | 44 | if __name__ == "__main__": 45 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 46 | -------------------------------------------------------------------------------- /tests/test_env/test_mpe_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import os 19 | import sys 20 | 21 | import numpy as np 22 | import pytest 23 | 24 | from openrl.envs.common import make 25 | 26 | 27 | @pytest.mark.unittest 28 | def test_mpe(): 29 | env_num = 3 30 | env = make("simple_spread", env_num=env_num) 31 | obs, info = env.reset() 32 | obs, reward, done, info = env.step(env.random_action()) 33 | assert env.agent_num == 3 34 | assert env.parallel_env_num == env_num 35 | env.close() 36 | 37 | 38 | @pytest.mark.unittest 39 | def test_mpe_render(): 40 | render_model = "human" 41 | env_num = 2 42 | env = make( 43 | "simple_spread", render_mode=render_model, env_num=env_num, asynchronous=False 44 | ) 45 | 46 | env.reset(seed=0) 47 | done = False 48 | step = 0 49 | total_reward = 0 50 | while not np.any(done): 51 | # Based on environmental observation input, predict next action. 52 | 53 | obs, r, done, info = env.step(env.random_action()) 54 | step += 1 55 | total_reward += np.mean(r) 56 | 57 | env.close() 58 | 59 | 60 | if __name__ == "__main__": 61 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 62 | -------------------------------------------------------------------------------- /tests/test_env/test_nlp/test_DailyDialogEnv.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import os 20 | import sys 21 | 22 | import pytest 23 | 24 | from openrl.configs.config import create_config_parser 25 | from openrl.envs.common import make 26 | 27 | 28 | @pytest.fixture( 29 | scope="module", 30 | params=["--env.args {'data_path':None,'tokenizer_path':'builtin_BPE'}"], 31 | ) 32 | def config(request): 33 | cfg_parser = create_config_parser() 34 | cfg = cfg_parser.parse_args(request.param.split()) 35 | return cfg 36 | 37 | 38 | @pytest.mark.unittest 39 | def test_DailyDialogEnv(config): 40 | env = make("daily_dialog", env_num=1, asynchronous=False, cfg=config) 41 | 42 | 43 | if __name__ == "__main__": 44 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 45 | -------------------------------------------------------------------------------- /tests/test_env/test_super_mario_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import os 19 | import sys 20 | 21 | import pytest 22 | 23 | 24 | @pytest.mark.unittest 25 | def test_super_mario(): 26 | from openrl.envs.common import make 27 | 28 | env_num = 2 29 | env = make("SuperMarioBros-1-1-v1", env_num=env_num) 30 | obs, info = env.reset() 31 | obs, reward, done, info = env.step(env.random_action()) 32 | 33 | assert obs["critic"].shape[2] == 3 34 | assert env.parallel_env_num == env_num 35 | 36 | env.close() 37 | 38 | 39 | if __name__ == "__main__": 40 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 41 | -------------------------------------------------------------------------------- /tests/test_env/test_vec_env/test_sync_env.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import os 20 | import sys 21 | 22 | import pytest 23 | from gymnasium.wrappers import EnvCompatibility 24 | 25 | from openrl.envs.toy_envs import make_toy_envs 26 | from openrl.envs.vec_env.sync_venv import SyncVectorEnv 27 | 28 | 29 | class CustomEnvCompatibility(EnvCompatibility): 30 | def reset(self, **kwargs): 31 | return super().reset(**kwargs)[0] 32 | 33 | 34 | def init_envs(): 35 | env_wrappers = [CustomEnvCompatibility] 36 | env_fns = make_toy_envs( 37 | id="IdentityEnv", 38 | env_num=2, 39 | env_wrappers=env_wrappers, 40 | ) 41 | return env_fns 42 | 43 | 44 | def assert_env_name(env, env_name): 45 | assert env.metadata["name"].__name__ == env_name 46 | 47 | 48 | @pytest.mark.unittest 49 | def test_sync_env(): 50 | env_name = "IdentityEnv" 51 | env = SyncVectorEnv(init_envs()) 52 | env.exec_func(assert_env_name, indices=None, env_name=env_name) 53 | env.call("render") 54 | 55 | 56 | if __name__ == "__main__": 57 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 58 | -------------------------------------------------------------------------------- /tests/test_env/test_wrappers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | import os 19 | import sys 20 | 21 | import pytest 22 | 23 | 24 | @pytest.mark.unittest 25 | def test_atari_wrappers(): 26 | import gymnasium 27 | 28 | from openrl.envs.wrappers.atari_wrappers import ( 29 | ClipRewardEnv, 30 | EpisodicLifeEnv, 31 | FireResetEnv, 32 | NoopResetEnv, 33 | WarpFrame, 34 | ) 35 | 36 | env = gymnasium.make("ALE/Breakout-v5") 37 | env = FireResetEnv(EpisodicLifeEnv(ClipRewardEnv(WarpFrame(NoopResetEnv(env))))) 38 | env.reset(seed=0) 39 | while True: 40 | obs, reward, done, truncated, info = env.step(0) 41 | if done: 42 | break 43 | env.close() 44 | 45 | 46 | if __name__ == "__main__": 47 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 48 | -------------------------------------------------------------------------------- /tests/test_examples/test_nlp.py: -------------------------------------------------------------------------------- 1 | # #!/usr/bin/env python 2 | # # -*- coding: utf-8 -*- 3 | # # Copyright 2023 The OpenRL Authors. 4 | # # 5 | # # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # # you may not use this file except in compliance with the License. 7 | # # You may obtain a copy of the License at 8 | # # 9 | # # https://www.apache.org/licenses/LICENSE-2.0 10 | # # 11 | # # Unless required by applicable law or agreed to in writing, software 12 | # # distributed under the License is distributed on an "AS IS" BASIS, 13 | # # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # # See the License for the specific language governing permissions and 15 | # # limitations under the License. 16 | # 17 | # """""" 18 | # 19 | 20 | import os 21 | import sys 22 | 23 | import pytest 24 | 25 | from openrl.configs.config import create_config_parser 26 | from openrl.envs.common import make 27 | from openrl.modules.common import PPONet as Net 28 | from openrl.runners.common import PPOAgent as Agent 29 | 30 | 31 | # @pytest.fixture(scope="module", params=["--env.args {'data_path':None,'tokenizer_path':'builtin_BPE'}"]) 32 | @pytest.fixture(scope="module", params=[""]) 33 | def config(request): 34 | cfg_parser = create_config_parser() 35 | cfg = cfg_parser.parse_args(request.param.split()) 36 | return cfg 37 | 38 | 39 | @pytest.mark.unittest 40 | def test_train_nlp(config): 41 | env = make("fake_dialog_data", env_num=3, cfg=config) 42 | agent = Agent(Net(env)) 43 | agent.train(total_time_steps=1000) 44 | 45 | 46 | if __name__ == "__main__": 47 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 48 | -------------------------------------------------------------------------------- /tests/test_examples/test_train_mpe.py: -------------------------------------------------------------------------------- 1 | """""" 2 | 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import pytest 8 | 9 | from openrl.configs.config import create_config_parser 10 | from openrl.envs.common import make 11 | from openrl.modules.common import PPONet as Net 12 | from openrl.runners.common import PPOAgent as Agent 13 | 14 | 15 | @pytest.fixture( 16 | scope="module", 17 | params=[ 18 | "--episode_length 5 --use_recurrent_policy true --use_joint_action_loss true" 19 | " --use_valuenorm true --use_adv_normalize true" 20 | ], 21 | ) 22 | def config(request): 23 | cfg_parser = create_config_parser() 24 | cfg = cfg_parser.parse_args(request.param.split()) 25 | return cfg 26 | 27 | 28 | @pytest.mark.unittest 29 | def test_train_mpe(config): 30 | env_num = 2 31 | env = make( 32 | "simple_spread", 33 | env_num=env_num, 34 | asynchronous=True, 35 | ) 36 | net = Net(env, cfg=config) 37 | agent = Agent(net) 38 | agent.train(total_time_steps=30) 39 | agent.save("./ppo_agent/") 40 | agent.load("./ppo_agent/") 41 | agent.set_env(env) 42 | obs, info = env.reset(seed=0) 43 | step = 0 44 | while step < 5: 45 | action, _ = agent.act(obs, deterministic=True) 46 | obs, r, done, info = env.step(action) 47 | if np.any(done): 48 | break 49 | step += 1 50 | env.close() 51 | 52 | 53 | if __name__ == "__main__": 54 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 55 | -------------------------------------------------------------------------------- /tests/test_examples/test_train_super_mario.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import os 20 | import sys 21 | 22 | import pytest 23 | 24 | from openrl.envs.common import make 25 | from openrl.modules.common import PPONet as Net 26 | from openrl.runners.common import PPOAgent as Agent 27 | 28 | 29 | @pytest.fixture(scope="module", params=[""]) 30 | def config(request): 31 | from openrl.configs.config import create_config_parser 32 | 33 | cfg_parser = create_config_parser() 34 | cfg = cfg_parser.parse_args(request.param.split()) 35 | return cfg 36 | 37 | 38 | @pytest.mark.unittest 39 | def test_train_super_mario(config): 40 | env = make("SuperMarioBros-1-1-v1", env_num=2) 41 | 42 | agent = Agent(Net(env, cfg=config)) 43 | agent.train(total_time_steps=30) 44 | 45 | env.close() 46 | 47 | 48 | if __name__ == "__main__": 49 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 50 | -------------------------------------------------------------------------------- /tests/test_modules/test_common/test_vdn_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import os 20 | import sys 21 | 22 | import pytest 23 | 24 | from openrl.configs.config import create_config_parser 25 | from openrl.envs.common import make 26 | from openrl.envs.wrappers.mat_wrapper import MATWrapper 27 | from openrl.modules.common import VDNNet 28 | from openrl.runners.common import VDNAgent as Agent 29 | 30 | 31 | @pytest.fixture(scope="module", params=[""]) 32 | def config(request): 33 | cfg_parser = create_config_parser() 34 | cfg = cfg_parser.parse_args(request.param.split()) 35 | return cfg 36 | 37 | 38 | @pytest.mark.unittest 39 | def test_vdn_net(config): 40 | env_num = 2 41 | env = make( 42 | "simple_spread", 43 | env_num=env_num, 44 | asynchronous=True, 45 | ) 46 | env = MATWrapper(env) 47 | 48 | net = VDNNet(env, cfg=config) 49 | # initialize the trainer 50 | agent = Agent(net) 51 | # start training 52 | agent.train(total_time_steps=100) 53 | env.close() 54 | 55 | 56 | if __name__ == "__main__": 57 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 58 | -------------------------------------------------------------------------------- /tests/test_modules/test_networks/test_attention.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import os 20 | import sys 21 | 22 | import pytest 23 | import torch 24 | 25 | from openrl.configs.config import create_config_parser 26 | from openrl.modules.networks.utils.attention import Encoder 27 | 28 | 29 | @pytest.fixture( 30 | scope="module", params=["--use_average_pool True", "--use_average_pool False"] 31 | ) 32 | def config(request): 33 | cfg_parser = create_config_parser() 34 | cfg = cfg_parser.parse_args(request.param.split()) 35 | return cfg 36 | 37 | 38 | @pytest.mark.unittest 39 | def test_attention(config): 40 | for cat_self in [False, True]: 41 | net = Encoder(cfg=config, split_shape=[[1, 1], [1, 1]], cat_self=cat_self) 42 | net(torch.zeros((1, 1))) 43 | 44 | 45 | if __name__ == "__main__": 46 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 47 | -------------------------------------------------------------------------------- /tests/test_supports/test_opengpu/test_gpuinfo.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | # Copyright 2023 The OpenRL Authors. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # https://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | """""" 18 | 19 | import os 20 | import sys 21 | 22 | import pytest 23 | 24 | from openrl.supports.opengpu.gpu_info import preserve_decimal 25 | 26 | 27 | @pytest.mark.unittest 28 | def test_preserve_decimal(): 29 | preserve_decimal(1, 2) 30 | preserve_decimal(1.1, 0) 31 | preserve_decimal(1.1, -1) 32 | preserve_decimal(1.1, 4) 33 | preserve_decimal(-1.1, 4) 34 | preserve_decimal(-0.1, 0) 35 | 36 | 37 | if __name__ == "__main__": 38 | sys.exit(pytest.main(["-sv", os.path.basename(__file__)])) 39 | --------------------------------------------------------------------------------