├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ ├── feature_request.md
│ └── other_issue.md
├── PULL_REQUEST_TEMPLATE.md
├── dependabot.yml
└── workflows
│ ├── doc_dev.yml
│ ├── doc_stable.yml
│ ├── post_merge.yml
│ ├── preview.yml
│ └── ready_for_CI.yml
├── .gitignore
├── .pre-commit-config.yaml
├── CITATION.cff
├── CODE_OF_CONDUCT.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── assets
├── logo_avatar.png
├── logo_square.svg
└── logo_wide.svg
├── azure-pipelines.yml
├── codecov.yml
├── docs
├── Makefile
├── _video
│ ├── example_plot_atari_atlantis_vectorized_ppo.mp4
│ ├── example_plot_atari_breakout_vectorized_ppo.mp4
│ ├── user_guide_video
│ │ ├── _agent_page_CartPole.mp4
│ │ ├── _agent_page_chain1.mp4
│ │ ├── _agent_page_chain2.mp4
│ │ ├── _agent_page_frozenLake.mp4
│ │ ├── _env_page_Breakout.mp4
│ │ ├── _env_page_MountainCar.mp4
│ │ ├── _env_page_chain.mp4
│ │ └── _experimentManager_page_CartPole.mp4
│ ├── video_chain_quickstart.mp4
│ ├── video_plot_a2c.mp4
│ ├── video_plot_acrobot.mp4
│ ├── video_plot_apple_gold.mp4
│ ├── video_plot_atari_freeway.mp4
│ ├── video_plot_chain.mp4
│ ├── video_plot_dqn.mp4
│ ├── video_plot_gridworld.mp4
│ ├── video_plot_mbqvi.mp4
│ ├── video_plot_mdqn.mp4
│ ├── video_plot_montain_car.mp4
│ ├── video_plot_old_gym_acrobot.mp4
│ ├── video_plot_pball.mp4
│ ├── video_plot_ppo.mp4
│ ├── video_plot_rooms.mp4
│ ├── video_plot_rs_kernel_ucbvi.mp4
│ ├── video_plot_rsucbvi.mp4
│ ├── video_plot_springcartpole.mp4
│ ├── video_plot_twinrooms.mp4
│ └── video_plot_vi.mp4
├── about.rst
├── api.rst
├── basics
│ ├── DeepRLTutorial
│ │ ├── TutorialDeepRL.md
│ │ ├── output_10_3.png
│ │ ├── output_5_3.png
│ │ ├── output_6_3.png
│ │ └── output_9_3.png
│ ├── agent_manager_diagram.png
│ ├── comparison.md
│ ├── create_agent.rst
│ ├── evaluate_agent.rst
│ ├── experiment_setup.rst
│ ├── multiprocess.rst
│ ├── quick_start_rl
│ │ ├── Figure_1.png
│ │ ├── Figure_2.png
│ │ ├── Figure_3.png
│ │ ├── experiment_manager_diagram.png
│ │ ├── gif_chain.gif
│ │ ├── quickstart.md
│ │ └── video_chain.mp4
│ ├── rlberry how to.rst
│ ├── seeding.rst
│ └── userguide
│ │ ├── adastop.md
│ │ ├── adastop_boxplots.png
│ │ ├── agent.md
│ │ ├── entropy_loss.png
│ │ ├── environment.md
│ │ ├── example_eval.png
│ │ ├── expManager_multieval.png
│ │ ├── experimentManager.md
│ │ ├── export_training_data.md
│ │ ├── external_lib.md
│ │ ├── gif_chain.gif
│ │ ├── logging.md
│ │ ├── read_writer_example.png
│ │ ├── save_load.md
│ │ ├── seeding.md
│ │ ├── visu_gymnasium_gif.gif
│ │ └── visualization.md
├── beginner_dev_guide.md
├── changelog.rst
├── conf.py
├── contributing.md
├── contributors.rst
├── index.md
├── installation.md
├── make.bat
├── markdown_to_py.sh
├── requirements.txt
├── templates
│ ├── class.rst
│ ├── class_with_call.rst
│ ├── function.rst
│ ├── nice_toc.md
│ └── numpydoc_docstring.rst
├── themes
│ └── scikit-learn-fork
│ │ ├── README.md
│ │ ├── javascript.html
│ │ ├── layout.html
│ │ ├── nav.html
│ │ ├── search.html
│ │ ├── static
│ │ ├── css
│ │ │ ├── theme.css
│ │ │ └── vendor
│ │ │ │ └── bootstrap.min.css
│ │ └── js
│ │ │ ├── searchtools.js
│ │ │ └── vendor
│ │ │ ├── bootstrap.min.js
│ │ │ └── jquery-3.6.3.slim.min.js
│ │ ├── theme.conf
│ │ └── toc.css
├── thumbnails
│ ├── adastop_boxplots.png
│ ├── chain_thumb.jpg
│ ├── code.png
│ ├── example_plot_atari_atlantis_vectorized_ppo.jpg
│ ├── example_plot_atari_breakout_vectorized_ppo.jpg
│ ├── experiment_manager_diagram.png
│ ├── output_9_3.png
│ ├── video_plot_a2c.jpg
│ ├── video_plot_acrobot.jpg
│ ├── video_plot_apple_gold.jpg
│ ├── video_plot_atari_freeway.jpg
│ ├── video_plot_chain.jpg
│ ├── video_plot_dqn.jpg
│ ├── video_plot_gridworld.jpg
│ ├── video_plot_mbqvi.jpg
│ ├── video_plot_mdqn.jpg
│ ├── video_plot_montain_car.jpg
│ ├── video_plot_old_gym_acrobot.jpg
│ ├── video_plot_pball.jpg
│ ├── video_plot_ppo.jpg
│ ├── video_plot_rooms.jpg
│ ├── video_plot_rs_kernel_ucbvi.jpg
│ ├── video_plot_rsucbvi.jpg
│ ├── video_plot_springcartpole.jpg
│ ├── video_plot_twinrooms.jpg
│ └── video_plot_vi.jpg
├── user_guide.md
├── user_guide2.rst
└── versions.rst
├── examples
├── README.md
├── adastop_example.py
├── comparison_agents.py
├── demo_agents
│ ├── README.md
│ ├── demo_SAC.py
│ ├── gym_videos
│ │ ├── openaigym.episode_batch.0.454210.stats.json
│ │ ├── openaigym.manifest.0.454210.manifest.json
│ │ ├── openaigym.video.0.454210.video000000.meta.json
│ │ └── openaigym.video.0.454210.video000000.mp4
│ ├── video_plot_a2c.py
│ ├── video_plot_dqn.py
│ ├── video_plot_mbqvi.py
│ ├── video_plot_mdqn.py
│ ├── video_plot_ppo.py
│ ├── video_plot_rs_kernel_ucbvi.py
│ ├── video_plot_rsucbvi.py
│ └── video_plot_vi.py
├── demo_bandits
│ ├── README.md
│ ├── plot_TS_bandit.py
│ ├── plot_compare_index_bandits.py
│ ├── plot_exp3_bandit.py
│ ├── plot_mirror_bandit.py
│ └── plot_ucb_bandit.py
├── demo_env
│ ├── README.md
│ ├── example_atari_atlantis_vectorized_ppo.py
│ ├── example_atari_breakout_vectorized_ppo.py
│ ├── video_plot_acrobot.py
│ ├── video_plot_apple_gold.py
│ ├── video_plot_atari_freeway.py
│ ├── video_plot_chain.py
│ ├── video_plot_gridworld.py
│ ├── video_plot_mountain_car.py
│ ├── video_plot_old_gym_compatibility_wrapper_old_acrobot.py
│ ├── video_plot_pball.py
│ ├── video_plot_rooms.py
│ ├── video_plot_springcartpole.py
│ └── video_plot_twinrooms.py
├── demo_experiment
│ ├── params_experiment.yaml
│ ├── room.yaml
│ ├── rsucbvi.yaml
│ ├── rsucbvi_alternative.yaml
│ └── run.py
├── example_venv.py
├── plot_agent_manager.py
├── plot_checkpointing.py
├── plot_kernels.py
├── plot_smooth.py
└── plot_writer_wrapper.py
├── pyproject.toml
├── rlberry
├── __init__.py
├── agents
│ ├── __init__.py
│ ├── agent.py
│ ├── stable_baselines
│ │ ├── __init__.py
│ │ └── stable_baselines.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_replay.py
│ │ └── test_stable_baselines.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── replay.py
│ │ └── replay_utils.py
├── check_packages.py
├── envs
│ ├── __init__.py
│ ├── basewrapper.py
│ ├── finite_mdp.py
│ ├── gym_make.py
│ ├── interface
│ │ ├── __init__.py
│ │ └── model.py
│ ├── pipeline.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_env_seeding.py
│ │ ├── test_gym_env_seeding.py
│ │ └── test_gym_make.py
│ └── utils.py
├── experiment
│ ├── __init__.py
│ ├── generator.py
│ ├── load_results.py
│ ├── tests
│ │ ├── params_experiment.yaml
│ │ ├── room.yaml
│ │ ├── rsucbvi.yaml
│ │ ├── rsucbvi_alternative.yaml
│ │ ├── test_experiment_generator.py
│ │ └── test_load_results.py
│ └── yaml_utils.py
├── manager
│ ├── __init__.py
│ ├── comparison.py
│ ├── env_tools.py
│ ├── evaluation.py
│ ├── experiment_manager.py
│ ├── multiple_managers.py
│ ├── plotting.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_comparisons.py
│ │ ├── test_evaluation.py
│ │ ├── test_experiment_manager.py
│ │ ├── test_experiment_manager_seeding.py
│ │ ├── test_hyperparam_optim.py
│ │ ├── test_plot.py
│ │ ├── test_shared_data.py
│ │ ├── test_utils.py
│ │ └── test_venv.py
│ └── utils.py
├── metadata_utils.py
├── rendering
│ ├── __init__.py
│ ├── common_shapes.py
│ ├── core.py
│ ├── opengl_render2d.py
│ ├── pygame_render2d.py
│ ├── render_interface.py
│ ├── tests
│ │ ├── __init__.py
│ │ └── test_rendering_interface.py
│ └── utils.py
├── seeding
│ ├── __init__.py
│ ├── seeder.py
│ ├── seeding.py
│ └── tests
│ │ ├── __init__.py
│ │ ├── test_seeding.py
│ │ ├── test_threads.py
│ │ └── test_threads_torch.py
├── spaces
│ ├── __init__.py
│ ├── box.py
│ ├── dict.py
│ ├── discrete.py
│ ├── from_gym.py
│ ├── multi_binary.py
│ ├── multi_discrete.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_from_gym.py
│ │ └── test_spaces.py
│ └── tuple.py
├── tests
│ ├── __init__.py
│ ├── test_agent_extra.py
│ ├── test_agents_base.py
│ ├── test_envs.py
│ ├── test_imports.py
│ └── test_rlberry_main_agents_and_env.py
├── types.py
├── utils
│ ├── __init__.py
│ ├── binsearch.py
│ ├── check_agent.py
│ ├── check_env.py
│ ├── check_gym_env.py
│ ├── factory.py
│ ├── loading_tools.py
│ ├── logging.py
│ ├── space_discretizer.py
│ ├── tests
│ │ ├── __init__.py
│ │ ├── test_binsearch.py
│ │ ├── test_check.py
│ │ ├── test_loading_tools.py
│ │ └── test_writer.py
│ ├── torch.py
│ └── writers.py
└── wrappers
│ ├── __init__.py
│ ├── autoreset.py
│ ├── discrete2onehot.py
│ ├── discretize_state.py
│ ├── gym_utils.py
│ ├── rescale_reward.py
│ ├── tests
│ ├── __init__.py
│ ├── old_env
│ │ ├── __init__.py
│ │ ├── old_acrobot.py
│ │ ├── old_apple_gold.py
│ │ ├── old_ball2d.py
│ │ ├── old_finite_mdp.py
│ │ ├── old_four_room.py
│ │ ├── old_gridworld.py
│ │ ├── old_mountain_car.py
│ │ ├── old_nroom.py
│ │ ├── old_pball.py
│ │ ├── old_pendulum.py
│ │ ├── old_six_room.py
│ │ └── old_twinrooms.py
│ ├── test_basewrapper.py
│ ├── test_common_wrappers.py
│ ├── test_gym_space_conversion.py
│ ├── test_utils.py
│ ├── test_wrapper_seeding.py
│ └── test_writer_utils.py
│ ├── uncertainty_estimator_wrapper.py
│ ├── utils.py
│ └── writer_utils.py
└── scripts
├── apptainer_for_tests
├── README.md
├── monthly_test_base.sh
├── rlberry_apptainer__specific_python.def
└── rlberry_apptainer_base.def
├── build_docs.sh
├── conda_env_setup.sh
├── construct_video_examples.sh
├── fetch_contributors.py
├── full_install.sh
└── run_testscov.sh
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report about: Create a report to help us improve title: ''
3 | labels: ''
4 | assignees: ''
5 |
6 | ---
7 |
8 | **Describe the bug**
9 | A clear and concise description of what the bug is.
10 |
11 | **To Reproduce**
12 | Steps to reproduce the behavior.
13 |
14 | **Expected behavior**
15 | A clear and concise description of what you expected to happen.
16 |
17 | **Screenshots**
18 | If applicable, add screenshots to help explain your problem.
19 |
20 | **Desktop (please complete the following information):**
21 |
22 | - OS: \[e.g. iOS]
23 | - Version \[e.g. 22]
24 | - Python version - PyTorch version
25 |
26 | **Additional context**
27 | Add any other context about the problem here.
28 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request about: Suggest an idea for this project title: ''
3 | labels: ''
4 | assignees: ''
5 |
6 | ---
7 |
8 | **Is your feature request related to a problem? Please describe.**
9 | A clear and concise description of what the problem is. Ex. I'm always frustrated when \[...]
10 |
11 | **Describe the solution you'd like**
12 | A clear and concise description of what you want to happen.
13 |
14 | **Describe alternatives you've considered**
15 | A clear and concise description of any alternative solutions or features you've considered.
16 |
17 | **Additional context**
18 | Add any other context or screenshots about the feature request here.
19 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/other_issue.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Other issue about: Create an issue that is not a bug report nor a feature request title: ''
3 | labels: ''
4 | assignees: ''
5 |
6 | ---
7 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Description
4 |
5 | Please include a summary of the change and which issue is fixed.
6 | Please also include relevant motivation and context, in particular link the relevant issue if it is appropriate.
7 | List any dependencies that are required for this change.
8 |
9 |
24 |
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # To get started with Dependabot version updates, you'll need to specify which
2 | # package ecosystems to update and where the package manifests are located.
3 | # Please see the documentation for all configuration options:
4 | # https://help.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
5 |
6 | version: 2
7 | updates:
8 | - package-ecosystem: "pip" # See documentation for possible values
9 | directory: "/" # Location of package manifests
10 | schedule:
11 | interval: "daily"
12 |
--------------------------------------------------------------------------------
/.github/workflows/doc_dev.yml:
--------------------------------------------------------------------------------
1 | name: documentation_dev
2 | on:
3 | pull_request_target:
4 | branches:
5 | - main
6 | types: [closed]
7 | push:
8 | branches:
9 | - main
10 |
11 |
12 | permissions:
13 | contents: write
14 |
15 | jobs:
16 | docs:
17 | runs-on: ubuntu-latest
18 | steps:
19 | - name: Wait for version update
20 | run: |
21 | sleep 60
22 | - name: Checkout
23 | uses: actions/checkout@v4
24 | - uses: actions/setup-python@v4
25 | with:
26 | python-version: '3.10'
27 |
28 | - name: Install and configure Poetry
29 | uses: snok/install-poetry@v1
30 |
31 | - name: Install dependencies
32 | run: |
33 | curl -C - https://raw.githubusercontent.com/rlberry-py/rlberry/main/pyproject.toml > pyproject.toml
34 | poetry sync --all-extras --with dev
35 | - name: Sphinx build
36 | run: |
37 | poetry run sphinx-build docs _build
38 | - uses: actions/checkout@v4
39 | with:
40 | # This is necessary so that we have the tags.
41 | fetch-depth: 0
42 | ref: gh-pages
43 | path: gh_pages
44 | - name: copy stable and preview version changes
45 | run: |
46 | cp -rv gh_pages/stable _build/stable || echo "Ignoring exit status"
47 | cp -rv gh_pages/preview_pr _build/preview_pr || echo "Ignoring exit status"
48 | - name: Deploy to GitHub Pages
49 | uses: peaceiris/actions-gh-pages@v3
50 | if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
51 | with:
52 | publish_branch: gh-pages
53 | github_token: ${{ secrets.GITHUB_TOKEN }}
54 | publish_dir: _build/
55 | force_orphan: true
56 |
--------------------------------------------------------------------------------
/.github/workflows/doc_stable.yml:
--------------------------------------------------------------------------------
1 | name: documentation_stable
2 | on:
3 | push:
4 | # Pattern matched against refs/tags
5 | tags:
6 | - '*' # Push events to every tag not containing /
7 |
8 | permissions:
9 | contents: write
10 |
11 | jobs:
12 | docs:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 | with:
17 | fetch-depth: 0
18 | fetch-tags: true
19 | path: main
20 | - name: checkout latest
21 | run: |
22 | cd main
23 | git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
24 | cd ..
25 | - uses: actions/setup-python@v4
26 | with:
27 | python-version: '3.10'
28 | - name: Install and configure Poetry
29 | uses: snok/install-poetry@v1
30 |
31 | - name: Install dependencies
32 | run: |
33 | cd main
34 | poetry sync --all-extras --with dev
35 | - name: Sphinx build
36 | run: |
37 | poetry run sphinx-build docs ../_build
38 | cd ..
39 | - uses: actions/checkout@v4
40 | with:
41 | # This is necessary so that we have the tags.
42 | fetch-depth: 0
43 | ref: gh-pages
44 | path: gh_pages
45 | - name: Commit documentation changes
46 | run: |
47 | cd gh_pages
48 | rm -r stable || echo "Ignoring exit status"
49 | mkdir stable
50 | cp -rv ../_build/* stable
51 | git config user.name github-actions
52 | git config user.email github-actions@github.com
53 | git add .
54 | git commit -m "Documentation Stable"
55 | git push
56 |
--------------------------------------------------------------------------------
/.github/workflows/post_merge.yml:
--------------------------------------------------------------------------------
1 | name: Version Control
2 | on:
3 | pull_request_target:
4 | branches:
5 | - main
6 | types: [closed]
7 |
8 | jobs:
9 | version_lock_job:
10 | if: github.event.pull_request.merged == true
11 | runs-on: ubuntu-latest
12 | steps:
13 | - uses: actions/setup-python@v4
14 | with:
15 | python-version: '3.10'
16 | - uses: actions/checkout@v4
17 | with:
18 | # This is necessary so that we have the tags.
19 | fetch-depth: 0
20 | ref: main
21 | - uses: mtkennerly/dunamai-action@v1
22 | with:
23 | env-var: MY_VERSION
24 | - run: echo $MY_VERSION
25 | - uses: snok/install-poetry@v1
26 | - run: poetry version "v$MY_VERSION "
27 | - run: poetry lock
28 | - uses: EndBug/add-and-commit@v8
29 | with:
30 | add: '["pyproject.toml", "poetry.lock"]'
31 | default_author: github_actor
32 | message: 'Writing version and lock with github action [skip ci]'
33 |
--------------------------------------------------------------------------------
/.github/workflows/ready_for_CI.yml:
--------------------------------------------------------------------------------
1 | name: ready_for_CI
2 |
3 | on:
4 | pull_request:
5 | types: [labeled, opened, reopened, synchronize]
6 |
7 | jobs:
8 | build:
9 | if: contains( github.event.pull_request.labels.*.name, 'ready for CI')
10 | runs-on: ubuntu-latest
11 |
12 | steps:
13 | - uses: actions/checkout@v4
14 | - name: Run a one-line script
15 | run: echo ready for CI!
16 | - name: Save the PR number in an artifact
17 | shell: bash
18 | env:
19 | PR_NUM: ${{ github.event.number }}
20 | run: echo $PR_NUM > pr_num.txt
21 |
22 | - name: Upload the PR number
23 | uses: actions/upload-artifact@v4
24 | with:
25 | name: pr_num
26 | path: ./pr_num.txt
27 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # test markdown
2 | docs/python_scripts/
3 |
4 | #test venvs
5 | rlberry_venvs
6 |
7 | # tensorboard runs
8 | runs/
9 |
10 | # videos
11 | #*.mp4
12 | notebooks/videos/*
13 |
14 | # pickled objects and csv
15 | *.pickle
16 | *.csv
17 |
18 | # test coverage folder
19 | cov_html/*
20 |
21 | # dev and results folder
22 | dev/*
23 | results/*
24 | temp/*
25 | rlberry_data/*
26 | */rlberry_data/*
27 |
28 |
29 | # Byte-compiled / optimized / DLL files
30 | __pycache__/
31 | *.py[cod]
32 | *$py.class
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | pip-wheel-metadata/
52 | share/python-wheels/
53 | *.egg-info/
54 | .installed.cfg
55 | *.egg
56 | MANIFEST
57 |
58 | # PyInstaller
59 | # Usually these files are written by a python script from a template
60 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
61 | *.manifest
62 | *.spec
63 |
64 | # Installer logs
65 | pip-log.txt
66 | pip-delete-this-directory.txt
67 |
68 | # Unit test / coverage reports
69 | htmlcov/
70 | .tox/
71 | .nox/
72 | .coverage
73 | .coverage.*
74 | .cache
75 | nosetests.xml
76 | coverage.xml
77 | *.cover
78 | *.py,cover
79 | .hypothesis/
80 | .pytest_cache/
81 | profile.prof
82 |
83 | # Translations
84 | *.mo
85 | *.pot
86 |
87 | # Django stuff:
88 | *.log
89 | local_settings.py
90 | db.sqlite3
91 | db.sqlite3-journal
92 |
93 | # Flask stuff:
94 | instance/
95 | .webassets-cache
96 |
97 | # Scrapy stuff:
98 | .scrapy
99 |
100 | # Sphinx documentation
101 | docs/_build/
102 | docs/generated
103 | docs/auto_examples
104 |
105 |
106 |
107 | # PyBuilder
108 | target/
109 |
110 | # Jupyter Notebook
111 | .ipynb_checkpoints
112 |
113 | # IPython
114 | profile_default/
115 | ipython_config.py
116 |
117 | # pyenv
118 | .python-version
119 |
120 | # pipenv
121 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
122 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
123 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
124 | # install all needed dependencies.
125 | #Pipfile.lock
126 |
127 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
128 | __pypackages__/
129 |
130 | # Celery stuff
131 | celerybeat-schedule
132 | celerybeat.pid
133 |
134 | # SageMath parsed files
135 | *.sage.py
136 |
137 | # Environments
138 | .env
139 | .venv
140 | env/
141 | venv/
142 | ENV/
143 | env.bak/
144 | venv.bak/
145 |
146 | # Spyder project settings
147 | .spyderproject
148 | .spyproject
149 |
150 | # Rope project settings
151 | .ropeproject
152 |
153 | # mkdocs documentation
154 | /site
155 |
156 | # mypy
157 | .mypy_cache/
158 | .dmypy.json
159 | dmypy.json
160 |
161 | # Pyre type checker
162 | .pyre/
163 |
164 | # macOS
165 | .DS_Store
166 |
167 | # vscode
168 | .vscode
169 |
170 | # PyCharm
171 | .idea
172 | .project
173 | .pydevproject
174 |
175 |
176 | *.prof
177 |
178 | # poetry.lock
179 | poetry.lock
180 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v4.4.0
6 | hooks:
7 | - id: trailing-whitespace
8 | - id: end-of-file-fixer
9 | - id: check-yaml
10 | - id: check-docstring-first
11 |
12 | - repo: https://github.com/psf/black
13 | rev: 23.9.1
14 | hooks:
15 | - id: black
16 |
17 | - repo: https://github.com/asottile/blacken-docs
18 | rev: 1.16.0
19 | hooks:
20 | - id: blacken-docs
21 | additional_dependencies: [black==23.1.0]
22 |
23 | - repo: https://github.com/pycqa/flake8
24 | rev: 6.1.0
25 | hooks:
26 | - id: flake8
27 | additional_dependencies: [flake8-docstrings]
28 | types: [file, python]
29 | exclude: (.*/__init__.py|rlberry/check_packages.py)
30 | args: ['--select=F401,F405,D410,D411,D412']
31 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Domingues"
5 | given-names: "Omar Darwiche"
6 | - family-names: "Flet-Berliac"
7 | given-names: "Yannis"
8 | - family-names: "Leurent"
9 | given-names: "Edouard"
10 | - family-names: "Ménard"
11 | given-names: "Pierre"
12 | - family-names: "Shang"
13 | given-names: "Xuedong"
14 | - family-names: "Valko"
15 | given-names: "Michal"
16 |
17 | title: "rlberry - A Reinforcement Learning Library for Research and Education"
18 | abbreviation: rlberry
19 | version: 0.2.2-dev
20 | doi: 10.5281/zenodo.5223307
21 | date-released: 2021-10-01
22 | url: "https://github.com/rlberry-py/rlberry"
23 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 rlberry-py
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include pyproject.toml
2 |
3 | recursive-include assets *.svg
4 |
--------------------------------------------------------------------------------
/assets/logo_avatar.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/assets/logo_avatar.png
--------------------------------------------------------------------------------
/codecov.yml:
--------------------------------------------------------------------------------
1 |
2 | comment: false
3 |
4 | coverage:
5 | precision: 2
6 | round: down
7 | range: "70...100"
8 | status:
9 | project:
10 | default:
11 | # basic
12 | target: auto
13 | threshold: 1% # allow coverage to drop at most 1%
14 |
15 | parsers:
16 | gcov:
17 | branch_detection:
18 | conditional: yes
19 | loop: yes
20 | method: no
21 | macro: no
22 |
23 | ignore:
24 | - "./rlberry/wrappers/tests/old_env/*.py"
25 | - "./rlberry/utils/torch.py"
26 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | # clean example gallery files
16 | clean:
17 | rm -rf $(BUILDDIR)/*
18 | rm -rf auto_examples/
19 |
20 | # Script used to construct the videos for the examples that output videos.
21 | # Please use a script name that begins with video_plot as with the other examples
22 | # and in the script, there should be a line to save the video in the right place
23 | # and a line to load the video in the headers. Look at existing examples for
24 | # the correct syntax. Be careful that you must remove the _build folder before
25 | # recompiling the doc when a video has been updated/added.
26 | video:
27 | # Make videos
28 | $(foreach file, $(wildcard ../examples/**/video_plot*.py), \
29 | @echo $(basename $(notdir $(file)));\
30 | python $(file)) ;\
31 | # Make thumbnails
32 | $(foreach file, $(wildcard _video/*.mp4), \
33 | ffmpeg -y -i $(file) -vframes 1 -f image2 \
34 | thumbnails/$(basename $(notdir $(file))).jpg ;\
35 | )
36 | # Remove unused metadata json
37 | @rm _video/*.json
38 |
39 | thumbnail_images:
40 | $(foreach file, $(wildcard _video/*.mp4), \
41 | ffmpeg -y -i $(file) -vframes 1 -f image2 \
42 | thumbnails/$(basename $(notdir $(file))).jpg ;\
43 | )
44 |
45 | .PHONY: help Makefile
46 |
47 | # Catch-all target: route all unknown targets to Sphinx using the new
48 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
49 | %: Makefile
50 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
51 |
--------------------------------------------------------------------------------
/docs/_video/example_plot_atari_atlantis_vectorized_ppo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/example_plot_atari_atlantis_vectorized_ppo.mp4
--------------------------------------------------------------------------------
/docs/_video/example_plot_atari_breakout_vectorized_ppo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/example_plot_atari_breakout_vectorized_ppo.mp4
--------------------------------------------------------------------------------
/docs/_video/user_guide_video/_agent_page_CartPole.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_agent_page_CartPole.mp4
--------------------------------------------------------------------------------
/docs/_video/user_guide_video/_agent_page_chain1.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_agent_page_chain1.mp4
--------------------------------------------------------------------------------
/docs/_video/user_guide_video/_agent_page_chain2.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_agent_page_chain2.mp4
--------------------------------------------------------------------------------
/docs/_video/user_guide_video/_agent_page_frozenLake.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_agent_page_frozenLake.mp4
--------------------------------------------------------------------------------
/docs/_video/user_guide_video/_env_page_Breakout.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_env_page_Breakout.mp4
--------------------------------------------------------------------------------
/docs/_video/user_guide_video/_env_page_MountainCar.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_env_page_MountainCar.mp4
--------------------------------------------------------------------------------
/docs/_video/user_guide_video/_env_page_chain.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_env_page_chain.mp4
--------------------------------------------------------------------------------
/docs/_video/user_guide_video/_experimentManager_page_CartPole.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/user_guide_video/_experimentManager_page_CartPole.mp4
--------------------------------------------------------------------------------
/docs/_video/video_chain_quickstart.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_chain_quickstart.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_a2c.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_a2c.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_acrobot.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_acrobot.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_apple_gold.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_apple_gold.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_atari_freeway.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_atari_freeway.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_chain.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_chain.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_dqn.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_dqn.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_gridworld.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_gridworld.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_mbqvi.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_mbqvi.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_mdqn.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_mdqn.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_montain_car.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_montain_car.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_old_gym_acrobot.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_old_gym_acrobot.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_pball.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_pball.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_ppo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_ppo.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_rooms.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_rooms.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_rs_kernel_ucbvi.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_rs_kernel_ucbvi.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_rsucbvi.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_rsucbvi.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_springcartpole.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_springcartpole.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_twinrooms.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_twinrooms.mp4
--------------------------------------------------------------------------------
/docs/_video/video_plot_vi.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/_video/video_plot_vi.mp4
--------------------------------------------------------------------------------
/docs/about.rst:
--------------------------------------------------------------------------------
1 | .. _about:
2 |
3 | About us
4 | ========
5 |
6 | This project was initiated and is actively maintained by
7 | `INRIA SCOOL team `_
8 |
9 | Contributors
10 | ------------
11 |
12 | The following people contributed actively to rlberry.
13 |
14 | .. include:: contributors.rst
15 |
16 |
17 | .. _citing-rlberry:
18 |
19 | Citing rlberry
20 | --------------
21 |
22 | If you use rlberry in scientific publications, we would appreciate citations using the following Bibtex entry::
23 |
24 | @misc{rlberry,
25 | author = {Domingues, Omar Darwiche and Flet-Berliac, Yannis and Leurent, Edouard and M{\'e}nard, Pierre and Shang, Xuedong and Valko, Michal},
26 | doi = {10.5281/zenodo.5544540},
27 | month = {10},
28 | title = {{rlberry - A Reinforcement Learning Library for Research and Education}},
29 | url = {https://github.com/rlberry-py/rlberry},
30 | year = {2021}
31 | }
32 |
33 |
34 |
35 | Funding
36 | -------
37 |
38 | The project would like to thank the following for their participation in funding
39 | PhD's and research project that made this happen.
40 |
41 | - Inria and in particular Inria, Scool for the working environment
42 | - Université de Lille
43 | - I-SITE ULNE
44 | - ANR
45 | - ANRT
46 | - Renault
47 | - European CHIST-ERA project DELTA
48 | - La Région Hauts-de-France and the MEL
49 |
--------------------------------------------------------------------------------
/docs/basics/DeepRLTutorial/output_10_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/DeepRLTutorial/output_10_3.png
--------------------------------------------------------------------------------
/docs/basics/DeepRLTutorial/output_5_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/DeepRLTutorial/output_5_3.png
--------------------------------------------------------------------------------
/docs/basics/DeepRLTutorial/output_6_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/DeepRLTutorial/output_6_3.png
--------------------------------------------------------------------------------
/docs/basics/DeepRLTutorial/output_9_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/DeepRLTutorial/output_9_3.png
--------------------------------------------------------------------------------
/docs/basics/agent_manager_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/agent_manager_diagram.png
--------------------------------------------------------------------------------
/docs/basics/create_agent.rst:
--------------------------------------------------------------------------------
1 | .. _rlberry: https://github.com/rlberry-py/rlberry
2 |
3 | .. _create_agent:
4 |
5 |
6 | Create an agent
7 | ===============
8 |
9 | rlberry_ requires you to use a **very simple interface** to write agents, with basically
10 | two methods to implement: :code:`fit()` and :code:`eval()`.
11 |
12 | The example below shows how to create an agent.
13 |
14 |
15 | .. code-block:: python
16 |
17 | import numpy as np
18 | from rlberry.agents import Agent
19 |
20 |
21 | class MyAgent(Agent):
22 | name = "MyAgent"
23 |
24 | def __init__(
25 | self, env, param1=0.99, param2=1e-5, **kwargs
26 | ): # it's important to put **kwargs to ensure compatibility with the base class
27 | # self.env is initialized in the base class
28 | # An evaluation environment is also initialized: self.eval_env
29 | Agent.__init__(self, env, **kwargs)
30 |
31 | self.param1 = param1
32 | self.param2 = param2
33 |
34 | def fit(self, budget, **kwargs):
35 | """
36 | The parameter budget can represent the number of steps, the number of episodes etc,
37 | depending on the agent.
38 | * Interact with the environment (self.env);
39 | * Train the agent
40 | * Return useful information
41 | """
42 | n_episodes = budget
43 | rewards = np.zeros(n_episodes)
44 |
45 | for ep in range(n_episodes):
46 | state, info = self.env.reset()
47 | done = False
48 | while not done:
49 | action = ...
50 | observation, reward, terminated, truncated, info = self.env.step(action)
51 | done = terminated or truncated
52 | rewards[ep] += reward
53 |
54 | info = {"episode_rewards": rewards}
55 | return info
56 |
57 | def eval(self, **kwargs):
58 | """
59 | Returns a value corresponding to the evaluation of the agent on the
60 | evaluation environment.
61 |
62 | For instance, it can be a Monte-Carlo evaluation of the policy learned in fit().
63 | """
64 | return 0.0
65 |
66 |
67 | .. note:: It's important that your agent accepts optional `**kwargs` and pass it to the base class as :code:`Agent.__init__(self, env, **kwargs)`.
68 |
69 |
70 | .. seealso::
71 | Documentation of the classes :class:`~rlberry.agents.agent.Agent`
72 | and :class:`~rlberry.agents.agent.AgentWithSimplePolicy`.
73 |
--------------------------------------------------------------------------------
/docs/basics/experiment_setup.rst:
--------------------------------------------------------------------------------
1 | .. _rlberry: https://github.com/rlberry-py/rlberry
2 |
3 | .. _experiment_setup:
4 |
5 |
6 | Setup and run experiments using yaml config files
7 | =================================================
8 |
9 |
10 | To setup an experiment with rlberry, you can use yaml files. You'll need:
11 |
12 | * An **experiment.yaml** with some global parameters: seed, number of episodes, horizon, environments (for training and evaluation) and a list of agents to run.
13 |
14 | * yaml files describing the environments and the agents
15 |
16 | * A main python script that reads the files and generates :class:`~rlberry.manager.experiment_manager.ExperimentManager` instances to run each agent.
17 |
18 |
19 | This can be done very succinctly as in the example below:
20 |
21 |
22 | **experiment.yaml**
23 |
24 | .. code-block:: yaml
25 |
26 | description: 'RSUCBVI in NRoom'
27 | seed: 123
28 | train_env: 'examples/demo_experiment/room.yaml'
29 | eval_env: 'examples/demo_experiment/room.yaml'
30 | agents:
31 | - 'examples/demo_experiment/rsucbvi.yaml'
32 | - 'examples/demo_experiment/rsucbvi_alternative.yaml'
33 |
34 |
35 | **room.yaml**
36 |
37 | .. code-block:: yaml
38 |
39 | constructor: 'rlberry_research.envs.benchmarks.grid_exploration.nroom.NRoom'
40 | params:
41 | reward_free: false
42 | array_observation: true
43 | nrooms: 5
44 |
45 | **rsucbvi.yaml**
46 |
47 | .. code-block:: yaml
48 |
49 | agent_class: 'rlberry_research.agents.kernel_based.rs_ucbvi.RSUCBVIAgent'
50 | init_kwargs:
51 | gamma: 1.0
52 | lp_metric: 2
53 | min_dist: 0.0
54 | max_repr: 800
55 | bonus_scale_factor: 1.0
56 | reward_free: True
57 | horizon: 50
58 | eval_kwargs:
59 | eval_horizon: 50
60 | fit_kwargs:
61 | fit_budget: 100
62 |
63 | **rsucbvi_alternative.yaml**
64 |
65 | .. code-block:: yaml
66 |
67 | base_config: 'examples/demo_experiment/rsucbvi.yaml'
68 | init_kwargs:
69 | gamma: 0.9
70 |
71 |
72 |
73 | **run.py**
74 |
75 | .. code-block:: python
76 |
77 | """
78 | To run the experiment:
79 |
80 | $ python run.py experiment.yaml
81 |
82 | To see more options:
83 |
84 | $ python run.py -h
85 | """
86 |
87 | from rlberry.experiment import experiment_generator
88 | from rlberry.manager.multiple_managers import MultipleManagers
89 |
90 | multimanagers = MultipleManagers()
91 |
92 | for experiment_manager in experiment_generator():
93 | multimanagers.append(experiment_manager)
94 |
95 | # Alternatively:
96 | # experiment_manager.fit()
97 | # experiment_manager.save()
98 |
99 | multimanagers.run()
100 | multimanagers.save()
101 |
--------------------------------------------------------------------------------
/docs/basics/multiprocess.rst:
--------------------------------------------------------------------------------
1 | .. _multiprocess:
2 |
3 | Parallelization in rlberry
4 | ==========================
5 |
6 | rlberry use python's standard multiprocessing library to execute the fit of agents in parallel on cpus. The parallelization is done via
7 | :class:`~rlberry.manager.ExperimentManager` and via :class:`~rlberry.manager.MultipleManagers`.
8 |
9 | If a user wants to use a third-party parallelization library like joblib, the user must be aware of where the seeding is done so as not to bias the results. rlberry automatically handles seeding when the native parallelization scheme are used.
10 |
11 | Several multiprocessing scheme are implemented in rlberry.
12 |
13 | Threading
14 | ---------
15 |
16 | Thread multiprocessing "constructs higher-level threading interfaces on top of the lower level _thread module" (see the doc on `python's website `_). This is the default scheme in rlberry, most of the time it will result in
17 | having practically no parallelization except if the code executed in each thread (i.e. each fit) is executed without GIL (example: cython code or numpy code).
18 |
19 | Process: spawn or forkserver
20 | ----------------------------
21 |
22 | To have an efficient parallelization, it is often better to use processes (see the doc on `python's website `_) using the parameter :code:`parallelization="process"` in :class:`~rlberry.manager.ExperimentManager` or :class:`~rlberry.manager.MultipleManagers`.
23 |
24 | This implies that a new process will be launched for each fit of the ExperimentManager.
25 |
26 | The advised method of parallelization is spawn (parameter :code:`mp_context="spawn"`), however spawn method has several drawbacks:
27 |
28 | - The fit code needs to be encapsulated in a :code:`if __name__ == '__main__'` directive. Example :
29 |
30 | .. code:: python
31 |
32 | from rlberry_research.agents.torch import A2CAgent
33 | from rlberry.manager import ExperimentManager
34 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D
35 |
36 | n_steps = 1e5
37 | batch_size = 256
38 |
39 | if __name__ == "__main__":
40 | manager = ExperimentManager(
41 | A2CAgent,
42 | (PBall2D, {}),
43 | init_kwargs=dict(batch_size=batch_size, gamma=0.99, learning_rate=0.001),
44 | n_fit=4,
45 | fit_budget=n_steps,
46 | parallelization="process",
47 | mp_context="spawn",
48 | )
49 | manager.fit()
50 |
51 | - As a consequence, :code:`spawn` parallelization only works if called from the main script.
52 | - :code:`spawn` does not work when called from a notebook. To work in a notebook, use :code:`fork` instead.
53 | - :code:`forkserver` is an alternative to :code:`spawn` that performs sometimes faster than :code:`spawn`. :code:`forkserver` parallelization must also be encapsulated into a :code:`if __name__ == '__main__'` directive and for now it is available only on Unix systems (MacOS, Linux, ...).
54 |
55 |
56 | Process: fork
57 | -------------
58 |
59 | Fork multiprocessing is only possible on Unix systems.
60 | It is available through the parameter :code:`mp_context="fork"` when :code:`parallelization="process"`.
61 | Remark that there could be some logging error and hanging when using :code:`fork`. The usage of fork in rlberry is still experimental and may be unstable.
62 |
--------------------------------------------------------------------------------
/docs/basics/quick_start_rl/Figure_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/Figure_1.png
--------------------------------------------------------------------------------
/docs/basics/quick_start_rl/Figure_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/Figure_2.png
--------------------------------------------------------------------------------
/docs/basics/quick_start_rl/Figure_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/Figure_3.png
--------------------------------------------------------------------------------
/docs/basics/quick_start_rl/experiment_manager_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/experiment_manager_diagram.png
--------------------------------------------------------------------------------
/docs/basics/quick_start_rl/gif_chain.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/gif_chain.gif
--------------------------------------------------------------------------------
/docs/basics/quick_start_rl/video_chain.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/quick_start_rl/video_chain.mp4
--------------------------------------------------------------------------------
/docs/basics/seeding.rst:
--------------------------------------------------------------------------------
1 | .. _seeding:
2 |
3 | Seeding & Reproducibility
4 | ==========================
5 |
6 | rlberry_ has a class :class:`~rlberry.seeding.seeder.Seeder` that conveniently wraps a `NumPy SeedSequence `_,
7 | and allows us to create independent random number generators for different objects and threads, using a single
8 | :class:`~rlberry.seeding.seeder.Seeder` instance.
9 |
10 | It works as follows:
11 |
12 |
13 | .. code-block:: python
14 |
15 | from rlberry.seeding import Seeder
16 |
17 | seeder = Seeder(123)
18 |
19 | # Each Seeder instance has a random number generator (rng)
20 | # See https://numpy.org/doc/stable/reference/random/generator.html to check the
21 | # methods available in rng.
22 | seeder.rng.integers(5)
23 | seeder.rng.normal()
24 | print(type(seeder.rng))
25 | # etc
26 |
27 | # Environments and agents should be seeded using a single seeder,
28 | # to ensure that their random number generators are independent.
29 | from rlberry.envs import gym_make
30 | from rlberry.agents import RSUCBVIAgent
31 |
32 | env = gym_make("MountainCar-v0")
33 | env.reseed(seeder)
34 |
35 | agent = RSUCBVIAgent(env)
36 | agent.reseed(seeder)
37 |
38 |
39 | # Environments and Agents have their own seeder and rng.
40 | # When writing your own agents and inheriting from the Agent class,
41 | # you should use agent.rng whenever you need to generate random numbers;
42 | # the same applies to your environments.
43 | # This is necessary to ensure reproducibility.
44 | print("env seeder: ", env.seeder)
45 | print("random sample from env rng: ", env.rng.normal())
46 | print("agent seeder: ", agent.seeder)
47 | print("random sample from agent rng: ", agent.rng.normal())
48 |
49 |
50 | # A seeder can spawn other seeders that are independent from it.
51 | # This is useful to seed two different threads, using seeder1
52 | # in the first thread, and seeder2 in the second thread.
53 | seeder1, seeder2 = seeder.spawn(2)
54 |
55 |
56 | # You can also use a seeder to seed external libraries (such as torch)
57 | # using the function set_external_seed
58 | from rlberry.seeding import set_external_seed
59 |
60 | set_external_seed(seeder)
61 |
62 |
63 | .. note::
64 | The class :class:`~rlberry.manager.experiment_manager.ExperimentManager` provides a :code:`seed` parameter in its constructor,
65 | and handles automatically the seeding of all environments and agents used by it.
66 |
67 | .. note::
68 |
69 | The :meth:`optimize_hyperparams` method of
70 | :class:`~rlberry.manager.experiment_manager.ExperimentManager` uses the `Optuna `_
71 | library for hyperparameter optimization and is **inherently non-deterministic**
72 | (see `Optuna FAQ `_).
73 |
--------------------------------------------------------------------------------
/docs/basics/userguide/adastop_boxplots.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/adastop_boxplots.png
--------------------------------------------------------------------------------
/docs/basics/userguide/entropy_loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/entropy_loss.png
--------------------------------------------------------------------------------
/docs/basics/userguide/example_eval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/example_eval.png
--------------------------------------------------------------------------------
/docs/basics/userguide/expManager_multieval.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/expManager_multieval.png
--------------------------------------------------------------------------------
/docs/basics/userguide/gif_chain.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/gif_chain.gif
--------------------------------------------------------------------------------
/docs/basics/userguide/read_writer_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/read_writer_example.png
--------------------------------------------------------------------------------
/docs/basics/userguide/visu_gymnasium_gif.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/basics/userguide/visu_gymnasium_gif.gif
--------------------------------------------------------------------------------
/docs/contributors.rst:
--------------------------------------------------------------------------------
1 | .. raw :: html
2 |
3 |
4 |
5 |
8 |
9 |
10 |
AleShi94
11 |
12 |
13 |
14 |
brahimdriss
15 |
16 |
17 |
18 |
Matheus M. Centa
19 |
20 |
21 |
22 |
Omar D.
23 |
24 |
25 |
26 |
Rémy Degenne
27 |
28 |
29 |
30 |
Yannis Flet-Berliac
31 |
32 |
33 |
34 |
Hector Kohler
35 |
36 |
37 |
38 |
Edouard Leurent
39 |
40 |
41 |
42 |
Pierre Ménard
43 |
44 |
45 |
46 |
Waris Radji
47 |
48 |
49 |
50 |
sauxpa
51 |
52 |
53 |
54 |
Xuedong Shang
55 |
56 |
57 |
58 |
Ju T
59 |
60 |
61 |
62 |
TimotheeMathieu
63 |
64 |
65 |
66 |
Riccardo Della Vecchia
67 |
68 |
69 |
70 |
YannBerthelot
71 |
72 |
73 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/markdown_to_py.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env -S bash
2 |
3 |
4 | mkdir python_scripts
5 |
6 | set -e
7 |
8 | shopt -s globstar
9 | list_files=$(ls $PWD/**/*.md)
10 |
11 | for script in $list_files; do
12 | echo "Processing " $script
13 | sed -n '/^```python/,/^```/ p' < $script | sed '/^```/ d' > python_scripts/${script##*/}.py
14 | done
15 |
16 | # read -p "Do you wish to execute all python scripts? (y/n)" yn
17 | # case $yn in
18 | # [Yy]* ) for f in python_scripts/*.py; do python3 $f ; done ;;
19 | # * ) exit;;
20 | # esac
21 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx<7
2 | sphinx-gallery
3 | sphinx-math-dollar
4 | numpydoc
5 | myst-parser
6 | git+https://github.com/sphinx-contrib/video
7 | matplotlib
8 | sphinx-copybutton
9 | sphinx-design
10 |
--------------------------------------------------------------------------------
/docs/templates/class.rst:
--------------------------------------------------------------------------------
1 | :mod:`{{module}}`.{{objname}}
2 | {{ underline }}==============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 |
8 | .. include:: {{module}}.{{objname}}.examples
9 |
10 | .. raw:: html
11 |
12 |
13 |
--------------------------------------------------------------------------------
/docs/templates/class_with_call.rst:
--------------------------------------------------------------------------------
1 | :mod:`{{module}}`.{{objname}}
2 | {{ underline }}===============
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 |
8 | {% block methods %}
9 | .. automethod:: __call__
10 | {% endblock %}
11 |
12 | .. include:: {{module}}.{{objname}}.examples
13 |
14 | .. raw:: html
15 |
16 |
17 |
--------------------------------------------------------------------------------
/docs/templates/function.rst:
--------------------------------------------------------------------------------
1 | :mod:`{{module}}`.{{objname}}
2 | {{ underline }}====================
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autofunction:: {{ objname }}
7 |
8 | .. include:: {{module}}.{{objname}}.examples
9 |
10 | .. raw:: html
11 |
12 |
13 |
--------------------------------------------------------------------------------
/docs/templates/numpydoc_docstring.rst:
--------------------------------------------------------------------------------
1 | {{index}}
2 | {{summary}}
3 | {{extended_summary}}
4 | {{parameters}}
5 | {{returns}}
6 | {{yields}}
7 | {{other_parameters}}
8 | {{attributes}}
9 | {{raises}}
10 | {{warns}}
11 | {{warnings}}
12 | {{see_also}}
13 | {{notes}}
14 | {{references}}
15 | {{examples}}
16 | {{methods}}
17 |
--------------------------------------------------------------------------------
/docs/themes/scikit-learn-fork/README.md:
--------------------------------------------------------------------------------
1 | Theme forked from [scikit-learn theme](https://github.com/scikit-learn/scikit-learn).
2 |
--------------------------------------------------------------------------------
/docs/themes/scikit-learn-fork/search.html:
--------------------------------------------------------------------------------
1 | {%- extends "basic/search.html" %}
2 | {% block extrahead %}
3 |
4 |
5 |
6 |
7 |
8 | {% endblock %}
9 |
--------------------------------------------------------------------------------
/docs/themes/scikit-learn-fork/theme.conf:
--------------------------------------------------------------------------------
1 | [theme]
2 | inherit = basic
3 | pygments_style = default
4 | stylesheet = css/theme.css
5 |
6 | [options]
7 | google_analytics = true
8 | mathjax_path =
9 |
--------------------------------------------------------------------------------
/docs/themes/scikit-learn-fork/toc.css:
--------------------------------------------------------------------------------
1 | ..
2 | File to ..include in a document with a big table of content, to give
3 | it 'style'
4 |
5 | .. raw:: html
6 |
7 |
38 |
--------------------------------------------------------------------------------
/docs/thumbnails/adastop_boxplots.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/adastop_boxplots.png
--------------------------------------------------------------------------------
/docs/thumbnails/chain_thumb.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/chain_thumb.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/code.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/code.png
--------------------------------------------------------------------------------
/docs/thumbnails/example_plot_atari_atlantis_vectorized_ppo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/example_plot_atari_atlantis_vectorized_ppo.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/example_plot_atari_breakout_vectorized_ppo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/example_plot_atari_breakout_vectorized_ppo.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/experiment_manager_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/experiment_manager_diagram.png
--------------------------------------------------------------------------------
/docs/thumbnails/output_9_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/output_9_3.png
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_a2c.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_a2c.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_acrobot.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_acrobot.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_apple_gold.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_apple_gold.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_atari_freeway.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_atari_freeway.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_chain.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_chain.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_dqn.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_dqn.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_gridworld.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_gridworld.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_mbqvi.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_mbqvi.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_mdqn.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_mdqn.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_montain_car.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_montain_car.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_old_gym_acrobot.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_old_gym_acrobot.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_pball.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_pball.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_ppo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_ppo.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_rooms.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_rooms.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_rs_kernel_ucbvi.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_rs_kernel_ucbvi.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_rsucbvi.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_rsucbvi.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_springcartpole.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_springcartpole.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_twinrooms.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_twinrooms.jpg
--------------------------------------------------------------------------------
/docs/thumbnails/video_plot_vi.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/docs/thumbnails/video_plot_vi.jpg
--------------------------------------------------------------------------------
/docs/user_guide.md:
--------------------------------------------------------------------------------
1 | (user_guide)=
2 |
3 |
4 | # User Guide
5 |
6 | ## Introduction
7 | Welcome to rlberry.
8 | Use rlberry's [ExperimentManager](experimentManager_page) to train, evaluate and compare rl agents.
9 | Like other popular rl libraries, rlberry also provides basic tools for plotting, multiprocessing and logging . In this user guide, we take you through the core features of rlberry and illustrate them with [examples](/auto_examples/index) and [API documentation](/api) .
10 |
11 | To run all the examples, you will need to install other libraries like "[rlberry-scool](https://github.com/rlberry-py/rlberry-scool)" (and others).
12 |
13 | The easiest way to do it is :
14 | ```none
15 | pip install rlberry[torch,extras]
16 | pip install rlberry-scool
17 | ```
18 |
19 | [rlberry-scool](https://github.com/rlberry-py/rlberry-scool) :
20 | It's the repository used for teaching purposes. These are mainly basic agents and environments, in a version that makes it easier for students to learn.
21 |
22 | You can find more details about installation [here](installation)!
23 |
24 | You can find our quick starts here :
25 | ```{toctree}
26 | :maxdepth: 2
27 | basics/quick_start_rl/quickstart.md
28 | basics/DeepRLTutorial/TutorialDeepRL.md
29 | ```
30 |
31 | ## Set up an experiment
32 | ```{include} templates/nice_toc.md
33 | ```
34 |
35 | ```{toctree}
36 | :maxdepth: 2
37 | basics/userguide/environment.md
38 | basics/userguide/agent.md
39 | basics/userguide/experimentManager.md
40 | basics/userguide/logging.md
41 | basics/userguide/visualization.md
42 | ```
43 |
44 | ## Experimenting with Deep agents
45 | [(In construction)](https://github.com/rlberry-py/rlberry/issues/459)
46 | ## Reproducibility
47 | ```{toctree}
48 | :maxdepth: 2
49 | basics/userguide/seeding.md
50 | basics/userguide/save_load.md
51 | basics/userguide/export_training_data.md
52 | ```
53 |
54 | ## Advanced Usage
55 | ```{toctree}
56 | :maxdepth: 2
57 | basics/userguide/adastop.md
58 | basics/comparison.md
59 | basics/userguide/external_lib.md
60 | ```
61 | - Custom Agents (In construction)
62 | - Custom Environments (In construction)
63 | - Transfer Learning (In construction)
64 |
65 | # Contributing to rlberry
66 | If you want to contribute to rlberry, check out [the contribution guidelines](contributing).
67 |
--------------------------------------------------------------------------------
/docs/user_guide2.rst:
--------------------------------------------------------------------------------
1 | .. title:: User guide : contents
2 |
3 | .. _user_guide2:
4 |
5 | ==========
6 | User guide
7 | ==========
8 |
9 | .. Introduction
10 | .. ============
11 | .. Welcome to rlberry. Use rlberry's ExperimentManager (add ref) to train, evaluate and compare rl agents. In addition to
12 | .. the core ExperimentManager (add ref), rlberry provides the user with a set of bandit (add ref), tabular rl (add ref), and
13 | .. deep rl agents (add ref) as well as a wrapper for stablebaselines3 (add link, and ref) agents.
14 | .. Like other popular rl libraries, rlberry also provides basic tools for plotting, multiprocessing and logging (add refs).
15 | .. In this user guide, we take you through the core features of rlberry and illustrate them with examples (add ref) and API documentation (add ref).
16 |
17 | If you are new to rlberry, check the :ref:`Tutorials` below and the :ref:`the quickstart` documentation.
18 | In the quick start, you will learn how to set up an experiment and evaluate the
19 | efficiency of different agents.
20 |
21 | For more information see :ref:`the gallery of examples`.
22 |
23 |
24 | Tutorials
25 | =========
26 |
27 | The tutorials below will present to you the main functionalities of ``rlberry`` in a few minutes.
28 |
29 |
30 |
31 |
32 | Quick start: setup an experiment and evaluate different agents
33 | --------------------------------------------------------------
34 |
35 | .. toctree::
36 | :maxdepth: 2
37 |
38 | basics/quick_start_rl/quickstart.md
39 | basics/DeepRLTutorial/TutorialDeepRL.md
40 |
41 |
42 | Agents, hyperparameter optimization and experiment setup
43 | ---------------------------------------------------------
44 |
45 | .. toctree::
46 | :maxdepth: 1
47 |
48 | basics/create_agent.rst
49 | basics/evaluate_agent.rst
50 | basics/experiment_setup.rst
51 | basics/seeding.rst
52 | basics/multiprocess.rst
53 | basics/comparison.md
54 |
55 | We also provide examples to show how to use :ref:`torch checkpointing`
56 | in rlberry and :ref:`tensorboard`
57 |
58 | Compatibility with External Libraries
59 | =====================================
60 |
61 | We provide examples to show you how to use rlberry with:
62 |
63 | - :ref:`Gymnasium `;
64 | - :ref:`Stable Baselines `.
65 |
66 |
67 | How to contribute?
68 | ==================
69 |
70 | If you want to contribute to rlberry, check out :doc:`the contribution guidelines`.
71 |
--------------------------------------------------------------------------------
/docs/versions.rst:
--------------------------------------------------------------------------------
1 | Documentation versions
2 | ======================
3 |
4 |
5 | * `Stable `_
6 | * :ref:`Development`
7 |
8 | * For dev team: `PR preview `_
9 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | .. _examples:
2 |
3 | ===================
4 | Gallery of examples
5 | ===================
6 |
--------------------------------------------------------------------------------
/examples/adastop_example.py:
--------------------------------------------------------------------------------
1 | """
2 | ===========================================
3 | Compare PPO and A2C on Acrobot with AdaStop
4 | ===========================================
5 |
6 | This example illustrate the use of adastop_comparator which uses adaptive multiple-testing to assess whether trained agents are
7 | statistically different or not.
8 |
9 | Remark that in the case where two agents are not deemed statistically different it can mean either that they are as efficient,
10 | or it can mean that there have not been enough fits to assess the variability of the agents.
11 |
12 | Results in
13 |
14 | .. code-block::
15 |
16 | [INFO] 13:35: Test finished
17 | [INFO] 13:35: Results are
18 | Agent1 vs Agent2 mean Agent1 mean Agent2 mean diff std Agent 1 std Agent 2 decisions
19 | 0 A2C vs PPO -274.274 -85.068 -189.206 185.82553 2.71784 smaller
20 |
21 |
22 | """
23 |
24 | from rlberry.envs import gym_make
25 | from stable_baselines3 import A2C, PPO
26 | from rlberry.agents.stable_baselines import StableBaselinesAgent
27 | from rlberry.manager import AdastopComparator
28 |
29 | env_ctor, env_kwargs = gym_make, dict(id="Acrobot-v1")
30 |
31 | managers = [
32 | {
33 | "agent_class": StableBaselinesAgent,
34 | "train_env": (env_ctor, env_kwargs),
35 | "fit_budget": 5e4,
36 | "agent_name": "A2C",
37 | "init_kwargs": {"algo_cls": A2C, "policy": "MlpPolicy", "verbose": 1},
38 | },
39 | {
40 | "agent_class": StableBaselinesAgent,
41 | "train_env": (env_ctor, env_kwargs),
42 | "agent_name": "PPO",
43 | "fit_budget": 5e4,
44 | "init_kwargs": {"algo_cls": PPO, "policy": "MlpPolicy", "verbose": 1},
45 | },
46 | ]
47 |
48 | comparator = AdastopComparator()
49 | comparator.compare(managers)
50 | print(comparator.managers_paths)
51 |
--------------------------------------------------------------------------------
/examples/comparison_agents.py:
--------------------------------------------------------------------------------
1 | """
2 | =========================
3 | Compare Bandit Algorithms
4 | =========================
5 |
6 | This example illustrate the use of compare_agents, a function that uses multiple-testing to assess whether trained agents are
7 | statistically different or not.
8 |
9 | Remark that in the case where two agents are not deemed statistically different it can mean either that they are as efficient,
10 | or it can mean that there have not been enough fits to assess the variability of the agents.
11 |
12 | """
13 |
14 | import numpy as np
15 |
16 | from rlberry.manager.comparison import compare_agents
17 | from rlberry.manager import AgentManager
18 | from rlberry_research.envs.bandits import BernoulliBandit
19 | from rlberry_research.agents.bandits import (
20 | IndexAgent,
21 | makeBoundedMOSSIndex,
22 | makeBoundedNPTSIndex,
23 | makeBoundedUCBIndex,
24 | makeETCIndex,
25 | )
26 |
27 | # Parameters of the problem
28 | means = np.array([0.6, 0.6, 0.6, 0.9]) # means of the arms
29 | A = len(means)
30 | T = 2000 # Horizon
31 | N = 50 # number of fits
32 |
33 | # Construction of the experiment
34 |
35 | env_ctor = BernoulliBandit
36 | env_kwargs = {"p": means}
37 |
38 |
39 | class UCBAgent(IndexAgent):
40 | name = "UCB"
41 |
42 | def __init__(self, env, **kwargs):
43 | index, _ = makeBoundedUCBIndex()
44 | IndexAgent.__init__(self, env, index, writer_extra="reward", **kwargs)
45 |
46 |
47 | class ETCAgent(IndexAgent):
48 | name = "ETC"
49 |
50 | def __init__(self, env, m=20, **kwargs):
51 | index, _ = makeETCIndex(A, m)
52 | IndexAgent.__init__(
53 | self, env, index, writer_extra="action_and_reward", **kwargs
54 | )
55 |
56 |
57 | class MOSSAgent(IndexAgent):
58 | name = "MOSS"
59 |
60 | def __init__(self, env, **kwargs):
61 | index, _ = makeBoundedMOSSIndex(T, A)
62 | IndexAgent.__init__(
63 | self, env, index, writer_extra="action_and_reward", **kwargs
64 | )
65 |
66 |
67 | class NPTSAgent(IndexAgent):
68 | name = "NPTS"
69 |
70 | def __init__(self, env, **kwargs):
71 | index, tracker_params = makeBoundedNPTSIndex()
72 | IndexAgent.__init__(
73 | self,
74 | env,
75 | index,
76 | writer_extra="reward",
77 | tracker_params=tracker_params,
78 | **kwargs,
79 | )
80 |
81 |
82 | Agents_class = [MOSSAgent, NPTSAgent, UCBAgent, ETCAgent]
83 |
84 | managers = [
85 | AgentManager(
86 | Agent,
87 | train_env=(env_ctor, env_kwargs),
88 | fit_budget=T,
89 | parallelization="process",
90 | mp_context="fork",
91 | n_fit=N,
92 | )
93 | for Agent in Agents_class
94 | ]
95 |
96 |
97 | for manager in managers:
98 | manager.fit()
99 |
100 |
101 | def eval_function(manager, eval_budget=None, agent_id=0):
102 | df = manager.get_writer_data()[agent_id]
103 | return T * np.max(means) - np.sum(df.loc[df["tag"] == "reward", "value"])
104 |
105 |
106 | print(
107 | compare_agents(managers, method="tukey_hsd", eval_function=eval_function, B=10_000)
108 | )
109 |
--------------------------------------------------------------------------------
/examples/demo_agents/README.md:
--------------------------------------------------------------------------------
1 | Illustration of rlberry agents
2 | ==============================
3 |
--------------------------------------------------------------------------------
/examples/demo_agents/demo_SAC.py:
--------------------------------------------------------------------------------
1 | """
2 | =============================
3 | SAC Soft Actor-Critic
4 | =============================
5 |
6 | This script shows how to train a SAC agent on a Pendulum environment.
7 | """
8 |
9 | import time
10 |
11 | import gymnasium as gym
12 | from rlberry_research.agents.torch.sac import SACAgent
13 | from rlberry_research.envs import Pendulum
14 | from rlberry.manager import ExperimentManager
15 |
16 |
17 | def env_ctor(env, wrap_spaces=True):
18 | return env
19 |
20 |
21 | # Setup agent parameters
22 | env_name = "Pendulum"
23 | fit_budget = int(2e5)
24 | agent_name = f"{env_name}_{fit_budget}_{int(time.time())}"
25 |
26 | # Setup environment parameters
27 | env = Pendulum()
28 | env = gym.wrappers.TimeLimit(env, max_episode_steps=200)
29 | env = gym.wrappers.RecordEpisodeStatistics(env)
30 | env_kwargs = dict(env=env)
31 |
32 | # Create agent instance
33 | xp_manager = ExperimentManager(
34 | SACAgent,
35 | (env_ctor, env_kwargs),
36 | fit_budget=fit_budget,
37 | n_fit=1,
38 | enable_tensorboard=True,
39 | agent_name=agent_name,
40 | )
41 |
42 | # Start training
43 | xp_manager.fit()
44 |
--------------------------------------------------------------------------------
/examples/demo_agents/gym_videos/openaigym.episode_batch.0.454210.stats.json:
--------------------------------------------------------------------------------
1 | {"initial_reset_timestamp": 22935.230051929, "timestamps": [22938.041554318], "episode_lengths": [71], "episode_rewards": [71.0], "episode_types": ["t"]}
2 |
--------------------------------------------------------------------------------
/examples/demo_agents/gym_videos/openaigym.manifest.0.454210.manifest.json:
--------------------------------------------------------------------------------
1 | {"stats": "openaigym.episode_batch.0.454210.stats.json", "videos": [["openaigym.video.0.454210.video000000.mp4", "openaigym.video.0.454210.video000000.meta.json"]], "env_info": {"gym_version": "0.21.0", "env_id": "CartPole-v0"}}
2 |
--------------------------------------------------------------------------------
/examples/demo_agents/gym_videos/openaigym.video.0.454210.video000000.meta.json:
--------------------------------------------------------------------------------
1 | {"episode_id": 0, "content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version n4.4.1 Copyright (c) 2000-2021 the FFmpeg developers\\nbuilt with gcc 11.1.0 (GCC)\\nconfiguration: --prefix=/usr --disable-debug --disable-static --disable-stripping --enable-amf --enable-avisynth --enable-cuda-llvm --enable-lto --enable-fontconfig --enable-gmp --enable-gnutls --enable-gpl --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libdav1d --enable-libdrm --enable-libfreetype --enable-libfribidi --enable-libgsm --enable-libiec61883 --enable-libjack --enable-libmfx --enable-libmodplug --enable-libmp3lame --enable-libopencore_amrnb --enable-libopencore_amrwb --enable-libopenjpeg --enable-libopus --enable-libpulse --enable-librav1e --enable-librsvg --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libsvtav1 --enable-libtheora --enable-libv4l2 --enable-libvidstab --enable-libvmaf --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx264 --enable-libx265 --enable-libxcb --enable-libxml2 --enable-libxvid --enable-libzimg --enable-nvdec --enable-nvenc --enable-shared --enable-version3\\nlibavutil 56. 70.100 / 56. 70.100\\nlibavcodec 58.134.100 / 58.134.100\\nlibavformat 58. 76.100 / 58. 76.100\\nlibavdevice 58. 13.100 / 58. 13.100\\nlibavfilter 7.110.100 / 7.110.100\\nlibswscale 5. 9.100 / 5. 9.100\\nlibswresample 3. 9.100 / 3. 9.100\\nlibpostproc 55. 9.100 / 55. 9.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/home/frost/tmpmath/rlberry/examples/demo_agents/gym_videos/openaigym.video.0.454210.video000000.mp4"]}}
2 |
--------------------------------------------------------------------------------
/examples/demo_agents/gym_videos/openaigym.video.0.454210.video000000.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/examples/demo_agents/gym_videos/openaigym.video.0.454210.video000000.mp4
--------------------------------------------------------------------------------
/examples/demo_agents/video_plot_a2c.py:
--------------------------------------------------------------------------------
1 | """
2 | ==============================================
3 | A demo of A2C algorithm in PBall2D environment
4 | ==============================================
5 | Illustration of how to set up an A2C algorithm in rlberry.
6 | The environment chosen here is PBALL2D environment.
7 |
8 | .. video:: ../../video_plot_a2c.mp4
9 | :width: 600
10 |
11 | """
12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_a2c.jpg'
13 |
14 | from rlberry_research.agents.torch import A2CAgent
15 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D
16 | from gymnasium.wrappers import TimeLimit
17 |
18 |
19 | env = PBall2D()
20 | env = TimeLimit(env, max_episode_steps=256)
21 | n_timesteps = 50_000
22 | agent = A2CAgent(env, gamma=0.99, learning_rate=0.001)
23 | agent.fit(budget=n_timesteps)
24 |
25 | env.enable_rendering()
26 |
27 | observation, info = env.reset()
28 | for tt in range(200):
29 | action = agent.policy(observation)
30 | observation, reward, terminated, truncated, info = env.step(action)
31 | done = terminated or truncated
32 |
33 | video = env.save_video("_video/video_plot_a2c.mp4")
34 |
--------------------------------------------------------------------------------
/examples/demo_agents/video_plot_dqn.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================================
3 | A demo of DQN algorithm in CartPole environment
4 | ===============================================
5 | Illustration of how to set up a DQN algorithm in rlberry.
6 | The environment chosen here is gym's cartpole environment.
7 |
8 | As DQN can be computationally intensive and hard to tune,
9 | one can use tensorboard to visualize the training of the DQN
10 | using the following command:
11 |
12 | .. code-block:: bash
13 |
14 | tensorboard --logdir {Path(agent.writer.log_dir).parent}
15 |
16 | .. video:: ../../video_plot_dqn.mp4
17 | :width: 600
18 |
19 | """
20 |
21 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_dqn.jpg'
22 |
23 | from rlberry.envs import gym_make
24 | from torch.utils.tensorboard import SummaryWriter
25 |
26 | from rlberry_research.agents.torch.dqn import DQNAgent
27 | from rlberry.utils.logging import configure_logging
28 |
29 | from gymnasium.wrappers.rendering import RecordVideo
30 | import shutil
31 | import os
32 |
33 |
34 | configure_logging(level="INFO")
35 |
36 | env = gym_make("CartPole-v1", render_mode="rgb_array")
37 | agent = DQNAgent(env, epsilon_decay_interval=1000)
38 | agent.set_writer(SummaryWriter())
39 |
40 | print(f"Running DQN on {env}")
41 |
42 | agent.fit(budget=50)
43 | env = RecordVideo(env, "_video/temp")
44 |
45 | for episode in range(3):
46 | done = False
47 | observation, info = env.reset()
48 | while not done:
49 | action = agent.policy(observation)
50 | observation, reward, terminated, truncated, info = env.step(action)
51 | done = terminated or truncated
52 | env.close()
53 |
54 | # need to move the final result inside the folder used for documentation
55 | os.rename("_video/temp/rl-video-episode-0.mp4", "_video/video_plot_dqn.mp4")
56 | shutil.rmtree("_video/temp/")
57 |
--------------------------------------------------------------------------------
/examples/demo_agents/video_plot_mbqvi.py:
--------------------------------------------------------------------------------
1 | """
2 | ==================================================
3 | A demo of MBQVI algorithm in Gridworld environment
4 | ==================================================
5 | Illustration of how to set up an MBQVI algorithm in rlberry.
6 | The environment chosen here is GridWorld environment.
7 |
8 | .. video:: ../../video_plot_mbqvi.mp4
9 | :width: 600
10 |
11 | """
12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_mbqvi.jpg'
13 | from rlberry_scool.agents.mbqvi import MBQVIAgent
14 | from rlberry_research.envs.finite import GridWorld
15 |
16 | params = {}
17 | params["n_samples"] = 100 # samples per state-action pair
18 | params["gamma"] = 0.99
19 | params["horizon"] = None
20 |
21 | env = GridWorld(7, 10, walls=((2, 2), (3, 3)), success_probability=0.6)
22 | agent = MBQVIAgent(env, **params)
23 | info = agent.fit()
24 | print(info)
25 |
26 | # evaluate policy in a deterministic version of the environment
27 | env_eval = GridWorld(7, 10, walls=((2, 2), (3, 3)), success_probability=1.0)
28 | env_eval.enable_rendering()
29 | state, info = env_eval.reset()
30 | for tt in range(50):
31 | action = agent.policy(state)
32 | next_s, _, _, _, _ = env_eval.step(action)
33 | state = next_s
34 | video = env_eval.save_video("_video/video_plot_mbqvi.mp4")
35 |
--------------------------------------------------------------------------------
/examples/demo_agents/video_plot_mdqn.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================================
3 | A demo of M-DQN algorithm in CartPole environment
4 | ===============================================
5 | Illustration of how to set up a M-DQN algorithm in rlberry.
6 | The environment chosen here is gym's cartpole environment.
7 |
8 | As DQN can be computationally intensive and hard to tune,
9 | one can use tensorboard to visualize the training of the DQN
10 | using the following command:
11 |
12 | .. code-block:: bash
13 |
14 | tensorboard --logdir {Path(agent.writer.log_dir).parent}
15 |
16 | .. video:: ../../video_plot_mdqn.mp4
17 | :width: 600
18 |
19 | """
20 |
21 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_dqn.jpg'
22 |
23 | from rlberry.envs import gym_make
24 | from torch.utils.tensorboard import SummaryWriter
25 |
26 | from rlberry_research.agents.torch.dqn import MunchausenDQNAgent
27 | from rlberry.utils.logging import configure_logging
28 |
29 | from gymnasium.wrappers.rendering import RecordVideo
30 | import shutil
31 | import os
32 |
33 |
34 | configure_logging(level="INFO")
35 |
36 | env = gym_make("CartPole-v1", render_mode="rgb_array")
37 | agent = MunchausenDQNAgent(env, epsilon_decay_interval=1000)
38 | agent.set_writer(SummaryWriter())
39 |
40 | print(f"Running Munchausen DQN on {env}")
41 |
42 | agent.fit(budget=10**5)
43 | env = RecordVideo(env, "_video/temp")
44 |
45 |
46 | for episode in range(3):
47 | done = False
48 | observation, info = env.reset()
49 | while not done:
50 | action = agent.policy(observation)
51 | observation, reward, terminated, truncated, info = env.step(action)
52 | done = terminated or truncated
53 | env.close()
54 |
55 | # need to move the final result inside the folder used for documentation
56 | os.rename("_video/temp/rl-video-episode-0.mp4", "_video/video_plot_mdqn.mp4")
57 | shutil.rmtree("_video/temp/")
58 |
--------------------------------------------------------------------------------
/examples/demo_agents/video_plot_ppo.py:
--------------------------------------------------------------------------------
1 | """
2 | ==============================================
3 | A demo of PPO algorithm in PBall2D environment
4 | ==============================================
5 | Illustration of how to set up an PPO algorithm in rlberry.
6 | The environment chosen here is PBALL2D environment.
7 |
8 | .. video:: ../../video_plot_ppo.mp4
9 | :width: 600
10 |
11 | """
12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_a2c.jpg'
13 |
14 | from rlberry_research.agents.torch import PPOAgent
15 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D
16 |
17 |
18 | env = PBall2D()
19 | n_steps = 3e3
20 |
21 | agent = PPOAgent(env)
22 | agent.fit(budget=n_steps)
23 |
24 | env.enable_rendering()
25 | observation, info = env.reset()
26 | for tt in range(200):
27 | action = agent.policy(observation)
28 | observation, reward, terminated, truncated, info = env.step(action)
29 | done = terminated or truncated
30 |
31 | video = env.save_video("_video/video_plot_ppo.mp4")
32 |
--------------------------------------------------------------------------------
/examples/demo_agents/video_plot_rs_kernel_ucbvi.py:
--------------------------------------------------------------------------------
1 | """
2 | =============================================================
3 | A demo of RSKernelUCBVIAgent algorithm in Acrobot environment
4 | =============================================================
5 | Illustration of how to set up a RSKernelUCBVI algorithm in rlberry.
6 | The environment chosen here is Acrobot environment.
7 |
8 | .. video:: ../../video_plot_rs_kernel_ucbvi.mp4
9 | :width: 600
10 |
11 | """
12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_rs_kernel_ucbvi.jpg'
13 |
14 | from rlberry_research.envs import Acrobot
15 | from rlberry_research.agents import RSKernelUCBVIAgent
16 | from rlberry.wrappers import RescaleRewardWrapper
17 |
18 | env = Acrobot()
19 | # rescake rewards to [0, 1]
20 | env = RescaleRewardWrapper(env, (0.0, 1.0))
21 |
22 | agent = RSKernelUCBVIAgent(
23 | env,
24 | gamma=0.99,
25 | horizon=300,
26 | bonus_scale_factor=0.01,
27 | min_dist=0.2,
28 | bandwidth=0.05,
29 | beta=1.0,
30 | kernel_type="gaussian",
31 | )
32 | agent.fit(budget=500)
33 |
34 | env.enable_rendering()
35 | observation, info = env.reset()
36 |
37 | time_before_done = 0
38 | ended = False
39 | for tt in range(2 * agent.horizon):
40 | action = agent.policy(observation)
41 | observation, reward, terminated, truncated, info = env.step(action)
42 | done = terminated or truncated
43 | if not done and not ended:
44 | time_before_done += 1
45 | if done:
46 | ended = True
47 |
48 | print("steps to achieve the goal for the first time = ", time_before_done)
49 | video = env.save_video("_video/video_plot_rs_kernel_ucbvi.mp4")
50 |
--------------------------------------------------------------------------------
/examples/demo_agents/video_plot_rsucbvi.py:
--------------------------------------------------------------------------------
1 | """
2 | ==================================================
3 | A demo of RSUCBVI algorithm in MountainCar environment
4 | ==================================================
5 | Illustration of how to set up an RSUCBVI algorithm in rlberry.
6 | The environment chosen here is MountainCar environment.
7 |
8 | .. video:: ../../video_plot_rsucbvi.mp4
9 | :width: 600
10 |
11 | """
12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_rsucbvi.jpg'
13 |
14 | from rlberry_research.agents import RSUCBVIAgent
15 | from rlberry_research.envs.classic_control import MountainCar
16 |
17 | env = MountainCar()
18 | horizon = 170
19 | print("Running RS-UCBVI on %s" % env.name)
20 | agent = RSUCBVIAgent(env, gamma=0.99, horizon=horizon, bonus_scale_factor=0.1)
21 | agent.fit(budget=500)
22 |
23 | env.enable_rendering()
24 | observation, info = env.reset()
25 | for tt in range(200):
26 | action = agent.policy(observation)
27 | observation, reward, terminated, truncated, info = env.step(action)
28 | done = terminated or truncated
29 |
30 | video = env.save_video("_video/video_plot_rsucbvi.mp4")
31 |
--------------------------------------------------------------------------------
/examples/demo_agents/video_plot_vi.py:
--------------------------------------------------------------------------------
1 | """
2 | =======================================================
3 | A demo of ValueIteration algorithm in Chain environment
4 | =======================================================
5 | Illustration of how to set up an ValueIteration algorithm in rlberry.
6 | The environment chosen here is Chain environment.
7 |
8 | .. video:: ../../video_plot_vi.mp4
9 | :width: 600
10 |
11 | """
12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_vi.jpg'
13 |
14 | from rlberry_scool.agents.dynprog import ValueIterationAgent
15 | from rlberry_scool.envs.finite import Chain
16 |
17 | env = Chain()
18 | agent = ValueIterationAgent(env, gamma=0.95)
19 | info = agent.fit()
20 | print(info)
21 |
22 | env.enable_rendering()
23 | observation, info = env.reset()
24 | for tt in range(50):
25 | action = agent.policy(observation)
26 | observation, reward, terminated, truncated, info = env.step(action)
27 | done = terminated or truncated
28 | if done:
29 | break
30 | video = env.save_video("_video/video_plot_vi.mp4")
31 |
--------------------------------------------------------------------------------
/examples/demo_bandits/README.md:
--------------------------------------------------------------------------------
1 | Illustration of bandits in rlberry
2 | ==================================
3 |
--------------------------------------------------------------------------------
/examples/demo_bandits/plot_exp3_bandit.py:
--------------------------------------------------------------------------------
1 | """
2 | =============================
3 | EXP3 Bandit cumulative regret
4 | =============================
5 |
6 | This script shows how to define a bandit environment and an EXP3
7 | randomized algorithm.
8 | """
9 |
10 | import numpy as np
11 | from rlberry_research.envs.bandits import AdversarialBandit
12 | from rlberry_research.agents.bandits import (
13 | RandomizedAgent,
14 | TSAgent,
15 | makeEXP3Index,
16 | makeBetaPrior,
17 | )
18 | from rlberry.manager import ExperimentManager, plot_writer_data
19 |
20 |
21 | # Agents definition
22 |
23 |
24 | class EXP3Agent(RandomizedAgent):
25 | name = "EXP3"
26 |
27 | def __init__(self, env, **kwargs):
28 | prob, tracker_params = makeEXP3Index()
29 | RandomizedAgent.__init__(
30 | self,
31 | env,
32 | prob,
33 | writer_extra="action",
34 | tracker_params=tracker_params,
35 | **kwargs
36 | )
37 |
38 |
39 | class BernoulliTSAgent(TSAgent):
40 | """Thompson sampling for Bernoulli bandit"""
41 |
42 | name = "TS"
43 |
44 | def __init__(self, env, **kwargs):
45 | prior, _ = makeBetaPrior()
46 | TSAgent.__init__(self, env, prior, writer_extra="action", **kwargs)
47 |
48 |
49 | # Parameters of the problem
50 | T = 3000 # Horizon
51 | M = 20 # number of MC simu
52 |
53 |
54 | def switching_rewards(T, gap=0.1, rate=1.6):
55 | """Adversarially switching rewards over exponentially long phases.
56 | Inspired by Zimmert, Julian, and Yevgeny Seldin.
57 | "Tsallis-INF: An Optimal Algorithm for Stochastic and Adversarial Bandits."
58 | J. Mach. Learn. Res. 22 (2021): 28-1.
59 | """
60 | rewards = np.zeros((T, 2))
61 | t = 0
62 | exp = 1
63 | high_rewards = True
64 | for t in range(T):
65 | if t > np.floor(rate**exp):
66 | high_rewards = not high_rewards
67 | exp += 1
68 | if high_rewards:
69 | rewards[t] = [1.0 - gap, 1.0]
70 | else:
71 | rewards[t] = [0.0, gap]
72 | return rewards
73 |
74 |
75 | rewards = switching_rewards(T, rate=5.0)
76 |
77 |
78 | # Construction of the experiment
79 |
80 | env_ctor = AdversarialBandit
81 | env_kwargs = {"rewards": rewards}
82 |
83 | Agents_class = [EXP3Agent, BernoulliTSAgent]
84 |
85 | agents = [
86 | ExperimentManager(
87 | Agent,
88 | (env_ctor, env_kwargs),
89 | init_kwargs={},
90 | fit_budget=T,
91 | n_fit=M,
92 | parallelization="process",
93 | mp_context="fork",
94 | )
95 | for Agent in Agents_class
96 | ]
97 |
98 | # these parameters should give parallel computing even in notebooks
99 |
100 |
101 | # Agent training
102 | for agent in agents:
103 | agent.fit()
104 |
105 |
106 | # Compute and plot (pseudo-)regret
107 | def compute_pseudo_regret(actions):
108 | selected_rewards = np.array(
109 | [rewards[t, int(action)] for t, action in enumerate(actions)]
110 | )
111 | return np.cumsum(np.max(rewards, axis=1) - selected_rewards)
112 |
113 |
114 | output = plot_writer_data(
115 | agents,
116 | tag="action",
117 | preprocess_func=compute_pseudo_regret,
118 | title="Cumulative Pseudo-Regret",
119 | )
120 |
--------------------------------------------------------------------------------
/examples/demo_bandits/plot_ucb_bandit.py:
--------------------------------------------------------------------------------
1 | """
2 | =============================
3 | UCB Bandit cumulative regret
4 | =============================
5 |
6 | This script shows how to define a bandit environment and an UCB Index-based algorithm.
7 | """
8 |
9 | import numpy as np
10 | from rlberry_research.envs.bandits import NormalBandit
11 | from rlberry_research.agents.bandits import IndexAgent, makeSubgaussianUCBIndex
12 | from rlberry.manager import ExperimentManager, plot_writer_data
13 | import matplotlib.pyplot as plt
14 |
15 |
16 | # Agents definition
17 |
18 |
19 | class UCBAgent(IndexAgent):
20 | """UCB agent for sigma-subgaussian bandits"""
21 |
22 | name = "UCB Agent"
23 |
24 | def __init__(self, env, sigma=1, **kwargs):
25 | index, _ = makeSubgaussianUCBIndex(sigma)
26 | IndexAgent.__init__(self, env, index, writer_extra="action", **kwargs)
27 |
28 |
29 | # Parameters of the problem
30 | means = np.array([0, 0.9, 1]) # means of the arms
31 | T = 3000 # Horizon
32 | M = 20 # number of MC simu
33 |
34 | # Construction of the experiment
35 |
36 | env_ctor = NormalBandit
37 | env_kwargs = {"means": means, "stds": 2 * np.ones(len(means))}
38 |
39 | xp_manager = ExperimentManager(
40 | UCBAgent,
41 | (env_ctor, env_kwargs),
42 | fit_budget=T,
43 | init_kwargs={"sigma": 2},
44 | n_fit=M,
45 | parallelization="process",
46 | mp_context="fork",
47 | )
48 | # these parameters should give parallel computing even in notebooks
49 |
50 |
51 | # Agent training
52 |
53 | xp_manager.fit()
54 |
55 |
56 | # Compute and plot (pseudo-)regret
57 | def compute_pseudo_regret(actions):
58 | return np.cumsum(np.max(means) - means[actions.astype(int)])
59 |
60 |
61 | fig = plt.figure(1, figsize=(5, 3))
62 | ax = plt.gca()
63 | output = plot_writer_data(
64 | [xp_manager],
65 | tag="action",
66 | preprocess_func=compute_pseudo_regret,
67 | title="Cumulative Pseudo-Regret",
68 | ax=ax,
69 | )
70 |
--------------------------------------------------------------------------------
/examples/demo_env/README.md:
--------------------------------------------------------------------------------
1 | Illustration of rlberry environments
2 | ====================================
3 |
--------------------------------------------------------------------------------
/examples/demo_env/video_plot_acrobot.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================================
3 | A demo of Acrobot environment with RSUCBVIAgent
4 | ===============================================
5 | Illustration of the training and video rendering of RSUCBVI Agent in Acrobot
6 | environment.
7 |
8 | .. video:: ../../video_plot_acrobot.mp4
9 | :width: 600
10 |
11 | """
12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_acrobot.jpg'
13 |
14 | from rlberry_research.envs import Acrobot
15 | from rlberry_research.agents import RSUCBVIAgent
16 | from rlberry.wrappers import RescaleRewardWrapper
17 |
18 | env = Acrobot()
19 | # rescale rewards to [0, 1]
20 | env = RescaleRewardWrapper(env, (0.0, 1.0))
21 | n_episodes = 300
22 | agent = RSUCBVIAgent(
23 | env, gamma=0.99, horizon=300, bonus_scale_factor=0.01, min_dist=0.25
24 | )
25 | agent.fit(budget=n_episodes)
26 |
27 | env.enable_rendering()
28 | observation, info = env.reset()
29 | for tt in range(2 * agent.horizon):
30 | action = agent.policy(observation)
31 | observation, reward, terminated, truncated, info = env.step(action)
32 | done = terminated or truncated
33 |
34 | # Save video
35 | video = env.save_video("_video/video_plot_acrobot.mp4")
36 |
--------------------------------------------------------------------------------
/examples/demo_env/video_plot_apple_gold.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================
3 | A demo of AppleGold environment
4 | ===============================
5 | Illustration of Applegold environment on which we train a ValueIteration
6 | algorithm.
7 |
8 | .. video:: ../../video_plot_apple_gold.mp4
9 | :width: 600
10 |
11 | """
12 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_apple_gold.jpg'
13 | from rlberry_research.envs.benchmarks.grid_exploration.apple_gold import AppleGold
14 | from rlberry_scool.agents.dynprog import ValueIterationAgent
15 |
16 | env = AppleGold(reward_free=False, array_observation=False)
17 |
18 | agent = ValueIterationAgent(env, gamma=0.9)
19 | info = agent.fit()
20 | print(info)
21 |
22 | env.enable_rendering()
23 |
24 | observation, info = env.reset()
25 | for tt in range(5):
26 | action = agent.policy(observation)
27 | observation, reward, terminated, truncated, info = env.step(action)
28 | done = terminated or truncated
29 | if done:
30 | break
31 | env.render()
32 | video = env.save_video("_video/video_plot_apple_gold.mp4")
33 |
--------------------------------------------------------------------------------
/examples/demo_env/video_plot_chain.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================
3 | A demo of Chain environment
4 | ===============================
5 | Illustration of Chain environment
6 |
7 | .. video:: ../../video_plot_chain.mp4
8 | :width: 600
9 |
10 | """
11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_chain.jpg'
12 |
13 |
14 | from rlberry_scool.envs.finite import Chain
15 |
16 | env = Chain(10, 0.1)
17 | env.enable_rendering()
18 | for tt in range(5):
19 | env.step(env.action_space.sample())
20 | env.render()
21 | env.save_video("_video/video_plot_chain.mp4")
22 |
--------------------------------------------------------------------------------
/examples/demo_env/video_plot_gridworld.py:
--------------------------------------------------------------------------------
1 | """
2 | .. _gridworld_example:
3 |
4 | ========================================================
5 | A demo of Gridworld environment with ValueIterationAgent
6 | ========================================================
7 | Illustration of the training and video rendering ofValueIteration Agent in
8 | Gridworld environment.
9 |
10 | .. video:: ../../video_plot_gridworld.mp4
11 | :width: 600
12 | """
13 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_gridworld.jpg'
14 |
15 | from rlberry_scool.agents.dynprog import ValueIterationAgent
16 | from rlberry_scool.envs.finite import GridWorld
17 |
18 |
19 | env = GridWorld(7, 10, walls=((2, 2), (3, 3)))
20 |
21 | agent = ValueIterationAgent(env, gamma=0.95)
22 | info = agent.fit()
23 | print(info)
24 |
25 | env.enable_rendering()
26 | observation, info = env.reset()
27 | for tt in range(50):
28 | action = agent.policy(observation)
29 | observation, reward, terminated, truncated, info = env.step(action)
30 | done = terminated or truncated
31 | if done:
32 | # Warning: this will never happen in the present case because there is no terminal state.
33 | # See the doc of GridWorld for more informations on the default parameters of GridWorld.
34 | break
35 | # Save the video
36 | env.save_video("_video/video_plot_gridworld.mp4", framerate=10)
37 |
--------------------------------------------------------------------------------
/examples/demo_env/video_plot_mountain_car.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================
3 | A demo of MountainCar environment
4 | ===============================
5 | Illustration of MountainCar environment
6 |
7 | .. video:: ../../video_plot_montain_car.mp4
8 | :width: 600
9 |
10 | """
11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_montain_car.jpg'
12 |
13 | from rlberry_scool.agents.mbqvi import MBQVIAgent
14 | from rlberry_research.envs.classic_control import MountainCar
15 | from rlberry.wrappers import DiscretizeStateWrapper
16 |
17 | _env = MountainCar()
18 | env = DiscretizeStateWrapper(_env, 20)
19 | agent = MBQVIAgent(env, n_samples=40, gamma=0.99)
20 | agent.fit()
21 |
22 | env.enable_rendering()
23 | observation, info = env.reset()
24 | for tt in range(200):
25 | action = agent.policy(observation)
26 | observation, reward, terminated, truncated, info = env.step(action)
27 | done = terminated or truncated
28 |
29 | video = env.save_video("_video/video_plot_montain_car.mp4")
30 |
--------------------------------------------------------------------------------
/examples/demo_env/video_plot_old_gym_compatibility_wrapper_old_acrobot.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================================
3 | A demo of OldGymCompatibilityWrapper with old_Acrobot environment
4 | ===============================================
5 | Illustration of the wrapper for old environments (old Acrobot).
6 |
7 | .. video:: ../../video_plot_old_gym_acrobot.mp4
8 | :width: 600
9 |
10 | """
11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_old_gym_acrobot.jpg'
12 |
13 |
14 | from rlberry.wrappers.tests.old_env.old_acrobot import Old_Acrobot
15 | from rlberry_research.agents import RSUCBVIAgent
16 | from rlberry.wrappers import RescaleRewardWrapper
17 | from rlberry.wrappers.gym_utils import OldGymCompatibilityWrapper
18 |
19 | env = Old_Acrobot()
20 | env = OldGymCompatibilityWrapper(env)
21 | env = RescaleRewardWrapper(env, (0.0, 1.0))
22 | n_episodes = 300
23 | agent = RSUCBVIAgent(
24 | env, gamma=0.99, horizon=300, bonus_scale_factor=0.01, min_dist=0.25
25 | )
26 | result = env.reset(seed=42)
27 |
28 | agent.fit(budget=n_episodes)
29 |
30 | env.enable_rendering()
31 | observation, info = env.reset()
32 | for tt in range(2 * agent.horizon):
33 | action = agent.policy(observation)
34 | observation, reward, terminated, truncated, info = env.step(action)
35 | done = terminated or truncated
36 |
37 | # Save video
38 | video = env.save_video("_video/video_plot_old_gym_acrobot.mp4")
39 |
--------------------------------------------------------------------------------
/examples/demo_env/video_plot_pball.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================
3 | A demo of PBALL2D environment
4 | ===============================
5 | Illustration of PBall2D environment
6 |
7 | .. video:: ../../video_plot_pball.mp4
8 | :width: 600
9 |
10 | """
11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_pball.jpg'
12 |
13 | import numpy as np
14 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D
15 |
16 | p = 5
17 | A = np.array([[1.0, 0.1], [-0.1, 1.0]])
18 |
19 | reward_amplitudes = np.array([1.0, 0.5, 0.5])
20 | reward_smoothness = np.array([0.25, 0.25, 0.25])
21 |
22 | reward_centers = [
23 | np.array([0.75 * np.cos(np.pi / 2), 0.75 * np.sin(np.pi / 2)]),
24 | np.array([0.75 * np.cos(np.pi / 6), 0.75 * np.sin(np.pi / 6)]),
25 | np.array([0.75 * np.cos(5 * np.pi / 6), 0.75 * np.sin(5 * np.pi / 6)]),
26 | ]
27 |
28 | action_list = [
29 | 0.1 * np.array([1, 0]),
30 | -0.1 * np.array([1, 0]),
31 | 0.1 * np.array([0, 1]),
32 | -0.1 * np.array([0, 1]),
33 | ]
34 |
35 | env = PBall2D(
36 | p=p,
37 | A=A,
38 | reward_amplitudes=reward_amplitudes,
39 | reward_centers=reward_centers,
40 | reward_smoothness=reward_smoothness,
41 | action_list=action_list,
42 | )
43 |
44 | env.enable_rendering()
45 |
46 | for ii in range(5):
47 | env.step(1)
48 | env.step(3)
49 |
50 | env.render()
51 | video = env.save_video("_video/video_plot_pball.mp4")
52 |
--------------------------------------------------------------------------------
/examples/demo_env/video_plot_rooms.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================
3 | A demo of rooms environment
4 | ===============================
5 | Illustration of NRooms environment
6 |
7 | .. video:: ../../video_plot_rooms.mp4
8 | :width: 600
9 |
10 | """
11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_rooms.jpg'
12 |
13 | from rlberry_research.envs.benchmarks.grid_exploration.nroom import NRoom
14 | from rlberry_scool.agents.dynprog import ValueIterationAgent
15 |
16 | env = NRoom(
17 | nrooms=9,
18 | remove_walls=False,
19 | room_size=9,
20 | initial_state_distribution="center",
21 | include_traps=True,
22 | )
23 | horizon = env.observation_space.n
24 |
25 | agent = ValueIterationAgent(env, gamma=0.999, horizon=horizon)
26 | print("fitting...")
27 | info = agent.fit()
28 | print(info)
29 |
30 | env.enable_rendering()
31 |
32 | for _ in range(10):
33 | observation, info = env.reset()
34 | for tt in range(horizon):
35 | # action = agent.policy(observation)
36 | action = env.action_space.sample()
37 | observation, reward, terminated, truncated, info = env.step(action)
38 | done = terminated or truncated
39 | if done:
40 | break
41 | env.render()
42 | video = env.save_video("_video/video_plot_rooms.mp4")
43 |
--------------------------------------------------------------------------------
/examples/demo_env/video_plot_springcartpole.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================================
3 | A demo of SpringCartPole environment with DQNAgent
4 | ===============================================
5 | Illustration of the training and video rendering of DQN Agent in
6 | SpringCartPole environment.
7 |
8 | Agent is slightly tuned, but not optimal. This is just for illustration purpose.
9 |
10 | .. video:: ../../video_plot_springcartpole.mp4
11 | :width: 600
12 |
13 | """
14 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_springcartpole.jpg'
15 |
16 | from rlberry_research.envs.classic_control import SpringCartPole
17 | from rlberry_research.agents.torch import DQNAgent
18 | from gymnasium.wrappers.time_limit import TimeLimit
19 |
20 | model_configs = {
21 | "type": "MultiLayerPerceptron",
22 | "layer_sizes": (256, 256),
23 | "reshape": False,
24 | }
25 |
26 | init_kwargs = dict(
27 | q_net_constructor="rlberry_research.agents.torch.utils.training.model_factory_from_env",
28 | q_net_kwargs=model_configs,
29 | )
30 |
31 | env = SpringCartPole(obs_trans=False, swing_up=True)
32 | env = TimeLimit(env, max_episode_steps=500)
33 | agent = DQNAgent(env, **init_kwargs)
34 | agent.fit(budget=1e5)
35 |
36 | env.enable_rendering()
37 | observation, info = env.reset()
38 |
39 | for tt in range(1000):
40 | action = agent.policy(observation)
41 | observation, reward, terminated, truncated, info = env.step(action)
42 | done = terminated or truncated
43 | if done:
44 | observation, info = env.reset()
45 |
46 | # Save video
47 | video = env.save_video("_video/video_plot_springcartpole.mp4")
48 |
--------------------------------------------------------------------------------
/examples/demo_env/video_plot_twinrooms.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================
3 | A demo of twinrooms environment
4 | ===============================
5 | Illustration of TwinRooms environment
6 |
7 | .. video:: ../../video_plot_twinrooms.mp4
8 | :width: 600
9 |
10 | """
11 | # sphinx_gallery_thumbnail_path = 'thumbnails/video_plot_twinrooms.jpg'
12 |
13 | from rlberry_research.envs.benchmarks.generalization.twinrooms import TwinRooms
14 | from rlberry_scool.agents.mbqvi import MBQVIAgent
15 | from rlberry.wrappers.discretize_state import DiscretizeStateWrapper
16 | from rlberry.seeding import Seeder
17 |
18 | seeder = Seeder(123)
19 |
20 | env = TwinRooms()
21 | env = DiscretizeStateWrapper(env, n_bins=20)
22 | env.reseed(seeder)
23 | horizon = 20
24 | agent = MBQVIAgent(env, n_samples=10, gamma=1.0, horizon=horizon)
25 | agent.reseed(seeder)
26 | agent.fit()
27 |
28 | observation, info = env.reset()
29 | env.enable_rendering()
30 | for ii in range(10):
31 | action = agent.policy(observation)
32 | observation, reward, terminated, truncated, info = env.step(action)
33 | done = terminated or truncated
34 |
35 | if (ii + 1) % horizon == 0:
36 | observation, info = env.reset()
37 |
38 | env.render()
39 | video = env.save_video("_video/video_plot_twinrooms.mp4")
40 |
--------------------------------------------------------------------------------
/examples/demo_experiment/params_experiment.yaml:
--------------------------------------------------------------------------------
1 | # """
2 | # =====================
3 | # Demo: params_experiment.yaml
4 | # =====================
5 | # """
6 | description: 'RSUCBVI in NRoom'
7 | seed: 123
8 | train_env: 'examples/demo_experiment/room.yaml'
9 | eval_env: 'examples/demo_experiment/room.yaml'
10 | agents:
11 | - 'examples/demo_experiment/rsucbvi.yaml'
12 | - 'examples/demo_experiment/rsucbvi_alternative.yaml'
13 |
--------------------------------------------------------------------------------
/examples/demo_experiment/room.yaml:
--------------------------------------------------------------------------------
1 | # """
2 | # =====================
3 | # Demo: room.yaml
4 | # =====================
5 | # """
6 | constructor: 'rlberry_research.envs.benchmarks.grid_exploration.nroom.NRoom'
7 | params:
8 | reward_free: false
9 | array_observation: true
10 | nrooms: 5
11 |
--------------------------------------------------------------------------------
/examples/demo_experiment/rsucbvi.yaml:
--------------------------------------------------------------------------------
1 | # """
2 | # =====================
3 | # Demo: rsucbvi.yaml
4 | # =====================
5 | # """
6 | agent_class: 'rlberry_research.agents.kernel_based.rs_ucbvi.RSUCBVIAgent'
7 | init_kwargs:
8 | gamma: 1.0
9 | lp_metric: 2
10 | min_dist: 0.0
11 | max_repr: 800
12 | bonus_scale_factor: 1.0
13 | reward_free: True
14 | horizon: 50
15 | eval_kwargs:
16 | eval_horizon: 50
17 | fit_kwargs:
18 | fit_budget: 100
19 |
--------------------------------------------------------------------------------
/examples/demo_experiment/rsucbvi_alternative.yaml:
--------------------------------------------------------------------------------
1 | # """
2 | # =====================
3 | # Demo: rsucbvi_alternative.yaml
4 | # =====================
5 | # """
6 | base_config: 'examples/demo_experiment/rsucbvi.yaml'
7 | init_kwargs:
8 | gamma: 0.9
9 |
--------------------------------------------------------------------------------
/examples/demo_experiment/run.py:
--------------------------------------------------------------------------------
1 | """
2 | =====================
3 | Demo: run
4 | =====================
5 | To run the experiment:
6 |
7 | $ python examples/demo_experiment/run.py examples/demo_experiment/params_experiment.yaml
8 |
9 | To see more options:
10 |
11 | $ python examples/demo_experiment/run.py
12 | """
13 |
14 | from rlberry.experiment import load_experiment_results
15 | from rlberry.experiment import experiment_generator
16 | from rlberry.manager.multiple_managers import MultipleManagers
17 |
18 |
19 | if __name__ == "__main__":
20 | multimanagers = MultipleManagers(parallelization="thread")
21 |
22 | for experiment_manager in experiment_generator():
23 | multimanagers.append(experiment_manager)
24 |
25 | multimanagers.run()
26 | multimanagers.save()
27 |
28 | # Reading the results
29 | del multimanagers
30 |
31 | data = load_experiment_results("results", "params_experiment")
32 |
33 | print(data)
34 |
35 | # Fit one of the managers for a few more episodes
36 | # If tensorboard is enabled, you should see more episodes ran for 'rsucbvi_alternative'
37 | data["manager"]["rsucbvi_alternative"].fit(50)
38 |
--------------------------------------------------------------------------------
/examples/plot_kernels.py:
--------------------------------------------------------------------------------
1 | """
2 | =====================
3 | Plot kernel functions
4 | =====================
5 |
6 | This script requires matplotlib
7 | """
8 |
9 | import matplotlib.pyplot as plt
10 | import numpy as np
11 | from rlberry_research.agents.kernel_based.kernels import kernel_func
12 |
13 | kernel_types = [
14 | "uniform",
15 | "triangular",
16 | "gaussian",
17 | "epanechnikov",
18 | "quartic",
19 | "triweight",
20 | "tricube",
21 | "cosine",
22 | "exp-4",
23 | ]
24 |
25 | z = np.linspace(-2, 2, 100)
26 |
27 |
28 | fig, axes = plt.subplots(1, len(kernel_types), figsize=(15, 5))
29 | for ii, k_type in enumerate(kernel_types):
30 | kernel_vals = kernel_func(z, k_type)
31 | axes[ii].plot(z, kernel_vals)
32 | axes[ii].set_title(k_type)
33 | plt.show()
34 |
--------------------------------------------------------------------------------
/examples/plot_writer_wrapper.py:
--------------------------------------------------------------------------------
1 | """
2 | ==============================================
3 | Record reward during training and then plot it
4 | ==============================================
5 |
6 | This script shows how to modify an agent to easily record reward or action
7 | during the fit of the agent and then use the plot utils.
8 |
9 | .. note::
10 | If you already ran this script once, the fitted agent has been saved
11 | in rlberry_data folder. Then, you can comment-out the line
12 |
13 | .. code-block:: python
14 |
15 | agent.fit(budget=10)
16 |
17 | and avoid fitting the agent one more time, the statistics from the last
18 | time you fitted the agent will automatically be loaded. See
19 | `rlberry.manager.plot_writer_data` documentation for more information.
20 | """
21 |
22 |
23 | import numpy as np
24 |
25 | from rlberry_scool.envs import GridWorld
26 | from rlberry.manager import plot_writer_data, ExperimentManager
27 | from rlberry_scool.agents import UCBVIAgent
28 | import matplotlib.pyplot as plt
29 |
30 | # We wrape the default writer of the agent in a WriterWrapper
31 | # to record rewards.
32 |
33 |
34 | class VIAgent(UCBVIAgent):
35 | name = "UCBVIAgent"
36 |
37 | def __init__(self, env, **kwargs):
38 | UCBVIAgent.__init__(self, env, writer_extra="reward", horizon=50, **kwargs)
39 |
40 |
41 | env_ctor = GridWorld
42 | env_kwargs = dict(
43 | nrows=3,
44 | ncols=10,
45 | reward_at={(1, 1): 0.1, (2, 9): 1.0},
46 | walls=((1, 4), (2, 4), (1, 5)),
47 | success_probability=0.7,
48 | )
49 |
50 | env = env_ctor(**env_kwargs)
51 | xp_manager = ExperimentManager(VIAgent, (env_ctor, env_kwargs), fit_budget=10, n_fit=3)
52 |
53 | xp_manager.fit(budget=10)
54 | # comment the line above if you only want to load data from rlberry_data.
55 |
56 |
57 | # We use the following preprocessing function to plot the cumulative reward.
58 | def compute_reward(rewards):
59 | return np.cumsum(rewards)
60 |
61 |
62 | # Plot of the cumulative reward.
63 | output = plot_writer_data(
64 | xp_manager, tag="reward", preprocess_func=compute_reward, title="Cumulative Reward"
65 | )
66 | # The output is for 500 global steps because it uses 10 fit_budget * horizon
67 |
68 | # Log-Log plot :
69 | fig, ax = plt.subplots(1, 1)
70 | plot_writer_data(
71 | xp_manager,
72 | tag="reward",
73 | preprocess_func=compute_reward,
74 | title="Cumulative Reward",
75 | ax=ax,
76 | show=False, # necessary to customize axes
77 | )
78 | ax.set_xlim(100, 500)
79 | ax.relim()
80 | ax.set_xscale("log")
81 | ax.set_yscale("log")
82 |
--------------------------------------------------------------------------------
/rlberry/__init__.py:
--------------------------------------------------------------------------------
1 | from importlib import metadata
2 |
3 | __version__ = metadata.version("rlberry")
4 |
5 | import logging
6 |
7 | logger = logging.getLogger("rlberry_logger")
8 |
9 | from rlberry.utils.logging import configure_logging
10 |
11 |
12 | __path__ = __import__("pkgutil").extend_path(__path__, __name__)
13 |
14 | # Initialize logging level
15 | configure_logging(level="INFO")
16 |
17 |
18 | # define __version__
19 |
20 | __all__ = ["__version__", "logger"]
21 |
--------------------------------------------------------------------------------
/rlberry/agents/__init__.py:
--------------------------------------------------------------------------------
1 | # Interfaces
2 | from .agent import Agent
3 | from .agent import AgentWithSimplePolicy
4 | from .agent import AgentTorch
5 |
--------------------------------------------------------------------------------
/rlberry/agents/stable_baselines/__init__.py:
--------------------------------------------------------------------------------
1 | from .stable_baselines import StableBaselinesAgent
2 |
--------------------------------------------------------------------------------
/rlberry/agents/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/agents/tests/__init__.py
--------------------------------------------------------------------------------
/rlberry/agents/tests/test_stable_baselines.py:
--------------------------------------------------------------------------------
1 | import tempfile
2 |
3 | from stable_baselines3 import A2C
4 |
5 | from rlberry.envs import gym_make
6 | from rlberry.agents.stable_baselines import StableBaselinesAgent
7 | from rlberry.utils.check_agent import check_rl_agent
8 |
9 |
10 | def test_sb3_agent():
11 | # Test only one algorithm per action space type
12 | check_rl_agent(
13 | StableBaselinesAgent,
14 | env=(gym_make, {"id": "Pendulum-v1"}),
15 | init_kwargs={"algo_cls": A2C, "policy": "MlpPolicy", "verbose": 1},
16 | )
17 | check_rl_agent(
18 | StableBaselinesAgent,
19 | env=(gym_make, {"id": "CartPole-v1"}),
20 | init_kwargs={"algo_cls": A2C, "policy": "MlpPolicy", "verbose": 1},
21 | )
22 |
23 |
24 | def test_sb3_tensorboard_log():
25 | # Test tensorboard support
26 | with tempfile.TemporaryDirectory() as tmpdir:
27 | check_rl_agent(
28 | StableBaselinesAgent,
29 | env=(gym_make, {"id": "Pendulum-v1"}),
30 | init_kwargs={
31 | "algo_cls": A2C,
32 | "policy": "MlpPolicy",
33 | "verbose": 1,
34 | "tensorboard_log": tmpdir,
35 | },
36 | )
37 |
--------------------------------------------------------------------------------
/rlberry/agents/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/agents/utils/__init__.py
--------------------------------------------------------------------------------
/rlberry/check_packages.py:
--------------------------------------------------------------------------------
1 | # Define import flags
2 |
3 | TORCH_INSTALLED = True
4 | try:
5 | import torch
6 | except ModuleNotFoundError: # pragma: no cover
7 | TORCH_INSTALLED = False # pragma: no cover
8 |
9 | TENSORBOARD_INSTALLED = True
10 | try:
11 | import torch.utils.tensorboard
12 | except ModuleNotFoundError: # pragma: no cover
13 | TENSORBOARD_INSTALLED = False # pragma: no cover
14 |
--------------------------------------------------------------------------------
/rlberry/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .gym_make import gym_make, atari_make
2 | from .basewrapper import Wrapper
3 | from .interface import Model
4 | from .pipeline import PipelineEnv
5 | from .finite_mdp import FiniteMDP
6 |
--------------------------------------------------------------------------------
/rlberry/envs/interface/__init__.py:
--------------------------------------------------------------------------------
1 | from .model import Model
2 |
--------------------------------------------------------------------------------
/rlberry/envs/pipeline.py:
--------------------------------------------------------------------------------
1 | from rlberry.envs import Wrapper
2 |
3 |
4 | class PipelineEnv(Wrapper):
5 | """
6 | Environment defined as a pipeline of wrappers and an environment to wrap.
7 |
8 | Parameters
9 | ----------
10 | env_ctor: environment class
11 |
12 | env_kwargs: dictionary
13 | kwargs fed to the environment
14 |
15 | wrappers: list of tuple (wrapper, wrapper_kwargs)
16 | list of tuple (wrapper, wrapper_kwargs) to be applied to the environment.
17 | The list [wrapper1, wrapper2] will be applied in the order wrapper1(wrapper2(env))
18 |
19 | Examples
20 | --------
21 | >>> from rlberry.envs import PipelineEnv
22 | >>> from rlberry.envs import gym_make
23 | >>> from rlberry.wrappers import RescaleRewardWrapper
24 | >>>
25 | >>> env_ctor, env_kwargs = PipelineEnv, {
26 | >>> "env_ctor": gym_make,
27 | >>> "env_kwargs": {"id": "Acrobot-v1"},
28 | >>> "wrappers": [(RescaleRewardWrapper, {"reward_range": (0, 1)})],
29 | >>> }
30 | >>> eval_env = (gym_make, {"id":"Acrobot-v1"}) # unscaled env for evaluation
31 |
32 | """
33 |
34 | def __init__(self, env_ctor, env_kwargs, wrappers):
35 | env = env_ctor(**env_kwargs)
36 | for wrapper in wrappers[::-1]:
37 | env = wrapper[0](env, **wrapper[1])
38 | env.reset()
39 | Wrapper.__init__(self, env)
40 |
--------------------------------------------------------------------------------
/rlberry/envs/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/envs/tests/__init__.py
--------------------------------------------------------------------------------
/rlberry/envs/tests/test_env_seeding.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import rlberry.seeding as seeding
4 |
5 | from copy import deepcopy
6 | from rlberry_research.envs.classic_control import MountainCar, Pendulum
7 | from rlberry_scool.envs.finite import Chain
8 | from rlberry_scool.envs.finite import GridWorld
9 | from rlberry_research.envs.benchmarks.grid_exploration.four_room import FourRoom
10 | from rlberry_research.envs.benchmarks.grid_exploration.six_room import SixRoom
11 | from rlberry_research.envs.benchmarks.grid_exploration.apple_gold import AppleGold
12 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D, SimplePBallND
13 |
14 | classes = [
15 | MountainCar,
16 | GridWorld,
17 | Chain,
18 | PBall2D,
19 | SimplePBallND,
20 | Pendulum,
21 | FourRoom,
22 | SixRoom,
23 | AppleGold,
24 | ]
25 |
26 |
27 | def get_env_trajectory(env, horizon):
28 | states = []
29 | ss, info = env.reset()
30 | for ii in range(horizon):
31 | states.append(ss)
32 | ss, _, _, _, _ = env.step(env.action_space.sample())
33 | return states
34 |
35 |
36 | def compare_trajectories(traj1, traj2):
37 | """
38 | returns true if trajectories are equal
39 | """
40 | for ss1, ss2 in zip(traj1, traj2):
41 | if not np.array_equal(ss1, ss2):
42 | return False
43 | return True
44 |
45 |
46 | @pytest.mark.parametrize("ModelClass", classes)
47 | def test_env_seeding(ModelClass):
48 | env1 = ModelClass()
49 | seeder1 = seeding.Seeder(123)
50 | env1.reseed(seeder1)
51 |
52 | env2 = ModelClass()
53 | seeder2 = seeder1.spawn()
54 | env2.reseed(seeder2)
55 |
56 | env3 = ModelClass()
57 | seeder3 = seeding.Seeder(123)
58 | env3.reseed(seeder3)
59 |
60 | env4 = ModelClass()
61 | seeder4 = seeding.Seeder(123)
62 | env4.reseed(seeder4)
63 |
64 | env5 = ModelClass()
65 | env5.reseed(
66 | seeder1
67 | ) # same seeder as env1, but different trajectories. This is expected.
68 |
69 | seeding.safe_reseed(env4, seeder4)
70 |
71 | if deepcopy(env1).is_online():
72 | traj1 = get_env_trajectory(env1, 500)
73 | traj2 = get_env_trajectory(env2, 500)
74 | traj3 = get_env_trajectory(env3, 500)
75 | traj4 = get_env_trajectory(env4, 500)
76 | traj5 = get_env_trajectory(env5, 500)
77 |
78 | assert not compare_trajectories(traj1, traj2)
79 | assert compare_trajectories(traj1, traj3)
80 | assert not compare_trajectories(traj3, traj4)
81 | assert not compare_trajectories(traj1, traj5)
82 |
83 |
84 | @pytest.mark.parametrize("ModelClass", classes)
85 | def test_copy_reseeding(ModelClass):
86 | seeder = seeding.Seeder(123)
87 | env = ModelClass()
88 | env.reseed(seeder)
89 |
90 | c_env = deepcopy(env)
91 | c_env.reseed()
92 |
93 | if deepcopy(env).is_online():
94 | traj1 = get_env_trajectory(env, 500)
95 | traj2 = get_env_trajectory(c_env, 500)
96 | assert not compare_trajectories(traj1, traj2)
97 |
--------------------------------------------------------------------------------
/rlberry/envs/tests/test_gym_env_seeding.py:
--------------------------------------------------------------------------------
1 | from rlberry.seeding.seeding import safe_reseed
2 | import gymnasium as gym
3 | import numpy as np
4 | import pytest
5 | from rlberry.seeding import Seeder
6 | from rlberry.envs import gym_make
7 |
8 | from copy import deepcopy
9 |
10 | gym_envs = [
11 | "Acrobot-v1",
12 | "CartPole-v1",
13 | "MountainCar-v0",
14 | ]
15 |
16 |
17 | def get_env_trajectory(env, horizon):
18 | states = []
19 | observation, info = env.reset()
20 | for ii in range(horizon):
21 | states.append(observation)
22 | observation, _, terminated, truncated, _ = env.step(env.action_space.sample())
23 | done = terminated or truncated
24 | if done:
25 | observation, info = env.reset()
26 | return states
27 |
28 |
29 | def compare_trajectories(traj1, traj2):
30 | for ss1, ss2 in zip(traj1, traj2):
31 | if not np.array_equal(ss1, ss2):
32 | return False
33 | return True
34 |
35 |
36 | @pytest.mark.parametrize("env_name", gym_envs)
37 | def test_env_seeding(env_name):
38 | seeder1 = Seeder(123)
39 | env1 = gym_make(env_name, module_import="numpy")
40 | env1.reseed(seeder1)
41 |
42 | seeder2 = Seeder(456)
43 | env2 = gym_make(env_name)
44 | env2.reseed(seeder2)
45 |
46 | seeder3 = Seeder(123)
47 | env3 = gym_make(env_name)
48 | env3.reseed(seeder3)
49 |
50 | if deepcopy(env1).is_online():
51 | traj1 = get_env_trajectory(env1, 500)
52 | traj2 = get_env_trajectory(env2, 500)
53 | traj3 = get_env_trajectory(env3, 500)
54 |
55 | assert not compare_trajectories(traj1, traj2)
56 | assert compare_trajectories(traj1, traj3)
57 |
58 |
59 | @pytest.mark.parametrize("env_name", gym_envs)
60 | def test_copy_reseeding(env_name):
61 | seeder = Seeder(123)
62 | env = gym_make(env_name)
63 | env.reseed(seeder)
64 |
65 | c_env = deepcopy(env)
66 | c_env.reseed()
67 |
68 | if deepcopy(env).is_online():
69 | traj1 = get_env_trajectory(env, 500)
70 | traj2 = get_env_trajectory(c_env, 500)
71 | assert not compare_trajectories(traj1, traj2)
72 |
73 |
74 | @pytest.mark.parametrize("env_name", gym_envs)
75 | def test_gym_safe_reseed(env_name):
76 | seeder = Seeder(123)
77 | seeder_aux = Seeder(123)
78 |
79 | env1 = gym.make(env_name)
80 | env2 = gym.make(env_name)
81 | env3 = gym.make(env_name)
82 |
83 | safe_reseed(env1, seeder)
84 | safe_reseed(env2, seeder)
85 | safe_reseed(env3, seeder_aux)
86 |
87 | traj1 = get_env_trajectory(env1, 500)
88 | traj2 = get_env_trajectory(env2, 500)
89 | traj3 = get_env_trajectory(env3, 500)
90 | assert not compare_trajectories(traj1, traj2)
91 | assert compare_trajectories(traj1, traj3)
92 |
--------------------------------------------------------------------------------
/rlberry/envs/tests/test_gym_make.py:
--------------------------------------------------------------------------------
1 | from rlberry.envs.gym_make import atari_make
2 | import gymnasium as gym
3 | import ale_py
4 |
5 | gym.register_envs(ale_py)
6 |
7 |
8 | def test_atari_make():
9 | wrappers_dict = dict(terminal_on_life_loss=True, frame_skip=8)
10 | env = atari_make(
11 | "ALE/Freeway-v5", render_mode="rgb_array", atari_SB3_wrappers_dict=wrappers_dict
12 | )
13 | assert "EpisodicLifeEnv" in str(env)
14 | assert "MaxAndSkipEnv" in str(env)
15 | assert "ClipRewardEnv" in str(env)
16 | assert env.render_mode == "rgb_array"
17 |
18 | wrappers_dict2 = dict(terminal_on_life_loss=False, frame_skip=0)
19 | env2 = atari_make(
20 | "ALE/Breakout-v5", render_mode="human", atari_SB3_wrappers_dict=wrappers_dict2
21 | )
22 | assert "EpisodicLifeEnv" not in str(env2)
23 | assert "MaxAndSkipEnv" not in str(env2)
24 | assert "ClipRewardEnv" in str(env2)
25 | assert env2.render_mode == "human"
26 |
27 |
28 | def test_rendering_with_atari_make():
29 | from rlberry.manager import ExperimentManager
30 |
31 | from gymnasium.wrappers.rendering import RecordVideo
32 | import os
33 | from rlberry.envs.gym_make import atari_make
34 |
35 | import tempfile
36 |
37 | with tempfile.TemporaryDirectory() as tmpdirname:
38 | from stable_baselines3 import (
39 | PPO,
40 | ) # https://stable-baselines3.readthedocs.io/en/master/modules/ppo.html
41 |
42 | from rlberry.agents.stable_baselines import StableBaselinesAgent
43 |
44 | tuned_xp = ExperimentManager(
45 | StableBaselinesAgent, # The Agent class.
46 | (
47 | atari_make,
48 | dict(id="ALE/Breakout-v5"),
49 | ), # The Environment to solve.
50 | init_kwargs=dict( # Where to put the agent's hyperparameters
51 | algo_cls=PPO,
52 | policy="MlpPolicy",
53 | ),
54 | fit_budget=1000, # The number of interactions between the agent and the environment during training.
55 | eval_kwargs=dict(
56 | eval_horizon=500
57 | ), # The number of interactions between the agent and the environment during evaluations.
58 | n_fit=1, # The number of agents to train. Usually, it is good to do more than 1 because the training is stochastic.
59 | agent_name="PPO_tuned", # The agent's name.
60 | output_dir=str(tmpdirname) + "/PPO_for_breakout",
61 | )
62 |
63 | tuned_xp.fit()
64 |
65 | env = atari_make("ALE/Breakout-v5", render_mode="rgb_array")
66 | env = RecordVideo(env, str(tmpdirname) + "/_video/temp")
67 |
68 | if "render_modes" in env.metadata:
69 | env.metadata["render.modes"] = env.metadata[
70 | "render_modes"
71 | ] # bug with some 'gym' version
72 |
73 | observation, info = env.reset()
74 | for tt in range(3000):
75 | action = tuned_xp.get_agent_instances()[0].policy(observation)
76 | observation, reward, terminated, truncated, info = env.step(action)
77 | done = terminated or truncated
78 | if done:
79 | break
80 |
81 | env.close()
82 |
83 | assert os.path.exists(str(tmpdirname) + "/_video/temp/rl-video-episode-0.mp4")
84 |
--------------------------------------------------------------------------------
/rlberry/envs/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Tuple
2 | from copy import deepcopy
3 | from rlberry.seeding import safe_reseed
4 |
5 |
6 | import rlberry
7 |
8 | logger = rlberry.logger
9 |
10 |
11 | def process_env(env, seeder, copy_env=True):
12 | if isinstance(env, Tuple):
13 | constructor = env[0]
14 | if constructor is None:
15 | return None
16 | kwargs = env[1] or {}
17 | processed_env = constructor(**kwargs)
18 | else:
19 | if env is None:
20 | return None
21 | if copy_env:
22 | try:
23 | processed_env = deepcopy(env)
24 | except Exception as ex:
25 | raise RuntimeError("[Agent] Not possible to deepcopy env: " + str(ex))
26 | else:
27 | processed_env = env
28 | reseeded = safe_reseed(processed_env, seeder)
29 | if not reseeded:
30 | logger.warning("[Agent] Not possible to reseed environment.")
31 | return processed_env
32 |
--------------------------------------------------------------------------------
/rlberry/experiment/__init__.py:
--------------------------------------------------------------------------------
1 | from .yaml_utils import parse_experiment_config
2 | from .generator import experiment_generator
3 | from .load_results import load_experiment_results
4 |
--------------------------------------------------------------------------------
/rlberry/experiment/generator.py:
--------------------------------------------------------------------------------
1 | """Run experiments.
2 |
3 | Usage:
4 | run.py [--enable_tensorboard] [--n_fit=] [--output_dir=] [--parallelization=] [--max_workers=]
5 | run.py (-h | --help)
6 |
7 | Options:
8 | -h --help Show this screen.
9 | --enable_tensorboard Enable tensorboard writer in ExperimentManager.
10 | --n_fit= Number of times each agent is fit [default: 4].
11 | --output_dir= Directory to save the results [default: results].
12 | --parallelization= Either 'thread' or 'process' [default: process].
13 | --max_workers= Number of workers used by ExperimentManager.fit. Set to -1 for the maximum value. [default: -1]
14 | """
15 |
16 | from docopt import docopt
17 | from pathlib import Path
18 | from rlberry.experiment.yaml_utils import parse_experiment_config
19 | from rlberry.manager import ExperimentManager
20 | from rlberry import check_packages
21 |
22 | import rlberry
23 |
24 | logger = rlberry.logger
25 |
26 |
27 | def experiment_generator():
28 | """
29 | Parse command line arguments and yields ExperimentManager instances.
30 | """
31 | args = docopt(__doc__)
32 | max_workers = int(args["--max_workers"])
33 | if max_workers == -1:
34 | max_workers = None
35 | for _, experiment_manager_kwargs in parse_experiment_config(
36 | Path(args[""]),
37 | n_fit=int(args["--n_fit"]),
38 | max_workers=max_workers,
39 | output_base_dir=args["--output_dir"],
40 | parallelization=args["--parallelization"],
41 | ):
42 | if args["--enable_tensorboard"]:
43 | if check_packages.TENSORBOARD_INSTALLED:
44 | experiment_manager_kwargs.update(dict(enable_tensorboard=True))
45 | else:
46 | logger.warning(
47 | "Option --enable_tensorboard is not available: tensorboard is not installed."
48 | )
49 |
50 | yield ExperimentManager(**experiment_manager_kwargs)
51 |
--------------------------------------------------------------------------------
/rlberry/experiment/tests/params_experiment.yaml:
--------------------------------------------------------------------------------
1 | description: 'RSUCBVI in NRoom'
2 | seed: 4575458
3 | train_env: 'rlberry/experiment/tests/room.yaml'
4 | eval_env: 'rlberry/experiment/tests/room.yaml'
5 | global_init_kwargs:
6 | reward_free: True
7 | horizon: 2
8 | global_eval_kwargs:
9 | eval_horizon: 4
10 | global_fit_kwargs:
11 | fit_budget: 3
12 | agents:
13 | - 'rlberry/experiment/tests/rsucbvi.yaml'
14 | - 'rlberry/experiment/tests/rsucbvi_alternative.yaml'
15 |
--------------------------------------------------------------------------------
/rlberry/experiment/tests/room.yaml:
--------------------------------------------------------------------------------
1 | constructor: 'rlberry_research.envs.benchmarks.grid_exploration.nroom.NRoom'
2 | params:
3 | reward_free: false
4 | array_observation: true
5 | nrooms: 5
6 |
--------------------------------------------------------------------------------
/rlberry/experiment/tests/rsucbvi.yaml:
--------------------------------------------------------------------------------
1 | agent_class: 'rlberry_research.agents.kernel_based.rs_ucbvi.RSUCBVIAgent'
2 | init_kwargs:
3 | gamma: 1.0
4 | lp_metric: 2
5 | min_dist: 0.0
6 | max_repr: 800
7 | bonus_scale_factor: 1.0
8 | reward_free: True
9 | horizon: 50
10 | eval_kwargs:
11 | eval_horizon: 50
12 | fit_kwargs:
13 | fit_budget: 123
14 |
--------------------------------------------------------------------------------
/rlberry/experiment/tests/rsucbvi_alternative.yaml:
--------------------------------------------------------------------------------
1 | base_config: 'rlberry/experiment/tests/rsucbvi.yaml'
2 | init_kwargs:
3 | gamma: 0.9
4 | eval_kwargs:
5 | horizon: 50
6 |
--------------------------------------------------------------------------------
/rlberry/experiment/tests/test_experiment_generator.py:
--------------------------------------------------------------------------------
1 | from rlberry.experiment import experiment_generator
2 | from rlberry_research.agents.kernel_based.rs_ucbvi import RSUCBVIAgent
3 |
4 | import numpy as np
5 |
6 |
7 | def test_mock_args(monkeypatch):
8 | monkeypatch.setattr(
9 | "sys.argv", ["", "rlberry/experiment/tests/params_experiment.yaml"]
10 | )
11 | random_numbers = []
12 |
13 | for experiment_manager in experiment_generator():
14 | rng = experiment_manager.seeder.rng
15 | random_numbers.append(rng.uniform(size=10))
16 |
17 | assert experiment_manager.agent_class is RSUCBVIAgent
18 | assert experiment_manager._base_init_kwargs["horizon"] == 2
19 | assert experiment_manager.fit_budget == 3
20 | assert experiment_manager.eval_kwargs["eval_horizon"] == 4
21 |
22 | assert experiment_manager._base_init_kwargs["lp_metric"] == 2
23 | assert experiment_manager._base_init_kwargs["min_dist"] == 0.0
24 | assert experiment_manager._base_init_kwargs["max_repr"] == 800
25 | assert experiment_manager._base_init_kwargs["bonus_scale_factor"] == 1.0
26 | assert experiment_manager._base_init_kwargs["reward_free"] is True
27 |
28 | train_env = experiment_manager.train_env[0](**experiment_manager.train_env[1])
29 | assert train_env.reward_free is False
30 | assert train_env.array_observation is True
31 |
32 | if experiment_manager.agent_name == "rsucbvi":
33 | assert experiment_manager._base_init_kwargs["gamma"] == 1.0
34 |
35 | elif experiment_manager.agent_name == "rsucbvi_alternative":
36 | assert experiment_manager._base_init_kwargs["gamma"] == 0.9
37 |
38 | else:
39 | raise ValueError()
40 |
41 | # check that seeding is the same for each ExperimentManager instance
42 | for ii in range(1, len(random_numbers)):
43 | assert np.array_equal(random_numbers[ii - 1], random_numbers[ii])
44 |
--------------------------------------------------------------------------------
/rlberry/experiment/tests/test_load_results.py:
--------------------------------------------------------------------------------
1 | from rlberry.experiment import load_experiment_results
2 | import tempfile
3 | from rlberry.experiment import experiment_generator
4 | import os
5 | import sys
6 |
7 | TEST_DIR = os.path.dirname(os.path.abspath(__file__))
8 |
9 |
10 | def test_save_and_load():
11 | TEST_DIR = os.path.dirname(os.path.abspath(__file__))
12 | sys.argv = [TEST_DIR + "/test_load_results.py"]
13 | with tempfile.TemporaryDirectory() as tmpdirname:
14 | sys.argv.append(TEST_DIR + "/params_experiment.yaml")
15 | sys.argv.append("--parallelization=thread")
16 | sys.argv.append("--output_dir=" + tmpdirname)
17 | print(sys.argv)
18 | for experiment_manager in experiment_generator():
19 | experiment_manager.fit()
20 | experiment_manager.save()
21 | data = load_experiment_results(tmpdirname, "params_experiment")
22 |
23 | assert len(data) > 0
24 |
--------------------------------------------------------------------------------
/rlberry/manager/__init__.py:
--------------------------------------------------------------------------------
1 | from .experiment_manager import ExperimentManager
2 | from .experiment_manager import preset_manager
3 | from .multiple_managers import MultipleManagers
4 | from .evaluation import evaluate_agents, read_writer_data
5 | from .comparison import compare_agents, AdastopComparator
6 | from .plotting import plot_smoothed_curves, plot_writer_data, plot_synchronized_curves
7 | from .env_tools import with_venv, run_venv_xp
8 | from .utils import tensorboard_to_dataframe
9 |
10 | # AgentManager alias for the ExperimentManager class, for backward compatibility
11 | AgentManager = ExperimentManager
12 |
--------------------------------------------------------------------------------
/rlberry/manager/multiple_managers.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import functools
3 | import multiprocessing
4 | from typing import Optional
5 |
6 |
7 | def fit_stats(stats, save):
8 | stats.fit()
9 | if save:
10 | stats.save()
11 | return stats
12 |
13 |
14 | class MultipleManagers:
15 | """
16 | Class to fit multiple ExperimentManager instances in parallel with multiple threads.
17 |
18 | Parameters
19 | ----------
20 | max_workers: int, default=None
21 | max number of workers (ExperimentManager instances) fitted at the same time.
22 | parallelization: {'thread', 'process'}, default: 'process'
23 | Whether to parallelize agent training using threads or processes.
24 | mp_context: {'spawn', 'fork', 'forkserver'}, default: 'spawn'.
25 | Context for python multiprocessing module.
26 | Warning: If you're using JAX or PyTorch, it only works with 'spawn'.
27 | If running code on a notebook or interpreter, use 'fork'.
28 | """
29 |
30 | def __init__(
31 | self,
32 | max_workers: Optional[int] = None,
33 | parallelization: str = "process",
34 | mp_context="spawn",
35 | ) -> None:
36 | super().__init__()
37 | self.instances = []
38 | self.max_workers = max_workers
39 | self.parallelization = parallelization
40 | self.mp_context = mp_context
41 |
42 | def append(self, experiment_manager):
43 | """
44 | Append new ExperimentManager instance.
45 |
46 | Parameters
47 | ----------
48 | experiment_manager : ExperimentManager
49 | """
50 | self.instances.append(experiment_manager)
51 |
52 | def run(self, save=True):
53 | """
54 | Fit ExperimentManager instances in parallel.
55 |
56 | Parameters
57 | ----------
58 | save: bool, default: True
59 | If true, save ExperimentManager intances immediately after fitting.
60 | ExperimentManager.save() is called.
61 | """
62 | if self.parallelization == "thread":
63 | executor_class = concurrent.futures.ThreadPoolExecutor
64 | elif self.parallelization == "process":
65 | executor_class = functools.partial(
66 | concurrent.futures.ProcessPoolExecutor,
67 | mp_context=multiprocessing.get_context(self.mp_context),
68 | )
69 | else:
70 | raise ValueError(
71 | f"Invalid backend for parallelization: {self.parallelization}"
72 | )
73 |
74 | with executor_class(max_workers=self.max_workers) as executor:
75 | futures = []
76 | for inst in self.instances:
77 | futures.append(executor.submit(fit_stats, inst, save=save))
78 |
79 | fitted_instances = []
80 | for future in concurrent.futures.as_completed(futures):
81 | fitted_instances.append(future.result())
82 |
83 | self.instances = fitted_instances
84 |
85 | def save(self):
86 | """
87 | Pickle ExperimentManager instances and saves fit statistics in .csv files.
88 | The output folder is defined in each of the ExperimentManager instances.
89 | """
90 | for stats in self.instances:
91 | stats.save()
92 |
93 | @property
94 | def managers(self):
95 | return self.instances
96 |
--------------------------------------------------------------------------------
/rlberry/manager/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/manager/tests/__init__.py
--------------------------------------------------------------------------------
/rlberry/manager/tests/test_shared_data.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import numpy as np
3 | from rlberry.agents import Agent
4 | from rlberry.manager import ExperimentManager
5 |
6 |
7 | class DummyAgent(Agent):
8 | def __init__(self, **kwargs):
9 | Agent.__init__(self, **kwargs)
10 | self.name = "DummyAgent"
11 | self.shared_data_id = id(self.thread_shared_data)
12 |
13 | def fit(self, budget, **kwargs):
14 | del budget, kwargs
15 |
16 | def eval(self, **kwargs):
17 | del kwargs
18 | return self.shared_data_id
19 |
20 |
21 | @pytest.mark.parametrize("paralellization", ["thread", "process"])
22 | def test_data_sharing(paralellization):
23 | shared_data = dict(X=np.arange(10))
24 | manager = ExperimentManager(
25 | agent_class=DummyAgent,
26 | fit_budget=-1,
27 | n_fit=4,
28 | parallelization=paralellization,
29 | thread_shared_data=shared_data,
30 | )
31 | manager.fit()
32 | data_ids = [agent.eval() for agent in manager.get_agent_instances()]
33 | unique_data_ids = list(set(data_ids))
34 | if paralellization == "thread":
35 | # id() is unique for each object: make sure that shared data have same id
36 | assert len(unique_data_ids) == 1
37 | else:
38 | # when using processes, make sure that data is copied and each instance
39 | # has its own data id
40 | assert len(unique_data_ids) == manager.n_fit
41 |
--------------------------------------------------------------------------------
/rlberry/manager/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | from rlberry.manager import tensorboard_to_dataframe
2 | from stable_baselines3 import PPO, A2C
3 | import tempfile
4 | import os
5 | import pandas as pd
6 | import pytest
7 |
8 |
9 | def test_tensorboard_to_dataframe():
10 | with tempfile.TemporaryDirectory() as tmpdirname:
11 | # create data to test
12 | path_ppo = str(tmpdirname + "/ppo_cartpole_tensorboard/")
13 | path_a2c = str(tmpdirname + "/a2c_cartpole_tensorboard/")
14 | model = PPO("MlpPolicy", "CartPole-v1", tensorboard_log=path_ppo)
15 | model2 = A2C("MlpPolicy", "CartPole-v1", tensorboard_log=path_a2c)
16 | model2_seed2 = A2C("MlpPolicy", "CartPole-v1", tensorboard_log=path_a2c)
17 | model.learn(total_timesteps=5_000, tb_log_name="ppo")
18 | model2.learn(total_timesteps=5_000, tb_log_name="A2C")
19 | model2_seed2.learn(total_timesteps=5_000, tb_log_name="A2C")
20 |
21 | assert os.path.exists(path_ppo)
22 | assert os.path.exists(path_a2c)
23 |
24 | # check with parent folder
25 | data_in_dataframe = tensorboard_to_dataframe(tmpdirname)
26 |
27 | assert isinstance(data_in_dataframe, dict)
28 | assert "rollout/ep_rew_mean" in data_in_dataframe
29 | a_dict = data_in_dataframe["rollout/ep_rew_mean"]
30 |
31 | assert isinstance(a_dict, pd.DataFrame)
32 | assert "name" in a_dict.columns
33 | assert "n_simu" in a_dict.columns
34 | assert "x" in a_dict.columns
35 | assert "y" in a_dict.columns
36 |
37 | # check with list of folder
38 | folder_ppo_1 = str(path_ppo + "ppo_1/")
39 | folder_A2C_1 = str(path_a2c + "A2C_1/")
40 | folder_A2C_2 = str(path_a2c + "A2C_2/")
41 |
42 | path_event_ppo_1 = str(folder_ppo_1 + os.listdir(folder_ppo_1)[0])
43 | path_event_A2C_1 = str(folder_A2C_1 + os.listdir(folder_A2C_1)[0])
44 | path_event_A2C_2 = str(folder_A2C_2 + os.listdir(folder_A2C_2)[0])
45 |
46 | input_dict = {
47 | "ppo_cartpole_tensorboard": [path_event_ppo_1],
48 | "a2c_cartpole_tensorboard": [path_event_A2C_1, path_event_A2C_2],
49 | }
50 |
51 | data_in_dataframe2 = tensorboard_to_dataframe(input_dict)
52 | assert isinstance(data_in_dataframe2, dict)
53 | assert "rollout/ep_rew_mean" in data_in_dataframe2
54 | a_dict2 = data_in_dataframe2["rollout/ep_rew_mean"]
55 |
56 | assert isinstance(a_dict2, pd.DataFrame)
57 | assert "name" in a_dict2.columns
58 | assert "n_simu" in a_dict2.columns
59 | assert "x" in a_dict2.columns
60 | assert "y" in a_dict2.columns
61 |
62 | # check both strategies give the same result
63 | assert set(a_dict.keys()) == set(a_dict2.keys())
64 | for key in a_dict:
65 | if (
66 | key != "n_simu"
67 | ): # don't test n_simu/seed, it is different because one come from the folder name, and the other come for the index in the list
68 | assert set(a_dict[key]) == set(a_dict2[key])
69 |
70 |
71 | def test_tensorboard_to_dataframe_errorIO():
72 | msg = "Input of 'tensorboard_to_dataframe' must be a str or a dict... not a "
73 | with pytest.raises(IOError, match=msg):
74 | tensorboard_to_dataframe(1)
75 |
--------------------------------------------------------------------------------
/rlberry/manager/tests/test_venv.py:
--------------------------------------------------------------------------------
1 | from rlberry.manager import with_venv, run_venv_xp
2 |
3 |
4 | @with_venv(import_libs=["tqdm"], verbose=True)
5 | def run_tqdm():
6 | from tqdm import tqdm # noqa
7 |
8 |
9 | def test_venv():
10 | run_venv_xp(verbose=True)
11 |
--------------------------------------------------------------------------------
/rlberry/metadata_utils.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import uuid
3 | import hashlib
4 | from typing import Optional, NamedTuple
5 |
6 |
7 | # Default output directory used by the library.
8 | RLBERRY_DEFAULT_DATA_DIR = "rlberry_data/"
9 |
10 | # Temporary directory used by the library
11 | RLBERRY_TEMP_DATA_DIR = "rlberry_data/temp/"
12 |
13 |
14 | def get_timestamp_str():
15 | """
16 | Get a string containing current time stamp.
17 | """
18 | now = datetime.now()
19 | date_time = now.strftime("%Y-%m-%d_%H-%M-%S")
20 | timestamp = date_time
21 | return timestamp
22 |
23 |
24 | def get_readable_id(obj):
25 | """
26 | Create a more readable id than get_unique_id(),
27 | combining a timestamp to a 8-character hash.
28 | """
29 | long_id = get_unique_id(obj)
30 | timestamp = get_timestamp_str()
31 | short_id = f"{timestamp}_{long_id[:8]}"
32 | return short_id
33 |
34 |
35 | def get_unique_id(obj):
36 | """
37 | Get a unique id for an obj. Use it in __init__ methods when necessary.
38 | """
39 | # id() is guaranteed to be unique among simultaneously existing objects (uses memory address).
40 | # uuid4() is an universal id, but there might be issues if called simultaneously in different processes.
41 | # This function combines id(), uuid4(), and a timestamp in a single ID, and hashes it.
42 | timestamp = datetime.timestamp(datetime.now())
43 | timestamp = str(timestamp).replace(".", "")
44 | str_id = timestamp + str(id(obj)) + uuid.uuid4().hex
45 | str_id = hashlib.md5(str_id.encode()).hexdigest()
46 | return str_id
47 |
48 |
49 | class ExecutionMetadata(NamedTuple):
50 | """
51 | Metadata for objects handled by rlberry.
52 |
53 | Attributes
54 | ----------
55 | obj_worker_id : int, default: -1
56 | If given, must be >= 0, and inform the worker id (thread or process) where the
57 | object was created. It is not necessarity unique across all the workers launched by
58 | rlberry, it is mainly for debug purposes.
59 | obj_info : dict, default: None
60 | Extra info about the object.
61 | """
62 |
63 | obj_worker_id: int = -1
64 | obj_info: Optional[dict] = None
65 |
--------------------------------------------------------------------------------
/rlberry/rendering/__init__.py:
--------------------------------------------------------------------------------
1 | from .core import Scene, GeometricPrimitive
2 | from .render_interface import RenderInterface
3 | from .render_interface import RenderInterface2D
4 |
--------------------------------------------------------------------------------
/rlberry/rendering/common_shapes.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rlberry.rendering import GeometricPrimitive
3 |
4 |
5 | def bar_shape(p0, p1, width):
6 | shape = GeometricPrimitive("QUADS")
7 |
8 | x0, y0 = p0
9 | x1, y1 = p1
10 |
11 | direction = np.array([x1 - x0, y1 - y0])
12 | norm = np.sqrt((direction * direction).sum())
13 | direction = direction / norm
14 |
15 | # get vector perpendicular to direction
16 | u_vec = np.zeros(2)
17 | u_vec[0] = -direction[1]
18 | u_vec[1] = direction[0]
19 |
20 | u_vec = u_vec * width / 2
21 |
22 | shape.add_vertex((x0 + u_vec[0], y0 + u_vec[1]))
23 | shape.add_vertex((x0 - u_vec[0], y0 - u_vec[1]))
24 | shape.add_vertex((x1 - u_vec[0], y1 - u_vec[1]))
25 | shape.add_vertex((x1 + u_vec[0], y1 + u_vec[1]))
26 | return shape
27 |
28 |
29 | def circle_shape(center, radius, n_points=50):
30 | shape = GeometricPrimitive("POLYGON")
31 |
32 | x0, y0 = center
33 | theta = np.linspace(0.0, 2 * np.pi, n_points)
34 | for tt in theta:
35 | xx = radius * np.cos(tt)
36 | yy = radius * np.sin(tt)
37 | shape.add_vertex((x0 + xx, y0 + yy))
38 |
39 | return shape
40 |
--------------------------------------------------------------------------------
/rlberry/rendering/core.py:
--------------------------------------------------------------------------------
1 | """
2 | Provide classes for geometric primitives in OpenGL and scenes.
3 | """
4 |
5 |
6 | class Scene:
7 | """
8 | Class representing a scene, which is a vector of GeometricPrimitive objects
9 | """
10 |
11 | def __init__(self):
12 | self.shapes = []
13 |
14 | def add_shape(self, shape):
15 | self.shapes.append(shape)
16 |
17 |
18 | class GeometricPrimitive:
19 | """
20 | Class representing an OpenGL geometric primitive.
21 |
22 | Primitive type (GL_LINE_LOOP by defaut)
23 |
24 | If using OpenGLRender2D, one of the following:
25 | POINTS
26 | LINES
27 | LINE_STRIP
28 | LINE_LOOP
29 | POLYGON
30 | TRIANGLES
31 | TRIANGLE_STRIP
32 | TRIANGLE_FAN
33 | QUADS
34 | QUAD_STRIP
35 |
36 | If using PyGameRender2D:
37 | POLYGON
38 |
39 |
40 | TODO: Add support to more pygame shapes,
41 | see https://www.pygame.org/docs/ref/draw.html
42 | """
43 |
44 | def __init__(self, primitive_type="GL_LINE_LOOP"):
45 | # primitive type
46 | self.type = primitive_type
47 | # color in RGB
48 | self.color = (0.25, 0.25, 0.25)
49 | # list of vertices. each vertex is a tuple with coordinates in space
50 | self.vertices = []
51 |
52 | def add_vertex(self, vertex):
53 | self.vertices.append(vertex)
54 |
55 | def set_color(self, color):
56 | self.color = color
57 |
--------------------------------------------------------------------------------
/rlberry/rendering/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/rendering/tests/__init__.py
--------------------------------------------------------------------------------
/rlberry/rendering/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import rlberry
3 |
4 | logger = rlberry.logger
5 | import imageio
6 |
7 |
8 | _FFMPEG_INSTALLED = True
9 | try:
10 | import ffmpeg
11 | except Exception:
12 | _FFMPEG_INSTALLED = False
13 |
14 |
15 | def video_write(fn, images, framerate=60, vcodec="libx264"):
16 | """
17 | Save list of images to a video file.
18 |
19 | Source:
20 | https://github.com/kkroening/ffmpeg-python/issues/246#issuecomment-520200981
21 | Modified so that framerate is given to .input(), as suggested in the
22 | thread, to avoid
23 | skipping frames.
24 |
25 | Parameters
26 | ----------
27 | fn : string
28 | filename
29 | images : list or np.array
30 | list of images to save to a video.
31 | framerate : int
32 | """
33 | global _FFMPEG_INSTALLED
34 |
35 | try:
36 | if len(images) == 0:
37 | logger.warning("Calling video_write() with empty images.")
38 | return
39 |
40 | if not _FFMPEG_INSTALLED:
41 | logger.error(
42 | "video_write(): Unable to save video, ffmpeg-python \
43 | package required (https://github.com/kkroening/ffmpeg-python)"
44 | )
45 | return
46 |
47 | if not isinstance(images, np.ndarray):
48 | images = np.asarray(images)
49 | _, height, width, channels = images.shape
50 | process = (
51 | ffmpeg.input(
52 | "pipe:",
53 | format="rawvideo",
54 | pix_fmt="rgb24",
55 | s="{}x{}".format(width, height),
56 | r=framerate,
57 | )
58 | .output(fn, pix_fmt="yuv420p", vcodec=vcodec)
59 | .overwrite_output()
60 | .run_async(pipe_stdin=True)
61 | )
62 | for frame in images:
63 | process.stdin.write(frame.astype(np.uint8).tobytes())
64 | process.stdin.close()
65 | process.wait()
66 |
67 | except Exception as ex:
68 | logger.warning(
69 | "Not possible to save \
70 | video, due to exception: {}".format(
71 | str(ex)
72 | )
73 | )
74 |
75 |
76 | def gif_write(fn, images):
77 | """
78 | Save list of images to a gif file
79 |
80 | Parameters
81 | ----------
82 | fn : string
83 | filename
84 | images : list or np.array
85 | list of images to save to a gif.
86 | """
87 |
88 | try:
89 | if len(images) == 0:
90 | logger.warning("Calling gif_write() with empty images.")
91 | return
92 |
93 | if not isinstance(images, np.ndarray):
94 | images = np.asarray(images)
95 |
96 | with imageio.get_writer(fn, mode="I") as writer:
97 | for frame in images:
98 | writer.append_data(frame)
99 |
100 | except Exception as ex:
101 | logger.warning(
102 | "Not possible to save \
103 | gif, due to exception: {}".format(
104 | str(ex)
105 | )
106 | )
107 |
--------------------------------------------------------------------------------
/rlberry/seeding/__init__.py:
--------------------------------------------------------------------------------
1 | from .seeder import Seeder
2 | from .seeding import safe_reseed
3 | from .seeding import set_external_seed
4 |
--------------------------------------------------------------------------------
/rlberry/seeding/seeding.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import rlberry.check_packages as check_packages
3 | from rlberry.seeding.seeder import Seeder
4 |
5 | if check_packages.TORCH_INSTALLED:
6 | import torch
7 |
8 |
9 | def set_external_seed(seeder):
10 | """
11 | Set seeds of external libraries.
12 |
13 | To do:
14 | Check (torch seeding):
15 | https://github.com/pytorch/pytorch/issues/7068#issuecomment-487907668
16 |
17 | Parameters
18 | ---------
19 | seeder: seeding.Seeder or int
20 | Integer or Seeder object from which to generate random seeds.
21 |
22 | Examples
23 | --------
24 | >>> from rlberry.seeding import set_external_seed
25 | >>> set_external_seed(seeder)
26 | """
27 | if np.issubdtype(type(seeder), np.integer):
28 | seeder = Seeder(seeder)
29 |
30 | # seed torch
31 | if check_packages.TORCH_INSTALLED:
32 | torch.manual_seed(seeder.seed_seq.generate_state(1, dtype=np.uint32)[0])
33 |
34 |
35 | def safe_reseed(obj, seeder, reseed_spaces=True):
36 | """
37 | Calls obj.reseed(seed_seq) method if available;
38 | If a obj.seed() method is available, call obj.seed(seed_val),
39 | where seed_val is generated by the seeder.
40 | Otherwise, does nothing.
41 |
42 | Parameters
43 | ----------
44 | obj : object
45 | Object to be reseeded.
46 | seeder: :class:`~rlberry.seeding.seeder.Seeder`
47 | Seeder object from which to generate random seeds.
48 | reseed_spaces: bool, default = True.
49 | If False, do not try to reseed observation_space and action_space (if
50 | they exist as attributes of `obj`).
51 |
52 | Returns
53 | -------
54 | True if reseeding was done, False otherwise.
55 |
56 | """
57 | reseeded = False
58 | try:
59 | obj.reseed(seeder)
60 | reseeded = True
61 | except AttributeError:
62 | seed_val = seeder.rng.integers(2**32).item()
63 | try:
64 | obj.seed(seed_val)
65 | reseeded = True
66 | except AttributeError:
67 | try:
68 | obj.reset(seed=seed_val)
69 | reseeded = True
70 | except AttributeError:
71 | reseeded = False
72 |
73 | # check if the object has observation and action spaces to be reseeded.
74 | if reseed_spaces:
75 | try:
76 | safe_reseed(obj.observation_space, seeder)
77 | safe_reseed(obj.action_space, seeder)
78 | except AttributeError:
79 | pass
80 |
81 | return reseeded
82 |
--------------------------------------------------------------------------------
/rlberry/seeding/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/seeding/tests/__init__.py
--------------------------------------------------------------------------------
/rlberry/seeding/tests/test_seeding.py:
--------------------------------------------------------------------------------
1 | from rlberry.seeding import Seeder
2 |
3 |
4 | def test_seeder_basic():
5 | seeder1 = Seeder(43)
6 | data1 = seeder1.rng.integers(100, size=1000)
7 |
8 | seeder2 = Seeder(44)
9 | data2 = seeder2.rng.integers(100, size=1000)
10 |
11 | seeder3 = Seeder(44)
12 | data3 = seeder3.rng.integers(100, size=1000)
13 |
14 | assert (data1 != data2).sum() > 5
15 | assert (data2 != data3).sum() == 0
16 | assert (
17 | seeder2.spawn(1).generate_state(1)[0] == seeder3.spawn(1).generate_state(1)[0]
18 | )
19 | assert (
20 | seeder1.spawn(1).generate_state(1)[0] != seeder3.spawn(1).generate_state(1)[0]
21 | )
22 |
23 |
24 | def test_seeder_initialized_from_seeder():
25 | """
26 | Check that Seeder(seed_seq) respawns seed_seq in the constructor.
27 | """
28 | seeder1 = Seeder(43)
29 | seeder_temp = Seeder(43)
30 | seeder2 = Seeder(seeder_temp)
31 |
32 | data1 = seeder1.rng.integers(100, size=1000)
33 | data2 = seeder2.rng.integers(100, size=1000)
34 | assert (data1 != data2).sum() > 5
35 |
36 |
37 | def test_seeder_spawning():
38 | """
39 | Check that Seeder(seed_seq) respawns seed_seq in the constructor.
40 | """
41 | seeder1 = Seeder(43)
42 | seeder2 = seeder1.spawn()
43 | seeder3 = seeder2.spawn()
44 |
45 | print(seeder1)
46 | print(seeder2)
47 | print(seeder3)
48 |
49 | data1 = seeder1.rng.integers(100, size=1000)
50 | data2 = seeder2.rng.integers(100, size=1000)
51 | assert (data1 != data2).sum() > 5
52 |
53 |
54 | def test_seeder_reseeding():
55 | """
56 | Check that reseeding with a Seeder instance works properly.
57 | """
58 | # seeders 1 and 2 are identical
59 | seeder1 = Seeder(43)
60 | seeder2 = Seeder(43)
61 |
62 | # reseed seeder 2 using seeder 1
63 | seeder2.reseed(seeder1)
64 |
65 | data1 = seeder1.rng.integers(100, size=1000)
66 | data2 = seeder2.rng.integers(100, size=1000)
67 | assert (data1 != data2).sum() > 5
68 |
--------------------------------------------------------------------------------
/rlberry/seeding/tests/test_threads.py:
--------------------------------------------------------------------------------
1 | from rlberry.seeding.seeder import Seeder
2 | import concurrent.futures
3 |
4 |
5 | def get_random_number_setting_seed(seeder):
6 | return seeder.rng.integers(2**32)
7 |
8 |
9 | def test_multithread_seeding():
10 | """
11 | Checks that different seeds are given to different threads
12 | """
13 | for ii in range(5):
14 | main_seeder = Seeder(123)
15 | for jj in range(10):
16 | with concurrent.futures.ThreadPoolExecutor() as executor:
17 | futures = []
18 | for seed in main_seeder.spawn(2):
19 | futures.append(
20 | executor.submit(get_random_number_setting_seed, seed)
21 | )
22 |
23 | results = []
24 | for future in concurrent.futures.as_completed(futures):
25 | results.append(future.result())
26 | assert results[0] != results[1], f"error in simulation {(ii, jj)}"
27 |
--------------------------------------------------------------------------------
/rlberry/seeding/tests/test_threads_torch.py:
--------------------------------------------------------------------------------
1 | from rlberry.seeding.seeder import Seeder
2 | from rlberry.seeding import set_external_seed
3 | import concurrent.futures
4 |
5 | _TORCH_INSTALLED = True
6 | try:
7 | import torch
8 | except Exception:
9 | _TORCH_INSTALLED = False
10 |
11 |
12 | def get_torch_random_number_setting_seed(seeder):
13 | set_external_seed(seeder)
14 | return torch.randint(2**32, (1,))[0].item()
15 |
16 |
17 | def test_torch_multithread_seeding():
18 | """
19 | Checks that different seeds are given to different threads
20 | """
21 | for ii in range(5):
22 | main_seeder = Seeder(123)
23 | for jj in range(10):
24 | with concurrent.futures.ThreadPoolExecutor() as executor:
25 | futures = []
26 | for seed in main_seeder.spawn(2):
27 | futures.append(
28 | executor.submit(get_torch_random_number_setting_seed, seed)
29 | )
30 |
31 | results = []
32 | for future in concurrent.futures.as_completed(futures):
33 | results.append(future.result())
34 | assert results[0] != results[1], f"error in simulation {(ii, jj)}"
35 |
--------------------------------------------------------------------------------
/rlberry/spaces/__init__.py:
--------------------------------------------------------------------------------
1 | from .discrete import Discrete
2 | from .box import Box
3 | from .tuple import Tuple
4 | from .multi_discrete import MultiDiscrete
5 | from .multi_binary import MultiBinary
6 | from .dict import Dict
7 |
--------------------------------------------------------------------------------
/rlberry/spaces/box.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | import numpy as np
3 | from rlberry.seeding import Seeder
4 |
5 |
6 | class Box(gym.spaces.Box):
7 | """
8 | Class that represents a space that is a cartesian product in R^n:
9 |
10 | [a_1, b_1] x [a_2, b_2] x ... x [a_n, b_n]
11 |
12 |
13 | Inherited from gymnasium.spaces.Box for compatibility with gym.
14 |
15 | rlberry wraps gym.spaces to make sure the seeding
16 | mechanism is unified in the library (rlberry.seeding)
17 |
18 | Attributes
19 | ----------
20 | rng : numpy.random._generator.Generator
21 | random number generator provided by rlberry.seeding
22 |
23 | Methods
24 | -------
25 | reseed()
26 | get new random number generator
27 | """
28 |
29 | def __init__(self, low, high, shape=None, dtype=np.float64):
30 | gym.spaces.Box.__init__(self, low, high, shape=shape, dtype=dtype)
31 | self.seeder = Seeder()
32 |
33 | @property
34 | def rng(self):
35 | return self.seeder.rng
36 |
37 | def reseed(self, seed_seq=None):
38 | """
39 | Get new random number generator.
40 |
41 | Parameters
42 | ----------
43 | seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None
44 | Seed sequence from which to spawn the random number generator.
45 | If None, generate random seed.
46 | If int, use as entropy for SeedSequence.
47 | If seeder, use seeder.seed_seq
48 | """
49 | self.seeder.reseed(seed_seq)
50 |
51 | def sample(self):
52 | """
53 | Adapted from:
54 | https://raw.githubusercontent.com/openai/gym/master/gym/spaces/box.py
55 |
56 |
57 | Generates a single random sample inside of the Box.
58 |
59 | In creating a sample of the box, each coordinate is sampled according
60 | to the form of the interval:
61 |
62 | * [a, b] : uniform distribution
63 | * [a, oo) : shifted exponential distribution
64 | * (-oo, b] : shifted negative exponential distribution
65 | * (-oo, oo) : normal distribution
66 | """
67 | high = self.high if self.dtype.kind == "f" else self.high.astype("int64") + 1
68 | sample = np.empty(self.shape)
69 |
70 | # Masking arrays which classify the coordinates according to interval
71 | # type
72 | unbounded = ~self.bounded_below & ~self.bounded_above
73 | upp_bounded = ~self.bounded_below & self.bounded_above
74 | low_bounded = self.bounded_below & ~self.bounded_above
75 | bounded = self.bounded_below & self.bounded_above
76 |
77 | # Vectorized sampling by interval type
78 | sample[unbounded] = self.rng.normal(size=unbounded[unbounded].shape)
79 |
80 | sample[low_bounded] = (
81 | self.rng.exponential(size=low_bounded[low_bounded].shape)
82 | + self.low[low_bounded]
83 | )
84 |
85 | sample[upp_bounded] = (
86 | -self.rng.exponential(size=upp_bounded[upp_bounded].shape)
87 | + self.high[upp_bounded]
88 | )
89 |
90 | sample[bounded] = self.rng.uniform(
91 | low=self.low[bounded], high=high[bounded], size=bounded[bounded].shape
92 | )
93 | if self.dtype.kind == "i":
94 | sample = np.floor(sample)
95 |
96 | return sample.astype(self.dtype)
97 |
--------------------------------------------------------------------------------
/rlberry/spaces/dict.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | from rlberry.seeding import Seeder
3 |
4 |
5 | class Dict(gym.spaces.Dict):
6 | """
7 |
8 | Inherited from gymnasium.spaces.Dict for compatibility with gym.
9 |
10 | rlberry wraps gym.spaces to make sure the seeding
11 | mechanism is unified in the library (rlberry.seeding)
12 |
13 | Attributes
14 | ----------
15 | rng : numpy.random._generator.Generator
16 | random number generator provided by rlberry.seeding
17 |
18 | Methods
19 | -------
20 | reseed()
21 | get new random number generator
22 | """
23 |
24 | def __init__(self, spaces=None, **spaces_kwargs):
25 | gym.spaces.Dict.__init__(self, spaces, **spaces_kwargs)
26 | self.seeder = Seeder()
27 |
28 | @property
29 | def rng(self):
30 | return self.seeder.rng
31 |
32 | def reseed(self, seed_seq=None):
33 | _ = [space.reseed(seed_seq) for space in self.spaces.values()]
34 |
--------------------------------------------------------------------------------
/rlberry/spaces/discrete.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | from rlberry.seeding import Seeder
3 |
4 |
5 | class Discrete(gym.spaces.Discrete):
6 | """
7 | Class that represents discrete spaces.
8 |
9 |
10 | Inherited from gymnasium.spaces.Discrete for compatibility with gym.
11 |
12 | rlberry wraps gym.spaces to make sure the seeding
13 | mechanism is unified in the library (rlberry.seeding)
14 |
15 | Attributes
16 | ----------
17 | rng : numpy.random._generator.Generator
18 | random number generator provided by rlberry.seeding
19 |
20 | Methods
21 | -------
22 | reseed()
23 | get new random number generator
24 | """
25 |
26 | def __init__(self, n):
27 | """
28 | Parameters
29 | ----------
30 | n : int
31 | number of elements in the space
32 | """
33 | assert n >= 0, "The number of elements in Discrete must be >= 0"
34 | gym.spaces.Discrete.__init__(self, n)
35 | self.seeder = Seeder()
36 |
37 | @property
38 | def rng(self):
39 | return self.seeder.rng
40 |
41 | def reseed(self, seed_seq=None):
42 | """
43 | Get new random number generator.
44 |
45 | Parameters
46 | ----------
47 | seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None
48 | Seed sequence from which to spawn the random number generator.
49 | If None, generate random seed.
50 | If int, use as entropy for SeedSequence.
51 | If seeder, use seeder.seed_seq
52 | """
53 | self.seeder.reseed(seed_seq)
54 |
55 | def sample(self):
56 | return self.rng.integers(0, self.n)
57 |
58 | def __str__(self):
59 | objstr = "%d-element Discrete space" % self.n
60 | return objstr
61 |
--------------------------------------------------------------------------------
/rlberry/spaces/from_gym.py:
--------------------------------------------------------------------------------
1 | import rlberry.spaces
2 | import gymnasium.spaces
3 |
4 |
5 | def convert_space_from_gym(space):
6 | if isinstance(space, gymnasium.spaces.Box) and (
7 | not isinstance(space, rlberry.spaces.Box)
8 | ):
9 | return rlberry.spaces.Box(
10 | space.low, space.high, shape=space.shape, dtype=space.dtype
11 | )
12 | if isinstance(space, gymnasium.spaces.Discrete) and (
13 | not isinstance(space, rlberry.spaces.Discrete)
14 | ):
15 | return rlberry.spaces.Discrete(n=space.n)
16 | if isinstance(space, gymnasium.spaces.MultiBinary) and (
17 | not isinstance(space, rlberry.spaces.MultiBinary)
18 | ):
19 | return rlberry.spaces.MultiBinary(n=space.n)
20 | if isinstance(space, gymnasium.spaces.MultiDiscrete) and (
21 | not isinstance(space, rlberry.spaces.MultiDiscrete)
22 | ):
23 | return rlberry.spaces.MultiDiscrete(
24 | nvec=space.nvec,
25 | dtype=space.dtype,
26 | )
27 | if isinstance(space, gymnasium.spaces.Tuple) and (
28 | not isinstance(space, rlberry.spaces.Tuple)
29 | ):
30 | return rlberry.spaces.Tuple(
31 | spaces=[convert_space_from_gym(sp) for sp in space.spaces]
32 | )
33 | if isinstance(space, gymnasium.spaces.Dict) and (
34 | not isinstance(space, rlberry.spaces.Dict)
35 | ):
36 | converted_spaces = dict()
37 | for key in space.spaces:
38 | converted_spaces[key] = convert_space_from_gym(space.spaces[key])
39 | return rlberry.spaces.Dict(spaces=converted_spaces)
40 |
41 | return space
42 |
--------------------------------------------------------------------------------
/rlberry/spaces/multi_binary.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | from rlberry.seeding import Seeder
3 |
4 |
5 | class MultiBinary(gym.spaces.MultiBinary):
6 | """
7 |
8 | Inherited from gymnasium.spaces.MultiBinary for compatibility with gym.
9 |
10 | rlberry wraps gym.spaces to make sure the seeding
11 | mechanism is unified in the library (rlberry.seeding)
12 |
13 | Attributes
14 | ----------
15 | rng : numpy.random._generator.Generator
16 | random number generator provided by rlberry.seeding
17 |
18 | Methods
19 | -------
20 | reseed()
21 | get new random number generator
22 | """
23 |
24 | def __init__(self, n):
25 | gym.spaces.MultiBinary.__init__(self, n)
26 | self.seeder = Seeder()
27 |
28 | @property
29 | def rng(self):
30 | return self.seeder.rng
31 |
32 | def reseed(self, seed_seq=None):
33 | """
34 | Get new random number generator.
35 |
36 | Parameters
37 | ----------
38 | seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None
39 | Seed sequence from which to spawn the random number generator.
40 | If None, generate random seed.
41 | If int, use as entropy for SeedSequence.
42 | If seeder, use seeder.seed_seq
43 | """
44 | self.seeder.reseed(seed_seq)
45 |
46 | def sample(self):
47 | return self.rng.integers(low=0, high=2, size=self.n, dtype=self.dtype)
48 |
--------------------------------------------------------------------------------
/rlberry/spaces/multi_discrete.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | import numpy as np
3 | from rlberry.seeding import Seeder
4 |
5 |
6 | class MultiDiscrete(gym.spaces.MultiDiscrete):
7 | """
8 |
9 | Inherited from gymnasium.spaces.MultiDiscrete for compatibility with gym.
10 |
11 | rlberry wraps gym.spaces to make sure the seeding
12 | mechanism is unified in the library (rlberry.seeding)
13 |
14 | Attributes
15 | ----------
16 | rng : numpy.random._generator.Generator
17 | random number generator provided by rlberry.seeding
18 |
19 | Methods
20 | -------
21 | reseed()
22 | get new random number generator
23 | """
24 |
25 | def __init__(self, nvec, dtype=np.int64):
26 | gym.spaces.MultiDiscrete.__init__(self, nvec, dtype=dtype)
27 | self.seeder = Seeder()
28 |
29 | @property
30 | def rng(self):
31 | return self.seeder.rng
32 |
33 | def reseed(self, seed_seq=None):
34 | """
35 | Get new random number generator.
36 |
37 | Parameters
38 | ----------
39 | seed_seq : np.random.SeedSequence, rlberry.seeding.Seeder or int, default : None
40 | Seed sequence from which to spawn the random number generator.
41 | If None, generate random seed.
42 | If int, use as entropy for SeedSequence.
43 | If seeder, use seeder.seed_seq
44 | """
45 | self.seeder.reseed(seed_seq)
46 |
47 | def sample(self):
48 | sample = self.rng.random(self.nvec.shape) * self.nvec
49 | return sample.astype(self.dtype)
50 |
--------------------------------------------------------------------------------
/rlberry/spaces/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/spaces/tests/__init__.py
--------------------------------------------------------------------------------
/rlberry/spaces/tuple.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | from rlberry.seeding import Seeder
3 |
4 |
5 | class Tuple(gym.spaces.Tuple):
6 | """
7 |
8 | Inherited from gymnasium.spaces.Tuple for compatibility with gym.
9 |
10 | rlberry wraps gym.spaces to make sure the seeding
11 | mechanism is unified in the library (rlberry.seeding)
12 |
13 | Attributes
14 | ----------
15 | rng : numpy.random._generator.Generator
16 | random number generator provided by rlberry.seeding
17 |
18 | Methods
19 | -------
20 | reseed()
21 | get new random number generator
22 | """
23 |
24 | def __init__(self, spaces):
25 | gym.spaces.Tuple.__init__(self, spaces)
26 | self.seeder = Seeder()
27 |
28 | @property
29 | def rng(self):
30 | return self.seeder.rng
31 |
32 | def reseed(self, seed_seq=None):
33 | _ = [space.reseed(seed_seq) for space in self.spaces]
34 |
--------------------------------------------------------------------------------
/rlberry/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/tests/__init__.py
--------------------------------------------------------------------------------
/rlberry/tests/test_agent_extra.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import rlberry_scool.agents as agents_scool
3 | import rlberry_research.agents.torch as torch_agents
4 | from rlberry.utils.check_agent import (
5 | check_rl_agent,
6 | check_rlberry_agent,
7 | check_vectorized_env_agent,
8 | check_hyperparam_optimisation_agent,
9 | )
10 | from rlberry_scool.agents.features import FeatureMap
11 | import numpy as np
12 | import sys
13 |
14 |
15 | class OneHotFeatureMap(FeatureMap):
16 | def __init__(self, S, A):
17 | self.S = S
18 | self.A = A
19 | self.shape = (S * A,)
20 |
21 | def map(self, observation, action):
22 | feat = np.zeros((self.S, self.A))
23 | feat[observation, action] = 1.0
24 | return feat.flatten()
25 |
26 |
27 | # LSVIUCBAgent needs a feature map function to work.
28 | class OneHotLSVI(agents_scool.LSVIUCBAgent):
29 | def __init__(self, env, **kwargs):
30 | def feature_map_fn(_env):
31 | return OneHotFeatureMap(5, 2) # values for Chain
32 |
33 | agents_scool.LSVIUCBAgent.__init__(
34 | self, env, feature_map_fn=feature_map_fn, horizon=10, **kwargs
35 | )
36 |
37 |
38 | # No agent "FINITE_MDP" in extra
39 | # FINITE_MDP_AGENTS = [
40 | # ]
41 |
42 |
43 | CONTINUOUS_STATE_AGENTS = [
44 | torch_agents.DQNAgent,
45 | torch_agents.MunchausenDQNAgent,
46 | torch_agents.REINFORCEAgent,
47 | torch_agents.PPOAgent,
48 | torch_agents.A2CAgent,
49 | ]
50 |
51 | # Maybe add PPO ?
52 | CONTINUOUS_ACTIONS_AGENTS = [torch_agents.SACAgent]
53 |
54 |
55 | HYPERPARAM_OPTI_AGENTS = [
56 | torch_agents.PPOAgent,
57 | torch_agents.REINFORCEAgent,
58 | torch_agents.A2CAgent,
59 | torch_agents.SACAgent,
60 | ]
61 |
62 |
63 | MULTI_ENV_AGENTS = [
64 | torch_agents.PPOAgent,
65 | ]
66 |
67 | # No agent "FINITE_MDP" in extra
68 | # @pytest.mark.parametrize("agent", FINITE_MDP_AGENTS)
69 | # def test_finite_state_agent(agent):
70 | # check_rl_agent(agent, env="discrete_state")
71 | # check_rlberry_agent(agent, env="discrete_state")
72 |
73 |
74 | @pytest.mark.xfail(sys.platform == "win32", reason="bug with windows???")
75 | @pytest.mark.parametrize("agent", CONTINUOUS_STATE_AGENTS)
76 | def test_continuous_state_agent(agent):
77 | check_rl_agent(agent, env="continuous_state")
78 | check_rlberry_agent(agent, env="continuous_state")
79 |
80 |
81 | @pytest.mark.xfail(sys.platform == "win32", reason="bug with windows???")
82 | @pytest.mark.parametrize("agent", CONTINUOUS_ACTIONS_AGENTS)
83 | def test_continuous_action_agent(agent):
84 | check_rl_agent(agent, env="continuous_action")
85 | check_rlberry_agent(agent, env="continuous_action")
86 |
87 |
88 | @pytest.mark.xfail(sys.platform == "win32", reason="bug with windows???")
89 | @pytest.mark.parametrize("agent", MULTI_ENV_AGENTS)
90 | def test_continuous_vectorized_env_agent(agent):
91 | check_vectorized_env_agent(agent, env="vectorized_env_continuous")
92 |
93 |
94 | @pytest.mark.xfail(sys.platform == "win32", reason="bug with windows???")
95 | @pytest.mark.parametrize("agent", HYPERPARAM_OPTI_AGENTS)
96 | def test_hyperparam_optimisation_agent(agent):
97 | check_hyperparam_optimisation_agent(agent, env="continuous_state")
98 |
--------------------------------------------------------------------------------
/rlberry/tests/test_agents_base.py:
--------------------------------------------------------------------------------
1 | """
2 | ===============================================
3 | Tests for the installation without extra (StableBaselines3, torch, optuna, ...)
4 | ===============================================
5 | tests based on test_agent.py and test_envs.py
6 |
7 | """
8 |
9 |
10 | import pytest
11 | import numpy as np
12 | import sys
13 |
14 | import rlberry_research.agents as agents_research
15 | import rlberry_scool.agents as agents_scool
16 | from rlberry_scool.agents.features import FeatureMap
17 |
18 | from rlberry.utils.check_agent import (
19 | check_rl_agent,
20 | check_rlberry_agent,
21 | )
22 |
23 |
24 | class OneHotFeatureMap(FeatureMap):
25 | def __init__(self, S, A):
26 | self.S = S
27 | self.A = A
28 | self.shape = (S * A,)
29 |
30 | def map(self, observation, action):
31 | feat = np.zeros((self.S, self.A))
32 | feat[observation, action] = 1.0
33 | return feat.flatten()
34 |
35 |
36 | # LSVIUCBAgent needs a feature map function to work.
37 | class OneHotLSVI(agents_scool.LSVIUCBAgent):
38 | def __init__(self, env, **kwargs):
39 | def feature_map_fn(_env):
40 | return OneHotFeatureMap(5, 2) # values for Chain
41 |
42 | agents_scool.LSVIUCBAgent.__init__(
43 | self, env, feature_map_fn=feature_map_fn, horizon=10, **kwargs
44 | )
45 |
46 |
47 | FINITE_MDP_AGENTS = [
48 | agents_scool.QLAgent,
49 | agents_scool.SARSAAgent,
50 | agents_scool.ValueIterationAgent,
51 | agents_scool.MBQVIAgent,
52 | agents_scool.UCBVIAgent,
53 | agents_research.OptQLAgent,
54 | agents_research.PSRLAgent,
55 | agents_research.RLSVIAgent,
56 | OneHotLSVI,
57 | ]
58 |
59 |
60 | CONTINUOUS_STATE_AGENTS = [
61 | agents_research.RSUCBVIAgent,
62 | agents_research.RSKernelUCBVIAgent,
63 | ]
64 |
65 |
66 | @pytest.mark.parametrize("agent", FINITE_MDP_AGENTS)
67 | def test_finite_state_agent(agent):
68 | check_rl_agent(agent, env="discrete_state")
69 | check_rlberry_agent(agent, env="discrete_state")
70 |
71 |
72 | @pytest.mark.xfail(sys.platform == "win32", reason="bug with windows???")
73 | @pytest.mark.parametrize("agent", CONTINUOUS_STATE_AGENTS)
74 | def test_continuous_state_agent(agent):
75 | check_rl_agent(agent, env="continuous_state")
76 | check_rlberry_agent(agent, env="continuous_state")
77 |
--------------------------------------------------------------------------------
/rlberry/tests/test_envs.py:
--------------------------------------------------------------------------------
1 | from rlberry.utils.check_env import check_env, check_rlberry_env
2 | from rlberry_research.envs.benchmarks.ball_exploration import PBall2D
3 | from rlberry_research.envs.benchmarks.generalization.twinrooms import TwinRooms
4 | from rlberry_research.envs.benchmarks.grid_exploration.apple_gold import AppleGold
5 | from rlberry_research.envs.benchmarks.grid_exploration.nroom import NRoom
6 | from rlberry_research.envs.classic_control import MountainCar, SpringCartPole
7 | from rlberry_scool.envs.finite import Chain, GridWorld
8 | import pytest
9 |
10 | ALL_ENVS = [
11 | PBall2D,
12 | TwinRooms,
13 | AppleGold,
14 | NRoom,
15 | MountainCar,
16 | Chain,
17 | GridWorld,
18 | SpringCartPole,
19 | ]
20 |
21 |
22 | @pytest.mark.parametrize("Env", ALL_ENVS)
23 | def test_env(Env):
24 | check_env(Env())
25 | check_rlberry_env(Env())
26 |
--------------------------------------------------------------------------------
/rlberry/tests/test_imports.py:
--------------------------------------------------------------------------------
1 | def test_imports():
2 | import rlberry # noqa
3 | from rlberry.manager import ( # noqa
4 | ExperimentManager, # noqa
5 | evaluate_agents, # noqa
6 | plot_writer_data, # noqa
7 | ) # noqa
8 | from rlberry.agents import AgentWithSimplePolicy # noqa
9 | from rlberry.wrappers import WriterWrapper # noqa
10 |
--------------------------------------------------------------------------------
/rlberry/types.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | from typing import Any, Callable, Mapping, Tuple, Union
3 | from rlberry.seeding import Seeder
4 |
5 | # either a gymnasium.Env or a tuple containing (constructor, kwargs) to build the env
6 | Env = Union[gym.Env, Tuple[Callable[..., gym.Env], Mapping[str, Any]]]
7 |
8 | #
9 | Seed = Union[Seeder, int]
10 |
--------------------------------------------------------------------------------
/rlberry/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .check_agent import (
2 | check_rl_agent,
3 | check_save_load,
4 | check_fit_additive,
5 | check_seeding_agent,
6 | check_experiment_manager,
7 | )
8 | from .check_env import check_env
9 |
--------------------------------------------------------------------------------
/rlberry/utils/binsearch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def binary_search_nd(x_vec, bins):
5 | """n-dimensional binary search
6 |
7 | Parameters
8 | -----------
9 | x_vec : numpy.ndarray
10 | numpy 1d array to be searched in the bins
11 | bins : list
12 | list of numpy 1d array, bins[d] = bins of the d-th dimension
13 |
14 |
15 | Returns
16 | --------
17 | index (int) corresponding to the position of x in the partition
18 | defined by the bins.
19 | """
20 | dim = len(bins)
21 | flat_index = 0
22 | aux = 1
23 | assert dim == len(x_vec), "dimension mismatch in binary_search_nd()"
24 | for dd in range(dim):
25 | index_dd = np.searchsorted(bins[dd], x_vec[dd], side="right") - 1
26 | assert index_dd != -1, "error in binary_search_nd()"
27 | flat_index += aux * index_dd
28 | aux *= len(bins[dd]) - 1
29 | return flat_index
30 |
31 |
32 | def unravel_index_uniform_bin(flat_index, dim, n_per_dim):
33 | index = []
34 | aux_index = flat_index
35 | for _ in range(dim):
36 | index.append(aux_index % n_per_dim)
37 | aux_index = aux_index // n_per_dim
38 | return tuple(index)
39 |
40 |
41 | if __name__ == "__main__":
42 | bins = [(0, 1, 2, 3, 4), (0, 1, 2, 3, 4)]
43 | x = [3.9, 3.5]
44 | index = binary_search_nd(x, bins)
45 | print(index)
46 |
--------------------------------------------------------------------------------
/rlberry/utils/check_env.py:
--------------------------------------------------------------------------------
1 | from rlberry.seeding import safe_reseed
2 | from rlberry.seeding import Seeder
3 | import numpy as np
4 | from rlberry.utils.check_gym_env import check_gym_env
5 |
6 | seeder = Seeder(42)
7 |
8 |
9 | def check_env(env):
10 | """
11 | Check that the environment is (almost) gym-compatible and that it is reproducible
12 | in the sense that it returns the same states when given the same seed.
13 |
14 | Parameters
15 | ----------
16 | env: gymnasium.env or rlberry env
17 | Environment that we want to check.
18 | """
19 | # Small reproducibility test
20 | action = env.action_space.sample()
21 | safe_reseed(env, Seeder(42))
22 | env.reset()
23 | a = env.step(action)[0]
24 |
25 | safe_reseed(env, Seeder(42))
26 | env.reset()
27 | b = env.step(action)[0]
28 | if hasattr(a, "__len__"):
29 | assert np.all(
30 | np.array(a) == np.array(b)
31 | ), "The environment does not seem to be reproducible"
32 | else:
33 | assert a == b, "The environment does not seem to be reproducible"
34 |
35 | # Modified check suite from gym
36 | check_gym_env(env)
37 |
38 |
39 | def check_rlberry_env(env):
40 | """
41 | Companion to check_env, contains additional tests. It is not mandatory
42 | for an environment to satisfy this check but satisfying this check give access to
43 | additional features in rlberry.
44 |
45 | Parameters
46 | ----------
47 | env: gymnasium.env or rlberry env
48 | Environment that we want to check.
49 | """
50 | try:
51 | env.get_params()
52 | except Exception:
53 | raise RuntimeError("Fail to call get_params on the environment.")
54 |
--------------------------------------------------------------------------------
/rlberry/utils/factory.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from typing import Callable
3 |
4 |
5 | def load(path: str) -> Callable:
6 | module_name, class_name = path.rsplit(".", 1)
7 | return getattr(importlib.import_module(module_name), class_name)
8 |
--------------------------------------------------------------------------------
/rlberry/utils/space_discretizer.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from gymnasium.spaces import Box, Discrete
3 | from rlberry.utils.binsearch import binary_search_nd
4 | from rlberry.utils.binsearch import unravel_index_uniform_bin
5 |
6 |
7 | class Discretizer:
8 | def __init__(self, space, n_bins):
9 | assert isinstance(
10 | space, Box
11 | ), "Discretization is only implemented for Box spaces."
12 | assert space.is_bounded()
13 | self.space = space
14 | self.n_bins = n_bins
15 |
16 | # initialize bins
17 | assert n_bins > 0, "Discretizer requires n_bins > 0"
18 | n_elements = 1
19 | tol = 1e-8
20 | self.dim = len(self.space.low)
21 | n_elements = n_bins**self.dim
22 | self._bins = []
23 | self._open_bins = []
24 | for dd in range(self.dim):
25 | range_dd = self.space.high[dd] - self.space.low[dd]
26 | epsilon = range_dd / n_bins
27 | bins_dd = []
28 | for bb in range(n_bins + 1):
29 | val = self.space.low[dd] + epsilon * bb
30 | bins_dd.append(val)
31 | self._open_bins.append(tuple(bins_dd[1:]))
32 | bins_dd[-1] += tol # "close" the last interval
33 | self._bins.append(tuple(bins_dd))
34 |
35 | # set observation space
36 | self.discrete_space = Discrete(n_elements)
37 |
38 | # List of discretized elements
39 | self.discretized_elements = np.zeros((self.dim, n_elements))
40 | for ii in range(n_elements):
41 | self.discretized_elements[:, ii] = self.get_coordinates(ii, False)
42 |
43 | def discretize(self, coordinates):
44 | return binary_search_nd(coordinates, self._bins)
45 |
46 | def get_coordinates(self, flat_index, randomize=False):
47 | assert self.discrete_space.contains(flat_index), "invalid flat_index"
48 | # get multi-index
49 | index = unravel_index_uniform_bin(flat_index, self.dim, self.n_bins)
50 |
51 | # get coordinates
52 | coordinates = np.zeros(self.dim)
53 | for dd in range(self.dim):
54 | coordinates[dd] = self._bins[dd][index[dd]]
55 | if randomize:
56 | range_dd = self.space.high[dd] - self.space.low[dd]
57 | epsilon = range_dd / self.n_bins
58 | coordinates[dd] += epsilon * self.space.rng.uniform()
59 | return coordinates
60 |
--------------------------------------------------------------------------------
/rlberry/utils/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/rlberry-py/rlberry/f698f2636f159e772cacca79cdf36cf51e8cbf6b/rlberry/utils/tests/__init__.py
--------------------------------------------------------------------------------
/rlberry/utils/tests/test_binsearch.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from rlberry.utils.binsearch import binary_search_nd
5 | from rlberry.utils.binsearch import unravel_index_uniform_bin
6 |
7 |
8 | def test_binary_search_nd():
9 | bin1 = np.array([0.0, 1.0, 2.0, 3.0]) # 3 intervals
10 | bin2 = np.array([1.0, 2.0, 3.0, 4.0]) # 3 intervals
11 | bin3 = np.array([2.0, 3.0, 4.0, 5.0, 6.0]) # 4 intervals
12 |
13 | bins = [bin1, bin2, bin3]
14 |
15 | vec1 = np.array([0.0, 1.0, 2.0])
16 | vec2 = np.array([2.9, 3.9, 5.9])
17 | vec3 = np.array([1.5, 2.5, 2.5])
18 | vec4 = np.array([1.5, 2.5, 2.5])
19 |
20 | # index = i + Ni * j + Ni * Nj * k
21 | assert binary_search_nd(vec1, bins) == 0
22 | assert binary_search_nd(vec2, bins) == 2 + 3 * 2 + 3 * 3 * 3
23 | assert binary_search_nd(vec3, bins) == 1 + 3 * 1 + 3 * 3 * 0
24 |
25 |
26 | @pytest.mark.parametrize(
27 | "i, j, k, N", [(0, 0, 0, 5), (0, 1, 2, 5), (4, 3, 2, 5), (4, 4, 4, 5)]
28 | )
29 | def test_unravel_index_uniform_bin(i, j, k, N):
30 | # index = i + N * j + N * N * k
31 | dim = 3
32 | flat_index = i + N * j + N * N * k
33 | assert (i, j, k) == unravel_index_uniform_bin(flat_index, dim, N)
34 |
--------------------------------------------------------------------------------
/rlberry/utils/tests/test_writer.py:
--------------------------------------------------------------------------------
1 | import time
2 | from rlberry_research.envs import GridWorld
3 | from rlberry.agents import AgentWithSimplePolicy
4 | from rlberry.manager import ExperimentManager
5 |
6 |
7 | class DummyAgent(AgentWithSimplePolicy):
8 | def __init__(self, env, hyperparameter1=0, hyperparameter2=0, **kwargs):
9 | AgentWithSimplePolicy.__init__(self, env, **kwargs)
10 | self.name = "DummyAgent"
11 | self.fitted = False
12 | self.hyperparameter1 = hyperparameter1
13 | self.hyperparameter2 = hyperparameter2
14 |
15 | self.total_budget = 0.0
16 |
17 | def fit(self, budget, **kwargs):
18 | del kwargs
19 | self.fitted = True
20 | self.total_budget += budget
21 | for ii in range(budget):
22 | if self.writer is not None:
23 | self.writer.add_scalar("a", ii, ii)
24 | scalar_dict = dict(multi1=1, multi2=2)
25 | self.writer.add_scalars("multi_scalar_test", scalar_dict)
26 | time.sleep(1)
27 |
28 | return None
29 |
30 | def policy(self, observation):
31 | return 0
32 |
33 |
34 | def test_myoutput(capsys): # or use "capfd" for fd-level
35 | env_ctor = GridWorld
36 | env_kwargs = dict()
37 |
38 | env = env_ctor(**env_kwargs)
39 | xp_manager = ExperimentManager(
40 | DummyAgent,
41 | (env_ctor, env_kwargs),
42 | fit_budget=3,
43 | n_fit=1,
44 | default_writer_kwargs={"log_interval": 1},
45 | )
46 | budget_size = 22
47 | xp_manager.fit(budget=budget_size)
48 |
49 | assert xp_manager.agent_handlers[0].writer.summary_writer == None
50 | assert list(xp_manager.agent_handlers[0].writer.read_tag_value("a")) == list(
51 | range(budget_size)
52 | )
53 | assert xp_manager.agent_handlers[0].writer.read_first_tag_value("a") == 0
54 | assert (
55 | xp_manager.agent_handlers[0].writer.read_last_tag_value("a") == budget_size - 1
56 | ) # start at 0
57 |
58 | captured = capsys.readouterr()
59 | # test that what is written to stderr is longer than 50 char,
60 | assert (
61 | len(captured.err) + len(captured.out) > 50
62 | ), "the logging did not print the info to stderr"
63 |
--------------------------------------------------------------------------------
/rlberry/utils/torch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import shutil
4 | from subprocess import check_output, run, PIPE
5 | import numpy as np
6 | import torch
7 |
8 |
9 | import rlberry
10 |
11 | logger = rlberry.logger
12 |
13 |
14 | def get_gpu_memory_map():
15 | result = check_output(
16 | ["nvidia-smi", "--query-gpu=memory.used", "--format=csv,nounits,noheader"]
17 | )
18 | return [int(x) for x in result.split()]
19 |
20 |
21 | def least_used_device():
22 | """Get the GPU device with most available memory."""
23 | if not torch.cuda.is_available():
24 | raise RuntimeError("cuda unavailable")
25 |
26 | if shutil.which("nvidia-smi") is None:
27 | raise RuntimeError(
28 | "nvidia-smi unavailable: \
29 | cannot select device with most least memory used."
30 | )
31 |
32 | memory_map = get_gpu_memory_map()
33 | device_id = np.argmin(memory_map)
34 | logger.debug(
35 | f"Choosing GPU device: {device_id}, " f"memory used: {memory_map[device_id]}"
36 | )
37 | return torch.device("cuda:{}".format(device_id))
38 |
39 |
40 | def choose_device(preferred_device, default_device="cpu"):
41 | """Choose torch device, use default if choice is not available.
42 |
43 | Parameters
44 | ----------
45 | preferred_device: str
46 | Torch device to be used (if available), e.g. "cpu", "cuda:0", "cuda:best".
47 | If "cuda:best", returns the least used device in the machine.
48 | default_device: str, default = "cpu"
49 | Default device if preferred_device is not available.
50 | """
51 | if preferred_device == "cuda:best":
52 | try:
53 | preferred_device = least_used_device()
54 | except RuntimeError:
55 | logger.debug(
56 | f"Could not find least used device (nvidia-smi might be missing), use cuda:0 instead"
57 | )
58 | if torch.cuda.is_available():
59 | return choose_device("cuda:0")
60 | else:
61 | return choose_device("cpu")
62 | try:
63 | torch.zeros((1,), device=preferred_device) # Test availability
64 | except (RuntimeError, AssertionError) as e:
65 | logger.debug(
66 | f"Preferred device {preferred_device} unavailable ({e})."
67 | f"Switching to default {default_device}"
68 | )
69 | return default_device
70 | return preferred_device
71 |
72 |
73 | def get_memory(pid=None):
74 | if not pid:
75 | pid = os.getpid()
76 | command = "nvidia-smi"
77 | result = run(
78 | command, stdout=PIPE, stderr=PIPE, universal_newlines=True, shell=True
79 | ).stdout
80 | m = re.findall(
81 | r"\| *[0-9] *" + str(pid) + r" *C *.*python.*? +([0-9]+).*\|",
82 | result,
83 | re.MULTILINE,
84 | )
85 | return [int(mem) for mem in m]
86 |
--------------------------------------------------------------------------------
/rlberry/wrappers/__init__.py:
--------------------------------------------------------------------------------
1 | from .discretize_state import DiscretizeStateWrapper
2 | from .rescale_reward import RescaleRewardWrapper
3 | from .writer_utils import WriterWrapper
4 | from .discrete2onehot import DiscreteToOneHotWrapper
5 |
--------------------------------------------------------------------------------
/rlberry/wrappers/autoreset.py:
--------------------------------------------------------------------------------
1 | from rlberry.envs import Wrapper
2 |
3 |
4 | class AutoResetWrapper(Wrapper):
5 | """
6 | Auto reset the environment after "horizon" steps have passed.
7 | """
8 |
9 | def __init__(self, env, horizon):
10 | """
11 | Parameters
12 | ----------
13 | horizon: int
14 | """
15 | Wrapper.__init__(self, env)
16 | self.horizon = horizon
17 | assert self.horizon >= 1
18 | self.current_step = 0
19 |
20 | def reset(self, seed=None, options=None):
21 | self.current_step = 0
22 | return self.env.reset(seed=seed, options=options)
23 |
24 | def step(self, action):
25 | observation, reward, terminated, truncated, info = self.env.step(action)
26 | self.current_step += 1
27 | # At H, always return to the initial state.
28 | # Also, set done to True.
29 | if self.current_step == self.horizon:
30 | self.current_step = 0
31 | observation, info = self.env.reset()
32 | terminated = True
33 | truncated = False
34 | return observation, reward, terminated, truncated, info
35 |
--------------------------------------------------------------------------------
/rlberry/wrappers/discrete2onehot.py:
--------------------------------------------------------------------------------
1 | from rlberry.spaces import Box, Discrete
2 | from rlberry.envs import Wrapper
3 | import numpy as np
4 |
5 |
6 | class DiscreteToOneHotWrapper(Wrapper):
7 | """Converts observation spaces from Discrete to Box via one-hot encoding."""
8 |
9 | def __init__(self, env):
10 | Wrapper.__init__(self, env, wrap_spaces=True)
11 | obs_space = self.env.observation_space
12 | assert isinstance(obs_space, Discrete)
13 | self.observation_space = Box(
14 | low=0.0, high=1.0, shape=(obs_space.n,), dtype=np.uint32
15 | )
16 |
17 | def process_obs(self, obs):
18 | one_hot_obs = np.zeros(self.env.observation_space.n, dtype=np.uint32)
19 | one_hot_obs[obs] = 1.0
20 | return one_hot_obs
21 |
22 | def reset(self, seed=None, options=None):
23 | obs, info = self.env.reset(seed=seed, options=options)
24 | return self.process_obs(obs), info
25 |
26 | def step(self, action):
27 | observation, reward, terminated, truncated, info = self.env.step(action)
28 | observation = self.process_obs(observation)
29 | return observation, reward, terminated, truncated, info
30 |
--------------------------------------------------------------------------------
/rlberry/wrappers/gym_utils.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | from gymnasium.utils.step_api_compatibility import step_api_compatibility
3 | from rlberry.spaces import Discrete
4 | from rlberry.spaces import Box
5 | from rlberry.spaces import Tuple
6 | from rlberry.spaces import MultiDiscrete
7 | from rlberry.spaces import MultiBinary
8 | from rlberry.spaces import Dict
9 |
10 | from rlberry.envs import Wrapper
11 |
12 |
13 | def convert_space_from_gym(gym_space):
14 | if isinstance(gym_space, gym.spaces.Discrete):
15 | return Discrete(gym_space.n)
16 | #
17 | #
18 | elif isinstance(gym_space, gym.spaces.Box):
19 | return Box(gym_space.low, gym_space.high, gym_space.shape, gym_space.dtype)
20 | #
21 | #
22 | elif isinstance(gym_space, gym.spaces.Tuple):
23 | spaces = []
24 | for sp in gym_space.spaces:
25 | spaces.append(convert_space_from_gym(sp))
26 | return Tuple(spaces)
27 | #
28 | #
29 | elif isinstance(gym_space, gym.spaces.MultiDiscrete):
30 | return MultiDiscrete(gym_space.nvec)
31 | #
32 | #
33 | elif isinstance(gym_space, gym.spaces.MultiBinary):
34 | return MultiBinary(gym_space.n)
35 | #
36 | #
37 | elif isinstance(gym_space, gym.spaces.Dict):
38 | spaces = {}
39 | for key in gym_space.spaces:
40 | spaces[key] = convert_space_from_gym(gym_space[key])
41 | return Dict(spaces)
42 | else:
43 | raise ValueError("Unknown space class: {}".format(type(gym_space)))
44 |
45 |
46 | class OldGymCompatibilityWrapper(Wrapper):
47 | """
48 | Allow to use old gym env (V0.21) with rlberry (gymnasium).
49 | (for basic use only)
50 | """
51 |
52 | def __init__(self, env):
53 | Wrapper.__init__(self, env)
54 |
55 | def reset(self, seed=None, options=None):
56 | if seed:
57 | self.env.reseed(seed)
58 | observation = self.env.reset()
59 | return observation, {}
60 |
61 | def step(self, action):
62 | obs, rewards, terminated, truncated, info = step_api_compatibility(
63 | self.env.step(action), output_truncation_bool=True
64 | )
65 | return obs, rewards, terminated, truncated, info
66 |
--------------------------------------------------------------------------------
/rlberry/wrappers/rescale_reward.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from rlberry.envs import Wrapper
3 |
4 |
5 | class RescaleRewardWrapper(Wrapper):
6 | """
7 | Rescale the reward function to a bounded range.
8 |
9 | Parameters
10 | ----------
11 | reward_range: tuple (double, double)
12 | tuple with the desired reward range, which needs to be bounded.
13 | """
14 |
15 | def __init__(self, env, reward_range):
16 | Wrapper.__init__(self, env)
17 | self.reward_range = reward_range
18 | assert reward_range[0] < reward_range[1]
19 | assert reward_range[0] > -np.inf and reward_range[1] < np.inf
20 |
21 | def _linear_rescaling(self, x, x0, x1, u0, u1):
22 | """
23 | For x a value in [x0, x1], maps x linearly to the interval [u0, u1].
24 | """
25 | a = (u1 - u0) / (x1 - x0)
26 | b = (x1 * u0 - x0 * u1) / (x1 - x0)
27 | return a * x + b
28 |
29 | def _rescale(self, reward):
30 | x0, x1 = self.env.reward_range
31 | u0, u1 = self.reward_range
32 | # bounded reward
33 | if x0 > -np.inf and x1 < np.inf:
34 | return self._linear_rescaling(reward, x0, x1, u0, u1)
35 | # unbounded
36 | elif x0 > -np.inf and x1 == np.inf:
37 | x = reward - x0 # [0, infty]
38 | x = 2.0 / (1.0 + np.exp(-x)) - 1.0 # [0, 1]
39 | return self._linear_rescaling(x, 0.0, 1.0, u0, u1)
40 | # unbouded below
41 | elif x0 == -np.inf and x1 < np.inf:
42 | x = reward - x1 # [-infty, 0]
43 | x = 2.0 / (1.0 + np.exp(-x)) # [0, 1]
44 | return self._linear_rescaling(x, 0.0, 1.0, u0, u1)
45 | # unbounded
46 | else:
47 | x = 1.0 / (1.0 + np.exp(-reward)) # [0, 1]
48 | return self._linear_rescaling(x, 0.0, 1.0, u0, u1)
49 |
50 | def step(self, action):
51 | observation, reward, terminated, truncated, info = self.env.step(action)
52 | rescaled_reward = self._rescale(reward)
53 | return observation, rescaled_reward, terminated, truncated, info
54 |
55 | def sample(self, state, action):
56 | observation, reward, terminated, truncated, info = self.env.sample(
57 | state, action
58 | )
59 | rescaled_reward = self._rescale(reward)
60 | return observation, rescaled_reward, terminated, truncated, info
61 |
--------------------------------------------------------------------------------
/rlberry/wrappers/tests/__init__.py:
--------------------------------------------------------------------------------
1 | from .old_env import (
2 | old_acrobot,
3 | old_twinrooms,
4 | old_six_room,
5 | old_apple_gold,
6 | old_ball2d,
7 | old_finite_mdp,
8 | old_four_room,
9 | old_gridworld,
10 | old_mountain_car,
11 | old_nroom,
12 | old_pball,
13 | old_pendulum,
14 | )
15 |
--------------------------------------------------------------------------------
/rlberry/wrappers/tests/old_env/__init__.py:
--------------------------------------------------------------------------------
1 | from .old_acrobot import Old_Acrobot
2 | from .old_apple_gold import Old_AppleGold
3 | from .old_four_room import Old_FourRoom
4 | from .old_gridworld import Old_GridWorld
5 | from .old_mountain_car import Old_MountainCar
6 | from .old_nroom import Old_NRoom
7 | from .old_pendulum import Old_Pendulum
8 | from .old_pball import Old_PBall2D, Old_SimplePBallND
9 | from .old_six_room import Old_SixRoom
10 | from .old_twinrooms import Old_TwinRooms
11 |
--------------------------------------------------------------------------------
/rlberry/wrappers/tests/test_basewrapper.py:
--------------------------------------------------------------------------------
1 | from rlberry.envs.interface import Model
2 | from rlberry.envs import Wrapper
3 | from rlberry_research.envs import GridWorld
4 | import gymnasium as gym
5 |
6 |
7 | def test_wrapper():
8 | env = GridWorld()
9 | wrapped = Wrapper(env)
10 | assert isinstance(wrapped, Model)
11 | assert wrapped.is_online()
12 | assert wrapped.is_generative()
13 |
14 | # calling some functions
15 | wrapped.reset()
16 | wrapped.step(wrapped.action_space.sample())
17 | wrapped.sample(wrapped.observation_space.sample(), wrapped.action_space.sample())
18 |
19 |
20 | def test_gym_wrapper():
21 | gym_env = gym.make("Acrobot-v1")
22 | wrapped = Wrapper(gym_env)
23 | assert isinstance(wrapped, Model)
24 | assert wrapped.is_online()
25 | assert not wrapped.is_generative()
26 |
27 | wrapped.reseed()
28 |
29 | # calling some gym functions
30 | wrapped.close()
31 | wrapped.seed()
32 |
--------------------------------------------------------------------------------
/rlberry/wrappers/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import gymnasium as gym
2 | from rlberry.wrappers.utils import get_base_env
3 |
4 |
5 | def test_get_base_env():
6 | """Test that the utils function 'get_base_env' return the wrapped env without the wrappers"""
7 |
8 | from rlberry.envs.basewrapper import Wrapper
9 | from stable_baselines3.common.monitor import Monitor
10 |
11 | from stable_baselines3.common.atari_wrappers import ( # isort:skip
12 | FireResetEnv,
13 | MaxAndSkipEnv,
14 | NoopResetEnv,
15 | NoopResetEnv,
16 | StickyActionEnv,
17 | )
18 |
19 | env = gym.make("ALE/Breakout-v5")
20 | original_env = env
21 |
22 | # add wrappers
23 | env = Wrapper(env)
24 | env = Monitor(env)
25 | env = StickyActionEnv(env, 0.2)
26 | env = NoopResetEnv(env, noop_max=2)
27 | env = MaxAndSkipEnv(env, skip=4)
28 | env = FireResetEnv(env)
29 | env = gym.wrappers.GrayscaleObservation(env)
30 | env = gym.wrappers.FrameStackObservation(env, 8)
31 | assert original_env != env
32 |
33 | # use the tool
34 | unwrapped_env = get_base_env(env)
35 |
36 | # test the result
37 | assert unwrapped_env != env
38 | assert isinstance(unwrapped_env, gym.Env)
39 | assert unwrapped_env == get_base_env(original_env)
40 |
--------------------------------------------------------------------------------
/rlberry/wrappers/tests/test_writer_utils.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from rlberry_scool.envs import GridWorld
4 |
5 | from rlberry_scool.agents import UCBVIAgent
6 |
7 |
8 | @pytest.mark.parametrize(
9 | "write_scalar", [None, "action", "reward", "action_and_reward"]
10 | )
11 | def test_wrapper(write_scalar):
12 | """Test that the wrapper record data"""
13 |
14 | class MyAgent(UCBVIAgent):
15 | def __init__(self, env, **kwargs):
16 | UCBVIAgent.__init__(self, env, writer_extra=write_scalar, **kwargs)
17 |
18 | env = GridWorld()
19 | agent = MyAgent(env)
20 | agent.fit(budget=10)
21 | assert len(agent.writer.data) > 0
22 |
23 |
24 | def test_invalid_wrapper_name():
25 | """Test that warning message is thrown when unsupported write_scalar."""
26 |
27 | class MyAgent(UCBVIAgent):
28 | def __init__(self, env, **kwargs):
29 | UCBVIAgent.__init__(self, env, writer_extra="invalid", **kwargs)
30 |
31 | msg = "write_scalar invalid is not known"
32 | env = GridWorld()
33 | agent = MyAgent(env)
34 | with pytest.raises(ValueError, match=msg):
35 | agent.fit(budget=10)
36 |
--------------------------------------------------------------------------------
/rlberry/wrappers/utils.py:
--------------------------------------------------------------------------------
1 | def get_base_env(env):
2 | """Traverse the wrappers to find the base environment."""
3 | while hasattr(env, "env"):
4 | env = env.env
5 | return env
6 |
--------------------------------------------------------------------------------
/rlberry/wrappers/writer_utils.py:
--------------------------------------------------------------------------------
1 | from rlberry.envs import Wrapper
2 |
3 |
4 | class WriterWrapper(Wrapper):
5 | """
6 | Wrapper for environment to automatically record reward or action in writer.
7 |
8 | Parameters
9 | ----------
10 | env : gymnasium.Env or tuple (constructor, kwargs)
11 | Environment used to fit the agent.
12 |
13 | writer : object, default: None
14 | Writer object (e.g. tensorboard SummaryWriter).
15 |
16 | write_scalar : string in {"reward", "action", "action_and_reward"},
17 | default = "reward"
18 | Scalar that will be recorded in the writer.
19 |
20 | """
21 |
22 | def __init__(self, env, writer, write_scalar="reward"):
23 | Wrapper.__init__(self, env)
24 | self.writer = writer
25 | self.write_scalar = write_scalar
26 | self.iteration_ = 0
27 |
28 | def step(self, action):
29 | observation, reward, terminated, truncated, info = self.env.step(action)
30 |
31 | self.iteration_ += 1
32 | if self.write_scalar == "reward":
33 | self.writer.add_scalar("reward", reward, self.iteration_)
34 | elif self.write_scalar == "action":
35 | self.writer.add_scalar("action", action, self.iteration_)
36 | elif self.write_scalar == "action_and_reward":
37 | self.writer.add_scalar("reward", reward, self.iteration_)
38 | self.writer.add_scalar("action", action, self.iteration_)
39 | else:
40 | raise ValueError("write_scalar %s is not known" % (self.write_scalar))
41 |
42 | return observation, reward, terminated, truncated, info
43 |
--------------------------------------------------------------------------------
/scripts/apptainer_for_tests/README.md:
--------------------------------------------------------------------------------
1 |
2 | To test rlberry, you can use this script to create a container that install the latest version, run the tests, and send the result by email.
3 | (or you can check inside the .sh file to only get the part you need)
4 |
5 | :warning: **WARNING** :warning: : In both files, you have to update the paths and names
6 |
7 | ## .def
8 | Scripts to build your apptainer.
9 | 2 scripts :
10 | - 1 with the "current" version of python (from ubuntu:last)
11 | - 1 with a specific version of python to choose
12 |
13 | ## .sh
14 | Script to run your apptainer and send the report
15 | use chmod +x [name].sh to make it executable
16 |
17 | To run this script you need to install "mailutils" first (to send the report by email)
18 |
--------------------------------------------------------------------------------
/scripts/apptainer_for_tests/monthly_test_base.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Set the email recipient and subject
4 | recipient="email1@inria.fr,email2@inria.fr"
5 |
6 | # Build, then run, the Apptainer container and capture the test results
7 | cd [Path_to_your_apptainer_folder]
8 | #build the apptainer from the .def file into an .sif image
9 | apptainer build --fakeroot --force rlberry_apptainer_base.sif rlberry_apptainer_base.def
10 | #Run the "runscript" section, and export the result inside a file (named previously)
11 | apptainer run --fakeroot --overlay my_overlay/ rlberry_apptainer.sif > "$attachment"
12 |
13 |
14 | # Send the test results by email
15 | exit_code1=$(cat [path]/exit_code1.txt) # Read the exit code from the file (tests)
16 | exit_code2=$(cat [path]/exit_code2.txt) # Read the exit code from the file (long tests)
17 | exit_code3=$(cat [path]/exit_code3.txt) # Read the exit code from the file (doc test)
18 |
19 | if [ $exit_code -eq 0 ]; then
20 | # Initialization when the exit code is 0 (success)
21 | subject="Rlberry : Success Monthly Test Report"
22 | core_message="Success. Please find attached the monthly test reports."
23 | else
24 | # Initialization when the exit code is not 0 (failed)
25 | subject="Rlberry : Failed Monthly Test Report"
26 | core_message="Failed. Please find attached the monthly test reports."
27 | fi
28 |
29 |
30 | echo "$core_message" | mail -s "$subject" -A "test_result.txt" -A "long_test_result.txt" -A "doc_test_result.txt" -A"lib_versions.txt" "$recipient" -aFrom:"Rlberry_Monthly_tests"
31 |
--------------------------------------------------------------------------------
/scripts/apptainer_for_tests/rlberry_apptainer__specific_python.def:
--------------------------------------------------------------------------------
1 | Bootstrap: docker
2 | From: ubuntu:latest
3 |
4 | #script for the build
5 | %post -c /bin/bash
6 |
7 | #get the last Ubuntu Update, and add the desdsnakes ppa to acces other python version
8 | apt-get update \
9 | && apt-get upgrade -y
10 | apt-get install -y software-properties-common
11 | add-apt-repository ppa:deadsnakes/ppa -y
12 | apt-get update
13 |
14 | # Install python, and graphic and basic libs. Don't forget to change [version] by the python you want (python[version] > python3.11), then set the new python as "main" python
15 | DEBIAN_FRONTEND=noninteractive TZ=Etc/UTC apt-get install -y python[version] python[version]-dev python[version]-venv python3-pip git ffmpeg libsm6 libxext6 libsdl2-dev xvfb x11-xkb-utils --no-install-recommends
16 | update-alternatives --install /usr/bin/python3 python3 /usr/bin/python[version] 1
17 | pip3 install --upgrade pip
18 |
19 |
20 | #Remove the old tmp folder if it exist, then install rlberry
21 | if [ -d /tmp/rlberry_test_dir[version] ]; then /bin/rm -r /tmp/rlberry_test_dir[version]; fi
22 | git clone https://github.com/rlberry-py/rlberry.git /tmp/rlberry_test_dir[version]
23 |
24 | #Install all the lib we need to run rlberry and its tests
25 | pip3 install rlberry[torch_agents] opencv-python pytest pytest-xvfb pytest-xprocess tensorboard #--break-system-packages
26 | pip3 install gymnasium[other]
27 |
28 | #Environmment variable, Don't forget to change [version]
29 | %environment
30 | export LC_ALL=C
31 | export PATH="/usr/bin/python[version]:$PATH"
32 |
33 | #script that will be executed with the "run" command : run the tests in rlberry, then export the exit code inside a text file
34 | %runscript
35 | cd /tmp/rlberry_test_dir[version] && \
36 | pytest rlberry
37 | echo $? > [path]/exit_code.txt
38 |
--------------------------------------------------------------------------------
/scripts/apptainer_for_tests/rlberry_apptainer_base.def:
--------------------------------------------------------------------------------
1 | Bootstrap: docker
2 | From: ubuntu:latest
3 |
4 | %post -c /bin/bash
5 |
6 | #get the last Ubuntu Update
7 | apt-get update \
8 | && apt-get upgrade -y
9 |
10 | echo "export PS1=Apptainer-\[\e]0;\u@\h: \w\a\]${debian_chroot:+($debian_chroot)}\[\033[01;32m\]\u@\h\[\033[00m\]:\[\033[01;34m\]\w\[\033[00m\]\$" >> /etc/profile
11 |
12 | # Install python, and graphic and basic libs.
13 | apt-get install -y software-properties-common python3-pip git ffmpeg libsm6 libxext6 libsdl2-dev xvfb x11-xkb-utils libblas-dev liblapack-dev
14 | pip3 install --upgrade pip setuptools wheel
15 |
16 | # remove the old folder
17 | if [ -d /tmp/rlberry_test_dir ]; then /bin/rm -r /tmp/rlberry_test_dir; fi
18 | if [ -d /tmp/rlberry-research_test_dir ]; then /bin/rm -r /tmp/rlberry-research_test_dir; fi
19 |
20 | # get all the git repos to do their tests
21 | git clone https://github.com/rlberry-py/rlberry.git /tmp/rlberry_test_dir
22 |
23 | cd /tmp/rlberry_test_dir/
24 | git fetch --tags
25 | latestTag=$(git describe --tags `git rev-list --tags --max-count=1`)
26 | git checkout $latestTag
27 |
28 |
29 | git clone https://github.com/rlberry-py/rlberry-research.git /tmp/rlberry-research_test_dir
30 |
31 | cd /tmp/rlberry-research_test_dir/
32 | git fetch --tags
33 | latestTagResearch=$(git describe --tags `git rev-list --tags --max-count=1`)
34 | git checkout $latestTagResearch
35 |
36 | # install rlberry, rlberry-scool and rlberry-research
37 | pip3 install rlberry[torch,doc,extras] rlberry-scool opencv-python pytest pytest-xvfb pytest-xprocess tensorboard #--break-system-packages
38 | pip3 install git+https://github.com/rlberry-py/rlberry-research.git
39 |
40 |
41 |
42 | %environment
43 | export LC_ALL=C
44 |
45 | %runscript
46 | # info about current versions
47 | pip list > [path]/lib_versions.txt
48 |
49 | #run tests
50 | (cd /tmp/rlberry_test_dir && \
51 | date && \
52 | pytest rlberry && \
53 | date) > [path]/test_result.txt
54 | echo $? > [path]/exit_code1.txt
55 |
56 | #run long tests
57 | (cd /tmp/rlberry-research_test_dir && \
58 | date && \
59 | pytest long_tests/rl_agent/ltest_mbqvi_applegold.py long_tests/torch_agent/ltest_a2c_cartpole.py long_tests/torch_agent/ltest_ctn_ppo_a2c_pendulum.py long_tests/torch_agent/ltest_dqn_montaincar.py && \
60 | date) > [path]/long_test_result.txt
61 | #pytest --ignore=long_tests/torch_agent/ltest_dqn_vs_mdqn_acrobot.py long_tests/**/*.py
62 | echo $? > [path]exit_code2.txt
63 |
64 | #run doc test
65 | (cd /tmp/rlberry_test_dir/docs/ && \
66 | date && \
67 | ./markdown_to_py.sh && \
68 | for f in python_scripts/*.py; do python3 $f ; done && \
69 | date) > [path]/doc_test_result.txt
70 | echo $? > [path]/exit_code3.txt
71 |
--------------------------------------------------------------------------------
/scripts/build_docs.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../docs/
4 | sphinx-apidoc -o source/ ../rlberry
5 | make html
6 | cd ..
7 |
8 | # Useful: https://samnicholls.net/2016/06/15/how-to-sphinx-readthedocs/
9 |
--------------------------------------------------------------------------------
/scripts/conda_env_setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd $CONDA_PREFIX
4 | mkdir -p ./etc/conda/activate.d
5 | mkdir -p ./etc/conda/deactivate.d
6 | touch ./etc/conda/activate.d/env_vars.sh
7 | touch ./etc/conda/deactivate.d/env_vars.sh
8 |
9 | echo '#!/bin/sh' > ./etc/conda/activate.d/env_vars.sh
10 | echo >> ./etc/conda/activate.d/env_vars.sh
11 | echo "export LD_LIBRARY_PATH=$CONDA_PREFIX/lib" >> ./etc/conda/activate.d/env_vars.sh
12 |
13 | echo '#!/bin/sh' > ./etc/conda/deactivate.d/env_vars.sh
14 | echo >> ./etc/conda/deactivate.d/env_vars.sh
15 | echo "unset LD_LIBRARY_PATH" >> ./etc/conda/deactivate.d/env_vars.sh
16 |
17 | echo "Contents of $CONDA_PREFIX/etc/conda/activate.d/env_vars.sh:"
18 | cat ./etc/conda/activate.d/env_vars.sh
19 | echo ""
20 |
21 | echo "Contents of $CONDA_PREFIX/etc/conda/deactivate.d/env_vars.sh:"
22 | cat ./etc/conda/deactivate.d/env_vars.sh
23 | echo ""
24 |
--------------------------------------------------------------------------------
/scripts/construct_video_examples.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Script used to construct the videos for the examples that output videos.
4 | # Please use a script name that begins with video_plot as with the other examples
5 | # and in the script, there should be a line to save the video in the right place
6 | # and a line to load the video in the headers. Look at existing examples for
7 | # the correct syntax. Be careful that you must remove the _build folder before
8 | # recompiling the doc when a video has been updated/added.
9 |
10 |
11 | for f in ../examples/video_plot*.py ;
12 | do
13 | # construct the mp4
14 | python $f
15 | name=$(basename $f)
16 | # make a thumbnail. Warning : the video should have the same name as the python script
17 | # i.e. video_plot_SOMETHING.mp4 to be detected
18 | ffmpeg -i ../docs/_video/${name%%.py}.mp4 -vframes 1 -f image2 ../docs/thumbnails/${name%%.py}.jpg
19 | done
20 |
--------------------------------------------------------------------------------
/scripts/full_install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Install everything!
4 |
5 | pip install -e .[default]
6 | pip install -e .[jax_agents]
7 | pip install -e .[torch_agents]
8 |
9 | pip install pytest
10 | pip install pytest-cov
11 | conda install -c conda-forge jupyterlab
12 |
--------------------------------------------------------------------------------
/scripts/run_testscov.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # disable JIT to get complete coverage report
4 | export NUMBA_DISABLE_JIT=1
5 |
6 | # run pytest
7 | cd ..
8 | pytest --cov=rlberry --cov-report html:cov_html
9 |
--------------------------------------------------------------------------------