├── .coveragerc
├── .editorconfig
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.md
│ └── feature_request.md
└── workflows
│ ├── format_check.yml
│ └── test.yml
├── .gitignore
├── .readthedocs.yaml
├── CITATION.cff
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── ROADMAP.md
├── assets
├── breakout.gif
├── breakout.png
├── hopper.png
├── logo.png
└── mujoco_hopper.gif
├── d3rlpy
├── __init__.py
├── _version.py
├── algos
│ ├── __init__.py
│ ├── qlearning
│ │ ├── __init__.py
│ │ ├── awac.py
│ │ ├── base.py
│ │ ├── bc.py
│ │ ├── bcq.py
│ │ ├── bear.py
│ │ ├── cal_ql.py
│ │ ├── cql.py
│ │ ├── crr.py
│ │ ├── ddpg.py
│ │ ├── dqn.py
│ │ ├── explorers.py
│ │ ├── iql.py
│ │ ├── nfq.py
│ │ ├── plas.py
│ │ ├── prdc.py
│ │ ├── random_policy.py
│ │ ├── rebrac.py
│ │ ├── sac.py
│ │ ├── td3.py
│ │ ├── td3_plus_bc.py
│ │ └── torch
│ │ │ ├── __init__.py
│ │ │ ├── awac_impl.py
│ │ │ ├── bc_impl.py
│ │ │ ├── bcq_impl.py
│ │ │ ├── bear_impl.py
│ │ │ ├── cal_ql_impl.py
│ │ │ ├── cql_impl.py
│ │ │ ├── crr_impl.py
│ │ │ ├── ddpg_impl.py
│ │ │ ├── dqn_impl.py
│ │ │ ├── iql_impl.py
│ │ │ ├── plas_impl.py
│ │ │ ├── prdc_impl.py
│ │ │ ├── rebrac_impl.py
│ │ │ ├── sac_impl.py
│ │ │ ├── td3_impl.py
│ │ │ ├── td3_plus_bc_impl.py
│ │ │ └── utility.py
│ ├── transformer
│ │ ├── __init__.py
│ │ ├── action_samplers.py
│ │ ├── base.py
│ │ ├── decision_transformer.py
│ │ ├── inputs.py
│ │ ├── tacr.py
│ │ └── torch
│ │ │ ├── __init__.py
│ │ │ ├── decision_transformer_impl.py
│ │ │ └── tacr_impl.py
│ └── utility.py
├── base.py
├── cli.py
├── constants.py
├── dataclass_utils.py
├── dataset
│ ├── __init__.py
│ ├── buffers.py
│ ├── compat.py
│ ├── components.py
│ ├── episode_generator.py
│ ├── io.py
│ ├── mini_batch.py
│ ├── replay_buffer.py
│ ├── trajectory_slicers.py
│ ├── transition_pickers.py
│ ├── utils.py
│ └── writers.py
├── datasets.py
├── distributed.py
├── envs
│ ├── __init__.py
│ ├── utility.py
│ └── wrappers.py
├── healthcheck.py
├── interface.py
├── itertools.py
├── logging
│ ├── __init__.py
│ ├── file_adapter.py
│ ├── logger.py
│ ├── noop_adapter.py
│ ├── tensorboard_adapter.py
│ ├── utils.py
│ └── wandb_adapter.py
├── metrics
│ ├── __init__.py
│ ├── evaluators.py
│ └── utility.py
├── models
│ ├── __init__.py
│ ├── builders.py
│ ├── encoders.py
│ ├── q_functions.py
│ ├── torch
│ │ ├── __init__.py
│ │ ├── distributions.py
│ │ ├── encoders.py
│ │ ├── imitators.py
│ │ ├── parameters.py
│ │ ├── policies.py
│ │ ├── q_functions
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── ensemble_q_function.py
│ │ │ ├── iqn_q_function.py
│ │ │ ├── mean_q_function.py
│ │ │ ├── qr_q_function.py
│ │ │ └── utility.py
│ │ ├── transformers.py
│ │ └── v_functions.py
│ └── utility.py
├── notebook_utils.py
├── ope
│ ├── __init__.py
│ ├── fqe.py
│ └── torch
│ │ ├── __init__.py
│ │ └── fqe_impl.py
├── optimizers
│ ├── __init__.py
│ ├── lr_schedulers.py
│ └── optimizers.py
├── preprocessing
│ ├── __init__.py
│ ├── action_scalers.py
│ ├── base.py
│ ├── observation_scalers.py
│ └── reward_scalers.py
├── serializable_config.py
├── tokenizers
│ ├── __init__.py
│ ├── tokenizers.py
│ └── utils.py
├── torch_utility.py
└── types.py
├── dev.requirements.txt
├── docker
└── Dockerfile
├── docs
├── Makefile
├── _static
│ ├── css
│ │ └── d3rlpy.css
│ └── logo.png
├── _templates
│ └── autosummary
│ │ └── class.rst
├── assets
│ ├── design.png
│ ├── dqn_cartpole.png
│ ├── fqe_cartpole_init_value.png
│ ├── fqe_cartpole_soft_opc.png
│ ├── mdp_dataset.png
│ ├── plot_all_example.png
│ └── plot_example.png
├── cli.rst
├── conf.py
├── index.rst
├── installation.rst
├── license.rst
├── make.bat
├── notebooks.rst
├── references
│ ├── algos.rst
│ ├── dataset.rst
│ ├── datasets.rst
│ ├── index.rst
│ ├── logging.rst
│ ├── metrics.rst
│ ├── network_architectures.rst
│ ├── off_policy_evaluation.rst
│ ├── online.rst
│ ├── optimizers.rst
│ ├── preprocessing.rst
│ └── q_functions.rst
├── reproductions.rst
├── requirements.txt
├── software_design.rst
├── tips.rst
└── tutorials
│ ├── after_training_policies.rst
│ ├── create_your_dataset.rst
│ ├── customize_neural_network.rst
│ ├── data_collection.rst
│ ├── finetuning.rst
│ ├── getting_started.rst
│ ├── index.rst
│ ├── offline_policy_selection.rst
│ ├── online_rl.rst
│ ├── preprocess_and_postprocess.rst
│ └── use_distributional_q_function.rst
├── examples
├── custom_algo.py
├── deepmind_control.py
├── distributed_offline_training.py
├── fine_tuning.py
├── frame_stacking.py
├── gymnasium_env.py
├── lr_scheduler.py
├── multi_step_training.py
├── preprocessors.py
└── tuple_observation.py
├── mypy.ini
├── reproductions
├── finetuning
│ ├── awac_finetune.py
│ ├── cal_ql_finetune.py
│ └── iql_finetune.py
├── offline
│ ├── awac.py
│ ├── bcq.py
│ ├── bear.py
│ ├── cql.py
│ ├── crr.py
│ ├── decision_transformer.py
│ ├── discrete_bcq.py
│ ├── discrete_cql.py
│ ├── discrete_decision_transformer.py
│ ├── iql.py
│ ├── nfq.py
│ ├── plas.py
│ ├── plas_with_perturbation.py
│ ├── prdc.py
│ ├── qdt.py
│ ├── qr_dqn.py
│ ├── rebrac.py
│ ├── sac.py
│ ├── tacr.py
│ ├── td3.py
│ └── td3_plus_bc.py
└── online
│ ├── double_dqn_online.py
│ ├── dqn_online.py
│ ├── iqn_online.py
│ ├── qr_dqn_online.py
│ └── sac_online.py
├── requirements.txt
├── ruff.toml
├── scripts
├── build-dist
├── build-docker
├── build-docs
├── create_cartpole_dataset
├── create_cartpole_random_dataset
├── create_pendulum_dataset
├── create_pendulum_random_dataset
├── lint
└── test
├── setup.py
├── tests
├── __init__.py
├── algos
│ ├── __init__.py
│ ├── qlearning
│ │ ├── __init__.py
│ │ ├── algo_test.py
│ │ ├── test_awac.py
│ │ ├── test_bc.py
│ │ ├── test_bcq.py
│ │ ├── test_bear.py
│ │ ├── test_cal_ql.py
│ │ ├── test_cql.py
│ │ ├── test_crr.py
│ │ ├── test_ddpg.py
│ │ ├── test_dqn.py
│ │ ├── test_explorers.py
│ │ ├── test_iql.py
│ │ ├── test_iterators.py
│ │ ├── test_nfq.py
│ │ ├── test_plas.py
│ │ ├── test_prdc.py
│ │ ├── test_random_policy.py
│ │ ├── test_rebrac.py
│ │ ├── test_sac.py
│ │ ├── test_td3.py
│ │ ├── test_td3_plus_bc.py
│ │ └── torch
│ │ │ ├── __init__.py
│ │ │ └── test_utility.py
│ └── transformer
│ │ ├── __init__.py
│ │ ├── algo_test.py
│ │ ├── test_action_samplers.py
│ │ ├── test_decision_transformer.py
│ │ ├── test_inputs.py
│ │ └── test_tacr.py
├── base_test.py
├── dataset
│ ├── __init__.py
│ ├── test_buffers.py
│ ├── test_compat.py
│ ├── test_components.py
│ ├── test_episode_generator.py
│ ├── test_io.py
│ ├── test_mini_batch.py
│ ├── test_replay_buffer.py
│ ├── test_trajectory_slicer.py
│ ├── test_transition_pickers.py
│ ├── test_utils.py
│ └── test_writers.py
├── dummy_env.py
├── dummy_scalers.py
├── envs
│ ├── __init__.py
│ └── test_wrappers.py
├── logging
│ └── test_logger.py
├── metrics
│ ├── __init__.py
│ ├── test_evaluators.py
│ └── test_utility.py
├── models
│ ├── __init__.py
│ ├── test_builders.py
│ ├── test_encoders.py
│ ├── test_lr_schedulers.py
│ ├── test_q_functions.py
│ └── torch
│ │ ├── __init__.py
│ │ ├── model_test.py
│ │ ├── q_functions
│ │ ├── __init__.py
│ │ ├── test_ensemble_q_function.py
│ │ ├── test_iqn_q_function.py
│ │ ├── test_mean_q_function.py
│ │ ├── test_qr_q_function.py
│ │ └── test_utility.py
│ │ ├── test_distributions.py
│ │ ├── test_encoders.py
│ │ ├── test_imitators.py
│ │ ├── test_parameters.py
│ │ ├── test_policies.py
│ │ ├── test_q_functions.py
│ │ ├── test_transformers.py
│ │ └── test_v_functions.py
├── ope
│ ├── __init__.py
│ └── test_fqe.py
├── optimizers
│ ├── __init__.py
│ ├── test_lr_schedulers.py
│ └── test_optimizers.py
├── preprocessing
│ ├── __init__.py
│ ├── test_action_scalers.py
│ ├── test_base.py
│ ├── test_observation_scalers.py
│ └── test_reward_scalers.py
├── test_dataclass_utils.py
├── test_datasets.py
├── test_itertools.py
├── test_torch_utility.py
├── testing_utils.py
└── tokenizers
│ ├── __init__.py
│ ├── test_tokenizers.py
│ └── test_utils.py
└── tutorials
├── atari.ipynb
├── cartpole.ipynb
├── online.ipynb
└── tpu.ipynb
/.coveragerc:
--------------------------------------------------------------------------------
1 | [run]
2 | omit = d3rlpy/cli.py
3 |
--------------------------------------------------------------------------------
/.editorconfig:
--------------------------------------------------------------------------------
1 | root = true
2 |
3 | [*]
4 | indent_size = 2
5 | indent_style = space
6 | end_of_line = lf
7 | charset = utf-8
8 | trim_trailing_whitespace = true
9 | insert_final_newline = true
10 |
11 | [*.py]
12 | indent_size = 4
13 |
14 | [*.pyx]
15 | indent_size = 4
16 |
17 | [*.pxd]
18 | indent_size = 4
19 |
20 | [*.pyi]
21 | indent_size = 4
22 |
23 | [*.md]
24 | trim_trailing_whitespace = false
25 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Bug report
3 | about: Create a report to help us improve
4 | title: "[BUG] bug title"
5 | labels: bug
6 | assignees: ''
7 |
8 | ---
9 |
10 | :warning:
11 | - Please don't post bug reports without a minimal example codes. Otherwise, it will take really long to solve the issue.
12 | - Please be polite in discussion. This is an open-sourced project by voluntary contributors.
13 |
14 | **Describe the bug**
15 | A clear and concise description of what the bug is.
16 |
17 | **To Reproduce**
18 | Steps to reproduce the behavior.
19 |
20 | **Expected behavior**
21 | A clear and concise description of what you expected to happen.
22 |
23 | **Additional context**
24 | Add any other context about the problem here.
25 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature_request.md:
--------------------------------------------------------------------------------
1 | ---
2 | name: Feature request
3 | about: Suggest an idea for this project
4 | title: "[REQUEST] request title"
5 | labels: enhancement
6 | assignees: ''
7 |
8 | ---
9 |
10 | **Is your feature request related to a problem? Please describe.**
11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
12 |
13 | **Describe the solution you'd like**
14 | A clear and concise description of what you want to happen.
15 |
16 | **Additional context**
17 | Add any other context or screenshots about the feature request here.
18 |
--------------------------------------------------------------------------------
/.github/workflows/format_check.yml:
--------------------------------------------------------------------------------
1 | name: format check
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 | runs-on: ubuntu-22.04
8 | steps:
9 | - uses: actions/checkout@v2
10 | - name: Setup Python.3.9.x
11 | uses: actions/setup-python@v1
12 | with:
13 | python-version: 3.9.x
14 | - name: Cache pip
15 | uses: actions/cache@v4
16 | with:
17 | path: ~/.cache/pip
18 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev.requirements.txt') }}
19 | restore-keys: |
20 | ${{ runner.os }}-pip-
21 | ${{ runner.os }}-
22 | - name: Install packages
23 | run: |
24 | python -m pip install --upgrade pip
25 | pip install Cython numpy
26 | pip install -e .
27 | pip install -r dev.requirements.txt
28 | - name: Static analysis
29 | run: |
30 | ./scripts/lint
31 |
32 | concurrency:
33 | group: ${{ github.workflow }}-${{ github.ref }}
34 | cancel-in-progress: true
35 |
--------------------------------------------------------------------------------
/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 | runs-on: ${{ matrix.os }}
8 | strategy:
9 | matrix:
10 | # temporary drop windows test
11 | # os: [ubuntu-22.04, macos-latest, windows-latest]
12 | os: [ubuntu-22.04, macos-latest]
13 | steps:
14 | - uses: actions/checkout@v4
15 | - name: Setup Python.3.10.x
16 | uses: actions/setup-python@v5
17 | with:
18 | python-version: '3.10'
19 | - name: Cache pip
20 | uses: actions/cache@v4
21 | with:
22 | path: ~/.cache/pip
23 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev.requirements.txt') }}
24 | restore-keys: |
25 | ${{ runner.os }}-pip-
26 | ${{ runner.os }}-
27 | - name: Install dependencies for Windows
28 | if: ${{ matrix.os == 'windows-latest' }}
29 | run: |
30 | pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
31 | - name: Install dependencies for macOS
32 | if: ${{ matrix.os == 'macos-latest' }}
33 | run: |
34 | brew install libomp
35 | - name: Install packages
36 | run: |
37 | python -m pip install --upgrade pip
38 | pip install Cython numpy
39 | pip install -e .
40 | pip install -r dev.requirements.txt
41 | - name: Unit tests
42 | run: |
43 | mkdir -p test_data
44 | pytest --cov-report=xml --cov=d3rlpy --cov-config=.coveragerc tests -p no:warnings -v
45 | - name: Upload coverage
46 | if: ${{ matrix.os == 'ubuntu-22.04' }}
47 | env:
48 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
49 | run: |
50 | bash <(curl -s https://codecov.io/bash)
51 |
52 | concurrency:
53 | group: ${{ github.workflow }}-${{ github.ref }}
54 | cancel-in-progress: true
55 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *__pycache__
3 | *.pkl
4 | *.c
5 | *.cpp
6 | *.so
7 | .python-version
8 | d3rlpy_data
9 | test_data
10 | runs
11 | d3rlpy_logs
12 | docs/_build
13 | docs/d3rlpy*.rst
14 | docs/modules.rst
15 | docs/references/generated
16 | coverage.xml
17 | .coverage
18 | .mypy_cache
19 | .ipynb_checkpoints
20 | build
21 | dist
22 | /.idea/
23 | *.egg-info
24 | /.vscode/
25 |
--------------------------------------------------------------------------------
/.readthedocs.yaml:
--------------------------------------------------------------------------------
1 | version: 2
2 | build:
3 | os: ubuntu-22.04
4 | tools:
5 | python: "3.10"
6 | sphinx:
7 | builder: html
8 | configuration: docs/conf.py
9 | formats: all
10 | python:
11 | install:
12 | - requirements: docs/requirements.txt
13 | - method: pip
14 | path: .
15 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | message: "If you use this software, please cite it as below."
3 | authors:
4 | - family-names: "Seno"
5 | given-names: "Takuma"
6 | title: "d3rlpy: An offline deep reinforcement learning library"
7 | version: 1.1.1
8 | date-released: 2022-11-26
9 | url: "https://github.com/takuseno/d3rlpy"
10 | preferred-citation:
11 | type: article
12 | authors:
13 | - family-names: "Seno"
14 | given-names: "Takuma"
15 | - family-names: "Imai"
16 | given-names: "Michita"
17 | journal: "Journal of Machine Learning Research"
18 | month: 11
19 | title: "d3rlpy: An Offline Deep Reinforcement Learning Library"
20 | year: 2022
21 | url: http://jmlr.org/papers/v23/22-0017.html
22 | volume: 23
23 | issue: 315
24 | start: 1
25 | end: 20
26 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to d3rlpy
2 |
3 | Any kind of contribution to d3rlpy would be highly appreciated!
4 |
5 | Contribution examples:
6 | - Thumbing up to good issues or pull requests :+1:
7 | - Opening issues about questions, bugs, installation problems, feature requests, algorithm requests etc.
8 | - Sending pull requests
9 |
10 | ## Development Guide
11 |
12 | ### Build from source
13 | ```
14 | $ git clone git@github.com:takuseno/d3rlpy
15 | $ cd d3rlpy
16 | $ pip install -e .
17 | ```
18 |
19 | Before making your nice PR, please run the follwing commands to inspect code qualities.
20 |
21 | ### Install additional dependencies for development
22 | ```
23 | $ pip install -r dev.requirements.txt
24 | ```
25 |
26 | ### Testing
27 | ```
28 | $ ./scripts/test
29 | ```
30 |
31 | ### Coding style check
32 | This repository is styled and analyzed with [Ruff](https://docs.astral.sh/ruff/).
33 | [docformatter](https://github.com/PyCQA/docformatter) is additionally used to format docstrings.
34 | This repository is fully type-annotated and checked by [mypy](https://github.com/python/mypy).
35 | Before you submit your PR, please execute this command:
36 | ```
37 | $ ./scripts/lint
38 | ```
39 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Takuma Seno
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/ROADMAP.md:
--------------------------------------------------------------------------------
1 | This list describes the planned features including breaking changes.
2 |
3 | ## Roadmap to v1.0.0
4 | - [x] Benchmark MuJoCo datasets
5 | - [x] Benchmark Atari 2600 datasets
6 |
7 | ## Roadmap to v2.x.x
8 | - [x] Change MDPDataset format to align with D4RL datasets
9 | - [x] Sophisticated config system using dataclasses
10 | - [x] Dump configuration and model parameters in a single file
11 | - [x] Support large dataset
12 | - [x] Support tuple observation
13 | - [x] Support large-scale data-parallel offline training
14 | - [x] Support Transformer architecture (e.g. Decision Transformer)
15 | - [x] Speed up training with CudaGraph and torch.compile
16 | - [ ] Support training foundation models (e.g. Gato)
17 | - [ ] Support large-scale distributed online training
18 | - [ ] Change library name to represent unification of offline and online
19 |
--------------------------------------------------------------------------------
/assets/breakout.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/assets/breakout.gif
--------------------------------------------------------------------------------
/assets/breakout.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/assets/breakout.png
--------------------------------------------------------------------------------
/assets/hopper.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/assets/hopper.png
--------------------------------------------------------------------------------
/assets/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/assets/logo.png
--------------------------------------------------------------------------------
/assets/mujoco_hopper.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/assets/mujoco_hopper.gif
--------------------------------------------------------------------------------
/d3rlpy/__init__.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=protected-access
2 | import random
3 |
4 | import gymnasium
5 | import numpy as np
6 | import torch
7 |
8 | from . import (
9 | algos,
10 | dataset,
11 | datasets,
12 | distributed,
13 | envs,
14 | logging,
15 | metrics,
16 | models,
17 | notebook_utils,
18 | ope,
19 | optimizers,
20 | preprocessing,
21 | tokenizers,
22 | types,
23 | )
24 | from ._version import __version__
25 | from .base import load_learnable
26 | from .constants import ActionSpace, LoggingStrategy, PositionEncodingType
27 | from .healthcheck import run_healthcheck
28 | from .torch_utility import Modules, TorchMiniBatch
29 |
30 | __all__ = [
31 | "algos",
32 | "dataset",
33 | "datasets",
34 | "distributed",
35 | "envs",
36 | "logging",
37 | "metrics",
38 | "models",
39 | "optimizers",
40 | "notebook_utils",
41 | "ope",
42 | "preprocessing",
43 | "tokenizers",
44 | "types",
45 | "__version__",
46 | "load_learnable",
47 | "ActionSpace",
48 | "LoggingStrategy",
49 | "PositionEncodingType",
50 | "Modules",
51 | "TorchMiniBatch",
52 | "seed",
53 | ]
54 |
55 |
56 | def seed(n: int) -> None:
57 | """Sets random seed value.
58 |
59 | Args:
60 | n (int): seed value.
61 | """
62 | random.seed(n)
63 | np.random.seed(n)
64 | torch.manual_seed(n)
65 | torch.cuda.manual_seed(n)
66 | torch.backends.cudnn.deterministic = True
67 |
68 |
69 | # run healthcheck
70 | run_healthcheck()
71 |
72 | if torch.cuda.is_available():
73 | # enable autograd compilation
74 | torch._dynamo.config.compiled_autograd = True
75 | torch.set_float32_matmul_precision("high")
76 |
77 | # register Shimmy if available
78 | try:
79 | import shimmy
80 |
81 | gymnasium.register_envs(shimmy)
82 | logging.LOG.info("Register Shimmy environments.")
83 | except ImportError:
84 | pass
85 |
--------------------------------------------------------------------------------
/d3rlpy/_version.py:
--------------------------------------------------------------------------------
1 | __version__ = "2.8.1"
2 |
--------------------------------------------------------------------------------
/d3rlpy/algos/__init__.py:
--------------------------------------------------------------------------------
1 | from .qlearning import *
2 | from .transformer import *
3 | from .utility import *
4 |
--------------------------------------------------------------------------------
/d3rlpy/algos/qlearning/__init__.py:
--------------------------------------------------------------------------------
1 | from .awac import *
2 | from .base import *
3 | from .bc import *
4 | from .bcq import *
5 | from .bear import *
6 | from .cal_ql import *
7 | from .cql import *
8 | from .crr import *
9 | from .ddpg import *
10 | from .dqn import *
11 | from .explorers import *
12 | from .iql import *
13 | from .nfq import *
14 | from .plas import *
15 | from .prdc import *
16 | from .random_policy import *
17 | from .rebrac import *
18 | from .sac import *
19 | from .td3 import *
20 | from .td3_plus_bc import *
21 |
--------------------------------------------------------------------------------
/d3rlpy/algos/qlearning/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from .awac_impl import *
2 | from .bc_impl import *
3 | from .bcq_impl import *
4 | from .bear_impl import *
5 | from .cal_ql_impl import *
6 | from .cql_impl import *
7 | from .crr_impl import *
8 | from .ddpg_impl import *
9 | from .dqn_impl import *
10 | from .iql_impl import *
11 | from .plas_impl import *
12 | from .prdc_impl import *
13 | from .rebrac_impl import *
14 | from .sac_impl import *
15 | from .td3_impl import *
16 | from .td3_plus_bc_impl import *
17 |
--------------------------------------------------------------------------------
/d3rlpy/algos/qlearning/torch/cal_ql_impl.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from ....types import TorchObservation
4 | from .cql_impl import CQLImpl
5 |
6 | __all__ = ["CalQLImpl"]
7 |
8 |
9 | class CalQLImpl(CQLImpl):
10 | def _compute_policy_is_values(
11 | self,
12 | policy_obs: TorchObservation,
13 | value_obs: TorchObservation,
14 | returns_to_go: torch.Tensor,
15 | ) -> tuple[torch.Tensor, torch.Tensor]:
16 | values, log_probs = super()._compute_policy_is_values(
17 | policy_obs=policy_obs,
18 | value_obs=value_obs,
19 | returns_to_go=returns_to_go,
20 | )
21 | return torch.maximum(values, returns_to_go.view(1, -1, 1)), log_probs
22 |
--------------------------------------------------------------------------------
/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=too-many-ancestors
2 | import dataclasses
3 |
4 | import torch
5 |
6 | from ....models.torch import ActionOutput, ContinuousEnsembleQFunctionForwarder
7 | from ....torch_utility import TorchMiniBatch
8 | from ....types import Shape
9 | from .ddpg_impl import DDPGBaseActorLoss, DDPGModules
10 | from .td3_impl import TD3Impl
11 |
12 | __all__ = ["TD3PlusBCImpl"]
13 |
14 |
15 | @dataclasses.dataclass(frozen=True)
16 | class TD3PlusBCActorLoss(DDPGBaseActorLoss):
17 | bc_loss: torch.Tensor
18 |
19 |
20 | class TD3PlusBCImpl(TD3Impl):
21 | _alpha: float
22 |
23 | def __init__(
24 | self,
25 | observation_shape: Shape,
26 | action_size: int,
27 | modules: DDPGModules,
28 | q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
29 | targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder,
30 | gamma: float,
31 | tau: float,
32 | target_smoothing_sigma: float,
33 | target_smoothing_clip: float,
34 | alpha: float,
35 | update_actor_interval: int,
36 | compiled: bool,
37 | device: str,
38 | ):
39 | super().__init__(
40 | observation_shape=observation_shape,
41 | action_size=action_size,
42 | modules=modules,
43 | q_func_forwarder=q_func_forwarder,
44 | targ_q_func_forwarder=targ_q_func_forwarder,
45 | gamma=gamma,
46 | tau=tau,
47 | target_smoothing_sigma=target_smoothing_sigma,
48 | target_smoothing_clip=target_smoothing_clip,
49 | update_actor_interval=update_actor_interval,
50 | compiled=compiled,
51 | device=device,
52 | )
53 | self._alpha = alpha
54 |
55 | def compute_actor_loss(
56 | self, batch: TorchMiniBatch, action: ActionOutput
57 | ) -> TD3PlusBCActorLoss:
58 | q_t = self._q_func_forwarder.compute_expected_q(
59 | batch.observations, action.squashed_mu, "none"
60 | )[0]
61 | lam = self._alpha / (q_t.abs().mean()).detach()
62 | bc_loss = ((batch.actions - action.squashed_mu) ** 2).mean()
63 | return TD3PlusBCActorLoss(
64 | actor_loss=lam * -q_t.mean() + bc_loss, bc_loss=bc_loss
65 | )
66 |
--------------------------------------------------------------------------------
/d3rlpy/algos/transformer/__init__.py:
--------------------------------------------------------------------------------
1 | from .action_samplers import *
2 | from .base import *
3 | from .decision_transformer import *
4 | from .inputs import *
5 | from .tacr import *
6 |
--------------------------------------------------------------------------------
/d3rlpy/algos/transformer/action_samplers.py:
--------------------------------------------------------------------------------
1 | from typing import Protocol, Union
2 |
3 | import numpy as np
4 |
5 | from ...types import NDArray
6 |
7 | __all__ = [
8 | "TransformerActionSampler",
9 | "IdentityTransformerActionSampler",
10 | "SoftmaxTransformerActionSampler",
11 | "GreedyTransformerActionSampler",
12 | ]
13 |
14 |
15 | class TransformerActionSampler(Protocol):
16 | r"""Interface of TransformerActionSampler."""
17 |
18 | def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]:
19 | r"""Returns sampled action from Transformer output.
20 |
21 | Args:
22 | transformer_output: Output of Transformer algorithms.
23 |
24 | Returns:
25 | Sampled action.
26 | """
27 | raise NotImplementedError
28 |
29 |
30 | class IdentityTransformerActionSampler(TransformerActionSampler):
31 | r"""Identity action-sampler.
32 |
33 | This class implements identity function to process Transformer output.
34 | Sampled action is the exactly same as ``transformer_output``.
35 | """
36 |
37 | def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]:
38 | return transformer_output
39 |
40 |
41 | class SoftmaxTransformerActionSampler(TransformerActionSampler):
42 | r"""Softmax action-sampler.
43 |
44 | This class implements softmax function to sample action from discrete
45 | probability distribution.
46 |
47 | Args:
48 | temperature (int): Softmax temperature.
49 | """
50 |
51 | _temperature: float
52 |
53 | def __init__(self, temperature: float = 1.0):
54 | self._temperature = temperature
55 |
56 | def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]:
57 | assert transformer_output.ndim == 1
58 | logits = transformer_output / self._temperature
59 | x = np.exp(logits - np.max(logits))
60 | probs = x / np.sum(x)
61 | action = np.random.choice(probs.shape[0], p=probs)
62 | return int(action)
63 |
64 |
65 | class GreedyTransformerActionSampler(TransformerActionSampler):
66 | r"""Greedy action-sampler.
67 |
68 | This class implements greedy function to determine action from discrte
69 | probability distribution.
70 | """
71 |
72 | def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]:
73 | assert transformer_output.ndim == 1
74 | return int(np.argmax(transformer_output))
75 |
--------------------------------------------------------------------------------
/d3rlpy/algos/transformer/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from .decision_transformer_impl import *
2 | from .tacr_impl import *
3 |
--------------------------------------------------------------------------------
/d3rlpy/constants.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | from ._version import __version__
4 |
5 | __all__ = [
6 | "IMPL_NOT_INITIALIZED_ERROR",
7 | "ALGO_NOT_GIVEN_ERROR",
8 | "DISCRETE_ACTION_SPACE_MISMATCH_ERROR",
9 | "CONTINUOUS_ACTION_SPACE_MISMATCH_ERROR",
10 | "ActionSpace",
11 | "PositionEncodingType",
12 | ]
13 |
14 | IMPL_NOT_INITIALIZED_ERROR = (
15 | "The neural network parameters are not "
16 | "initialized. Pleaes call build_with_dataset, "
17 | "build_with_env, or directly call fit or "
18 | "fit_online method."
19 | )
20 |
21 | ALGO_NOT_GIVEN_ERROR = (
22 | "The algorithm to evaluate is not given. Please give the trained algorithm"
23 | " to the argument."
24 | )
25 |
26 | DISCRETE_ACTION_SPACE_MISMATCH_ERROR = (
27 | "The action-space of the given dataset is not compatible with the"
28 | " algorithm. Please use discrete action-space algorithms. The algorithms"
29 | " list is available below.\n"
30 | f"https://d3rlpy.readthedocs.io/en/v{__version__}/references/algos.html"
31 | )
32 |
33 | CONTINUOUS_ACTION_SPACE_MISMATCH_ERROR = (
34 | "The action-space of the given dataset is not compatible with the"
35 | " algorithm. Please use continuous action-space algorithms. The algorithm"
36 | " list is available below.\n"
37 | f"https://d3rlpy.readthedocs.io/en/v{__version__}/references/algos.html"
38 | )
39 |
40 |
41 | class ActionSpace(Enum):
42 | CONTINUOUS = 1
43 | DISCRETE = 2
44 | BOTH = 3
45 |
46 |
47 | class PositionEncodingType(Enum):
48 | SIMPLE = "simple"
49 | GLOBAL = "global"
50 |
51 |
52 | class LoggingStrategy(Enum):
53 | STEPS = "steps"
54 | EPOCH = "epoch"
55 |
--------------------------------------------------------------------------------
/d3rlpy/dataclass_utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import Any
3 |
4 | import torch
5 |
6 | __all__ = ["asdict_without_copy", "asdict_as_float"]
7 |
8 |
9 | def asdict_without_copy(obj: Any) -> dict[str, Any]:
10 | assert dataclasses.is_dataclass(obj)
11 | fields = dataclasses.fields(obj)
12 | return {field.name: getattr(obj, field.name) for field in fields}
13 |
14 |
15 | def asdict_as_float(obj: Any) -> dict[str, float]:
16 | assert dataclasses.is_dataclass(obj)
17 | fields = dataclasses.fields(obj)
18 | ret: dict[str, float] = {}
19 | for field in fields:
20 | value = getattr(obj, field.name)
21 | if isinstance(value, torch.Tensor):
22 | assert (
23 | value.ndim == 0
24 | ), f"{field.name} needs to be scalar. {value.shape}."
25 | ret[field.name] = float(value.cpu().detach().numpy())
26 | else:
27 | ret[field.name] = float(value)
28 | return ret
29 |
--------------------------------------------------------------------------------
/d3rlpy/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .buffers import *
2 | from .compat import *
3 | from .components import *
4 | from .episode_generator import *
5 | from .io import *
6 | from .mini_batch import *
7 | from .replay_buffer import *
8 | from .trajectory_slicers import *
9 | from .transition_pickers import *
10 | from .utils import *
11 | from .writers import *
12 |
--------------------------------------------------------------------------------
/d3rlpy/distributed.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import torch.distributed as dist
4 |
5 | from .logging import set_log_context
6 |
7 | __all__ = ["init_process_group", "destroy_process_group"]
8 |
9 |
10 | @dataclasses.dataclass(frozen=True)
11 | class DistributedWorkerInfo:
12 | rank: int
13 | backend: str
14 | world_size: int
15 |
16 |
17 | def init_process_group(backend: str) -> int:
18 | """Initializes process group of distributed workers.
19 |
20 | Internally, distributed worker information is injected to log outputs.
21 |
22 | Args:
23 | backend: Backend of communication. Available options are ``gloo``,
24 | ``mpi`` and ``nccl``.
25 |
26 | Returns:
27 | Rank of the current process.
28 | """
29 | dist.init_process_group(backend)
30 | rank = dist.get_rank()
31 | set_log_context(
32 | distributed=DistributedWorkerInfo(
33 | rank=dist.get_rank(),
34 | backend=dist.get_backend(),
35 | world_size=dist.get_world_size(),
36 | )
37 | )
38 | return int(rank)
39 |
40 |
41 | def destroy_process_group() -> None:
42 | """Destroys process group of distributed workers."""
43 | dist.destroy_process_group()
44 |
--------------------------------------------------------------------------------
/d3rlpy/envs/__init__.py:
--------------------------------------------------------------------------------
1 | from .utility import *
2 | from .wrappers import *
3 |
--------------------------------------------------------------------------------
/d3rlpy/envs/utility.py:
--------------------------------------------------------------------------------
1 | from ..types import GymEnv
2 |
3 | __all__ = ["seed_env"]
4 |
5 |
6 | def seed_env(env: GymEnv, seed: int) -> None:
7 | env.reset(seed=seed)
8 |
--------------------------------------------------------------------------------
/d3rlpy/healthcheck.py:
--------------------------------------------------------------------------------
1 | __all__ = ["run_healthcheck"]
2 |
3 |
4 | def run_healthcheck() -> None:
5 | _check_gym()
6 | _check_pytorch()
7 |
8 |
9 | def _check_gym() -> None:
10 | import gymnasium
11 | from gym.version import VERSION
12 |
13 | if VERSION < "0.26.0":
14 | raise ValueError(
15 | "Gym version is too outdated. "
16 | "Please upgrade Gym to 0.26.0 or later."
17 | )
18 |
19 | if gymnasium.__version__ < "1.0.0":
20 | raise ValueError(
21 | "Gymnasium version is too outdated. "
22 | "Please upgrade Gymnasium to 1.0.0 or later."
23 | )
24 |
25 |
26 | def _check_pytorch() -> None:
27 | import torch
28 |
29 | if torch.__version__ < "2.5.0":
30 | raise ValueError(
31 | "PyTorch version is too outdated. "
32 | "Please upgrade PyTorch to 2.5.0 or later."
33 | )
34 |
--------------------------------------------------------------------------------
/d3rlpy/interface.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Protocol, Union
2 |
3 | from .preprocessing import ActionScaler, ObservationScaler, RewardScaler
4 | from .types import NDArray, Observation
5 |
6 | __all__ = ["QLearningAlgoProtocol", "StatefulTransformerAlgoProtocol"]
7 |
8 |
9 | class QLearningAlgoProtocol(Protocol):
10 | def predict(self, x: Observation) -> NDArray: ...
11 |
12 | def predict_value(self, x: Observation, action: NDArray) -> NDArray: ...
13 |
14 | def sample_action(self, x: Observation) -> NDArray: ...
15 |
16 | @property
17 | def gamma(self) -> float: ...
18 |
19 | @property
20 | def observation_scaler(self) -> Optional[ObservationScaler]: ...
21 |
22 | @property
23 | def action_scaler(self) -> Optional[ActionScaler]: ...
24 |
25 | @property
26 | def reward_scaler(self) -> Optional[RewardScaler]: ...
27 |
28 | @property
29 | def action_size(self) -> Optional[int]: ...
30 |
31 |
32 | class StatefulTransformerAlgoProtocol(Protocol):
33 | def predict(self, x: Observation, reward: float) -> Union[NDArray, int]: ...
34 |
35 | def reset(self) -> None: ...
36 |
--------------------------------------------------------------------------------
/d3rlpy/itertools.py:
--------------------------------------------------------------------------------
1 | from typing import Iterable, Iterator, TypeVar
2 |
3 | __all__ = ["last_flag", "first_flag"]
4 |
5 | T = TypeVar("T")
6 |
7 |
8 | def last_flag(iterator: Iterable[T]) -> Iterator[tuple[bool, T]]:
9 | items = list(iterator)
10 | for i, item in enumerate(items):
11 | yield i == len(items) - 1, item
12 |
13 |
14 | def first_flag(iterator: Iterable[T]) -> Iterator[tuple[bool, T]]:
15 | items = list(iterator)
16 | for i, item in enumerate(items):
17 | yield i == 0, item
18 |
--------------------------------------------------------------------------------
/d3rlpy/logging/__init__.py:
--------------------------------------------------------------------------------
1 | from .file_adapter import *
2 | from .logger import *
3 | from .noop_adapter import *
4 | from .tensorboard_adapter import *
5 | from .utils import *
6 | from .wandb_adapter import *
7 |
--------------------------------------------------------------------------------
/d3rlpy/logging/noop_adapter.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | from .logger import (
4 | AlgProtocol,
5 | LoggerAdapter,
6 | LoggerAdapterFactory,
7 | SaveProtocol,
8 | )
9 |
10 | __all__ = ["NoopAdapter", "NoopAdapterFactory"]
11 |
12 |
13 | class NoopAdapter(LoggerAdapter):
14 | r"""NoopAdapter class.
15 |
16 | This class does not save anything. This can be used especially when programs
17 | are not allowed to write things to disks.
18 | """
19 |
20 | def write_params(self, params: dict[str, Any]) -> None:
21 | pass
22 |
23 | def before_write_metric(self, epoch: int, step: int) -> None:
24 | pass
25 |
26 | def write_metric(
27 | self, epoch: int, step: int, name: str, value: float
28 | ) -> None:
29 | pass
30 |
31 | def after_write_metric(self, epoch: int, step: int) -> None:
32 | pass
33 |
34 | def save_model(self, epoch: int, algo: SaveProtocol) -> None:
35 | pass
36 |
37 | def close(self) -> None:
38 | pass
39 |
40 | def watch_model(
41 | self,
42 | epoch: int,
43 | step: int,
44 | ) -> None:
45 | pass
46 |
47 |
48 | class NoopAdapterFactory(LoggerAdapterFactory):
49 | r"""NoopAdapterFactory class.
50 |
51 | This class instantiates ``NoopAdapter`` object.
52 | """
53 |
54 | def create(
55 | self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int
56 | ) -> NoopAdapter:
57 | return NoopAdapter()
58 |
--------------------------------------------------------------------------------
/d3rlpy/logging/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Sequence
2 |
3 | from .logger import (
4 | AlgProtocol,
5 | LoggerAdapter,
6 | LoggerAdapterFactory,
7 | SaveProtocol,
8 | )
9 |
10 | __all__ = ["CombineAdapter", "CombineAdapterFactory"]
11 |
12 |
13 | class CombineAdapter(LoggerAdapter):
14 | r"""CombineAdapter class.
15 |
16 | This class combines multiple LoggerAdapter to write metrics through
17 | different adapters at the same time.
18 |
19 | Args:
20 | adapters (Sequence[LoggerAdapter]): List of LoggerAdapter.
21 | """
22 |
23 | def __init__(self, adapters: Sequence[LoggerAdapter]):
24 | self._adapters = adapters
25 |
26 | def write_params(self, params: dict[str, Any]) -> None:
27 | for adapter in self._adapters:
28 | adapter.write_params(params)
29 |
30 | def before_write_metric(self, epoch: int, step: int) -> None:
31 | for adapter in self._adapters:
32 | adapter.before_write_metric(epoch, step)
33 |
34 | def write_metric(
35 | self, epoch: int, step: int, name: str, value: float
36 | ) -> None:
37 | for adapter in self._adapters:
38 | adapter.write_metric(epoch, step, name, value)
39 |
40 | def after_write_metric(self, epoch: int, step: int) -> None:
41 | for adapter in self._adapters:
42 | adapter.after_write_metric(epoch, step)
43 |
44 | def save_model(self, epoch: int, algo: SaveProtocol) -> None:
45 | for adapter in self._adapters:
46 | adapter.save_model(epoch, algo)
47 |
48 | def close(self) -> None:
49 | for adapter in self._adapters:
50 | adapter.close()
51 |
52 | def watch_model(
53 | self,
54 | epoch: int,
55 | step: int,
56 | ) -> None:
57 | for adapter in self._adapters:
58 | adapter.watch_model(epoch, step)
59 |
60 |
61 | class CombineAdapterFactory(LoggerAdapterFactory):
62 | r"""CombineAdapterFactory class.
63 |
64 | This class instantiates ``CombineAdapter`` object.
65 |
66 | Args:
67 | adapter_factories (Sequence[LoggerAdapterFactory]):
68 | List of LoggerAdapterFactory.
69 | """
70 |
71 | _adapter_factories: Sequence[LoggerAdapterFactory]
72 |
73 | def __init__(self, adapter_factories: Sequence[LoggerAdapterFactory]):
74 | self._adapter_factories = adapter_factories
75 |
76 | def create(
77 | self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int
78 | ) -> CombineAdapter:
79 | return CombineAdapter(
80 | [
81 | factory.create(algo, experiment_name, n_steps_per_epoch)
82 | for factory in self._adapter_factories
83 | ]
84 | )
85 |
--------------------------------------------------------------------------------
/d3rlpy/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .evaluators import *
2 | from .utility import *
3 |
--------------------------------------------------------------------------------
/d3rlpy/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .builders import *
2 | from .encoders import *
3 | from .q_functions import *
4 |
--------------------------------------------------------------------------------
/d3rlpy/models/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from .distributions import *
2 | from .encoders import *
3 | from .imitators import *
4 | from .parameters import *
5 | from .policies import *
6 | from .q_functions import *
7 | from .transformers import *
8 | from .v_functions import *
9 |
--------------------------------------------------------------------------------
/d3rlpy/models/torch/parameters.py:
--------------------------------------------------------------------------------
1 | from typing import NoReturn
2 |
3 | import torch
4 | from torch import nn
5 |
6 | __all__ = ["Parameter", "get_parameter"]
7 |
8 |
9 | class Parameter(nn.Module): # type: ignore
10 | _parameter: nn.Parameter
11 |
12 | def __init__(self, data: torch.Tensor):
13 | super().__init__()
14 | self._parameter = nn.Parameter(data)
15 |
16 | def forward(self) -> NoReturn:
17 | raise NotImplementedError(
18 | "Parameter does not support __call__. Use parameter property "
19 | "instead."
20 | )
21 |
22 | def __call__(self) -> NoReturn:
23 | raise NotImplementedError(
24 | "Parameter does not support __call__. Use parameter property "
25 | "instead."
26 | )
27 |
28 |
29 | def get_parameter(parameter: Parameter) -> nn.Parameter:
30 | return next(parameter.parameters())
31 |
--------------------------------------------------------------------------------
/d3rlpy/models/torch/q_functions/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from .ensemble_q_function import *
3 | from .iqn_q_function import *
4 | from .mean_q_function import *
5 | from .qr_q_function import *
6 | from .utility import *
7 |
--------------------------------------------------------------------------------
/d3rlpy/models/torch/v_functions.py:
--------------------------------------------------------------------------------
1 | from typing import cast
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn
6 |
7 | from ...types import TorchObservation
8 | from .encoders import Encoder
9 |
10 | __all__ = ["ValueFunction", "compute_v_function_error"]
11 |
12 |
13 | class ValueFunction(nn.Module): # type: ignore
14 | _encoder: Encoder
15 | _fc: nn.Linear
16 |
17 | def __init__(self, encoder: Encoder, hidden_size: int):
18 | super().__init__()
19 | self._encoder = encoder
20 | self._fc = nn.Linear(hidden_size, 1)
21 |
22 | def forward(self, x: TorchObservation) -> torch.Tensor:
23 | h = self._encoder(x)
24 | return cast(torch.Tensor, self._fc(h))
25 |
26 | def __call__(self, x: TorchObservation) -> torch.Tensor:
27 | return cast(torch.Tensor, super().__call__(x))
28 |
29 |
30 | def compute_v_function_error(
31 | v_function: ValueFunction,
32 | observations: TorchObservation,
33 | target: torch.Tensor,
34 | ) -> torch.Tensor:
35 | v_t = v_function(observations)
36 | loss = F.mse_loss(v_t, target)
37 | return loss
38 |
--------------------------------------------------------------------------------
/d3rlpy/models/utility.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 |
3 | from ..torch_utility import GEGLU, Swish
4 |
5 | __all__ = ["create_activation"]
6 |
7 |
8 | def create_activation(activation_type: str) -> nn.Module:
9 | if activation_type == "relu":
10 | return nn.ReLU()
11 | elif activation_type == "gelu":
12 | return nn.GELU()
13 | elif activation_type == "tanh":
14 | return nn.Tanh()
15 | elif activation_type == "swish":
16 | return Swish()
17 | elif activation_type == "none":
18 | return nn.Identity()
19 | elif activation_type == "geglu":
20 | return GEGLU()
21 | raise ValueError("invalid activation_type.")
22 |
--------------------------------------------------------------------------------
/d3rlpy/notebook_utils.py:
--------------------------------------------------------------------------------
1 | import base64
2 |
3 | __all__ = ["start_virtual_display", "render_video"]
4 |
5 |
6 | def start_virtual_display() -> None:
7 | """Starts virtual display."""
8 | try:
9 | from pyvirtualdisplay.display import Display
10 |
11 | display = Display()
12 | display.start()
13 | except ImportError as e:
14 | raise ImportError(
15 | "pyvirtualdisplay is not installed.\n"
16 | "$ pip install pyvirtualdisplay"
17 | ) from e
18 |
19 |
20 | def render_video(path: str) -> None:
21 | """Renders video file in Jupyter Notebook.
22 |
23 | Args:
24 | path: Path to video file.
25 | """
26 | try:
27 | from IPython import display as ipythondisplay
28 | from IPython.core.display import HTML
29 |
30 | with open(path, "r+b") as f:
31 | encoded = base64.b64encode(f.read())
32 | template = """
33 |
36 | """
37 | ipythondisplay.display(
38 | HTML(data=template.format(encoded.decode("ascii")))
39 | )
40 | except ImportError as e:
41 | raise ImportError(
42 | "This should be executed inside Jupyter Notebook."
43 | ) from e
44 |
--------------------------------------------------------------------------------
/d3rlpy/ope/__init__.py:
--------------------------------------------------------------------------------
1 | from .fqe import *
2 |
--------------------------------------------------------------------------------
/d3rlpy/ope/torch/__init__.py:
--------------------------------------------------------------------------------
1 | from .fqe_impl import *
2 |
--------------------------------------------------------------------------------
/d3rlpy/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .lr_schedulers import *
2 | from .optimizers import *
3 |
--------------------------------------------------------------------------------
/d3rlpy/optimizers/lr_schedulers.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | from torch.optim import Optimizer
4 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LRScheduler
5 |
6 | from ..serializable_config import (
7 | DynamicConfig,
8 | generate_optional_config_generation,
9 | )
10 |
11 | __all__ = [
12 | "LRSchedulerFactory",
13 | "WarmupSchedulerFactory",
14 | "CosineAnnealingLRFactory",
15 | "make_lr_scheduler_field",
16 | ]
17 |
18 |
19 | @dataclasses.dataclass()
20 | class LRSchedulerFactory(DynamicConfig):
21 | """A factory class that creates a learning rate scheduler a lazy way."""
22 |
23 | def create(self, optim: Optimizer) -> LRScheduler:
24 | """Returns a learning rate scheduler object.
25 |
26 | Args:
27 | optim: PyTorch optimizer.
28 |
29 | Returns:
30 | Learning rate scheduler.
31 | """
32 | raise NotImplementedError
33 |
34 |
35 | @dataclasses.dataclass()
36 | class WarmupSchedulerFactory(LRSchedulerFactory):
37 | r"""A warmup learning rate scheduler.
38 |
39 | .. math::
40 |
41 | lr = \max((t + 1) / warmup\_steps, 1)
42 |
43 | Args:
44 | warmup_steps: Warmup steps.
45 | """
46 |
47 | warmup_steps: int
48 |
49 | def create(self, optim: Optimizer) -> LRScheduler:
50 | return LambdaLR(
51 | optim,
52 | lambda steps: min((steps + 1) / self.warmup_steps, 1),
53 | )
54 |
55 | @staticmethod
56 | def get_type() -> str:
57 | return "warmup"
58 |
59 |
60 | @dataclasses.dataclass()
61 | class CosineAnnealingLRFactory(LRSchedulerFactory):
62 | """A cosine annealing learning rate scheduler.
63 |
64 | Args:
65 | T_max: Maximum time step.
66 | eta_min: Minimum learning rate.
67 | last_epoch: Last epoch.
68 | """
69 |
70 | T_max: int
71 | eta_min: float = 0.0
72 | last_epoch: int = -1
73 |
74 | def create(self, optim: Optimizer) -> LRScheduler:
75 | return CosineAnnealingLR(
76 | optim,
77 | T_max=self.T_max,
78 | eta_min=self.eta_min,
79 | last_epoch=self.last_epoch,
80 | )
81 |
82 | @staticmethod
83 | def get_type() -> str:
84 | return "cosine_annealing"
85 |
86 |
87 | register_lr_scheduler_factory, make_lr_scheduler_field = (
88 | generate_optional_config_generation(
89 | LRSchedulerFactory,
90 | )
91 | )
92 |
93 | register_lr_scheduler_factory(WarmupSchedulerFactory)
94 | register_lr_scheduler_factory(CosineAnnealingLRFactory)
95 |
--------------------------------------------------------------------------------
/d3rlpy/preprocessing/__init__.py:
--------------------------------------------------------------------------------
1 | from .action_scalers import *
2 | from .base import *
3 | from .observation_scalers import *
4 | from .reward_scalers import *
5 |
--------------------------------------------------------------------------------
/d3rlpy/tokenizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .tokenizers import *
2 | from .utils import *
3 |
--------------------------------------------------------------------------------
/d3rlpy/tokenizers/tokenizers.py:
--------------------------------------------------------------------------------
1 | from typing import Protocol, runtime_checkable
2 |
3 | import numpy as np
4 |
5 | from ..types import Float32NDArray, Int32NDArray, NDArray
6 | from .utils import mu_law_decode, mu_law_encode
7 |
8 | __all__ = [
9 | "Tokenizer",
10 | "FloatTokenizer",
11 | ]
12 |
13 |
14 | @runtime_checkable
15 | class Tokenizer(Protocol):
16 | def __call__(self, x: NDArray) -> NDArray: ...
17 |
18 | def decode(self, y: Int32NDArray) -> NDArray: ...
19 |
20 |
21 | class FloatTokenizer(Tokenizer):
22 | _bins: Float32NDArray
23 | _use_mu_law_encode: bool
24 | _mu: float
25 | _basis: float
26 | _token_offset: int
27 |
28 | def __init__(
29 | self,
30 | num_bins: int,
31 | minimum: float = -1.0,
32 | maximum: float = 1.0,
33 | use_mu_law_encode: bool = True,
34 | mu: float = 100.0,
35 | basis: float = 256.0,
36 | token_offset: int = 0,
37 | ):
38 | self._bins = np.array(
39 | (maximum - minimum) * np.arange(num_bins) / num_bins + minimum,
40 | dtype=np.float32,
41 | )
42 | self._use_mu_law_encode = use_mu_law_encode
43 | self._mu = mu
44 | self._basis = basis
45 | self._token_offset = token_offset
46 |
47 | def __call__(self, x: NDArray) -> Int32NDArray:
48 | if self._use_mu_law_encode:
49 | x = mu_law_encode(x, self._mu, self._basis)
50 | return np.digitize(x, self._bins) - 1 + self._token_offset
51 |
52 | def decode(self, y: Int32NDArray) -> NDArray:
53 | x = self._bins[y - self._token_offset]
54 | if self._use_mu_law_encode:
55 | x = mu_law_decode(x, mu=self._mu, basis=self._basis)
56 | return x # type: ignore
57 |
--------------------------------------------------------------------------------
/d3rlpy/tokenizers/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from ..types import Float32NDArray, NDArray
4 |
5 | __all__ = ["mu_law_encode", "mu_law_decode"]
6 |
7 |
8 | def mu_law_encode(x: NDArray, mu: float, basis: float) -> Float32NDArray:
9 | x = np.array(x, dtype=np.float32)
10 | y = np.sign(x) * np.log(np.abs(x) * mu + 1.0) / np.log(basis * mu + 1.0)
11 | return y # type: ignore
12 |
13 |
14 | def mu_law_decode(y: Float32NDArray, mu: float, basis: float) -> Float32NDArray:
15 | x = np.sign(y) * (np.exp(np.log(basis * mu + 1.0) * np.abs(y)) - 1.0) / mu
16 | return x # type: ignore
17 |
--------------------------------------------------------------------------------
/d3rlpy/types.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Mapping, Protocol, Sequence, Union, runtime_checkable
2 |
3 | import gym
4 | import gymnasium
5 | import numpy as np
6 | import numpy.typing as npt
7 | import torch
8 | from torch.optim import Optimizer
9 |
10 | __all__ = [
11 | "NDArray",
12 | "Float32NDArray",
13 | "Int32NDArray",
14 | "UInt8NDArray",
15 | "DType",
16 | "Observation",
17 | "ObservationSequence",
18 | "Shape",
19 | "TorchObservation",
20 | "GymEnv",
21 | "OptimizerWrapperProto",
22 | ]
23 |
24 |
25 | NDArray = npt.NDArray[Any]
26 | Float32NDArray = npt.NDArray[np.float32]
27 | Int32NDArray = npt.NDArray[np.int32]
28 | UInt8NDArray = npt.NDArray[np.uint8]
29 | DType = npt.DTypeLike
30 |
31 | Observation = Union[NDArray, Sequence[NDArray]]
32 | ObservationSequence = Union[NDArray, Sequence[NDArray]]
33 | Shape = Union[Sequence[int], Sequence[Sequence[int]]]
34 | TorchObservation = Union[torch.Tensor, Sequence[torch.Tensor]]
35 |
36 | GymEnv = Union[gym.Env[Any, Any], gymnasium.Env[Any, Any]]
37 |
38 |
39 | @runtime_checkable
40 | class OptimizerWrapperProto(Protocol):
41 | @property
42 | def optim(self) -> Optimizer:
43 | raise NotImplementedError
44 |
45 | def state_dict(self) -> Mapping[str, Any]:
46 | raise NotImplementedError
47 |
48 | def load_state_dict(self, state_dict: Mapping[str, Any]) -> None:
49 | raise NotImplementedError
50 |
--------------------------------------------------------------------------------
/dev.requirements.txt:
--------------------------------------------------------------------------------
1 | pytest
2 | pytest-cov
3 | onnxruntime
4 | onnx
5 | matplotlib
6 | tensorboardX
7 | wandb
8 | mypy
9 | numpy<2
10 | docformatter
11 | ruff
12 | black
13 | minari[all]>=0.5.2
14 | gymnasium-robotics>=1.3.1
15 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel
2 |
3 | # this needs to avoid time zone question
4 | ENV DEBIAN_FRONTEND=noninteractive
5 |
6 | # install dependencies
7 | RUN apt-get update && \
8 | apt-get install -y --no-install-recommends \
9 | build-essential \
10 | software-properties-common \
11 | cmake \
12 | git \
13 | wget \
14 | unzip \
15 | python3-dev \
16 | zlib1g \
17 | zlib1g-dev \
18 | libgl1-mesa-dri \
19 | libgl1-mesa-glx \
20 | libglu1-mesa-dev \
21 | libasio-dev \
22 | pkg-config \
23 | python3-tk \
24 | libsm6 \
25 | libxext6 \
26 | libxrender1 \
27 | libpcre3-dev && \
28 | pip install --no-cache-dir \
29 | Cython==0.29.28 \
30 | git+https://github.com/takuseno/d4rl-atari && \
31 | rm -rf /var/lib/apt/lists/* && \
32 | rm -rf /tmp/* && \
33 | wget https://github.com/takuseno/d3rlpy/archive/master.zip && \
34 | unzip master.zip && \
35 | cd d3rlpy-master && \
36 | pip install --no-cache-dir . && \
37 | cd .. && \
38 | rm -rf d3rlpy-master master.zip
39 |
40 | EXPOSE 6006
41 |
42 | CMD ["tail", "-f", "/dev/null"]
43 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/docs/_static/css/d3rlpy.css:
--------------------------------------------------------------------------------
1 |
2 | :root{
3 | --main-bg-color: #2c3e50;
4 | }
5 |
6 | /* Docs background */
7 | .wy-side-nav-search{
8 | background-color: var(--main-bg-color);
9 | }
10 |
11 | /* Mobile version */
12 | .wy-nav-top{
13 | background-color: var(--main-bg-color);
14 | }
15 |
--------------------------------------------------------------------------------
/docs/_static/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/_static/logo.png
--------------------------------------------------------------------------------
/docs/_templates/autosummary/class.rst:
--------------------------------------------------------------------------------
1 | {{ fullname }}
2 | {{ underline }}
3 |
4 | .. currentmodule:: {{ module }}
5 |
6 | .. autoclass:: {{ objname }}
7 |
8 | ..
9 | Methods
10 |
11 | {% block methods %}
12 |
13 | .. rubric:: Methods
14 |
15 | ..
16 | Special methods
17 |
18 | {% for item in ('__call__', '__enter__', '__exit__', '__getitem__', '__setitem__', '__len__', '__next__', '__iter__', '__copy__') %}
19 | {% if item in all_methods %}
20 | .. automethod:: {{ item }}
21 | {% endif %}
22 | {%- endfor %}
23 |
24 | ..
25 | Ordinary methods
26 |
27 | {% for item in methods %}
28 | {% if item not in ('__init__',) %}
29 | .. automethod:: {{ item }}
30 | {% endif %}
31 | {%- endfor %}
32 |
33 | {% endblock %}
34 |
35 | ..
36 | Attributes
37 |
38 | {% block attributes %} {% if attributes %}
39 |
40 | .. rubric:: Attributes
41 |
42 | {% for item in attributes %}
43 | .. autoattribute:: {{ item }}
44 | {%- endfor %}
45 | {% endif %} {% endblock %}
46 |
--------------------------------------------------------------------------------
/docs/assets/design.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/design.png
--------------------------------------------------------------------------------
/docs/assets/dqn_cartpole.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/dqn_cartpole.png
--------------------------------------------------------------------------------
/docs/assets/fqe_cartpole_init_value.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/fqe_cartpole_init_value.png
--------------------------------------------------------------------------------
/docs/assets/fqe_cartpole_soft_opc.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/fqe_cartpole_soft_opc.png
--------------------------------------------------------------------------------
/docs/assets/mdp_dataset.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/mdp_dataset.png
--------------------------------------------------------------------------------
/docs/assets/plot_all_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/plot_all_example.png
--------------------------------------------------------------------------------
/docs/assets/plot_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/plot_example.png
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. d3rlpy documentation master file, created by
2 | sphinx-quickstart on Mon Jul 6 18:20:21 2020.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | d3rlpy - An offline deep reinforcement learning library.
7 | ========================================================
8 |
9 | `d3rlpy `_ is a easy-to-use offline deep
10 | reinforcement learning library.
11 |
12 | .. code-block::
13 |
14 | $ pip install d3rlpy
15 |
16 | d3rlpy provides state-of-the-art offline deep reinforcement learning
17 | algorithms through out-of-the-box scikit-learn-style APIs.
18 | Unlike other RL libraries, the provided algorithms can achieve extremely
19 | powerful performance beyond their papers via several tweaks.
20 |
21 | .. toctree::
22 | :maxdepth: 2
23 | :caption: Tutorials
24 |
25 | tutorials/index
26 | notebooks
27 |
28 | .. toctree::
29 | :maxdepth: 2
30 | :caption: References
31 |
32 | software_design
33 | references/index
34 | cli
35 | installation
36 | tips
37 |
38 | .. toctree::
39 | :maxdepth: 2
40 | :caption: Other
41 |
42 | reproductions
43 | license
44 |
45 |
46 | Indices and tables
47 | ==================
48 |
49 | * :ref:`genindex`
50 | * :ref:`modindex`
51 | * :ref:`search`
52 |
--------------------------------------------------------------------------------
/docs/installation.rst:
--------------------------------------------------------------------------------
1 | Installation
2 | ============
3 |
4 | Recommended Platforms
5 | ---------------------
6 |
7 | d3rlpy supports Linux, macOS and also Windows.
8 |
9 |
10 | Install d3rlpy
11 | --------------
12 |
13 | Install via PyPI
14 | ~~~~~~~~~~~~~~~~
15 |
16 | `pip` is a recommended way to install d3rlpy::
17 |
18 | $ pip install d3rlpy
19 |
20 | Install via Anaconda
21 | ~~~~~~~~~~~~~~~~~~~~
22 |
23 | d3rlpy is also available on `conda-forge`::
24 |
25 | $ conda install -c conda-forge d3rlpy
26 |
27 |
28 | Install via Docker
29 | ~~~~~~~~~~~~~~~~~~
30 |
31 | d3rlpy is also available on Docker Hub::
32 |
33 | $ docker run -it --gpus all --name d3rlpy takuseno/d3rlpy:latest bash
34 |
35 |
36 | .. _install_from_source:
37 |
38 | Install from source
39 | ~~~~~~~~~~~~~~~~~~~
40 |
41 | You can also install via GitHub repository::
42 |
43 | $ git clone https://github.com/takuseno/d3rlpy
44 | $ cd d3rlpy
45 | $ pip install -e .
46 |
--------------------------------------------------------------------------------
/docs/license.rst:
--------------------------------------------------------------------------------
1 | License
2 | =======
3 |
4 | MIT License
5 |
6 | Copyright (c) 2021 Takuma Seno
7 |
8 | Permission is hereby granted, free of charge, to any person obtaining a copy
9 | of this software and associated documentation files (the "Software"), to deal
10 | in the Software without restriction, including without limitation the rights
11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | copies of the Software, and to permit persons to whom the Software is
13 | furnished to do so, subject to the following conditions:
14 |
15 | The above copyright notice and this permission notice shall be included in all
16 | copies or substantial portions of the Software.
17 |
18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24 | SOFTWARE.
25 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/docs/notebooks.rst:
--------------------------------------------------------------------------------
1 | Jupyter Notebooks
2 | =================
3 |
4 | * `CartPole `_
5 | * `CartPole (online) `_
6 | * `Discrete Control with Atari `_
7 | * `TPU Example `_
8 |
--------------------------------------------------------------------------------
/docs/references/datasets.rst:
--------------------------------------------------------------------------------
1 | Datasets
2 | ========
3 |
4 | .. module:: d3rlpy.datasets
5 |
6 | d3rlpy provides datasets for experimenting data-driven deep reinforcement
7 | learning algorithms.
8 |
9 | .. autosummary::
10 | :toctree: generated/
11 | :nosignatures:
12 |
13 | d3rlpy.datasets.get_cartpole
14 | d3rlpy.datasets.get_pendulum
15 | d3rlpy.datasets.get_atari
16 | d3rlpy.datasets.get_atari_transitions
17 | d3rlpy.datasets.get_d4rl
18 | d3rlpy.datasets.get_dataset
19 | d3rlpy.datasets.get_minari
20 |
--------------------------------------------------------------------------------
/docs/references/index.rst:
--------------------------------------------------------------------------------
1 | *************
2 | API Reference
3 | *************
4 |
5 | .. module:: d3rlpy
6 |
7 | .. toctree::
8 | :maxdepth: 2
9 |
10 | algos
11 | q_functions
12 | dataset
13 | datasets
14 | preprocessing
15 | optimizers
16 | network_architectures
17 | metrics
18 | off_policy_evaluation
19 | logging
20 | online
21 |
--------------------------------------------------------------------------------
/docs/references/metrics.rst:
--------------------------------------------------------------------------------
1 | Metrics
2 | =======
3 |
4 | .. module:: d3rlpy.metrics
5 |
6 | d3rlpy provides scoring functions for offline Q-learning-based training.
7 | You can also check :doc:`../references/logging` to understand how to write
8 | metrics to files.
9 |
10 | .. code-block:: python
11 |
12 | import d3rlpy
13 |
14 | dataset, env = d3rlpy.datasets.get_cartpole()
15 | # use partial episodes as test data
16 | test_episodes = dataset.episodes[:10]
17 |
18 | dqn = d3rlpy.algos.DQNConfig().create()
19 |
20 | dqn.fit(
21 | dataset,
22 | n_steps=100000,
23 | evaluators={
24 | 'td_error': d3rlpy.metrics.TDErrorEvaluator(test_episodes),
25 | 'value_scale': d3rlpy.metrics.AverageValueEstimationEvaluator(test_episodes),
26 | 'environment': d3rlpy.metrics.EnvironmentEvaluator(env),
27 | },
28 | )
29 |
30 | You can also implement your own metrics.
31 |
32 |
33 | .. code-block:: python
34 |
35 | class CustomEvaluator(d3rlpy.metrics.EvaluatorProtocol):
36 | def __call__(self, algo: d3rlpy.algos.QLearningAlgoBase, dataset: ReplayBuffer) -> float:
37 | # do some evaluation
38 |
39 |
40 | .. autosummary::
41 | :toctree: generated/
42 | :nosignatures:
43 |
44 | d3rlpy.metrics.TDErrorEvaluator
45 | d3rlpy.metrics.DiscountedSumOfAdvantageEvaluator
46 | d3rlpy.metrics.AverageValueEstimationEvaluator
47 | d3rlpy.metrics.InitialStateValueEstimationEvaluator
48 | d3rlpy.metrics.SoftOPCEvaluator
49 | d3rlpy.metrics.ContinuousActionDiffEvaluator
50 | d3rlpy.metrics.DiscreteActionMatchEvaluator
51 | d3rlpy.metrics.EnvironmentEvaluator
52 | d3rlpy.metrics.CompareContinuousActionDiffEvaluator
53 | d3rlpy.metrics.CompareDiscreteActionMatchEvaluator
54 |
--------------------------------------------------------------------------------
/docs/references/off_policy_evaluation.rst:
--------------------------------------------------------------------------------
1 | Off-Policy Evaluation
2 | =====================
3 |
4 | .. module:: d3rlpy.ope
5 |
6 | The off-policy evaluation is a method to estimate the trained policy
7 | performance only with offline datasets.
8 |
9 | .. code-block:: python
10 |
11 | import d3rlpy
12 |
13 | # prepare the trained algorithm
14 | cql = d3rlpy.load_learnable("model.d3")
15 |
16 | # dataset to evaluate with
17 | dataset, env = d3rlpy.datasets.get_pendulum()
18 |
19 | # off-policy evaluation algorithm
20 | fqe = d3rlpy.ope.FQE(algo=cql, config=d3rlpy.ope.FQEConfig())
21 |
22 | # train estimators to evaluate the trained policy
23 | fqe.fit(
24 | dataset,
25 | n_steps=100000,
26 | evaluators={
27 | 'init_value': d3rlpy.metrics.InitialStateValueEstimationEvaluator(),
28 | 'soft_opc': d3rlpy.metrics.SoftOPCEvaluator(return_threshold=-300),
29 | },
30 | )
31 |
32 | The evaluation during fitting is evaluating the trained policy.
33 |
34 | For continuous control algorithms
35 | ---------------------------------
36 |
37 | .. autosummary::
38 | :toctree: generated/
39 | :nosignatures:
40 |
41 | d3rlpy.ope.FQE
42 |
43 |
44 | For discrete control algorithms
45 | -------------------------------
46 |
47 | .. autosummary::
48 | :toctree: generated/
49 | :nosignatures:
50 |
51 | d3rlpy.ope.DiscreteFQE
52 |
--------------------------------------------------------------------------------
/docs/references/online.rst:
--------------------------------------------------------------------------------
1 | Online Training
2 | ===============
3 |
4 | .. module:: d3rlpy.algos
5 |
6 | d3rlpy provides not only offline training, but also online training utilities.
7 | Despite being designed for offline training algorithms, d3rlpy is flexible
8 | enough to be trained in an online manner with a few more utilities.
9 |
10 | .. code-block:: python
11 |
12 | import d3lpy
13 | import gym
14 |
15 | # setup environment
16 | env = gym.make('CartPole-v1')
17 | eval_env = gym.make('CartPole-v1')
18 |
19 | # setup algorithm
20 | dqn = d3rlpy.algos.DQN(
21 | batch_size=32,
22 | learning_rate=2.5e-4,
23 | target_update_interval=100,
24 | ).create(device="cuda:0")
25 |
26 | # setup replay buffer
27 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=100000, env=env)
28 |
29 | # setup explorers
30 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy(
31 | start_epsilon=1.0,
32 | end_epsilon=0.1,
33 | duration=10000,
34 | )
35 |
36 | # start training
37 | dqn.fit_online(
38 | env,
39 | buffer,
40 | explorer=explorer, # you don't need this with probablistic policy algorithms
41 | eval_env=eval_env,
42 | n_steps=30000, # the number of total steps to train.
43 | n_steps_per_epoch=1000,
44 | update_interval=10, # update parameters every 10 steps.
45 | )
46 |
47 |
48 | Explorers
49 | ~~~~~~~~~
50 |
51 | .. autosummary::
52 | :toctree: generated/
53 | :nosignatures:
54 |
55 | d3rlpy.algos.ConstantEpsilonGreedy
56 | d3rlpy.algos.LinearDecayEpsilonGreedy
57 | d3rlpy.algos.NormalNoise
58 |
--------------------------------------------------------------------------------
/docs/references/optimizers.rst:
--------------------------------------------------------------------------------
1 | Optimizers
2 | ==========
3 |
4 | .. module:: d3rlpy.optimizers
5 |
6 | d3rlpy provides ``OptimizerFactory`` that gives you flexible control over
7 | optimizers.
8 | ``OptimizerFactory`` takes PyTorch's optimizer class and its arguments to
9 | initialize, which you can check more `here `_.
10 |
11 | .. code-block:: python
12 |
13 | import d3rlpy
14 | from torch.optim import Adam
15 |
16 | # modify weight decay
17 | optim_factory = d3rlpy.optimizers.OptimizerFactory(Adam, weight_decay=1e-4)
18 |
19 | # set OptimizerFactory
20 | dqn = d3rlpy.algos.DQNConfig(optim_factory=optim_factory).create()
21 |
22 | There are also convenient alises.
23 |
24 | .. code-block:: python
25 |
26 | # alias for Adam optimizer
27 | optim_factory = d3rlpy.optimizers.AdamFactory(weight_decay=1e-4)
28 |
29 | dqn = d3rlpy.algos.DQNConfig(optim_factory=optim_factory).create()
30 |
31 |
32 | .. autosummary::
33 | :toctree: generated/
34 | :nosignatures:
35 |
36 | d3rlpy.optimizers.OptimizerFactory
37 | d3rlpy.optimizers.SGDFactory
38 | d3rlpy.optimizers.AdamFactory
39 | d3rlpy.optimizers.RMSpropFactory
40 | d3rlpy.optimizers.GPTAdamWFactory
41 |
42 |
43 | Learning rate scheduler
44 | ~~~~~~~~~~~~~~~~~~~~~~~
45 |
46 | d3rlpy provides ``LRSchedulerFactory`` that gives you configure learning rate
47 | schedulers with ``OptimizerFactory``.
48 |
49 | .. code-block:: python
50 |
51 | import d3rlpy
52 |
53 | # set lr_scheduler_factory
54 | optim_factory = d3rlpy.optimizers.AdamFactory(
55 | lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory(
56 | warmup_steps=10000
57 | )
58 | )
59 |
60 |
61 | .. autosummary::
62 | :toctree: generated/
63 | :nosignatures:
64 |
65 | d3rlpy.optimizers.LRSchedulerFactory
66 | d3rlpy.optimizers.WarmupSchedulerFactory
67 | d3rlpy.optimizers.CosineAnnealingLRFactory
68 |
--------------------------------------------------------------------------------
/docs/references/q_functions.rst:
--------------------------------------------------------------------------------
1 | Q Functions
2 | ===========
3 |
4 | .. module:: d3rlpy.models
5 |
6 | d3rlpy provides various Q functions including state-of-the-arts, which are
7 | internally used in algorithm objects.
8 | You can switch Q functions by passing ``q_func_factory`` argument at
9 | algorithm initialization.
10 |
11 | .. code-block:: python
12 |
13 | import d3rlpy
14 |
15 | cql = d3rlpy.algos.CQLConfig(q_func_factory=d3rlpy.models.QRQFunctionFactory())
16 |
17 | Also you can change hyper parameters.
18 |
19 | .. code-block:: python
20 |
21 | q_func = d3rlpy.models.QRQFunctionFactory(n_quantiles=32)
22 |
23 | cql = d3rlpy.algos.CQLConfig(q_func_factory=q_func).create()
24 |
25 | The default Q function is ``mean`` approximator, which estimates expected scalar
26 | action-values.
27 | However, in recent advancements of deep reinforcement learning, the new type
28 | of action-value approximators has been proposed, which is called
29 | `distributional` Q functions.
30 |
31 | Unlike the ``mean`` approximator, the `distributional` Q functions estimate
32 | distribution of action-values.
33 | This `distributional` approaches have shown consistently much stronger
34 | performance than the ``mean`` approximator.
35 |
36 | Here is a list of available Q functions in the order of performance
37 | ascendingly.
38 | Currently, as a trade-off between performance and computational complexity,
39 | the higher performance requires the more expensive computational costs.
40 |
41 | .. autosummary::
42 | :toctree: generated/
43 | :nosignatures:
44 |
45 | d3rlpy.models.MeanQFunctionFactory
46 | d3rlpy.models.QRQFunctionFactory
47 | d3rlpy.models.IQNQFunctionFactory
48 |
--------------------------------------------------------------------------------
/docs/reproductions.rst:
--------------------------------------------------------------------------------
1 | Paper Reproductions
2 | -------------------
3 |
4 | For the experiment code, please take a look at
5 | `reproductions `_ directory.
6 |
7 | All the experimental results are available in `d3rlpy-benchmarks `_ repository.
8 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | tensorboardX==2.0
3 | Sphinx==5.0.2
4 | sphinx-rtd-theme==0.5.0
5 |
--------------------------------------------------------------------------------
/docs/tutorials/create_your_dataset.rst:
--------------------------------------------------------------------------------
1 | *******************
2 | Create Your Dataset
3 | *******************
4 |
5 | The data collection API is introduced in :doc:`data_collection`.
6 | In this tutorial, you can learn how to build your dataset from logged data
7 | such as the user data collected in your web service.
8 |
9 | Prepare Logged Data
10 | -------------------
11 |
12 | First of all, you need to prepare your logged data.
13 | In this tutorial, let's use randomly generated data.
14 | ``terminals`` represents the last step of episodes.
15 | If ``terminals[i] == 1.0``, i-th step is the terminal state.
16 | Otherwise you need to set zeros for non-terminal states.
17 |
18 | .. code-block:: python
19 |
20 | import numpy as np
21 |
22 | # vector observation
23 | # 1000 steps of observations with shape of (100,)
24 | observations = np.random.random((1000, 100))
25 |
26 | # 1000 steps of actions with shape of (4,)
27 | actions = np.random.random((1000, 4))
28 |
29 | # 1000 steps of rewards
30 | rewards = np.random.random(1000)
31 |
32 | # 1000 steps of terminal flags
33 | terminals = np.random.randint(2, size=1000)
34 |
35 | Build MDPDataset
36 | ----------------
37 |
38 | Once your logged data is ready, you can build ``MDPDataset`` object.
39 |
40 | .. code-block:: python
41 |
42 | import d3rlpy
43 |
44 | dataset = d3rlpy.dataset.MDPDataset(
45 | observations=observations,
46 | actions=actions,
47 | rewards=rewards,
48 | terminals=terminals,
49 | )
50 |
51 | Set Timeout Flags
52 | -----------------
53 |
54 | In RL, there is the case where you want to stop an episode without a terminal
55 | state.
56 | For example, if you're collecting data of a 4-legged robot walking forward,
57 | the walking task basically never ends as long as the robot keeps walking while
58 | the logged episode must stop somewhere.
59 | In this case, you can use ``timeouts`` to represent this timeout states.
60 |
61 | .. code-block:: python
62 |
63 | # terminal states
64 | terminals = np.zeros(1000)
65 |
66 | # timeout states
67 | timeouts = np.random.randint(2, size=1000)
68 |
69 | dataset = d3rlpy.dataset.MDPDataset(
70 | observations=observations,
71 | actions=actions,
72 | rewards=rewards,
73 | terminals=terminals,
74 | timeouts=timeouts,
75 | )
76 |
--------------------------------------------------------------------------------
/docs/tutorials/finetuning.rst:
--------------------------------------------------------------------------------
1 | **********
2 | Finetuning
3 | **********
4 |
5 | d3rlpy supports smooth transition from offline training to online training.
6 |
7 | Prepare Dataset and Environment
8 | -------------------------------
9 |
10 | In this tutorial, let's use a built-in dataset for CartPole-v0 environment.
11 |
12 | .. code-block:: python
13 |
14 | import d3rlpy
15 |
16 | # setup random CartPole-v0 dataset and environment
17 | dataset, env = d3rlpy.datasets.get_dataset("cartpole-random")
18 |
19 | Pretrain with Dataset
20 | ---------------------
21 |
22 | .. code-block:: python
23 |
24 | # setup algorithm
25 | dqn = d3rlpy.algos.DQNConfig().create()
26 |
27 | # start offline training
28 | dqn.fit(dataset, n_steps=100000)
29 |
30 | Finetune with Environment
31 | -------------------------
32 |
33 | .. code-block:: python
34 |
35 | # setup experience replay buffer
36 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=100000, env=env)
37 |
38 | # setup exploration strategy if necessary
39 | explorer = d3rlpy.algos.ConstantEpsilonGreedy(0.1)
40 |
41 | # start finetuning
42 | dqn.fit_online(env, buffer, explorer, n_steps=100000)
43 |
44 | Finetune with Saved Policy
45 | --------------------------
46 |
47 | If you want to finetune the saved policy, that's also easy to do with d3rlpy.
48 |
49 | .. code-block:: python
50 |
51 | # setup algorithm
52 | dqn = d3rlpy.load_learnable("dqn_model.d3")
53 |
54 | # start finetuning
55 | dqn.fit_online(env, buffer, explorer, n_steps=100000)
56 |
57 | Finetune with Different Algorithm
58 | ---------------------------------
59 |
60 | If you want to finetune the saved policy trained offline with online RL
61 | algorithms, you can do it in an out-of-the-box way.
62 |
63 | .. code-block:: python
64 |
65 | # setup offline RL algorithm
66 | cql = d3rlpy.algos.DiscreteCQLConfig().create()
67 |
68 | # train offline
69 | cql.fit(dataset, n_steps=100000)
70 |
71 | # transfer to DQN
72 | dqn = d3rlpy.algos.DQNConfig().create()
73 | dqn.build_with_env(env)
74 | dqn.copy_q_function_from(cql)
75 |
76 | # start finetuning
77 | dqn.fit_online(env, buffer, explorer, n_steps=100000)
78 |
79 | In actor-critic cases, you should also transfer the policy function.
80 |
81 | .. code-block:: python
82 |
83 | # offline RL
84 | cql = d3rlpy.algos.CQLConfig().create()
85 | cql.fit(dataset, n_steps=100000)
86 |
87 | # transfer to SAC
88 | sac = d3rlpy.algos.SACConfig().create()
89 | sac.build_with_env(env)
90 | sac.copy_q_function_from(cql)
91 | sac.copy_policy_from(cql)
92 |
93 | # online RL
94 | sac.fit_online(env, buffer, n_steps=100000)
95 |
--------------------------------------------------------------------------------
/docs/tutorials/index.rst:
--------------------------------------------------------------------------------
1 | *********
2 | Tutorials
3 | *********
4 |
5 | .. toctree::
6 | :maxdepth: 2
7 |
8 | getting_started
9 | data_collection
10 | create_your_dataset
11 | preprocess_and_postprocess
12 | customize_neural_network
13 | online_rl
14 | finetuning
15 | offline_policy_selection
16 | use_distributional_q_function
17 | after_training_policies
18 |
--------------------------------------------------------------------------------
/docs/tutorials/online_rl.rst:
--------------------------------------------------------------------------------
1 | *********
2 | Online RL
3 | *********
4 |
5 | Prepare Environment
6 | -------------------
7 |
8 | d3rlpy supports environments with OpenAI Gym interface.
9 | In this tutorial, let's use simple CartPole environment.
10 |
11 | .. code-block:: python
12 |
13 | import gym
14 |
15 | # for training
16 | env = gym.make("CartPole-v1")
17 |
18 | # for evaluation
19 | eval_env = gym.make("CartPole-v1")
20 |
21 | Setup Algorithm
22 | ---------------
23 |
24 | Just like offline RL training, you can setup an algorithm object.
25 |
26 | .. code-block:: python
27 |
28 | import d3rlpy
29 |
30 | # if you don't use GPU, set use_gpu=False instead.
31 | dqn = d3rlpy.algos.DQNConfig(
32 | batch_size=32,
33 | learning_rate=2.5e-4,
34 | target_update_interval=100,
35 | ).create(device="cuda:0")
36 |
37 | # initialize neural networks with the given environment object.
38 | # this is not necessary when you directly call fit or fit_online method.
39 | dqn.build_with_env(env)
40 |
41 |
42 | Setup Online RL Utilities
43 | -------------------------
44 |
45 | Unlike offline RL training, you'll need to setup an experience replay buffer and
46 | an exploration strategy.
47 |
48 | .. code-block:: python
49 |
50 | # experience replay buffer
51 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=100000, env=env)
52 |
53 | # exploration strategy
54 | # in this tutorial, epsilon-greedy policy with static epsilon=0.3
55 | explorer = d3rlpy.algos.ConstantEpsilonGreedy(0.3)
56 |
57 |
58 | Start Training
59 | --------------
60 |
61 | Now, you have everything you need to start online RL training.
62 | Let's put them together!
63 |
64 | .. code-block:: python
65 |
66 | dqn.fit_online(
67 | env,
68 | buffer,
69 | explorer,
70 | n_steps=100000, # train for 100K steps
71 | eval_env=eval_env,
72 | n_steps_per_epoch=1000, # evaluation is performed every 1K steps
73 | update_start_step=1000, # parameter update starts after 1K steps
74 | )
75 |
76 | Train with Stochastic Policy
77 | ----------------------------
78 |
79 | If the algorithm uses a stochastic policy (e.g. SAC), you can train algorithms
80 | without setting an exploration strategy.
81 |
82 | .. code-block:: python
83 |
84 | sac = d3rlpy.algos.DiscreteSACConfig().create()
85 | sac.fit_online(
86 | env,
87 | buffer,
88 | n_steps=100000,
89 | eval_env=eval_env,
90 | n_steps_per_epoch=1000,
91 | update_start_step=1000,
92 | )
93 |
--------------------------------------------------------------------------------
/docs/tutorials/use_distributional_q_function.rst:
--------------------------------------------------------------------------------
1 | *****************************
2 | Use Distributional Q-Function
3 | *****************************
4 |
5 | The one of the unique features in d3rlpy is to use distributional Q-functions
6 | with arbitrary d3rlpy algorithms.
7 | The distributional Q-functions are powerful and potentially capable of
8 | improving performance of any algorithms.
9 | In this tutorial, you can learn how to use them.
10 | Check :doc:`../references/q_functions` for more information.
11 |
12 | .. code-block:: python
13 |
14 | # default standard Q-function
15 | mean_q_function = d3rlpy.models.MeanQFunctionFactory()
16 | sac = d3rlpy.algos.SACConfig(q_func_factory=mean_q_function).create()
17 |
18 | # Quantile Regression Q-function
19 | qr_q_function = d3rlpy.models.QRQFunctionFactory(n_quantiles=200)
20 | sac = d3rlpy.algos.SACConfig(q_func_factory=qr_q_function).create()
21 |
22 | # Implicit Quantile Network Q-function
23 | iqn_q_function = d3rlpy.models.IQNQFunctionFactory(
24 | n_quantiles=32,
25 | n_greedy_quantiles=64,
26 | embed_size=64,
27 | )
28 | sac = d3rlpy.algos.SACConfig(q_func_factory=iqn_q_function).create()
29 |
--------------------------------------------------------------------------------
/examples/deepmind_control.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gymnasium
4 | from gymnasium.wrappers import FlattenObservation
5 |
6 | import d3rlpy
7 |
8 | # You need to install DeepMind Control (DMC) and Shimmy beforehands as follows:
9 | #
10 | # $ d3rlpy install dm_control
11 | #
12 | # After you install the packages, d3rlpy internally registers DMC environments
13 | # for Gymnasium.
14 |
15 |
16 | def main() -> None:
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument("--env", type=str, default="hopper-stand")
19 | parser.add_argument("--seed", type=int, default=1)
20 | parser.add_argument("--gpu", action="store_true")
21 | args = parser.parse_args()
22 |
23 | env_id = f"dm_control/{args.env}"
24 | env = FlattenObservation(gymnasium.make(env_id)) # type: ignore
25 | eval_env = FlattenObservation(gymnasium.make(env_id)) # type: ignore
26 |
27 | # fix seed
28 | d3rlpy.seed(args.seed)
29 | d3rlpy.envs.seed_env(env, args.seed)
30 | d3rlpy.envs.seed_env(eval_env, args.seed)
31 |
32 | # setup algorithm
33 | sac = d3rlpy.algos.SACConfig(
34 | batch_size=256,
35 | actor_learning_rate=3e-4,
36 | critic_learning_rate=3e-4,
37 | temp_learning_rate=3e-4,
38 | ).create(device=args.gpu)
39 |
40 | # replay buffer for experience replay
41 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=1000000, env=env)
42 |
43 | # start training
44 | sac.fit_online(
45 | env,
46 | buffer,
47 | eval_env=eval_env,
48 | n_steps=1000000,
49 | n_steps_per_epoch=10000,
50 | update_interval=1,
51 | update_start_step=1000,
52 | )
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/examples/distributed_offline_training.py:
--------------------------------------------------------------------------------
1 | import d3rlpy
2 |
3 | # This script needs to be launched by using torchrun command.
4 | # $ torchrun \
5 | # --nnodes=1 \
6 | # --nproc_per_node=3 \
7 | # --rdzv_id=100 \
8 | # --rdzv_backend=c10d \
9 | # --rdzv_endpoint=localhost:29400 \
10 | # examples/distributed_offline_training.py
11 |
12 |
13 | def main() -> None:
14 | # GPU version:
15 | # rank = d3rlpy.distributed.init_process_group("nccl")
16 | rank = d3rlpy.distributed.init_process_group("gloo")
17 | print(f"Start running on rank={rank}.")
18 |
19 | # GPU version:
20 | # device = f"cuda:{rank}"
21 | device = "cpu:0"
22 |
23 | # setup algorithm
24 | cql = d3rlpy.algos.CQLConfig(
25 | actor_learning_rate=1e-3,
26 | critic_learning_rate=1e-3,
27 | alpha_learning_rate=1e-3,
28 | ).create(device=device, enable_ddp=True)
29 |
30 | # prepare dataset
31 | dataset, env = d3rlpy.datasets.get_pendulum()
32 |
33 | # disable logging on rank != 0 workers
34 | logger_adapter: d3rlpy.logging.LoggerAdapterFactory
35 | evaluators: dict[str, d3rlpy.metrics.EvaluatorProtocol]
36 | if rank == 0:
37 | evaluators = {"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}
38 | logger_adapter = d3rlpy.logging.FileAdapterFactory()
39 | else:
40 | evaluators = {}
41 | logger_adapter = d3rlpy.logging.NoopAdapterFactory()
42 |
43 | # start training
44 | cql.fit(
45 | dataset,
46 | n_steps=10000,
47 | n_steps_per_epoch=1000,
48 | evaluators=evaluators,
49 | logger_adapter=logger_adapter,
50 | show_progress=rank == 0,
51 | )
52 |
53 | d3rlpy.distributed.destroy_process_group()
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/examples/fine_tuning.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 |
4 | import d3rlpy
5 |
6 |
7 | def main() -> None:
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
10 | parser.add_argument("--seed", type=int, default=1)
11 | parser.add_argument("--gpu", type=int)
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | cql = d3rlpy.algos.CQLConfig(
21 | actor_learning_rate=1e-4,
22 | critic_learning_rate=3e-4,
23 | temp_learning_rate=1e-4,
24 | batch_size=256,
25 | n_action_samples=10,
26 | alpha_learning_rate=0.0,
27 | conservative_weight=10.0,
28 | ).create(device=args.gpu)
29 |
30 | # pretraining
31 | cql.fit(
32 | dataset,
33 | n_steps=100000,
34 | n_steps_per_epoch=1000,
35 | save_interval=10,
36 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
37 | experiment_name=f"CQL_pretraining_{args.dataset}_{args.seed}",
38 | )
39 |
40 | sac = d3rlpy.algos.SACConfig(
41 | actor_learning_rate=3e-4,
42 | critic_learning_rate=3e-4,
43 | temp_learning_rate=3e-4,
44 | batch_size=256,
45 | ).create(device=args.gpu)
46 |
47 | # copy pretrained models to SAC
48 | sac.build_with_env(env)
49 | sac.copy_policy_from(cql) # type: ignore
50 | sac.copy_q_function_from(cql) # type: ignore
51 |
52 | # prepare FIFO buffer filled with dataset episodes
53 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(
54 | limit=100000,
55 | episodes=dataset.episodes,
56 | )
57 |
58 | # finetuning
59 | eval_env = copy.deepcopy(env)
60 | d3rlpy.envs.seed_env(eval_env, args.seed)
61 | sac.fit_online(
62 | env,
63 | buffer=buffer,
64 | eval_env=eval_env,
65 | experiment_name=f"SAC_finetuning_{args.dataset}_{args.seed}",
66 | n_steps=100000,
67 | n_steps_per_epoch=1000,
68 | save_interval=10,
69 | )
70 |
71 |
72 | if __name__ == "__main__":
73 | main()
74 |
--------------------------------------------------------------------------------
/examples/frame_stacking.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 |
5 | import d3rlpy
6 |
7 |
8 | def main() -> None:
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--env", type=str, default="PongNoFrameskip-v4")
11 | parser.add_argument("--seed", type=int, default=1)
12 | parser.add_argument("--gpu", action="store_true")
13 | args = parser.parse_args()
14 |
15 | # get wrapped atari environment with 4 frame stacking
16 | # observation shape is [4, 84, 84]
17 | env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4)
18 | eval_env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4, is_eval=True)
19 |
20 | # fix seed
21 | d3rlpy.seed(args.seed)
22 | d3rlpy.envs.seed_env(env, args.seed)
23 | d3rlpy.envs.seed_env(eval_env, args.seed)
24 |
25 | # setup algorithm
26 | dqn = d3rlpy.algos.DQNConfig(
27 | batch_size=32,
28 | learning_rate=2.5e-4,
29 | optim_factory=d3rlpy.optimizers.RMSpropFactory(),
30 | target_update_interval=10000 // 4,
31 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),
32 | ).create(device=args.gpu)
33 |
34 | # replay buffer for experience replay
35 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(
36 | limit=1000000,
37 | # stack last 4 frames (stacked shape is [4, 84, 84])
38 | transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4),
39 | # store only last frame to save memory (stored shape is [1, 84, 84])
40 | writer_preprocessor=d3rlpy.dataset.LastFrameWriterPreprocess(),
41 | env=env,
42 | )
43 |
44 | # epilon-greedy explorer
45 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy(
46 | start_epsilon=1.0, end_epsilon=0.1, duration=1000000
47 | )
48 |
49 | # start training
50 | dqn.fit_online(
51 | env,
52 | buffer,
53 | explorer,
54 | eval_env=eval_env,
55 | eval_epsilon=0.01,
56 | n_steps=1000000,
57 | n_steps_per_epoch=100000,
58 | update_interval=4,
59 | update_start_step=50000,
60 | )
61 |
62 |
63 | if __name__ == "__main__":
64 | main()
65 |
--------------------------------------------------------------------------------
/examples/gymnasium_env.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gymnasium
4 |
5 | import d3rlpy
6 |
7 |
8 | def main() -> None:
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--env", type=str, default="Hopper-v2")
11 | parser.add_argument("--seed", type=int, default=1)
12 | parser.add_argument("--gpu", action="store_true")
13 | args = parser.parse_args()
14 |
15 | # d3rlpy supports both Gym and Gymnasium
16 | env = gymnasium.make(args.env)
17 | eval_env = gymnasium.make(args.env)
18 |
19 | # fix seed
20 | d3rlpy.seed(args.seed)
21 | d3rlpy.envs.seed_env(env, args.seed)
22 | d3rlpy.envs.seed_env(eval_env, args.seed)
23 |
24 | # setup algorithm
25 | sac = d3rlpy.algos.SACConfig(
26 | batch_size=256,
27 | actor_learning_rate=3e-4,
28 | critic_learning_rate=3e-4,
29 | temp_learning_rate=3e-4,
30 | ).create(device=args.gpu)
31 |
32 | # replay buffer for experience replay
33 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=1000000, env=env)
34 |
35 | # start training
36 | sac.fit_online(
37 | env,
38 | buffer,
39 | eval_env=eval_env,
40 | n_steps=1000000,
41 | n_steps_per_epoch=10000,
42 | update_interval=1,
43 | update_start_step=1000,
44 | )
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
49 |
--------------------------------------------------------------------------------
/examples/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gymnasium
4 |
5 | import d3rlpy
6 |
7 |
8 | def main() -> None:
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--env", type=str, default="Hopper-v2")
11 | parser.add_argument("--seed", type=int, default=1)
12 | parser.add_argument("--gpu", action="store_true")
13 | args = parser.parse_args()
14 |
15 | env = gymnasium.make(args.env)
16 | eval_env = gymnasium.make(args.env)
17 |
18 | # fix seed
19 | d3rlpy.seed(args.seed)
20 | d3rlpy.envs.seed_env(env, args.seed)
21 | d3rlpy.envs.seed_env(eval_env, args.seed)
22 |
23 | # setup algorithm
24 | sac = d3rlpy.algos.SACConfig(
25 | batch_size=256,
26 | actor_learning_rate=3e-4,
27 | critic_learning_rate=3e-4,
28 | actor_optim_factory=d3rlpy.optimizers.AdamFactory(
29 | # setup learning rate scheduler
30 | lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory(
31 | warmup_steps=10000
32 | ),
33 | ),
34 | critic_optim_factory=d3rlpy.optimizers.AdamFactory(
35 | # setup learning rate scheduler
36 | lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory(
37 | warmup_steps=10000
38 | ),
39 | ),
40 | temp_learning_rate=3e-4,
41 | ).create(device=args.gpu)
42 |
43 | # replay buffer for experience replay
44 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=1000000, env=env)
45 |
46 | # start training
47 | sac.fit_online(
48 | env,
49 | buffer,
50 | eval_env=eval_env,
51 | n_steps=1000000,
52 | n_steps_per_epoch=10000,
53 | update_interval=1,
54 | update_start_step=1000,
55 | )
56 |
57 |
58 | if __name__ == "__main__":
59 | main()
60 |
--------------------------------------------------------------------------------
/examples/multi_step_training.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 |
5 | import d3rlpy
6 |
7 | GAMMA = 0.99
8 |
9 |
10 | def main() -> None:
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--env", type=str, default="Pendulum-v1")
13 | parser.add_argument("--seed", type=int, default=1)
14 | parser.add_argument("--n-steps", type=int, default=1)
15 | parser.add_argument("--gpu", action="store_true")
16 | args = parser.parse_args()
17 |
18 | env = gym.make(args.env)
19 | eval_env = gym.make(args.env)
20 |
21 | # fix seed
22 | d3rlpy.seed(args.seed)
23 | d3rlpy.envs.seed_env(env, args.seed)
24 | d3rlpy.envs.seed_env(eval_env, args.seed)
25 |
26 | # setup algorithm
27 | sac = d3rlpy.algos.SACConfig(
28 | batch_size=256,
29 | gamma=GAMMA,
30 | actor_learning_rate=3e-4,
31 | critic_learning_rate=3e-4,
32 | temp_learning_rate=3e-4,
33 | action_scaler=d3rlpy.preprocessing.MinMaxActionScaler(),
34 | ).create(device=args.gpu)
35 |
36 | # multi-step transition sampling
37 | transition_picker = d3rlpy.dataset.MultiStepTransitionPicker(
38 | n_steps=args.n_steps,
39 | gamma=GAMMA,
40 | )
41 |
42 | # replay buffer for experience replay
43 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(
44 | limit=100000,
45 | env=env,
46 | transition_picker=transition_picker,
47 | )
48 |
49 | # start training
50 | sac.fit_online(
51 | env,
52 | buffer,
53 | eval_env=eval_env,
54 | n_steps=100000,
55 | n_steps_per_epoch=1000,
56 | update_interval=1,
57 | update_start_step=1000,
58 | )
59 |
60 |
61 | if __name__ == "__main__":
62 | main()
63 |
--------------------------------------------------------------------------------
/examples/preprocessors.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 |
5 | import d3rlpy
6 |
7 |
8 | def main() -> None:
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--env", type=str, default="Pendulum-v1")
11 | parser.add_argument("--seed", type=int, default=1)
12 | parser.add_argument("--gpu", action="store_true")
13 | args = parser.parse_args()
14 |
15 | env = gym.make(args.env)
16 | eval_env = gym.make(args.env)
17 |
18 | # fix seed
19 | d3rlpy.seed(args.seed)
20 | d3rlpy.envs.seed_env(env, args.seed)
21 | d3rlpy.envs.seed_env(eval_env, args.seed)
22 |
23 | # setup algorithm
24 | sac = d3rlpy.algos.SACConfig(
25 | batch_size=256,
26 | actor_learning_rate=3e-4,
27 | critic_learning_rate=3e-4,
28 | temp_learning_rate=3e-4,
29 | # normalizes observations within [-1, 1] range
30 | observation_scaler=d3rlpy.preprocessing.MinMaxObservationScaler(),
31 | # normalizes actions within [-1, 1] range
32 | action_scaler=d3rlpy.preprocessing.MinMaxActionScaler(),
33 | # multiply rewards by 0.1
34 | reward_scaler=d3rlpy.preprocessing.MultiplyRewardScaler(0.1),
35 | ).create(device=args.gpu)
36 |
37 | # replay buffer for experience replay
38 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(
39 | limit=100000,
40 | env=env,
41 | )
42 |
43 | # start training
44 | sac.fit_online(
45 | env,
46 | buffer,
47 | eval_env=eval_env,
48 | n_steps=100000,
49 | n_steps_per_epoch=1000,
50 | update_interval=1,
51 | update_start_step=1000,
52 | )
53 |
54 |
55 | if __name__ == "__main__":
56 | main()
57 |
--------------------------------------------------------------------------------
/mypy.ini:
--------------------------------------------------------------------------------
1 | [mypy]
2 | python_version = 3.9
3 | strict = True
4 | strict_optional = True
5 | disallow_untyped_defs = True
6 | disallow_incomplete_defs = True
7 | disallow_untyped_decorators = True
8 | no_implicit_optional = True
9 | warn_redundant_casts = True
10 | warn_unused_ignores = True
11 | warn_return_any = True
12 | warn_unused_configs = True
13 | plugins = numpy.typing.mypy_plugin
14 |
15 | [mypy-torch.*]
16 | ignore_missing_imports = True
17 | follow_imports = skip
18 | follow_imports_for_stubs = True
19 |
20 | [mypy-tqdm.*]
21 | ignore_missing_imports = True
22 |
23 | [mypy-tensorboardX]
24 | ignore_missing_imports = True
25 |
26 | [mypy-wandb]
27 | ignore_missing_imports = True
28 | follow_imports = skip
29 | follow_imports_for_stubs = True
30 |
31 | [mypy-matplotlib.*]
32 | ignore_missing_imports = True
33 |
34 | [mypy-seaborn.*]
35 | ignore_missing_imports = True
36 |
37 | [mypy-cv2.*]
38 | ignore_missing_imports = True
39 |
40 | [mypy-h5py.*]
41 | ignore_missing_imports = True
42 |
43 | [mypy-dataclasses_json.*]
44 | ignore_missing_imports = True
45 |
46 | [mypy-onnxruntime.*]
47 | ignore_missing_imports = True
48 |
49 | [mypy-click.*]
50 | ignore_missing_imports = True
51 | follow_imports = skip
52 | follow_imports_for_stubs = True
53 |
54 | [mypy-pyvirtualdisplay.*]
55 | ignore_missing_imports = True
56 |
57 | [mypy-IPython.*]
58 | ignore_missing_imports = True
59 | follow_imports = skip
60 | follow_imports_for_stubs = True
61 |
62 | [mypy-minari.*]
63 | ignore_missing_imports = True
64 |
65 | [mypy-d4rl.*]
66 | ignore_missing_imports = True
67 | follow_imports = skip
68 | follow_imports_for_stubs = True
69 |
70 | [mypy-shimmy.*]
71 | ignore_missing_imports = True
72 | follow_imports = skip
73 | follow_imports_for_stubs = True
74 |
75 | [mypy-sklearn.*]
76 | ignore_missing_imports = True
77 |
--------------------------------------------------------------------------------
/reproductions/finetuning/awac_finetune.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 |
4 | import d3rlpy
5 |
6 |
7 | def main() -> None:
8 | parser = argparse.ArgumentParser()
9 | parser.add_argument("--dataset", type=str, default="D4RL/antmaze/umaze-v1")
10 | parser.add_argument("--seed", type=int, default=1)
11 | parser.add_argument("--gpu", type=int)
12 | parser.add_argument("--compile", action="store_true")
13 | args = parser.parse_args()
14 |
15 | dataset, env = d3rlpy.datasets.get_minari(args.dataset)
16 |
17 | # fix seed
18 | d3rlpy.seed(args.seed)
19 | d3rlpy.envs.seed_env(env, args.seed)
20 |
21 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256, 256])
22 | optim = d3rlpy.optimizers.AdamFactory(weight_decay=1e-4)
23 | # for antmaze datasets
24 | reward_scaler = d3rlpy.preprocessing.ConstantShiftRewardScaler(shift=-1)
25 |
26 | awac = d3rlpy.algos.AWACConfig(
27 | actor_learning_rate=3e-4,
28 | actor_encoder_factory=encoder,
29 | actor_optim_factory=optim,
30 | critic_learning_rate=3e-4,
31 | batch_size=1024,
32 | lam=1.0,
33 | reward_scaler=reward_scaler,
34 | compile_graph=args.compile,
35 | ).create(device=args.gpu)
36 |
37 | awac.fit(
38 | dataset,
39 | n_steps=25000,
40 | n_steps_per_epoch=5000,
41 | save_interval=10,
42 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
43 | experiment_name=f"AWAC_pretraining_{args.dataset}_{args.seed}",
44 | )
45 |
46 | # prepare FIFO buffer filled with dataset episodes
47 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(1000000)
48 | for episode in dataset.episodes:
49 | buffer.append_episode(episode)
50 |
51 | # finetuning
52 | eval_env = copy.deepcopy(env)
53 | d3rlpy.envs.seed_env(eval_env, args.seed)
54 | awac.fit_online(
55 | env,
56 | buffer=buffer,
57 | eval_env=eval_env,
58 | experiment_name=f"AWAC_finetuning_{args.dataset}_{args.seed}",
59 | n_steps=1000000,
60 | n_steps_per_epoch=1000,
61 | save_interval=10,
62 | )
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/reproductions/finetuning/iql_finetune.py:
--------------------------------------------------------------------------------
1 | # pylint: disable=protected-access
2 | import argparse
3 | import copy
4 |
5 | import d3rlpy
6 |
7 |
8 | def main() -> None:
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--dataset", type=str, default="D4RL/antmaze/umaze-v1")
11 | parser.add_argument("--seed", type=int, default=1)
12 | parser.add_argument("--gpu", type=int)
13 | parser.add_argument("--compile", action="store_true")
14 | args = parser.parse_args()
15 |
16 | dataset, env = d3rlpy.datasets.get_minari(args.dataset)
17 |
18 | # fix seed
19 | d3rlpy.seed(args.seed)
20 | d3rlpy.envs.seed_env(env, args.seed)
21 |
22 | # for antmaze datasets
23 | reward_scaler = d3rlpy.preprocessing.ConstantShiftRewardScaler(shift=-1)
24 |
25 | iql = d3rlpy.algos.IQLConfig(
26 | actor_learning_rate=3e-4,
27 | critic_learning_rate=3e-4,
28 | actor_optim_factory=d3rlpy.optimizers.AdamFactory(
29 | lr_scheduler_factory=d3rlpy.optimizers.CosineAnnealingLRFactory(
30 | T_max=1000000
31 | ),
32 | ),
33 | batch_size=256,
34 | weight_temp=10.0, # hyperparameter for antmaze
35 | max_weight=100.0,
36 | expectile=0.9, # hyperparameter for antmaze
37 | reward_scaler=reward_scaler,
38 | compile_graph=args.compile,
39 | ).create(device=args.gpu)
40 |
41 | # pretraining
42 | iql.fit(
43 | dataset,
44 | n_steps=1000000,
45 | n_steps_per_epoch=100000,
46 | save_interval=10,
47 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
48 | experiment_name=f"IQL_pretraining_{args.dataset}_{args.seed}",
49 | )
50 |
51 | # reset learning rate
52 | assert iql.impl
53 | for g in iql.impl._modules.actor_optim.optim.param_groups:
54 | g["lr"] = iql.config.actor_learning_rate
55 |
56 | # prepare FIFO buffer filled with dataset episodes
57 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(1000000)
58 | for episode in dataset.episodes:
59 | buffer.append_episode(episode)
60 |
61 | # finetuning
62 | eval_env = copy.deepcopy(env)
63 | d3rlpy.envs.seed_env(eval_env, args.seed)
64 | iql.fit_online(
65 | env,
66 | buffer=buffer,
67 | eval_env=eval_env,
68 | experiment_name=f"IQL_finetuning_{args.dataset}_{args.seed}",
69 | n_steps=1000000,
70 | n_steps_per_epoch=1000,
71 | save_interval=10,
72 | )
73 |
74 |
75 | if __name__ == "__main__":
76 | main()
77 |
--------------------------------------------------------------------------------
/reproductions/offline/awac.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256, 256])
21 | optim = d3rlpy.optimizers.AdamFactory(weight_decay=1e-4)
22 |
23 | awac = d3rlpy.algos.AWACConfig(
24 | actor_learning_rate=3e-4,
25 | actor_encoder_factory=encoder,
26 | actor_optim_factory=optim,
27 | critic_learning_rate=3e-4,
28 | critic_encoder_factory=encoder,
29 | batch_size=1024,
30 | lam=1.0,
31 | compile_graph=args.compile,
32 | ).create(args.gpu)
33 |
34 | awac.fit(
35 | dataset,
36 | n_steps=500000,
37 | n_steps_per_epoch=1000,
38 | save_interval=10,
39 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
40 | experiment_name=f"AWAC_{args.dataset}_{args.seed}",
41 | )
42 |
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/reproductions/offline/bcq.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750])
21 | rl_encoder = d3rlpy.models.encoders.VectorEncoderFactory([400, 300])
22 |
23 | bcq = d3rlpy.algos.BCQConfig(
24 | actor_encoder_factory=rl_encoder,
25 | actor_learning_rate=1e-3,
26 | critic_encoder_factory=rl_encoder,
27 | critic_learning_rate=1e-3,
28 | imitator_encoder_factory=vae_encoder,
29 | imitator_learning_rate=1e-3,
30 | batch_size=100,
31 | lam=0.75,
32 | action_flexibility=0.05,
33 | n_action_samples=100,
34 | compile_graph=args.compile,
35 | ).create(args.gpu)
36 |
37 | bcq.fit(
38 | dataset,
39 | n_steps=500000,
40 | n_steps_per_epoch=1000,
41 | save_interval=10,
42 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
43 | experiment_name=f"BCQ_{args.dataset}_{args.seed}",
44 | )
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
49 |
--------------------------------------------------------------------------------
/reproductions/offline/bear.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750])
21 |
22 | if "halfcheetah" in args.dataset:
23 | kernel = "gaussian"
24 | else:
25 | kernel = "laplacian"
26 |
27 | bear = d3rlpy.algos.BEARConfig(
28 | actor_learning_rate=1e-4,
29 | critic_learning_rate=3e-4,
30 | imitator_learning_rate=3e-4,
31 | alpha_learning_rate=1e-3,
32 | imitator_encoder_factory=vae_encoder,
33 | temp_learning_rate=0.0,
34 | initial_temperature=1e-20,
35 | batch_size=256,
36 | mmd_sigma=20.0,
37 | mmd_kernel=kernel,
38 | n_mmd_action_samples=4,
39 | alpha_threshold=0.05,
40 | n_target_samples=10,
41 | n_action_samples=100,
42 | warmup_steps=40000,
43 | compile_graph=args.compile,
44 | ).create(device=args.gpu)
45 |
46 | bear.fit(
47 | dataset,
48 | n_steps=500000,
49 | n_steps_per_epoch=1000,
50 | save_interval=10,
51 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
52 | experiment_name=f"BEAR_{args.dataset}_{args.seed}",
53 | )
54 |
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------
/reproductions/offline/cql.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256])
21 |
22 | if "medium-v0" in args.dataset:
23 | conservative_weight = 10.0
24 | else:
25 | conservative_weight = 5.0
26 |
27 | cql = d3rlpy.algos.CQLConfig(
28 | actor_learning_rate=1e-4,
29 | critic_learning_rate=3e-4,
30 | temp_learning_rate=1e-4,
31 | actor_encoder_factory=encoder,
32 | critic_encoder_factory=encoder,
33 | batch_size=256,
34 | n_action_samples=10,
35 | alpha_learning_rate=0.0,
36 | conservative_weight=conservative_weight,
37 | compile_graph=args.compile,
38 | ).create(device=args.gpu)
39 |
40 | cql.fit(
41 | dataset,
42 | n_steps=500000,
43 | n_steps_per_epoch=1000,
44 | save_interval=10,
45 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
46 | experiment_name=f"CQL_{args.dataset}_{args.seed}",
47 | )
48 |
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/reproductions/offline/crr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | crr = d3rlpy.algos.CRRConfig(
21 | actor_learning_rate=3e-4,
22 | critic_learning_rate=3e-4,
23 | batch_size=256,
24 | weight_type="binary",
25 | advantage_type="mean",
26 | target_update_type="soft",
27 | compile_graph=args.compile,
28 | ).create(device=args.gpu)
29 |
30 | crr.fit(
31 | dataset,
32 | n_steps=500000,
33 | n_steps_per_epoch=1000,
34 | save_interval=10,
35 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
36 | experiment_name=f"CRR_{args.dataset}_{args.seed}",
37 | )
38 |
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/reproductions/offline/decision_transformer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | if "halfcheetah" in args.dataset:
21 | target_return = 6000
22 | elif "hopper" in args.dataset:
23 | target_return = 3600
24 | elif "walker" in args.dataset:
25 | target_return = 5000
26 | else:
27 | raise ValueError("unsupported dataset")
28 |
29 | dt = d3rlpy.algos.DecisionTransformerConfig(
30 | batch_size=64,
31 | learning_rate=1e-4,
32 | optim_factory=d3rlpy.optimizers.AdamWFactory(
33 | weight_decay=1e-4,
34 | clip_grad_norm=0.25,
35 | lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory(
36 | warmup_steps=10000
37 | ),
38 | ),
39 | encoder_factory=d3rlpy.models.VectorEncoderFactory(
40 | [128],
41 | exclude_last_activation=True,
42 | ),
43 | observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(),
44 | reward_scaler=d3rlpy.preprocessing.MultiplyRewardScaler(0.001),
45 | position_encoding_type=d3rlpy.PositionEncodingType.SIMPLE,
46 | context_size=20,
47 | num_heads=1,
48 | num_layers=3,
49 | max_timestep=1000,
50 | compile_graph=args.compile,
51 | ).create(device=args.gpu)
52 |
53 | dt.fit(
54 | dataset,
55 | n_steps=100000,
56 | n_steps_per_epoch=1000,
57 | save_interval=10,
58 | eval_env=env,
59 | eval_target_return=target_return,
60 | experiment_name=f"DT_{args.dataset}_{args.seed}",
61 | )
62 |
63 |
64 | if __name__ == "__main__":
65 | main()
66 |
--------------------------------------------------------------------------------
/reproductions/offline/discrete_bcq.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--game", type=str, default="breakout")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | # fix seed
15 | d3rlpy.seed(args.seed)
16 |
17 | dataset, env = d3rlpy.datasets.get_atari_transitions(
18 | args.game,
19 | fraction=0.01,
20 | index=1 if args.game == "asterix" else 0,
21 | num_stack=4,
22 | )
23 |
24 | d3rlpy.envs.seed_env(env, args.seed)
25 |
26 | bcq = d3rlpy.algos.DiscreteBCQConfig(
27 | learning_rate=5e-5,
28 | optim_factory=d3rlpy.optimizers.AdamFactory(eps=1e-2 / 32),
29 | batch_size=32,
30 | q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory(
31 | n_quantiles=200
32 | ),
33 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),
34 | target_update_interval=2000,
35 | reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0),
36 | action_flexibility=0.3,
37 | beta=0.01,
38 | compile_graph=args.compile,
39 | ).create(device=args.gpu)
40 |
41 | env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001)
42 |
43 | bcq.fit(
44 | dataset,
45 | n_steps=50000000 // 4,
46 | n_steps_per_epoch=125000,
47 | save_interval=10,
48 | evaluators={"environment": env_scorer},
49 | experiment_name=f"DiscreteBCQ_{args.game}_{args.seed}",
50 | )
51 |
52 |
53 | if __name__ == "__main__":
54 | main()
55 |
--------------------------------------------------------------------------------
/reproductions/offline/discrete_cql.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--game", type=str, default="breakout")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | d3rlpy.seed(args.seed)
15 |
16 | dataset, env = d3rlpy.datasets.get_atari_transitions(
17 | args.game,
18 | fraction=0.01,
19 | index=1 if args.game == "asterix" else 0,
20 | num_stack=4,
21 | )
22 |
23 | d3rlpy.envs.seed_env(env, args.seed)
24 |
25 | cql = d3rlpy.algos.DiscreteCQLConfig(
26 | learning_rate=5e-5,
27 | optim_factory=d3rlpy.optimizers.AdamFactory(eps=1e-2 / 32),
28 | batch_size=32,
29 | alpha=4.0,
30 | q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory(
31 | n_quantiles=200
32 | ),
33 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),
34 | target_update_interval=2000,
35 | reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0),
36 | compile_graph=args.compile,
37 | ).create(device=args.gpu)
38 |
39 | env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001)
40 |
41 | cql.fit(
42 | dataset,
43 | n_steps=50000000 // 4,
44 | n_steps_per_epoch=125000,
45 | evaluators={"environment": env_scorer},
46 | experiment_name=f"DiscreteCQL_{args.game}_{args.seed}",
47 | )
48 |
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/reproductions/offline/iql.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | reward_scaler = d3rlpy.preprocessing.ReturnBasedRewardScaler(
21 | multiplier=1000.0
22 | )
23 |
24 | iql = d3rlpy.algos.IQLConfig(
25 | actor_learning_rate=3e-4,
26 | critic_learning_rate=3e-4,
27 | actor_optim_factory=d3rlpy.optimizers.AdamFactory(
28 | lr_scheduler_factory=d3rlpy.optimizers.CosineAnnealingLRFactory(
29 | T_max=500000
30 | ),
31 | ),
32 | batch_size=256,
33 | weight_temp=3.0,
34 | max_weight=100.0,
35 | expectile=0.7,
36 | reward_scaler=reward_scaler,
37 | compile_graph=args.compile,
38 | ).create(device=args.gpu)
39 |
40 | iql.fit(
41 | dataset,
42 | n_steps=500000,
43 | n_steps_per_epoch=1000,
44 | save_interval=10,
45 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
46 | experiment_name=f"IQL_{args.dataset}_{args.seed}",
47 | )
48 |
49 |
50 | if __name__ == "__main__":
51 | main()
52 |
--------------------------------------------------------------------------------
/reproductions/offline/nfq.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--game", type=str, default="breakout")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | # fix seed
15 | d3rlpy.seed(args.seed)
16 |
17 | dataset, env = d3rlpy.datasets.get_atari_transitions(
18 | args.game,
19 | fraction=0.01,
20 | index=1 if args.game == "asterix" else 0,
21 | num_stack=4,
22 | )
23 |
24 | d3rlpy.envs.seed_env(env, args.seed)
25 |
26 | nfq = d3rlpy.algos.NFQConfig(
27 | learning_rate=5e-5,
28 | optim_factory=d3rlpy.optimizers.AdamFactory(),
29 | batch_size=32,
30 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),
31 | reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0),
32 | compile_graph=args.compile,
33 | ).create(device=args.gpu)
34 |
35 | env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001)
36 |
37 | nfq.fit(
38 | dataset,
39 | n_steps=50000000 // 4,
40 | n_steps_per_epoch=125000,
41 | save_interval=10,
42 | evaluators={"environment": env_scorer},
43 | experiment_name=f"NFQ_{args.game}_{args.seed}",
44 | )
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
49 |
--------------------------------------------------------------------------------
/reproductions/offline/plas.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | if "medium-replay" in args.dataset:
21 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([128, 128])
22 | else:
23 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750])
24 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([400, 300])
25 |
26 | plas = d3rlpy.algos.PLASConfig(
27 | actor_learning_rate=1e-4,
28 | actor_encoder_factory=encoder,
29 | critic_learning_rate=1e-3,
30 | critic_encoder_factory=encoder,
31 | imitator_learning_rate=1e-4,
32 | imitator_encoder_factory=vae_encoder,
33 | batch_size=100,
34 | lam=1.0,
35 | warmup_steps=500000,
36 | compile_graph=args.compile,
37 | ).create(device=args.gpu)
38 |
39 | plas.fit(
40 | dataset,
41 | n_steps=1000000, # RL starts at 500000 step
42 | n_steps_per_epoch=1000,
43 | save_interval=10,
44 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
45 | experiment_name=f"PLAS_{args.dataset}_{args.seed}",
46 | )
47 |
48 |
49 | if __name__ == "__main__":
50 | main()
51 |
--------------------------------------------------------------------------------
/reproductions/offline/plas_with_perturbation.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 | ACTION_FLEXIBILITY = {
6 | "walker2d-random-v0": 0.05,
7 | "hopper-random-v0": 0.5,
8 | "halfcheetah-random-v0": 0.01,
9 | "walker2d-medium-v0": 0.1,
10 | "hopper-medium-v0": 0.01,
11 | "halfcheetah-medium-v0": 0.1,
12 | "walker2d-medium-relay-v0": 0.01,
13 | "hopper-medium-replay-v0": 0.2,
14 | "halfcheetah-medium-replay-v0": 0.05,
15 | "walker2d-medium-expert-v0": 0.01,
16 | "hopper-medium-expert-v0": 0.01,
17 | "halfcheetah-medium-expert-v0": 0.01,
18 | }
19 |
20 |
21 | def main() -> None:
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
24 | parser.add_argument("--seed", type=int, default=1)
25 | parser.add_argument("--gpu", type=int)
26 | parser.add_argument("--compile", action="store_true")
27 | args = parser.parse_args()
28 |
29 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
30 |
31 | # fix seed
32 | d3rlpy.seed(args.seed)
33 | d3rlpy.envs.seed_env(env, args.seed)
34 |
35 | if "medium-replay" in args.dataset:
36 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([128, 128])
37 | else:
38 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750])
39 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([400, 300])
40 |
41 | plas = d3rlpy.algos.PLASWithPerturbationConfig(
42 | actor_learning_rate=1e-4,
43 | actor_encoder_factory=encoder,
44 | critic_learning_rate=1e-3,
45 | critic_encoder_factory=encoder,
46 | imitator_learning_rate=1e-4,
47 | imitator_encoder_factory=vae_encoder,
48 | batch_size=100,
49 | lam=1.0,
50 | warmup_steps=500000,
51 | action_flexibility=ACTION_FLEXIBILITY[args.dataset],
52 | compile_graph=args.compile,
53 | ).create(device=args.gpu)
54 |
55 | plas.fit(
56 | dataset,
57 | n_steps=1000000, # RL starts at 500000 step
58 | n_steps_per_epoch=1000,
59 | save_interval=10,
60 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
61 | experiment_name=f"PLASWithPerturbation_{args.dataset}_{args.seed}",
62 | )
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/reproductions/offline/prdc.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v2")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int, default=0)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | prdc = d3rlpy.algos.PRDCConfig(
21 | actor_learning_rate=3e-4,
22 | critic_learning_rate=3e-4,
23 | batch_size=256,
24 | target_smoothing_sigma=0.2,
25 | target_smoothing_clip=0.5,
26 | alpha=2.5,
27 | update_actor_interval=2,
28 | observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(),
29 | compile_graph=args.compile,
30 | ).create(device=args.gpu)
31 |
32 | prdc.fit(
33 | dataset,
34 | n_steps=500000,
35 | n_steps_per_epoch=1000,
36 | save_interval=10,
37 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
38 | experiment_name=f"PRDC_{args.dataset}_{args.seed}",
39 | )
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/reproductions/offline/qr_dqn.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--game", type=str, default="breakout")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | # fix seed
15 | d3rlpy.seed(args.seed)
16 |
17 | dataset, env = d3rlpy.datasets.get_atari_transitions(
18 | args.game,
19 | fraction=0.01,
20 | index=1 if args.game == "asterix" else 0,
21 | num_stack=4,
22 | )
23 |
24 | d3rlpy.envs.seed_env(env, args.seed)
25 |
26 | dqn = d3rlpy.algos.DQNConfig(
27 | learning_rate=5e-5,
28 | optim_factory=d3rlpy.optimizers.AdamFactory(eps=1e-2 / 32),
29 | batch_size=32,
30 | q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory(
31 | n_quantiles=200
32 | ),
33 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),
34 | target_update_interval=2000,
35 | reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0),
36 | compile_graph=args.compile,
37 | ).create(device=args.gpu)
38 |
39 | env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001)
40 |
41 | dqn.fit(
42 | dataset,
43 | n_steps=50000000 // 4,
44 | n_steps_per_epoch=125000,
45 | save_interval=10,
46 | evaluators={"environment": env_scorer},
47 | experiment_name=f"QRDQN_{args.game}_{args.seed}",
48 | )
49 |
50 |
51 | if __name__ == "__main__":
52 | main()
53 |
--------------------------------------------------------------------------------
/reproductions/offline/rebrac.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 | BETA_TABLE: dict[str, tuple[float, float]] = {
6 | "halfcheetah-random": (0.001, 0.1),
7 | "halfcheetah-medium": (0.001, 0.01),
8 | "halfcheetah-expert": (0.01, 0.01),
9 | "halfcheetah-medium-replay": (0.01, 0.001),
10 | "halfcheetah-full-replay": (0.001, 0.1),
11 | "hopper-random": (0.001, 0.01),
12 | "hopper-medium": (0.01, 0.001),
13 | "hopper-expert": (0.1, 0.001),
14 | "hopper-medium-expert": (0.1, 0.01),
15 | "hopper-medium-replay": (0.05, 0.5),
16 | "hopper-full-replay": (0.01, 0.01),
17 | "walker2d-random": (0.01, 0.0),
18 | "walker2d-medium": (0.05, 0.1),
19 | "walker2d-expert": (0.01, 0.5),
20 | "walker2d-medium-expert": (0.01, 0.01),
21 | "walker2d-medium-replay": (0.05, 0.01),
22 | "walker2d-full-replay": (0.01, 0.01),
23 | }
24 |
25 |
26 | def main() -> None:
27 | parser = argparse.ArgumentParser()
28 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
29 | parser.add_argument("--seed", type=int, default=1)
30 | parser.add_argument("--gpu", type=int)
31 | parser.add_argument("--compile", action="store_true")
32 | args = parser.parse_args()
33 |
34 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
35 |
36 | # fix seed
37 | d3rlpy.seed(args.seed)
38 | d3rlpy.envs.seed_env(env, args.seed)
39 |
40 | # deeper network
41 | actor_encoder = d3rlpy.models.VectorEncoderFactory([256, 256, 256])
42 | critic_encoder = d3rlpy.models.VectorEncoderFactory(
43 | [256, 256, 256], use_layer_norm=True
44 | )
45 |
46 | actor_beta, critic_beta = 0.01, 0.01
47 | for dataset_name, beta_from_paper in BETA_TABLE.items():
48 | if dataset_name in args.dataset:
49 | actor_beta, critic_beta = beta_from_paper
50 | break
51 |
52 | rebrac = d3rlpy.algos.ReBRACConfig(
53 | actor_learning_rate=1e-3,
54 | critic_learning_rate=1e-3,
55 | batch_size=1024,
56 | gamma=0.99,
57 | actor_encoder_factory=actor_encoder,
58 | critic_encoder_factory=critic_encoder,
59 | target_smoothing_sigma=0.2,
60 | target_smoothing_clip=0.5,
61 | update_actor_interval=2,
62 | actor_beta=actor_beta,
63 | critic_beta=critic_beta,
64 | observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(),
65 | compile_graph=args.compile,
66 | ).create(device=args.gpu)
67 |
68 | rebrac.fit(
69 | dataset,
70 | n_steps=1000000,
71 | n_steps_per_epoch=1000,
72 | save_interval=10,
73 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
74 | experiment_name=f"ReBRAC_{args.dataset}_{args.seed}",
75 | )
76 |
77 |
78 | if __name__ == "__main__":
79 | main()
80 |
--------------------------------------------------------------------------------
/reproductions/offline/sac.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | sac = d3rlpy.algos.SACConfig(
21 | actor_learning_rate=3e-4,
22 | critic_learning_rate=3e-4,
23 | temp_learning_rate=3e-4,
24 | batch_size=256,
25 | compile_graph=args.compile,
26 | ).create(device=args.gpu)
27 |
28 | sac.fit(
29 | dataset,
30 | n_steps=500000,
31 | n_steps_per_epoch=1000,
32 | save_interval=10,
33 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
34 | experiment_name=f"SAC_{args.dataset}_{args.seed}",
35 | )
36 |
37 |
38 | if __name__ == "__main__":
39 | main()
40 |
--------------------------------------------------------------------------------
/reproductions/offline/tacr.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | if "halfcheetah" in args.dataset:
21 | target_return = 6000
22 | elif "hopper" in args.dataset:
23 | target_return = 3600
24 | elif "walker" in args.dataset:
25 | target_return = 5000
26 | else:
27 | raise ValueError("unsupported dataset")
28 |
29 | tacr = d3rlpy.algos.TACRConfig(
30 | batch_size=64,
31 | actor_learning_rate=1e-4,
32 | actor_optim_factory=d3rlpy.optimizers.AdamWFactory(
33 | weight_decay=1e-4,
34 | clip_grad_norm=0.25,
35 | lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory(
36 | warmup_steps=10000
37 | ),
38 | ),
39 | actor_encoder_factory=d3rlpy.models.VectorEncoderFactory(
40 | [128],
41 | exclude_last_activation=True,
42 | ),
43 | observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(),
44 | position_encoding_type=d3rlpy.PositionEncodingType.SIMPLE,
45 | context_size=20,
46 | num_heads=1,
47 | num_layers=3,
48 | max_timestep=1000,
49 | compile_graph=args.compile,
50 | ).create(device=args.gpu)
51 |
52 | tacr.fit(
53 | dataset,
54 | n_steps=100000,
55 | n_steps_per_epoch=1000,
56 | save_interval=10,
57 | eval_env=env,
58 | eval_target_return=target_return,
59 | experiment_name=f"TACR_{args.dataset}_{args.seed}",
60 | )
61 |
62 |
63 | if __name__ == "__main__":
64 | main()
65 |
--------------------------------------------------------------------------------
/reproductions/offline/td3.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | td3 = d3rlpy.algos.TD3Config(
21 | actor_learning_rate=3e-4,
22 | critic_learning_rate=3e-4,
23 | batch_size=256,
24 | target_smoothing_sigma=0.2,
25 | target_smoothing_clip=0.5,
26 | update_actor_interval=2,
27 | compile_graph=args.compile,
28 | ).create(device=args.gpu)
29 |
30 | td3.fit(
31 | dataset,
32 | n_steps=500000,
33 | n_steps_per_epoch=1000,
34 | save_interval=10,
35 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
36 | experiment_name=f"TD3_{args.dataset}_{args.seed}",
37 | )
38 |
39 |
40 | if __name__ == "__main__":
41 | main()
42 |
--------------------------------------------------------------------------------
/reproductions/offline/td3_plus_bc.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import d3rlpy
4 |
5 |
6 | def main() -> None:
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0")
9 | parser.add_argument("--seed", type=int, default=1)
10 | parser.add_argument("--gpu", type=int)
11 | parser.add_argument("--compile", action="store_true")
12 | args = parser.parse_args()
13 |
14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset)
15 |
16 | # fix seed
17 | d3rlpy.seed(args.seed)
18 | d3rlpy.envs.seed_env(env, args.seed)
19 |
20 | td3 = d3rlpy.algos.TD3PlusBCConfig(
21 | actor_learning_rate=3e-4,
22 | critic_learning_rate=3e-4,
23 | batch_size=256,
24 | target_smoothing_sigma=0.2,
25 | target_smoothing_clip=0.5,
26 | alpha=2.5,
27 | update_actor_interval=2,
28 | observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(),
29 | compile_graph=args.compile,
30 | ).create(device=args.gpu)
31 |
32 | td3.fit(
33 | dataset,
34 | n_steps=500000,
35 | n_steps_per_epoch=1000,
36 | save_interval=10,
37 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)},
38 | experiment_name=f"TD3PlusBC_{args.dataset}_{args.seed}",
39 | )
40 |
41 |
42 | if __name__ == "__main__":
43 | main()
44 |
--------------------------------------------------------------------------------
/reproductions/online/double_dqn_online.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 |
5 | import d3rlpy
6 |
7 |
8 | def main() -> None:
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4")
11 | parser.add_argument("--seed", type=int, default=1)
12 | parser.add_argument("--gpu", action="store_true")
13 | parser.add_argument("--compile", action="store_true")
14 | args = parser.parse_args()
15 |
16 | # get wrapped atari environment
17 | env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4)
18 | eval_env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4, is_eval=True)
19 |
20 | # fix seed
21 | d3rlpy.seed(args.seed)
22 | d3rlpy.envs.seed_env(env, args.seed)
23 | d3rlpy.envs.seed_env(eval_env, args.seed)
24 |
25 | # setup algorithm
26 | dqn = d3rlpy.algos.DoubleDQNConfig(
27 | batch_size=32,
28 | learning_rate=2.5e-4,
29 | optim_factory=d3rlpy.optimizers.RMSpropFactory(),
30 | target_update_interval=10000,
31 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),
32 | compile_graph=args.compile,
33 | ).create(device=args.gpu)
34 |
35 | # replay buffer for experience replay
36 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(
37 | limit=1000000,
38 | transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4),
39 | writer_preprocessor=d3rlpy.dataset.LastFrameWriterPreprocess(),
40 | env=env,
41 | )
42 |
43 | # epilon-greedy explorer
44 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy(
45 | start_epsilon=1.0, end_epsilon=0.1, duration=1000000
46 | )
47 |
48 | # start training
49 | dqn.fit_online(
50 | env,
51 | buffer,
52 | explorer,
53 | eval_env=eval_env,
54 | eval_epsilon=0.01,
55 | n_steps=50000000,
56 | n_steps_per_epoch=100000,
57 | update_interval=4,
58 | update_start_step=50000,
59 | )
60 |
61 |
62 | if __name__ == "__main__":
63 | main()
64 |
--------------------------------------------------------------------------------
/reproductions/online/dqn_online.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 |
5 | import d3rlpy
6 |
7 |
8 | def main() -> None:
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4")
11 | parser.add_argument("--seed", type=int, default=1)
12 | parser.add_argument("--gpu", action="store_true")
13 | parser.add_argument("--compile", action="store_true")
14 | args = parser.parse_args()
15 |
16 | # get wrapped atari environment
17 | env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4)
18 | eval_env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4, is_eval=True)
19 |
20 | # fix seed
21 | d3rlpy.seed(args.seed)
22 | d3rlpy.envs.seed_env(env, args.seed)
23 | d3rlpy.envs.seed_env(eval_env, args.seed)
24 |
25 | # setup algorithm
26 | dqn = d3rlpy.algos.DQNConfig(
27 | batch_size=32,
28 | learning_rate=2.5e-4,
29 | optim_factory=d3rlpy.optimizers.RMSpropFactory(),
30 | target_update_interval=10000 // 4,
31 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),
32 | compile_graph=args.compile,
33 | ).create(device=args.gpu)
34 |
35 | # replay buffer for experience replay
36 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(
37 | limit=1000000,
38 | transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4),
39 | writer_preprocessor=d3rlpy.dataset.LastFrameWriterPreprocess(),
40 | env=env,
41 | )
42 |
43 | # epilon-greedy explorer
44 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy(
45 | start_epsilon=1.0, end_epsilon=0.1, duration=1000000
46 | )
47 |
48 | # start training
49 | dqn.fit_online(
50 | env,
51 | buffer,
52 | explorer,
53 | eval_env=eval_env,
54 | eval_epsilon=0.01,
55 | n_steps=50000000,
56 | n_steps_per_epoch=100000,
57 | update_interval=4,
58 | update_start_step=50000,
59 | )
60 |
61 |
62 | if __name__ == "__main__":
63 | main()
64 |
--------------------------------------------------------------------------------
/reproductions/online/iqn_online.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 |
5 | import d3rlpy
6 |
7 |
8 | def main() -> None:
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4")
11 | parser.add_argument("--seed", type=int, default=1)
12 | parser.add_argument("--gpu", action="store_true")
13 | parser.add_argument("--compile", action="store_true")
14 | args = parser.parse_args()
15 |
16 | # get wrapped atari environment
17 | env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4)
18 | eval_env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4, is_eval=True)
19 |
20 | # fix seed
21 | d3rlpy.seed(args.seed)
22 | d3rlpy.envs.seed_env(env, args.seed)
23 | d3rlpy.envs.seed_env(eval_env, args.seed)
24 |
25 | # setup algorithm
26 | dqn = d3rlpy.algos.DQNConfig(
27 | batch_size=32,
28 | learning_rate=5e-5,
29 | optim_factory=d3rlpy.optimizers.AdamFactory(eps=1e-2 / 32),
30 | target_update_interval=10000 // 4,
31 | q_func_factory=d3rlpy.models.IQNQFunctionFactory(),
32 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),
33 | compile_graph=args.compile,
34 | ).create(device=args.gpu)
35 |
36 | # replay buffer for experience replay
37 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(
38 | limit=1000000,
39 | transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4),
40 | writer_preprocessor=d3rlpy.dataset.LastFrameWriterPreprocess(),
41 | env=env,
42 | )
43 |
44 | # epilon-greedy explorer
45 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy(
46 | start_epsilon=1.0, end_epsilon=0.01, duration=1000000
47 | )
48 |
49 | # start training
50 | dqn.fit_online(
51 | env,
52 | buffer,
53 | explorer,
54 | eval_env=eval_env,
55 | eval_epsilon=0.001,
56 | n_steps=50000000,
57 | n_steps_per_epoch=100000,
58 | update_interval=4,
59 | update_start_step=50000,
60 | )
61 |
62 |
63 | if __name__ == "__main__":
64 | main()
65 |
--------------------------------------------------------------------------------
/reproductions/online/qr_dqn_online.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 |
5 | import d3rlpy
6 |
7 |
8 | def main() -> None:
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4")
11 | parser.add_argument("--seed", type=int, default=1)
12 | parser.add_argument("--gpu", action="store_true")
13 | parser.add_argument("--compile", action="store_true")
14 | args = parser.parse_args()
15 |
16 | # get wrapped atari environment
17 | env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4)
18 | eval_env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4, is_eval=True)
19 |
20 | # fix seed
21 | d3rlpy.seed(args.seed)
22 | d3rlpy.envs.seed_env(env, args.seed)
23 | d3rlpy.envs.seed_env(eval_env, args.seed)
24 |
25 | # setup algorithm
26 | dqn = d3rlpy.algos.DQNConfig(
27 | batch_size=32,
28 | learning_rate=5e-5,
29 | optim_factory=d3rlpy.optimizers.AdamFactory(eps=1e-2 / 32),
30 | target_update_interval=10000 // 4,
31 | q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory(
32 | n_quantiles=200
33 | ),
34 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(),
35 | compile_graph=args.compile,
36 | ).create(device=args.gpu)
37 |
38 | # replay buffer for experience replay
39 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(
40 | limit=1000000,
41 | transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4),
42 | writer_preprocessor=d3rlpy.dataset.LastFrameWriterPreprocess(),
43 | env=env,
44 | )
45 |
46 | # epilon-greedy explorer
47 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy(
48 | start_epsilon=1.0, end_epsilon=0.01, duration=1000000
49 | )
50 |
51 | # start training
52 | dqn.fit_online(
53 | env,
54 | buffer,
55 | explorer,
56 | eval_env=eval_env,
57 | eval_epsilon=0.001,
58 | n_steps=50000000,
59 | n_steps_per_epoch=100000,
60 | update_interval=4,
61 | update_start_step=50000,
62 | )
63 |
64 |
65 | if __name__ == "__main__":
66 | main()
67 |
--------------------------------------------------------------------------------
/reproductions/online/sac_online.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import gym
4 |
5 | import d3rlpy
6 |
7 |
8 | def main() -> None:
9 | parser = argparse.ArgumentParser()
10 | parser.add_argument("--env", type=str, default="Hopper-v2")
11 | parser.add_argument("--seed", type=int, default=1)
12 | parser.add_argument("--gpu", action="store_true")
13 | parser.add_argument("--compile", action="store_true")
14 | args = parser.parse_args()
15 |
16 | env = gym.make(args.env)
17 | eval_env = gym.make(args.env)
18 |
19 | # fix seed
20 | d3rlpy.seed(args.seed)
21 | d3rlpy.envs.seed_env(env, args.seed)
22 | d3rlpy.envs.seed_env(eval_env, args.seed)
23 |
24 | # setup algorithm
25 | sac = d3rlpy.algos.SACConfig(
26 | batch_size=256,
27 | actor_learning_rate=3e-4,
28 | critic_learning_rate=3e-4,
29 | temp_learning_rate=3e-4,
30 | compile_graph=args.compile,
31 | ).create(device=args.gpu)
32 |
33 | # replay buffer for experience replay
34 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=1000000, env=env)
35 |
36 | # start training
37 | sac.fit_online(
38 | env,
39 | buffer,
40 | eval_env=eval_env,
41 | n_steps=1000000,
42 | n_steps_per_epoch=10000,
43 | update_interval=1,
44 | update_start_step=1000,
45 | )
46 |
47 |
48 | if __name__ == "__main__":
49 | main()
50 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.5.0
2 | tqdm>=4.66.1
3 | h5py==2.10.0
4 | gym==0.26.2
5 | click==8.0.1
6 | typing-extensions==3.7.4.3
7 | structlog==20.2.0
8 | colorama==0.4.4
9 | gymnasium==1.0.0
10 | scikit-learn==1.5.2
11 |
--------------------------------------------------------------------------------
/ruff.toml:
--------------------------------------------------------------------------------
1 | line-length = 80
2 | indent-width = 4
3 | target-version = "py39"
4 | unsafe-fixes = true
5 |
6 | [lint]
7 | select = ["E4", "E5", "E7", "E9", "F", "UP006", "I", "W"]
8 | ignore = ["F403"]
9 |
10 | # Allow fix for all enabled rules (when `--fix`) is provided.
11 | fixable = ["ALL"]
12 | unfixable = []
13 |
14 | [format]
15 | # Like Black, use double quotes for strings.
16 | quote-style = "double"
17 |
18 | # Like Black, indent with spaces, rather than tabs.
19 | indent-style = "space"
20 |
21 | # Like Black, respect magic trailing commas.
22 | skip-magic-trailing-comma = false
23 |
24 | # Like Black, automatically detect the appropriate line ending.
25 | line-ending = "auto"
26 |
--------------------------------------------------------------------------------
/scripts/build-dist:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # clean
4 | rm -rf d3rlpy.egg-info
5 | rm -rf dist
6 | rm -rf build
7 |
8 | python setup.py test
9 |
10 | python setup.py sdist bdist_wheel
11 |
--------------------------------------------------------------------------------
/scripts/build-docker:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | docker build -t takuseno/d3rlpy:latest docker
4 |
--------------------------------------------------------------------------------
/scripts/build-docs:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | sphinx-apidoc -f -o ./docs d3rlpy *tests
4 |
5 | sphinx-build -b html ./docs ./docs/_build
6 |
--------------------------------------------------------------------------------
/scripts/create_cartpole_dataset:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import gym
4 | import d3rlpy
5 |
6 | d3rlpy.seed(100)
7 |
8 | # prepare environment
9 | env = gym.make('CartPole-v0')
10 | eval_env = gym.make('CartPole-v0')
11 |
12 | # prepare algorithms
13 | dqn = d3rlpy.algos.DQN(learning_rate=1e-3, target_update_interval=100)
14 |
15 | # prepare utilities
16 | buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env)
17 | explorer = d3rlpy.online.explorers.ConstantEpsilonGreedy(epsilon=0.3)
18 |
19 | # start training
20 | dqn.fit_online(
21 | env, buffer=buffer, explorer=explorer, eval_env=eval_env, n_steps=100000
22 | )
23 |
24 | # export replay buffer as MDPDataset
25 | dataset = buffer.to_mdp_dataset()
26 |
27 | # save MDPDataset
28 | dataset.dump('cartpole.h5')
29 |
--------------------------------------------------------------------------------
/scripts/create_cartpole_random_dataset:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import gym
4 | import d3rlpy
5 |
6 | d3rlpy.seed(100)
7 |
8 | # prepare environment
9 | env = gym.make('CartPole-v0')
10 |
11 | # prepare algorithms
12 | policy = d3rlpy.algos.DiscreteRandomPolicy()
13 |
14 | # prepare utilities
15 | buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env)
16 |
17 | # start training
18 | policy.collect(env, buffer=buffer, n_steps=100000)
19 |
20 | # export replay buffer as MDPDataset
21 | dataset = buffer.to_mdp_dataset()
22 |
23 | # save MDPDataset
24 | dataset.dump('cartpole_random.h5')
25 |
--------------------------------------------------------------------------------
/scripts/create_pendulum_dataset:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import gym
4 | import d3rlpy
5 |
6 | d3rlpy.seed(100)
7 |
8 | # prepare environment
9 | env = gym.make('Pendulum-v0')
10 | eval_env = gym.make('Pendulum-v0')
11 |
12 | # prepare algorithms
13 | sac = d3rlpy.algos.SAC(action_scaler='min_max')
14 |
15 | # prepare utilities
16 | buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env)
17 |
18 | # start training
19 | sac.fit_online(env, buffer=buffer, eval_env=eval_env, n_steps=100000)
20 |
21 | # export replay buffer as MDPDataset
22 | dataset = buffer.to_mdp_dataset()
23 |
24 | # save MDPDataset
25 | dataset.dump('pendulum.h5')
26 |
--------------------------------------------------------------------------------
/scripts/create_pendulum_random_dataset:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import gym
4 | import d3rlpy
5 |
6 | d3rlpy.seed(100)
7 |
8 | # prepare environment
9 | env = gym.make('Pendulum-v0')
10 |
11 | # prepare algorithms
12 | policy = d3rlpy.algos.RandomPolicy(
13 | distribution='normal',
14 | action_scaler='min_max',
15 | )
16 |
17 | # prepare utilities
18 | buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env)
19 |
20 | # start training
21 | policy.collect(env, buffer=buffer, n_steps=100000)
22 |
23 | # export replay buffer as MDPDataset
24 | dataset = buffer.to_mdp_dataset()
25 |
26 | # save MDPDataset
27 | dataset.dump('pendulum_random.h5')
28 |
--------------------------------------------------------------------------------
/scripts/lint:
--------------------------------------------------------------------------------
1 | #!/bin/bash -ex
2 |
3 | if [[ -z $CI ]]; then
4 | BLACK_ARG=""
5 | RUFF_ARG="check --fix"
6 | DOCFORMATTER_ARG="--in-place"
7 | else
8 | ISORT_ARG="--check --diff"
9 | RUFF_ARG="check"
10 | DOCFORMATTER_ARG="--check --diff"
11 | fi
12 |
13 | # use black for the better type annotations
14 | black -l 80 $BLACK_ARG d3rlpy tests setup.py reproductions examples
15 |
16 | # formatter and linter
17 | ruff $RUFF_ARG d3rlpy tests examples reproductions setup.py
18 |
19 | # format docstrings
20 | docformatter $DOCFORMATTER_ARG --black --wrap-summaries 80 --wrap-descriptions 80 -r d3rlpy
21 |
22 | # type check
23 | mypy d3rlpy reproductions tests examples
24 |
--------------------------------------------------------------------------------
/scripts/test:
--------------------------------------------------------------------------------
1 | #!/bin/bash -eux
2 |
3 | FLAG_P="FALSE"
4 |
5 | while getopts p OPT
6 | do
7 | case $OPT in
8 | "p" ) FLAG_P="TRUE" ;; # do performance test
9 | * ) echo "Usage: ./scripts/test [-p]" 1>&2
10 | exit 1 ;;
11 | esac
12 | done
13 |
14 | # create temporary directory for tests
15 | mkdir -p test_data
16 |
17 | # set flag
18 | if [ "$FLAG_P" = "TRUE" ]; then
19 | TEST_PERFORMANCE="TRUE"
20 | else
21 | TEST_PERFORMANCE="FALSE"
22 | fi
23 |
24 | # run tests
25 | echo "Run unit tests"
26 | TEST_PERFORMANCE=$TEST_PERFORMANCE \
27 | KMP_DUPLICATE_LIB_OK='True' pytest --cov-report=xml \
28 | --cov=d3rlpy \
29 | --cov-config=.coveragerc \
30 | tests -p no:warnings -v
31 |
32 | # clean up
33 | rm -r test_data
34 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from setuptools import find_packages, setup
4 |
5 | # get __version__ variable
6 | here = os.path.abspath(os.path.dirname(__file__))
7 | exec(open(os.path.join(here, "d3rlpy", "_version.py")).read())
8 |
9 | if __name__ == "__main__":
10 | setup(
11 | name="d3rlpy",
12 | version=__version__, # noqa
13 | description="An offline deep reinforcement learning library",
14 | long_description=open("README.md").read(),
15 | long_description_content_type="text/markdown",
16 | url="https://github.com/takuseno/d3rlpy",
17 | author="Takuma Seno",
18 | author_email="takuma.seno@gmail.com",
19 | license="MIT License",
20 | classifiers=[
21 | "Development Status :: 5 - Production/Stable",
22 | "Intended Audience :: Developers",
23 | "Intended Audience :: Education",
24 | "Intended Audience :: Science/Research",
25 | "Topic :: Scientific/Engineering",
26 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
27 | "Programming Language :: Python :: 3.9",
28 | "Programming Language :: Python :: 3.10",
29 | "Programming Language :: Python :: 3.11",
30 | "Programming Language :: Python :: Implementation :: CPython",
31 | "Operating System :: POSIX :: Linux",
32 | "Operating System :: Microsoft :: Windows",
33 | "Operating System :: MacOS :: MacOS X",
34 | ],
35 | install_requires=[
36 | "torch>=2.5.0",
37 | "tqdm>=4.66.3",
38 | "h5py",
39 | "gym>=0.26.0",
40 | "click",
41 | "typing-extensions",
42 | "structlog",
43 | "colorama",
44 | "dataclasses-json",
45 | "gymnasium==1.0.0",
46 | "scikit-learn",
47 | ],
48 | packages=find_packages(exclude=["tests*"]),
49 | python_requires=">=3.9.0",
50 | zip_safe=True,
51 | entry_points={"console_scripts": ["d3rlpy=d3rlpy.cli:cli"]},
52 | )
53 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pytest
4 |
5 | is_skipping_performance_test = os.environ.get("TEST_PERFORMANCE") != "TRUE"
6 | performance_test = pytest.mark.skipif(
7 | is_skipping_performance_test, reason="skip performance tests"
8 | )
9 |
--------------------------------------------------------------------------------
/tests/algos/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/algos/__init__.py
--------------------------------------------------------------------------------
/tests/algos/qlearning/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/algos/qlearning/__init__.py
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_awac.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.awac import AWACConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_awac(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = AWACConfig(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | q_func_factory=q_func_factory,
37 | observation_scaler=observation_scaler,
38 | action_scaler=action_scaler,
39 | reward_scaler=reward_scaler,
40 | )
41 | awac = config.create()
42 | algo_tester(awac, observation_shape) # type: ignore
43 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_bc.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.bc import BCConfig, DiscreteBCConfig
6 | from d3rlpy.types import Shape
7 |
8 | from ...models.torch.model_test import DummyEncoderFactory
9 | from ...testing_utils import create_scaler_tuple
10 | from .algo_test import algo_tester
11 |
12 |
13 | @pytest.mark.parametrize(
14 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
15 | )
16 | @pytest.mark.parametrize("policy_type", ["deterministic", "stochastic"])
17 | @pytest.mark.parametrize("scalers", [None, "min_max"])
18 | def test_bc(
19 | observation_shape: Shape, policy_type: str, scalers: Optional[str]
20 | ) -> None:
21 | observation_scaler, action_scaler, _ = create_scaler_tuple(
22 | scalers, observation_shape
23 | )
24 | config = BCConfig(
25 | encoder_factory=DummyEncoderFactory(),
26 | observation_scaler=observation_scaler,
27 | action_scaler=action_scaler,
28 | policy_type=policy_type,
29 | )
30 | bc = config.create()
31 | algo_tester(
32 | bc, # type: ignore
33 | observation_shape,
34 | test_predict_value=False,
35 | test_policy_copy=False,
36 | test_policy_optim_copy=False,
37 | test_q_function_optim_copy=False,
38 | test_q_function_copy=False,
39 | )
40 |
41 |
42 | @pytest.mark.parametrize(
43 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
44 | )
45 | @pytest.mark.parametrize("scaler", [None, "min_max"])
46 | def test_discrete_bc(observation_shape: Shape, scaler: Optional[str]) -> None:
47 | observation_scaler, _, _ = create_scaler_tuple(scaler, observation_shape)
48 | config = DiscreteBCConfig(
49 | encoder_factory=DummyEncoderFactory(),
50 | observation_scaler=observation_scaler,
51 | )
52 | bc = config.create()
53 | algo_tester(
54 | bc, # type: ignore
55 | observation_shape,
56 | test_policy_copy=False,
57 | test_predict_value=False,
58 | test_policy_optim_copy=False,
59 | test_q_function_optim_copy=False,
60 | test_q_function_copy=False,
61 | )
62 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_bcq.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.bcq import BCQConfig, DiscreteBCQConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_bcq(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = BCQConfig(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | imitator_encoder_factory=DummyEncoderFactory(),
37 | q_func_factory=q_func_factory,
38 | observation_scaler=observation_scaler,
39 | action_scaler=action_scaler,
40 | reward_scaler=reward_scaler,
41 | rl_start_step=0,
42 | )
43 | bcq = config.create()
44 | algo_tester(
45 | bcq, # type: ignore
46 | observation_shape,
47 | deterministic_best_action=False,
48 | test_policy_copy=False,
49 | )
50 |
51 |
52 | @pytest.mark.parametrize(
53 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
54 | )
55 | @pytest.mark.parametrize("n_critics", [1])
56 | @pytest.mark.parametrize(
57 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
58 | )
59 | @pytest.mark.parametrize("scalers", [None, "min_max"])
60 | def test_discrete_bcq(
61 | observation_shape: Shape,
62 | n_critics: int,
63 | q_func_factory: QFunctionFactory,
64 | scalers: Optional[str],
65 | ) -> None:
66 | observation_scaler, _, reward_scaler = create_scaler_tuple(
67 | scalers, observation_shape
68 | )
69 | config = DiscreteBCQConfig(
70 | encoder_factory=DummyEncoderFactory(),
71 | n_critics=n_critics,
72 | q_func_factory=q_func_factory,
73 | observation_scaler=observation_scaler,
74 | reward_scaler=reward_scaler,
75 | )
76 | bcq = config.create()
77 | algo_tester(
78 | bcq, # type: ignore
79 | observation_shape,
80 | test_policy_copy=False,
81 | test_policy_optim_copy=False,
82 | )
83 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_bear.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.bear import BEARConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_bear(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = BEARConfig(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | imitator_encoder_factory=DummyEncoderFactory(),
37 | q_func_factory=q_func_factory,
38 | observation_scaler=observation_scaler,
39 | action_scaler=action_scaler,
40 | reward_scaler=reward_scaler,
41 | warmup_steps=0,
42 | )
43 | bear = config.create()
44 | algo_tester(
45 | bear, # type: ignore
46 | observation_shape,
47 | deterministic_best_action=False,
48 | test_policy_copy=False,
49 | )
50 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_cal_ql.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.cal_ql import CalQLConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_cal_ql(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = CalQLConfig(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | q_func_factory=q_func_factory,
37 | observation_scaler=observation_scaler,
38 | action_scaler=action_scaler,
39 | reward_scaler=reward_scaler,
40 | )
41 | cal_ql = config.create()
42 | algo_tester(cal_ql, observation_shape) # type: ignore
43 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_cql.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.cql import CQLConfig, DiscreteCQLConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_cql(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = CQLConfig(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | q_func_factory=q_func_factory,
37 | observation_scaler=observation_scaler,
38 | action_scaler=action_scaler,
39 | reward_scaler=reward_scaler,
40 | )
41 | cql = config.create()
42 | algo_tester(cql, observation_shape) # type: ignore
43 |
44 |
45 | @pytest.mark.parametrize(
46 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
47 | )
48 | @pytest.mark.parametrize("n_critics", [1])
49 | @pytest.mark.parametrize(
50 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
51 | )
52 | @pytest.mark.parametrize("scalers", [None, None, "min_max"])
53 | def test_discrete_cql(
54 | observation_shape: Shape,
55 | n_critics: int,
56 | q_func_factory: QFunctionFactory,
57 | scalers: Optional[str],
58 | ) -> None:
59 | observation_scaler, _, reward_scaler = create_scaler_tuple(
60 | scalers, observation_shape
61 | )
62 | config = DiscreteCQLConfig(
63 | encoder_factory=DummyEncoderFactory(),
64 | n_critics=n_critics,
65 | q_func_factory=q_func_factory,
66 | observation_scaler=observation_scaler,
67 | reward_scaler=reward_scaler,
68 | )
69 | cql = config.create()
70 | algo_tester(
71 | cql, # type: ignore
72 | observation_shape,
73 | test_policy_copy=False,
74 | test_policy_optim_copy=False,
75 | )
76 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_crr.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.crr import CRRConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | @pytest.mark.parametrize("advantage_type", ["mean", "max"])
26 | @pytest.mark.parametrize("weight_type", ["exp", "binary"])
27 | @pytest.mark.parametrize("target_update_type", ["hard", "soft"])
28 | def test_crr(
29 | observation_shape: Shape,
30 | q_func_factory: QFunctionFactory,
31 | scalers: Optional[str],
32 | advantage_type: str,
33 | weight_type: str,
34 | target_update_type: str,
35 | ) -> None:
36 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
37 | scalers, observation_shape
38 | )
39 | config = CRRConfig(
40 | actor_encoder_factory=DummyEncoderFactory(),
41 | critic_encoder_factory=DummyEncoderFactory(),
42 | q_func_factory=q_func_factory,
43 | observation_scaler=observation_scaler,
44 | action_scaler=action_scaler,
45 | reward_scaler=reward_scaler,
46 | advantage_type=advantage_type,
47 | weight_type=weight_type,
48 | target_update_type=target_update_type,
49 | )
50 | crr = config.create()
51 | algo_tester(
52 | crr, # type: ignore
53 | observation_shape,
54 | deterministic_best_action=False,
55 | test_policy_copy=False,
56 | )
57 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_ddpg.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.ddpg import DDPGConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_ddpg(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = DDPGConfig(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | q_func_factory=q_func_factory,
37 | observation_scaler=observation_scaler,
38 | action_scaler=action_scaler,
39 | reward_scaler=reward_scaler,
40 | )
41 | ddpg = config.create()
42 | algo_tester(ddpg, observation_shape) # type: ignore
43 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_dqn.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.dqn import DoubleDQNConfig, DQNConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize("n_critics", [1])
22 | @pytest.mark.parametrize(
23 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
24 | )
25 | @pytest.mark.parametrize("scalers", [None, "min_max"])
26 | def test_dqn(
27 | observation_shape: Shape,
28 | n_critics: int,
29 | q_func_factory: QFunctionFactory,
30 | scalers: Optional[str],
31 | ) -> None:
32 | observation_scaler, _, reward_scaler = create_scaler_tuple(
33 | scalers, observation_shape
34 | )
35 | config = DQNConfig(
36 | encoder_factory=DummyEncoderFactory(),
37 | n_critics=n_critics,
38 | q_func_factory=q_func_factory,
39 | observation_scaler=observation_scaler,
40 | reward_scaler=reward_scaler,
41 | )
42 | dqn = config.create()
43 | algo_tester(
44 | dqn, # type: ignore
45 | observation_shape,
46 | test_policy_copy=False,
47 | test_policy_optim_copy=False,
48 | )
49 |
50 |
51 | @pytest.mark.parametrize(
52 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
53 | )
54 | @pytest.mark.parametrize("n_critics", [1])
55 | @pytest.mark.parametrize(
56 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
57 | )
58 | @pytest.mark.parametrize("scalers", [None, "min_max"])
59 | def test_double_dqn(
60 | observation_shape: Shape,
61 | n_critics: int,
62 | q_func_factory: QFunctionFactory,
63 | scalers: Optional[str],
64 | ) -> None:
65 | observation_scaler, _, reward_scaler = create_scaler_tuple(
66 | scalers, observation_shape
67 | )
68 | config = DoubleDQNConfig(
69 | encoder_factory=DummyEncoderFactory(),
70 | n_critics=n_critics,
71 | q_func_factory=q_func_factory,
72 | observation_scaler=observation_scaler,
73 | reward_scaler=reward_scaler,
74 | )
75 | double_dqn = config.create()
76 | algo_tester(
77 | double_dqn, # type: ignore
78 | observation_shape,
79 | test_policy_copy=False,
80 | test_policy_optim_copy=False,
81 | )
82 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_iql.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.iql import IQLConfig
6 | from d3rlpy.types import Shape
7 |
8 | from ...models.torch.model_test import DummyEncoderFactory
9 | from ...testing_utils import create_scaler_tuple
10 | from .algo_test import algo_tester
11 |
12 |
13 | @pytest.mark.parametrize(
14 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
15 | )
16 | @pytest.mark.parametrize("scalers", [None, "min_max"])
17 | def test_iql(observation_shape: Shape, scalers: Optional[str]) -> None:
18 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
19 | scalers, observation_shape
20 | )
21 | config = IQLConfig(
22 | actor_encoder_factory=DummyEncoderFactory(),
23 | critic_encoder_factory=DummyEncoderFactory(),
24 | value_encoder_factory=DummyEncoderFactory(),
25 | observation_scaler=observation_scaler,
26 | action_scaler=action_scaler,
27 | reward_scaler=reward_scaler,
28 | )
29 | iql = config.create()
30 | algo_tester(iql, observation_shape) # type: ignore
31 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_nfq.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.nfq import NFQConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize("n_critics", [1])
22 | @pytest.mark.parametrize(
23 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
24 | )
25 | @pytest.mark.parametrize("scalers", [None, "min_max"])
26 | def test_nfq(
27 | observation_shape: Shape,
28 | n_critics: int,
29 | q_func_factory: QFunctionFactory,
30 | scalers: Optional[str],
31 | ) -> None:
32 | observation_scaler, _, reward_scaler = create_scaler_tuple(
33 | scalers, observation_shape
34 | )
35 | config = NFQConfig(
36 | encoder_factory=DummyEncoderFactory(),
37 | n_critics=n_critics,
38 | q_func_factory=q_func_factory,
39 | observation_scaler=observation_scaler,
40 | reward_scaler=reward_scaler,
41 | )
42 | nfq = config.create()
43 | algo_tester(
44 | nfq, # type: ignore
45 | observation_shape,
46 | test_policy_copy=False,
47 | test_policy_optim_copy=False,
48 | )
49 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_plas.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.plas import PLASConfig, PLASWithPerturbationConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_plas(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = PLASConfig(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | imitator_encoder_factory=DummyEncoderFactory(),
37 | q_func_factory=q_func_factory,
38 | observation_scaler=observation_scaler,
39 | action_scaler=action_scaler,
40 | reward_scaler=reward_scaler,
41 | warmup_steps=0,
42 | )
43 | plas = config.create()
44 | algo_tester(plas, observation_shape, test_policy_copy=False) # type: ignore
45 |
46 |
47 | @pytest.mark.parametrize(
48 | "observation_shape", [(100,), (4, 84, 84), ((100,), (200,))]
49 | )
50 | @pytest.mark.parametrize(
51 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
52 | )
53 | @pytest.mark.parametrize("scalers", [None, "min_max"])
54 | def test_plas_with_perturbation(
55 | observation_shape: Shape,
56 | q_func_factory: QFunctionFactory,
57 | scalers: Optional[str],
58 | ) -> None:
59 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
60 | scalers, observation_shape
61 | )
62 | config = PLASWithPerturbationConfig(
63 | actor_encoder_factory=DummyEncoderFactory(),
64 | critic_encoder_factory=DummyEncoderFactory(),
65 | imitator_encoder_factory=DummyEncoderFactory(),
66 | q_func_factory=q_func_factory,
67 | observation_scaler=observation_scaler,
68 | action_scaler=action_scaler,
69 | reward_scaler=reward_scaler,
70 | warmup_steps=0,
71 | )
72 | plas = config.create()
73 | algo_tester(plas, observation_shape, test_policy_copy=False) # type: ignore
74 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_prdc.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.prdc import PRDCConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize("observation_shape", [(100,), (17,)])
19 | @pytest.mark.parametrize(
20 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
21 | )
22 | @pytest.mark.parametrize("scalers", [None, "min_max"])
23 | def test_prdc(
24 | observation_shape: Shape,
25 | q_func_factory: QFunctionFactory,
26 | scalers: Optional[str],
27 | ) -> None:
28 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
29 | scalers, observation_shape
30 | )
31 | config = PRDCConfig(
32 | actor_encoder_factory=DummyEncoderFactory(),
33 | critic_encoder_factory=DummyEncoderFactory(),
34 | q_func_factory=q_func_factory,
35 | observation_scaler=observation_scaler,
36 | action_scaler=action_scaler,
37 | reward_scaler=reward_scaler,
38 | )
39 | prdc = config.create()
40 | algo_tester(prdc, observation_shape) # type: ignore
41 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_random_policy.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from d3rlpy.algos.qlearning.random_policy import (
7 | DiscreteRandomPolicyConfig,
8 | RandomPolicyConfig,
9 | )
10 |
11 |
12 | @pytest.mark.parametrize("distribution", ["uniform", "normal"])
13 | @pytest.mark.parametrize("action_size", [4])
14 | @pytest.mark.parametrize("batch_size", [32])
15 | @pytest.mark.parametrize("observation_shape", [(100,)])
16 | def test_random_policy(
17 | distribution: str,
18 | action_size: int,
19 | batch_size: int,
20 | observation_shape: Sequence[int],
21 | ) -> None:
22 | config = RandomPolicyConfig(distribution=distribution)
23 | algo = config.create()
24 | algo.create_impl(observation_shape, action_size)
25 |
26 | x = np.random.random((batch_size, *observation_shape))
27 |
28 | # check predict
29 | action = algo.predict(x)
30 | assert action.shape == (batch_size, action_size)
31 |
32 | # check sample_action
33 | action = algo.sample_action(x)
34 | assert action.shape == (batch_size, action_size)
35 |
36 |
37 | @pytest.mark.parametrize("action_size", [4])
38 | @pytest.mark.parametrize("batch_size", [32])
39 | @pytest.mark.parametrize("observation_shape", [(100,)])
40 | def test_discrete_random_policy(
41 | action_size: int, batch_size: int, observation_shape: Sequence[int]
42 | ) -> None:
43 | algo = DiscreteRandomPolicyConfig().create()
44 | algo.create_impl(observation_shape, action_size)
45 |
46 | x = np.random.random((batch_size, *observation_shape))
47 |
48 | # check predict
49 | action = algo.predict(x)
50 | assert action.shape == (batch_size,)
51 | assert np.all(action < action_size)
52 |
53 | # check sample_action
54 | action = algo.sample_action(x)
55 | assert action.shape == (batch_size,)
56 | assert np.all(action < action_size)
57 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_rebrac.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.rebrac import ReBRACConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_rebrac(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = ReBRACConfig(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | q_func_factory=q_func_factory,
37 | observation_scaler=observation_scaler,
38 | action_scaler=action_scaler,
39 | reward_scaler=reward_scaler,
40 | )
41 | rebrac = config.create()
42 | algo_tester(rebrac, observation_shape) # type: ignore
43 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_sac.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.sac import DiscreteSACConfig, SACConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_sac(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = SACConfig(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | q_func_factory=q_func_factory,
37 | observation_scaler=observation_scaler,
38 | action_scaler=action_scaler,
39 | reward_scaler=reward_scaler,
40 | )
41 | sac = config.create()
42 | algo_tester(sac, observation_shape) # type: ignore
43 |
44 |
45 | @pytest.mark.parametrize(
46 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
47 | )
48 | @pytest.mark.parametrize(
49 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
50 | )
51 | @pytest.mark.parametrize("scalers", [None, "min_max"])
52 | def test_discrete_sac(
53 | observation_shape: Shape,
54 | q_func_factory: QFunctionFactory,
55 | scalers: Optional[str],
56 | ) -> None:
57 | observation_scaler, _, reward_scaler = create_scaler_tuple(
58 | scalers, observation_shape
59 | )
60 | config = DiscreteSACConfig(
61 | actor_encoder_factory=DummyEncoderFactory(),
62 | critic_encoder_factory=DummyEncoderFactory(),
63 | q_func_factory=q_func_factory,
64 | observation_scaler=observation_scaler,
65 | reward_scaler=reward_scaler,
66 | )
67 | sac = config.create()
68 | algo_tester(sac, observation_shape, action_size=100) # type: ignore
69 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_td3.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.td3 import TD3Config
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_td3(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = TD3Config(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | q_func_factory=q_func_factory,
37 | observation_scaler=observation_scaler,
38 | action_scaler=action_scaler,
39 | reward_scaler=reward_scaler,
40 | )
41 | td3 = config.create()
42 | algo_tester(td3, observation_shape) # type: ignore
43 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/test_td3_plus_bc.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos.qlearning.td3_plus_bc import TD3PlusBCConfig
6 | from d3rlpy.models import (
7 | MeanQFunctionFactory,
8 | QFunctionFactory,
9 | QRQFunctionFactory,
10 | )
11 | from d3rlpy.types import Shape
12 |
13 | from ...models.torch.model_test import DummyEncoderFactory
14 | from ...testing_utils import create_scaler_tuple
15 | from .algo_test import algo_tester
16 |
17 |
18 | @pytest.mark.parametrize(
19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))]
20 | )
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("scalers", [None, "min_max"])
25 | def test_td3_plus_bc(
26 | observation_shape: Shape,
27 | q_func_factory: QFunctionFactory,
28 | scalers: Optional[str],
29 | ) -> None:
30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
31 | scalers, observation_shape
32 | )
33 | config = TD3PlusBCConfig(
34 | actor_encoder_factory=DummyEncoderFactory(),
35 | critic_encoder_factory=DummyEncoderFactory(),
36 | q_func_factory=q_func_factory,
37 | observation_scaler=observation_scaler,
38 | action_scaler=action_scaler,
39 | reward_scaler=reward_scaler,
40 | )
41 | td3 = config.create()
42 | algo_tester(td3, observation_shape) # type: ignore
43 |
--------------------------------------------------------------------------------
/tests/algos/qlearning/torch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/algos/qlearning/torch/__init__.py
--------------------------------------------------------------------------------
/tests/algos/qlearning/torch/test_utility.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from d3rlpy.algos.qlearning.torch.utility import sample_q_values_with_policy
4 | from d3rlpy.models.builders import (
5 | create_continuous_q_function,
6 | create_normal_policy,
7 | )
8 | from d3rlpy.models.q_functions import MeanQFunctionFactory
9 | from d3rlpy.types import Shape
10 |
11 | from ....models.torch.model_test import DummyEncoderFactory
12 | from ....testing_utils import create_torch_observations
13 |
14 |
15 | @pytest.mark.parametrize(
16 | "observation_shape", [(100,), (4, 84, 84), ((100,), (200,))]
17 | )
18 | @pytest.mark.parametrize("action_size", [4])
19 | @pytest.mark.parametrize("n_action_samples", [10])
20 | @pytest.mark.parametrize("batch_size", [256])
21 | @pytest.mark.parametrize("n_critics", [2])
22 | def test_sample_q_values_with_policy(
23 | observation_shape: Shape,
24 | action_size: int,
25 | n_action_samples: int,
26 | batch_size: int,
27 | n_critics: int,
28 | ) -> None:
29 | policy = create_normal_policy(
30 | observation_shape=observation_shape,
31 | action_size=action_size,
32 | encoder_factory=DummyEncoderFactory(),
33 | device="cpu:0",
34 | enable_ddp=False,
35 | )
36 | _, q_func_forwarder = create_continuous_q_function(
37 | observation_shape=observation_shape,
38 | action_size=action_size,
39 | encoder_factory=DummyEncoderFactory(),
40 | q_func_factory=MeanQFunctionFactory(),
41 | n_ensembles=n_critics,
42 | device="cpu:0",
43 | enable_ddp=False,
44 | )
45 |
46 | observations = create_torch_observations(observation_shape, batch_size)
47 |
48 | q_values, log_probs = sample_q_values_with_policy(
49 | policy=policy,
50 | q_func_forwarder=q_func_forwarder,
51 | policy_observations=observations,
52 | value_observations=observations,
53 | n_action_samples=n_action_samples,
54 | detach_policy_output=False,
55 | )
56 | assert q_values.shape == (n_critics, batch_size, n_action_samples)
57 | assert log_probs.shape == (1, batch_size, n_action_samples)
58 |
--------------------------------------------------------------------------------
/tests/algos/transformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/algos/transformer/__init__.py
--------------------------------------------------------------------------------
/tests/algos/transformer/test_action_samplers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from d3rlpy.algos import (
5 | GreedyTransformerActionSampler,
6 | IdentityTransformerActionSampler,
7 | SoftmaxTransformerActionSampler,
8 | )
9 |
10 |
11 | @pytest.mark.parametrize("action_size", [4])
12 | def test_identity_transformer_action_sampler(action_size: int) -> None:
13 | action_sampler = IdentityTransformerActionSampler()
14 |
15 | x = np.random.random(action_size)
16 | action = action_sampler(x)
17 |
18 | assert np.all(action == x)
19 |
20 |
21 | @pytest.mark.parametrize("action_size", [10])
22 | def test_softmax_transformer_action_sampler(action_size: int) -> None:
23 | action_sampler = SoftmaxTransformerActionSampler()
24 |
25 | logits = np.random.random(action_size)
26 | action = action_sampler(logits)
27 | assert isinstance(action, int)
28 |
29 | same_actions = []
30 | for _ in range(100):
31 | same_actions.append(action == action_sampler(logits))
32 | assert not all(same_actions)
33 |
34 |
35 | @pytest.mark.parametrize("action_size", [10])
36 | def test_greedy_transformer_action_sampler(action_size: int) -> None:
37 | action_sampler = GreedyTransformerActionSampler()
38 |
39 | logits = np.random.random(action_size)
40 | action = action_sampler(logits)
41 | assert isinstance(action, int)
42 | assert action == np.argmax(logits)
43 |
--------------------------------------------------------------------------------
/tests/algos/transformer/test_decision_transformer.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos import (
6 | DecisionTransformerConfig,
7 | DiscreteDecisionTransformerConfig,
8 | )
9 | from d3rlpy.types import Shape
10 |
11 | from ...models.torch.model_test import DummyEncoderFactory
12 | from ...testing_utils import create_scaler_tuple
13 | from .algo_test import algo_tester
14 |
15 |
16 | @pytest.mark.parametrize(
17 | "observation_shape", [(100,), (4, 8, 8), ((100,), (200,))]
18 | )
19 | @pytest.mark.parametrize("scalers", [None, "min_max"])
20 | def test_decision_transformer(
21 | observation_shape: Shape,
22 | scalers: Optional[str],
23 | ) -> None:
24 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
25 | scalers, observation_shape
26 | )
27 | config = DecisionTransformerConfig(
28 | encoder_factory=DummyEncoderFactory(),
29 | observation_scaler=observation_scaler,
30 | action_scaler=action_scaler,
31 | reward_scaler=reward_scaler,
32 | )
33 | dt = config.create()
34 | algo_tester(
35 | dt, # type: ignore
36 | observation_shape,
37 | )
38 |
39 |
40 | @pytest.mark.parametrize(
41 | "observation_shape", [(100,), (4, 8, 8), ((100,), (200,))]
42 | )
43 | @pytest.mark.parametrize("scalers", [None, "min_max"])
44 | def test_discrete_decision_transformer(
45 | observation_shape: Shape,
46 | scalers: Optional[str],
47 | ) -> None:
48 | observation_scaler, _, reward_scaler = create_scaler_tuple(
49 | scalers, observation_shape
50 | )
51 | config = DiscreteDecisionTransformerConfig(
52 | encoder_factory=DummyEncoderFactory(),
53 | observation_scaler=observation_scaler,
54 | reward_scaler=reward_scaler,
55 | num_heads=4,
56 | )
57 | dt = config.create()
58 | algo_tester(
59 | dt, # type: ignore
60 | observation_shape,
61 | action_size=10,
62 | )
63 |
--------------------------------------------------------------------------------
/tests/algos/transformer/test_tacr.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import pytest
4 |
5 | from d3rlpy.algos import TACRConfig
6 | from d3rlpy.types import Shape
7 |
8 | from ...models.torch.model_test import DummyEncoderFactory
9 | from ...testing_utils import create_scaler_tuple
10 | from .algo_test import algo_tester
11 |
12 |
13 | @pytest.mark.parametrize(
14 | "observation_shape", [(100,), (4, 8, 8), ((100,), (200,))]
15 | )
16 | @pytest.mark.parametrize("scalers", [None, "min_max"])
17 | def test_tacr(observation_shape: Shape, scalers: Optional[str]) -> None:
18 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple(
19 | scalers, observation_shape
20 | )
21 | config = TACRConfig(
22 | actor_encoder_factory=DummyEncoderFactory(),
23 | critic_encoder_factory=DummyEncoderFactory(),
24 | observation_scaler=observation_scaler,
25 | action_scaler=action_scaler,
26 | reward_scaler=reward_scaler,
27 | )
28 | tacr = config.create()
29 | algo_tester(
30 | tacr, # type: ignore
31 | observation_shape,
32 | )
33 |
--------------------------------------------------------------------------------
/tests/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/dataset/__init__.py
--------------------------------------------------------------------------------
/tests/dataset/test_buffers.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | import pytest
4 |
5 | from d3rlpy.dataset import FIFOBuffer, InfiniteBuffer
6 |
7 | from ..testing_utils import create_episode
8 |
9 |
10 | @pytest.mark.parametrize("observation_shape", [(4,)])
11 | @pytest.mark.parametrize("action_size", [2])
12 | @pytest.mark.parametrize("length", [100])
13 | @pytest.mark.parametrize("terminated", [False, True])
14 | def test_infinite_buffer(
15 | observation_shape: Sequence[int],
16 | action_size: int,
17 | length: int,
18 | terminated: bool,
19 | ) -> None:
20 | buffer = InfiniteBuffer()
21 |
22 | for i in range(10):
23 | episode = create_episode(
24 | observation_shape, action_size, length, terminated=terminated
25 | )
26 | for j in range(episode.transition_count):
27 | buffer.append(episode, j)
28 |
29 | if terminated:
30 | assert buffer.transition_count == (i + 1) * (length)
31 | else:
32 | assert buffer.transition_count == (i + 1) * (length - 1)
33 | assert len(buffer.episodes) == i + 1
34 |
35 |
36 | @pytest.mark.parametrize("observation_shape", [(4,)])
37 | @pytest.mark.parametrize("action_size", [2])
38 | @pytest.mark.parametrize("length", [100])
39 | @pytest.mark.parametrize("limit", [500])
40 | @pytest.mark.parametrize("terminated", [False, True])
41 | def test_fifo_buffer(
42 | observation_shape: Sequence[int],
43 | action_size: int,
44 | length: int,
45 | limit: int,
46 | terminated: bool,
47 | ) -> None:
48 | buffer = FIFOBuffer(limit)
49 |
50 | for i in range(10):
51 | episode = create_episode(
52 | observation_shape, action_size, length, terminated=terminated
53 | )
54 | for j in range(episode.transition_count):
55 | buffer.append(episode, j)
56 |
57 | if i >= 5:
58 | assert buffer.transition_count == limit
59 | if terminated:
60 | assert len(buffer.episodes) == 5
61 | else:
62 | assert len(buffer.episodes) == 6
63 | else:
64 | if terminated:
65 | assert buffer.transition_count == (i + 1) * length
66 | else:
67 | assert buffer.transition_count == (i + 1) * (length - 1)
68 | assert len(buffer.episodes) == i + 1
69 |
--------------------------------------------------------------------------------
/tests/dataset/test_compat.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from d3rlpy.dataset import MDPDataset, ReplayBuffer
7 |
8 | from ..testing_utils import create_observation
9 |
10 |
11 | @pytest.mark.parametrize("observation_shape", [(4,)])
12 | @pytest.mark.parametrize("action_size", [2])
13 | @pytest.mark.parametrize("length", [100])
14 | @pytest.mark.parametrize("num_episodes", [10])
15 | def test_replay_buffer(
16 | observation_shape: Sequence[int],
17 | action_size: int,
18 | length: int,
19 | num_episodes: int,
20 | ) -> None:
21 | observations = []
22 | actions = []
23 | rewards = []
24 | terminals = []
25 | for _ in range(num_episodes):
26 | for i in range(length):
27 | observations.append(create_observation(observation_shape))
28 | actions.append(np.random.random(action_size))
29 | rewards.append(np.random.random())
30 | terminals.append(float(i == length - 1))
31 |
32 | dataset = MDPDataset(
33 | observations=np.array(observations),
34 | actions=np.array(actions),
35 | rewards=np.array(rewards),
36 | terminals=np.array(terminals),
37 | )
38 |
39 | assert isinstance(dataset, ReplayBuffer)
40 | assert len(dataset.episodes) == num_episodes
41 |
--------------------------------------------------------------------------------
/tests/dataset/test_episode_generator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from d3rlpy.dataset import EpisodeGenerator
5 | from d3rlpy.types import Float32NDArray, Shape
6 |
7 | from ..testing_utils import create_observations
8 |
9 |
10 | @pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))])
11 | @pytest.mark.parametrize("action_size", [2])
12 | @pytest.mark.parametrize("length", [1000])
13 | @pytest.mark.parametrize("terminal", [False, True])
14 | def test_episode_generator(
15 | observation_shape: Shape, action_size: int, length: int, terminal: bool
16 | ) -> None:
17 | observations = create_observations(observation_shape, length)
18 | actions = np.random.random((length, action_size))
19 | rewards: Float32NDArray = np.random.random((length, 1)).astype(np.float32)
20 | terminals: Float32NDArray = np.zeros(length, dtype=np.float32)
21 | timeouts: Float32NDArray = np.zeros(length, dtype=np.float32)
22 | for i in range(length // 100):
23 | if terminal:
24 | terminals[(i + 1) * 100 - 1] = 1.0
25 | else:
26 | timeouts[(i + 1) * 100 - 1] = 1.0
27 |
28 | episode_generator = EpisodeGenerator(
29 | observations=observations,
30 | actions=actions,
31 | rewards=rewards,
32 | terminals=terminals,
33 | timeouts=timeouts,
34 | )
35 |
36 | episodes = episode_generator()
37 | assert len(episodes) == length // 100
38 |
39 | for episode in episodes:
40 | assert len(episode) == 100
41 | if isinstance(observation_shape[0], tuple):
42 | for i, shape in enumerate(observation_shape):
43 | assert isinstance(shape, tuple)
44 | assert episode.observations[i].shape == (100, *shape)
45 | else:
46 | assert isinstance(episode.observations, np.ndarray)
47 | assert episode.observations.shape == (100, *observation_shape)
48 | assert episode.actions.shape == (100, action_size)
49 | assert episode.rewards.shape == (100, 1)
50 | assert episode.terminated == terminal
51 |
--------------------------------------------------------------------------------
/tests/dataset/test_io.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pytest
5 |
6 | from d3rlpy.dataset import Episode, dump, load
7 | from d3rlpy.types import Shape
8 |
9 | from ..testing_utils import create_episode
10 |
11 |
12 | @pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))])
13 | @pytest.mark.parametrize("action_size", [2])
14 | @pytest.mark.parametrize("length", [100])
15 | def test_dump_and_load(
16 | observation_shape: Shape, action_size: int, length: int
17 | ) -> None:
18 | episode1 = create_episode(observation_shape, action_size, length)
19 | episode2 = create_episode(observation_shape, action_size, length * 2)
20 |
21 | path = os.path.join("test_data", "data.h5")
22 |
23 | # dump
24 | with open(path, "w+b") as f:
25 | dump([episode1, episode2], f)
26 |
27 | # load
28 | with open(path, "rb") as f:
29 | loaded_episodes = load(Episode, f)
30 | assert len(loaded_episodes) == 2
31 |
32 | for episode, loaded_episode in zip([episode1, episode2], loaded_episodes):
33 | if isinstance(observation_shape[0], tuple):
34 | for i in range(len(observation_shape)):
35 | assert np.all(
36 | episode.observations[i] == loaded_episode.observations[i]
37 | )
38 | else:
39 | assert np.all(episode.observations == loaded_episode.observations)
40 | assert np.all(episode.actions == loaded_episode.actions)
41 | assert np.all(episode.rewards == loaded_episode.rewards)
42 | assert np.all(episode.terminated == loaded_episode.terminated)
43 |
--------------------------------------------------------------------------------
/tests/dummy_env.py:
--------------------------------------------------------------------------------
1 | from typing import Any
2 |
3 | import gym
4 | import numpy as np
5 | from gym.spaces import Box, Discrete
6 |
7 | from d3rlpy.types import NDArray
8 |
9 |
10 | class DummyAtari(gym.Env[NDArray, int]):
11 | def __init__(self, grayscale: bool = True, squeeze: bool = False):
12 | if grayscale:
13 | shape = (84, 84) if squeeze else (84, 84, 1)
14 | else:
15 | shape = (84, 84, 3)
16 | self.observation_space = Box(
17 | low=np.zeros(shape),
18 | high=np.zeros(shape) + 255,
19 | dtype=np.uint8,
20 | )
21 | self.action_space = Discrete(4)
22 | self.t = 1
23 |
24 | def step(
25 | self, action: int
26 | ) -> tuple[NDArray, float, bool, bool, dict[str, Any]]:
27 | observation = self.observation_space.sample()
28 | reward = np.random.random()
29 | return observation, reward, False, self.t % 80 == 0, {}
30 |
31 | def reset(self, **kwargs: Any) -> tuple[NDArray, dict[str, Any]]:
32 | self.t = 1
33 | return self.observation_space.sample(), {}
34 |
--------------------------------------------------------------------------------
/tests/envs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/envs/__init__.py
--------------------------------------------------------------------------------
/tests/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/metrics/__init__.py
--------------------------------------------------------------------------------
/tests/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/models/__init__.py
--------------------------------------------------------------------------------
/tests/models/test_lr_schedulers.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | from torch.optim import SGD
4 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
5 |
6 | from d3rlpy.optimizers.lr_schedulers import (
7 | CosineAnnealingLRFactory,
8 | WarmupSchedulerFactory,
9 | )
10 |
11 |
12 | @pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)])
13 | def test_warmup_scheduler_factory(module: torch.nn.Module) -> None:
14 | factory = WarmupSchedulerFactory(warmup_steps=1000)
15 |
16 | lr_scheduler = factory.create(SGD(module.parameters(), 1e-4))
17 |
18 | assert isinstance(lr_scheduler, LambdaLR)
19 |
20 | # check serialization and deserialization
21 | WarmupSchedulerFactory.deserialize(factory.serialize())
22 |
23 |
24 | @pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)])
25 | def test_cosine_annealing_lr_factory(module: torch.nn.Module) -> None:
26 | factory = CosineAnnealingLRFactory(T_max=1000)
27 |
28 | lr_scheduler = factory.create(SGD(module.parameters(), 1e-4))
29 |
30 | assert isinstance(lr_scheduler, CosineAnnealingLR)
31 |
32 | # check serialization and deserialization
33 | CosineAnnealingLRFactory.deserialize(factory.serialize())
34 |
--------------------------------------------------------------------------------
/tests/models/torch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/models/torch/__init__.py
--------------------------------------------------------------------------------
/tests/models/torch/q_functions/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/models/torch/q_functions/__init__.py
--------------------------------------------------------------------------------
/tests/models/torch/q_functions/test_utility.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 |
5 | from d3rlpy.models.torch.q_functions.utility import (
6 | compute_quantile_huber_loss,
7 | pick_quantile_value_by_action,
8 | pick_value_by_action,
9 | )
10 |
11 | from ..model_test import ref_quantile_huber_loss
12 |
13 |
14 | @pytest.mark.parametrize("batch_size", [32])
15 | @pytest.mark.parametrize("action_size", [2])
16 | @pytest.mark.parametrize("keepdims", [True, False])
17 | def test_pick_value_by_action(
18 | batch_size: int, action_size: int, keepdims: bool
19 | ) -> None:
20 | values = torch.rand(batch_size, action_size)
21 | action = torch.randint(action_size, size=(batch_size,))
22 |
23 | rets = pick_value_by_action(values, action, keepdims)
24 |
25 | if keepdims:
26 | assert rets.shape == (batch_size, 1)
27 | else:
28 | assert rets.shape == (batch_size,)
29 |
30 | rets = rets.view(batch_size, -1)
31 |
32 | for i in range(batch_size):
33 | assert (rets[i] == values[i][action[i]]).all()
34 |
35 |
36 | @pytest.mark.parametrize("batch_size", [32])
37 | @pytest.mark.parametrize("action_size", [2])
38 | @pytest.mark.parametrize("n_quantiles", [200])
39 | @pytest.mark.parametrize("keepdims", [True, False])
40 | def test_pick_quantile_value_by_action(
41 | batch_size: int,
42 | action_size: int,
43 | n_quantiles: int,
44 | keepdims: bool,
45 | ) -> None:
46 | values = torch.rand(batch_size, action_size, n_quantiles)
47 | action = torch.randint(action_size, size=(batch_size,))
48 |
49 | rets = pick_quantile_value_by_action(values, action, keepdims)
50 |
51 | if keepdims:
52 | assert rets.shape == (batch_size, 1, n_quantiles)
53 | else:
54 | assert rets.shape == (batch_size, n_quantiles)
55 |
56 | rets = rets.view(batch_size, -1)
57 |
58 | for i in range(batch_size):
59 | assert (rets[i] == values[i][action[i]]).all()
60 |
61 |
62 | @pytest.mark.parametrize("batch_size", [32])
63 | @pytest.mark.parametrize("n_quantiles", [200])
64 | def test_compute_quantile_huber_loss(batch_size: int, n_quantiles: int) -> None:
65 | y = np.random.random((batch_size, n_quantiles, 1))
66 | target = np.random.random((batch_size, 1, n_quantiles))
67 | taus = np.random.random((1, 1, n_quantiles))
68 |
69 | ref_loss = ref_quantile_huber_loss(y, target, taus, n_quantiles)
70 | loss = compute_quantile_huber_loss(
71 | torch.tensor(y), torch.tensor(target), torch.tensor(taus)
72 | )
73 |
74 | assert np.allclose(loss.cpu().detach().numpy(), ref_loss)
75 |
--------------------------------------------------------------------------------
/tests/models/torch/test_parameters.py:
--------------------------------------------------------------------------------
1 | from typing import Sequence
2 |
3 | import pytest
4 | import torch
5 |
6 | from d3rlpy.models.torch.parameters import Parameter, get_parameter
7 |
8 |
9 | @pytest.mark.parametrize("shape", [(100,)])
10 | def test_parameter(shape: Sequence[int]) -> None:
11 | data = torch.rand(shape)
12 | parameter = Parameter(data)
13 |
14 | assert get_parameter(parameter).data.shape == shape
15 | assert torch.all(get_parameter(parameter).data == data)
16 |
--------------------------------------------------------------------------------
/tests/models/torch/test_q_functions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 |
4 | from d3rlpy.models.builders import create_continuous_q_function
5 | from d3rlpy.models.q_functions import (
6 | MeanQFunctionFactory,
7 | QFunctionFactory,
8 | QRQFunctionFactory,
9 | )
10 | from d3rlpy.models.torch.q_functions import compute_max_with_n_actions
11 | from d3rlpy.types import Shape
12 |
13 | from ...testing_utils import create_torch_observations
14 | from .model_test import DummyEncoderFactory
15 |
16 |
17 | @pytest.mark.parametrize(
18 | "observation_shape", [(4, 84, 84), (100,), ((100,), (200,))]
19 | )
20 | @pytest.mark.parametrize("action_size", [3])
21 | @pytest.mark.parametrize(
22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()]
23 | )
24 | @pytest.mark.parametrize("n_ensembles", [2])
25 | @pytest.mark.parametrize("batch_size", [100])
26 | @pytest.mark.parametrize("n_actions", [10])
27 | @pytest.mark.parametrize("lam", [0.75])
28 | def test_compute_max_with_n_actions(
29 | observation_shape: Shape,
30 | action_size: int,
31 | q_func_factory: QFunctionFactory,
32 | n_ensembles: int,
33 | batch_size: int,
34 | n_actions: int,
35 | lam: float,
36 | ) -> None:
37 | _, forwarder = create_continuous_q_function(
38 | observation_shape,
39 | action_size,
40 | DummyEncoderFactory(),
41 | q_func_factory,
42 | n_ensembles=n_ensembles,
43 | device="cpu:0",
44 | enable_ddp=False,
45 | )
46 | x = create_torch_observations(observation_shape, batch_size)
47 | actions = torch.rand(batch_size, n_actions, action_size)
48 |
49 | y = compute_max_with_n_actions(x, actions, forwarder, lam)
50 |
51 | if isinstance(q_func_factory, MeanQFunctionFactory):
52 | assert y.shape == (batch_size, 1)
53 | else:
54 | assert isinstance(q_func_factory, QRQFunctionFactory)
55 | assert y.shape == (batch_size, q_func_factory.n_quantiles)
56 |
--------------------------------------------------------------------------------
/tests/models/torch/test_v_functions.py:
--------------------------------------------------------------------------------
1 | import pytest
2 | import torch
3 | import torch.nn.functional as F
4 |
5 | from d3rlpy.models.torch.v_functions import (
6 | ValueFunction,
7 | compute_v_function_error,
8 | )
9 | from d3rlpy.types import Shape
10 |
11 | from ...testing_utils import create_torch_observations
12 | from .model_test import DummyEncoder, check_parameter_updates
13 |
14 |
15 | @pytest.mark.parametrize("observation_shape", [(100,), ((100,), (200,))])
16 | @pytest.mark.parametrize("batch_size", [32])
17 | def test_value_function(observation_shape: Shape, batch_size: int) -> None:
18 | encoder = DummyEncoder(observation_shape)
19 | v_func = ValueFunction(encoder, encoder.get_feature_size())
20 |
21 | # check output shape
22 | x = create_torch_observations(observation_shape, batch_size)
23 | y = v_func(x)
24 | assert y.shape == (batch_size, 1)
25 |
26 | # check compute_error
27 | returns = torch.rand(batch_size, 1)
28 | loss = compute_v_function_error(v_func, x, returns)
29 | assert torch.allclose(loss, F.mse_loss(y, returns))
30 |
31 | # check layer connections
32 | check_parameter_updates(
33 | v_func,
34 | (x,),
35 | )
36 |
--------------------------------------------------------------------------------
/tests/ope/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/ope/__init__.py
--------------------------------------------------------------------------------
/tests/optimizers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/optimizers/__init__.py
--------------------------------------------------------------------------------
/tests/optimizers/test_lr_schedulers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from torch.optim import SGD
5 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
6 |
7 | from d3rlpy.optimizers.lr_schedulers import (
8 | CosineAnnealingLRFactory,
9 | WarmupSchedulerFactory,
10 | )
11 |
12 |
13 | @pytest.mark.parametrize("warmup_steps", [100])
14 | @pytest.mark.parametrize("lr", [1e-4])
15 | @pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)])
16 | def test_warmup_scheduler_factory(
17 | warmup_steps: int, lr: float, module: torch.nn.Module
18 | ) -> None:
19 | factory = WarmupSchedulerFactory(warmup_steps)
20 |
21 | lr_scheduler = factory.create(SGD(module.parameters(), lr=lr))
22 |
23 | assert np.allclose(lr_scheduler.get_lr()[0], lr / warmup_steps)
24 | for _ in range(warmup_steps):
25 | lr_scheduler.step()
26 | assert lr_scheduler.get_lr()[0] == lr
27 |
28 | assert isinstance(lr_scheduler, LambdaLR)
29 |
30 | # check serialization and deserialization
31 | WarmupSchedulerFactory.deserialize(factory.serialize())
32 |
33 |
34 | @pytest.mark.parametrize("T_max", [100])
35 | @pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)])
36 | def test_cosine_annealing_factory(T_max: int, module: torch.nn.Module) -> None:
37 | factory = CosineAnnealingLRFactory(T_max=T_max)
38 |
39 | lr_scheduler = factory.create(SGD(module.parameters()))
40 |
41 | assert isinstance(lr_scheduler, CosineAnnealingLR)
42 |
43 | # check serialization and deserialization
44 | CosineAnnealingLRFactory.deserialize(factory.serialize())
45 |
--------------------------------------------------------------------------------
/tests/preprocessing/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/preprocessing/__init__.py
--------------------------------------------------------------------------------
/tests/preprocessing/test_base.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from d3rlpy.preprocessing.base import add_leading_dims, add_leading_dims_numpy
5 |
6 |
7 | def test_add_leading_dims() -> None:
8 | x = torch.rand(3)
9 | target = torch.rand(1, 2, 3)
10 | assert add_leading_dims(x, target).shape == (1, 1, 3)
11 |
12 |
13 | def test_add_leading_dims_numpy() -> None:
14 | x = np.random.random(3)
15 | target = np.random.random((1, 2, 3))
16 | assert add_leading_dims_numpy(x, target).shape == (1, 1, 3)
17 |
--------------------------------------------------------------------------------
/tests/test_dataclass_utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 |
3 | import torch
4 |
5 | from d3rlpy.dataclass_utils import asdict_as_float, asdict_without_copy
6 |
7 |
8 | @dataclasses.dataclass(frozen=True)
9 | class A:
10 | a: int
11 |
12 |
13 | @dataclasses.dataclass(frozen=True)
14 | class D:
15 | a: A
16 | b: float
17 | c: str
18 |
19 |
20 | def test_asdict_without_any() -> None:
21 | a = A(1)
22 | d = D(a, 2.0, "3")
23 | dict_d = asdict_without_copy(d)
24 | assert dict_d["a"] is a
25 | assert dict_d["b"] == 2.0
26 | assert dict_d["c"] == "3"
27 |
28 |
29 | @dataclasses.dataclass(frozen=True)
30 | class D2:
31 | a: float
32 | b: torch.Tensor
33 |
34 |
35 | def test_asdict_as_float() -> None:
36 | b = torch.rand([], dtype=torch.float32)
37 | d = D2(a=1.0, b=b)
38 | dict_d = asdict_as_float(d)
39 | assert dict_d["a"] == 1.0
40 | assert dict_d["b"] == b.numpy()
41 |
--------------------------------------------------------------------------------
/tests/test_datasets.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from d3rlpy.datasets import get_cartpole, get_dataset, get_minari, get_pendulum
4 |
5 |
6 | @pytest.mark.parametrize("dataset_type", ["replay", "random"])
7 | def test_get_cartpole(dataset_type: str) -> None:
8 | get_cartpole(dataset_type=dataset_type)
9 |
10 |
11 | @pytest.mark.parametrize("dataset_type", ["replay", "random"])
12 | def test_get_pendulum(dataset_type: str) -> None:
13 | get_pendulum(dataset_type=dataset_type)
14 |
15 |
16 | @pytest.mark.parametrize(
17 | "env_name",
18 | ["cartpole-random", "pendulum-random"],
19 | )
20 | def test_get_dataset(env_name: str) -> None:
21 | _, env = get_dataset(env_name)
22 | if env_name == "cartpole-random":
23 | assert env.unwrapped.spec.id == "CartPole-v1"
24 | elif env_name == "pendulum-random":
25 | assert env.unwrapped.spec.id == "Pendulum-v1"
26 |
27 |
28 | @pytest.mark.parametrize(
29 | "dataset_name, env_name",
30 | [
31 | ("D4RL/door/cloned-v2", "AdroitHandDoor-v1"),
32 | ("D4RL/kitchen/complete-v2", "FrankaKitchen-v1"),
33 | ],
34 | )
35 | @pytest.mark.parametrize("tuple_observation", [False, True])
36 | def test_get_minari(
37 | dataset_name: str, env_name: str, tuple_observation: bool
38 | ) -> None:
39 | dataset, env = get_minari(dataset_name, tuple_observation=tuple_observation)
40 | assert env.unwrapped.spec.id == env_name # type: ignore
41 |
42 | if tuple_observation:
43 | # check shape
44 | ep = dataset.episodes[0]
45 | ref_shape0 = ep.observations[0].shape[1:]
46 | ref_shape1 = ep.observations[1].shape[1:]
47 | obs, _ = env.reset()
48 | assert obs[0].shape == ref_shape0
49 | assert obs[1].shape == ref_shape1
50 | obs, _, _, _, _ = env.step(env.action_space.sample())
51 | assert obs[0].shape == ref_shape0
52 | assert obs[1].shape == ref_shape1
53 | else:
54 | # check shape
55 | ref_shape = dataset.episodes[0].observations.shape[1:] # type: ignore
56 | obs, _ = env.reset()
57 | assert obs.shape == ref_shape
58 | obs, _, _, _, _ = env.step(env.action_space.sample())
59 | assert obs.shape == ref_shape
60 |
--------------------------------------------------------------------------------
/tests/test_itertools.py:
--------------------------------------------------------------------------------
1 | from d3rlpy.itertools import first_flag, last_flag
2 |
3 |
4 | def test_last_flag() -> None:
5 | x = [1, 2, 3, 4, 5]
6 | i = 0
7 | for is_last, value in last_flag(x):
8 | if i == len(x) - 1:
9 | assert is_last
10 | else:
11 | assert not is_last
12 | assert value == x[i]
13 | i += 1
14 |
15 |
16 | def test_first_flag() -> None:
17 | x = [1, 2, 3, 4, 5]
18 | i = 0
19 | for is_first, value in first_flag(x):
20 | if i == 0:
21 | assert is_first
22 | else:
23 | assert not is_first
24 | assert value == x[i]
25 | i += 1
26 |
--------------------------------------------------------------------------------
/tests/tokenizers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/tokenizers/__init__.py
--------------------------------------------------------------------------------
/tests/tokenizers/test_tokenizers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from d3rlpy.tokenizers import FloatTokenizer
4 | from d3rlpy.types import NDArray
5 |
6 |
7 | def test_float_tokenizer() -> None:
8 | tokenizer = FloatTokenizer(num_bins=100, use_mu_law_encode=False)
9 | v: NDArray = (2.0 * np.arange(100) / 100 - 1).astype(np.float32)
10 | tokenized_v = tokenizer(v)
11 | assert np.all(tokenized_v == np.arange(100))
12 |
13 | # check mu_law_encode
14 | tokenizer = FloatTokenizer(num_bins=100)
15 | v = np.arange(100) - 50
16 | tokenized_v = tokenizer(v)
17 | assert np.all(tokenized_v >= 0)
18 | assert np.all(tokenized_v < 100)
19 |
20 | # check token_offset
21 | tokenizer = FloatTokenizer(
22 | num_bins=100, use_mu_law_encode=False, token_offset=1
23 | )
24 | v = np.array([-1, 1])
25 | tokenized_v = tokenizer(v)
26 | assert tokenized_v[0] == 1
27 | assert tokenized_v[1] == 100
28 |
29 | # check decode
30 | tokenizer = FloatTokenizer(num_bins=1000000)
31 | v = np.arange(100) - 50
32 | decoded_v = tokenizer.decode(tokenizer(v))
33 | assert np.allclose(decoded_v, v, atol=1e-3)
34 |
35 | # check decode with multi-dimension
36 | v = np.reshape(v, [5, -1])
37 | decoded_v = tokenizer.decode(tokenizer(v))
38 | assert v.shape == decoded_v.shape
39 | assert np.allclose(decoded_v, v, atol=1e-3)
40 |
--------------------------------------------------------------------------------
/tests/tokenizers/test_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from d3rlpy.tokenizers import mu_law_decode, mu_law_encode
4 | from d3rlpy.types import NDArray
5 |
6 |
7 | def test_mu_law_encode() -> None:
8 | v = np.arange(100) - 50
9 | encoded_v = mu_law_encode(v, mu=100, basis=256)
10 | assert np.all(encoded_v < 1)
11 | assert np.all(-1 < encoded_v)
12 |
13 |
14 | def test_mu_law_decode() -> None:
15 | v: NDArray = np.array(np.arange(100) - 50, dtype=np.float32)
16 | decoded_v = mu_law_decode(
17 | mu_law_encode(v, mu=100, basis=256), mu=100, basis=256
18 | )
19 | assert np.allclose(decoded_v, v)
20 |
--------------------------------------------------------------------------------