├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ ├── feature_request.md │ └── other_issue.md ├── PULL_REQUEST_TEMPLATE.md ├── dependabot.yml └── workflows │ ├── doc_dev.yml │ ├── doc_stable.yml │ ├── post_merge.yml │ ├── preview.yml │ └── ready_for_CI.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── assets ├── logo_avatar.png ├── logo_square.svg └── logo_wide.svg ├── azure-pipelines.yml ├── codecov.yml ├── docs ├── Makefile ├── _video │ ├── example_plot_atari_atlantis_vectorized_ppo.mp4 │ ├── example_plot_atari_breakout_vectorized_ppo.mp4 │ ├── user_guide_video │ │ ├── _agent_page_CartPole.mp4 │ │ ├── _agent_page_chain1.mp4 │ │ ├── _agent_page_chain2.mp4 │ │ ├── _agent_page_frozenLake.mp4 │ │ ├── _env_page_Breakout.mp4 │ │ ├── _env_page_MountainCar.mp4 │ │ ├── _env_page_chain.mp4 │ │ └── _experimentManager_page_CartPole.mp4 │ ├── video_chain_quickstart.mp4 │ ├── video_plot_a2c.mp4 │ ├── video_plot_acrobot.mp4 │ ├── video_plot_apple_gold.mp4 │ ├── video_plot_atari_freeway.mp4 │ ├── video_plot_chain.mp4 │ ├── video_plot_dqn.mp4 │ ├── video_plot_gridworld.mp4 │ ├── video_plot_mbqvi.mp4 │ ├── video_plot_mdqn.mp4 │ ├── video_plot_montain_car.mp4 │ ├── video_plot_old_gym_acrobot.mp4 │ ├── video_plot_pball.mp4 │ ├── video_plot_ppo.mp4 │ ├── video_plot_rooms.mp4 │ ├── video_plot_rs_kernel_ucbvi.mp4 │ ├── video_plot_rsucbvi.mp4 │ ├── video_plot_springcartpole.mp4 │ ├── video_plot_twinrooms.mp4 │ └── video_plot_vi.mp4 ├── about.rst ├── api.rst ├── basics │ ├── DeepRLTutorial │ │ ├── TutorialDeepRL.md │ │ ├── output_10_3.png │ │ ├── output_5_3.png │ │ ├── output_6_3.png │ │ └── output_9_3.png │ ├── agent_manager_diagram.png │ ├── comparison.md │ ├── create_agent.rst │ ├── evaluate_agent.rst │ ├── experiment_setup.rst │ ├── multiprocess.rst │ ├── quick_start_rl │ │ ├── Figure_1.png │ │ ├── Figure_2.png │ │ ├── Figure_3.png │ │ ├── experiment_manager_diagram.png │ │ ├── gif_chain.gif │ │ ├── quickstart.md │ │ └── video_chain.mp4 │ ├── rlberry how to.rst │ ├── seeding.rst │ └── userguide │ │ ├── adastop.md │ │ ├── adastop_boxplots.png │ │ ├── agent.md │ │ ├── entropy_loss.png │ │ ├── environment.md │ │ ├── example_eval.png │ │ ├── expManager_multieval.png │ │ ├── experimentManager.md │ │ ├── export_training_data.md │ │ ├── external_lib.md │ │ ├── gif_chain.gif │ │ ├── logging.md │ │ ├── read_writer_example.png │ │ ├── save_load.md │ │ ├── seeding.md │ │ ├── visu_gymnasium_gif.gif │ │ └── visualization.md ├── beginner_dev_guide.md ├── changelog.rst ├── conf.py ├── contributing.md ├── contributors.rst ├── index.md ├── installation.md ├── make.bat ├── markdown_to_py.sh ├── requirements.txt ├── templates │ ├── class.rst │ ├── class_with_call.rst │ ├── function.rst │ ├── nice_toc.md │ └── numpydoc_docstring.rst ├── themes │ └── scikit-learn-fork │ │ ├── README.md │ │ ├── javascript.html │ │ ├── layout.html │ │ ├── nav.html │ │ ├── search.html │ │ ├── static │ │ ├── css │ │ │ ├── theme.css │ │ │ └── vendor │ │ │ │ └── bootstrap.min.css │ │ └── js │ │ │ ├── searchtools.js │ │ │ └── vendor │ │ │ ├── bootstrap.min.js │ │ │ └── jquery-3.6.3.slim.min.js │ │ ├── theme.conf │ │ └── toc.css ├── thumbnails │ ├── adastop_boxplots.png │ ├── chain_thumb.jpg │ ├── code.png │ ├── example_plot_atari_atlantis_vectorized_ppo.jpg │ ├── example_plot_atari_breakout_vectorized_ppo.jpg │ ├── experiment_manager_diagram.png │ ├── output_9_3.png │ ├── video_plot_a2c.jpg │ ├── video_plot_acrobot.jpg │ ├── video_plot_apple_gold.jpg │ ├── video_plot_atari_freeway.jpg │ ├── video_plot_chain.jpg │ ├── video_plot_dqn.jpg │ ├── video_plot_gridworld.jpg │ ├── video_plot_mbqvi.jpg │ ├── video_plot_mdqn.jpg │ ├── video_plot_montain_car.jpg │ ├── video_plot_old_gym_acrobot.jpg │ ├── video_plot_pball.jpg │ ├── video_plot_ppo.jpg │ ├── video_plot_rooms.jpg │ ├── video_plot_rs_kernel_ucbvi.jpg │ ├── video_plot_rsucbvi.jpg │ ├── video_plot_springcartpole.jpg │ ├── video_plot_twinrooms.jpg │ └── video_plot_vi.jpg ├── user_guide.md ├── user_guide2.rst └── versions.rst ├── examples ├── README.md ├── adastop_example.py ├── comparison_agents.py ├── demo_agents │ ├── README.md │ ├── demo_SAC.py │ ├── gym_videos │ │ ├── openaigym.episode_batch.0.454210.stats.json │ │ ├── openaigym.manifest.0.454210.manifest.json │ │ ├── openaigym.video.0.454210.video000000.meta.json │ │ └── openaigym.video.0.454210.video000000.mp4 │ ├── video_plot_a2c.py │ ├── video_plot_dqn.py │ ├── video_plot_mbqvi.py │ ├── video_plot_mdqn.py │ ├── video_plot_ppo.py │ ├── video_plot_rs_kernel_ucbvi.py │ ├── video_plot_rsucbvi.py │ └── video_plot_vi.py ├── demo_bandits │ ├── README.md │ ├── plot_TS_bandit.py │ ├── plot_compare_index_bandits.py │ ├── plot_exp3_bandit.py │ ├── plot_mirror_bandit.py │ └── plot_ucb_bandit.py ├── demo_env │ ├── README.md │ ├── example_atari_atlantis_vectorized_ppo.py │ ├── example_atari_breakout_vectorized_ppo.py │ ├── video_plot_acrobot.py │ ├── video_plot_apple_gold.py │ ├── video_plot_atari_freeway.py │ ├── video_plot_chain.py │ ├── video_plot_gridworld.py │ ├── video_plot_mountain_car.py │ ├── video_plot_old_gym_compatibility_wrapper_old_acrobot.py │ ├── video_plot_pball.py │ ├── video_plot_rooms.py │ ├── video_plot_springcartpole.py │ └── video_plot_twinrooms.py ├── demo_experiment │ ├── params_experiment.yaml │ ├── room.yaml │ ├── rsucbvi.yaml │ ├── rsucbvi_alternative.yaml │ └── run.py ├── example_venv.py ├── plot_agent_manager.py ├── plot_checkpointing.py ├── plot_kernels.py ├── plot_smooth.py └── plot_writer_wrapper.py ├── pyproject.toml ├── rlberry ├── __init__.py ├── agents │ ├── __init__.py │ ├── agent.py │ ├── stable_baselines │ │ ├── __init__.py │ │ └── stable_baselines.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_replay.py │ │ └── test_stable_baselines.py │ └── utils │ │ ├── __init__.py │ │ ├── replay.py │ │ └── replay_utils.py ├── check_packages.py ├── envs │ ├── __init__.py │ ├── basewrapper.py │ ├── finite_mdp.py │ ├── gym_make.py │ ├── interface │ │ ├── __init__.py │ │ └── model.py │ ├── pipeline.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_env_seeding.py │ │ ├── test_gym_env_seeding.py │ │ └── test_gym_make.py │ └── utils.py ├── experiment │ ├── __init__.py │ ├── generator.py │ ├── load_results.py │ ├── tests │ │ ├── params_experiment.yaml │ │ ├── room.yaml │ │ ├── rsucbvi.yaml │ │ ├── rsucbvi_alternative.yaml │ │ ├── test_experiment_generator.py │ │ └── test_load_results.py │ └── yaml_utils.py ├── manager │ ├── __init__.py │ ├── comparison.py │ ├── env_tools.py │ ├── evaluation.py │ ├── experiment_manager.py │ ├── multiple_managers.py │ ├── plotting.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_comparisons.py │ │ ├── test_evaluation.py │ │ ├── test_experiment_manager.py │ │ ├── test_experiment_manager_seeding.py │ │ ├── test_hyperparam_optim.py │ │ ├── test_plot.py │ │ ├── test_shared_data.py │ │ ├── test_utils.py │ │ └── test_venv.py │ └── utils.py ├── metadata_utils.py ├── rendering │ ├── __init__.py │ ├── common_shapes.py │ ├── core.py │ ├── opengl_render2d.py │ ├── pygame_render2d.py │ ├── render_interface.py │ ├── tests │ │ ├── __init__.py │ │ └── test_rendering_interface.py │ └── utils.py ├── seeding │ ├── __init__.py │ ├── seeder.py │ ├── seeding.py │ └── tests │ │ ├── __init__.py │ │ ├── test_seeding.py │ │ ├── test_threads.py │ │ └── test_threads_torch.py ├── spaces │ ├── __init__.py │ ├── box.py │ ├── dict.py │ ├── discrete.py │ ├── from_gym.py │ ├── multi_binary.py │ ├── multi_discrete.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_from_gym.py │ │ └── test_spaces.py │ └── tuple.py ├── tests │ ├── __init__.py │ ├── test_agent_extra.py │ ├── test_agents_base.py │ ├── test_envs.py │ ├── test_imports.py │ └── test_rlberry_main_agents_and_env.py ├── types.py ├── utils │ ├── __init__.py │ ├── binsearch.py │ ├── check_agent.py │ ├── check_env.py │ ├── check_gym_env.py │ ├── factory.py │ ├── loading_tools.py │ ├── logging.py │ ├── space_discretizer.py │ ├── tests │ │ ├── __init__.py │ │ ├── test_binsearch.py │ │ ├── test_check.py │ │ ├── test_loading_tools.py │ │ └── test_writer.py │ ├── torch.py │ └── writers.py └── wrappers │ ├── __init__.py │ ├── autoreset.py │ ├── discrete2onehot.py │ ├── discretize_state.py │ ├── gym_utils.py │ ├── rescale_reward.py │ ├── tests │ ├── __init__.py │ ├── old_env │ │ ├── __init__.py │ │ ├── old_acrobot.py │ │ ├── old_apple_gold.py │ │ ├── old_ball2d.py │ │ ├── old_finite_mdp.py │ │ ├── old_four_room.py │ │ ├── old_gridworld.py │ │ ├── old_mountain_car.py │ │ ├── old_nroom.py │ │ ├── old_pball.py │ │ ├── old_pendulum.py │ │ ├── old_six_room.py │ │ └── old_twinrooms.py │ ├── test_basewrapper.py │ ├── test_common_wrappers.py │ ├── test_gym_space_conversion.py │ ├── test_utils.py │ ├── test_wrapper_seeding.py │ └── test_writer_utils.py │ ├── uncertainty_estimator_wrapper.py │ ├── utils.py │ └── writer_utils.py └── scripts ├── apptainer_for_tests ├── README.md ├── monthly_test_base.sh ├── rlberry_apptainer__specific_python.def └── rlberry_apptainer_base.def ├── build_docs.sh ├── conda_env_setup.sh ├── construct_video_examples.sh ├── fetch_contributors.py ├── full_install.sh └── run_testscov.sh /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report about: Create a report to help us improve title: '' 3 | labels: '' 4 | assignees: '' 5 | 6 | --- 7 | 8 | **Describe the bug** 9 | A clear and concise description of what the bug is. 10 | 11 | **To Reproduce** 12 | Steps to reproduce the behavior. 13 | 14 | **Expected behavior** 15 | A clear and concise description of what you expected to happen. 16 | 17 | **Screenshots** 18 | If applicable, add screenshots to help explain your problem. 19 | 20 | **Desktop (please complete the following information):** 21 | 22 | - OS: \[e.g. iOS] 23 | - Version \[e.g. 22] 24 | - Python version - PyTorch version 25 | 26 | **Additional context** 27 | Add any other context about the problem here. 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request about: Suggest an idea for this project title: '' 3 | labels: '' 4 | assignees: '' 5 | 6 | --- 7 | 8 | **Is your feature request related to a problem? Please describe.** 9 | A clear and concise description of what the problem is. Ex. I'm always frustrated when \[...] 10 | 11 | **Describe the solution you'd like** 12 | A clear and concise description of what you want to happen. 13 | 14 | **Describe alternatives you've considered** 15 | A clear and concise description of any alternative solutions or features you've considered. 16 | 17 | **Additional context** 18 | Add any other context or screenshots about the feature request here. 19 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/other_issue.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Other issue about: Create an issue that is not a bug report nor a feature request title: '' 3 | labels: '' 4 | assignees: '' 5 | 6 | --- 7 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Description 4 | 5 | Please include a summary of the change and which issue is fixed. 6 | Please also include relevant motivation and context, in particular link the relevant issue if it is appropriate. 7 | List any dependencies that are required for this change. 8 | 9 | 24 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # To get started with Dependabot version updates, you'll need to specify which 2 | # package ecosystems to update and where the package manifests are located. 3 | # Please see the documentation for all configuration options: 4 | # https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates 5 | 6 | version: 2 7 | updates: 8 | - package-ecosystem: "pip" # See documentation for possible values 9 | directory: "/" # Location of package manifests 10 | schedule: 11 | interval: "daily" 12 | -------------------------------------------------------------------------------- /.github/workflows/doc_dev.yml: -------------------------------------------------------------------------------- 1 | name: documentation_dev 2 | on: 3 | pull_request_target: 4 | branches: 5 | - main 6 | types: [closed] 7 | push: 8 | branches: 9 | - main 10 | 11 | 12 | permissions: 13 | contents: write 14 | 15 | jobs: 16 | docs: 17 | runs-on: ubuntu-latest 18 | steps: 19 | - name: Wait for version update 20 | run: | 21 | sleep 60 22 | - name: Checkout 23 | uses: actions/checkout@v4 24 | - uses: actions/setup-python@v4 25 | with: 26 | python-version: '3.10' 27 | 28 | - name: Install and configure Poetry 29 | uses: snok/install-poetry@v1 30 | 31 | - name: Install dependencies 32 | run: | 33 | curl -C - https://raw.githubusercontent.com/rlberry-py/rlberry/main/pyproject.toml > pyproject.toml 34 | poetry sync --all-extras --with dev 35 | - name: Sphinx build 36 | run: | 37 | poetry run sphinx-build docs _build 38 | - uses: actions/checkout@v4 39 | with: 40 | # This is necessary so that we have the tags. 41 | fetch-depth: 0 42 | ref: gh-pages 43 | path: gh_pages 44 | - name: copy stable and preview version changes 45 | run: | 46 | cp -rv gh_pages/stable _build/stable || echo "Ignoring exit status" 47 | cp -rv gh_pages/preview_pr _build/preview_pr || echo "Ignoring exit status" 48 | - name: Deploy to GitHub Pages 49 | uses: peaceiris/actions-gh-pages@v3 50 | if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }} 51 | with: 52 | publish_branch: gh-pages 53 | github_token: ${{ secrets.GITHUB_TOKEN }} 54 | publish_dir: _build/ 55 | force_orphan: true 56 | -------------------------------------------------------------------------------- /.github/workflows/doc_stable.yml: -------------------------------------------------------------------------------- 1 | name: documentation_stable 2 | on: 3 | push: 4 | # Pattern matched against refs/tags 5 | tags: 6 | - '*' # Push events to every tag not containing / 7 | 8 | permissions: 9 | contents: write 10 | 11 | jobs: 12 | docs: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | with: 17 | fetch-depth: 0 18 | fetch-tags: true 19 | path: main 20 | - name: checkout latest 21 | run: | 22 | cd main 23 | git checkout $(git describe --tags $(git rev-list --tags --max-count=1)) 24 | cd .. 25 | - uses: actions/setup-python@v4 26 | with: 27 | python-version: '3.10' 28 | - name: Install and configure Poetry 29 | uses: snok/install-poetry@v1 30 | 31 | - name: Install dependencies 32 | run: | 33 | cd main 34 | poetry sync --all-extras --with dev 35 | - name: Sphinx build 36 | run: | 37 | poetry run sphinx-build docs ../_build 38 | cd .. 39 | - uses: actions/checkout@v4 40 | with: 41 | # This is necessary so that we have the tags. 42 | fetch-depth: 0 43 | ref: gh-pages 44 | path: gh_pages 45 | - name: Commit documentation changes 46 | run: | 47 | cd gh_pages 48 | rm -r stable || echo "Ignoring exit status" 49 | mkdir stable 50 | cp -rv ../_build/* stable 51 | git config user.name github-actions 52 | git config user.email github-actions@github.com 53 | git add . 54 | git commit -m "Documentation Stable" 55 | git push 56 | -------------------------------------------------------------------------------- /.github/workflows/post_merge.yml: -------------------------------------------------------------------------------- 1 | name: Version Control 2 | on: 3 | pull_request_target: 4 | branches: 5 | - main 6 | types: [closed] 7 | 8 | jobs: 9 | version_lock_job: 10 | if: github.event.pull_request.merged == true 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/setup-python@v4 14 | with: 15 | python-version: '3.10' 16 | - uses: actions/checkout@v4 17 | with: 18 | # This is necessary so that we have the tags. 19 | fetch-depth: 0 20 | ref: main 21 | - uses: mtkennerly/dunamai-action@v1 22 | with: 23 | env-var: MY_VERSION 24 | - run: echo $MY_VERSION 25 | - uses: snok/install-poetry@v1 26 | - run: poetry version "v$MY_VERSION " 27 | - run: poetry lock 28 | - uses: EndBug/add-and-commit@v8 29 | with: 30 | add: '["pyproject.toml", "poetry.lock"]' 31 | default_author: github_actor 32 | message: 'Writing version and lock with github action [skip ci]' 33 | -------------------------------------------------------------------------------- /.github/workflows/ready_for_CI.yml: -------------------------------------------------------------------------------- 1 | name: ready_for_CI 2 | 3 | on: 4 | pull_request: 5 | types: [labeled, opened, reopened, synchronize] 6 | 7 | jobs: 8 | build: 9 | if: contains( github.event.pull_request.labels.*.name, 'ready for CI') 10 | runs-on: ubuntu-latest 11 | 12 | steps: 13 | - uses: actions/checkout@v4 14 | - name: Run a one-line script 15 | run: echo ready for CI! 16 | - name: Save the PR number in an artifact 17 | shell: bash 18 | env: 19 | PR_NUM: ${{ github.event.number }} 20 | run: echo $PR_NUM > pr_num.txt 21 | 22 | - name: Upload the PR number 23 | uses: actions/upload-artifact@v4 24 | with: 25 | name: pr_num 26 | path: ./pr_num.txt 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # test markdown 2 | docs/python_scripts/ 3 | 4 | #test venvs 5 | rlberry_venvs 6 | 7 | # tensorboard runs 8 | runs/ 9 | 10 | # videos 11 | #*.mp4 12 | notebooks/videos/* 13 | 14 | # pickled objects and csv 15 | *.pickle 16 | *.csv 17 | 18 | # test coverage folder 19 | cov_html/* 20 | 21 | # dev and results folder 22 | dev/* 23 | results/* 24 | temp/* 25 | rlberry_data/* 26 | */rlberry_data/* 27 | 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | pip-wheel-metadata/ 52 | share/python-wheels/ 53 | *.egg-info/ 54 | .installed.cfg 55 | *.egg 56 | MANIFEST 57 | 58 | # PyInstaller 59 | # Usually these files are written by a python script from a template 60 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 61 | *.manifest 62 | *.spec 63 | 64 | # Installer logs 65 | pip-log.txt 66 | pip-delete-this-directory.txt 67 | 68 | # Unit test / coverage reports 69 | htmlcov/ 70 | .tox/ 71 | .nox/ 72 | .coverage 73 | .coverage.* 74 | .cache 75 | nosetests.xml 76 | coverage.xml 77 | *.cover 78 | *.py,cover 79 | .hypothesis/ 80 | .pytest_cache/ 81 | profile.prof 82 | 83 | # Translations 84 | *.mo 85 | *.pot 86 | 87 | # Django stuff: 88 | *.log 89 | local_settings.py 90 | db.sqlite3 91 | db.sqlite3-journal 92 | 93 | # Flask stuff: 94 | instance/ 95 | .webassets-cache 96 | 97 | # Scrapy stuff: 98 | .scrapy 99 | 100 | # Sphinx documentation 101 | docs/_build/ 102 | docs/generated 103 | docs/auto_examples 104 | 105 | 106 | 107 | # PyBuilder 108 | target/ 109 | 110 | # Jupyter Notebook 111 | .ipynb_checkpoints 112 | 113 | # IPython 114 | profile_default/ 115 | ipython_config.py 116 | 117 | # pyenv 118 | .python-version 119 | 120 | # pipenv 121 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 122 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 123 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 124 | # install all needed dependencies. 125 | #Pipfile.lock 126 | 127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 128 | __pypackages__/ 129 | 130 | # Celery stuff 131 | celerybeat-schedule 132 | celerybeat.pid 133 | 134 | # SageMath parsed files 135 | *.sage.py 136 | 137 | # Environments 138 | .env 139 | .venv 140 | env/ 141 | venv/ 142 | ENV/ 143 | env.bak/ 144 | venv.bak/ 145 | 146 | # Spyder project settings 147 | .spyderproject 148 | .spyproject 149 | 150 | # Rope project settings 151 | .ropeproject 152 | 153 | # mkdocs documentation 154 | /site 155 | 156 | # mypy 157 | .mypy_cache/ 158 | .dmypy.json 159 | dmypy.json 160 | 161 | # Pyre type checker 162 | .pyre/ 163 | 164 | # macOS 165 | .DS_Store 166 | 167 | # vscode 168 | .vscode 169 | 170 | # PyCharm 171 | .idea 172 | .project 173 | .pydevproject 174 | 175 | 176 | *.prof 177 | 178 | # poetry.lock 179 | poetry.lock 180 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: trailing-whitespace 8 | - id: end-of-file-fixer 9 | - id: check-yaml 10 | - id: check-docstring-first 11 | 12 | - repo: https://github.com/psf/black 13 | rev: 23.9.1 14 | hooks: 15 | - id: black 16 | 17 | - repo: https://github.com/asottile/blacken-docs 18 | rev: 1.16.0 19 | hooks: 20 | - id: blacken-docs 21 | additional_dependencies: [black==23.1.0] 22 | 23 | - repo: https://github.com/pycqa/flake8 24 | rev: 6.1.0 25 | hooks: 26 | - id: flake8 27 | additional_dependencies: [flake8-docstrings] 28 | types: [file, python] 29 | exclude: (.*/__init__.py|rlberry/check_packages.py) 30 | args: ['--select=F401,F405,D410,D411,D412'] 31 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Domingues" 5 | given-names: "Omar Darwiche" 6 | - family-names: "Flet-Berliac" 7 | given-names: "Yannis" 8 | - family-names: "Leurent" 9 | given-names: "Edouard" 10 | - family-names: "Ménard" 11 | given-names: "Pierre" 12 | - family-names: "Shang" 13 | given-names: "Xuedong" 14 | - family-names: "Valko" 15 | given-names: "Michal" 16 | 17 | title: "rlberry - A Reinforcement Learning Library for Research and Education" 18 | abbreviation: rlberry 19 | version: 0.2.2-dev 20 | doi: 10.5281/zenodo.5223307 21 | date-released: 2021-10-01 22 | url: "https://github.com/rlberry-py/rlberry" 23 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 rlberry-py 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include pyproject.toml 2 | 3 | recursive-include assets *.svg 4 | -------------------------------------------------------------------------------- /assets/logo_avatar.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/assets/logo_avatar.png -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | 2 | comment: false 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "70...100" 8 | status: 9 | project: 10 | default: 11 | # basic 12 | target: auto 13 | threshold: 1% # allow coverage to drop at most 1% 14 | 15 | parsers: 16 | gcov: 17 | branch_detection: 18 | conditional: yes 19 | loop: yes 20 | method: no 21 | macro: no 22 | 23 | ignore: 24 | - "./rlberry/wrappers/tests/old_env/*.py" 25 | - "./rlberry/utils/torch.py" 26 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | # clean example gallery files 16 | clean: 17 | rm -rf $(BUILDDIR)/* 18 | rm -rf auto_examples/ 19 | 20 | # Script used to construct the videos for the examples that output videos. 21 | # Please use a script name that begins with video_plot as with the other examples 22 | # and in the script, there should be a line to save the video in the right place 23 | # and a line to load the video in the headers. Look at existing examples for 24 | # the correct syntax. Be careful that you must remove the _build folder before 25 | # recompiling the doc when a video has been updated/added. 26 | video: 27 | # Make videos 28 | $(foreach file, $(wildcard ../examples/**/video_plot*.py), \ 29 | @echo $(basename $(notdir $(file)));\ 30 | python $(file)) ;\ 31 | # Make thumbnails 32 | $(foreach file, $(wildcard _video/*.mp4), \ 33 | ffmpeg -y -i $(file) -vframes 1 -f image2 \ 34 | thumbnails/$(basename $(notdir $(file))).jpg ;\ 35 | ) 36 | # Remove unused metadata json 37 | @rm _video/*.json 38 | 39 | thumbnail_images: 40 | $(foreach file, $(wildcard _video/*.mp4), \ 41 | ffmpeg -y -i $(file) -vframes 1 -f image2 \ 42 | thumbnails/$(basename $(notdir $(file))).jpg ;\ 43 | ) 44 | 45 | .PHONY: help Makefile 46 | 47 | # Catch-all target: route all unknown targets to Sphinx using the new 48 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 49 | %: Makefile 50 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 51 | -------------------------------------------------------------------------------- /docs/_video/example_plot_atari_atlantis_vectorized_ppo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/example_plot_atari_atlantis_vectorized_ppo.mp4 -------------------------------------------------------------------------------- /docs/_video/example_plot_atari_breakout_vectorized_ppo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/example_plot_atari_breakout_vectorized_ppo.mp4 -------------------------------------------------------------------------------- /docs/_video/user_guide_video/_agent_page_CartPole.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_agent_page_CartPole.mp4 -------------------------------------------------------------------------------- /docs/_video/user_guide_video/_agent_page_chain1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_agent_page_chain1.mp4 -------------------------------------------------------------------------------- /docs/_video/user_guide_video/_agent_page_chain2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_agent_page_chain2.mp4 -------------------------------------------------------------------------------- /docs/_video/user_guide_video/_agent_page_frozenLake.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_agent_page_frozenLake.mp4 -------------------------------------------------------------------------------- /docs/_video/user_guide_video/_env_page_Breakout.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_env_page_Breakout.mp4 -------------------------------------------------------------------------------- /docs/_video/user_guide_video/_env_page_MountainCar.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_env_page_MountainCar.mp4 -------------------------------------------------------------------------------- /docs/_video/user_guide_video/_env_page_chain.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_env_page_chain.mp4 -------------------------------------------------------------------------------- /docs/_video/user_guide_video/_experimentManager_page_CartPole.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_experimentManager_page_CartPole.mp4 -------------------------------------------------------------------------------- /docs/_video/video_chain_quickstart.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_chain_quickstart.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_a2c.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_a2c.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_acrobot.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_acrobot.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_apple_gold.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_apple_gold.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_atari_freeway.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_atari_freeway.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_chain.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_chain.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_dqn.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_dqn.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_gridworld.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_gridworld.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_mbqvi.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_mbqvi.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_mdqn.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_mdqn.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_montain_car.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_montain_car.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_old_gym_acrobot.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_old_gym_acrobot.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_pball.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_pball.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_ppo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_ppo.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_rooms.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_rooms.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_rs_kernel_ucbvi.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_rs_kernel_ucbvi.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_rsucbvi.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_rsucbvi.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_springcartpole.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_springcartpole.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_twinrooms.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_twinrooms.mp4 -------------------------------------------------------------------------------- /docs/_video/video_plot_vi.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_vi.mp4 -------------------------------------------------------------------------------- /docs/about.rst: -------------------------------------------------------------------------------- 1 | .. _about: 2 | 3 | About us 4 | ======== 5 | 6 | This project was initiated and is actively maintained by 7 | `INRIA SCOOL team `_ 8 | 9 | Contributors 10 | ------------ 11 | 12 | The following people contributed actively to rlberry. 13 | 14 | .. include:: contributors.rst 15 | 16 | 17 | .. _citing-rlberry: 18 | 19 | Citing rlberry 20 | -------------- 21 | 22 | If you use rlberry in scientific publications, we would appreciate citations using the following Bibtex entry:: 23 | 24 | @misc{rlberry, 25 | author = {Domingues, Omar Darwiche and Flet-Berliac, Yannis and Leurent, Edouard and M{\'e}nard, Pierre and Shang, Xuedong and Valko, Michal}, 26 | doi = {10.5281/zenodo.5544540}, 27 | month = {10}, 28 | title = {{rlberry - A Reinforcement Learning Library for Research and Education}}, 29 | url = {https://github.com/rlberry-py/rlberry}, 30 | year = {2021} 31 | } 32 | 33 | 34 | 35 | Funding 36 | ------- 37 | 38 | The project would like to thank the following for their participation in funding 39 | PhD's and research project that made this happen. 40 | 41 | - Inria and in particular Inria, Scool for the working environment 42 | - Université de Lille 43 | - I-SITE ULNE 44 | - ANR 45 | - ANRT 46 | - Renault 47 | - European CHIST-ERA project DELTA 48 | - La Région Hauts-de-France and the MEL 49 | -------------------------------------------------------------------------------- /docs/basics/DeepRLTutorial/output_10_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/DeepRLTutorial/output_10_3.png -------------------------------------------------------------------------------- /docs/basics/DeepRLTutorial/output_5_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/DeepRLTutorial/output_5_3.png -------------------------------------------------------------------------------- /docs/basics/DeepRLTutorial/output_6_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/DeepRLTutorial/output_6_3.png -------------------------------------------------------------------------------- /docs/basics/DeepRLTutorial/output_9_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/DeepRLTutorial/output_9_3.png -------------------------------------------------------------------------------- /docs/basics/agent_manager_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/agent_manager_diagram.png -------------------------------------------------------------------------------- /docs/basics/create_agent.rst: -------------------------------------------------------------------------------- 1 | .. _rlberry: https://github.com/rlberry-py/rlberry 2 | 3 | .. _create_agent: 4 | 5 | 6 | Create an agent 7 | =============== 8 | 9 | rlberry_ requires you to use a **very simple interface** to write agents, with basically 10 | two methods to implement: :code:`fit()` and :code:`eval()`. 11 | 12 | The example below shows how to create an agent. 13 | 14 | 15 | .. code-block:: python 16 | 17 | import numpy as np 18 | from rlberry.agents import Agent 19 | 20 | 21 | class MyAgent(Agent): 22 | name = "MyAgent" 23 | 24 | def __init__( 25 | self, env, param1=0.99, param2=1e-5, **kwargs 26 | ): # it's important to put **kwargs to ensure compatibility with the base class 27 | # self.env is initialized in the base class 28 | # An evaluation environment is also initialized: self.eval_env 29 | Agent.__init__(self, env, **kwargs) 30 | 31 | self.param1 = param1 32 | self.param2 = param2 33 | 34 | def fit(self, budget, **kwargs): 35 | """ 36 | The parameter budget can represent the number of steps, the number of episodes etc, 37 | depending on the agent. 38 | * Interact with the environment (self.env); 39 | * Train the agent 40 | * Return useful information 41 | """ 42 | n_episodes = budget 43 | rewards = np.zeros(n_episodes) 44 | 45 | for ep in range(n_episodes): 46 | state, info = self.env.reset() 47 | done = False 48 | while not done: 49 | action = ... 50 | observation, reward, terminated, truncated, info = self.env.step(action) 51 | done = terminated or truncated 52 | rewards[ep] += reward 53 | 54 | info = {"episode_rewards": rewards} 55 | return info 56 | 57 | def eval(self, **kwargs): 58 | """ 59 | Returns a value corresponding to the evaluation of the agent on the 60 | evaluation environment. 61 | 62 | For instance, it can be a Monte-Carlo evaluation of the policy learned in fit(). 63 | """ 64 | return 0.0 65 | 66 | 67 | .. note:: It's important that your agent accepts optional `**kwargs` and pass it to the base class as :code:`Agent.__init__(self, env, **kwargs)`. 68 | 69 | 70 | .. seealso:: 71 | Documentation of the classes :class:`~rlberry.agents.agent.Agent` 72 | and :class:`~rlberry.agents.agent.AgentWithSimplePolicy`. 73 | -------------------------------------------------------------------------------- /docs/basics/experiment_setup.rst: -------------------------------------------------------------------------------- 1 | .. _rlberry: https://github.com/rlberry-py/rlberry 2 | 3 | .. _experiment_setup: 4 | 5 | 6 | Setup and run experiments using yaml config files 7 | ================================================= 8 | 9 | 10 | To setup an experiment with rlberry, you can use yaml files. You'll need: 11 | 12 | * An **experiment.yaml** with some global parameters: seed, number of episodes, horizon, environments (for training and evaluation) and a list of agents to run. 13 | 14 | * yaml files describing the environments and the agents 15 | 16 | * A main python script that reads the files and generates :class:`~rlberry.manager.experiment_manager.ExperimentManager` instances to run each agent. 17 | 18 | 19 | This can be done very succinctly as in the example below: 20 | 21 | 22 | **experiment.yaml** 23 | 24 | .. code-block:: yaml 25 | 26 | description: 'RSUCBVI in NRoom' 27 | seed: 123 28 | train_env: 'examples/demo_experiment/room.yaml' 29 | eval_env: 'examples/demo_experiment/room.yaml' 30 | agents: 31 | - 'examples/demo_experiment/rsucbvi.yaml' 32 | - 'examples/demo_experiment/rsucbvi_alternative.yaml' 33 | 34 | 35 | **room.yaml** 36 | 37 | .. code-block:: yaml 38 | 39 | constructor: 'rlberry_research.envs.benchmarks.grid_exploration.nroom.NRoom' 40 | params: 41 | reward_free: false 42 | array_observation: true 43 | nrooms: 5 44 | 45 | **rsucbvi.yaml** 46 | 47 | .. code-block:: yaml 48 | 49 | agent_class: 'rlberry_research.agents.kernel_based.rs_ucbvi.RSUCBVIAgent' 50 | init_kwargs: 51 | gamma: 1.0 52 | lp_metric: 2 53 | min_dist: 0.0 54 | max_repr: 800 55 | bonus_scale_factor: 1.0 56 | reward_free: True 57 | horizon: 50 58 | eval_kwargs: 59 | eval_horizon: 50 60 | fit_kwargs: 61 | fit_budget: 100 62 | 63 | **rsucbvi_alternative.yaml** 64 | 65 | .. code-block:: yaml 66 | 67 | base_config: 'examples/demo_experiment/rsucbvi.yaml' 68 | init_kwargs: 69 | gamma: 0.9 70 | 71 | 72 | 73 | **run.py** 74 | 75 | .. code-block:: python 76 | 77 | """ 78 | To run the experiment: 79 | 80 | $ python run.py experiment.yaml 81 | 82 | To see more options: 83 | 84 | $ python run.py -h 85 | """ 86 | 87 | from rlberry.experiment import experiment_generator 88 | from rlberry.manager.multiple_managers import MultipleManagers 89 | 90 | multimanagers = MultipleManagers() 91 | 92 | for experiment_manager in experiment_generator(): 93 | multimanagers.append(experiment_manager) 94 | 95 | # Alternatively: 96 | # experiment_manager.fit() 97 | # experiment_manager.save() 98 | 99 | multimanagers.run() 100 | multimanagers.save() 101 | -------------------------------------------------------------------------------- /docs/basics/multiprocess.rst: -------------------------------------------------------------------------------- 1 | .. _multiprocess: 2 | 3 | Parallelization in rlberry 4 | ========================== 5 | 6 | rlberry use python's standard multiprocessing library to execute the fit of agents in parallel on cpus. The parallelization is done via 7 | :class:`~rlberry.manager.ExperimentManager` and via :class:`~rlberry.manager.MultipleManagers`. 8 | 9 | If a user wants to use a third-party parallelization library like joblib, the user must be aware of where the seeding is done so as not to bias the results. rlberry automatically handles seeding when the native parallelization scheme are used. 10 | 11 | Several multiprocessing scheme are implemented in rlberry. 12 | 13 | Threading 14 | --------- 15 | 16 | Thread multiprocessing "constructs higher-level threading interfaces on top of the lower level _thread module" (see the doc on `python's website `_). This is the default scheme in rlberry, most of the time it will result in 17 | having practically no parallelization except if the code executed in each thread (i.e. each fit) is executed without GIL (example: cython code or numpy code). 18 | 19 | Process: spawn or forkserver 20 | ---------------------------- 21 | 22 | To have an efficient parallelization, it is often better to use processes (see the doc on `python's website `_) using the parameter :code:`parallelization="process"` in :class:`~rlberry.manager.ExperimentManager` or :class:`~rlberry.manager.MultipleManagers`. 23 | 24 | This implies that a new process will be launched for each fit of the ExperimentManager. 25 | 26 | The advised method of parallelization is spawn (parameter :code:`mp_context="spawn"`), however spawn method has several drawbacks: 27 | 28 | - The fit code needs to be encapsulated in a :code:`if __name__ == '__main__'` directive. Example : 29 | 30 | .. code:: python 31 | 32 | from rlberry_research.agents.torch import A2CAgent 33 | from rlberry.manager import ExperimentManager 34 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D 35 | 36 | n_steps = 1e5 37 | batch_size = 256 38 | 39 | if __name__ == "__main__": 40 | manager = ExperimentManager( 41 | A2CAgent, 42 | (PBall2D, {}), 43 | init_kwargs=dict(batch_size=batch_size, gamma=0.99, learning_rate=0.001), 44 | n_fit=4, 45 | fit_budget=n_steps, 46 | parallelization="process", 47 | mp_context="spawn", 48 | ) 49 | manager.fit() 50 | 51 | - As a consequence, :code:`spawn` parallelization only works if called from the main script. 52 | - :code:`spawn` does not work when called from a notebook. To work in a notebook, use :code:`fork` instead. 53 | - :code:`forkserver` is an alternative to :code:`spawn` that performs sometimes faster than :code:`spawn`. :code:`forkserver` parallelization must also be encapsulated into a :code:`if __name__ == '__main__'` directive and for now it is available only on Unix systems (MacOS, Linux, ...). 54 | 55 | 56 | Process: fork 57 | ------------- 58 | 59 | Fork multiprocessing is only possible on Unix systems. 60 | It is available through the parameter :code:`mp_context="fork"` when :code:`parallelization="process"`. 61 | Remark that there could be some logging error and hanging when using :code:`fork`. The usage of fork in rlberry is still experimental and may be unstable. 62 | -------------------------------------------------------------------------------- /docs/basics/quick_start_rl/Figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/Figure_1.png -------------------------------------------------------------------------------- /docs/basics/quick_start_rl/Figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/Figure_2.png -------------------------------------------------------------------------------- /docs/basics/quick_start_rl/Figure_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/Figure_3.png -------------------------------------------------------------------------------- /docs/basics/quick_start_rl/experiment_manager_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/experiment_manager_diagram.png -------------------------------------------------------------------------------- /docs/basics/quick_start_rl/gif_chain.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/gif_chain.gif -------------------------------------------------------------------------------- /docs/basics/quick_start_rl/video_chain.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/video_chain.mp4 -------------------------------------------------------------------------------- /docs/basics/seeding.rst: -------------------------------------------------------------------------------- 1 | .. _seeding: 2 | 3 | Seeding & Reproducibility 4 | ========================== 5 | 6 | rlberry_ has a class :class:`~rlberry.seeding.seeder.Seeder` that conveniently wraps a `NumPy SeedSequence `_, 7 | and allows us to create independent random number generators for different objects and threads, using a single 8 | :class:`~rlberry.seeding.seeder.Seeder` instance. 9 | 10 | It works as follows: 11 | 12 | 13 | .. code-block:: python 14 | 15 | from rlberry.seeding import Seeder 16 | 17 | seeder = Seeder(123) 18 | 19 | # Each Seeder instance has a random number generator (rng) 20 | # See https://numpy.org/doc/stable/reference/random/generator.html to check the 21 | # methods available in rng. 22 | seeder.rng.integers(5) 23 | seeder.rng.normal() 24 | print(type(seeder.rng)) 25 | # etc 26 | 27 | # Environments and agents should be seeded using a single seeder, 28 | # to ensure that their random number generators are independent. 29 | from rlberry.envs import gym_make 30 | from rlberry.agents import RSUCBVIAgent 31 | 32 | env = gym_make("MountainCar-v0") 33 | env.reseed(seeder) 34 | 35 | agent = RSUCBVIAgent(env) 36 | agent.reseed(seeder) 37 | 38 | 39 | # Environments and Agents have their own seeder and rng. 40 | # When writing your own agents and inheriting from the Agent class, 41 | # you should use agent.rng whenever you need to generate random numbers; 42 | # the same applies to your environments. 43 | # This is necessary to ensure reproducibility. 44 | print("env seeder: ", env.seeder) 45 | print("random sample from env rng: ", env.rng.normal()) 46 | print("agent seeder: ", agent.seeder) 47 | print("random sample from agent rng: ", agent.rng.normal()) 48 | 49 | 50 | # A seeder can spawn other seeders that are independent from it. 51 | # This is useful to seed two different threads, using seeder1 52 | # in the first thread, and seeder2 in the second thread. 53 | seeder1, seeder2 = seeder.spawn(2) 54 | 55 | 56 | # You can also use a seeder to seed external libraries (such as torch) 57 | # using the function set_external_seed 58 | from rlberry.seeding import set_external_seed 59 | 60 | set_external_seed(seeder) 61 | 62 | 63 | .. note:: 64 | The class :class:`~rlberry.manager.experiment_manager.ExperimentManager` provides a :code:`seed` parameter in its constructor, 65 | and handles automatically the seeding of all environments and agents used by it. 66 | 67 | .. note:: 68 | 69 | The :meth:`optimize_hyperparams` method of 70 | :class:`~rlberry.manager.experiment_manager.ExperimentManager` uses the `Optuna `_ 71 | library for hyperparameter optimization and is **inherently non-deterministic** 72 | (see `Optuna FAQ `_). 73 | -------------------------------------------------------------------------------- /docs/basics/userguide/adastop_boxplots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/adastop_boxplots.png -------------------------------------------------------------------------------- /docs/basics/userguide/entropy_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/entropy_loss.png -------------------------------------------------------------------------------- /docs/basics/userguide/example_eval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/example_eval.png -------------------------------------------------------------------------------- /docs/basics/userguide/expManager_multieval.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/expManager_multieval.png -------------------------------------------------------------------------------- /docs/basics/userguide/gif_chain.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/gif_chain.gif -------------------------------------------------------------------------------- /docs/basics/userguide/read_writer_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/read_writer_example.png -------------------------------------------------------------------------------- /docs/basics/userguide/visu_gymnasium_gif.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/visu_gymnasium_gif.gif -------------------------------------------------------------------------------- /docs/contributors.rst: -------------------------------------------------------------------------------- 1 | .. raw :: html 2 | 3 | 4 |
5 | 8 |
9 |
10 |

AleShi94

11 |
12 |
13 |
14 |

brahimdriss

15 |
16 |
17 |
18 |

Matheus M. Centa

19 |
20 |
21 |
22 |

Omar D.

23 |
24 |
25 |
26 |

Rémy Degenne

27 |
28 |
29 |
30 |

Yannis Flet-Berliac

31 |
32 |
33 |
34 |

Hector Kohler

35 |
36 |
37 |
38 |

Edouard Leurent

39 |
40 |
41 |
42 |

Pierre Ménard

43 |
44 |
45 |
46 |

Waris Radji

47 |
48 |
49 |
50 |

sauxpa

51 |
52 |
53 |
54 |

Xuedong Shang

55 |
56 |
57 |
58 |

Ju T

59 |
60 |
61 |
62 |

TimotheeMathieu

63 |
64 |
65 |
66 |

Riccardo Della Vecchia

67 |
68 |
69 |
70 |

YannBerthelot

71 |
72 |
73 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/markdown_to_py.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env -S bash 2 | 3 | 4 | mkdir python_scripts 5 | 6 | set -e 7 | 8 | shopt -s globstar 9 | list_files=$(ls $PWD/**/*.md) 10 | 11 | for script in $list_files; do 12 | echo "Processing " $script 13 | sed -n '/^```python/,/^```/ p' < $script | sed '/^```/ d' > python_scripts/${script##*/}.py 14 | done 15 | 16 | # read -p "Do you wish to execute all python scripts? (y/n)" yn 17 | # case $yn in 18 | # [Yy]* ) for f in python_scripts/*.py; do python3 $f ; done ;; 19 | # * ) exit;; 20 | # esac 21 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx<7 2 | sphinx-gallery 3 | sphinx-math-dollar 4 | numpydoc 5 | myst-parser 6 | git+https://github.com/sphinx-contrib/video 7 | matplotlib 8 | sphinx-copybutton 9 | sphinx-design 10 | -------------------------------------------------------------------------------- /docs/templates/class.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | .. include:: {{module}}.{{objname}}.examples 9 | 10 | .. raw:: html 11 | 12 |
13 | -------------------------------------------------------------------------------- /docs/templates/class_with_call.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}=============== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | {% block methods %} 9 | .. automethod:: __call__ 10 | {% endblock %} 11 | 12 | .. include:: {{module}}.{{objname}}.examples 13 | 14 | .. raw:: html 15 | 16 |
17 | -------------------------------------------------------------------------------- /docs/templates/function.rst: -------------------------------------------------------------------------------- 1 | :mod:`{{module}}`.{{objname}} 2 | {{ underline }}==================== 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autofunction:: {{ objname }} 7 | 8 | .. include:: {{module}}.{{objname}}.examples 9 | 10 | .. raw:: html 11 | 12 |
13 | -------------------------------------------------------------------------------- /docs/templates/numpydoc_docstring.rst: -------------------------------------------------------------------------------- 1 | {{index}} 2 | {{summary}} 3 | {{extended_summary}} 4 | {{parameters}} 5 | {{returns}} 6 | {{yields}} 7 | {{other_parameters}} 8 | {{attributes}} 9 | {{raises}} 10 | {{warns}} 11 | {{warnings}} 12 | {{see_also}} 13 | {{notes}} 14 | {{references}} 15 | {{examples}} 16 | {{methods}} 17 | -------------------------------------------------------------------------------- /docs/themes/scikit-learn-fork/README.md: -------------------------------------------------------------------------------- 1 | Theme forked from [scikit-learn theme](https://github.com/scikit-learn/scikit-learn). 2 | -------------------------------------------------------------------------------- /docs/themes/scikit-learn-fork/search.html: -------------------------------------------------------------------------------- 1 | {%- extends "basic/search.html" %} 2 | {% block extrahead %} 3 | 4 | 5 | 6 | 7 | 8 | {% endblock %} 9 | -------------------------------------------------------------------------------- /docs/themes/scikit-learn-fork/theme.conf: -------------------------------------------------------------------------------- 1 | [theme] 2 | inherit = basic 3 | pygments_style = default 4 | stylesheet = css/theme.css 5 | 6 | [options] 7 | google_analytics = true 8 | mathjax_path = 9 | -------------------------------------------------------------------------------- /docs/themes/scikit-learn-fork/toc.css: -------------------------------------------------------------------------------- 1 | .. 2 | File to ..include in a document with a big table of content, to give 3 | it 'style' 4 | 5 | .. raw:: html 6 | 7 | 38 | -------------------------------------------------------------------------------- /docs/thumbnails/adastop_boxplots.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/adastop_boxplots.png -------------------------------------------------------------------------------- /docs/thumbnails/chain_thumb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/chain_thumb.jpg -------------------------------------------------------------------------------- /docs/thumbnails/code.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/code.png -------------------------------------------------------------------------------- /docs/thumbnails/example_plot_atari_atlantis_vectorized_ppo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/example_plot_atari_atlantis_vectorized_ppo.jpg -------------------------------------------------------------------------------- /docs/thumbnails/example_plot_atari_breakout_vectorized_ppo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/example_plot_atari_breakout_vectorized_ppo.jpg -------------------------------------------------------------------------------- /docs/thumbnails/experiment_manager_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/experiment_manager_diagram.png -------------------------------------------------------------------------------- /docs/thumbnails/output_9_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/output_9_3.png -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_a2c.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_a2c.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_acrobot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_acrobot.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_apple_gold.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_apple_gold.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_atari_freeway.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_atari_freeway.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_chain.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_chain.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_dqn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_dqn.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_gridworld.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_gridworld.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_mbqvi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_mbqvi.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_mdqn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_mdqn.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_montain_car.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_montain_car.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_old_gym_acrobot.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_old_gym_acrobot.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_pball.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_pball.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_ppo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_ppo.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_rooms.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_rooms.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_rs_kernel_ucbvi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_rs_kernel_ucbvi.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_rsucbvi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_rsucbvi.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_springcartpole.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_springcartpole.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_twinrooms.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_twinrooms.jpg -------------------------------------------------------------------------------- /docs/thumbnails/video_plot_vi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_vi.jpg -------------------------------------------------------------------------------- /docs/user_guide.md: -------------------------------------------------------------------------------- 1 | (user_guide)= 2 | 3 | 4 | # User Guide 5 | 6 | ## Introduction 7 | Welcome to rlberry. 8 | Use rlberry's [ExperimentManager](experimentManager_page) to train, evaluate and compare rl agents. 9 | Like other popular rl libraries, rlberry also provides basic tools for plotting, multiprocessing and logging . In this user guide, we take you through the core features of rlberry and illustrate them with [examples](/auto_examples/index) and [API documentation](/api) . 10 | 11 | To run all the examples, you will need to install other libraries like "[rlberry-scool](https://github.com/rlberry-py/rlberry-scool)" (and others). 12 | 13 | The easiest way to do it is : 14 | ```none 15 | pip install rlberry[torch,extras] 16 | pip install rlberry-scool 17 | ``` 18 | 19 | [rlberry-scool](https://github.com/rlberry-py/rlberry-scool) : 20 | It's the repository used for teaching purposes. These are mainly basic agents and environments, in a version that makes it easier for students to learn. 21 | 22 | You can find more details about installation [here](installation)! 23 | 24 | You can find our quick starts here : 25 | ```{toctree} 26 | :maxdepth: 2 27 | basics/quick_start_rl/quickstart.md 28 | basics/DeepRLTutorial/TutorialDeepRL.md 29 | ``` 30 | 31 | ## Set up an experiment 32 | ```{include} templates/nice_toc.md 33 | ``` 34 | 35 | ```{toctree} 36 | :maxdepth: 2 37 | basics/userguide/environment.md 38 | basics/userguide/agent.md 39 | basics/userguide/experimentManager.md 40 | basics/userguide/logging.md 41 | basics/userguide/visualization.md 42 | ``` 43 | 44 | ## Experimenting with Deep agents 45 | [(In construction)](https://github.com/rlberry-py/rlberry/issues/459) 46 | ## Reproducibility 47 | ```{toctree} 48 | :maxdepth: 2 49 | basics/userguide/seeding.md 50 | basics/userguide/save_load.md 51 | basics/userguide/export_training_data.md 52 | ``` 53 | 54 | ## Advanced Usage 55 | ```{toctree} 56 | :maxdepth: 2 57 | basics/userguide/adastop.md 58 | basics/comparison.md 59 | basics/userguide/external_lib.md 60 | ``` 61 | - Custom Agents (In construction) 62 | - Custom Environments (In construction) 63 | - Transfer Learning (In construction) 64 | 65 | # Contributing to rlberry 66 | If you want to contribute to rlberry, check out [the contribution guidelines](contributing). 67 | -------------------------------------------------------------------------------- /docs/user_guide2.rst: -------------------------------------------------------------------------------- 1 | .. title:: User guide : contents 2 | 3 | .. _user_guide2: 4 | 5 | ========== 6 | User guide 7 | ========== 8 | 9 | .. Introduction 10 | .. ============ 11 | .. Welcome to rlberry. Use rlberry's ExperimentManager (add ref) to train, evaluate and compare rl agents. In addition to 12 | .. the core ExperimentManager (add ref), rlberry provides the user with a set of bandit (add ref), tabular rl (add ref), and 13 | .. deep rl agents (add ref) as well as a wrapper for stablebaselines3 (add link, and ref) agents. 14 | .. Like other popular rl libraries, rlberry also provides basic tools for plotting, multiprocessing and logging (add refs). 15 | .. In this user guide, we take you through the core features of rlberry and illustrate them with examples (add ref) and API documentation (add ref). 16 | 17 | If you are new to rlberry, check the :ref:`Tutorials` below and the :ref:`the quickstart` documentation. 18 | In the quick start, you will learn how to set up an experiment and evaluate the 19 | efficiency of different agents. 20 | 21 | For more information see :ref:`the gallery of examples`. 22 | 23 | 24 | Tutorials 25 | ========= 26 | 27 | The tutorials below will present to you the main functionalities of ``rlberry`` in a few minutes. 28 | 29 | 30 | 31 | 32 | Quick start: setup an experiment and evaluate different agents 33 | -------------------------------------------------------------- 34 | 35 | .. toctree:: 36 | :maxdepth: 2 37 | 38 | basics/quick_start_rl/quickstart.md 39 | basics/DeepRLTutorial/TutorialDeepRL.md 40 | 41 | 42 | Agents, hyperparameter optimization and experiment setup 43 | --------------------------------------------------------- 44 | 45 | .. toctree:: 46 | :maxdepth: 1 47 | 48 | basics/create_agent.rst 49 | basics/evaluate_agent.rst 50 | basics/experiment_setup.rst 51 | basics/seeding.rst 52 | basics/multiprocess.rst 53 | basics/comparison.md 54 | 55 | We also provide examples to show how to use :ref:`torch checkpointing` 56 | in rlberry and :ref:`tensorboard` 57 | 58 | Compatibility with External Libraries 59 | ===================================== 60 | 61 | We provide examples to show you how to use rlberry with: 62 | 63 | - :ref:`Gymnasium `; 64 | - :ref:`Stable Baselines `. 65 | 66 | 67 | How to contribute? 68 | ================== 69 | 70 | If you want to contribute to rlberry, check out :doc:`the contribution guidelines`. 71 | -------------------------------------------------------------------------------- /docs/versions.rst: -------------------------------------------------------------------------------- 1 | Documentation versions 2 | ====================== 3 | 4 | 5 | * `Stable `_ 6 | * :ref:`Development` 7 | 8 | * For dev team: `PR preview `_ 9 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | .. _examples: 2 | 3 | =================== 4 | Gallery of examples 5 | =================== 6 | -------------------------------------------------------------------------------- /examples/adastop_example.py: -------------------------------------------------------------------------------- 1 | """ 2 | =========================================== 3 | Compare PPO and A2C on Acrobot with AdaStop 4 | =========================================== 5 | 6 | This example illustrate the use of adastop_comparator which uses adaptive multiple-testing to assess whether trained agents are 7 | statistically different or not. 8 | 9 | Remark that in the case where two agents are not deemed statistically different it can mean either that they are as efficient, 10 | or it can mean that there have not been enough fits to assess the variability of the agents. 11 | 12 | Results in 13 | 14 | .. code-block:: 15 | 16 | [INFO] 13:35: Test finished 17 | [INFO] 13:35: Results are 18 | Agent1 vs Agent2 mean Agent1 mean Agent2 mean diff std Agent 1 std Agent 2 decisions 19 | 0 A2C vs PPO -274.274 -85.068 -189.206 185.82553 2.71784 smaller 20 | 21 | 22 | """ 23 | 24 | from rlberry.envs import gym_make 25 | from stable_baselines3 import A2C, PPO 26 | from rlberry.agents.stable_baselines import StableBaselinesAgent 27 | from rlberry.manager import AdastopComparator 28 | 29 | env_ctor, env_kwargs = gym_make, dict(id="Acrobot-v1") 30 | 31 | managers = [ 32 | { 33 | "agent_class": StableBaselinesAgent, 34 | "train_env": (env_ctor, env_kwargs), 35 | "fit_budget": 5e4, 36 | "agent_name": "A2C", 37 | "init_kwargs": {"algo_cls": A2C, "policy": "MlpPolicy", "verbose": 1}, 38 | }, 39 | { 40 | "agent_class": StableBaselinesAgent, 41 | "train_env": (env_ctor, env_kwargs), 42 | "agent_name": "PPO", 43 | "fit_budget": 5e4, 44 | "init_kwargs": {"algo_cls": PPO, "policy": "MlpPolicy", "verbose": 1}, 45 | }, 46 | ] 47 | 48 | comparator = AdastopComparator() 49 | comparator.compare(managers) 50 | print(comparator.managers_paths) 51 | -------------------------------------------------------------------------------- /examples/comparison_agents.py: -------------------------------------------------------------------------------- 1 | """ 2 | ========================= 3 | Compare Bandit Algorithms 4 | ========================= 5 | 6 | This example illustrate the use of compare_agents, a function that uses multiple-testing to assess whether trained agents are 7 | statistically different or not. 8 | 9 | Remark that in the case where two agents are not deemed statistically different it can mean either that they are as efficient, 10 | or it can mean that there have not been enough fits to assess the variability of the agents. 11 | 12 | """ 13 | 14 | import numpy as np 15 | 16 | from rlberry.manager.comparison import compare_agents 17 | from rlberry.manager import AgentManager 18 | from rlberry_research.envs.bandits import BernoulliBandit 19 | from rlberry_research.agents.bandits import ( 20 | IndexAgent, 21 | makeBoundedMOSSIndex, 22 | makeBoundedNPTSIndex, 23 | makeBoundedUCBIndex, 24 | makeETCIndex, 25 | ) 26 | 27 | # Parameters of the problem 28 | means = np.array([0.6, 0.6, 0.6, 0.9]) # means of the arms 29 | A = len(means) 30 | T = 2000 # Horizon 31 | N = 50 # number of fits 32 | 33 | # Construction of the experiment 34 | 35 | env_ctor = BernoulliBandit 36 | env_kwargs = {"p": means} 37 | 38 | 39 | class UCBAgent(IndexAgent): 40 | name = "UCB" 41 | 42 | def __init__(self, env, **kwargs): 43 | index, _ = makeBoundedUCBIndex() 44 | IndexAgent.__init__(self, env, index, writer_extra="reward", **kwargs) 45 | 46 | 47 | class ETCAgent(IndexAgent): 48 | name = "ETC" 49 | 50 | def __init__(self, env, m=20, **kwargs): 51 | index, _ = makeETCIndex(A, m) 52 | IndexAgent.__init__( 53 | self, env, index, writer_extra="action_and_reward", **kwargs 54 | ) 55 | 56 | 57 | class MOSSAgent(IndexAgent): 58 | name = "MOSS" 59 | 60 | def __init__(self, env, **kwargs): 61 | index, _ = makeBoundedMOSSIndex(T, A) 62 | IndexAgent.__init__( 63 | self, env, index, writer_extra="action_and_reward", **kwargs 64 | ) 65 | 66 | 67 | class NPTSAgent(IndexAgent): 68 | name = "NPTS" 69 | 70 | def __init__(self, env, **kwargs): 71 | index, tracker_params = makeBoundedNPTSIndex() 72 | IndexAgent.__init__( 73 | self, 74 | env, 75 | index, 76 | writer_extra="reward", 77 | tracker_params=tracker_params, 78 | **kwargs, 79 | ) 80 | 81 | 82 | Agents_class = [MOSSAgent, NPTSAgent, UCBAgent, ETCAgent] 83 | 84 | managers = [ 85 | AgentManager( 86 | Agent, 87 | train_env=(env_ctor, env_kwargs), 88 | fit_budget=T, 89 | parallelization="process", 90 | mp_context="fork", 91 | n_fit=N, 92 | ) 93 | for Agent in Agents_class 94 | ] 95 | 96 | 97 | for manager in managers: 98 | manager.fit() 99 | 100 | 101 | def eval_function(manager, eval_budget=None, agent_id=0): 102 | df = manager.get_writer_data()[agent_id] 103 | return T * np.max(means) - np.sum(df.loc[df["tag"] == "reward", "value"]) 104 | 105 | 106 | print( 107 | compare_agents(managers, method="tukey_hsd", eval_function=eval_function, B=10_000) 108 | ) 109 | -------------------------------------------------------------------------------- /examples/demo_agents/README.md: -------------------------------------------------------------------------------- 1 | Illustration of rlberry agents 2 | ============================== 3 | -------------------------------------------------------------------------------- /examples/demo_agents/demo_SAC.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================= 3 | SAC Soft Actor-Critic 4 | ============================= 5 | 6 | This script shows how to train a SAC agent on a Pendulum environment. 7 | """ 8 | 9 | import time 10 | 11 | import gymnasium as gym 12 | from rlberry_research.agents.torch.sac import SACAgent 13 | from rlberry_research.envs import Pendulum 14 | from rlberry.manager import ExperimentManager 15 | 16 | 17 | def env_ctor(env, wrap_spaces=True): 18 | return env 19 | 20 | 21 | # Setup agent parameters 22 | env_name = "Pendulum" 23 | fit_budget = int(2e5) 24 | agent_name = f"{env_name}_{fit_budget}_{int(time.time())}" 25 | 26 | # Setup environment parameters 27 | env = Pendulum() 28 | env = gym.wrappers.TimeLimit(env, max_episode_steps=200) 29 | env = gym.wrappers.RecordEpisodeStatistics(env) 30 | env_kwargs = dict(env=env) 31 | 32 | # Create agent instance 33 | xp_manager = ExperimentManager( 34 | SACAgent, 35 | (env_ctor, env_kwargs), 36 | fit_budget=fit_budget, 37 | n_fit=1, 38 | enable_tensorboard=True, 39 | agent_name=agent_name, 40 | ) 41 | 42 | # Start training 43 | xp_manager.fit() 44 | -------------------------------------------------------------------------------- /examples/demo_agents/gym_videos/openaigym.episode_batch.0.454210.stats.json: -------------------------------------------------------------------------------- 1 | {"initial_reset_timestamp": 22935.230051929, "timestamps": [22938.041554318], "episode_lengths": [71], "episode_rewards": [71.0], "episode_types": ["t"]} 2 | -------------------------------------------------------------------------------- /examples/demo_agents/gym_videos/openaigym.manifest.0.454210.manifest.json: -------------------------------------------------------------------------------- 1 | {"stats": "openaigym.episode_batch.0.454210.stats.json", "videos": [["openaigym.video.0.454210.video000000.mp4", "openaigym.video.0.454210.video000000.meta.json"]], "env_info": {"gym_version": "0.21.0", "env_id": "CartPole-v0"}} 2 | -------------------------------------------------------------------------------- /examples/demo_agents/gym_videos/openaigym.video.0.454210.video000000.meta.json: -------------------------------------------------------------------------------- 1 | {"episode_id": 0, "content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version n4.4.1 Copyright (c) 2000-2021 the FFmpeg developers\\nbuilt with gcc 11.1.0 (GCC)\\nconfiguration: --prefix=/usr --disable-debug --disable-static --disable-stripping --enable-amf --enable-avisynth --enable-cuda-llvm --enable-lto --enable-fontconfig --enable-gmp --enable-gnutls --enable-gpl --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libdav1d --enable-libdrm --enable-libfreetype --enable-libfribidi --enable-libgsm --enable-libiec61883 --enable-libjack --enable-libmfx --enable-libmodplug --enable-libmp3lame --enable-libopencore_amrnb --enable-libopencore_amrwb --enable-libopenjpeg --enable-libopus --enable-libpulse --enable-librav1e --enable-librsvg --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libsvtav1 --enable-libtheora --enable-libv4l2 --enable-libvidstab --enable-libvmaf --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx264 --enable-libx265 --enable-libxcb --enable-libxml2 --enable-libxvid --enable-libzimg --enable-nvdec --enable-nvenc --enable-shared --enable-version3\\nlibavutil 56. 70.100 / 56. 70.100\\nlibavcodec 58.134.100 / 58.134.100\\nlibavformat 58. 76.100 / 58. 76.100\\nlibavdevice 58. 13.100 / 58. 13.100\\nlibavfilter 7.110.100 / 7.110.100\\nlibswscale 5. 9.100 / 5. 9.100\\nlibswresample 3. 9.100 / 3. 9.100\\nlibpostproc 55. 9.100 / 55. 9.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/home/frost/tmpmath/rlberry/examples/demo_agents/gym_videos/openaigym.video.0.454210.video000000.mp4"]}} 2 | -------------------------------------------------------------------------------- /examples/demo_agents/gym_videos/openaigym.video.0.454210.video000000.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/examples/demo_agents/gym_videos/openaigym.video.0.454210.video000000.mp4 -------------------------------------------------------------------------------- /examples/demo_agents/video_plot_a2c.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================================== 3 | A demo of A2C algorithm in PBall2D environment 4 | ============================================== 5 | Illustration of how to set up an A2C algorithm in rlberry. 6 | The environment chosen here is PBALL2D environment. 7 | 8 | .. video:: ../../video_plot_a2c.mp4 9 | :width: 600 10 | 11 | """ 12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_a2c.jpg' 13 | 14 | from rlberry_research.agents.torch import A2CAgent 15 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D 16 | from gymnasium.wrappers import TimeLimit 17 | 18 | 19 | env = PBall2D() 20 | env = TimeLimit(env, max_episode_steps=256) 21 | n_timesteps = 50_000 22 | agent = A2CAgent(env, gamma=0.99, learning_rate=0.001) 23 | agent.fit(budget=n_timesteps) 24 | 25 | env.enable_rendering() 26 | 27 | observation, info = env.reset() 28 | for tt in range(200): 29 | action = agent.policy(observation) 30 | observation, reward, terminated, truncated, info = env.step(action) 31 | done = terminated or truncated 32 | 33 | video = env.save_video("_video/video_plot_a2c.mp4") 34 | -------------------------------------------------------------------------------- /examples/demo_agents/video_plot_dqn.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================================== 3 | A demo of DQN algorithm in CartPole environment 4 | =============================================== 5 | Illustration of how to set up a DQN algorithm in rlberry. 6 | The environment chosen here is gym's cartpole environment. 7 | 8 | As DQN can be computationally intensive and hard to tune, 9 | one can use tensorboard to visualize the training of the DQN 10 | using the following command: 11 | 12 | .. code-block:: bash 13 | 14 | tensorboard --logdir {Path(agent.writer.log_dir).parent} 15 | 16 | .. video:: ../../video_plot_dqn.mp4 17 | :width: 600 18 | 19 | """ 20 | 21 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_dqn.jpg' 22 | 23 | from rlberry.envs import gym_make 24 | from torch.utils.tensorboard import SummaryWriter 25 | 26 | from rlberry_research.agents.torch.dqn import DQNAgent 27 | from rlberry.utils.logging import configure_logging 28 | 29 | from gymnasium.wrappers.rendering import RecordVideo 30 | import shutil 31 | import os 32 | 33 | 34 | configure_logging(level="INFO") 35 | 36 | env = gym_make("CartPole-v1", render_mode="rgb_array") 37 | agent = DQNAgent(env, epsilon_decay_interval=1000) 38 | agent.set_writer(SummaryWriter()) 39 | 40 | print(f"Running DQN on {env}") 41 | 42 | agent.fit(budget=50) 43 | env = RecordVideo(env, "_video/temp") 44 | 45 | for episode in range(3): 46 | done = False 47 | observation, info = env.reset() 48 | while not done: 49 | action = agent.policy(observation) 50 | observation, reward, terminated, truncated, info = env.step(action) 51 | done = terminated or truncated 52 | env.close() 53 | 54 | # need to move the final result inside the folder used for documentation 55 | os.rename("_video/temp/rl-video-episode-0.mp4", "_video/video_plot_dqn.mp4") 56 | shutil.rmtree("_video/temp/") 57 | -------------------------------------------------------------------------------- /examples/demo_agents/video_plot_mbqvi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================================================== 3 | A demo of MBQVI algorithm in Gridworld environment 4 | ================================================== 5 | Illustration of how to set up an MBQVI algorithm in rlberry. 6 | The environment chosen here is GridWorld environment. 7 | 8 | .. video:: ../../video_plot_mbqvi.mp4 9 | :width: 600 10 | 11 | """ 12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_mbqvi.jpg' 13 | from rlberry_scool.agents.mbqvi import MBQVIAgent 14 | from rlberry_research.envs.finite import GridWorld 15 | 16 | params = {} 17 | params["n_samples"] = 100 # samples per state-action pair 18 | params["gamma"] = 0.99 19 | params["horizon"] = None 20 | 21 | env = GridWorld(7, 10, walls=((2, 2), (3, 3)), success_probability=0.6) 22 | agent = MBQVIAgent(env, **params) 23 | info = agent.fit() 24 | print(info) 25 | 26 | # evaluate policy in a deterministic version of the environment 27 | env_eval = GridWorld(7, 10, walls=((2, 2), (3, 3)), success_probability=1.0) 28 | env_eval.enable_rendering() 29 | state, info = env_eval.reset() 30 | for tt in range(50): 31 | action = agent.policy(state) 32 | next_s, _, _, _, _ = env_eval.step(action) 33 | state = next_s 34 | video = env_eval.save_video("_video/video_plot_mbqvi.mp4") 35 | -------------------------------------------------------------------------------- /examples/demo_agents/video_plot_mdqn.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================================== 3 | A demo of M-DQN algorithm in CartPole environment 4 | =============================================== 5 | Illustration of how to set up a M-DQN algorithm in rlberry. 6 | The environment chosen here is gym's cartpole environment. 7 | 8 | As DQN can be computationally intensive and hard to tune, 9 | one can use tensorboard to visualize the training of the DQN 10 | using the following command: 11 | 12 | .. code-block:: bash 13 | 14 | tensorboard --logdir {Path(agent.writer.log_dir).parent} 15 | 16 | .. video:: ../../video_plot_mdqn.mp4 17 | :width: 600 18 | 19 | """ 20 | 21 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_dqn.jpg' 22 | 23 | from rlberry.envs import gym_make 24 | from torch.utils.tensorboard import SummaryWriter 25 | 26 | from rlberry_research.agents.torch.dqn import MunchausenDQNAgent 27 | from rlberry.utils.logging import configure_logging 28 | 29 | from gymnasium.wrappers.rendering import RecordVideo 30 | import shutil 31 | import os 32 | 33 | 34 | configure_logging(level="INFO") 35 | 36 | env = gym_make("CartPole-v1", render_mode="rgb_array") 37 | agent = MunchausenDQNAgent(env, epsilon_decay_interval=1000) 38 | agent.set_writer(SummaryWriter()) 39 | 40 | print(f"Running Munchausen DQN on {env}") 41 | 42 | agent.fit(budget=10**5) 43 | env = RecordVideo(env, "_video/temp") 44 | 45 | 46 | for episode in range(3): 47 | done = False 48 | observation, info = env.reset() 49 | while not done: 50 | action = agent.policy(observation) 51 | observation, reward, terminated, truncated, info = env.step(action) 52 | done = terminated or truncated 53 | env.close() 54 | 55 | # need to move the final result inside the folder used for documentation 56 | os.rename("_video/temp/rl-video-episode-0.mp4", "_video/video_plot_mdqn.mp4") 57 | shutil.rmtree("_video/temp/") 58 | -------------------------------------------------------------------------------- /examples/demo_agents/video_plot_ppo.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================================== 3 | A demo of PPO algorithm in PBall2D environment 4 | ============================================== 5 | Illustration of how to set up an PPO algorithm in rlberry. 6 | The environment chosen here is PBALL2D environment. 7 | 8 | .. video:: ../../video_plot_ppo.mp4 9 | :width: 600 10 | 11 | """ 12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_a2c.jpg' 13 | 14 | from rlberry_research.agents.torch import PPOAgent 15 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D 16 | 17 | 18 | env = PBall2D() 19 | n_steps = 3e3 20 | 21 | agent = PPOAgent(env) 22 | agent.fit(budget=n_steps) 23 | 24 | env.enable_rendering() 25 | observation, info = env.reset() 26 | for tt in range(200): 27 | action = agent.policy(observation) 28 | observation, reward, terminated, truncated, info = env.step(action) 29 | done = terminated or truncated 30 | 31 | video = env.save_video("_video/video_plot_ppo.mp4") 32 | -------------------------------------------------------------------------------- /examples/demo_agents/video_plot_rs_kernel_ucbvi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================================================= 3 | A demo of RSKernelUCBVIAgent algorithm in Acrobot environment 4 | ============================================================= 5 | Illustration of how to set up a RSKernelUCBVI algorithm in rlberry. 6 | The environment chosen here is Acrobot environment. 7 | 8 | .. video:: ../../video_plot_rs_kernel_ucbvi.mp4 9 | :width: 600 10 | 11 | """ 12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_rs_kernel_ucbvi.jpg' 13 | 14 | from rlberry_research.envs import Acrobot 15 | from rlberry_research.agents import RSKernelUCBVIAgent 16 | from rlberry.wrappers import RescaleRewardWrapper 17 | 18 | env = Acrobot() 19 | # rescake rewards to [0, 1] 20 | env = RescaleRewardWrapper(env, (0.0, 1.0)) 21 | 22 | agent = RSKernelUCBVIAgent( 23 | env, 24 | gamma=0.99, 25 | horizon=300, 26 | bonus_scale_factor=0.01, 27 | min_dist=0.2, 28 | bandwidth=0.05, 29 | beta=1.0, 30 | kernel_type="gaussian", 31 | ) 32 | agent.fit(budget=500) 33 | 34 | env.enable_rendering() 35 | observation, info = env.reset() 36 | 37 | time_before_done = 0 38 | ended = False 39 | for tt in range(2 * agent.horizon): 40 | action = agent.policy(observation) 41 | observation, reward, terminated, truncated, info = env.step(action) 42 | done = terminated or truncated 43 | if not done and not ended: 44 | time_before_done += 1 45 | if done: 46 | ended = True 47 | 48 | print("steps to achieve the goal for the first time = ", time_before_done) 49 | video = env.save_video("_video/video_plot_rs_kernel_ucbvi.mp4") 50 | -------------------------------------------------------------------------------- /examples/demo_agents/video_plot_rsucbvi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ================================================== 3 | A demo of RSUCBVI algorithm in MountainCar environment 4 | ================================================== 5 | Illustration of how to set up an RSUCBVI algorithm in rlberry. 6 | The environment chosen here is MountainCar environment. 7 | 8 | .. video:: ../../video_plot_rsucbvi.mp4 9 | :width: 600 10 | 11 | """ 12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_rsucbvi.jpg' 13 | 14 | from rlberry_research.agents import RSUCBVIAgent 15 | from rlberry_research.envs.classic_control import MountainCar 16 | 17 | env = MountainCar() 18 | horizon = 170 19 | print("Running RS-UCBVI on %s" % env.name) 20 | agent = RSUCBVIAgent(env, gamma=0.99, horizon=horizon, bonus_scale_factor=0.1) 21 | agent.fit(budget=500) 22 | 23 | env.enable_rendering() 24 | observation, info = env.reset() 25 | for tt in range(200): 26 | action = agent.policy(observation) 27 | observation, reward, terminated, truncated, info = env.step(action) 28 | done = terminated or truncated 29 | 30 | video = env.save_video("_video/video_plot_rsucbvi.mp4") 31 | -------------------------------------------------------------------------------- /examples/demo_agents/video_plot_vi.py: -------------------------------------------------------------------------------- 1 | """ 2 | ======================================================= 3 | A demo of ValueIteration algorithm in Chain environment 4 | ======================================================= 5 | Illustration of how to set up an ValueIteration algorithm in rlberry. 6 | The environment chosen here is Chain environment. 7 | 8 | .. video:: ../../video_plot_vi.mp4 9 | :width: 600 10 | 11 | """ 12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_vi.jpg' 13 | 14 | from rlberry_scool.agents.dynprog import ValueIterationAgent 15 | from rlberry_scool.envs.finite import Chain 16 | 17 | env = Chain() 18 | agent = ValueIterationAgent(env, gamma=0.95) 19 | info = agent.fit() 20 | print(info) 21 | 22 | env.enable_rendering() 23 | observation, info = env.reset() 24 | for tt in range(50): 25 | action = agent.policy(observation) 26 | observation, reward, terminated, truncated, info = env.step(action) 27 | done = terminated or truncated 28 | if done: 29 | break 30 | video = env.save_video("_video/video_plot_vi.mp4") 31 | -------------------------------------------------------------------------------- /examples/demo_bandits/README.md: -------------------------------------------------------------------------------- 1 | Illustration of bandits in rlberry 2 | ================================== 3 | -------------------------------------------------------------------------------- /examples/demo_bandits/plot_exp3_bandit.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================= 3 | EXP3 Bandit cumulative regret 4 | ============================= 5 | 6 | This script shows how to define a bandit environment and an EXP3 7 | randomized algorithm. 8 | """ 9 | 10 | import numpy as np 11 | from rlberry_research.envs.bandits import AdversarialBandit 12 | from rlberry_research.agents.bandits import ( 13 | RandomizedAgent, 14 | TSAgent, 15 | makeEXP3Index, 16 | makeBetaPrior, 17 | ) 18 | from rlberry.manager import ExperimentManager, plot_writer_data 19 | 20 | 21 | # Agents definition 22 | 23 | 24 | class EXP3Agent(RandomizedAgent): 25 | name = "EXP3" 26 | 27 | def __init__(self, env, **kwargs): 28 | prob, tracker_params = makeEXP3Index() 29 | RandomizedAgent.__init__( 30 | self, 31 | env, 32 | prob, 33 | writer_extra="action", 34 | tracker_params=tracker_params, 35 | **kwargs 36 | ) 37 | 38 | 39 | class BernoulliTSAgent(TSAgent): 40 | """Thompson sampling for Bernoulli bandit""" 41 | 42 | name = "TS" 43 | 44 | def __init__(self, env, **kwargs): 45 | prior, _ = makeBetaPrior() 46 | TSAgent.__init__(self, env, prior, writer_extra="action", **kwargs) 47 | 48 | 49 | # Parameters of the problem 50 | T = 3000 # Horizon 51 | M = 20 # number of MC simu 52 | 53 | 54 | def switching_rewards(T, gap=0.1, rate=1.6): 55 | """Adversarially switching rewards over exponentially long phases. 56 | Inspired by Zimmert, Julian, and Yevgeny Seldin. 57 | "Tsallis-INF: An Optimal Algorithm for Stochastic and Adversarial Bandits." 58 | J. Mach. Learn. Res. 22 (2021): 28-1. 59 | """ 60 | rewards = np.zeros((T, 2)) 61 | t = 0 62 | exp = 1 63 | high_rewards = True 64 | for t in range(T): 65 | if t > np.floor(rate**exp): 66 | high_rewards = not high_rewards 67 | exp += 1 68 | if high_rewards: 69 | rewards[t] = [1.0 - gap, 1.0] 70 | else: 71 | rewards[t] = [0.0, gap] 72 | return rewards 73 | 74 | 75 | rewards = switching_rewards(T, rate=5.0) 76 | 77 | 78 | # Construction of the experiment 79 | 80 | env_ctor = AdversarialBandit 81 | env_kwargs = {"rewards": rewards} 82 | 83 | Agents_class = [EXP3Agent, BernoulliTSAgent] 84 | 85 | agents = [ 86 | ExperimentManager( 87 | Agent, 88 | (env_ctor, env_kwargs), 89 | init_kwargs={}, 90 | fit_budget=T, 91 | n_fit=M, 92 | parallelization="process", 93 | mp_context="fork", 94 | ) 95 | for Agent in Agents_class 96 | ] 97 | 98 | # these parameters should give parallel computing even in notebooks 99 | 100 | 101 | # Agent training 102 | for agent in agents: 103 | agent.fit() 104 | 105 | 106 | # Compute and plot (pseudo-)regret 107 | def compute_pseudo_regret(actions): 108 | selected_rewards = np.array( 109 | [rewards[t, int(action)] for t, action in enumerate(actions)] 110 | ) 111 | return np.cumsum(np.max(rewards, axis=1) - selected_rewards) 112 | 113 | 114 | output = plot_writer_data( 115 | agents, 116 | tag="action", 117 | preprocess_func=compute_pseudo_regret, 118 | title="Cumulative Pseudo-Regret", 119 | ) 120 | -------------------------------------------------------------------------------- /examples/demo_bandits/plot_ucb_bandit.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================= 3 | UCB Bandit cumulative regret 4 | ============================= 5 | 6 | This script shows how to define a bandit environment and an UCB Index-based algorithm. 7 | """ 8 | 9 | import numpy as np 10 | from rlberry_research.envs.bandits import NormalBandit 11 | from rlberry_research.agents.bandits import IndexAgent, makeSubgaussianUCBIndex 12 | from rlberry.manager import ExperimentManager, plot_writer_data 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | # Agents definition 17 | 18 | 19 | class UCBAgent(IndexAgent): 20 | """UCB agent for sigma-subgaussian bandits""" 21 | 22 | name = "UCB Agent" 23 | 24 | def __init__(self, env, sigma=1, **kwargs): 25 | index, _ = makeSubgaussianUCBIndex(sigma) 26 | IndexAgent.__init__(self, env, index, writer_extra="action", **kwargs) 27 | 28 | 29 | # Parameters of the problem 30 | means = np.array([0, 0.9, 1]) # means of the arms 31 | T = 3000 # Horizon 32 | M = 20 # number of MC simu 33 | 34 | # Construction of the experiment 35 | 36 | env_ctor = NormalBandit 37 | env_kwargs = {"means": means, "stds": 2 * np.ones(len(means))} 38 | 39 | xp_manager = ExperimentManager( 40 | UCBAgent, 41 | (env_ctor, env_kwargs), 42 | fit_budget=T, 43 | init_kwargs={"sigma": 2}, 44 | n_fit=M, 45 | parallelization="process", 46 | mp_context="fork", 47 | ) 48 | # these parameters should give parallel computing even in notebooks 49 | 50 | 51 | # Agent training 52 | 53 | xp_manager.fit() 54 | 55 | 56 | # Compute and plot (pseudo-)regret 57 | def compute_pseudo_regret(actions): 58 | return np.cumsum(np.max(means) - means[actions.astype(int)]) 59 | 60 | 61 | fig = plt.figure(1, figsize=(5, 3)) 62 | ax = plt.gca() 63 | output = plot_writer_data( 64 | [xp_manager], 65 | tag="action", 66 | preprocess_func=compute_pseudo_regret, 67 | title="Cumulative Pseudo-Regret", 68 | ax=ax, 69 | ) 70 | -------------------------------------------------------------------------------- /examples/demo_env/README.md: -------------------------------------------------------------------------------- 1 | Illustration of rlberry environments 2 | ==================================== 3 | -------------------------------------------------------------------------------- /examples/demo_env/video_plot_acrobot.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================================== 3 | A demo of Acrobot environment with RSUCBVIAgent 4 | =============================================== 5 | Illustration of the training and video rendering of RSUCBVI Agent in Acrobot 6 | environment. 7 | 8 | .. video:: ../../video_plot_acrobot.mp4 9 | :width: 600 10 | 11 | """ 12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_acrobot.jpg' 13 | 14 | from rlberry_research.envs import Acrobot 15 | from rlberry_research.agents import RSUCBVIAgent 16 | from rlberry.wrappers import RescaleRewardWrapper 17 | 18 | env = Acrobot() 19 | # rescale rewards to [0, 1] 20 | env = RescaleRewardWrapper(env, (0.0, 1.0)) 21 | n_episodes = 300 22 | agent = RSUCBVIAgent( 23 | env, gamma=0.99, horizon=300, bonus_scale_factor=0.01, min_dist=0.25 24 | ) 25 | agent.fit(budget=n_episodes) 26 | 27 | env.enable_rendering() 28 | observation, info = env.reset() 29 | for tt in range(2 * agent.horizon): 30 | action = agent.policy(observation) 31 | observation, reward, terminated, truncated, info = env.step(action) 32 | done = terminated or truncated 33 | 34 | # Save video 35 | video = env.save_video("_video/video_plot_acrobot.mp4") 36 | -------------------------------------------------------------------------------- /examples/demo_env/video_plot_apple_gold.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | A demo of AppleGold environment 4 | =============================== 5 | Illustration of Applegold environment on which we train a ValueIteration 6 | algorithm. 7 | 8 | .. video:: ../../video_plot_apple_gold.mp4 9 | :width: 600 10 | 11 | """ 12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_apple_gold.jpg' 13 | from rlberry_research.envs.benchmarks.grid_exploration.apple_gold import AppleGold 14 | from rlberry_scool.agents.dynprog import ValueIterationAgent 15 | 16 | env = AppleGold(reward_free=False, array_observation=False) 17 | 18 | agent = ValueIterationAgent(env, gamma=0.9) 19 | info = agent.fit() 20 | print(info) 21 | 22 | env.enable_rendering() 23 | 24 | observation, info = env.reset() 25 | for tt in range(5): 26 | action = agent.policy(observation) 27 | observation, reward, terminated, truncated, info = env.step(action) 28 | done = terminated or truncated 29 | if done: 30 | break 31 | env.render() 32 | video = env.save_video("_video/video_plot_apple_gold.mp4") 33 | -------------------------------------------------------------------------------- /examples/demo_env/video_plot_chain.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | A demo of Chain environment 4 | =============================== 5 | Illustration of Chain environment 6 | 7 | .. video:: ../../video_plot_chain.mp4 8 | :width: 600 9 | 10 | """ 11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_chain.jpg' 12 | 13 | 14 | from rlberry_scool.envs.finite import Chain 15 | 16 | env = Chain(10, 0.1) 17 | env.enable_rendering() 18 | for tt in range(5): 19 | env.step(env.action_space.sample()) 20 | env.render() 21 | env.save_video("_video/video_plot_chain.mp4") 22 | -------------------------------------------------------------------------------- /examples/demo_env/video_plot_gridworld.py: -------------------------------------------------------------------------------- 1 | """ 2 | .. _gridworld_example: 3 | 4 | ======================================================== 5 | A demo of Gridworld environment with ValueIterationAgent 6 | ======================================================== 7 | Illustration of the training and video rendering ofValueIteration Agent in 8 | Gridworld environment. 9 | 10 | .. video:: ../../video_plot_gridworld.mp4 11 | :width: 600 12 | """ 13 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_gridworld.jpg' 14 | 15 | from rlberry_scool.agents.dynprog import ValueIterationAgent 16 | from rlberry_scool.envs.finite import GridWorld 17 | 18 | 19 | env = GridWorld(7, 10, walls=((2, 2), (3, 3))) 20 | 21 | agent = ValueIterationAgent(env, gamma=0.95) 22 | info = agent.fit() 23 | print(info) 24 | 25 | env.enable_rendering() 26 | observation, info = env.reset() 27 | for tt in range(50): 28 | action = agent.policy(observation) 29 | observation, reward, terminated, truncated, info = env.step(action) 30 | done = terminated or truncated 31 | if done: 32 | # Warning: this will never happen in the present case because there is no terminal state. 33 | # See the doc of GridWorld for more informations on the default parameters of GridWorld. 34 | break 35 | # Save the video 36 | env.save_video("_video/video_plot_gridworld.mp4", framerate=10) 37 | -------------------------------------------------------------------------------- /examples/demo_env/video_plot_mountain_car.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | A demo of MountainCar environment 4 | =============================== 5 | Illustration of MountainCar environment 6 | 7 | .. video:: ../../video_plot_montain_car.mp4 8 | :width: 600 9 | 10 | """ 11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_montain_car.jpg' 12 | 13 | from rlberry_scool.agents.mbqvi import MBQVIAgent 14 | from rlberry_research.envs.classic_control import MountainCar 15 | from rlberry.wrappers import DiscretizeStateWrapper 16 | 17 | _env = MountainCar() 18 | env = DiscretizeStateWrapper(_env, 20) 19 | agent = MBQVIAgent(env, n_samples=40, gamma=0.99) 20 | agent.fit() 21 | 22 | env.enable_rendering() 23 | observation, info = env.reset() 24 | for tt in range(200): 25 | action = agent.policy(observation) 26 | observation, reward, terminated, truncated, info = env.step(action) 27 | done = terminated or truncated 28 | 29 | video = env.save_video("_video/video_plot_montain_car.mp4") 30 | -------------------------------------------------------------------------------- /examples/demo_env/video_plot_old_gym_compatibility_wrapper_old_acrobot.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================================== 3 | A demo of OldGymCompatibilityWrapper with old_Acrobot environment 4 | =============================================== 5 | Illustration of the wrapper for old environments (old Acrobot). 6 | 7 | .. video:: ../../video_plot_old_gym_acrobot.mp4 8 | :width: 600 9 | 10 | """ 11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_old_gym_acrobot.jpg' 12 | 13 | 14 | from rlberry.wrappers.tests.old_env.old_acrobot import Old_Acrobot 15 | from rlberry_research.agents import RSUCBVIAgent 16 | from rlberry.wrappers import RescaleRewardWrapper 17 | from rlberry.wrappers.gym_utils import OldGymCompatibilityWrapper 18 | 19 | env = Old_Acrobot() 20 | env = OldGymCompatibilityWrapper(env) 21 | env = RescaleRewardWrapper(env, (0.0, 1.0)) 22 | n_episodes = 300 23 | agent = RSUCBVIAgent( 24 | env, gamma=0.99, horizon=300, bonus_scale_factor=0.01, min_dist=0.25 25 | ) 26 | result = env.reset(seed=42) 27 | 28 | agent.fit(budget=n_episodes) 29 | 30 | env.enable_rendering() 31 | observation, info = env.reset() 32 | for tt in range(2 * agent.horizon): 33 | action = agent.policy(observation) 34 | observation, reward, terminated, truncated, info = env.step(action) 35 | done = terminated or truncated 36 | 37 | # Save video 38 | video = env.save_video("_video/video_plot_old_gym_acrobot.mp4") 39 | -------------------------------------------------------------------------------- /examples/demo_env/video_plot_pball.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | A demo of PBALL2D environment 4 | =============================== 5 | Illustration of PBall2D environment 6 | 7 | .. video:: ../../video_plot_pball.mp4 8 | :width: 600 9 | 10 | """ 11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_pball.jpg' 12 | 13 | import numpy as np 14 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D 15 | 16 | p = 5 17 | A = np.array([[1.0, 0.1], [-0.1, 1.0]]) 18 | 19 | reward_amplitudes = np.array([1.0, 0.5, 0.5]) 20 | reward_smoothness = np.array([0.25, 0.25, 0.25]) 21 | 22 | reward_centers = [ 23 | np.array([0.75 * np.cos(np.pi / 2), 0.75 * np.sin(np.pi / 2)]), 24 | np.array([0.75 * np.cos(np.pi / 6), 0.75 * np.sin(np.pi / 6)]), 25 | np.array([0.75 * np.cos(5 * np.pi / 6), 0.75 * np.sin(5 * np.pi / 6)]), 26 | ] 27 | 28 | action_list = [ 29 | 0.1 * np.array([1, 0]), 30 | -0.1 * np.array([1, 0]), 31 | 0.1 * np.array([0, 1]), 32 | -0.1 * np.array([0, 1]), 33 | ] 34 | 35 | env = PBall2D( 36 | p=p, 37 | A=A, 38 | reward_amplitudes=reward_amplitudes, 39 | reward_centers=reward_centers, 40 | reward_smoothness=reward_smoothness, 41 | action_list=action_list, 42 | ) 43 | 44 | env.enable_rendering() 45 | 46 | for ii in range(5): 47 | env.step(1) 48 | env.step(3) 49 | 50 | env.render() 51 | video = env.save_video("_video/video_plot_pball.mp4") 52 | -------------------------------------------------------------------------------- /examples/demo_env/video_plot_rooms.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | A demo of rooms environment 4 | =============================== 5 | Illustration of NRooms environment 6 | 7 | .. video:: ../../video_plot_rooms.mp4 8 | :width: 600 9 | 10 | """ 11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_rooms.jpg' 12 | 13 | from rlberry_research.envs.benchmarks.grid_exploration.nroom import NRoom 14 | from rlberry_scool.agents.dynprog import ValueIterationAgent 15 | 16 | env = NRoom( 17 | nrooms=9, 18 | remove_walls=False, 19 | room_size=9, 20 | initial_state_distribution="center", 21 | include_traps=True, 22 | ) 23 | horizon = env.observation_space.n 24 | 25 | agent = ValueIterationAgent(env, gamma=0.999, horizon=horizon) 26 | print("fitting...") 27 | info = agent.fit() 28 | print(info) 29 | 30 | env.enable_rendering() 31 | 32 | for _ in range(10): 33 | observation, info = env.reset() 34 | for tt in range(horizon): 35 | # action = agent.policy(observation) 36 | action = env.action_space.sample() 37 | observation, reward, terminated, truncated, info = env.step(action) 38 | done = terminated or truncated 39 | if done: 40 | break 41 | env.render() 42 | video = env.save_video("_video/video_plot_rooms.mp4") 43 | -------------------------------------------------------------------------------- /examples/demo_env/video_plot_springcartpole.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================================== 3 | A demo of SpringCartPole environment with DQNAgent 4 | =============================================== 5 | Illustration of the training and video rendering of DQN Agent in 6 | SpringCartPole environment. 7 | 8 | Agent is slightly tuned, but not optimal. This is just for illustration purpose. 9 | 10 | .. video:: ../../video_plot_springcartpole.mp4 11 | :width: 600 12 | 13 | """ 14 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_springcartpole.jpg' 15 | 16 | from rlberry_research.envs.classic_control import SpringCartPole 17 | from rlberry_research.agents.torch import DQNAgent 18 | from gymnasium.wrappers.time_limit import TimeLimit 19 | 20 | model_configs = { 21 | "type": "MultiLayerPerceptron", 22 | "layer_sizes": (256, 256), 23 | "reshape": False, 24 | } 25 | 26 | init_kwargs = dict( 27 | q_net_constructor="rlberry_research.agents.torch.utils.training.model_factory_from_env", 28 | q_net_kwargs=model_configs, 29 | ) 30 | 31 | env = SpringCartPole(obs_trans=False, swing_up=True) 32 | env = TimeLimit(env, max_episode_steps=500) 33 | agent = DQNAgent(env, **init_kwargs) 34 | agent.fit(budget=1e5) 35 | 36 | env.enable_rendering() 37 | observation, info = env.reset() 38 | 39 | for tt in range(1000): 40 | action = agent.policy(observation) 41 | observation, reward, terminated, truncated, info = env.step(action) 42 | done = terminated or truncated 43 | if done: 44 | observation, info = env.reset() 45 | 46 | # Save video 47 | video = env.save_video("_video/video_plot_springcartpole.mp4") 48 | -------------------------------------------------------------------------------- /examples/demo_env/video_plot_twinrooms.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================== 3 | A demo of twinrooms environment 4 | =============================== 5 | Illustration of TwinRooms environment 6 | 7 | .. video:: ../../video_plot_twinrooms.mp4 8 | :width: 600 9 | 10 | """ 11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_twinrooms.jpg' 12 | 13 | from rlberry_research.envs.benchmarks.generalization.twinrooms import TwinRooms 14 | from rlberry_scool.agents.mbqvi import MBQVIAgent 15 | from rlberry.wrappers.discretize_state import DiscretizeStateWrapper 16 | from rlberry.seeding import Seeder 17 | 18 | seeder = Seeder(123) 19 | 20 | env = TwinRooms() 21 | env = DiscretizeStateWrapper(env, n_bins=20) 22 | env.reseed(seeder) 23 | horizon = 20 24 | agent = MBQVIAgent(env, n_samples=10, gamma=1.0, horizon=horizon) 25 | agent.reseed(seeder) 26 | agent.fit() 27 | 28 | observation, info = env.reset() 29 | env.enable_rendering() 30 | for ii in range(10): 31 | action = agent.policy(observation) 32 | observation, reward, terminated, truncated, info = env.step(action) 33 | done = terminated or truncated 34 | 35 | if (ii + 1) % horizon == 0: 36 | observation, info = env.reset() 37 | 38 | env.render() 39 | video = env.save_video("_video/video_plot_twinrooms.mp4") 40 | -------------------------------------------------------------------------------- /examples/demo_experiment/params_experiment.yaml: -------------------------------------------------------------------------------- 1 | # """ 2 | # ===================== 3 | # Demo: params_experiment.yaml 4 | # ===================== 5 | # """ 6 | description: 'RSUCBVI in NRoom' 7 | seed: 123 8 | train_env: 'examples/demo_experiment/room.yaml' 9 | eval_env: 'examples/demo_experiment/room.yaml' 10 | agents: 11 | - 'examples/demo_experiment/rsucbvi.yaml' 12 | - 'examples/demo_experiment/rsucbvi_alternative.yaml' 13 | -------------------------------------------------------------------------------- /examples/demo_experiment/room.yaml: -------------------------------------------------------------------------------- 1 | # """ 2 | # ===================== 3 | # Demo: room.yaml 4 | # ===================== 5 | # """ 6 | constructor: 'rlberry_research.envs.benchmarks.grid_exploration.nroom.NRoom' 7 | params: 8 | reward_free: false 9 | array_observation: true 10 | nrooms: 5 11 | -------------------------------------------------------------------------------- /examples/demo_experiment/rsucbvi.yaml: -------------------------------------------------------------------------------- 1 | # """ 2 | # ===================== 3 | # Demo: rsucbvi.yaml 4 | # ===================== 5 | # """ 6 | agent_class: 'rlberry_research.agents.kernel_based.rs_ucbvi.RSUCBVIAgent' 7 | init_kwargs: 8 | gamma: 1.0 9 | lp_metric: 2 10 | min_dist: 0.0 11 | max_repr: 800 12 | bonus_scale_factor: 1.0 13 | reward_free: True 14 | horizon: 50 15 | eval_kwargs: 16 | eval_horizon: 50 17 | fit_kwargs: 18 | fit_budget: 100 19 | -------------------------------------------------------------------------------- /examples/demo_experiment/rsucbvi_alternative.yaml: -------------------------------------------------------------------------------- 1 | # """ 2 | # ===================== 3 | # Demo: rsucbvi_alternative.yaml 4 | # ===================== 5 | # """ 6 | base_config: 'examples/demo_experiment/rsucbvi.yaml' 7 | init_kwargs: 8 | gamma: 0.9 9 | -------------------------------------------------------------------------------- /examples/demo_experiment/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | ===================== 3 | Demo: run 4 | ===================== 5 | To run the experiment: 6 | 7 | $ python examples/demo_experiment/run.py examples/demo_experiment/params_experiment.yaml 8 | 9 | To see more options: 10 | 11 | $ python examples/demo_experiment/run.py 12 | """ 13 | 14 | from rlberry.experiment import load_experiment_results 15 | from rlberry.experiment import experiment_generator 16 | from rlberry.manager.multiple_managers import MultipleManagers 17 | 18 | 19 | if __name__ == "__main__": 20 | multimanagers = MultipleManagers(parallelization="thread") 21 | 22 | for experiment_manager in experiment_generator(): 23 | multimanagers.append(experiment_manager) 24 | 25 | multimanagers.run() 26 | multimanagers.save() 27 | 28 | # Reading the results 29 | del multimanagers 30 | 31 | data = load_experiment_results("results", "params_experiment") 32 | 33 | print(data) 34 | 35 | # Fit one of the managers for a few more episodes 36 | # If tensorboard is enabled, you should see more episodes ran for 'rsucbvi_alternative' 37 | data["manager"]["rsucbvi_alternative"].fit(50) 38 | -------------------------------------------------------------------------------- /examples/plot_kernels.py: -------------------------------------------------------------------------------- 1 | """ 2 | ===================== 3 | Plot kernel functions 4 | ===================== 5 | 6 | This script requires matplotlib 7 | """ 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | from rlberry_research.agents.kernel_based.kernels import kernel_func 12 | 13 | kernel_types = [ 14 | "uniform", 15 | "triangular", 16 | "gaussian", 17 | "epanechnikov", 18 | "quartic", 19 | "triweight", 20 | "tricube", 21 | "cosine", 22 | "exp-4", 23 | ] 24 | 25 | z = np.linspace(-2, 2, 100) 26 | 27 | 28 | fig, axes = plt.subplots(1, len(kernel_types), figsize=(15, 5)) 29 | for ii, k_type in enumerate(kernel_types): 30 | kernel_vals = kernel_func(z, k_type) 31 | axes[ii].plot(z, kernel_vals) 32 | axes[ii].set_title(k_type) 33 | plt.show() 34 | -------------------------------------------------------------------------------- /examples/plot_writer_wrapper.py: -------------------------------------------------------------------------------- 1 | """ 2 | ============================================== 3 | Record reward during training and then plot it 4 | ============================================== 5 | 6 | This script shows how to modify an agent to easily record reward or action 7 | during the fit of the agent and then use the plot utils. 8 | 9 | .. note:: 10 | If you already ran this script once, the fitted agent has been saved 11 | in rlberry_data folder. Then, you can comment-out the line 12 | 13 | .. code-block:: python 14 | 15 | agent.fit(budget=10) 16 | 17 | and avoid fitting the agent one more time, the statistics from the last 18 | time you fitted the agent will automatically be loaded. See 19 | `rlberry.manager.plot_writer_data` documentation for more information. 20 | """ 21 | 22 | 23 | import numpy as np 24 | 25 | from rlberry_scool.envs import GridWorld 26 | from rlberry.manager import plot_writer_data, ExperimentManager 27 | from rlberry_scool.agents import UCBVIAgent 28 | import matplotlib.pyplot as plt 29 | 30 | # We wrape the default writer of the agent in a WriterWrapper 31 | # to record rewards. 32 | 33 | 34 | class VIAgent(UCBVIAgent): 35 | name = "UCBVIAgent" 36 | 37 | def __init__(self, env, **kwargs): 38 | UCBVIAgent.__init__(self, env, writer_extra="reward", horizon=50, **kwargs) 39 | 40 | 41 | env_ctor = GridWorld 42 | env_kwargs = dict( 43 | nrows=3, 44 | ncols=10, 45 | reward_at={(1, 1): 0.1, (2, 9): 1.0}, 46 | walls=((1, 4), (2, 4), (1, 5)), 47 | success_probability=0.7, 48 | ) 49 | 50 | env = env_ctor(**env_kwargs) 51 | xp_manager = ExperimentManager(VIAgent, (env_ctor, env_kwargs), fit_budget=10, n_fit=3) 52 | 53 | xp_manager.fit(budget=10) 54 | # comment the line above if you only want to load data from rlberry_data. 55 | 56 | 57 | # We use the following preprocessing function to plot the cumulative reward. 58 | def compute_reward(rewards): 59 | return np.cumsum(rewards) 60 | 61 | 62 | # Plot of the cumulative reward. 63 | output = plot_writer_data( 64 | xp_manager, tag="reward", preprocess_func=compute_reward, title="Cumulative Reward" 65 | ) 66 | # The output is for 500 global steps because it uses 10 fit_budget * horizon 67 | 68 | # Log-Log plot : 69 | fig, ax = plt.subplots(1, 1) 70 | plot_writer_data( 71 | xp_manager, 72 | tag="reward", 73 | preprocess_func=compute_reward, 74 | title="Cumulative Reward", 75 | ax=ax, 76 | show=False, # necessary to customize axes 77 | ) 78 | ax.set_xlim(100, 500) 79 | ax.relim() 80 | ax.set_xscale("log") 81 | ax.set_yscale("log") 82 | -------------------------------------------------------------------------------- /rlberry/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import metadata 2 | 3 | __version__ = metadata.version("rlberry") 4 | 5 | import logging 6 | 7 | logger = logging.getLogger("rlberry_logger") 8 | 9 | from rlberry.utils.logging import configure_logging 10 | 11 | 12 | __path__ = __import__("pkgutil").extend_path(__path__, __name__) 13 | 14 | # Initialize logging level 15 | configure_logging(level="INFO") 16 | 17 | 18 | # define __version__ 19 | 20 | __all__ = ["__version__", "logger"] 21 | -------------------------------------------------------------------------------- /rlberry/agents/__init__.py: -------------------------------------------------------------------------------- 1 | # Interfaces 2 | from .agent import Agent 3 | from .agent import AgentWithSimplePolicy 4 | from .agent import AgentTorch 5 | -------------------------------------------------------------------------------- /rlberry/agents/stable_baselines/__init__.py: -------------------------------------------------------------------------------- 1 | from .stable_baselines import StableBaselinesAgent 2 | -------------------------------------------------------------------------------- /rlberry/agents/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/agents/tests/__init__.py -------------------------------------------------------------------------------- /rlberry/agents/tests/test_stable_baselines.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | 3 | from stable_baselines3 import A2C 4 | 5 | from rlberry.envs import gym_make 6 | from rlberry.agents.stable_baselines import StableBaselinesAgent 7 | from rlberry.utils.check_agent import check_rl_agent 8 | 9 | 10 | def test_sb3_agent(): 11 | # Test only one algorithm per action space type 12 | check_rl_agent( 13 | StableBaselinesAgent, 14 | env=(gym_make, {"id": "Pendulum-v1"}), 15 | init_kwargs={"algo_cls": A2C, "policy": "MlpPolicy", "verbose": 1}, 16 | ) 17 | check_rl_agent( 18 | StableBaselinesAgent, 19 | env=(gym_make, {"id": "CartPole-v1"}), 20 | init_kwargs={"algo_cls": A2C, "policy": "MlpPolicy", "verbose": 1}, 21 | ) 22 | 23 | 24 | def test_sb3_tensorboard_log(): 25 | # Test tensorboard support 26 | with tempfile.TemporaryDirectory() as tmpdir: 27 | check_rl_agent( 28 | StableBaselinesAgent, 29 | env=(gym_make, {"id": "Pendulum-v1"}), 30 | init_kwargs={ 31 | "algo_cls": A2C, 32 | "policy": "MlpPolicy", 33 | "verbose": 1, 34 | "tensorboard_log": tmpdir, 35 | }, 36 | ) 37 | -------------------------------------------------------------------------------- /rlberry/agents/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/agents/utils/__init__.py -------------------------------------------------------------------------------- /rlberry/check_packages.py: -------------------------------------------------------------------------------- 1 | # Define import flags 2 | 3 | TORCH_INSTALLED = True 4 | try: 5 | import torch 6 | except ModuleNotFoundError: # pragma: no cover 7 | TORCH_INSTALLED = False # pragma: no cover 8 | 9 | TENSORBOARD_INSTALLED = True 10 | try: 11 | import torch.utils.tensorboard 12 | except ModuleNotFoundError: # pragma: no cover 13 | TENSORBOARD_INSTALLED = False # pragma: no cover 14 | -------------------------------------------------------------------------------- /rlberry/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .gym_make import gym_make, atari_make 2 | from .basewrapper import Wrapper 3 | from .interface import Model 4 | from .pipeline import PipelineEnv 5 | from .finite_mdp import FiniteMDP 6 | -------------------------------------------------------------------------------- /rlberry/envs/interface/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model 2 | -------------------------------------------------------------------------------- /rlberry/envs/pipeline.py: -------------------------------------------------------------------------------- 1 | from rlberry.envs import Wrapper 2 | 3 | 4 | class PipelineEnv(Wrapper): 5 | """ 6 | Environment defined as a pipeline of wrappers and an environment to wrap. 7 | 8 | Parameters 9 | ---------- 10 | env_ctor: environment class 11 | 12 | env_kwargs: dictionary 13 | kwargs fed to the environment 14 | 15 | wrappers: list of tuple (wrapper, wrapper_kwargs) 16 | list of tuple (wrapper, wrapper_kwargs) to be applied to the environment. 17 | The list [wrapper1, wrapper2] will be applied in the order wrapper1(wrapper2(env)) 18 | 19 | Examples 20 | -------- 21 | >>> from rlberry.envs import PipelineEnv 22 | >>> from rlberry.envs import gym_make 23 | >>> from rlberry.wrappers import RescaleRewardWrapper 24 | >>> 25 | >>> env_ctor, env_kwargs = PipelineEnv, { 26 | >>> "env_ctor": gym_make, 27 | >>> "env_kwargs": {"id": "Acrobot-v1"}, 28 | >>> "wrappers": [(RescaleRewardWrapper, {"reward_range": (0, 1)})], 29 | >>> } 30 | >>> eval_env = (gym_make, {"id":"Acrobot-v1"}) # unscaled env for evaluation 31 | 32 | """ 33 | 34 | def __init__(self, env_ctor, env_kwargs, wrappers): 35 | env = env_ctor(**env_kwargs) 36 | for wrapper in wrappers[::-1]: 37 | env = wrapper[0](env, **wrapper[1]) 38 | env.reset() 39 | Wrapper.__init__(self, env) 40 | -------------------------------------------------------------------------------- /rlberry/envs/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/envs/tests/__init__.py -------------------------------------------------------------------------------- /rlberry/envs/tests/test_env_seeding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import rlberry.seeding as seeding 4 | 5 | from copy import deepcopy 6 | from rlberry_research.envs.classic_control import MountainCar, Pendulum 7 | from rlberry_scool.envs.finite import Chain 8 | from rlberry_scool.envs.finite import GridWorld 9 | from rlberry_research.envs.benchmarks.grid_exploration.four_room import FourRoom 10 | from rlberry_research.envs.benchmarks.grid_exploration.six_room import SixRoom 11 | from rlberry_research.envs.benchmarks.grid_exploration.apple_gold import AppleGold 12 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D, SimplePBallND 13 | 14 | classes = [ 15 | MountainCar, 16 | GridWorld, 17 | Chain, 18 | PBall2D, 19 | SimplePBallND, 20 | Pendulum, 21 | FourRoom, 22 | SixRoom, 23 | AppleGold, 24 | ] 25 | 26 | 27 | def get_env_trajectory(env, horizon): 28 | states = [] 29 | ss, info = env.reset() 30 | for ii in range(horizon): 31 | states.append(ss) 32 | ss, _, _, _, _ = env.step(env.action_space.sample()) 33 | return states 34 | 35 | 36 | def compare_trajectories(traj1, traj2): 37 | """ 38 | returns true if trajectories are equal 39 | """ 40 | for ss1, ss2 in zip(traj1, traj2): 41 | if not np.array_equal(ss1, ss2): 42 | return False 43 | return True 44 | 45 | 46 | @pytest.mark.parametrize("ModelClass", classes) 47 | def test_env_seeding(ModelClass): 48 | env1 = ModelClass() 49 | seeder1 = seeding.Seeder(123) 50 | env1.reseed(seeder1) 51 | 52 | env2 = ModelClass() 53 | seeder2 = seeder1.spawn() 54 | env2.reseed(seeder2) 55 | 56 | env3 = ModelClass() 57 | seeder3 = seeding.Seeder(123) 58 | env3.reseed(seeder3) 59 | 60 | env4 = ModelClass() 61 | seeder4 = seeding.Seeder(123) 62 | env4.reseed(seeder4) 63 | 64 | env5 = ModelClass() 65 | env5.reseed( 66 | seeder1 67 | ) # same seeder as env1, but different trajectories. This is expected. 68 | 69 | seeding.safe_reseed(env4, seeder4) 70 | 71 | if deepcopy(env1).is_online(): 72 | traj1 = get_env_trajectory(env1, 500) 73 | traj2 = get_env_trajectory(env2, 500) 74 | traj3 = get_env_trajectory(env3, 500) 75 | traj4 = get_env_trajectory(env4, 500) 76 | traj5 = get_env_trajectory(env5, 500) 77 | 78 | assert not compare_trajectories(traj1, traj2) 79 | assert compare_trajectories(traj1, traj3) 80 | assert not compare_trajectories(traj3, traj4) 81 | assert not compare_trajectories(traj1, traj5) 82 | 83 | 84 | @pytest.mark.parametrize("ModelClass", classes) 85 | def test_copy_reseeding(ModelClass): 86 | seeder = seeding.Seeder(123) 87 | env = ModelClass() 88 | env.reseed(seeder) 89 | 90 | c_env = deepcopy(env) 91 | c_env.reseed() 92 | 93 | if deepcopy(env).is_online(): 94 | traj1 = get_env_trajectory(env, 500) 95 | traj2 = get_env_trajectory(c_env, 500) 96 | assert not compare_trajectories(traj1, traj2) 97 | -------------------------------------------------------------------------------- /rlberry/envs/tests/test_gym_env_seeding.py: -------------------------------------------------------------------------------- 1 | from rlberry.seeding.seeding import safe_reseed 2 | import gymnasium as gym 3 | import numpy as np 4 | import pytest 5 | from rlberry.seeding import Seeder 6 | from rlberry.envs import gym_make 7 | 8 | from copy import deepcopy 9 | 10 | gym_envs = [ 11 | "Acrobot-v1", 12 | "CartPole-v1", 13 | "MountainCar-v0", 14 | ] 15 | 16 | 17 | def get_env_trajectory(env, horizon): 18 | states = [] 19 | observation, info = env.reset() 20 | for ii in range(horizon): 21 | states.append(observation) 22 | observation, _, terminated, truncated, _ = env.step(env.action_space.sample()) 23 | done = terminated or truncated 24 | if done: 25 | observation, info = env.reset() 26 | return states 27 | 28 | 29 | def compare_trajectories(traj1, traj2): 30 | for ss1, ss2 in zip(traj1, traj2): 31 | if not np.array_equal(ss1, ss2): 32 | return False 33 | return True 34 | 35 | 36 | @pytest.mark.parametrize("env_name", gym_envs) 37 | def test_env_seeding(env_name): 38 | seeder1 = Seeder(123) 39 | env1 = gym_make(env_name, module_import="numpy") 40 | env1.reseed(seeder1) 41 | 42 | seeder2 = Seeder(456) 43 | env2 = gym_make(env_name) 44 | env2.reseed(seeder2) 45 | 46 | seeder3 = Seeder(123) 47 | env3 = gym_make(env_name) 48 | env3.reseed(seeder3) 49 | 50 | if deepcopy(env1).is_online(): 51 | traj1 = get_env_trajectory(env1, 500) 52 | traj2 = get_env_trajectory(env2, 500) 53 | traj3 = get_env_trajectory(env3, 500) 54 | 55 | assert not compare_trajectories(traj1, traj2) 56 | assert compare_trajectories(traj1, traj3) 57 | 58 | 59 | @pytest.mark.parametrize("env_name", gym_envs) 60 | def test_copy_reseeding(env_name): 61 | seeder = Seeder(123) 62 | env = gym_make(env_name) 63 | env.reseed(seeder) 64 | 65 | c_env = deepcopy(env) 66 | c_env.reseed() 67 | 68 | if deepcopy(env).is_online(): 69 | traj1 = get_env_trajectory(env, 500) 70 | traj2 = get_env_trajectory(c_env, 500) 71 | assert not compare_trajectories(traj1, traj2) 72 | 73 | 74 | @pytest.mark.parametrize("env_name", gym_envs) 75 | def test_gym_safe_reseed(env_name): 76 | seeder = Seeder(123) 77 | seeder_aux = Seeder(123) 78 | 79 | env1 = gym.make(env_name) 80 | env2 = gym.make(env_name) 81 | env3 = gym.make(env_name) 82 | 83 | safe_reseed(env1, seeder) 84 | safe_reseed(env2, seeder) 85 | safe_reseed(env3, seeder_aux) 86 | 87 | traj1 = get_env_trajectory(env1, 500) 88 | traj2 = get_env_trajectory(env2, 500) 89 | traj3 = get_env_trajectory(env3, 500) 90 | assert not compare_trajectories(traj1, traj2) 91 | assert compare_trajectories(traj1, traj3) 92 | -------------------------------------------------------------------------------- /rlberry/envs/tests/test_gym_make.py: -------------------------------------------------------------------------------- 1 | from rlberry.envs.gym_make import atari_make 2 | import gymnasium as gym 3 | import ale_py 4 | 5 | gym.register_envs(ale_py) 6 | 7 | 8 | def test_atari_make(): 9 | wrappers_dict = dict(terminal_on_life_loss=True, frame_skip=8) 10 | env = atari_make( 11 | "ALE/Freeway-v5", render_mode="rgb_array", atari_SB3_wrappers_dict=wrappers_dict 12 | ) 13 | assert "EpisodicLifeEnv" in str(env) 14 | assert "MaxAndSkipEnv" in str(env) 15 | assert "ClipRewardEnv" in str(env) 16 | assert env.render_mode == "rgb_array" 17 | 18 | wrappers_dict2 = dict(terminal_on_life_loss=False, frame_skip=0) 19 | env2 = atari_make( 20 | "ALE/Breakout-v5", render_mode="human", atari_SB3_wrappers_dict=wrappers_dict2 21 | ) 22 | assert "EpisodicLifeEnv" not in str(env2) 23 | assert "MaxAndSkipEnv" not in str(env2) 24 | assert "ClipRewardEnv" in str(env2) 25 | assert env2.render_mode == "human" 26 | 27 | 28 | def test_rendering_with_atari_make(): 29 | from rlberry.manager import ExperimentManager 30 | 31 | from gymnasium.wrappers.rendering import RecordVideo 32 | import os 33 | from rlberry.envs.gym_make import atari_make 34 | 35 | import tempfile 36 | 37 | with tempfile.TemporaryDirectory() as tmpdirname: 38 | from stable_baselines3 import ( 39 | PPO, 40 | ) # https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html 41 | 42 | from rlberry.agents.stable_baselines import StableBaselinesAgent 43 | 44 | tuned_xp = ExperimentManager( 45 | StableBaselinesAgent, # The Agent class. 46 | ( 47 | atari_make, 48 | dict(id="ALE/Breakout-v5"), 49 | ), # The Environment to solve. 50 | init_kwargs=dict( # Where to put the agent's hyperparameters 51 | algo_cls=PPO, 52 | policy="MlpPolicy", 53 | ), 54 | fit_budget=1000, # The number of interactions between the agent and the environment during training. 55 | eval_kwargs=dict( 56 | eval_horizon=500 57 | ), # The number of interactions between the agent and the environment during evaluations. 58 | n_fit=1, # The number of agents to train. Usually, it is good to do more than 1 because the training is stochastic. 59 | agent_name="PPO_tuned", # The agent's name. 60 | output_dir=str(tmpdirname) + "/PPO_for_breakout", 61 | ) 62 | 63 | tuned_xp.fit() 64 | 65 | env = atari_make("ALE/Breakout-v5", render_mode="rgb_array") 66 | env = RecordVideo(env, str(tmpdirname) + "/_video/temp") 67 | 68 | if "render_modes" in env.metadata: 69 | env.metadata["render.modes"] = env.metadata[ 70 | "render_modes" 71 | ] # bug with some 'gym' version 72 | 73 | observation, info = env.reset() 74 | for tt in range(3000): 75 | action = tuned_xp.get_agent_instances()[0].policy(observation) 76 | observation, reward, terminated, truncated, info = env.step(action) 77 | done = terminated or truncated 78 | if done: 79 | break 80 | 81 | env.close() 82 | 83 | assert os.path.exists(str(tmpdirname) + "/_video/temp/rl-video-episode-0.mp4") 84 | -------------------------------------------------------------------------------- /rlberry/envs/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | from copy import deepcopy 3 | from rlberry.seeding import safe_reseed 4 | 5 | 6 | import rlberry 7 | 8 | logger = rlberry.logger 9 | 10 | 11 | def process_env(env, seeder, copy_env=True): 12 | if isinstance(env, Tuple): 13 | constructor = env[0] 14 | if constructor is None: 15 | return None 16 | kwargs = env[1] or {} 17 | processed_env = constructor(**kwargs) 18 | else: 19 | if env is None: 20 | return None 21 | if copy_env: 22 | try: 23 | processed_env = deepcopy(env) 24 | except Exception as ex: 25 | raise RuntimeError("[Agent] Not possible to deepcopy env: " + str(ex)) 26 | else: 27 | processed_env = env 28 | reseeded = safe_reseed(processed_env, seeder) 29 | if not reseeded: 30 | logger.warning("[Agent] Not possible to reseed environment.") 31 | return processed_env 32 | -------------------------------------------------------------------------------- /rlberry/experiment/__init__.py: -------------------------------------------------------------------------------- 1 | from .yaml_utils import parse_experiment_config 2 | from .generator import experiment_generator 3 | from .load_results import load_experiment_results 4 | -------------------------------------------------------------------------------- /rlberry/experiment/generator.py: -------------------------------------------------------------------------------- 1 | """Run experiments. 2 | 3 | Usage: 4 | run.py [--enable_tensorboard] [--n_fit=] [--output_dir=] [--parallelization=] [--max_workers=] 5 | run.py (-h | --help) 6 | 7 | Options: 8 | -h --help Show this screen. 9 | --enable_tensorboard Enable tensorboard writer in ExperimentManager. 10 | --n_fit= Number of times each agent is fit [default: 4]. 11 | --output_dir= Directory to save the results [default: results]. 12 | --parallelization= Either 'thread' or 'process' [default: process]. 13 | --max_workers= Number of workers used by ExperimentManager.fit. Set to -1 for the maximum value. [default: -1] 14 | """ 15 | 16 | from docopt import docopt 17 | from pathlib import Path 18 | from rlberry.experiment.yaml_utils import parse_experiment_config 19 | from rlberry.manager import ExperimentManager 20 | from rlberry import check_packages 21 | 22 | import rlberry 23 | 24 | logger = rlberry.logger 25 | 26 | 27 | def experiment_generator(): 28 | """ 29 | Parse command line arguments and yields ExperimentManager instances. 30 | """ 31 | args = docopt(__doc__) 32 | max_workers = int(args["--max_workers"]) 33 | if max_workers == -1: 34 | max_workers = None 35 | for _, experiment_manager_kwargs in parse_experiment_config( 36 | Path(args[""]), 37 | n_fit=int(args["--n_fit"]), 38 | max_workers=max_workers, 39 | output_base_dir=args["--output_dir"], 40 | parallelization=args["--parallelization"], 41 | ): 42 | if args["--enable_tensorboard"]: 43 | if check_packages.TENSORBOARD_INSTALLED: 44 | experiment_manager_kwargs.update(dict(enable_tensorboard=True)) 45 | else: 46 | logger.warning( 47 | "Option --enable_tensorboard is not available: tensorboard is not installed." 48 | ) 49 | 50 | yield ExperimentManager(**experiment_manager_kwargs) 51 | -------------------------------------------------------------------------------- /rlberry/experiment/tests/params_experiment.yaml: -------------------------------------------------------------------------------- 1 | description: 'RSUCBVI in NRoom' 2 | seed: 4575458 3 | train_env: 'rlberry/experiment/tests/room.yaml' 4 | eval_env: 'rlberry/experiment/tests/room.yaml' 5 | global_init_kwargs: 6 | reward_free: True 7 | horizon: 2 8 | global_eval_kwargs: 9 | eval_horizon: 4 10 | global_fit_kwargs: 11 | fit_budget: 3 12 | agents: 13 | - 'rlberry/experiment/tests/rsucbvi.yaml' 14 | - 'rlberry/experiment/tests/rsucbvi_alternative.yaml' 15 | -------------------------------------------------------------------------------- /rlberry/experiment/tests/room.yaml: -------------------------------------------------------------------------------- 1 | constructor: 'rlberry_research.envs.benchmarks.grid_exploration.nroom.NRoom' 2 | params: 3 | reward_free: false 4 | array_observation: true 5 | nrooms: 5 6 | -------------------------------------------------------------------------------- /rlberry/experiment/tests/rsucbvi.yaml: -------------------------------------------------------------------------------- 1 | agent_class: 'rlberry_research.agents.kernel_based.rs_ucbvi.RSUCBVIAgent' 2 | init_kwargs: 3 | gamma: 1.0 4 | lp_metric: 2 5 | min_dist: 0.0 6 | max_repr: 800 7 | bonus_scale_factor: 1.0 8 | reward_free: True 9 | horizon: 50 10 | eval_kwargs: 11 | eval_horizon: 50 12 | fit_kwargs: 13 | fit_budget: 123 14 | -------------------------------------------------------------------------------- /rlberry/experiment/tests/rsucbvi_alternative.yaml: -------------------------------------------------------------------------------- 1 | base_config: 'rlberry/experiment/tests/rsucbvi.yaml' 2 | init_kwargs: 3 | gamma: 0.9 4 | eval_kwargs: 5 | horizon: 50 6 | -------------------------------------------------------------------------------- /rlberry/experiment/tests/test_experiment_generator.py: -------------------------------------------------------------------------------- 1 | from rlberry.experiment import experiment_generator 2 | from rlberry_research.agents.kernel_based.rs_ucbvi import RSUCBVIAgent 3 | 4 | import numpy as np 5 | 6 | 7 | def test_mock_args(monkeypatch): 8 | monkeypatch.setattr( 9 | "sys.argv", ["", "rlberry/experiment/tests/params_experiment.yaml"] 10 | ) 11 | random_numbers = [] 12 | 13 | for experiment_manager in experiment_generator(): 14 | rng = experiment_manager.seeder.rng 15 | random_numbers.append(rng.uniform(size=10)) 16 | 17 | assert experiment_manager.agent_class is RSUCBVIAgent 18 | assert experiment_manager._base_init_kwargs["horizon"] == 2 19 | assert experiment_manager.fit_budget == 3 20 | assert experiment_manager.eval_kwargs["eval_horizon"] == 4 21 | 22 | assert experiment_manager._base_init_kwargs["lp_metric"] == 2 23 | assert experiment_manager._base_init_kwargs["min_dist"] == 0.0 24 | assert experiment_manager._base_init_kwargs["max_repr"] == 800 25 | assert experiment_manager._base_init_kwargs["bonus_scale_factor"] == 1.0 26 | assert experiment_manager._base_init_kwargs["reward_free"] is True 27 | 28 | train_env = experiment_manager.train_env[0](**experiment_manager.train_env[1]) 29 | assert train_env.reward_free is False 30 | assert train_env.array_observation is True 31 | 32 | if experiment_manager.agent_name == "rsucbvi": 33 | assert experiment_manager._base_init_kwargs["gamma"] == 1.0 34 | 35 | elif experiment_manager.agent_name == "rsucbvi_alternative": 36 | assert experiment_manager._base_init_kwargs["gamma"] == 0.9 37 | 38 | else: 39 | raise ValueError() 40 | 41 | # check that seeding is the same for each ExperimentManager instance 42 | for ii in range(1, len(random_numbers)): 43 | assert np.array_equal(random_numbers[ii - 1], random_numbers[ii]) 44 | -------------------------------------------------------------------------------- /rlberry/experiment/tests/test_load_results.py: -------------------------------------------------------------------------------- 1 | from rlberry.experiment import load_experiment_results 2 | import tempfile 3 | from rlberry.experiment import experiment_generator 4 | import os 5 | import sys 6 | 7 | TEST_DIR = os.path.dirname(os.path.abspath(__file__)) 8 | 9 | 10 | def test_save_and_load(): 11 | TEST_DIR = os.path.dirname(os.path.abspath(__file__)) 12 | sys.argv = [TEST_DIR + "/test_load_results.py"] 13 | with tempfile.TemporaryDirectory() as tmpdirname: 14 | sys.argv.append(TEST_DIR + "/params_experiment.yaml") 15 | sys.argv.append("--parallelization=thread") 16 | sys.argv.append("--output_dir=" + tmpdirname) 17 | print(sys.argv) 18 | for experiment_manager in experiment_generator(): 19 | experiment_manager.fit() 20 | experiment_manager.save() 21 | data = load_experiment_results(tmpdirname, "params_experiment") 22 | 23 | assert len(data) > 0 24 | -------------------------------------------------------------------------------- /rlberry/manager/__init__.py: -------------------------------------------------------------------------------- 1 | from .experiment_manager import ExperimentManager 2 | from .experiment_manager import preset_manager 3 | from .multiple_managers import MultipleManagers 4 | from .evaluation import evaluate_agents, read_writer_data 5 | from .comparison import compare_agents, AdastopComparator 6 | from .plotting import plot_smoothed_curves, plot_writer_data, plot_synchronized_curves 7 | from .env_tools import with_venv, run_venv_xp 8 | from .utils import tensorboard_to_dataframe 9 | 10 | # AgentManager alias for the ExperimentManager class, for backward compatibility 11 | AgentManager = ExperimentManager 12 | -------------------------------------------------------------------------------- /rlberry/manager/multiple_managers.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import functools 3 | import multiprocessing 4 | from typing import Optional 5 | 6 | 7 | def fit_stats(stats, save): 8 | stats.fit() 9 | if save: 10 | stats.save() 11 | return stats 12 | 13 | 14 | class MultipleManagers: 15 | """ 16 | Class to fit multiple ExperimentManager instances in parallel with multiple threads. 17 | 18 | Parameters 19 | ---------- 20 | max_workers: int, default=None 21 | max number of workers (ExperimentManager instances) fitted at the same time. 22 | parallelization: {'thread', 'process'}, default: 'process' 23 | Whether to parallelize agent training using threads or processes. 24 | mp_context: {'spawn', 'fork', 'forkserver'}, default: 'spawn'. 25 | Context for python multiprocessing module. 26 | Warning: If you're using JAX or PyTorch, it only works with 'spawn'. 27 | If running code on a notebook or interpreter, use 'fork'. 28 | """ 29 | 30 | def __init__( 31 | self, 32 | max_workers: Optional[int] = None, 33 | parallelization: str = "process", 34 | mp_context="spawn", 35 | ) -> None: 36 | super().__init__() 37 | self.instances = [] 38 | self.max_workers = max_workers 39 | self.parallelization = parallelization 40 | self.mp_context = mp_context 41 | 42 | def append(self, experiment_manager): 43 | """ 44 | Append new ExperimentManager instance. 45 | 46 | Parameters 47 | ---------- 48 | experiment_manager : ExperimentManager 49 | """ 50 | self.instances.append(experiment_manager) 51 | 52 | def run(self, save=True): 53 | """ 54 | Fit ExperimentManager instances in parallel. 55 | 56 | Parameters 57 | ---------- 58 | save: bool, default: True 59 | If true, save ExperimentManager intances immediately after fitting. 60 | ExperimentManager.save() is called. 61 | """ 62 | if self.parallelization == "thread": 63 | executor_class = concurrent.futures.ThreadPoolExecutor 64 | elif self.parallelization == "process": 65 | executor_class = functools.partial( 66 | concurrent.futures.ProcessPoolExecutor, 67 | mp_context=multiprocessing.get_context(self.mp_context), 68 | ) 69 | else: 70 | raise ValueError( 71 | f"Invalid backend for parallelization: {self.parallelization}" 72 | ) 73 | 74 | with executor_class(max_workers=self.max_workers) as executor: 75 | futures = [] 76 | for inst in self.instances: 77 | futures.append(executor.submit(fit_stats, inst, save=save)) 78 | 79 | fitted_instances = [] 80 | for future in concurrent.futures.as_completed(futures): 81 | fitted_instances.append(future.result()) 82 | 83 | self.instances = fitted_instances 84 | 85 | def save(self): 86 | """ 87 | Pickle ExperimentManager instances and saves fit statistics in .csv files. 88 | The output folder is defined in each of the ExperimentManager instances. 89 | """ 90 | for stats in self.instances: 91 | stats.save() 92 | 93 | @property 94 | def managers(self): 95 | return self.instances 96 | -------------------------------------------------------------------------------- /rlberry/manager/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/manager/tests/__init__.py -------------------------------------------------------------------------------- /rlberry/manager/tests/test_shared_data.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import numpy as np 3 | from rlberry.agents import Agent 4 | from rlberry.manager import ExperimentManager 5 | 6 | 7 | class DummyAgent(Agent): 8 | def __init__(self, **kwargs): 9 | Agent.__init__(self, **kwargs) 10 | self.name = "DummyAgent" 11 | self.shared_data_id = id(self.thread_shared_data) 12 | 13 | def fit(self, budget, **kwargs): 14 | del budget, kwargs 15 | 16 | def eval(self, **kwargs): 17 | del kwargs 18 | return self.shared_data_id 19 | 20 | 21 | @pytest.mark.parametrize("paralellization", ["thread", "process"]) 22 | def test_data_sharing(paralellization): 23 | shared_data = dict(X=np.arange(10)) 24 | manager = ExperimentManager( 25 | agent_class=DummyAgent, 26 | fit_budget=-1, 27 | n_fit=4, 28 | parallelization=paralellization, 29 | thread_shared_data=shared_data, 30 | ) 31 | manager.fit() 32 | data_ids = [agent.eval() for agent in manager.get_agent_instances()] 33 | unique_data_ids = list(set(data_ids)) 34 | if paralellization == "thread": 35 | # id() is unique for each object: make sure that shared data have same id 36 | assert len(unique_data_ids) == 1 37 | else: 38 | # when using processes, make sure that data is copied and each instance 39 | # has its own data id 40 | assert len(unique_data_ids) == manager.n_fit 41 | -------------------------------------------------------------------------------- /rlberry/manager/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from rlberry.manager import tensorboard_to_dataframe 2 | from stable_baselines3 import PPO, A2C 3 | import tempfile 4 | import os 5 | import pandas as pd 6 | import pytest 7 | 8 | 9 | def test_tensorboard_to_dataframe(): 10 | with tempfile.TemporaryDirectory() as tmpdirname: 11 | # create data to test 12 | path_ppo = str(tmpdirname + "/ppo_cartpole_tensorboard/") 13 | path_a2c = str(tmpdirname + "/a2c_cartpole_tensorboard/") 14 | model = PPO("MlpPolicy", "CartPole-v1", tensorboard_log=path_ppo) 15 | model2 = A2C("MlpPolicy", "CartPole-v1", tensorboard_log=path_a2c) 16 | model2_seed2 = A2C("MlpPolicy", "CartPole-v1", tensorboard_log=path_a2c) 17 | model.learn(total_timesteps=5_000, tb_log_name="ppo") 18 | model2.learn(total_timesteps=5_000, tb_log_name="A2C") 19 | model2_seed2.learn(total_timesteps=5_000, tb_log_name="A2C") 20 | 21 | assert os.path.exists(path_ppo) 22 | assert os.path.exists(path_a2c) 23 | 24 | # check with parent folder 25 | data_in_dataframe = tensorboard_to_dataframe(tmpdirname) 26 | 27 | assert isinstance(data_in_dataframe, dict) 28 | assert "rollout/ep_rew_mean" in data_in_dataframe 29 | a_dict = data_in_dataframe["rollout/ep_rew_mean"] 30 | 31 | assert isinstance(a_dict, pd.DataFrame) 32 | assert "name" in a_dict.columns 33 | assert "n_simu" in a_dict.columns 34 | assert "x" in a_dict.columns 35 | assert "y" in a_dict.columns 36 | 37 | # check with list of folder 38 | folder_ppo_1 = str(path_ppo + "ppo_1/") 39 | folder_A2C_1 = str(path_a2c + "A2C_1/") 40 | folder_A2C_2 = str(path_a2c + "A2C_2/") 41 | 42 | path_event_ppo_1 = str(folder_ppo_1 + os.listdir(folder_ppo_1)[0]) 43 | path_event_A2C_1 = str(folder_A2C_1 + os.listdir(folder_A2C_1)[0]) 44 | path_event_A2C_2 = str(folder_A2C_2 + os.listdir(folder_A2C_2)[0]) 45 | 46 | input_dict = { 47 | "ppo_cartpole_tensorboard": [path_event_ppo_1], 48 | "a2c_cartpole_tensorboard": [path_event_A2C_1, path_event_A2C_2], 49 | } 50 | 51 | data_in_dataframe2 = tensorboard_to_dataframe(input_dict) 52 | assert isinstance(data_in_dataframe2, dict) 53 | assert "rollout/ep_rew_mean" in data_in_dataframe2 54 | a_dict2 = data_in_dataframe2["rollout/ep_rew_mean"] 55 | 56 | assert isinstance(a_dict2, pd.DataFrame) 57 | assert "name" in a_dict2.columns 58 | assert "n_simu" in a_dict2.columns 59 | assert "x" in a_dict2.columns 60 | assert "y" in a_dict2.columns 61 | 62 | # check both strategies give the same result 63 | assert set(a_dict.keys()) == set(a_dict2.keys()) 64 | for key in a_dict: 65 | if ( 66 | key != "n_simu" 67 | ): # don't test n_simu/seed, it is different because one come from the folder name, and the other come for the index in the list 68 | assert set(a_dict[key]) == set(a_dict2[key]) 69 | 70 | 71 | def test_tensorboard_to_dataframe_errorIO(): 72 | msg = "Input of 'tensorboard_to_dataframe' must be a str or a dict... not a " 73 | with pytest.raises(IOError, match=msg): 74 | tensorboard_to_dataframe(1) 75 | -------------------------------------------------------------------------------- /rlberry/manager/tests/test_venv.py: -------------------------------------------------------------------------------- 1 | from rlberry.manager import with_venv, run_venv_xp 2 | 3 | 4 | @with_venv(import_libs=["tqdm"], verbose=True) 5 | def run_tqdm(): 6 | from tqdm import tqdm # noqa 7 | 8 | 9 | def test_venv(): 10 | run_venv_xp(verbose=True) 11 | -------------------------------------------------------------------------------- /rlberry/metadata_utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import uuid 3 | import hashlib 4 | from typing import Optional, NamedTuple 5 | 6 | 7 | # Default output directory used by the library. 8 | RLBERRY_DEFAULT_DATA_DIR = "rlberry_data/" 9 | 10 | # Temporary directory used by the library 11 | RLBERRY_TEMP_DATA_DIR = "rlberry_data/temp/" 12 | 13 | 14 | def get_timestamp_str(): 15 | """ 16 | Get a string containing current time stamp. 17 | """ 18 | now = datetime.now() 19 | date_time = now.strftime("%Y-%m-%d_%H-%M-%S") 20 | timestamp = date_time 21 | return timestamp 22 | 23 | 24 | def get_readable_id(obj): 25 | """ 26 | Create a more readable id than get_unique_id(), 27 | combining a timestamp to a 8-character hash. 28 | """ 29 | long_id = get_unique_id(obj) 30 | timestamp = get_timestamp_str() 31 | short_id = f"{timestamp}_{long_id[:8]}" 32 | return short_id 33 | 34 | 35 | def get_unique_id(obj): 36 | """ 37 | Get a unique id for an obj. Use it in __init__ methods when necessary. 38 | """ 39 | # id() is guaranteed to be unique among simultaneously existing objects (uses memory address). 40 | # uuid4() is an universal id, but there might be issues if called simultaneously in different processes. 41 | # This function combines id(), uuid4(), and a timestamp in a single ID, and hashes it. 42 | timestamp = datetime.timestamp(datetime.now()) 43 | timestamp = str(timestamp).replace(".", "") 44 | str_id = timestamp + str(id(obj)) + uuid.uuid4().hex 45 | str_id = hashlib.md5(str_id.encode()).hexdigest() 46 | return str_id 47 | 48 | 49 | class ExecutionMetadata(NamedTuple): 50 | """ 51 | Metadata for objects handled by rlberry. 52 | 53 | Attributes 54 | ---------- 55 | obj_worker_id : int, default: -1 56 | If given, must be >= 0, and inform the worker id (thread or process) where the 57 | object was created. It is not necessarity unique across all the workers launched by 58 | rlberry, it is mainly for debug purposes. 59 | obj_info : dict, default: None 60 | Extra info about the object. 61 | """ 62 | 63 | obj_worker_id: int = -1 64 | obj_info: Optional[dict] = None 65 | -------------------------------------------------------------------------------- /rlberry/rendering/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import Scene, GeometricPrimitive 2 | from .render_interface import RenderInterface 3 | from .render_interface import RenderInterface2D 4 | -------------------------------------------------------------------------------- /rlberry/rendering/common_shapes.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rlberry.rendering import GeometricPrimitive 3 | 4 | 5 | def bar_shape(p0, p1, width): 6 | shape = GeometricPrimitive("QUADS") 7 | 8 | x0, y0 = p0 9 | x1, y1 = p1 10 | 11 | direction = np.array([x1 - x0, y1 - y0]) 12 | norm = np.sqrt((direction * direction).sum()) 13 | direction = direction / norm 14 | 15 | # get vector perpendicular to direction 16 | u_vec = np.zeros(2) 17 | u_vec[0] = -direction[1] 18 | u_vec[1] = direction[0] 19 | 20 | u_vec = u_vec * width / 2 21 | 22 | shape.add_vertex((x0 + u_vec[0], y0 + u_vec[1])) 23 | shape.add_vertex((x0 - u_vec[0], y0 - u_vec[1])) 24 | shape.add_vertex((x1 - u_vec[0], y1 - u_vec[1])) 25 | shape.add_vertex((x1 + u_vec[0], y1 + u_vec[1])) 26 | return shape 27 | 28 | 29 | def circle_shape(center, radius, n_points=50): 30 | shape = GeometricPrimitive("POLYGON") 31 | 32 | x0, y0 = center 33 | theta = np.linspace(0.0, 2 * np.pi, n_points) 34 | for tt in theta: 35 | xx = radius * np.cos(tt) 36 | yy = radius * np.sin(tt) 37 | shape.add_vertex((x0 + xx, y0 + yy)) 38 | 39 | return shape 40 | -------------------------------------------------------------------------------- /rlberry/rendering/core.py: -------------------------------------------------------------------------------- 1 | """ 2 | Provide classes for geometric primitives in OpenGL and scenes. 3 | """ 4 | 5 | 6 | class Scene: 7 | """ 8 | Class representing a scene, which is a vector of GeometricPrimitive objects 9 | """ 10 | 11 | def __init__(self): 12 | self.shapes = [] 13 | 14 | def add_shape(self, shape): 15 | self.shapes.append(shape) 16 | 17 | 18 | class GeometricPrimitive: 19 | """ 20 | Class representing an OpenGL geometric primitive. 21 | 22 | Primitive type (GL_LINE_LOOP by defaut) 23 | 24 | If using OpenGLRender2D, one of the following: 25 | POINTS 26 | LINES 27 | LINE_STRIP 28 | LINE_LOOP 29 | POLYGON 30 | TRIANGLES 31 | TRIANGLE_STRIP 32 | TRIANGLE_FAN 33 | QUADS 34 | QUAD_STRIP 35 | 36 | If using PyGameRender2D: 37 | POLYGON 38 | 39 | 40 | TODO: Add support to more pygame shapes, 41 | see https://www.pygame.org/docs/ref/draw.html 42 | """ 43 | 44 | def __init__(self, primitive_type="GL_LINE_LOOP"): 45 | # primitive type 46 | self.type = primitive_type 47 | # color in RGB 48 | self.color = (0.25, 0.25, 0.25) 49 | # list of vertices. each vertex is a tuple with coordinates in space 50 | self.vertices = [] 51 | 52 | def add_vertex(self, vertex): 53 | self.vertices.append(vertex) 54 | 55 | def set_color(self, color): 56 | self.color = color 57 | -------------------------------------------------------------------------------- /rlberry/rendering/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/rendering/tests/__init__.py -------------------------------------------------------------------------------- /rlberry/rendering/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import rlberry 3 | 4 | logger = rlberry.logger 5 | import imageio 6 | 7 | 8 | _FFMPEG_INSTALLED = True 9 | try: 10 | import ffmpeg 11 | except Exception: 12 | _FFMPEG_INSTALLED = False 13 | 14 | 15 | def video_write(fn, images, framerate=60, vcodec="libx264"): 16 | """ 17 | Save list of images to a video file. 18 | 19 | Source: 20 | https://github.com/kkroening/ffmpeg-python/issues/246#issuecomment-520200981 21 | Modified so that framerate is given to .input(), as suggested in the 22 | thread, to avoid 23 | skipping frames. 24 | 25 | Parameters 26 | ---------- 27 | fn : string 28 | filename 29 | images : list or np.array 30 | list of images to save to a video. 31 | framerate : int 32 | """ 33 | global _FFMPEG_INSTALLED 34 | 35 | try: 36 | if len(images) == 0: 37 | logger.warning("Calling video_write() with empty images.") 38 | return 39 | 40 | if not _FFMPEG_INSTALLED: 41 | logger.error( 42 | "video_write(): Unable to save video, ffmpeg-python \ 43 | package required (https://github.com/kkroening/ffmpeg-python)" 44 | ) 45 | return 46 | 47 | if not isinstance(images, np.ndarray): 48 | images = np.asarray(images) 49 | _, height, width, channels = images.shape 50 | process = ( 51 | ffmpeg.input( 52 | "pipe:", 53 | format="rawvideo", 54 | pix_fmt="rgb24", 55 | s="{}x{}".format(width, height), 56 | r=framerate, 57 | ) 58 | .output(fn, pix_fmt="yuv420p", vcodec=vcodec) 59 | .overwrite_output() 60 | .run_async(pipe_stdin=True) 61 | ) 62 | for frame in images: 63 | process.stdin.write(frame.astype(np.uint8).tobytes()) 64 | process.stdin.close() 65 | process.wait() 66 | 67 | except Exception as ex: 68 | logger.warning( 69 | "Not possible to save \ 70 | video, due to exception: {}".format( 71 | str(ex) 72 | ) 73 | ) 74 | 75 | 76 | def gif_write(fn, images): 77 | """ 78 | Save list of images to a gif file 79 | 80 | Parameters 81 | ---------- 82 | fn : string 83 | filename 84 | images : list or np.array 85 | list of images to save to a gif. 86 | """ 87 | 88 | try: 89 | if len(images) == 0: 90 | logger.warning("Calling gif_write() with empty images.") 91 | return 92 | 93 | if not isinstance(images, np.ndarray): 94 | images = np.asarray(images) 95 | 96 | with imageio.get_writer(fn, mode="I") as writer: 97 | for frame in images: 98 | writer.append_data(frame) 99 | 100 | except Exception as ex: 101 | logger.warning( 102 | "Not possible to save \ 103 | gif, due to exception: {}".format( 104 | str(ex) 105 | ) 106 | ) 107 | -------------------------------------------------------------------------------- /rlberry/seeding/__init__.py: -------------------------------------------------------------------------------- 1 | from .seeder import Seeder 2 | from .seeding import safe_reseed 3 | from .seeding import set_external_seed 4 | -------------------------------------------------------------------------------- /rlberry/seeding/seeding.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import rlberry.check_packages as check_packages 3 | from rlberry.seeding.seeder import Seeder 4 | 5 | if check_packages.TORCH_INSTALLED: 6 | import torch 7 | 8 | 9 | def set_external_seed(seeder): 10 | """ 11 | Set seeds of external libraries. 12 | 13 | To do: 14 | Check (torch seeding): 15 | https://github.com/pytorch/pytorch/issues/7068#issuecomment-487907668 16 | 17 | Parameters 18 | --------- 19 | seeder: seeding.Seeder or int 20 | Integer or Seeder object from which to generate random seeds. 21 | 22 | Examples 23 | -------- 24 | >>> from rlberry.seeding import set_external_seed 25 | >>> set_external_seed(seeder) 26 | """ 27 | if np.issubdtype(type(seeder), np.integer): 28 | seeder = Seeder(seeder) 29 | 30 | # seed torch 31 | if check_packages.TORCH_INSTALLED: 32 | torch.manual_seed(seeder.seed_seq.generate_state(1, dtype=np.uint32)[0]) 33 | 34 | 35 | def safe_reseed(obj, seeder, reseed_spaces=True): 36 | """ 37 | Calls obj.reseed(seed_seq) method if available; 38 | If a obj.seed() method is available, call obj.seed(seed_val), 39 | where seed_val is generated by the seeder. 40 | Otherwise, does nothing. 41 | 42 | Parameters 43 | ---------- 44 | obj : object 45 | Object to be reseeded. 46 | seeder: :class:`~rlberry.seeding.seeder.Seeder` 47 | Seeder object from which to generate random seeds. 48 | reseed_spaces: bool, default = True. 49 | If False, do not try to reseed observation_space and action_space (if 50 | they exist as attributes of `obj`). 51 | 52 | Returns 53 | ------- 54 | True if reseeding was done, False otherwise. 55 | 56 | """ 57 | reseeded = False 58 | try: 59 | obj.reseed(seeder) 60 | reseeded = True 61 | except AttributeError: 62 | seed_val = seeder.rng.integers(2**32).item() 63 | try: 64 | obj.seed(seed_val) 65 | reseeded = True 66 | except AttributeError: 67 | try: 68 | obj.reset(seed=seed_val) 69 | reseeded = True 70 | except AttributeError: 71 | reseeded = False 72 | 73 | # check if the object has observation and action spaces to be reseeded. 74 | if reseed_spaces: 75 | try: 76 | safe_reseed(obj.observation_space, seeder) 77 | safe_reseed(obj.action_space, seeder) 78 | except AttributeError: 79 | pass 80 | 81 | return reseeded 82 | -------------------------------------------------------------------------------- /rlberry/seeding/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/seeding/tests/__init__.py -------------------------------------------------------------------------------- /rlberry/seeding/tests/test_seeding.py: -------------------------------------------------------------------------------- 1 | from rlberry.seeding import Seeder 2 | 3 | 4 | def test_seeder_basic(): 5 | seeder1 = Seeder(43) 6 | data1 = seeder1.rng.integers(100, size=1000) 7 | 8 | seeder2 = Seeder(44) 9 | data2 = seeder2.rng.integers(100, size=1000) 10 | 11 | seeder3 = Seeder(44) 12 | data3 = seeder3.rng.integers(100, size=1000) 13 | 14 | assert (data1 != data2).sum() > 5 15 | assert (data2 != data3).sum() == 0 16 | assert ( 17 | seeder2.spawn(1).generate_state(1)[0] == seeder3.spawn(1).generate_state(1)[0] 18 | ) 19 | assert ( 20 | seeder1.spawn(1).generate_state(1)[0] != seeder3.spawn(1).generate_state(1)[0] 21 | ) 22 | 23 | 24 | def test_seeder_initialized_from_seeder(): 25 | """ 26 | Check that Seeder(seed_seq) respawns seed_seq in the constructor. 27 | """ 28 | seeder1 = Seeder(43) 29 | seeder_temp = Seeder(43) 30 | seeder2 = Seeder(seeder_temp) 31 | 32 | data1 = seeder1.rng.integers(100, size=1000) 33 | data2 = seeder2.rng.integers(100, size=1000) 34 | assert (data1 != data2).sum() > 5 35 | 36 | 37 | def test_seeder_spawning(): 38 | """ 39 | Check that Seeder(seed_seq) respawns seed_seq in the constructor. 40 | """ 41 | seeder1 = Seeder(43) 42 | seeder2 = seeder1.spawn() 43 | seeder3 = seeder2.spawn() 44 | 45 | print(seeder1) 46 | print(seeder2) 47 | print(seeder3) 48 | 49 | data1 = seeder1.rng.integers(100, size=1000) 50 | data2 = seeder2.rng.integers(100, size=1000) 51 | assert (data1 != data2).sum() > 5 52 | 53 | 54 | def test_seeder_reseeding(): 55 | """ 56 | Check that reseeding with a Seeder instance works properly. 57 | """ 58 | # seeders 1 and 2 are identical 59 | seeder1 = Seeder(43) 60 | seeder2 = Seeder(43) 61 | 62 | # reseed seeder 2 using seeder 1 63 | seeder2.reseed(seeder1) 64 | 65 | data1 = seeder1.rng.integers(100, size=1000) 66 | data2 = seeder2.rng.integers(100, size=1000) 67 | assert (data1 != data2).sum() > 5 68 | -------------------------------------------------------------------------------- /rlberry/seeding/tests/test_threads.py: -------------------------------------------------------------------------------- 1 | from rlberry.seeding.seeder import Seeder 2 | import concurrent.futures 3 | 4 | 5 | def get_random_number_setting_seed(seeder): 6 | return seeder.rng.integers(2**32) 7 | 8 | 9 | def test_multithread_seeding(): 10 | """ 11 | Checks that different seeds are given to different threads 12 | """ 13 | for ii in range(5): 14 | main_seeder = Seeder(123) 15 | for jj in range(10): 16 | with concurrent.futures.ThreadPoolExecutor() as executor: 17 | futures = [] 18 | for seed in main_seeder.spawn(2): 19 | futures.append( 20 | executor.submit(get_random_number_setting_seed, seed) 21 | ) 22 | 23 | results = [] 24 | for future in concurrent.futures.as_completed(futures): 25 | results.append(future.result()) 26 | assert results[0] != results[1], f"error in simulation {(ii, jj)}" 27 | -------------------------------------------------------------------------------- /rlberry/seeding/tests/test_threads_torch.py: -------------------------------------------------------------------------------- 1 | from rlberry.seeding.seeder import Seeder 2 | from rlberry.seeding import set_external_seed 3 | import concurrent.futures 4 | 5 | _TORCH_INSTALLED = True 6 | try: 7 | import torch 8 | except Exception: 9 | _TORCH_INSTALLED = False 10 | 11 | 12 | def get_torch_random_number_setting_seed(seeder): 13 | set_external_seed(seeder) 14 | return torch.randint(2**32, (1,))[0].item() 15 | 16 | 17 | def test_torch_multithread_seeding(): 18 | """ 19 | Checks that different seeds are given to different threads 20 | """ 21 | for ii in range(5): 22 | main_seeder = Seeder(123) 23 | for jj in range(10): 24 | with concurrent.futures.ThreadPoolExecutor() as executor: 25 | futures = [] 26 | for seed in main_seeder.spawn(2): 27 | futures.append( 28 | executor.submit(get_torch_random_number_setting_seed, seed) 29 | ) 30 | 31 | results = [] 32 | for future in concurrent.futures.as_completed(futures): 33 | results.append(future.result()) 34 | assert results[0] != results[1], f"error in simulation {(ii, jj)}" 35 | -------------------------------------------------------------------------------- /rlberry/spaces/__init__.py: -------------------------------------------------------------------------------- 1 | from .discrete import Discrete 2 | from .box import Box 3 | from .tuple import Tuple 4 | from .multi_discrete import MultiDiscrete 5 | from .multi_binary import MultiBinary 6 | from .dict import Dict 7 | -------------------------------------------------------------------------------- /rlberry/spaces/box.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | from rlberry.seeding import Seeder 4 | 5 | 6 | class Box(gym.spaces.Box): 7 | """ 8 | Class that represents a space that is a cartesian product in R^n: 9 | 10 | [a_1, b_1] x [a_2, b_2] x ... x [a_n, b_n] 11 | 12 | 13 | Inherited from gymnasium.spaces.Box for compatibility with gym. 14 | 15 | rlberry wraps gym.spaces to make sure the seeding 16 | mechanism is unified in the library (rlberry.seeding) 17 | 18 | Attributes 19 | ---------- 20 | rng : numpy.random._generator.Generator 21 | random number generator provided by rlberry.seeding 22 | 23 | Methods 24 | ------- 25 | reseed() 26 | get new random number generator 27 | """ 28 | 29 | def __init__(self, low, high, shape=None, dtype=np.float64): 30 | gym.spaces.Box.__init__(self, low, high, shape=shape, dtype=dtype) 31 | self.seeder = Seeder() 32 | 33 | @property 34 | def rng(self): 35 | return self.seeder.rng 36 | 37 | def reseed(self, seed_seq=None): 38 | """ 39 | Get new random number generator. 40 | 41 | Parameters 42 | ---------- 43 | seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None 44 | Seed sequence from which to spawn the random number generator. 45 | If None, generate random seed. 46 | If int, use as entropy for SeedSequence. 47 | If seeder, use seeder.seed_seq 48 | """ 49 | self.seeder.reseed(seed_seq) 50 | 51 | def sample(self): 52 | """ 53 | Adapted from: 54 | https://raw.githubusercontent.com/openai/gym/master/gym/spaces/box.py 55 | 56 | 57 | Generates a single random sample inside of the Box. 58 | 59 | In creating a sample of the box, each coordinate is sampled according 60 | to the form of the interval: 61 | 62 | * [a, b] : uniform distribution 63 | * [a, oo) : shifted exponential distribution 64 | * (-oo, b] : shifted negative exponential distribution 65 | * (-oo, oo) : normal distribution 66 | """ 67 | high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1 68 | sample = np.empty(self.shape) 69 | 70 | # Masking arrays which classify the coordinates according to interval 71 | # type 72 | unbounded = ~self.bounded_below & ~self.bounded_above 73 | upp_bounded = ~self.bounded_below & self.bounded_above 74 | low_bounded = self.bounded_below & ~self.bounded_above 75 | bounded = self.bounded_below & self.bounded_above 76 | 77 | # Vectorized sampling by interval type 78 | sample[unbounded] = self.rng.normal(size=unbounded[unbounded].shape) 79 | 80 | sample[low_bounded] = ( 81 | self.rng.exponential(size=low_bounded[low_bounded].shape) 82 | + self.low[low_bounded] 83 | ) 84 | 85 | sample[upp_bounded] = ( 86 | -self.rng.exponential(size=upp_bounded[upp_bounded].shape) 87 | + self.high[upp_bounded] 88 | ) 89 | 90 | sample[bounded] = self.rng.uniform( 91 | low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape 92 | ) 93 | if self.dtype.kind == "i": 94 | sample = np.floor(sample) 95 | 96 | return sample.astype(self.dtype) 97 | -------------------------------------------------------------------------------- /rlberry/spaces/dict.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from rlberry.seeding import Seeder 3 | 4 | 5 | class Dict(gym.spaces.Dict): 6 | """ 7 | 8 | Inherited from gymnasium.spaces.Dict for compatibility with gym. 9 | 10 | rlberry wraps gym.spaces to make sure the seeding 11 | mechanism is unified in the library (rlberry.seeding) 12 | 13 | Attributes 14 | ---------- 15 | rng : numpy.random._generator.Generator 16 | random number generator provided by rlberry.seeding 17 | 18 | Methods 19 | ------- 20 | reseed() 21 | get new random number generator 22 | """ 23 | 24 | def __init__(self, spaces=None, **spaces_kwargs): 25 | gym.spaces.Dict.__init__(self, spaces, **spaces_kwargs) 26 | self.seeder = Seeder() 27 | 28 | @property 29 | def rng(self): 30 | return self.seeder.rng 31 | 32 | def reseed(self, seed_seq=None): 33 | _ = [space.reseed(seed_seq) for space in self.spaces.values()] 34 | -------------------------------------------------------------------------------- /rlberry/spaces/discrete.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from rlberry.seeding import Seeder 3 | 4 | 5 | class Discrete(gym.spaces.Discrete): 6 | """ 7 | Class that represents discrete spaces. 8 | 9 | 10 | Inherited from gymnasium.spaces.Discrete for compatibility with gym. 11 | 12 | rlberry wraps gym.spaces to make sure the seeding 13 | mechanism is unified in the library (rlberry.seeding) 14 | 15 | Attributes 16 | ---------- 17 | rng : numpy.random._generator.Generator 18 | random number generator provided by rlberry.seeding 19 | 20 | Methods 21 | ------- 22 | reseed() 23 | get new random number generator 24 | """ 25 | 26 | def __init__(self, n): 27 | """ 28 | Parameters 29 | ---------- 30 | n : int 31 | number of elements in the space 32 | """ 33 | assert n >= 0, "The number of elements in Discrete must be >= 0" 34 | gym.spaces.Discrete.__init__(self, n) 35 | self.seeder = Seeder() 36 | 37 | @property 38 | def rng(self): 39 | return self.seeder.rng 40 | 41 | def reseed(self, seed_seq=None): 42 | """ 43 | Get new random number generator. 44 | 45 | Parameters 46 | ---------- 47 | seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None 48 | Seed sequence from which to spawn the random number generator. 49 | If None, generate random seed. 50 | If int, use as entropy for SeedSequence. 51 | If seeder, use seeder.seed_seq 52 | """ 53 | self.seeder.reseed(seed_seq) 54 | 55 | def sample(self): 56 | return self.rng.integers(0, self.n) 57 | 58 | def __str__(self): 59 | objstr = "%d-element Discrete space" % self.n 60 | return objstr 61 | -------------------------------------------------------------------------------- /rlberry/spaces/from_gym.py: -------------------------------------------------------------------------------- 1 | import rlberry.spaces 2 | import gymnasium.spaces 3 | 4 | 5 | def convert_space_from_gym(space): 6 | if isinstance(space, gymnasium.spaces.Box) and ( 7 | not isinstance(space, rlberry.spaces.Box) 8 | ): 9 | return rlberry.spaces.Box( 10 | space.low, space.high, shape=space.shape, dtype=space.dtype 11 | ) 12 | if isinstance(space, gymnasium.spaces.Discrete) and ( 13 | not isinstance(space, rlberry.spaces.Discrete) 14 | ): 15 | return rlberry.spaces.Discrete(n=space.n) 16 | if isinstance(space, gymnasium.spaces.MultiBinary) and ( 17 | not isinstance(space, rlberry.spaces.MultiBinary) 18 | ): 19 | return rlberry.spaces.MultiBinary(n=space.n) 20 | if isinstance(space, gymnasium.spaces.MultiDiscrete) and ( 21 | not isinstance(space, rlberry.spaces.MultiDiscrete) 22 | ): 23 | return rlberry.spaces.MultiDiscrete( 24 | nvec=space.nvec, 25 | dtype=space.dtype, 26 | ) 27 | if isinstance(space, gymnasium.spaces.Tuple) and ( 28 | not isinstance(space, rlberry.spaces.Tuple) 29 | ): 30 | return rlberry.spaces.Tuple( 31 | spaces=[convert_space_from_gym(sp) for sp in space.spaces] 32 | ) 33 | if isinstance(space, gymnasium.spaces.Dict) and ( 34 | not isinstance(space, rlberry.spaces.Dict) 35 | ): 36 | converted_spaces = dict() 37 | for key in space.spaces: 38 | converted_spaces[key] = convert_space_from_gym(space.spaces[key]) 39 | return rlberry.spaces.Dict(spaces=converted_spaces) 40 | 41 | return space 42 | -------------------------------------------------------------------------------- /rlberry/spaces/multi_binary.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from rlberry.seeding import Seeder 3 | 4 | 5 | class MultiBinary(gym.spaces.MultiBinary): 6 | """ 7 | 8 | Inherited from gymnasium.spaces.MultiBinary for compatibility with gym. 9 | 10 | rlberry wraps gym.spaces to make sure the seeding 11 | mechanism is unified in the library (rlberry.seeding) 12 | 13 | Attributes 14 | ---------- 15 | rng : numpy.random._generator.Generator 16 | random number generator provided by rlberry.seeding 17 | 18 | Methods 19 | ------- 20 | reseed() 21 | get new random number generator 22 | """ 23 | 24 | def __init__(self, n): 25 | gym.spaces.MultiBinary.__init__(self, n) 26 | self.seeder = Seeder() 27 | 28 | @property 29 | def rng(self): 30 | return self.seeder.rng 31 | 32 | def reseed(self, seed_seq=None): 33 | """ 34 | Get new random number generator. 35 | 36 | Parameters 37 | ---------- 38 | seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None 39 | Seed sequence from which to spawn the random number generator. 40 | If None, generate random seed. 41 | If int, use as entropy for SeedSequence. 42 | If seeder, use seeder.seed_seq 43 | """ 44 | self.seeder.reseed(seed_seq) 45 | 46 | def sample(self): 47 | return self.rng.integers(low=0, high=2, size=self.n, dtype=self.dtype) 48 | -------------------------------------------------------------------------------- /rlberry/spaces/multi_discrete.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | from rlberry.seeding import Seeder 4 | 5 | 6 | class MultiDiscrete(gym.spaces.MultiDiscrete): 7 | """ 8 | 9 | Inherited from gymnasium.spaces.MultiDiscrete for compatibility with gym. 10 | 11 | rlberry wraps gym.spaces to make sure the seeding 12 | mechanism is unified in the library (rlberry.seeding) 13 | 14 | Attributes 15 | ---------- 16 | rng : numpy.random._generator.Generator 17 | random number generator provided by rlberry.seeding 18 | 19 | Methods 20 | ------- 21 | reseed() 22 | get new random number generator 23 | """ 24 | 25 | def __init__(self, nvec, dtype=np.int64): 26 | gym.spaces.MultiDiscrete.__init__(self, nvec, dtype=dtype) 27 | self.seeder = Seeder() 28 | 29 | @property 30 | def rng(self): 31 | return self.seeder.rng 32 | 33 | def reseed(self, seed_seq=None): 34 | """ 35 | Get new random number generator. 36 | 37 | Parameters 38 | ---------- 39 | seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None 40 | Seed sequence from which to spawn the random number generator. 41 | If None, generate random seed. 42 | If int, use as entropy for SeedSequence. 43 | If seeder, use seeder.seed_seq 44 | """ 45 | self.seeder.reseed(seed_seq) 46 | 47 | def sample(self): 48 | sample = self.rng.random(self.nvec.shape) * self.nvec 49 | return sample.astype(self.dtype) 50 | -------------------------------------------------------------------------------- /rlberry/spaces/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/spaces/tests/__init__.py -------------------------------------------------------------------------------- /rlberry/spaces/tuple.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from rlberry.seeding import Seeder 3 | 4 | 5 | class Tuple(gym.spaces.Tuple): 6 | """ 7 | 8 | Inherited from gymnasium.spaces.Tuple for compatibility with gym. 9 | 10 | rlberry wraps gym.spaces to make sure the seeding 11 | mechanism is unified in the library (rlberry.seeding) 12 | 13 | Attributes 14 | ---------- 15 | rng : numpy.random._generator.Generator 16 | random number generator provided by rlberry.seeding 17 | 18 | Methods 19 | ------- 20 | reseed() 21 | get new random number generator 22 | """ 23 | 24 | def __init__(self, spaces): 25 | gym.spaces.Tuple.__init__(self, spaces) 26 | self.seeder = Seeder() 27 | 28 | @property 29 | def rng(self): 30 | return self.seeder.rng 31 | 32 | def reseed(self, seed_seq=None): 33 | _ = [space.reseed(seed_seq) for space in self.spaces] 34 | -------------------------------------------------------------------------------- /rlberry/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/tests/__init__.py -------------------------------------------------------------------------------- /rlberry/tests/test_agent_extra.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import rlberry_scool.agents as agents_scool 3 | import rlberry_research.agents.torch as torch_agents 4 | from rlberry.utils.check_agent import ( 5 | check_rl_agent, 6 | check_rlberry_agent, 7 | check_vectorized_env_agent, 8 | check_hyperparam_optimisation_agent, 9 | ) 10 | from rlberry_scool.agents.features import FeatureMap 11 | import numpy as np 12 | import sys 13 | 14 | 15 | class OneHotFeatureMap(FeatureMap): 16 | def __init__(self, S, A): 17 | self.S = S 18 | self.A = A 19 | self.shape = (S * A,) 20 | 21 | def map(self, observation, action): 22 | feat = np.zeros((self.S, self.A)) 23 | feat[observation, action] = 1.0 24 | return feat.flatten() 25 | 26 | 27 | # LSVIUCBAgent needs a feature map function to work. 28 | class OneHotLSVI(agents_scool.LSVIUCBAgent): 29 | def __init__(self, env, **kwargs): 30 | def feature_map_fn(_env): 31 | return OneHotFeatureMap(5, 2) # values for Chain 32 | 33 | agents_scool.LSVIUCBAgent.__init__( 34 | self, env, feature_map_fn=feature_map_fn, horizon=10, **kwargs 35 | ) 36 | 37 | 38 | # No agent "FINITE_MDP" in extra 39 | # FINITE_MDP_AGENTS = [ 40 | # ] 41 | 42 | 43 | CONTINUOUS_STATE_AGENTS = [ 44 | torch_agents.DQNAgent, 45 | torch_agents.MunchausenDQNAgent, 46 | torch_agents.REINFORCEAgent, 47 | torch_agents.PPOAgent, 48 | torch_agents.A2CAgent, 49 | ] 50 | 51 | # Maybe add PPO ? 52 | CONTINUOUS_ACTIONS_AGENTS = [torch_agents.SACAgent] 53 | 54 | 55 | HYPERPARAM_OPTI_AGENTS = [ 56 | torch_agents.PPOAgent, 57 | torch_agents.REINFORCEAgent, 58 | torch_agents.A2CAgent, 59 | torch_agents.SACAgent, 60 | ] 61 | 62 | 63 | MULTI_ENV_AGENTS = [ 64 | torch_agents.PPOAgent, 65 | ] 66 | 67 | # No agent "FINITE_MDP" in extra 68 | # @pytest.mark.parametrize("agent", FINITE_MDP_AGENTS) 69 | # def test_finite_state_agent(agent): 70 | # check_rl_agent(agent, env="discrete_state") 71 | # check_rlberry_agent(agent, env="discrete_state") 72 | 73 | 74 | @pytest.mark.xfail(sys.platform == "win32", reason="bug with windows???") 75 | @pytest.mark.parametrize("agent", CONTINUOUS_STATE_AGENTS) 76 | def test_continuous_state_agent(agent): 77 | check_rl_agent(agent, env="continuous_state") 78 | check_rlberry_agent(agent, env="continuous_state") 79 | 80 | 81 | @pytest.mark.xfail(sys.platform == "win32", reason="bug with windows???") 82 | @pytest.mark.parametrize("agent", CONTINUOUS_ACTIONS_AGENTS) 83 | def test_continuous_action_agent(agent): 84 | check_rl_agent(agent, env="continuous_action") 85 | check_rlberry_agent(agent, env="continuous_action") 86 | 87 | 88 | @pytest.mark.xfail(sys.platform == "win32", reason="bug with windows???") 89 | @pytest.mark.parametrize("agent", MULTI_ENV_AGENTS) 90 | def test_continuous_vectorized_env_agent(agent): 91 | check_vectorized_env_agent(agent, env="vectorized_env_continuous") 92 | 93 | 94 | @pytest.mark.xfail(sys.platform == "win32", reason="bug with windows???") 95 | @pytest.mark.parametrize("agent", HYPERPARAM_OPTI_AGENTS) 96 | def test_hyperparam_optimisation_agent(agent): 97 | check_hyperparam_optimisation_agent(agent, env="continuous_state") 98 | -------------------------------------------------------------------------------- /rlberry/tests/test_agents_base.py: -------------------------------------------------------------------------------- 1 | """ 2 | =============================================== 3 | Tests for the installation without extra (StableBaselines3, torch, optuna, ...) 4 | =============================================== 5 | tests based on test_agent.py and test_envs.py 6 | 7 | """ 8 | 9 | 10 | import pytest 11 | import numpy as np 12 | import sys 13 | 14 | import rlberry_research.agents as agents_research 15 | import rlberry_scool.agents as agents_scool 16 | from rlberry_scool.agents.features import FeatureMap 17 | 18 | from rlberry.utils.check_agent import ( 19 | check_rl_agent, 20 | check_rlberry_agent, 21 | ) 22 | 23 | 24 | class OneHotFeatureMap(FeatureMap): 25 | def __init__(self, S, A): 26 | self.S = S 27 | self.A = A 28 | self.shape = (S * A,) 29 | 30 | def map(self, observation, action): 31 | feat = np.zeros((self.S, self.A)) 32 | feat[observation, action] = 1.0 33 | return feat.flatten() 34 | 35 | 36 | # LSVIUCBAgent needs a feature map function to work. 37 | class OneHotLSVI(agents_scool.LSVIUCBAgent): 38 | def __init__(self, env, **kwargs): 39 | def feature_map_fn(_env): 40 | return OneHotFeatureMap(5, 2) # values for Chain 41 | 42 | agents_scool.LSVIUCBAgent.__init__( 43 | self, env, feature_map_fn=feature_map_fn, horizon=10, **kwargs 44 | ) 45 | 46 | 47 | FINITE_MDP_AGENTS = [ 48 | agents_scool.QLAgent, 49 | agents_scool.SARSAAgent, 50 | agents_scool.ValueIterationAgent, 51 | agents_scool.MBQVIAgent, 52 | agents_scool.UCBVIAgent, 53 | agents_research.OptQLAgent, 54 | agents_research.PSRLAgent, 55 | agents_research.RLSVIAgent, 56 | OneHotLSVI, 57 | ] 58 | 59 | 60 | CONTINUOUS_STATE_AGENTS = [ 61 | agents_research.RSUCBVIAgent, 62 | agents_research.RSKernelUCBVIAgent, 63 | ] 64 | 65 | 66 | @pytest.mark.parametrize("agent", FINITE_MDP_AGENTS) 67 | def test_finite_state_agent(agent): 68 | check_rl_agent(agent, env="discrete_state") 69 | check_rlberry_agent(agent, env="discrete_state") 70 | 71 | 72 | @pytest.mark.xfail(sys.platform == "win32", reason="bug with windows???") 73 | @pytest.mark.parametrize("agent", CONTINUOUS_STATE_AGENTS) 74 | def test_continuous_state_agent(agent): 75 | check_rl_agent(agent, env="continuous_state") 76 | check_rlberry_agent(agent, env="continuous_state") 77 | -------------------------------------------------------------------------------- /rlberry/tests/test_envs.py: -------------------------------------------------------------------------------- 1 | from rlberry.utils.check_env import check_env, check_rlberry_env 2 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D 3 | from rlberry_research.envs.benchmarks.generalization.twinrooms import TwinRooms 4 | from rlberry_research.envs.benchmarks.grid_exploration.apple_gold import AppleGold 5 | from rlberry_research.envs.benchmarks.grid_exploration.nroom import NRoom 6 | from rlberry_research.envs.classic_control import MountainCar, SpringCartPole 7 | from rlberry_scool.envs.finite import Chain, GridWorld 8 | import pytest 9 | 10 | ALL_ENVS = [ 11 | PBall2D, 12 | TwinRooms, 13 | AppleGold, 14 | NRoom, 15 | MountainCar, 16 | Chain, 17 | GridWorld, 18 | SpringCartPole, 19 | ] 20 | 21 | 22 | @pytest.mark.parametrize("Env", ALL_ENVS) 23 | def test_env(Env): 24 | check_env(Env()) 25 | check_rlberry_env(Env()) 26 | -------------------------------------------------------------------------------- /rlberry/tests/test_imports.py: -------------------------------------------------------------------------------- 1 | def test_imports(): 2 | import rlberry # noqa 3 | from rlberry.manager import ( # noqa 4 | ExperimentManager, # noqa 5 | evaluate_agents, # noqa 6 | plot_writer_data, # noqa 7 | ) # noqa 8 | from rlberry.agents import AgentWithSimplePolicy # noqa 9 | from rlberry.wrappers import WriterWrapper # noqa 10 | -------------------------------------------------------------------------------- /rlberry/types.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from typing import Any, Callable, Mapping, Tuple, Union 3 | from rlberry.seeding import Seeder 4 | 5 | # either a gymnasium.Env or a tuple containing (constructor, kwargs) to build the env 6 | Env = Union[gym.Env, Tuple[Callable[..., gym.Env], Mapping[str, Any]]] 7 | 8 | # 9 | Seed = Union[Seeder, int] 10 | -------------------------------------------------------------------------------- /rlberry/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .check_agent import ( 2 | check_rl_agent, 3 | check_save_load, 4 | check_fit_additive, 5 | check_seeding_agent, 6 | check_experiment_manager, 7 | ) 8 | from .check_env import check_env 9 | -------------------------------------------------------------------------------- /rlberry/utils/binsearch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def binary_search_nd(x_vec, bins): 5 | """n-dimensional binary search 6 | 7 | Parameters 8 | ----------- 9 | x_vec : numpy.ndarray 10 | numpy 1d array to be searched in the bins 11 | bins : list 12 | list of numpy 1d array, bins[d] = bins of the d-th dimension 13 | 14 | 15 | Returns 16 | -------- 17 | index (int) corresponding to the position of x in the partition 18 | defined by the bins. 19 | """ 20 | dim = len(bins) 21 | flat_index = 0 22 | aux = 1 23 | assert dim == len(x_vec), "dimension mismatch in binary_search_nd()" 24 | for dd in range(dim): 25 | index_dd = np.searchsorted(bins[dd], x_vec[dd], side="right") - 1 26 | assert index_dd != -1, "error in binary_search_nd()" 27 | flat_index += aux * index_dd 28 | aux *= len(bins[dd]) - 1 29 | return flat_index 30 | 31 | 32 | def unravel_index_uniform_bin(flat_index, dim, n_per_dim): 33 | index = [] 34 | aux_index = flat_index 35 | for _ in range(dim): 36 | index.append(aux_index % n_per_dim) 37 | aux_index = aux_index // n_per_dim 38 | return tuple(index) 39 | 40 | 41 | if __name__ == "__main__": 42 | bins = [(0, 1, 2, 3, 4), (0, 1, 2, 3, 4)] 43 | x = [3.9, 3.5] 44 | index = binary_search_nd(x, bins) 45 | print(index) 46 | -------------------------------------------------------------------------------- /rlberry/utils/check_env.py: -------------------------------------------------------------------------------- 1 | from rlberry.seeding import safe_reseed 2 | from rlberry.seeding import Seeder 3 | import numpy as np 4 | from rlberry.utils.check_gym_env import check_gym_env 5 | 6 | seeder = Seeder(42) 7 | 8 | 9 | def check_env(env): 10 | """ 11 | Check that the environment is (almost) gym-compatible and that it is reproducible 12 | in the sense that it returns the same states when given the same seed. 13 | 14 | Parameters 15 | ---------- 16 | env: gymnasium.env or rlberry env 17 | Environment that we want to check. 18 | """ 19 | # Small reproducibility test 20 | action = env.action_space.sample() 21 | safe_reseed(env, Seeder(42)) 22 | env.reset() 23 | a = env.step(action)[0] 24 | 25 | safe_reseed(env, Seeder(42)) 26 | env.reset() 27 | b = env.step(action)[0] 28 | if hasattr(a, "__len__"): 29 | assert np.all( 30 | np.array(a) == np.array(b) 31 | ), "The environment does not seem to be reproducible" 32 | else: 33 | assert a == b, "The environment does not seem to be reproducible" 34 | 35 | # Modified check suite from gym 36 | check_gym_env(env) 37 | 38 | 39 | def check_rlberry_env(env): 40 | """ 41 | Companion to check_env, contains additional tests. It is not mandatory 42 | for an environment to satisfy this check but satisfying this check give access to 43 | additional features in rlberry. 44 | 45 | Parameters 46 | ---------- 47 | env: gymnasium.env or rlberry env 48 | Environment that we want to check. 49 | """ 50 | try: 51 | env.get_params() 52 | except Exception: 53 | raise RuntimeError("Fail to call get_params on the environment.") 54 | -------------------------------------------------------------------------------- /rlberry/utils/factory.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from typing import Callable 3 | 4 | 5 | def load(path: str) -> Callable: 6 | module_name, class_name = path.rsplit(".", 1) 7 | return getattr(importlib.import_module(module_name), class_name) 8 | -------------------------------------------------------------------------------- /rlberry/utils/space_discretizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from gymnasium.spaces import Box, Discrete 3 | from rlberry.utils.binsearch import binary_search_nd 4 | from rlberry.utils.binsearch import unravel_index_uniform_bin 5 | 6 | 7 | class Discretizer: 8 | def __init__(self, space, n_bins): 9 | assert isinstance( 10 | space, Box 11 | ), "Discretization is only implemented for Box spaces." 12 | assert space.is_bounded() 13 | self.space = space 14 | self.n_bins = n_bins 15 | 16 | # initialize bins 17 | assert n_bins > 0, "Discretizer requires n_bins > 0" 18 | n_elements = 1 19 | tol = 1e-8 20 | self.dim = len(self.space.low) 21 | n_elements = n_bins**self.dim 22 | self._bins = [] 23 | self._open_bins = [] 24 | for dd in range(self.dim): 25 | range_dd = self.space.high[dd] - self.space.low[dd] 26 | epsilon = range_dd / n_bins 27 | bins_dd = [] 28 | for bb in range(n_bins + 1): 29 | val = self.space.low[dd] + epsilon * bb 30 | bins_dd.append(val) 31 | self._open_bins.append(tuple(bins_dd[1:])) 32 | bins_dd[-1] += tol # "close" the last interval 33 | self._bins.append(tuple(bins_dd)) 34 | 35 | # set observation space 36 | self.discrete_space = Discrete(n_elements) 37 | 38 | # List of discretized elements 39 | self.discretized_elements = np.zeros((self.dim, n_elements)) 40 | for ii in range(n_elements): 41 | self.discretized_elements[:, ii] = self.get_coordinates(ii, False) 42 | 43 | def discretize(self, coordinates): 44 | return binary_search_nd(coordinates, self._bins) 45 | 46 | def get_coordinates(self, flat_index, randomize=False): 47 | assert self.discrete_space.contains(flat_index), "invalid flat_index" 48 | # get multi-index 49 | index = unravel_index_uniform_bin(flat_index, self.dim, self.n_bins) 50 | 51 | # get coordinates 52 | coordinates = np.zeros(self.dim) 53 | for dd in range(self.dim): 54 | coordinates[dd] = self._bins[dd][index[dd]] 55 | if randomize: 56 | range_dd = self.space.high[dd] - self.space.low[dd] 57 | epsilon = range_dd / self.n_bins 58 | coordinates[dd] += epsilon * self.space.rng.uniform() 59 | return coordinates 60 | -------------------------------------------------------------------------------- /rlberry/utils/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/utils/tests/__init__.py -------------------------------------------------------------------------------- /rlberry/utils/tests/test_binsearch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from rlberry.utils.binsearch import binary_search_nd 5 | from rlberry.utils.binsearch import unravel_index_uniform_bin 6 | 7 | 8 | def test_binary_search_nd(): 9 | bin1 = np.array([0.0, 1.0, 2.0, 3.0]) # 3 intervals 10 | bin2 = np.array([1.0, 2.0, 3.0, 4.0]) # 3 intervals 11 | bin3 = np.array([2.0, 3.0, 4.0, 5.0, 6.0]) # 4 intervals 12 | 13 | bins = [bin1, bin2, bin3] 14 | 15 | vec1 = np.array([0.0, 1.0, 2.0]) 16 | vec2 = np.array([2.9, 3.9, 5.9]) 17 | vec3 = np.array([1.5, 2.5, 2.5]) 18 | vec4 = np.array([1.5, 2.5, 2.5]) 19 | 20 | # index = i + Ni * j + Ni * Nj * k 21 | assert binary_search_nd(vec1, bins) == 0 22 | assert binary_search_nd(vec2, bins) == 2 + 3 * 2 + 3 * 3 * 3 23 | assert binary_search_nd(vec3, bins) == 1 + 3 * 1 + 3 * 3 * 0 24 | 25 | 26 | @pytest.mark.parametrize( 27 | "i, j, k, N", [(0, 0, 0, 5), (0, 1, 2, 5), (4, 3, 2, 5), (4, 4, 4, 5)] 28 | ) 29 | def test_unravel_index_uniform_bin(i, j, k, N): 30 | # index = i + N * j + N * N * k 31 | dim = 3 32 | flat_index = i + N * j + N * N * k 33 | assert (i, j, k) == unravel_index_uniform_bin(flat_index, dim, N) 34 | -------------------------------------------------------------------------------- /rlberry/utils/tests/test_writer.py: -------------------------------------------------------------------------------- 1 | import time 2 | from rlberry_research.envs import GridWorld 3 | from rlberry.agents import AgentWithSimplePolicy 4 | from rlberry.manager import ExperimentManager 5 | 6 | 7 | class DummyAgent(AgentWithSimplePolicy): 8 | def __init__(self, env, hyperparameter1=0, hyperparameter2=0, **kwargs): 9 | AgentWithSimplePolicy.__init__(self, env, **kwargs) 10 | self.name = "DummyAgent" 11 | self.fitted = False 12 | self.hyperparameter1 = hyperparameter1 13 | self.hyperparameter2 = hyperparameter2 14 | 15 | self.total_budget = 0.0 16 | 17 | def fit(self, budget, **kwargs): 18 | del kwargs 19 | self.fitted = True 20 | self.total_budget += budget 21 | for ii in range(budget): 22 | if self.writer is not None: 23 | self.writer.add_scalar("a", ii, ii) 24 | scalar_dict = dict(multi1=1, multi2=2) 25 | self.writer.add_scalars("multi_scalar_test", scalar_dict) 26 | time.sleep(1) 27 | 28 | return None 29 | 30 | def policy(self, observation): 31 | return 0 32 | 33 | 34 | def test_myoutput(capsys): # or use "capfd" for fd-level 35 | env_ctor = GridWorld 36 | env_kwargs = dict() 37 | 38 | env = env_ctor(**env_kwargs) 39 | xp_manager = ExperimentManager( 40 | DummyAgent, 41 | (env_ctor, env_kwargs), 42 | fit_budget=3, 43 | n_fit=1, 44 | default_writer_kwargs={"log_interval": 1}, 45 | ) 46 | budget_size = 22 47 | xp_manager.fit(budget=budget_size) 48 | 49 | assert xp_manager.agent_handlers[0].writer.summary_writer == None 50 | assert list(xp_manager.agent_handlers[0].writer.read_tag_value("a")) == list( 51 | range(budget_size) 52 | ) 53 | assert xp_manager.agent_handlers[0].writer.read_first_tag_value("a") == 0 54 | assert ( 55 | xp_manager.agent_handlers[0].writer.read_last_tag_value("a") == budget_size - 1 56 | ) # start at 0 57 | 58 | captured = capsys.readouterr() 59 | # test that what is written to stderr is longer than 50 char, 60 | assert ( 61 | len(captured.err) + len(captured.out) > 50 62 | ), "the logging did not print the info to stderr" 63 | -------------------------------------------------------------------------------- /rlberry/utils/torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import shutil 4 | from subprocess import check_output, run, PIPE 5 | import numpy as np 6 | import torch 7 | 8 | 9 | import rlberry 10 | 11 | logger = rlberry.logger 12 | 13 | 14 | def get_gpu_memory_map(): 15 | result = check_output( 16 | ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"] 17 | ) 18 | return [int(x) for x in result.split()] 19 | 20 | 21 | def least_used_device(): 22 | """Get the GPU device with most available memory.""" 23 | if not torch.cuda.is_available(): 24 | raise RuntimeError("cuda unavailable") 25 | 26 | if shutil.which("nvidia-smi") is None: 27 | raise RuntimeError( 28 | "nvidia-smi unavailable: \ 29 | cannot select device with most least memory used." 30 | ) 31 | 32 | memory_map = get_gpu_memory_map() 33 | device_id = np.argmin(memory_map) 34 | logger.debug( 35 | f"Choosing GPU device: {device_id}, " f"memory used: {memory_map[device_id]}" 36 | ) 37 | return torch.device("cuda:{}".format(device_id)) 38 | 39 | 40 | def choose_device(preferred_device, default_device="cpu"): 41 | """Choose torch device, use default if choice is not available. 42 | 43 | Parameters 44 | ---------- 45 | preferred_device: str 46 | Torch device to be used (if available), e.g. "cpu", "cuda:0", "cuda:best". 47 | If "cuda:best", returns the least used device in the machine. 48 | default_device: str, default = "cpu" 49 | Default device if preferred_device is not available. 50 | """ 51 | if preferred_device == "cuda:best": 52 | try: 53 | preferred_device = least_used_device() 54 | except RuntimeError: 55 | logger.debug( 56 | f"Could not find least used device (nvidia-smi might be missing), use cuda:0 instead" 57 | ) 58 | if torch.cuda.is_available(): 59 | return choose_device("cuda:0") 60 | else: 61 | return choose_device("cpu") 62 | try: 63 | torch.zeros((1,), device=preferred_device) # Test availability 64 | except (RuntimeError, AssertionError) as e: 65 | logger.debug( 66 | f"Preferred device {preferred_device} unavailable ({e})." 67 | f"Switching to default {default_device}" 68 | ) 69 | return default_device 70 | return preferred_device 71 | 72 | 73 | def get_memory(pid=None): 74 | if not pid: 75 | pid = os.getpid() 76 | command = "nvidia-smi" 77 | result = run( 78 | command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True 79 | ).stdout 80 | m = re.findall( 81 | r"\| *[0-9] *" + str(pid) + r" *C *.*python.*? +([0-9]+).*\|", 82 | result, 83 | re.MULTILINE, 84 | ) 85 | return [int(mem) for mem in m] 86 | -------------------------------------------------------------------------------- /rlberry/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from .discretize_state import DiscretizeStateWrapper 2 | from .rescale_reward import RescaleRewardWrapper 3 | from .writer_utils import WriterWrapper 4 | from .discrete2onehot import DiscreteToOneHotWrapper 5 | -------------------------------------------------------------------------------- /rlberry/wrappers/autoreset.py: -------------------------------------------------------------------------------- 1 | from rlberry.envs import Wrapper 2 | 3 | 4 | class AutoResetWrapper(Wrapper): 5 | """ 6 | Auto reset the environment after "horizon" steps have passed. 7 | """ 8 | 9 | def __init__(self, env, horizon): 10 | """ 11 | Parameters 12 | ---------- 13 | horizon: int 14 | """ 15 | Wrapper.__init__(self, env) 16 | self.horizon = horizon 17 | assert self.horizon >= 1 18 | self.current_step = 0 19 | 20 | def reset(self, seed=None, options=None): 21 | self.current_step = 0 22 | return self.env.reset(seed=seed, options=options) 23 | 24 | def step(self, action): 25 | observation, reward, terminated, truncated, info = self.env.step(action) 26 | self.current_step += 1 27 | # At H, always return to the initial state. 28 | # Also, set done to True. 29 | if self.current_step == self.horizon: 30 | self.current_step = 0 31 | observation, info = self.env.reset() 32 | terminated = True 33 | truncated = False 34 | return observation, reward, terminated, truncated, info 35 | -------------------------------------------------------------------------------- /rlberry/wrappers/discrete2onehot.py: -------------------------------------------------------------------------------- 1 | from rlberry.spaces import Box, Discrete 2 | from rlberry.envs import Wrapper 3 | import numpy as np 4 | 5 | 6 | class DiscreteToOneHotWrapper(Wrapper): 7 | """Converts observation spaces from Discrete to Box via one-hot encoding.""" 8 | 9 | def __init__(self, env): 10 | Wrapper.__init__(self, env, wrap_spaces=True) 11 | obs_space = self.env.observation_space 12 | assert isinstance(obs_space, Discrete) 13 | self.observation_space = Box( 14 | low=0.0, high=1.0, shape=(obs_space.n,), dtype=np.uint32 15 | ) 16 | 17 | def process_obs(self, obs): 18 | one_hot_obs = np.zeros(self.env.observation_space.n, dtype=np.uint32) 19 | one_hot_obs[obs] = 1.0 20 | return one_hot_obs 21 | 22 | def reset(self, seed=None, options=None): 23 | obs, info = self.env.reset(seed=seed, options=options) 24 | return self.process_obs(obs), info 25 | 26 | def step(self, action): 27 | observation, reward, terminated, truncated, info = self.env.step(action) 28 | observation = self.process_obs(observation) 29 | return observation, reward, terminated, truncated, info 30 | -------------------------------------------------------------------------------- /rlberry/wrappers/gym_utils.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from gymnasium.utils.step_api_compatibility import step_api_compatibility 3 | from rlberry.spaces import Discrete 4 | from rlberry.spaces import Box 5 | from rlberry.spaces import Tuple 6 | from rlberry.spaces import MultiDiscrete 7 | from rlberry.spaces import MultiBinary 8 | from rlberry.spaces import Dict 9 | 10 | from rlberry.envs import Wrapper 11 | 12 | 13 | def convert_space_from_gym(gym_space): 14 | if isinstance(gym_space, gym.spaces.Discrete): 15 | return Discrete(gym_space.n) 16 | # 17 | # 18 | elif isinstance(gym_space, gym.spaces.Box): 19 | return Box(gym_space.low, gym_space.high, gym_space.shape, gym_space.dtype) 20 | # 21 | # 22 | elif isinstance(gym_space, gym.spaces.Tuple): 23 | spaces = [] 24 | for sp in gym_space.spaces: 25 | spaces.append(convert_space_from_gym(sp)) 26 | return Tuple(spaces) 27 | # 28 | # 29 | elif isinstance(gym_space, gym.spaces.MultiDiscrete): 30 | return MultiDiscrete(gym_space.nvec) 31 | # 32 | # 33 | elif isinstance(gym_space, gym.spaces.MultiBinary): 34 | return MultiBinary(gym_space.n) 35 | # 36 | # 37 | elif isinstance(gym_space, gym.spaces.Dict): 38 | spaces = {} 39 | for key in gym_space.spaces: 40 | spaces[key] = convert_space_from_gym(gym_space[key]) 41 | return Dict(spaces) 42 | else: 43 | raise ValueError("Unknown space class: {}".format(type(gym_space))) 44 | 45 | 46 | class OldGymCompatibilityWrapper(Wrapper): 47 | """ 48 | Allow to use old gym env (V0.21) with rlberry (gymnasium). 49 | (for basic use only) 50 | """ 51 | 52 | def __init__(self, env): 53 | Wrapper.__init__(self, env) 54 | 55 | def reset(self, seed=None, options=None): 56 | if seed: 57 | self.env.reseed(seed) 58 | observation = self.env.reset() 59 | return observation, {} 60 | 61 | def step(self, action): 62 | obs, rewards, terminated, truncated, info = step_api_compatibility( 63 | self.env.step(action), output_truncation_bool=True 64 | ) 65 | return obs, rewards, terminated, truncated, info 66 | -------------------------------------------------------------------------------- /rlberry/wrappers/rescale_reward.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from rlberry.envs import Wrapper 3 | 4 | 5 | class RescaleRewardWrapper(Wrapper): 6 | """ 7 | Rescale the reward function to a bounded range. 8 | 9 | Parameters 10 | ---------- 11 | reward_range: tuple (double, double) 12 | tuple with the desired reward range, which needs to be bounded. 13 | """ 14 | 15 | def __init__(self, env, reward_range): 16 | Wrapper.__init__(self, env) 17 | self.reward_range = reward_range 18 | assert reward_range[0] < reward_range[1] 19 | assert reward_range[0] > -np.inf and reward_range[1] < np.inf 20 | 21 | def _linear_rescaling(self, x, x0, x1, u0, u1): 22 | """ 23 | For x a value in [x0, x1], maps x linearly to the interval [u0, u1]. 24 | """ 25 | a = (u1 - u0) / (x1 - x0) 26 | b = (x1 * u0 - x0 * u1) / (x1 - x0) 27 | return a * x + b 28 | 29 | def _rescale(self, reward): 30 | x0, x1 = self.env.reward_range 31 | u0, u1 = self.reward_range 32 | # bounded reward 33 | if x0 > -np.inf and x1 < np.inf: 34 | return self._linear_rescaling(reward, x0, x1, u0, u1) 35 | # unbounded 36 | elif x0 > -np.inf and x1 == np.inf: 37 | x = reward - x0 # [0, infty] 38 | x = 2.0 / (1.0 + np.exp(-x)) - 1.0 # [0, 1] 39 | return self._linear_rescaling(x, 0.0, 1.0, u0, u1) 40 | # unbouded below 41 | elif x0 == -np.inf and x1 < np.inf: 42 | x = reward - x1 # [-infty, 0] 43 | x = 2.0 / (1.0 + np.exp(-x)) # [0, 1] 44 | return self._linear_rescaling(x, 0.0, 1.0, u0, u1) 45 | # unbounded 46 | else: 47 | x = 1.0 / (1.0 + np.exp(-reward)) # [0, 1] 48 | return self._linear_rescaling(x, 0.0, 1.0, u0, u1) 49 | 50 | def step(self, action): 51 | observation, reward, terminated, truncated, info = self.env.step(action) 52 | rescaled_reward = self._rescale(reward) 53 | return observation, rescaled_reward, terminated, truncated, info 54 | 55 | def sample(self, state, action): 56 | observation, reward, terminated, truncated, info = self.env.sample( 57 | state, action 58 | ) 59 | rescaled_reward = self._rescale(reward) 60 | return observation, rescaled_reward, terminated, truncated, info 61 | -------------------------------------------------------------------------------- /rlberry/wrappers/tests/__init__.py: -------------------------------------------------------------------------------- 1 | from .old_env import ( 2 | old_acrobot, 3 | old_twinrooms, 4 | old_six_room, 5 | old_apple_gold, 6 | old_ball2d, 7 | old_finite_mdp, 8 | old_four_room, 9 | old_gridworld, 10 | old_mountain_car, 11 | old_nroom, 12 | old_pball, 13 | old_pendulum, 14 | ) 15 | -------------------------------------------------------------------------------- /rlberry/wrappers/tests/old_env/__init__.py: -------------------------------------------------------------------------------- 1 | from .old_acrobot import Old_Acrobot 2 | from .old_apple_gold import Old_AppleGold 3 | from .old_four_room import Old_FourRoom 4 | from .old_gridworld import Old_GridWorld 5 | from .old_mountain_car import Old_MountainCar 6 | from .old_nroom import Old_NRoom 7 | from .old_pendulum import Old_Pendulum 8 | from .old_pball import Old_PBall2D, Old_SimplePBallND 9 | from .old_six_room import Old_SixRoom 10 | from .old_twinrooms import Old_TwinRooms 11 | -------------------------------------------------------------------------------- /rlberry/wrappers/tests/test_basewrapper.py: -------------------------------------------------------------------------------- 1 | from rlberry.envs.interface import Model 2 | from rlberry.envs import Wrapper 3 | from rlberry_research.envs import GridWorld 4 | import gymnasium as gym 5 | 6 | 7 | def test_wrapper(): 8 | env = GridWorld() 9 | wrapped = Wrapper(env) 10 | assert isinstance(wrapped, Model) 11 | assert wrapped.is_online() 12 | assert wrapped.is_generative() 13 | 14 | # calling some functions 15 | wrapped.reset() 16 | wrapped.step(wrapped.action_space.sample()) 17 | wrapped.sample(wrapped.observation_space.sample(), wrapped.action_space.sample()) 18 | 19 | 20 | def test_gym_wrapper(): 21 | gym_env = gym.make("Acrobot-v1") 22 | wrapped = Wrapper(gym_env) 23 | assert isinstance(wrapped, Model) 24 | assert wrapped.is_online() 25 | assert not wrapped.is_generative() 26 | 27 | wrapped.reseed() 28 | 29 | # calling some gym functions 30 | wrapped.close() 31 | wrapped.seed() 32 | -------------------------------------------------------------------------------- /rlberry/wrappers/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from rlberry.wrappers.utils import get_base_env 3 | 4 | 5 | def test_get_base_env(): 6 | """Test that the utils function 'get_base_env' return the wrapped env without the wrappers""" 7 | 8 | from rlberry.envs.basewrapper import Wrapper 9 | from stable_baselines3.common.monitor import Monitor 10 | 11 | from stable_baselines3.common.atari_wrappers import ( # isort:skip 12 | FireResetEnv, 13 | MaxAndSkipEnv, 14 | NoopResetEnv, 15 | NoopResetEnv, 16 | StickyActionEnv, 17 | ) 18 | 19 | env = gym.make("ALE/Breakout-v5") 20 | original_env = env 21 | 22 | # add wrappers 23 | env = Wrapper(env) 24 | env = Monitor(env) 25 | env = StickyActionEnv(env, 0.2) 26 | env = NoopResetEnv(env, noop_max=2) 27 | env = MaxAndSkipEnv(env, skip=4) 28 | env = FireResetEnv(env) 29 | env = gym.wrappers.GrayscaleObservation(env) 30 | env = gym.wrappers.FrameStackObservation(env, 8) 31 | assert original_env != env 32 | 33 | # use the tool 34 | unwrapped_env = get_base_env(env) 35 | 36 | # test the result 37 | assert unwrapped_env != env 38 | assert isinstance(unwrapped_env, gym.Env) 39 | assert unwrapped_env == get_base_env(original_env) 40 | -------------------------------------------------------------------------------- /rlberry/wrappers/tests/test_writer_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from rlberry_scool.envs import GridWorld 4 | 5 | from rlberry_scool.agents import UCBVIAgent 6 | 7 | 8 | @pytest.mark.parametrize( 9 | "write_scalar", [None, "action", "reward", "action_and_reward"] 10 | ) 11 | def test_wrapper(write_scalar): 12 | """Test that the wrapper record data""" 13 | 14 | class MyAgent(UCBVIAgent): 15 | def __init__(self, env, **kwargs): 16 | UCBVIAgent.__init__(self, env, writer_extra=write_scalar, **kwargs) 17 | 18 | env = GridWorld() 19 | agent = MyAgent(env) 20 | agent.fit(budget=10) 21 | assert len(agent.writer.data) > 0 22 | 23 | 24 | def test_invalid_wrapper_name(): 25 | """Test that warning message is thrown when unsupported write_scalar.""" 26 | 27 | class MyAgent(UCBVIAgent): 28 | def __init__(self, env, **kwargs): 29 | UCBVIAgent.__init__(self, env, writer_extra="invalid", **kwargs) 30 | 31 | msg = "write_scalar invalid is not known" 32 | env = GridWorld() 33 | agent = MyAgent(env) 34 | with pytest.raises(ValueError, match=msg): 35 | agent.fit(budget=10) 36 | -------------------------------------------------------------------------------- /rlberry/wrappers/utils.py: -------------------------------------------------------------------------------- 1 | def get_base_env(env): 2 | """Traverse the wrappers to find the base environment.""" 3 | while hasattr(env, "env"): 4 | env = env.env 5 | return env 6 | -------------------------------------------------------------------------------- /rlberry/wrappers/writer_utils.py: -------------------------------------------------------------------------------- 1 | from rlberry.envs import Wrapper 2 | 3 | 4 | class WriterWrapper(Wrapper): 5 | """ 6 | Wrapper for environment to automatically record reward or action in writer. 7 | 8 | Parameters 9 | ---------- 10 | env : gymnasium.Env or tuple (constructor, kwargs) 11 | Environment used to fit the agent. 12 | 13 | writer : object, default: None 14 | Writer object (e.g. tensorboard SummaryWriter). 15 | 16 | write_scalar : string in {"reward", "action", "action_and_reward"}, 17 | default = "reward" 18 | Scalar that will be recorded in the writer. 19 | 20 | """ 21 | 22 | def __init__(self, env, writer, write_scalar="reward"): 23 | Wrapper.__init__(self, env) 24 | self.writer = writer 25 | self.write_scalar = write_scalar 26 | self.iteration_ = 0 27 | 28 | def step(self, action): 29 | observation, reward, terminated, truncated, info = self.env.step(action) 30 | 31 | self.iteration_ += 1 32 | if self.write_scalar == "reward": 33 | self.writer.add_scalar("reward", reward, self.iteration_) 34 | elif self.write_scalar == "action": 35 | self.writer.add_scalar("action", action, self.iteration_) 36 | elif self.write_scalar == "action_and_reward": 37 | self.writer.add_scalar("reward", reward, self.iteration_) 38 | self.writer.add_scalar("action", action, self.iteration_) 39 | else: 40 | raise ValueError("write_scalar %s is not known" % (self.write_scalar)) 41 | 42 | return observation, reward, terminated, truncated, info 43 | -------------------------------------------------------------------------------- /scripts/apptainer_for_tests/README.md: -------------------------------------------------------------------------------- 1 | 2 | To test rlberry, you can use this script to create a container that install the latest version, run the tests, and send the result by email. 3 | (or you can check inside the .sh file to only get the part you need) 4 | 5 | :warning: **WARNING** :warning: : In both files, you have to update the paths and names 6 | 7 | ## .def 8 | Scripts to build your apptainer. 9 | 2 scripts : 10 | - 1 with the "current" version of python (from ubuntu:last) 11 | - 1 with a specific version of python to choose 12 | 13 | ## .sh 14 | Script to run your apptainer and send the report 15 | use chmod +x [name].sh to make it executable 16 | 17 | To run this script you need to install "mailutils" first (to send the report by email) 18 | -------------------------------------------------------------------------------- /scripts/apptainer_for_tests/monthly_test_base.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Set the email recipient and subject 4 | recipient="email1@inria.fr,email2@inria.fr" 5 | 6 | # Build, then run, the Apptainer container and capture the test results 7 | cd [Path_to_your_apptainer_folder] 8 | #build the apptainer from the .def file into an .sif image 9 | apptainer build --fakeroot --force rlberry_apptainer_base.sif rlberry_apptainer_base.def 10 | #Run the "runscript" section, and export the result inside a file (named previously) 11 | apptainer run --fakeroot --overlay my_overlay/ rlberry_apptainer.sif > "$attachment" 12 | 13 | 14 | # Send the test results by email 15 | exit_code1=$(cat [path]/exit_code1.txt) # Read the exit code from the file (tests) 16 | exit_code2=$(cat [path]/exit_code2.txt) # Read the exit code from the file (long tests) 17 | exit_code3=$(cat [path]/exit_code3.txt) # Read the exit code from the file (doc test) 18 | 19 | if [ $exit_code -eq 0 ]; then 20 | # Initialization when the exit code is 0 (success) 21 | subject="Rlberry : Success Monthly Test Report" 22 | core_message="Success. Please find attached the monthly test reports." 23 | else 24 | # Initialization when the exit code is not 0 (failed) 25 | subject="Rlberry : Failed Monthly Test Report" 26 | core_message="Failed. Please find attached the monthly test reports." 27 | fi 28 | 29 | 30 | echo "$core_message" | mail -s "$subject" -A "test_result.txt" -A "long_test_result.txt" -A "doc_test_result.txt" -A"lib_versions.txt" "$recipient" -aFrom:"Rlberry_Monthly_tests" 31 | -------------------------------------------------------------------------------- /scripts/apptainer_for_tests/rlberry_apptainer__specific_python.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: ubuntu:latest 3 | 4 | #script for the build 5 | %post -c /bin/bash 6 | 7 | #get the last Ubuntu Update, and add the desdsnakes ppa to acces other python version 8 | apt-get update \ 9 | && apt-get upgrade -y 10 | apt-get install -y software-properties-common 11 | add-apt-repository ppa:deadsnakes/ppa -y 12 | apt-get update 13 | 14 | # Install python, and graphic and basic libs. Don't forget to change [version] by the python you want (python[version] > python3.11), then set the new python as "main" python 15 | DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get install -y python[version] python[version]-dev python[version]-venv python3-pip git ffmpeg libsm6 libxext6 libsdl2-dev xvfb x11-xkb-utils --no-install-recommends 16 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python[version] 1 17 | pip3 install --upgrade pip 18 | 19 | 20 | #Remove the old tmp folder if it exist, then install rlberry 21 | if [ -d /tmp/rlberry_test_dir[version] ]; then /bin/rm -r /tmp/rlberry_test_dir[version]; fi 22 | git clone https://github.com/rlberry-py/rlberry.git /tmp/rlberry_test_dir[version] 23 | 24 | #Install all the lib we need to run rlberry and its tests 25 | pip3 install rlberry[torch_agents] opencv-python pytest pytest-xvfb pytest-xprocess tensorboard #--break-system-packages 26 | pip3 install gymnasium[other] 27 | 28 | #Environmment variable, Don't forget to change [version] 29 | %environment 30 | export LC_ALL=C 31 | export PATH="/usr/bin/python[version]:$PATH" 32 | 33 | #script that will be executed with the "run" command : run the tests in rlberry, then export the exit code inside a text file 34 | %runscript 35 | cd /tmp/rlberry_test_dir[version] && \ 36 | pytest rlberry 37 | echo $? > [path]/exit_code.txt 38 | -------------------------------------------------------------------------------- /scripts/apptainer_for_tests/rlberry_apptainer_base.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: ubuntu:latest 3 | 4 | %post -c /bin/bash 5 | 6 | #get the last Ubuntu Update 7 | apt-get update \ 8 | && apt-get upgrade -y 9 | 10 | echo "export PS1=Apptainer-\[\e]0;\u@\h: \w\a\]${debian_chroot:+($debian_chroot)}\[\033[01;32m\]\u@\h\[\033[00m\]:\[\033[01;34m\]\w\[\033[00m\]\$" >> /etc/profile 11 | 12 | # Install python, and graphic and basic libs. 13 | apt-get install -y software-properties-common python3-pip git ffmpeg libsm6 libxext6 libsdl2-dev xvfb x11-xkb-utils libblas-dev liblapack-dev 14 | pip3 install --upgrade pip setuptools wheel 15 | 16 | # remove the old folder 17 | if [ -d /tmp/rlberry_test_dir ]; then /bin/rm -r /tmp/rlberry_test_dir; fi 18 | if [ -d /tmp/rlberry-research_test_dir ]; then /bin/rm -r /tmp/rlberry-research_test_dir; fi 19 | 20 | # get all the git repos to do their tests 21 | git clone https://github.com/rlberry-py/rlberry.git /tmp/rlberry_test_dir 22 | 23 | cd /tmp/rlberry_test_dir/ 24 | git fetch --tags 25 | latestTag=$(git describe --tags `git rev-list --tags --max-count=1`) 26 | git checkout $latestTag 27 | 28 | 29 | git clone https://github.com/rlberry-py/rlberry-research.git /tmp/rlberry-research_test_dir 30 | 31 | cd /tmp/rlberry-research_test_dir/ 32 | git fetch --tags 33 | latestTagResearch=$(git describe --tags `git rev-list --tags --max-count=1`) 34 | git checkout $latestTagResearch 35 | 36 | # install rlberry, rlberry-scool and rlberry-research 37 | pip3 install rlberry[torch,doc,extras] rlberry-scool opencv-python pytest pytest-xvfb pytest-xprocess tensorboard #--break-system-packages 38 | pip3 install git+https://github.com/rlberry-py/rlberry-research.git 39 | 40 | 41 | 42 | %environment 43 | export LC_ALL=C 44 | 45 | %runscript 46 | # info about current versions 47 | pip list > [path]/lib_versions.txt 48 | 49 | #run tests 50 | (cd /tmp/rlberry_test_dir && \ 51 | date && \ 52 | pytest rlberry && \ 53 | date) > [path]/test_result.txt 54 | echo $? > [path]/exit_code1.txt 55 | 56 | #run long tests 57 | (cd /tmp/rlberry-research_test_dir && \ 58 | date && \ 59 | pytest long_tests/rl_agent/ltest_mbqvi_applegold.py long_tests/torch_agent/ltest_a2c_cartpole.py long_tests/torch_agent/ltest_ctn_ppo_a2c_pendulum.py long_tests/torch_agent/ltest_dqn_montaincar.py && \ 60 | date) > [path]/long_test_result.txt 61 | #pytest --ignore=long_tests/torch_agent/ltest_dqn_vs_mdqn_acrobot.py long_tests/**/*.py 62 | echo $? > [path]exit_code2.txt 63 | 64 | #run doc test 65 | (cd /tmp/rlberry_test_dir/docs/ && \ 66 | date && \ 67 | ./markdown_to_py.sh && \ 68 | for f in python_scripts/*.py; do python3 $f ; done && \ 69 | date) > [path]/doc_test_result.txt 70 | echo $? > [path]/exit_code3.txt 71 | -------------------------------------------------------------------------------- /scripts/build_docs.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../docs/ 4 | sphinx-apidoc -o source/ ../rlberry 5 | make html 6 | cd .. 7 | 8 | # Useful: https://samnicholls.net/2016/06/15/how-to-sphinx-readthedocs/ 9 | -------------------------------------------------------------------------------- /scripts/conda_env_setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd $CONDA_PREFIX 4 | mkdir -p ./etc/conda/activate.d 5 | mkdir -p ./etc/conda/deactivate.d 6 | touch ./etc/conda/activate.d/env_vars.sh 7 | touch ./etc/conda/deactivate.d/env_vars.sh 8 | 9 | echo '#!/bin/sh' > ./etc/conda/activate.d/env_vars.sh 10 | echo >> ./etc/conda/activate.d/env_vars.sh 11 | echo "export LD_LIBRARY_PATH=$CONDA_PREFIX/lib" >> ./etc/conda/activate.d/env_vars.sh 12 | 13 | echo '#!/bin/sh' > ./etc/conda/deactivate.d/env_vars.sh 14 | echo >> ./etc/conda/deactivate.d/env_vars.sh 15 | echo "unset LD_LIBRARY_PATH" >> ./etc/conda/deactivate.d/env_vars.sh 16 | 17 | echo "Contents of $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh:" 18 | cat ./etc/conda/activate.d/env_vars.sh 19 | echo "" 20 | 21 | echo "Contents of $CONDA_PREFIX/etc/conda/deactivate.d/env_vars.sh:" 22 | cat ./etc/conda/deactivate.d/env_vars.sh 23 | echo "" 24 | -------------------------------------------------------------------------------- /scripts/construct_video_examples.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Script used to construct the videos for the examples that output videos. 4 | # Please use a script name that begins with video_plot as with the other examples 5 | # and in the script, there should be a line to save the video in the right place 6 | # and a line to load the video in the headers. Look at existing examples for 7 | # the correct syntax. Be careful that you must remove the _build folder before 8 | # recompiling the doc when a video has been updated/added. 9 | 10 | 11 | for f in ../examples/video_plot*.py ; 12 | do 13 | # construct the mp4 14 | python $f 15 | name=$(basename $f) 16 | # make a thumbnail. Warning : the video should have the same name as the python script 17 | # i.e. video_plot_SOMETHING.mp4 to be detected 18 | ffmpeg -i ../docs/_video/${name%%.py}.mp4 -vframes 1 -f image2 ../docs/thumbnails/${name%%.py}.jpg 19 | done 20 | -------------------------------------------------------------------------------- /scripts/full_install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install everything! 4 | 5 | pip install -e .[default] 6 | pip install -e .[jax_agents] 7 | pip install -e .[torch_agents] 8 | 9 | pip install pytest 10 | pip install pytest-cov 11 | conda install -c conda-forge jupyterlab 12 | -------------------------------------------------------------------------------- /scripts/run_testscov.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # disable JIT to get complete coverage report 4 | export NUMBA_DISABLE_JIT=1 5 | 6 | # run pytest 7 | cd .. 8 | pytest --cov=rlberry --cov-report html:cov_html 9 | --------------------------------------------------------------------------------