├── .coveragerc ├── .editorconfig ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.md │ └── feature_request.md └── workflows │ ├── format_check.yml │ └── test.yml ├── .gitignore ├── .readthedocs.yaml ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── ROADMAP.md ├── assets ├── breakout.gif ├── breakout.png ├── hopper.png ├── logo.png └── mujoco_hopper.gif ├── d3rlpy ├── __init__.py ├── _version.py ├── algos │ ├── __init__.py │ ├── qlearning │ │ ├── __init__.py │ │ ├── awac.py │ │ ├── base.py │ │ ├── bc.py │ │ ├── bcq.py │ │ ├── bear.py │ │ ├── cal_ql.py │ │ ├── cql.py │ │ ├── crr.py │ │ ├── ddpg.py │ │ ├── dqn.py │ │ ├── explorers.py │ │ ├── iql.py │ │ ├── nfq.py │ │ ├── plas.py │ │ ├── prdc.py │ │ ├── random_policy.py │ │ ├── rebrac.py │ │ ├── sac.py │ │ ├── td3.py │ │ ├── td3_plus_bc.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── awac_impl.py │ │ │ ├── bc_impl.py │ │ │ ├── bcq_impl.py │ │ │ ├── bear_impl.py │ │ │ ├── cal_ql_impl.py │ │ │ ├── cql_impl.py │ │ │ ├── crr_impl.py │ │ │ ├── ddpg_impl.py │ │ │ ├── dqn_impl.py │ │ │ ├── iql_impl.py │ │ │ ├── plas_impl.py │ │ │ ├── prdc_impl.py │ │ │ ├── rebrac_impl.py │ │ │ ├── sac_impl.py │ │ │ ├── td3_impl.py │ │ │ ├── td3_plus_bc_impl.py │ │ │ └── utility.py │ ├── transformer │ │ ├── __init__.py │ │ ├── action_samplers.py │ │ ├── base.py │ │ ├── decision_transformer.py │ │ ├── inputs.py │ │ ├── tacr.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── decision_transformer_impl.py │ │ │ └── tacr_impl.py │ └── utility.py ├── base.py ├── cli.py ├── constants.py ├── dataclass_utils.py ├── dataset │ ├── __init__.py │ ├── buffers.py │ ├── compat.py │ ├── components.py │ ├── episode_generator.py │ ├── io.py │ ├── mini_batch.py │ ├── replay_buffer.py │ ├── trajectory_slicers.py │ ├── transition_pickers.py │ ├── utils.py │ └── writers.py ├── datasets.py ├── distributed.py ├── envs │ ├── __init__.py │ ├── utility.py │ └── wrappers.py ├── healthcheck.py ├── interface.py ├── itertools.py ├── logging │ ├── __init__.py │ ├── file_adapter.py │ ├── logger.py │ ├── noop_adapter.py │ ├── tensorboard_adapter.py │ ├── utils.py │ └── wandb_adapter.py ├── metrics │ ├── __init__.py │ ├── evaluators.py │ └── utility.py ├── models │ ├── __init__.py │ ├── builders.py │ ├── encoders.py │ ├── q_functions.py │ ├── torch │ │ ├── __init__.py │ │ ├── distributions.py │ │ ├── encoders.py │ │ ├── imitators.py │ │ ├── parameters.py │ │ ├── policies.py │ │ ├── q_functions │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── ensemble_q_function.py │ │ │ ├── iqn_q_function.py │ │ │ ├── mean_q_function.py │ │ │ ├── qr_q_function.py │ │ │ └── utility.py │ │ ├── transformers.py │ │ └── v_functions.py │ └── utility.py ├── notebook_utils.py ├── ope │ ├── __init__.py │ ├── fqe.py │ └── torch │ │ ├── __init__.py │ │ └── fqe_impl.py ├── optimizers │ ├── __init__.py │ ├── lr_schedulers.py │ └── optimizers.py ├── preprocessing │ ├── __init__.py │ ├── action_scalers.py │ ├── base.py │ ├── observation_scalers.py │ └── reward_scalers.py ├── serializable_config.py ├── tokenizers │ ├── __init__.py │ ├── tokenizers.py │ └── utils.py ├── torch_utility.py └── types.py ├── dev.requirements.txt ├── docker └── Dockerfile ├── docs ├── Makefile ├── _static │ ├── css │ │ └── d3rlpy.css │ └── logo.png ├── _templates │ └── autosummary │ │ └── class.rst ├── assets │ ├── design.png │ ├── dqn_cartpole.png │ ├── fqe_cartpole_init_value.png │ ├── fqe_cartpole_soft_opc.png │ ├── mdp_dataset.png │ ├── plot_all_example.png │ └── plot_example.png ├── cli.rst ├── conf.py ├── index.rst ├── installation.rst ├── license.rst ├── make.bat ├── notebooks.rst ├── references │ ├── algos.rst │ ├── dataset.rst │ ├── datasets.rst │ ├── index.rst │ ├── logging.rst │ ├── metrics.rst │ ├── network_architectures.rst │ ├── off_policy_evaluation.rst │ ├── online.rst │ ├── optimizers.rst │ ├── preprocessing.rst │ └── q_functions.rst ├── reproductions.rst ├── requirements.txt ├── software_design.rst ├── tips.rst └── tutorials │ ├── after_training_policies.rst │ ├── create_your_dataset.rst │ ├── customize_neural_network.rst │ ├── data_collection.rst │ ├── finetuning.rst │ ├── getting_started.rst │ ├── index.rst │ ├── offline_policy_selection.rst │ ├── online_rl.rst │ ├── preprocess_and_postprocess.rst │ └── use_distributional_q_function.rst ├── examples ├── custom_algo.py ├── deepmind_control.py ├── distributed_offline_training.py ├── fine_tuning.py ├── frame_stacking.py ├── gymnasium_env.py ├── lr_scheduler.py ├── multi_step_training.py ├── preprocessors.py └── tuple_observation.py ├── mypy.ini ├── reproductions ├── finetuning │ ├── awac_finetune.py │ ├── cal_ql_finetune.py │ └── iql_finetune.py ├── offline │ ├── awac.py │ ├── bcq.py │ ├── bear.py │ ├── cql.py │ ├── crr.py │ ├── decision_transformer.py │ ├── discrete_bcq.py │ ├── discrete_cql.py │ ├── discrete_decision_transformer.py │ ├── iql.py │ ├── nfq.py │ ├── plas.py │ ├── plas_with_perturbation.py │ ├── prdc.py │ ├── qdt.py │ ├── qr_dqn.py │ ├── rebrac.py │ ├── sac.py │ ├── tacr.py │ ├── td3.py │ └── td3_plus_bc.py └── online │ ├── double_dqn_online.py │ ├── dqn_online.py │ ├── iqn_online.py │ ├── qr_dqn_online.py │ └── sac_online.py ├── requirements.txt ├── ruff.toml ├── scripts ├── build-dist ├── build-docker ├── build-docs ├── create_cartpole_dataset ├── create_cartpole_random_dataset ├── create_pendulum_dataset ├── create_pendulum_random_dataset ├── lint └── test ├── setup.py ├── tests ├── __init__.py ├── algos │ ├── __init__.py │ ├── qlearning │ │ ├── __init__.py │ │ ├── algo_test.py │ │ ├── test_awac.py │ │ ├── test_bc.py │ │ ├── test_bcq.py │ │ ├── test_bear.py │ │ ├── test_cal_ql.py │ │ ├── test_cql.py │ │ ├── test_crr.py │ │ ├── test_ddpg.py │ │ ├── test_dqn.py │ │ ├── test_explorers.py │ │ ├── test_iql.py │ │ ├── test_iterators.py │ │ ├── test_nfq.py │ │ ├── test_plas.py │ │ ├── test_prdc.py │ │ ├── test_random_policy.py │ │ ├── test_rebrac.py │ │ ├── test_sac.py │ │ ├── test_td3.py │ │ ├── test_td3_plus_bc.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ └── test_utility.py │ └── transformer │ │ ├── __init__.py │ │ ├── algo_test.py │ │ ├── test_action_samplers.py │ │ ├── test_decision_transformer.py │ │ ├── test_inputs.py │ │ └── test_tacr.py ├── base_test.py ├── dataset │ ├── __init__.py │ ├── test_buffers.py │ ├── test_compat.py │ ├── test_components.py │ ├── test_episode_generator.py │ ├── test_io.py │ ├── test_mini_batch.py │ ├── test_replay_buffer.py │ ├── test_trajectory_slicer.py │ ├── test_transition_pickers.py │ ├── test_utils.py │ └── test_writers.py ├── dummy_env.py ├── dummy_scalers.py ├── envs │ ├── __init__.py │ └── test_wrappers.py ├── logging │ └── test_logger.py ├── metrics │ ├── __init__.py │ ├── test_evaluators.py │ └── test_utility.py ├── models │ ├── __init__.py │ ├── test_builders.py │ ├── test_encoders.py │ ├── test_lr_schedulers.py │ ├── test_q_functions.py │ └── torch │ │ ├── __init__.py │ │ ├── model_test.py │ │ ├── q_functions │ │ ├── __init__.py │ │ ├── test_ensemble_q_function.py │ │ ├── test_iqn_q_function.py │ │ ├── test_mean_q_function.py │ │ ├── test_qr_q_function.py │ │ └── test_utility.py │ │ ├── test_distributions.py │ │ ├── test_encoders.py │ │ ├── test_imitators.py │ │ ├── test_parameters.py │ │ ├── test_policies.py │ │ ├── test_q_functions.py │ │ ├── test_transformers.py │ │ └── test_v_functions.py ├── ope │ ├── __init__.py │ └── test_fqe.py ├── optimizers │ ├── __init__.py │ ├── test_lr_schedulers.py │ └── test_optimizers.py ├── preprocessing │ ├── __init__.py │ ├── test_action_scalers.py │ ├── test_base.py │ ├── test_observation_scalers.py │ └── test_reward_scalers.py ├── test_dataclass_utils.py ├── test_datasets.py ├── test_itertools.py ├── test_torch_utility.py ├── testing_utils.py └── tokenizers │ ├── __init__.py │ ├── test_tokenizers.py │ └── test_utils.py └── tutorials ├── atari.ipynb ├── cartpole.ipynb ├── online.ipynb └── tpu.ipynb /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = d3rlpy/cli.py 3 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | root = true 2 | 3 | [*] 4 | indent_size = 2 5 | indent_style = space 6 | end_of_line = lf 7 | charset = utf-8 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | 11 | [*.py] 12 | indent_size = 4 13 | 14 | [*.pyx] 15 | indent_size = 4 16 | 17 | [*.pxd] 18 | indent_size = 4 19 | 20 | [*.pyi] 21 | indent_size = 4 22 | 23 | [*.md] 24 | trim_trailing_whitespace = false 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: "[BUG] bug title" 5 | labels: bug 6 | assignees: '' 7 | 8 | --- 9 | 10 | :warning: 11 | - Please don't post bug reports without a minimal example codes. Otherwise, it will take really long to solve the issue. 12 | - Please be polite in discussion. This is an open-sourced project by voluntary contributors. 13 | 14 | **Describe the bug** 15 | A clear and concise description of what the bug is. 16 | 17 | **To Reproduce** 18 | Steps to reproduce the behavior. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Additional context** 24 | Add any other context about the problem here. 25 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Feature request 3 | about: Suggest an idea for this project 4 | title: "[REQUEST] request title" 5 | labels: enhancement 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Is your feature request related to a problem? Please describe.** 11 | A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] 12 | 13 | **Describe the solution you'd like** 14 | A clear and concise description of what you want to happen. 15 | 16 | **Additional context** 17 | Add any other context or screenshots about the feature request here. 18 | -------------------------------------------------------------------------------- /.github/workflows/format_check.yml: -------------------------------------------------------------------------------- 1 | name: format check 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-22.04 8 | steps: 9 | - uses: actions/checkout@v2 10 | - name: Setup Python.3.9.x 11 | uses: actions/setup-python@v1 12 | with: 13 | python-version: 3.9.x 14 | - name: Cache pip 15 | uses: actions/cache@v4 16 | with: 17 | path: ~/.cache/pip 18 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev.requirements.txt') }} 19 | restore-keys: | 20 | ${{ runner.os }}-pip- 21 | ${{ runner.os }}- 22 | - name: Install packages 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install Cython numpy 26 | pip install -e . 27 | pip install -r dev.requirements.txt 28 | - name: Static analysis 29 | run: | 30 | ./scripts/lint 31 | 32 | concurrency: 33 | group: ${{ github.workflow }}-${{ github.ref }} 34 | cancel-in-progress: true 35 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | runs-on: ${{ matrix.os }} 8 | strategy: 9 | matrix: 10 | # temporary drop windows test 11 | # os: [ubuntu-22.04, macos-latest, windows-latest] 12 | os: [ubuntu-22.04, macos-latest] 13 | steps: 14 | - uses: actions/checkout@v4 15 | - name: Setup Python.3.10.x 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.10' 19 | - name: Cache pip 20 | uses: actions/cache@v4 21 | with: 22 | path: ~/.cache/pip 23 | key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}-${{ hashFiles('dev.requirements.txt') }} 24 | restore-keys: | 25 | ${{ runner.os }}-pip- 26 | ${{ runner.os }}- 27 | - name: Install dependencies for Windows 28 | if: ${{ matrix.os == 'windows-latest' }} 29 | run: | 30 | pip install torch==1.13.1+cpu -f https://download.pytorch.org/whl/torch_stable.html 31 | - name: Install dependencies for macOS 32 | if: ${{ matrix.os == 'macos-latest' }} 33 | run: | 34 | brew install libomp 35 | - name: Install packages 36 | run: | 37 | python -m pip install --upgrade pip 38 | pip install Cython numpy 39 | pip install -e . 40 | pip install -r dev.requirements.txt 41 | - name: Unit tests 42 | run: | 43 | mkdir -p test_data 44 | pytest --cov-report=xml --cov=d3rlpy --cov-config=.coveragerc tests -p no:warnings -v 45 | - name: Upload coverage 46 | if: ${{ matrix.os == 'ubuntu-22.04' }} 47 | env: 48 | CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} 49 | run: | 50 | bash <(curl -s https://codecov.io/bash) 51 | 52 | concurrency: 53 | group: ${{ github.workflow }}-${{ github.ref }} 54 | cancel-in-progress: true 55 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *__pycache__ 3 | *.pkl 4 | *.c 5 | *.cpp 6 | *.so 7 | .python-version 8 | d3rlpy_data 9 | test_data 10 | runs 11 | d3rlpy_logs 12 | docs/_build 13 | docs/d3rlpy*.rst 14 | docs/modules.rst 15 | docs/references/generated 16 | coverage.xml 17 | .coverage 18 | .mypy_cache 19 | .ipynb_checkpoints 20 | build 21 | dist 22 | /.idea/ 23 | *.egg-info 24 | /.vscode/ 25 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | build: 3 | os: ubuntu-22.04 4 | tools: 5 | python: "3.10" 6 | sphinx: 7 | builder: html 8 | configuration: docs/conf.py 9 | formats: all 10 | python: 11 | install: 12 | - requirements: docs/requirements.txt 13 | - method: pip 14 | path: . 15 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this software, please cite it as below." 3 | authors: 4 | - family-names: "Seno" 5 | given-names: "Takuma" 6 | title: "d3rlpy: An offline deep reinforcement learning library" 7 | version: 1.1.1 8 | date-released: 2022-11-26 9 | url: "https://github.com/takuseno/d3rlpy" 10 | preferred-citation: 11 | type: article 12 | authors: 13 | - family-names: "Seno" 14 | given-names: "Takuma" 15 | - family-names: "Imai" 16 | given-names: "Michita" 17 | journal: "Journal of Machine Learning Research" 18 | month: 11 19 | title: "d3rlpy: An Offline Deep Reinforcement Learning Library" 20 | year: 2022 21 | url: http://jmlr.org/papers/v23/22-0017.html 22 | volume: 23 23 | issue: 315 24 | start: 1 25 | end: 20 26 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to d3rlpy 2 | 3 | Any kind of contribution to d3rlpy would be highly appreciated! 4 | 5 | Contribution examples: 6 | - Thumbing up to good issues or pull requests :+1: 7 | - Opening issues about questions, bugs, installation problems, feature requests, algorithm requests etc. 8 | - Sending pull requests 9 | 10 | ## Development Guide 11 | 12 | ### Build from source 13 | ``` 14 | $ git clone git@github.com:takuseno/d3rlpy 15 | $ cd d3rlpy 16 | $ pip install -e . 17 | ``` 18 | 19 | Before making your nice PR, please run the follwing commands to inspect code qualities. 20 | 21 | ### Install additional dependencies for development 22 | ``` 23 | $ pip install -r dev.requirements.txt 24 | ``` 25 | 26 | ### Testing 27 | ``` 28 | $ ./scripts/test 29 | ``` 30 | 31 | ### Coding style check 32 | This repository is styled and analyzed with [Ruff](https://docs.astral.sh/ruff/). 33 | [docformatter](https://github.com/PyCQA/docformatter) is additionally used to format docstrings. 34 | This repository is fully type-annotated and checked by [mypy](https://github.com/python/mypy). 35 | Before you submit your PR, please execute this command: 36 | ``` 37 | $ ./scripts/lint 38 | ``` 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Takuma Seno 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /ROADMAP.md: -------------------------------------------------------------------------------- 1 | This list describes the planned features including breaking changes. 2 | 3 | ## Roadmap to v1.0.0 4 | - [x] Benchmark MuJoCo datasets 5 | - [x] Benchmark Atari 2600 datasets 6 | 7 | ## Roadmap to v2.x.x 8 | - [x] Change MDPDataset format to align with D4RL datasets 9 | - [x] Sophisticated config system using dataclasses 10 | - [x] Dump configuration and model parameters in a single file 11 | - [x] Support large dataset 12 | - [x] Support tuple observation 13 | - [x] Support large-scale data-parallel offline training 14 | - [x] Support Transformer architecture (e.g. Decision Transformer) 15 | - [x] Speed up training with CudaGraph and torch.compile 16 | - [ ] Support training foundation models (e.g. Gato) 17 | - [ ] Support large-scale distributed online training 18 | - [ ] Change library name to represent unification of offline and online 19 | -------------------------------------------------------------------------------- /assets/breakout.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/assets/breakout.gif -------------------------------------------------------------------------------- /assets/breakout.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/assets/breakout.png -------------------------------------------------------------------------------- /assets/hopper.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/assets/hopper.png -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/assets/logo.png -------------------------------------------------------------------------------- /assets/mujoco_hopper.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/assets/mujoco_hopper.gif -------------------------------------------------------------------------------- /d3rlpy/__init__.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=protected-access 2 | import random 3 | 4 | import gymnasium 5 | import numpy as np 6 | import torch 7 | 8 | from . import ( 9 | algos, 10 | dataset, 11 | datasets, 12 | distributed, 13 | envs, 14 | logging, 15 | metrics, 16 | models, 17 | notebook_utils, 18 | ope, 19 | optimizers, 20 | preprocessing, 21 | tokenizers, 22 | types, 23 | ) 24 | from ._version import __version__ 25 | from .base import load_learnable 26 | from .constants import ActionSpace, LoggingStrategy, PositionEncodingType 27 | from .healthcheck import run_healthcheck 28 | from .torch_utility import Modules, TorchMiniBatch 29 | 30 | __all__ = [ 31 | "algos", 32 | "dataset", 33 | "datasets", 34 | "distributed", 35 | "envs", 36 | "logging", 37 | "metrics", 38 | "models", 39 | "optimizers", 40 | "notebook_utils", 41 | "ope", 42 | "preprocessing", 43 | "tokenizers", 44 | "types", 45 | "__version__", 46 | "load_learnable", 47 | "ActionSpace", 48 | "LoggingStrategy", 49 | "PositionEncodingType", 50 | "Modules", 51 | "TorchMiniBatch", 52 | "seed", 53 | ] 54 | 55 | 56 | def seed(n: int) -> None: 57 | """Sets random seed value. 58 | 59 | Args: 60 | n (int): seed value. 61 | """ 62 | random.seed(n) 63 | np.random.seed(n) 64 | torch.manual_seed(n) 65 | torch.cuda.manual_seed(n) 66 | torch.backends.cudnn.deterministic = True 67 | 68 | 69 | # run healthcheck 70 | run_healthcheck() 71 | 72 | if torch.cuda.is_available(): 73 | # enable autograd compilation 74 | torch._dynamo.config.compiled_autograd = True 75 | torch.set_float32_matmul_precision("high") 76 | 77 | # register Shimmy if available 78 | try: 79 | import shimmy 80 | 81 | gymnasium.register_envs(shimmy) 82 | logging.LOG.info("Register Shimmy environments.") 83 | except ImportError: 84 | pass 85 | -------------------------------------------------------------------------------- /d3rlpy/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = "2.8.1" 2 | -------------------------------------------------------------------------------- /d3rlpy/algos/__init__.py: -------------------------------------------------------------------------------- 1 | from .qlearning import * 2 | from .transformer import * 3 | from .utility import * 4 | -------------------------------------------------------------------------------- /d3rlpy/algos/qlearning/__init__.py: -------------------------------------------------------------------------------- 1 | from .awac import * 2 | from .base import * 3 | from .bc import * 4 | from .bcq import * 5 | from .bear import * 6 | from .cal_ql import * 7 | from .cql import * 8 | from .crr import * 9 | from .ddpg import * 10 | from .dqn import * 11 | from .explorers import * 12 | from .iql import * 13 | from .nfq import * 14 | from .plas import * 15 | from .prdc import * 16 | from .random_policy import * 17 | from .rebrac import * 18 | from .sac import * 19 | from .td3 import * 20 | from .td3_plus_bc import * 21 | -------------------------------------------------------------------------------- /d3rlpy/algos/qlearning/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .awac_impl import * 2 | from .bc_impl import * 3 | from .bcq_impl import * 4 | from .bear_impl import * 5 | from .cal_ql_impl import * 6 | from .cql_impl import * 7 | from .crr_impl import * 8 | from .ddpg_impl import * 9 | from .dqn_impl import * 10 | from .iql_impl import * 11 | from .plas_impl import * 12 | from .prdc_impl import * 13 | from .rebrac_impl import * 14 | from .sac_impl import * 15 | from .td3_impl import * 16 | from .td3_plus_bc_impl import * 17 | -------------------------------------------------------------------------------- /d3rlpy/algos/qlearning/torch/cal_ql_impl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ....types import TorchObservation 4 | from .cql_impl import CQLImpl 5 | 6 | __all__ = ["CalQLImpl"] 7 | 8 | 9 | class CalQLImpl(CQLImpl): 10 | def _compute_policy_is_values( 11 | self, 12 | policy_obs: TorchObservation, 13 | value_obs: TorchObservation, 14 | returns_to_go: torch.Tensor, 15 | ) -> tuple[torch.Tensor, torch.Tensor]: 16 | values, log_probs = super()._compute_policy_is_values( 17 | policy_obs=policy_obs, 18 | value_obs=value_obs, 19 | returns_to_go=returns_to_go, 20 | ) 21 | return torch.maximum(values, returns_to_go.view(1, -1, 1)), log_probs 22 | -------------------------------------------------------------------------------- /d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=too-many-ancestors 2 | import dataclasses 3 | 4 | import torch 5 | 6 | from ....models.torch import ActionOutput, ContinuousEnsembleQFunctionForwarder 7 | from ....torch_utility import TorchMiniBatch 8 | from ....types import Shape 9 | from .ddpg_impl import DDPGBaseActorLoss, DDPGModules 10 | from .td3_impl import TD3Impl 11 | 12 | __all__ = ["TD3PlusBCImpl"] 13 | 14 | 15 | @dataclasses.dataclass(frozen=True) 16 | class TD3PlusBCActorLoss(DDPGBaseActorLoss): 17 | bc_loss: torch.Tensor 18 | 19 | 20 | class TD3PlusBCImpl(TD3Impl): 21 | _alpha: float 22 | 23 | def __init__( 24 | self, 25 | observation_shape: Shape, 26 | action_size: int, 27 | modules: DDPGModules, 28 | q_func_forwarder: ContinuousEnsembleQFunctionForwarder, 29 | targ_q_func_forwarder: ContinuousEnsembleQFunctionForwarder, 30 | gamma: float, 31 | tau: float, 32 | target_smoothing_sigma: float, 33 | target_smoothing_clip: float, 34 | alpha: float, 35 | update_actor_interval: int, 36 | compiled: bool, 37 | device: str, 38 | ): 39 | super().__init__( 40 | observation_shape=observation_shape, 41 | action_size=action_size, 42 | modules=modules, 43 | q_func_forwarder=q_func_forwarder, 44 | targ_q_func_forwarder=targ_q_func_forwarder, 45 | gamma=gamma, 46 | tau=tau, 47 | target_smoothing_sigma=target_smoothing_sigma, 48 | target_smoothing_clip=target_smoothing_clip, 49 | update_actor_interval=update_actor_interval, 50 | compiled=compiled, 51 | device=device, 52 | ) 53 | self._alpha = alpha 54 | 55 | def compute_actor_loss( 56 | self, batch: TorchMiniBatch, action: ActionOutput 57 | ) -> TD3PlusBCActorLoss: 58 | q_t = self._q_func_forwarder.compute_expected_q( 59 | batch.observations, action.squashed_mu, "none" 60 | )[0] 61 | lam = self._alpha / (q_t.abs().mean()).detach() 62 | bc_loss = ((batch.actions - action.squashed_mu) ** 2).mean() 63 | return TD3PlusBCActorLoss( 64 | actor_loss=lam * -q_t.mean() + bc_loss, bc_loss=bc_loss 65 | ) 66 | -------------------------------------------------------------------------------- /d3rlpy/algos/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .action_samplers import * 2 | from .base import * 3 | from .decision_transformer import * 4 | from .inputs import * 5 | from .tacr import * 6 | -------------------------------------------------------------------------------- /d3rlpy/algos/transformer/action_samplers.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, Union 2 | 3 | import numpy as np 4 | 5 | from ...types import NDArray 6 | 7 | __all__ = [ 8 | "TransformerActionSampler", 9 | "IdentityTransformerActionSampler", 10 | "SoftmaxTransformerActionSampler", 11 | "GreedyTransformerActionSampler", 12 | ] 13 | 14 | 15 | class TransformerActionSampler(Protocol): 16 | r"""Interface of TransformerActionSampler.""" 17 | 18 | def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]: 19 | r"""Returns sampled action from Transformer output. 20 | 21 | Args: 22 | transformer_output: Output of Transformer algorithms. 23 | 24 | Returns: 25 | Sampled action. 26 | """ 27 | raise NotImplementedError 28 | 29 | 30 | class IdentityTransformerActionSampler(TransformerActionSampler): 31 | r"""Identity action-sampler. 32 | 33 | This class implements identity function to process Transformer output. 34 | Sampled action is the exactly same as ``transformer_output``. 35 | """ 36 | 37 | def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]: 38 | return transformer_output 39 | 40 | 41 | class SoftmaxTransformerActionSampler(TransformerActionSampler): 42 | r"""Softmax action-sampler. 43 | 44 | This class implements softmax function to sample action from discrete 45 | probability distribution. 46 | 47 | Args: 48 | temperature (int): Softmax temperature. 49 | """ 50 | 51 | _temperature: float 52 | 53 | def __init__(self, temperature: float = 1.0): 54 | self._temperature = temperature 55 | 56 | def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]: 57 | assert transformer_output.ndim == 1 58 | logits = transformer_output / self._temperature 59 | x = np.exp(logits - np.max(logits)) 60 | probs = x / np.sum(x) 61 | action = np.random.choice(probs.shape[0], p=probs) 62 | return int(action) 63 | 64 | 65 | class GreedyTransformerActionSampler(TransformerActionSampler): 66 | r"""Greedy action-sampler. 67 | 68 | This class implements greedy function to determine action from discrte 69 | probability distribution. 70 | """ 71 | 72 | def __call__(self, transformer_output: NDArray) -> Union[NDArray, int]: 73 | assert transformer_output.ndim == 1 74 | return int(np.argmax(transformer_output)) 75 | -------------------------------------------------------------------------------- /d3rlpy/algos/transformer/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .decision_transformer_impl import * 2 | from .tacr_impl import * 3 | -------------------------------------------------------------------------------- /d3rlpy/constants.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | from ._version import __version__ 4 | 5 | __all__ = [ 6 | "IMPL_NOT_INITIALIZED_ERROR", 7 | "ALGO_NOT_GIVEN_ERROR", 8 | "DISCRETE_ACTION_SPACE_MISMATCH_ERROR", 9 | "CONTINUOUS_ACTION_SPACE_MISMATCH_ERROR", 10 | "ActionSpace", 11 | "PositionEncodingType", 12 | ] 13 | 14 | IMPL_NOT_INITIALIZED_ERROR = ( 15 | "The neural network parameters are not " 16 | "initialized. Pleaes call build_with_dataset, " 17 | "build_with_env, or directly call fit or " 18 | "fit_online method." 19 | ) 20 | 21 | ALGO_NOT_GIVEN_ERROR = ( 22 | "The algorithm to evaluate is not given. Please give the trained algorithm" 23 | " to the argument." 24 | ) 25 | 26 | DISCRETE_ACTION_SPACE_MISMATCH_ERROR = ( 27 | "The action-space of the given dataset is not compatible with the" 28 | " algorithm. Please use discrete action-space algorithms. The algorithms" 29 | " list is available below.\n" 30 | f"https://d3rlpy.readthedocs.io/en/v{__version__}/references/algos.html" 31 | ) 32 | 33 | CONTINUOUS_ACTION_SPACE_MISMATCH_ERROR = ( 34 | "The action-space of the given dataset is not compatible with the" 35 | " algorithm. Please use continuous action-space algorithms. The algorithm" 36 | " list is available below.\n" 37 | f"https://d3rlpy.readthedocs.io/en/v{__version__}/references/algos.html" 38 | ) 39 | 40 | 41 | class ActionSpace(Enum): 42 | CONTINUOUS = 1 43 | DISCRETE = 2 44 | BOTH = 3 45 | 46 | 47 | class PositionEncodingType(Enum): 48 | SIMPLE = "simple" 49 | GLOBAL = "global" 50 | 51 | 52 | class LoggingStrategy(Enum): 53 | STEPS = "steps" 54 | EPOCH = "epoch" 55 | -------------------------------------------------------------------------------- /d3rlpy/dataclass_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any 3 | 4 | import torch 5 | 6 | __all__ = ["asdict_without_copy", "asdict_as_float"] 7 | 8 | 9 | def asdict_without_copy(obj: Any) -> dict[str, Any]: 10 | assert dataclasses.is_dataclass(obj) 11 | fields = dataclasses.fields(obj) 12 | return {field.name: getattr(obj, field.name) for field in fields} 13 | 14 | 15 | def asdict_as_float(obj: Any) -> dict[str, float]: 16 | assert dataclasses.is_dataclass(obj) 17 | fields = dataclasses.fields(obj) 18 | ret: dict[str, float] = {} 19 | for field in fields: 20 | value = getattr(obj, field.name) 21 | if isinstance(value, torch.Tensor): 22 | assert ( 23 | value.ndim == 0 24 | ), f"{field.name} needs to be scalar. {value.shape}." 25 | ret[field.name] = float(value.cpu().detach().numpy()) 26 | else: 27 | ret[field.name] = float(value) 28 | return ret 29 | -------------------------------------------------------------------------------- /d3rlpy/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .buffers import * 2 | from .compat import * 3 | from .components import * 4 | from .episode_generator import * 5 | from .io import * 6 | from .mini_batch import * 7 | from .replay_buffer import * 8 | from .trajectory_slicers import * 9 | from .transition_pickers import * 10 | from .utils import * 11 | from .writers import * 12 | -------------------------------------------------------------------------------- /d3rlpy/distributed.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import torch.distributed as dist 4 | 5 | from .logging import set_log_context 6 | 7 | __all__ = ["init_process_group", "destroy_process_group"] 8 | 9 | 10 | @dataclasses.dataclass(frozen=True) 11 | class DistributedWorkerInfo: 12 | rank: int 13 | backend: str 14 | world_size: int 15 | 16 | 17 | def init_process_group(backend: str) -> int: 18 | """Initializes process group of distributed workers. 19 | 20 | Internally, distributed worker information is injected to log outputs. 21 | 22 | Args: 23 | backend: Backend of communication. Available options are ``gloo``, 24 | ``mpi`` and ``nccl``. 25 | 26 | Returns: 27 | Rank of the current process. 28 | """ 29 | dist.init_process_group(backend) 30 | rank = dist.get_rank() 31 | set_log_context( 32 | distributed=DistributedWorkerInfo( 33 | rank=dist.get_rank(), 34 | backend=dist.get_backend(), 35 | world_size=dist.get_world_size(), 36 | ) 37 | ) 38 | return int(rank) 39 | 40 | 41 | def destroy_process_group() -> None: 42 | """Destroys process group of distributed workers.""" 43 | dist.destroy_process_group() 44 | -------------------------------------------------------------------------------- /d3rlpy/envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .utility import * 2 | from .wrappers import * 3 | -------------------------------------------------------------------------------- /d3rlpy/envs/utility.py: -------------------------------------------------------------------------------- 1 | from ..types import GymEnv 2 | 3 | __all__ = ["seed_env"] 4 | 5 | 6 | def seed_env(env: GymEnv, seed: int) -> None: 7 | env.reset(seed=seed) 8 | -------------------------------------------------------------------------------- /d3rlpy/healthcheck.py: -------------------------------------------------------------------------------- 1 | __all__ = ["run_healthcheck"] 2 | 3 | 4 | def run_healthcheck() -> None: 5 | _check_gym() 6 | _check_pytorch() 7 | 8 | 9 | def _check_gym() -> None: 10 | import gymnasium 11 | from gym.version import VERSION 12 | 13 | if VERSION < "0.26.0": 14 | raise ValueError( 15 | "Gym version is too outdated. " 16 | "Please upgrade Gym to 0.26.0 or later." 17 | ) 18 | 19 | if gymnasium.__version__ < "1.0.0": 20 | raise ValueError( 21 | "Gymnasium version is too outdated. " 22 | "Please upgrade Gymnasium to 1.0.0 or later." 23 | ) 24 | 25 | 26 | def _check_pytorch() -> None: 27 | import torch 28 | 29 | if torch.__version__ < "2.5.0": 30 | raise ValueError( 31 | "PyTorch version is too outdated. " 32 | "Please upgrade PyTorch to 2.5.0 or later." 33 | ) 34 | -------------------------------------------------------------------------------- /d3rlpy/interface.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Protocol, Union 2 | 3 | from .preprocessing import ActionScaler, ObservationScaler, RewardScaler 4 | from .types import NDArray, Observation 5 | 6 | __all__ = ["QLearningAlgoProtocol", "StatefulTransformerAlgoProtocol"] 7 | 8 | 9 | class QLearningAlgoProtocol(Protocol): 10 | def predict(self, x: Observation) -> NDArray: ... 11 | 12 | def predict_value(self, x: Observation, action: NDArray) -> NDArray: ... 13 | 14 | def sample_action(self, x: Observation) -> NDArray: ... 15 | 16 | @property 17 | def gamma(self) -> float: ... 18 | 19 | @property 20 | def observation_scaler(self) -> Optional[ObservationScaler]: ... 21 | 22 | @property 23 | def action_scaler(self) -> Optional[ActionScaler]: ... 24 | 25 | @property 26 | def reward_scaler(self) -> Optional[RewardScaler]: ... 27 | 28 | @property 29 | def action_size(self) -> Optional[int]: ... 30 | 31 | 32 | class StatefulTransformerAlgoProtocol(Protocol): 33 | def predict(self, x: Observation, reward: float) -> Union[NDArray, int]: ... 34 | 35 | def reset(self) -> None: ... 36 | -------------------------------------------------------------------------------- /d3rlpy/itertools.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Iterator, TypeVar 2 | 3 | __all__ = ["last_flag", "first_flag"] 4 | 5 | T = TypeVar("T") 6 | 7 | 8 | def last_flag(iterator: Iterable[T]) -> Iterator[tuple[bool, T]]: 9 | items = list(iterator) 10 | for i, item in enumerate(items): 11 | yield i == len(items) - 1, item 12 | 13 | 14 | def first_flag(iterator: Iterable[T]) -> Iterator[tuple[bool, T]]: 15 | items = list(iterator) 16 | for i, item in enumerate(items): 17 | yield i == 0, item 18 | -------------------------------------------------------------------------------- /d3rlpy/logging/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_adapter import * 2 | from .logger import * 3 | from .noop_adapter import * 4 | from .tensorboard_adapter import * 5 | from .utils import * 6 | from .wandb_adapter import * 7 | -------------------------------------------------------------------------------- /d3rlpy/logging/noop_adapter.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | from .logger import ( 4 | AlgProtocol, 5 | LoggerAdapter, 6 | LoggerAdapterFactory, 7 | SaveProtocol, 8 | ) 9 | 10 | __all__ = ["NoopAdapter", "NoopAdapterFactory"] 11 | 12 | 13 | class NoopAdapter(LoggerAdapter): 14 | r"""NoopAdapter class. 15 | 16 | This class does not save anything. This can be used especially when programs 17 | are not allowed to write things to disks. 18 | """ 19 | 20 | def write_params(self, params: dict[str, Any]) -> None: 21 | pass 22 | 23 | def before_write_metric(self, epoch: int, step: int) -> None: 24 | pass 25 | 26 | def write_metric( 27 | self, epoch: int, step: int, name: str, value: float 28 | ) -> None: 29 | pass 30 | 31 | def after_write_metric(self, epoch: int, step: int) -> None: 32 | pass 33 | 34 | def save_model(self, epoch: int, algo: SaveProtocol) -> None: 35 | pass 36 | 37 | def close(self) -> None: 38 | pass 39 | 40 | def watch_model( 41 | self, 42 | epoch: int, 43 | step: int, 44 | ) -> None: 45 | pass 46 | 47 | 48 | class NoopAdapterFactory(LoggerAdapterFactory): 49 | r"""NoopAdapterFactory class. 50 | 51 | This class instantiates ``NoopAdapter`` object. 52 | """ 53 | 54 | def create( 55 | self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int 56 | ) -> NoopAdapter: 57 | return NoopAdapter() 58 | -------------------------------------------------------------------------------- /d3rlpy/logging/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Sequence 2 | 3 | from .logger import ( 4 | AlgProtocol, 5 | LoggerAdapter, 6 | LoggerAdapterFactory, 7 | SaveProtocol, 8 | ) 9 | 10 | __all__ = ["CombineAdapter", "CombineAdapterFactory"] 11 | 12 | 13 | class CombineAdapter(LoggerAdapter): 14 | r"""CombineAdapter class. 15 | 16 | This class combines multiple LoggerAdapter to write metrics through 17 | different adapters at the same time. 18 | 19 | Args: 20 | adapters (Sequence[LoggerAdapter]): List of LoggerAdapter. 21 | """ 22 | 23 | def __init__(self, adapters: Sequence[LoggerAdapter]): 24 | self._adapters = adapters 25 | 26 | def write_params(self, params: dict[str, Any]) -> None: 27 | for adapter in self._adapters: 28 | adapter.write_params(params) 29 | 30 | def before_write_metric(self, epoch: int, step: int) -> None: 31 | for adapter in self._adapters: 32 | adapter.before_write_metric(epoch, step) 33 | 34 | def write_metric( 35 | self, epoch: int, step: int, name: str, value: float 36 | ) -> None: 37 | for adapter in self._adapters: 38 | adapter.write_metric(epoch, step, name, value) 39 | 40 | def after_write_metric(self, epoch: int, step: int) -> None: 41 | for adapter in self._adapters: 42 | adapter.after_write_metric(epoch, step) 43 | 44 | def save_model(self, epoch: int, algo: SaveProtocol) -> None: 45 | for adapter in self._adapters: 46 | adapter.save_model(epoch, algo) 47 | 48 | def close(self) -> None: 49 | for adapter in self._adapters: 50 | adapter.close() 51 | 52 | def watch_model( 53 | self, 54 | epoch: int, 55 | step: int, 56 | ) -> None: 57 | for adapter in self._adapters: 58 | adapter.watch_model(epoch, step) 59 | 60 | 61 | class CombineAdapterFactory(LoggerAdapterFactory): 62 | r"""CombineAdapterFactory class. 63 | 64 | This class instantiates ``CombineAdapter`` object. 65 | 66 | Args: 67 | adapter_factories (Sequence[LoggerAdapterFactory]): 68 | List of LoggerAdapterFactory. 69 | """ 70 | 71 | _adapter_factories: Sequence[LoggerAdapterFactory] 72 | 73 | def __init__(self, adapter_factories: Sequence[LoggerAdapterFactory]): 74 | self._adapter_factories = adapter_factories 75 | 76 | def create( 77 | self, algo: AlgProtocol, experiment_name: str, n_steps_per_epoch: int 78 | ) -> CombineAdapter: 79 | return CombineAdapter( 80 | [ 81 | factory.create(algo, experiment_name, n_steps_per_epoch) 82 | for factory in self._adapter_factories 83 | ] 84 | ) 85 | -------------------------------------------------------------------------------- /d3rlpy/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluators import * 2 | from .utility import * 3 | -------------------------------------------------------------------------------- /d3rlpy/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .builders import * 2 | from .encoders import * 3 | from .q_functions import * 4 | -------------------------------------------------------------------------------- /d3rlpy/models/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .distributions import * 2 | from .encoders import * 3 | from .imitators import * 4 | from .parameters import * 5 | from .policies import * 6 | from .q_functions import * 7 | from .transformers import * 8 | from .v_functions import * 9 | -------------------------------------------------------------------------------- /d3rlpy/models/torch/parameters.py: -------------------------------------------------------------------------------- 1 | from typing import NoReturn 2 | 3 | import torch 4 | from torch import nn 5 | 6 | __all__ = ["Parameter", "get_parameter"] 7 | 8 | 9 | class Parameter(nn.Module): # type: ignore 10 | _parameter: nn.Parameter 11 | 12 | def __init__(self, data: torch.Tensor): 13 | super().__init__() 14 | self._parameter = nn.Parameter(data) 15 | 16 | def forward(self) -> NoReturn: 17 | raise NotImplementedError( 18 | "Parameter does not support __call__. Use parameter property " 19 | "instead." 20 | ) 21 | 22 | def __call__(self) -> NoReturn: 23 | raise NotImplementedError( 24 | "Parameter does not support __call__. Use parameter property " 25 | "instead." 26 | ) 27 | 28 | 29 | def get_parameter(parameter: Parameter) -> nn.Parameter: 30 | return next(parameter.parameters()) 31 | -------------------------------------------------------------------------------- /d3rlpy/models/torch/q_functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from .ensemble_q_function import * 3 | from .iqn_q_function import * 4 | from .mean_q_function import * 5 | from .qr_q_function import * 6 | from .utility import * 7 | -------------------------------------------------------------------------------- /d3rlpy/models/torch/v_functions.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | 7 | from ...types import TorchObservation 8 | from .encoders import Encoder 9 | 10 | __all__ = ["ValueFunction", "compute_v_function_error"] 11 | 12 | 13 | class ValueFunction(nn.Module): # type: ignore 14 | _encoder: Encoder 15 | _fc: nn.Linear 16 | 17 | def __init__(self, encoder: Encoder, hidden_size: int): 18 | super().__init__() 19 | self._encoder = encoder 20 | self._fc = nn.Linear(hidden_size, 1) 21 | 22 | def forward(self, x: TorchObservation) -> torch.Tensor: 23 | h = self._encoder(x) 24 | return cast(torch.Tensor, self._fc(h)) 25 | 26 | def __call__(self, x: TorchObservation) -> torch.Tensor: 27 | return cast(torch.Tensor, super().__call__(x)) 28 | 29 | 30 | def compute_v_function_error( 31 | v_function: ValueFunction, 32 | observations: TorchObservation, 33 | target: torch.Tensor, 34 | ) -> torch.Tensor: 35 | v_t = v_function(observations) 36 | loss = F.mse_loss(v_t, target) 37 | return loss 38 | -------------------------------------------------------------------------------- /d3rlpy/models/utility.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from ..torch_utility import GEGLU, Swish 4 | 5 | __all__ = ["create_activation"] 6 | 7 | 8 | def create_activation(activation_type: str) -> nn.Module: 9 | if activation_type == "relu": 10 | return nn.ReLU() 11 | elif activation_type == "gelu": 12 | return nn.GELU() 13 | elif activation_type == "tanh": 14 | return nn.Tanh() 15 | elif activation_type == "swish": 16 | return Swish() 17 | elif activation_type == "none": 18 | return nn.Identity() 19 | elif activation_type == "geglu": 20 | return GEGLU() 21 | raise ValueError("invalid activation_type.") 22 | -------------------------------------------------------------------------------- /d3rlpy/notebook_utils.py: -------------------------------------------------------------------------------- 1 | import base64 2 | 3 | __all__ = ["start_virtual_display", "render_video"] 4 | 5 | 6 | def start_virtual_display() -> None: 7 | """Starts virtual display.""" 8 | try: 9 | from pyvirtualdisplay.display import Display 10 | 11 | display = Display() 12 | display.start() 13 | except ImportError as e: 14 | raise ImportError( 15 | "pyvirtualdisplay is not installed.\n" 16 | "$ pip install pyvirtualdisplay" 17 | ) from e 18 | 19 | 20 | def render_video(path: str) -> None: 21 | """Renders video file in Jupyter Notebook. 22 | 23 | Args: 24 | path: Path to video file. 25 | """ 26 | try: 27 | from IPython import display as ipythondisplay 28 | from IPython.core.display import HTML 29 | 30 | with open(path, "r+b") as f: 31 | encoded = base64.b64encode(f.read()) 32 | template = """ 33 | 36 | """ 37 | ipythondisplay.display( 38 | HTML(data=template.format(encoded.decode("ascii"))) 39 | ) 40 | except ImportError as e: 41 | raise ImportError( 42 | "This should be executed inside Jupyter Notebook." 43 | ) from e 44 | -------------------------------------------------------------------------------- /d3rlpy/ope/__init__.py: -------------------------------------------------------------------------------- 1 | from .fqe import * 2 | -------------------------------------------------------------------------------- /d3rlpy/ope/torch/__init__.py: -------------------------------------------------------------------------------- 1 | from .fqe_impl import * 2 | -------------------------------------------------------------------------------- /d3rlpy/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .lr_schedulers import * 2 | from .optimizers import * 3 | -------------------------------------------------------------------------------- /d3rlpy/optimizers/lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | from torch.optim import Optimizer 4 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR, LRScheduler 5 | 6 | from ..serializable_config import ( 7 | DynamicConfig, 8 | generate_optional_config_generation, 9 | ) 10 | 11 | __all__ = [ 12 | "LRSchedulerFactory", 13 | "WarmupSchedulerFactory", 14 | "CosineAnnealingLRFactory", 15 | "make_lr_scheduler_field", 16 | ] 17 | 18 | 19 | @dataclasses.dataclass() 20 | class LRSchedulerFactory(DynamicConfig): 21 | """A factory class that creates a learning rate scheduler a lazy way.""" 22 | 23 | def create(self, optim: Optimizer) -> LRScheduler: 24 | """Returns a learning rate scheduler object. 25 | 26 | Args: 27 | optim: PyTorch optimizer. 28 | 29 | Returns: 30 | Learning rate scheduler. 31 | """ 32 | raise NotImplementedError 33 | 34 | 35 | @dataclasses.dataclass() 36 | class WarmupSchedulerFactory(LRSchedulerFactory): 37 | r"""A warmup learning rate scheduler. 38 | 39 | .. math:: 40 | 41 | lr = \max((t + 1) / warmup\_steps, 1) 42 | 43 | Args: 44 | warmup_steps: Warmup steps. 45 | """ 46 | 47 | warmup_steps: int 48 | 49 | def create(self, optim: Optimizer) -> LRScheduler: 50 | return LambdaLR( 51 | optim, 52 | lambda steps: min((steps + 1) / self.warmup_steps, 1), 53 | ) 54 | 55 | @staticmethod 56 | def get_type() -> str: 57 | return "warmup" 58 | 59 | 60 | @dataclasses.dataclass() 61 | class CosineAnnealingLRFactory(LRSchedulerFactory): 62 | """A cosine annealing learning rate scheduler. 63 | 64 | Args: 65 | T_max: Maximum time step. 66 | eta_min: Minimum learning rate. 67 | last_epoch: Last epoch. 68 | """ 69 | 70 | T_max: int 71 | eta_min: float = 0.0 72 | last_epoch: int = -1 73 | 74 | def create(self, optim: Optimizer) -> LRScheduler: 75 | return CosineAnnealingLR( 76 | optim, 77 | T_max=self.T_max, 78 | eta_min=self.eta_min, 79 | last_epoch=self.last_epoch, 80 | ) 81 | 82 | @staticmethod 83 | def get_type() -> str: 84 | return "cosine_annealing" 85 | 86 | 87 | register_lr_scheduler_factory, make_lr_scheduler_field = ( 88 | generate_optional_config_generation( 89 | LRSchedulerFactory, 90 | ) 91 | ) 92 | 93 | register_lr_scheduler_factory(WarmupSchedulerFactory) 94 | register_lr_scheduler_factory(CosineAnnealingLRFactory) 95 | -------------------------------------------------------------------------------- /d3rlpy/preprocessing/__init__.py: -------------------------------------------------------------------------------- 1 | from .action_scalers import * 2 | from .base import * 3 | from .observation_scalers import * 4 | from .reward_scalers import * 5 | -------------------------------------------------------------------------------- /d3rlpy/tokenizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .tokenizers import * 2 | from .utils import * 3 | -------------------------------------------------------------------------------- /d3rlpy/tokenizers/tokenizers.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol, runtime_checkable 2 | 3 | import numpy as np 4 | 5 | from ..types import Float32NDArray, Int32NDArray, NDArray 6 | from .utils import mu_law_decode, mu_law_encode 7 | 8 | __all__ = [ 9 | "Tokenizer", 10 | "FloatTokenizer", 11 | ] 12 | 13 | 14 | @runtime_checkable 15 | class Tokenizer(Protocol): 16 | def __call__(self, x: NDArray) -> NDArray: ... 17 | 18 | def decode(self, y: Int32NDArray) -> NDArray: ... 19 | 20 | 21 | class FloatTokenizer(Tokenizer): 22 | _bins: Float32NDArray 23 | _use_mu_law_encode: bool 24 | _mu: float 25 | _basis: float 26 | _token_offset: int 27 | 28 | def __init__( 29 | self, 30 | num_bins: int, 31 | minimum: float = -1.0, 32 | maximum: float = 1.0, 33 | use_mu_law_encode: bool = True, 34 | mu: float = 100.0, 35 | basis: float = 256.0, 36 | token_offset: int = 0, 37 | ): 38 | self._bins = np.array( 39 | (maximum - minimum) * np.arange(num_bins) / num_bins + minimum, 40 | dtype=np.float32, 41 | ) 42 | self._use_mu_law_encode = use_mu_law_encode 43 | self._mu = mu 44 | self._basis = basis 45 | self._token_offset = token_offset 46 | 47 | def __call__(self, x: NDArray) -> Int32NDArray: 48 | if self._use_mu_law_encode: 49 | x = mu_law_encode(x, self._mu, self._basis) 50 | return np.digitize(x, self._bins) - 1 + self._token_offset 51 | 52 | def decode(self, y: Int32NDArray) -> NDArray: 53 | x = self._bins[y - self._token_offset] 54 | if self._use_mu_law_encode: 55 | x = mu_law_decode(x, mu=self._mu, basis=self._basis) 56 | return x # type: ignore 57 | -------------------------------------------------------------------------------- /d3rlpy/tokenizers/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from ..types import Float32NDArray, NDArray 4 | 5 | __all__ = ["mu_law_encode", "mu_law_decode"] 6 | 7 | 8 | def mu_law_encode(x: NDArray, mu: float, basis: float) -> Float32NDArray: 9 | x = np.array(x, dtype=np.float32) 10 | y = np.sign(x) * np.log(np.abs(x) * mu + 1.0) / np.log(basis * mu + 1.0) 11 | return y # type: ignore 12 | 13 | 14 | def mu_law_decode(y: Float32NDArray, mu: float, basis: float) -> Float32NDArray: 15 | x = np.sign(y) * (np.exp(np.log(basis * mu + 1.0) * np.abs(y)) - 1.0) / mu 16 | return x # type: ignore 17 | -------------------------------------------------------------------------------- /d3rlpy/types.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Mapping, Protocol, Sequence, Union, runtime_checkable 2 | 3 | import gym 4 | import gymnasium 5 | import numpy as np 6 | import numpy.typing as npt 7 | import torch 8 | from torch.optim import Optimizer 9 | 10 | __all__ = [ 11 | "NDArray", 12 | "Float32NDArray", 13 | "Int32NDArray", 14 | "UInt8NDArray", 15 | "DType", 16 | "Observation", 17 | "ObservationSequence", 18 | "Shape", 19 | "TorchObservation", 20 | "GymEnv", 21 | "OptimizerWrapperProto", 22 | ] 23 | 24 | 25 | NDArray = npt.NDArray[Any] 26 | Float32NDArray = npt.NDArray[np.float32] 27 | Int32NDArray = npt.NDArray[np.int32] 28 | UInt8NDArray = npt.NDArray[np.uint8] 29 | DType = npt.DTypeLike 30 | 31 | Observation = Union[NDArray, Sequence[NDArray]] 32 | ObservationSequence = Union[NDArray, Sequence[NDArray]] 33 | Shape = Union[Sequence[int], Sequence[Sequence[int]]] 34 | TorchObservation = Union[torch.Tensor, Sequence[torch.Tensor]] 35 | 36 | GymEnv = Union[gym.Env[Any, Any], gymnasium.Env[Any, Any]] 37 | 38 | 39 | @runtime_checkable 40 | class OptimizerWrapperProto(Protocol): 41 | @property 42 | def optim(self) -> Optimizer: 43 | raise NotImplementedError 44 | 45 | def state_dict(self) -> Mapping[str, Any]: 46 | raise NotImplementedError 47 | 48 | def load_state_dict(self, state_dict: Mapping[str, Any]) -> None: 49 | raise NotImplementedError 50 | -------------------------------------------------------------------------------- /dev.requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-cov 3 | onnxruntime 4 | onnx 5 | matplotlib 6 | tensorboardX 7 | wandb 8 | mypy 9 | numpy<2 10 | docformatter 11 | ruff 12 | black 13 | minari[all]>=0.5.2 14 | gymnasium-robotics>=1.3.1 15 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.5.1-cuda12.4-cudnn9-devel 2 | 3 | # this needs to avoid time zone question 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | # install dependencies 7 | RUN apt-get update && \ 8 | apt-get install -y --no-install-recommends \ 9 | build-essential \ 10 | software-properties-common \ 11 | cmake \ 12 | git \ 13 | wget \ 14 | unzip \ 15 | python3-dev \ 16 | zlib1g \ 17 | zlib1g-dev \ 18 | libgl1-mesa-dri \ 19 | libgl1-mesa-glx \ 20 | libglu1-mesa-dev \ 21 | libasio-dev \ 22 | pkg-config \ 23 | python3-tk \ 24 | libsm6 \ 25 | libxext6 \ 26 | libxrender1 \ 27 | libpcre3-dev && \ 28 | pip install --no-cache-dir \ 29 | Cython==0.29.28 \ 30 | git+https://github.com/takuseno/d4rl-atari && \ 31 | rm -rf /var/lib/apt/lists/* && \ 32 | rm -rf /tmp/* && \ 33 | wget https://github.com/takuseno/d3rlpy/archive/master.zip && \ 34 | unzip master.zip && \ 35 | cd d3rlpy-master && \ 36 | pip install --no-cache-dir . && \ 37 | cd .. && \ 38 | rm -rf d3rlpy-master master.zip 39 | 40 | EXPOSE 6006 41 | 42 | CMD ["tail", "-f", "/dev/null"] 43 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/_static/css/d3rlpy.css: -------------------------------------------------------------------------------- 1 | 2 | :root{ 3 | --main-bg-color: #2c3e50; 4 | } 5 | 6 | /* Docs background */ 7 | .wy-side-nav-search{ 8 | background-color: var(--main-bg-color); 9 | } 10 | 11 | /* Mobile version */ 12 | .wy-nav-top{ 13 | background-color: var(--main-bg-color); 14 | } 15 | -------------------------------------------------------------------------------- /docs/_static/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/_static/logo.png -------------------------------------------------------------------------------- /docs/_templates/autosummary/class.rst: -------------------------------------------------------------------------------- 1 | {{ fullname }} 2 | {{ underline }} 3 | 4 | .. currentmodule:: {{ module }} 5 | 6 | .. autoclass:: {{ objname }} 7 | 8 | .. 9 | Methods 10 | 11 | {% block methods %} 12 | 13 | .. rubric:: Methods 14 | 15 | .. 16 | Special methods 17 | 18 | {% for item in ('__call__', '__enter__', '__exit__', '__getitem__', '__setitem__', '__len__', '__next__', '__iter__', '__copy__') %} 19 | {% if item in all_methods %} 20 | .. automethod:: {{ item }} 21 | {% endif %} 22 | {%- endfor %} 23 | 24 | .. 25 | Ordinary methods 26 | 27 | {% for item in methods %} 28 | {% if item not in ('__init__',) %} 29 | .. automethod:: {{ item }} 30 | {% endif %} 31 | {%- endfor %} 32 | 33 | {% endblock %} 34 | 35 | .. 36 | Attributes 37 | 38 | {% block attributes %} {% if attributes %} 39 | 40 | .. rubric:: Attributes 41 | 42 | {% for item in attributes %} 43 | .. autoattribute:: {{ item }} 44 | {%- endfor %} 45 | {% endif %} {% endblock %} 46 | -------------------------------------------------------------------------------- /docs/assets/design.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/design.png -------------------------------------------------------------------------------- /docs/assets/dqn_cartpole.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/dqn_cartpole.png -------------------------------------------------------------------------------- /docs/assets/fqe_cartpole_init_value.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/fqe_cartpole_init_value.png -------------------------------------------------------------------------------- /docs/assets/fqe_cartpole_soft_opc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/fqe_cartpole_soft_opc.png -------------------------------------------------------------------------------- /docs/assets/mdp_dataset.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/mdp_dataset.png -------------------------------------------------------------------------------- /docs/assets/plot_all_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/plot_all_example.png -------------------------------------------------------------------------------- /docs/assets/plot_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/docs/assets/plot_example.png -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. d3rlpy documentation master file, created by 2 | sphinx-quickstart on Mon Jul 6 18:20:21 2020. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | d3rlpy - An offline deep reinforcement learning library. 7 | ======================================================== 8 | 9 | `d3rlpy `_ is a easy-to-use offline deep 10 | reinforcement learning library. 11 | 12 | .. code-block:: 13 | 14 | $ pip install d3rlpy 15 | 16 | d3rlpy provides state-of-the-art offline deep reinforcement learning 17 | algorithms through out-of-the-box scikit-learn-style APIs. 18 | Unlike other RL libraries, the provided algorithms can achieve extremely 19 | powerful performance beyond their papers via several tweaks. 20 | 21 | .. toctree:: 22 | :maxdepth: 2 23 | :caption: Tutorials 24 | 25 | tutorials/index 26 | notebooks 27 | 28 | .. toctree:: 29 | :maxdepth: 2 30 | :caption: References 31 | 32 | software_design 33 | references/index 34 | cli 35 | installation 36 | tips 37 | 38 | .. toctree:: 39 | :maxdepth: 2 40 | :caption: Other 41 | 42 | reproductions 43 | license 44 | 45 | 46 | Indices and tables 47 | ================== 48 | 49 | * :ref:`genindex` 50 | * :ref:`modindex` 51 | * :ref:`search` 52 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Recommended Platforms 5 | --------------------- 6 | 7 | d3rlpy supports Linux, macOS and also Windows. 8 | 9 | 10 | Install d3rlpy 11 | -------------- 12 | 13 | Install via PyPI 14 | ~~~~~~~~~~~~~~~~ 15 | 16 | `pip` is a recommended way to install d3rlpy:: 17 | 18 | $ pip install d3rlpy 19 | 20 | Install via Anaconda 21 | ~~~~~~~~~~~~~~~~~~~~ 22 | 23 | d3rlpy is also available on `conda-forge`:: 24 | 25 | $ conda install -c conda-forge d3rlpy 26 | 27 | 28 | Install via Docker 29 | ~~~~~~~~~~~~~~~~~~ 30 | 31 | d3rlpy is also available on Docker Hub:: 32 | 33 | $ docker run -it --gpus all --name d3rlpy takuseno/d3rlpy:latest bash 34 | 35 | 36 | .. _install_from_source: 37 | 38 | Install from source 39 | ~~~~~~~~~~~~~~~~~~~ 40 | 41 | You can also install via GitHub repository:: 42 | 43 | $ git clone https://github.com/takuseno/d3rlpy 44 | $ cd d3rlpy 45 | $ pip install -e . 46 | -------------------------------------------------------------------------------- /docs/license.rst: -------------------------------------------------------------------------------- 1 | License 2 | ======= 3 | 4 | MIT License 5 | 6 | Copyright (c) 2021 Takuma Seno 7 | 8 | Permission is hereby granted, free of charge, to any person obtaining a copy 9 | of this software and associated documentation files (the "Software"), to deal 10 | in the Software without restriction, including without limitation the rights 11 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | copies of the Software, and to permit persons to whom the Software is 13 | furnished to do so, subject to the following conditions: 14 | 15 | The above copyright notice and this permission notice shall be included in all 16 | copies or substantial portions of the Software. 17 | 18 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 19 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 20 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 21 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 22 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 23 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 24 | SOFTWARE. 25 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/notebooks.rst: -------------------------------------------------------------------------------- 1 | Jupyter Notebooks 2 | ================= 3 | 4 | * `CartPole `_ 5 | * `CartPole (online) `_ 6 | * `Discrete Control with Atari `_ 7 | * `TPU Example `_ 8 | -------------------------------------------------------------------------------- /docs/references/datasets.rst: -------------------------------------------------------------------------------- 1 | Datasets 2 | ======== 3 | 4 | .. module:: d3rlpy.datasets 5 | 6 | d3rlpy provides datasets for experimenting data-driven deep reinforcement 7 | learning algorithms. 8 | 9 | .. autosummary:: 10 | :toctree: generated/ 11 | :nosignatures: 12 | 13 | d3rlpy.datasets.get_cartpole 14 | d3rlpy.datasets.get_pendulum 15 | d3rlpy.datasets.get_atari 16 | d3rlpy.datasets.get_atari_transitions 17 | d3rlpy.datasets.get_d4rl 18 | d3rlpy.datasets.get_dataset 19 | d3rlpy.datasets.get_minari 20 | -------------------------------------------------------------------------------- /docs/references/index.rst: -------------------------------------------------------------------------------- 1 | ************* 2 | API Reference 3 | ************* 4 | 5 | .. module:: d3rlpy 6 | 7 | .. toctree:: 8 | :maxdepth: 2 9 | 10 | algos 11 | q_functions 12 | dataset 13 | datasets 14 | preprocessing 15 | optimizers 16 | network_architectures 17 | metrics 18 | off_policy_evaluation 19 | logging 20 | online 21 | -------------------------------------------------------------------------------- /docs/references/metrics.rst: -------------------------------------------------------------------------------- 1 | Metrics 2 | ======= 3 | 4 | .. module:: d3rlpy.metrics 5 | 6 | d3rlpy provides scoring functions for offline Q-learning-based training. 7 | You can also check :doc:`../references/logging` to understand how to write 8 | metrics to files. 9 | 10 | .. code-block:: python 11 | 12 | import d3rlpy 13 | 14 | dataset, env = d3rlpy.datasets.get_cartpole() 15 | # use partial episodes as test data 16 | test_episodes = dataset.episodes[:10] 17 | 18 | dqn = d3rlpy.algos.DQNConfig().create() 19 | 20 | dqn.fit( 21 | dataset, 22 | n_steps=100000, 23 | evaluators={ 24 | 'td_error': d3rlpy.metrics.TDErrorEvaluator(test_episodes), 25 | 'value_scale': d3rlpy.metrics.AverageValueEstimationEvaluator(test_episodes), 26 | 'environment': d3rlpy.metrics.EnvironmentEvaluator(env), 27 | }, 28 | ) 29 | 30 | You can also implement your own metrics. 31 | 32 | 33 | .. code-block:: python 34 | 35 | class CustomEvaluator(d3rlpy.metrics.EvaluatorProtocol): 36 | def __call__(self, algo: d3rlpy.algos.QLearningAlgoBase, dataset: ReplayBuffer) -> float: 37 | # do some evaluation 38 | 39 | 40 | .. autosummary:: 41 | :toctree: generated/ 42 | :nosignatures: 43 | 44 | d3rlpy.metrics.TDErrorEvaluator 45 | d3rlpy.metrics.DiscountedSumOfAdvantageEvaluator 46 | d3rlpy.metrics.AverageValueEstimationEvaluator 47 | d3rlpy.metrics.InitialStateValueEstimationEvaluator 48 | d3rlpy.metrics.SoftOPCEvaluator 49 | d3rlpy.metrics.ContinuousActionDiffEvaluator 50 | d3rlpy.metrics.DiscreteActionMatchEvaluator 51 | d3rlpy.metrics.EnvironmentEvaluator 52 | d3rlpy.metrics.CompareContinuousActionDiffEvaluator 53 | d3rlpy.metrics.CompareDiscreteActionMatchEvaluator 54 | -------------------------------------------------------------------------------- /docs/references/off_policy_evaluation.rst: -------------------------------------------------------------------------------- 1 | Off-Policy Evaluation 2 | ===================== 3 | 4 | .. module:: d3rlpy.ope 5 | 6 | The off-policy evaluation is a method to estimate the trained policy 7 | performance only with offline datasets. 8 | 9 | .. code-block:: python 10 | 11 | import d3rlpy 12 | 13 | # prepare the trained algorithm 14 | cql = d3rlpy.load_learnable("model.d3") 15 | 16 | # dataset to evaluate with 17 | dataset, env = d3rlpy.datasets.get_pendulum() 18 | 19 | # off-policy evaluation algorithm 20 | fqe = d3rlpy.ope.FQE(algo=cql, config=d3rlpy.ope.FQEConfig()) 21 | 22 | # train estimators to evaluate the trained policy 23 | fqe.fit( 24 | dataset, 25 | n_steps=100000, 26 | evaluators={ 27 | 'init_value': d3rlpy.metrics.InitialStateValueEstimationEvaluator(), 28 | 'soft_opc': d3rlpy.metrics.SoftOPCEvaluator(return_threshold=-300), 29 | }, 30 | ) 31 | 32 | The evaluation during fitting is evaluating the trained policy. 33 | 34 | For continuous control algorithms 35 | --------------------------------- 36 | 37 | .. autosummary:: 38 | :toctree: generated/ 39 | :nosignatures: 40 | 41 | d3rlpy.ope.FQE 42 | 43 | 44 | For discrete control algorithms 45 | ------------------------------- 46 | 47 | .. autosummary:: 48 | :toctree: generated/ 49 | :nosignatures: 50 | 51 | d3rlpy.ope.DiscreteFQE 52 | -------------------------------------------------------------------------------- /docs/references/online.rst: -------------------------------------------------------------------------------- 1 | Online Training 2 | =============== 3 | 4 | .. module:: d3rlpy.algos 5 | 6 | d3rlpy provides not only offline training, but also online training utilities. 7 | Despite being designed for offline training algorithms, d3rlpy is flexible 8 | enough to be trained in an online manner with a few more utilities. 9 | 10 | .. code-block:: python 11 | 12 | import d3lpy 13 | import gym 14 | 15 | # setup environment 16 | env = gym.make('CartPole-v1') 17 | eval_env = gym.make('CartPole-v1') 18 | 19 | # setup algorithm 20 | dqn = d3rlpy.algos.DQN( 21 | batch_size=32, 22 | learning_rate=2.5e-4, 23 | target_update_interval=100, 24 | ).create(device="cuda:0") 25 | 26 | # setup replay buffer 27 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=100000, env=env) 28 | 29 | # setup explorers 30 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy( 31 | start_epsilon=1.0, 32 | end_epsilon=0.1, 33 | duration=10000, 34 | ) 35 | 36 | # start training 37 | dqn.fit_online( 38 | env, 39 | buffer, 40 | explorer=explorer, # you don't need this with probablistic policy algorithms 41 | eval_env=eval_env, 42 | n_steps=30000, # the number of total steps to train. 43 | n_steps_per_epoch=1000, 44 | update_interval=10, # update parameters every 10 steps. 45 | ) 46 | 47 | 48 | Explorers 49 | ~~~~~~~~~ 50 | 51 | .. autosummary:: 52 | :toctree: generated/ 53 | :nosignatures: 54 | 55 | d3rlpy.algos.ConstantEpsilonGreedy 56 | d3rlpy.algos.LinearDecayEpsilonGreedy 57 | d3rlpy.algos.NormalNoise 58 | -------------------------------------------------------------------------------- /docs/references/optimizers.rst: -------------------------------------------------------------------------------- 1 | Optimizers 2 | ========== 3 | 4 | .. module:: d3rlpy.optimizers 5 | 6 | d3rlpy provides ``OptimizerFactory`` that gives you flexible control over 7 | optimizers. 8 | ``OptimizerFactory`` takes PyTorch's optimizer class and its arguments to 9 | initialize, which you can check more `here `_. 10 | 11 | .. code-block:: python 12 | 13 | import d3rlpy 14 | from torch.optim import Adam 15 | 16 | # modify weight decay 17 | optim_factory = d3rlpy.optimizers.OptimizerFactory(Adam, weight_decay=1e-4) 18 | 19 | # set OptimizerFactory 20 | dqn = d3rlpy.algos.DQNConfig(optim_factory=optim_factory).create() 21 | 22 | There are also convenient alises. 23 | 24 | .. code-block:: python 25 | 26 | # alias for Adam optimizer 27 | optim_factory = d3rlpy.optimizers.AdamFactory(weight_decay=1e-4) 28 | 29 | dqn = d3rlpy.algos.DQNConfig(optim_factory=optim_factory).create() 30 | 31 | 32 | .. autosummary:: 33 | :toctree: generated/ 34 | :nosignatures: 35 | 36 | d3rlpy.optimizers.OptimizerFactory 37 | d3rlpy.optimizers.SGDFactory 38 | d3rlpy.optimizers.AdamFactory 39 | d3rlpy.optimizers.RMSpropFactory 40 | d3rlpy.optimizers.GPTAdamWFactory 41 | 42 | 43 | Learning rate scheduler 44 | ~~~~~~~~~~~~~~~~~~~~~~~ 45 | 46 | d3rlpy provides ``LRSchedulerFactory`` that gives you configure learning rate 47 | schedulers with ``OptimizerFactory``. 48 | 49 | .. code-block:: python 50 | 51 | import d3rlpy 52 | 53 | # set lr_scheduler_factory 54 | optim_factory = d3rlpy.optimizers.AdamFactory( 55 | lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory( 56 | warmup_steps=10000 57 | ) 58 | ) 59 | 60 | 61 | .. autosummary:: 62 | :toctree: generated/ 63 | :nosignatures: 64 | 65 | d3rlpy.optimizers.LRSchedulerFactory 66 | d3rlpy.optimizers.WarmupSchedulerFactory 67 | d3rlpy.optimizers.CosineAnnealingLRFactory 68 | -------------------------------------------------------------------------------- /docs/references/q_functions.rst: -------------------------------------------------------------------------------- 1 | Q Functions 2 | =========== 3 | 4 | .. module:: d3rlpy.models 5 | 6 | d3rlpy provides various Q functions including state-of-the-arts, which are 7 | internally used in algorithm objects. 8 | You can switch Q functions by passing ``q_func_factory`` argument at 9 | algorithm initialization. 10 | 11 | .. code-block:: python 12 | 13 | import d3rlpy 14 | 15 | cql = d3rlpy.algos.CQLConfig(q_func_factory=d3rlpy.models.QRQFunctionFactory()) 16 | 17 | Also you can change hyper parameters. 18 | 19 | .. code-block:: python 20 | 21 | q_func = d3rlpy.models.QRQFunctionFactory(n_quantiles=32) 22 | 23 | cql = d3rlpy.algos.CQLConfig(q_func_factory=q_func).create() 24 | 25 | The default Q function is ``mean`` approximator, which estimates expected scalar 26 | action-values. 27 | However, in recent advancements of deep reinforcement learning, the new type 28 | of action-value approximators has been proposed, which is called 29 | `distributional` Q functions. 30 | 31 | Unlike the ``mean`` approximator, the `distributional` Q functions estimate 32 | distribution of action-values. 33 | This `distributional` approaches have shown consistently much stronger 34 | performance than the ``mean`` approximator. 35 | 36 | Here is a list of available Q functions in the order of performance 37 | ascendingly. 38 | Currently, as a trade-off between performance and computational complexity, 39 | the higher performance requires the more expensive computational costs. 40 | 41 | .. autosummary:: 42 | :toctree: generated/ 43 | :nosignatures: 44 | 45 | d3rlpy.models.MeanQFunctionFactory 46 | d3rlpy.models.QRQFunctionFactory 47 | d3rlpy.models.IQNQFunctionFactory 48 | -------------------------------------------------------------------------------- /docs/reproductions.rst: -------------------------------------------------------------------------------- 1 | Paper Reproductions 2 | ------------------- 3 | 4 | For the experiment code, please take a look at 5 | `reproductions `_ directory. 6 | 7 | All the experimental results are available in `d3rlpy-benchmarks `_ repository. 8 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | tensorboardX==2.0 3 | Sphinx==5.0.2 4 | sphinx-rtd-theme==0.5.0 5 | -------------------------------------------------------------------------------- /docs/tutorials/create_your_dataset.rst: -------------------------------------------------------------------------------- 1 | ******************* 2 | Create Your Dataset 3 | ******************* 4 | 5 | The data collection API is introduced in :doc:`data_collection`. 6 | In this tutorial, you can learn how to build your dataset from logged data 7 | such as the user data collected in your web service. 8 | 9 | Prepare Logged Data 10 | ------------------- 11 | 12 | First of all, you need to prepare your logged data. 13 | In this tutorial, let's use randomly generated data. 14 | ``terminals`` represents the last step of episodes. 15 | If ``terminals[i] == 1.0``, i-th step is the terminal state. 16 | Otherwise you need to set zeros for non-terminal states. 17 | 18 | .. code-block:: python 19 | 20 | import numpy as np 21 | 22 | # vector observation 23 | # 1000 steps of observations with shape of (100,) 24 | observations = np.random.random((1000, 100)) 25 | 26 | # 1000 steps of actions with shape of (4,) 27 | actions = np.random.random((1000, 4)) 28 | 29 | # 1000 steps of rewards 30 | rewards = np.random.random(1000) 31 | 32 | # 1000 steps of terminal flags 33 | terminals = np.random.randint(2, size=1000) 34 | 35 | Build MDPDataset 36 | ---------------- 37 | 38 | Once your logged data is ready, you can build ``MDPDataset`` object. 39 | 40 | .. code-block:: python 41 | 42 | import d3rlpy 43 | 44 | dataset = d3rlpy.dataset.MDPDataset( 45 | observations=observations, 46 | actions=actions, 47 | rewards=rewards, 48 | terminals=terminals, 49 | ) 50 | 51 | Set Timeout Flags 52 | ----------------- 53 | 54 | In RL, there is the case where you want to stop an episode without a terminal 55 | state. 56 | For example, if you're collecting data of a 4-legged robot walking forward, 57 | the walking task basically never ends as long as the robot keeps walking while 58 | the logged episode must stop somewhere. 59 | In this case, you can use ``timeouts`` to represent this timeout states. 60 | 61 | .. code-block:: python 62 | 63 | # terminal states 64 | terminals = np.zeros(1000) 65 | 66 | # timeout states 67 | timeouts = np.random.randint(2, size=1000) 68 | 69 | dataset = d3rlpy.dataset.MDPDataset( 70 | observations=observations, 71 | actions=actions, 72 | rewards=rewards, 73 | terminals=terminals, 74 | timeouts=timeouts, 75 | ) 76 | -------------------------------------------------------------------------------- /docs/tutorials/finetuning.rst: -------------------------------------------------------------------------------- 1 | ********** 2 | Finetuning 3 | ********** 4 | 5 | d3rlpy supports smooth transition from offline training to online training. 6 | 7 | Prepare Dataset and Environment 8 | ------------------------------- 9 | 10 | In this tutorial, let's use a built-in dataset for CartPole-v0 environment. 11 | 12 | .. code-block:: python 13 | 14 | import d3rlpy 15 | 16 | # setup random CartPole-v0 dataset and environment 17 | dataset, env = d3rlpy.datasets.get_dataset("cartpole-random") 18 | 19 | Pretrain with Dataset 20 | --------------------- 21 | 22 | .. code-block:: python 23 | 24 | # setup algorithm 25 | dqn = d3rlpy.algos.DQNConfig().create() 26 | 27 | # start offline training 28 | dqn.fit(dataset, n_steps=100000) 29 | 30 | Finetune with Environment 31 | ------------------------- 32 | 33 | .. code-block:: python 34 | 35 | # setup experience replay buffer 36 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=100000, env=env) 37 | 38 | # setup exploration strategy if necessary 39 | explorer = d3rlpy.algos.ConstantEpsilonGreedy(0.1) 40 | 41 | # start finetuning 42 | dqn.fit_online(env, buffer, explorer, n_steps=100000) 43 | 44 | Finetune with Saved Policy 45 | -------------------------- 46 | 47 | If you want to finetune the saved policy, that's also easy to do with d3rlpy. 48 | 49 | .. code-block:: python 50 | 51 | # setup algorithm 52 | dqn = d3rlpy.load_learnable("dqn_model.d3") 53 | 54 | # start finetuning 55 | dqn.fit_online(env, buffer, explorer, n_steps=100000) 56 | 57 | Finetune with Different Algorithm 58 | --------------------------------- 59 | 60 | If you want to finetune the saved policy trained offline with online RL 61 | algorithms, you can do it in an out-of-the-box way. 62 | 63 | .. code-block:: python 64 | 65 | # setup offline RL algorithm 66 | cql = d3rlpy.algos.DiscreteCQLConfig().create() 67 | 68 | # train offline 69 | cql.fit(dataset, n_steps=100000) 70 | 71 | # transfer to DQN 72 | dqn = d3rlpy.algos.DQNConfig().create() 73 | dqn.build_with_env(env) 74 | dqn.copy_q_function_from(cql) 75 | 76 | # start finetuning 77 | dqn.fit_online(env, buffer, explorer, n_steps=100000) 78 | 79 | In actor-critic cases, you should also transfer the policy function. 80 | 81 | .. code-block:: python 82 | 83 | # offline RL 84 | cql = d3rlpy.algos.CQLConfig().create() 85 | cql.fit(dataset, n_steps=100000) 86 | 87 | # transfer to SAC 88 | sac = d3rlpy.algos.SACConfig().create() 89 | sac.build_with_env(env) 90 | sac.copy_q_function_from(cql) 91 | sac.copy_policy_from(cql) 92 | 93 | # online RL 94 | sac.fit_online(env, buffer, n_steps=100000) 95 | -------------------------------------------------------------------------------- /docs/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | ********* 2 | Tutorials 3 | ********* 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | getting_started 9 | data_collection 10 | create_your_dataset 11 | preprocess_and_postprocess 12 | customize_neural_network 13 | online_rl 14 | finetuning 15 | offline_policy_selection 16 | use_distributional_q_function 17 | after_training_policies 18 | -------------------------------------------------------------------------------- /docs/tutorials/online_rl.rst: -------------------------------------------------------------------------------- 1 | ********* 2 | Online RL 3 | ********* 4 | 5 | Prepare Environment 6 | ------------------- 7 | 8 | d3rlpy supports environments with OpenAI Gym interface. 9 | In this tutorial, let's use simple CartPole environment. 10 | 11 | .. code-block:: python 12 | 13 | import gym 14 | 15 | # for training 16 | env = gym.make("CartPole-v1") 17 | 18 | # for evaluation 19 | eval_env = gym.make("CartPole-v1") 20 | 21 | Setup Algorithm 22 | --------------- 23 | 24 | Just like offline RL training, you can setup an algorithm object. 25 | 26 | .. code-block:: python 27 | 28 | import d3rlpy 29 | 30 | # if you don't use GPU, set use_gpu=False instead. 31 | dqn = d3rlpy.algos.DQNConfig( 32 | batch_size=32, 33 | learning_rate=2.5e-4, 34 | target_update_interval=100, 35 | ).create(device="cuda:0") 36 | 37 | # initialize neural networks with the given environment object. 38 | # this is not necessary when you directly call fit or fit_online method. 39 | dqn.build_with_env(env) 40 | 41 | 42 | Setup Online RL Utilities 43 | ------------------------- 44 | 45 | Unlike offline RL training, you'll need to setup an experience replay buffer and 46 | an exploration strategy. 47 | 48 | .. code-block:: python 49 | 50 | # experience replay buffer 51 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=100000, env=env) 52 | 53 | # exploration strategy 54 | # in this tutorial, epsilon-greedy policy with static epsilon=0.3 55 | explorer = d3rlpy.algos.ConstantEpsilonGreedy(0.3) 56 | 57 | 58 | Start Training 59 | -------------- 60 | 61 | Now, you have everything you need to start online RL training. 62 | Let's put them together! 63 | 64 | .. code-block:: python 65 | 66 | dqn.fit_online( 67 | env, 68 | buffer, 69 | explorer, 70 | n_steps=100000, # train for 100K steps 71 | eval_env=eval_env, 72 | n_steps_per_epoch=1000, # evaluation is performed every 1K steps 73 | update_start_step=1000, # parameter update starts after 1K steps 74 | ) 75 | 76 | Train with Stochastic Policy 77 | ---------------------------- 78 | 79 | If the algorithm uses a stochastic policy (e.g. SAC), you can train algorithms 80 | without setting an exploration strategy. 81 | 82 | .. code-block:: python 83 | 84 | sac = d3rlpy.algos.DiscreteSACConfig().create() 85 | sac.fit_online( 86 | env, 87 | buffer, 88 | n_steps=100000, 89 | eval_env=eval_env, 90 | n_steps_per_epoch=1000, 91 | update_start_step=1000, 92 | ) 93 | -------------------------------------------------------------------------------- /docs/tutorials/use_distributional_q_function.rst: -------------------------------------------------------------------------------- 1 | ***************************** 2 | Use Distributional Q-Function 3 | ***************************** 4 | 5 | The one of the unique features in d3rlpy is to use distributional Q-functions 6 | with arbitrary d3rlpy algorithms. 7 | The distributional Q-functions are powerful and potentially capable of 8 | improving performance of any algorithms. 9 | In this tutorial, you can learn how to use them. 10 | Check :doc:`../references/q_functions` for more information. 11 | 12 | .. code-block:: python 13 | 14 | # default standard Q-function 15 | mean_q_function = d3rlpy.models.MeanQFunctionFactory() 16 | sac = d3rlpy.algos.SACConfig(q_func_factory=mean_q_function).create() 17 | 18 | # Quantile Regression Q-function 19 | qr_q_function = d3rlpy.models.QRQFunctionFactory(n_quantiles=200) 20 | sac = d3rlpy.algos.SACConfig(q_func_factory=qr_q_function).create() 21 | 22 | # Implicit Quantile Network Q-function 23 | iqn_q_function = d3rlpy.models.IQNQFunctionFactory( 24 | n_quantiles=32, 25 | n_greedy_quantiles=64, 26 | embed_size=64, 27 | ) 28 | sac = d3rlpy.algos.SACConfig(q_func_factory=iqn_q_function).create() 29 | -------------------------------------------------------------------------------- /examples/deepmind_control.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gymnasium 4 | from gymnasium.wrappers import FlattenObservation 5 | 6 | import d3rlpy 7 | 8 | # You need to install DeepMind Control (DMC) and Shimmy beforehands as follows: 9 | # 10 | # $ d3rlpy install dm_control 11 | # 12 | # After you install the packages, d3rlpy internally registers DMC environments 13 | # for Gymnasium. 14 | 15 | 16 | def main() -> None: 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument("--env", type=str, default="hopper-stand") 19 | parser.add_argument("--seed", type=int, default=1) 20 | parser.add_argument("--gpu", action="store_true") 21 | args = parser.parse_args() 22 | 23 | env_id = f"dm_control/{args.env}" 24 | env = FlattenObservation(gymnasium.make(env_id)) # type: ignore 25 | eval_env = FlattenObservation(gymnasium.make(env_id)) # type: ignore 26 | 27 | # fix seed 28 | d3rlpy.seed(args.seed) 29 | d3rlpy.envs.seed_env(env, args.seed) 30 | d3rlpy.envs.seed_env(eval_env, args.seed) 31 | 32 | # setup algorithm 33 | sac = d3rlpy.algos.SACConfig( 34 | batch_size=256, 35 | actor_learning_rate=3e-4, 36 | critic_learning_rate=3e-4, 37 | temp_learning_rate=3e-4, 38 | ).create(device=args.gpu) 39 | 40 | # replay buffer for experience replay 41 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=1000000, env=env) 42 | 43 | # start training 44 | sac.fit_online( 45 | env, 46 | buffer, 47 | eval_env=eval_env, 48 | n_steps=1000000, 49 | n_steps_per_epoch=10000, 50 | update_interval=1, 51 | update_start_step=1000, 52 | ) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /examples/distributed_offline_training.py: -------------------------------------------------------------------------------- 1 | import d3rlpy 2 | 3 | # This script needs to be launched by using torchrun command. 4 | # $ torchrun \ 5 | # --nnodes=1 \ 6 | # --nproc_per_node=3 \ 7 | # --rdzv_id=100 \ 8 | # --rdzv_backend=c10d \ 9 | # --rdzv_endpoint=localhost:29400 \ 10 | # examples/distributed_offline_training.py 11 | 12 | 13 | def main() -> None: 14 | # GPU version: 15 | # rank = d3rlpy.distributed.init_process_group("nccl") 16 | rank = d3rlpy.distributed.init_process_group("gloo") 17 | print(f"Start running on rank={rank}.") 18 | 19 | # GPU version: 20 | # device = f"cuda:{rank}" 21 | device = "cpu:0" 22 | 23 | # setup algorithm 24 | cql = d3rlpy.algos.CQLConfig( 25 | actor_learning_rate=1e-3, 26 | critic_learning_rate=1e-3, 27 | alpha_learning_rate=1e-3, 28 | ).create(device=device, enable_ddp=True) 29 | 30 | # prepare dataset 31 | dataset, env = d3rlpy.datasets.get_pendulum() 32 | 33 | # disable logging on rank != 0 workers 34 | logger_adapter: d3rlpy.logging.LoggerAdapterFactory 35 | evaluators: dict[str, d3rlpy.metrics.EvaluatorProtocol] 36 | if rank == 0: 37 | evaluators = {"environment": d3rlpy.metrics.EnvironmentEvaluator(env)} 38 | logger_adapter = d3rlpy.logging.FileAdapterFactory() 39 | else: 40 | evaluators = {} 41 | logger_adapter = d3rlpy.logging.NoopAdapterFactory() 42 | 43 | # start training 44 | cql.fit( 45 | dataset, 46 | n_steps=10000, 47 | n_steps_per_epoch=1000, 48 | evaluators=evaluators, 49 | logger_adapter=logger_adapter, 50 | show_progress=rank == 0, 51 | ) 52 | 53 | d3rlpy.distributed.destroy_process_group() 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /examples/fine_tuning.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | 4 | import d3rlpy 5 | 6 | 7 | def main() -> None: 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 10 | parser.add_argument("--seed", type=int, default=1) 11 | parser.add_argument("--gpu", type=int) 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | cql = d3rlpy.algos.CQLConfig( 21 | actor_learning_rate=1e-4, 22 | critic_learning_rate=3e-4, 23 | temp_learning_rate=1e-4, 24 | batch_size=256, 25 | n_action_samples=10, 26 | alpha_learning_rate=0.0, 27 | conservative_weight=10.0, 28 | ).create(device=args.gpu) 29 | 30 | # pretraining 31 | cql.fit( 32 | dataset, 33 | n_steps=100000, 34 | n_steps_per_epoch=1000, 35 | save_interval=10, 36 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 37 | experiment_name=f"CQL_pretraining_{args.dataset}_{args.seed}", 38 | ) 39 | 40 | sac = d3rlpy.algos.SACConfig( 41 | actor_learning_rate=3e-4, 42 | critic_learning_rate=3e-4, 43 | temp_learning_rate=3e-4, 44 | batch_size=256, 45 | ).create(device=args.gpu) 46 | 47 | # copy pretrained models to SAC 48 | sac.build_with_env(env) 49 | sac.copy_policy_from(cql) # type: ignore 50 | sac.copy_q_function_from(cql) # type: ignore 51 | 52 | # prepare FIFO buffer filled with dataset episodes 53 | buffer = d3rlpy.dataset.create_fifo_replay_buffer( 54 | limit=100000, 55 | episodes=dataset.episodes, 56 | ) 57 | 58 | # finetuning 59 | eval_env = copy.deepcopy(env) 60 | d3rlpy.envs.seed_env(eval_env, args.seed) 61 | sac.fit_online( 62 | env, 63 | buffer=buffer, 64 | eval_env=eval_env, 65 | experiment_name=f"SAC_finetuning_{args.dataset}_{args.seed}", 66 | n_steps=100000, 67 | n_steps_per_epoch=1000, 68 | save_interval=10, 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | main() 74 | -------------------------------------------------------------------------------- /examples/frame_stacking.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | import d3rlpy 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--env", type=str, default="PongNoFrameskip-v4") 11 | parser.add_argument("--seed", type=int, default=1) 12 | parser.add_argument("--gpu", action="store_true") 13 | args = parser.parse_args() 14 | 15 | # get wrapped atari environment with 4 frame stacking 16 | # observation shape is [4, 84, 84] 17 | env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4) 18 | eval_env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4, is_eval=True) 19 | 20 | # fix seed 21 | d3rlpy.seed(args.seed) 22 | d3rlpy.envs.seed_env(env, args.seed) 23 | d3rlpy.envs.seed_env(eval_env, args.seed) 24 | 25 | # setup algorithm 26 | dqn = d3rlpy.algos.DQNConfig( 27 | batch_size=32, 28 | learning_rate=2.5e-4, 29 | optim_factory=d3rlpy.optimizers.RMSpropFactory(), 30 | target_update_interval=10000 // 4, 31 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), 32 | ).create(device=args.gpu) 33 | 34 | # replay buffer for experience replay 35 | buffer = d3rlpy.dataset.create_fifo_replay_buffer( 36 | limit=1000000, 37 | # stack last 4 frames (stacked shape is [4, 84, 84]) 38 | transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4), 39 | # store only last frame to save memory (stored shape is [1, 84, 84]) 40 | writer_preprocessor=d3rlpy.dataset.LastFrameWriterPreprocess(), 41 | env=env, 42 | ) 43 | 44 | # epilon-greedy explorer 45 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy( 46 | start_epsilon=1.0, end_epsilon=0.1, duration=1000000 47 | ) 48 | 49 | # start training 50 | dqn.fit_online( 51 | env, 52 | buffer, 53 | explorer, 54 | eval_env=eval_env, 55 | eval_epsilon=0.01, 56 | n_steps=1000000, 57 | n_steps_per_epoch=100000, 58 | update_interval=4, 59 | update_start_step=50000, 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /examples/gymnasium_env.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gymnasium 4 | 5 | import d3rlpy 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--env", type=str, default="Hopper-v2") 11 | parser.add_argument("--seed", type=int, default=1) 12 | parser.add_argument("--gpu", action="store_true") 13 | args = parser.parse_args() 14 | 15 | # d3rlpy supports both Gym and Gymnasium 16 | env = gymnasium.make(args.env) 17 | eval_env = gymnasium.make(args.env) 18 | 19 | # fix seed 20 | d3rlpy.seed(args.seed) 21 | d3rlpy.envs.seed_env(env, args.seed) 22 | d3rlpy.envs.seed_env(eval_env, args.seed) 23 | 24 | # setup algorithm 25 | sac = d3rlpy.algos.SACConfig( 26 | batch_size=256, 27 | actor_learning_rate=3e-4, 28 | critic_learning_rate=3e-4, 29 | temp_learning_rate=3e-4, 30 | ).create(device=args.gpu) 31 | 32 | # replay buffer for experience replay 33 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=1000000, env=env) 34 | 35 | # start training 36 | sac.fit_online( 37 | env, 38 | buffer, 39 | eval_env=eval_env, 40 | n_steps=1000000, 41 | n_steps_per_epoch=10000, 42 | update_interval=1, 43 | update_start_step=1000, 44 | ) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /examples/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gymnasium 4 | 5 | import d3rlpy 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--env", type=str, default="Hopper-v2") 11 | parser.add_argument("--seed", type=int, default=1) 12 | parser.add_argument("--gpu", action="store_true") 13 | args = parser.parse_args() 14 | 15 | env = gymnasium.make(args.env) 16 | eval_env = gymnasium.make(args.env) 17 | 18 | # fix seed 19 | d3rlpy.seed(args.seed) 20 | d3rlpy.envs.seed_env(env, args.seed) 21 | d3rlpy.envs.seed_env(eval_env, args.seed) 22 | 23 | # setup algorithm 24 | sac = d3rlpy.algos.SACConfig( 25 | batch_size=256, 26 | actor_learning_rate=3e-4, 27 | critic_learning_rate=3e-4, 28 | actor_optim_factory=d3rlpy.optimizers.AdamFactory( 29 | # setup learning rate scheduler 30 | lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory( 31 | warmup_steps=10000 32 | ), 33 | ), 34 | critic_optim_factory=d3rlpy.optimizers.AdamFactory( 35 | # setup learning rate scheduler 36 | lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory( 37 | warmup_steps=10000 38 | ), 39 | ), 40 | temp_learning_rate=3e-4, 41 | ).create(device=args.gpu) 42 | 43 | # replay buffer for experience replay 44 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=1000000, env=env) 45 | 46 | # start training 47 | sac.fit_online( 48 | env, 49 | buffer, 50 | eval_env=eval_env, 51 | n_steps=1000000, 52 | n_steps_per_epoch=10000, 53 | update_interval=1, 54 | update_start_step=1000, 55 | ) 56 | 57 | 58 | if __name__ == "__main__": 59 | main() 60 | -------------------------------------------------------------------------------- /examples/multi_step_training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | import d3rlpy 6 | 7 | GAMMA = 0.99 8 | 9 | 10 | def main() -> None: 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--env", type=str, default="Pendulum-v1") 13 | parser.add_argument("--seed", type=int, default=1) 14 | parser.add_argument("--n-steps", type=int, default=1) 15 | parser.add_argument("--gpu", action="store_true") 16 | args = parser.parse_args() 17 | 18 | env = gym.make(args.env) 19 | eval_env = gym.make(args.env) 20 | 21 | # fix seed 22 | d3rlpy.seed(args.seed) 23 | d3rlpy.envs.seed_env(env, args.seed) 24 | d3rlpy.envs.seed_env(eval_env, args.seed) 25 | 26 | # setup algorithm 27 | sac = d3rlpy.algos.SACConfig( 28 | batch_size=256, 29 | gamma=GAMMA, 30 | actor_learning_rate=3e-4, 31 | critic_learning_rate=3e-4, 32 | temp_learning_rate=3e-4, 33 | action_scaler=d3rlpy.preprocessing.MinMaxActionScaler(), 34 | ).create(device=args.gpu) 35 | 36 | # multi-step transition sampling 37 | transition_picker = d3rlpy.dataset.MultiStepTransitionPicker( 38 | n_steps=args.n_steps, 39 | gamma=GAMMA, 40 | ) 41 | 42 | # replay buffer for experience replay 43 | buffer = d3rlpy.dataset.create_fifo_replay_buffer( 44 | limit=100000, 45 | env=env, 46 | transition_picker=transition_picker, 47 | ) 48 | 49 | # start training 50 | sac.fit_online( 51 | env, 52 | buffer, 53 | eval_env=eval_env, 54 | n_steps=100000, 55 | n_steps_per_epoch=1000, 56 | update_interval=1, 57 | update_start_step=1000, 58 | ) 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /examples/preprocessors.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | import d3rlpy 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--env", type=str, default="Pendulum-v1") 11 | parser.add_argument("--seed", type=int, default=1) 12 | parser.add_argument("--gpu", action="store_true") 13 | args = parser.parse_args() 14 | 15 | env = gym.make(args.env) 16 | eval_env = gym.make(args.env) 17 | 18 | # fix seed 19 | d3rlpy.seed(args.seed) 20 | d3rlpy.envs.seed_env(env, args.seed) 21 | d3rlpy.envs.seed_env(eval_env, args.seed) 22 | 23 | # setup algorithm 24 | sac = d3rlpy.algos.SACConfig( 25 | batch_size=256, 26 | actor_learning_rate=3e-4, 27 | critic_learning_rate=3e-4, 28 | temp_learning_rate=3e-4, 29 | # normalizes observations within [-1, 1] range 30 | observation_scaler=d3rlpy.preprocessing.MinMaxObservationScaler(), 31 | # normalizes actions within [-1, 1] range 32 | action_scaler=d3rlpy.preprocessing.MinMaxActionScaler(), 33 | # multiply rewards by 0.1 34 | reward_scaler=d3rlpy.preprocessing.MultiplyRewardScaler(0.1), 35 | ).create(device=args.gpu) 36 | 37 | # replay buffer for experience replay 38 | buffer = d3rlpy.dataset.create_fifo_replay_buffer( 39 | limit=100000, 40 | env=env, 41 | ) 42 | 43 | # start training 44 | sac.fit_online( 45 | env, 46 | buffer, 47 | eval_env=eval_env, 48 | n_steps=100000, 49 | n_steps_per_epoch=1000, 50 | update_interval=1, 51 | update_start_step=1000, 52 | ) 53 | 54 | 55 | if __name__ == "__main__": 56 | main() 57 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | python_version = 3.9 3 | strict = True 4 | strict_optional = True 5 | disallow_untyped_defs = True 6 | disallow_incomplete_defs = True 7 | disallow_untyped_decorators = True 8 | no_implicit_optional = True 9 | warn_redundant_casts = True 10 | warn_unused_ignores = True 11 | warn_return_any = True 12 | warn_unused_configs = True 13 | plugins = numpy.typing.mypy_plugin 14 | 15 | [mypy-torch.*] 16 | ignore_missing_imports = True 17 | follow_imports = skip 18 | follow_imports_for_stubs = True 19 | 20 | [mypy-tqdm.*] 21 | ignore_missing_imports = True 22 | 23 | [mypy-tensorboardX] 24 | ignore_missing_imports = True 25 | 26 | [mypy-wandb] 27 | ignore_missing_imports = True 28 | follow_imports = skip 29 | follow_imports_for_stubs = True 30 | 31 | [mypy-matplotlib.*] 32 | ignore_missing_imports = True 33 | 34 | [mypy-seaborn.*] 35 | ignore_missing_imports = True 36 | 37 | [mypy-cv2.*] 38 | ignore_missing_imports = True 39 | 40 | [mypy-h5py.*] 41 | ignore_missing_imports = True 42 | 43 | [mypy-dataclasses_json.*] 44 | ignore_missing_imports = True 45 | 46 | [mypy-onnxruntime.*] 47 | ignore_missing_imports = True 48 | 49 | [mypy-click.*] 50 | ignore_missing_imports = True 51 | follow_imports = skip 52 | follow_imports_for_stubs = True 53 | 54 | [mypy-pyvirtualdisplay.*] 55 | ignore_missing_imports = True 56 | 57 | [mypy-IPython.*] 58 | ignore_missing_imports = True 59 | follow_imports = skip 60 | follow_imports_for_stubs = True 61 | 62 | [mypy-minari.*] 63 | ignore_missing_imports = True 64 | 65 | [mypy-d4rl.*] 66 | ignore_missing_imports = True 67 | follow_imports = skip 68 | follow_imports_for_stubs = True 69 | 70 | [mypy-shimmy.*] 71 | ignore_missing_imports = True 72 | follow_imports = skip 73 | follow_imports_for_stubs = True 74 | 75 | [mypy-sklearn.*] 76 | ignore_missing_imports = True 77 | -------------------------------------------------------------------------------- /reproductions/finetuning/awac_finetune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | 4 | import d3rlpy 5 | 6 | 7 | def main() -> None: 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--dataset", type=str, default="D4RL/antmaze/umaze-v1") 10 | parser.add_argument("--seed", type=int, default=1) 11 | parser.add_argument("--gpu", type=int) 12 | parser.add_argument("--compile", action="store_true") 13 | args = parser.parse_args() 14 | 15 | dataset, env = d3rlpy.datasets.get_minari(args.dataset) 16 | 17 | # fix seed 18 | d3rlpy.seed(args.seed) 19 | d3rlpy.envs.seed_env(env, args.seed) 20 | 21 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256, 256]) 22 | optim = d3rlpy.optimizers.AdamFactory(weight_decay=1e-4) 23 | # for antmaze datasets 24 | reward_scaler = d3rlpy.preprocessing.ConstantShiftRewardScaler(shift=-1) 25 | 26 | awac = d3rlpy.algos.AWACConfig( 27 | actor_learning_rate=3e-4, 28 | actor_encoder_factory=encoder, 29 | actor_optim_factory=optim, 30 | critic_learning_rate=3e-4, 31 | batch_size=1024, 32 | lam=1.0, 33 | reward_scaler=reward_scaler, 34 | compile_graph=args.compile, 35 | ).create(device=args.gpu) 36 | 37 | awac.fit( 38 | dataset, 39 | n_steps=25000, 40 | n_steps_per_epoch=5000, 41 | save_interval=10, 42 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 43 | experiment_name=f"AWAC_pretraining_{args.dataset}_{args.seed}", 44 | ) 45 | 46 | # prepare FIFO buffer filled with dataset episodes 47 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(1000000) 48 | for episode in dataset.episodes: 49 | buffer.append_episode(episode) 50 | 51 | # finetuning 52 | eval_env = copy.deepcopy(env) 53 | d3rlpy.envs.seed_env(eval_env, args.seed) 54 | awac.fit_online( 55 | env, 56 | buffer=buffer, 57 | eval_env=eval_env, 58 | experiment_name=f"AWAC_finetuning_{args.dataset}_{args.seed}", 59 | n_steps=1000000, 60 | n_steps_per_epoch=1000, 61 | save_interval=10, 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /reproductions/finetuning/iql_finetune.py: -------------------------------------------------------------------------------- 1 | # pylint: disable=protected-access 2 | import argparse 3 | import copy 4 | 5 | import d3rlpy 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--dataset", type=str, default="D4RL/antmaze/umaze-v1") 11 | parser.add_argument("--seed", type=int, default=1) 12 | parser.add_argument("--gpu", type=int) 13 | parser.add_argument("--compile", action="store_true") 14 | args = parser.parse_args() 15 | 16 | dataset, env = d3rlpy.datasets.get_minari(args.dataset) 17 | 18 | # fix seed 19 | d3rlpy.seed(args.seed) 20 | d3rlpy.envs.seed_env(env, args.seed) 21 | 22 | # for antmaze datasets 23 | reward_scaler = d3rlpy.preprocessing.ConstantShiftRewardScaler(shift=-1) 24 | 25 | iql = d3rlpy.algos.IQLConfig( 26 | actor_learning_rate=3e-4, 27 | critic_learning_rate=3e-4, 28 | actor_optim_factory=d3rlpy.optimizers.AdamFactory( 29 | lr_scheduler_factory=d3rlpy.optimizers.CosineAnnealingLRFactory( 30 | T_max=1000000 31 | ), 32 | ), 33 | batch_size=256, 34 | weight_temp=10.0, # hyperparameter for antmaze 35 | max_weight=100.0, 36 | expectile=0.9, # hyperparameter for antmaze 37 | reward_scaler=reward_scaler, 38 | compile_graph=args.compile, 39 | ).create(device=args.gpu) 40 | 41 | # pretraining 42 | iql.fit( 43 | dataset, 44 | n_steps=1000000, 45 | n_steps_per_epoch=100000, 46 | save_interval=10, 47 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 48 | experiment_name=f"IQL_pretraining_{args.dataset}_{args.seed}", 49 | ) 50 | 51 | # reset learning rate 52 | assert iql.impl 53 | for g in iql.impl._modules.actor_optim.optim.param_groups: 54 | g["lr"] = iql.config.actor_learning_rate 55 | 56 | # prepare FIFO buffer filled with dataset episodes 57 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(1000000) 58 | for episode in dataset.episodes: 59 | buffer.append_episode(episode) 60 | 61 | # finetuning 62 | eval_env = copy.deepcopy(env) 63 | d3rlpy.envs.seed_env(eval_env, args.seed) 64 | iql.fit_online( 65 | env, 66 | buffer=buffer, 67 | eval_env=eval_env, 68 | experiment_name=f"IQL_finetuning_{args.dataset}_{args.seed}", 69 | n_steps=1000000, 70 | n_steps_per_epoch=1000, 71 | save_interval=10, 72 | ) 73 | 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /reproductions/offline/awac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256, 256]) 21 | optim = d3rlpy.optimizers.AdamFactory(weight_decay=1e-4) 22 | 23 | awac = d3rlpy.algos.AWACConfig( 24 | actor_learning_rate=3e-4, 25 | actor_encoder_factory=encoder, 26 | actor_optim_factory=optim, 27 | critic_learning_rate=3e-4, 28 | critic_encoder_factory=encoder, 29 | batch_size=1024, 30 | lam=1.0, 31 | compile_graph=args.compile, 32 | ).create(args.gpu) 33 | 34 | awac.fit( 35 | dataset, 36 | n_steps=500000, 37 | n_steps_per_epoch=1000, 38 | save_interval=10, 39 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 40 | experiment_name=f"AWAC_{args.dataset}_{args.seed}", 41 | ) 42 | 43 | 44 | if __name__ == "__main__": 45 | main() 46 | -------------------------------------------------------------------------------- /reproductions/offline/bcq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750]) 21 | rl_encoder = d3rlpy.models.encoders.VectorEncoderFactory([400, 300]) 22 | 23 | bcq = d3rlpy.algos.BCQConfig( 24 | actor_encoder_factory=rl_encoder, 25 | actor_learning_rate=1e-3, 26 | critic_encoder_factory=rl_encoder, 27 | critic_learning_rate=1e-3, 28 | imitator_encoder_factory=vae_encoder, 29 | imitator_learning_rate=1e-3, 30 | batch_size=100, 31 | lam=0.75, 32 | action_flexibility=0.05, 33 | n_action_samples=100, 34 | compile_graph=args.compile, 35 | ).create(args.gpu) 36 | 37 | bcq.fit( 38 | dataset, 39 | n_steps=500000, 40 | n_steps_per_epoch=1000, 41 | save_interval=10, 42 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 43 | experiment_name=f"BCQ_{args.dataset}_{args.seed}", 44 | ) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /reproductions/offline/bear.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750]) 21 | 22 | if "halfcheetah" in args.dataset: 23 | kernel = "gaussian" 24 | else: 25 | kernel = "laplacian" 26 | 27 | bear = d3rlpy.algos.BEARConfig( 28 | actor_learning_rate=1e-4, 29 | critic_learning_rate=3e-4, 30 | imitator_learning_rate=3e-4, 31 | alpha_learning_rate=1e-3, 32 | imitator_encoder_factory=vae_encoder, 33 | temp_learning_rate=0.0, 34 | initial_temperature=1e-20, 35 | batch_size=256, 36 | mmd_sigma=20.0, 37 | mmd_kernel=kernel, 38 | n_mmd_action_samples=4, 39 | alpha_threshold=0.05, 40 | n_target_samples=10, 41 | n_action_samples=100, 42 | warmup_steps=40000, 43 | compile_graph=args.compile, 44 | ).create(device=args.gpu) 45 | 46 | bear.fit( 47 | dataset, 48 | n_steps=500000, 49 | n_steps_per_epoch=1000, 50 | save_interval=10, 51 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 52 | experiment_name=f"BEAR_{args.dataset}_{args.seed}", 53 | ) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /reproductions/offline/cql.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([256, 256, 256]) 21 | 22 | if "medium-v0" in args.dataset: 23 | conservative_weight = 10.0 24 | else: 25 | conservative_weight = 5.0 26 | 27 | cql = d3rlpy.algos.CQLConfig( 28 | actor_learning_rate=1e-4, 29 | critic_learning_rate=3e-4, 30 | temp_learning_rate=1e-4, 31 | actor_encoder_factory=encoder, 32 | critic_encoder_factory=encoder, 33 | batch_size=256, 34 | n_action_samples=10, 35 | alpha_learning_rate=0.0, 36 | conservative_weight=conservative_weight, 37 | compile_graph=args.compile, 38 | ).create(device=args.gpu) 39 | 40 | cql.fit( 41 | dataset, 42 | n_steps=500000, 43 | n_steps_per_epoch=1000, 44 | save_interval=10, 45 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 46 | experiment_name=f"CQL_{args.dataset}_{args.seed}", 47 | ) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /reproductions/offline/crr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | crr = d3rlpy.algos.CRRConfig( 21 | actor_learning_rate=3e-4, 22 | critic_learning_rate=3e-4, 23 | batch_size=256, 24 | weight_type="binary", 25 | advantage_type="mean", 26 | target_update_type="soft", 27 | compile_graph=args.compile, 28 | ).create(device=args.gpu) 29 | 30 | crr.fit( 31 | dataset, 32 | n_steps=500000, 33 | n_steps_per_epoch=1000, 34 | save_interval=10, 35 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 36 | experiment_name=f"CRR_{args.dataset}_{args.seed}", 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /reproductions/offline/decision_transformer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | if "halfcheetah" in args.dataset: 21 | target_return = 6000 22 | elif "hopper" in args.dataset: 23 | target_return = 3600 24 | elif "walker" in args.dataset: 25 | target_return = 5000 26 | else: 27 | raise ValueError("unsupported dataset") 28 | 29 | dt = d3rlpy.algos.DecisionTransformerConfig( 30 | batch_size=64, 31 | learning_rate=1e-4, 32 | optim_factory=d3rlpy.optimizers.AdamWFactory( 33 | weight_decay=1e-4, 34 | clip_grad_norm=0.25, 35 | lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory( 36 | warmup_steps=10000 37 | ), 38 | ), 39 | encoder_factory=d3rlpy.models.VectorEncoderFactory( 40 | [128], 41 | exclude_last_activation=True, 42 | ), 43 | observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(), 44 | reward_scaler=d3rlpy.preprocessing.MultiplyRewardScaler(0.001), 45 | position_encoding_type=d3rlpy.PositionEncodingType.SIMPLE, 46 | context_size=20, 47 | num_heads=1, 48 | num_layers=3, 49 | max_timestep=1000, 50 | compile_graph=args.compile, 51 | ).create(device=args.gpu) 52 | 53 | dt.fit( 54 | dataset, 55 | n_steps=100000, 56 | n_steps_per_epoch=1000, 57 | save_interval=10, 58 | eval_env=env, 59 | eval_target_return=target_return, 60 | experiment_name=f"DT_{args.dataset}_{args.seed}", 61 | ) 62 | 63 | 64 | if __name__ == "__main__": 65 | main() 66 | -------------------------------------------------------------------------------- /reproductions/offline/discrete_bcq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--game", type=str, default="breakout") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | # fix seed 15 | d3rlpy.seed(args.seed) 16 | 17 | dataset, env = d3rlpy.datasets.get_atari_transitions( 18 | args.game, 19 | fraction=0.01, 20 | index=1 if args.game == "asterix" else 0, 21 | num_stack=4, 22 | ) 23 | 24 | d3rlpy.envs.seed_env(env, args.seed) 25 | 26 | bcq = d3rlpy.algos.DiscreteBCQConfig( 27 | learning_rate=5e-5, 28 | optim_factory=d3rlpy.optimizers.AdamFactory(eps=1e-2 / 32), 29 | batch_size=32, 30 | q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory( 31 | n_quantiles=200 32 | ), 33 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), 34 | target_update_interval=2000, 35 | reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0), 36 | action_flexibility=0.3, 37 | beta=0.01, 38 | compile_graph=args.compile, 39 | ).create(device=args.gpu) 40 | 41 | env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001) 42 | 43 | bcq.fit( 44 | dataset, 45 | n_steps=50000000 // 4, 46 | n_steps_per_epoch=125000, 47 | save_interval=10, 48 | evaluators={"environment": env_scorer}, 49 | experiment_name=f"DiscreteBCQ_{args.game}_{args.seed}", 50 | ) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /reproductions/offline/discrete_cql.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--game", type=str, default="breakout") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | d3rlpy.seed(args.seed) 15 | 16 | dataset, env = d3rlpy.datasets.get_atari_transitions( 17 | args.game, 18 | fraction=0.01, 19 | index=1 if args.game == "asterix" else 0, 20 | num_stack=4, 21 | ) 22 | 23 | d3rlpy.envs.seed_env(env, args.seed) 24 | 25 | cql = d3rlpy.algos.DiscreteCQLConfig( 26 | learning_rate=5e-5, 27 | optim_factory=d3rlpy.optimizers.AdamFactory(eps=1e-2 / 32), 28 | batch_size=32, 29 | alpha=4.0, 30 | q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory( 31 | n_quantiles=200 32 | ), 33 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), 34 | target_update_interval=2000, 35 | reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0), 36 | compile_graph=args.compile, 37 | ).create(device=args.gpu) 38 | 39 | env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001) 40 | 41 | cql.fit( 42 | dataset, 43 | n_steps=50000000 // 4, 44 | n_steps_per_epoch=125000, 45 | evaluators={"environment": env_scorer}, 46 | experiment_name=f"DiscreteCQL_{args.game}_{args.seed}", 47 | ) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /reproductions/offline/iql.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | reward_scaler = d3rlpy.preprocessing.ReturnBasedRewardScaler( 21 | multiplier=1000.0 22 | ) 23 | 24 | iql = d3rlpy.algos.IQLConfig( 25 | actor_learning_rate=3e-4, 26 | critic_learning_rate=3e-4, 27 | actor_optim_factory=d3rlpy.optimizers.AdamFactory( 28 | lr_scheduler_factory=d3rlpy.optimizers.CosineAnnealingLRFactory( 29 | T_max=500000 30 | ), 31 | ), 32 | batch_size=256, 33 | weight_temp=3.0, 34 | max_weight=100.0, 35 | expectile=0.7, 36 | reward_scaler=reward_scaler, 37 | compile_graph=args.compile, 38 | ).create(device=args.gpu) 39 | 40 | iql.fit( 41 | dataset, 42 | n_steps=500000, 43 | n_steps_per_epoch=1000, 44 | save_interval=10, 45 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 46 | experiment_name=f"IQL_{args.dataset}_{args.seed}", 47 | ) 48 | 49 | 50 | if __name__ == "__main__": 51 | main() 52 | -------------------------------------------------------------------------------- /reproductions/offline/nfq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--game", type=str, default="breakout") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | # fix seed 15 | d3rlpy.seed(args.seed) 16 | 17 | dataset, env = d3rlpy.datasets.get_atari_transitions( 18 | args.game, 19 | fraction=0.01, 20 | index=1 if args.game == "asterix" else 0, 21 | num_stack=4, 22 | ) 23 | 24 | d3rlpy.envs.seed_env(env, args.seed) 25 | 26 | nfq = d3rlpy.algos.NFQConfig( 27 | learning_rate=5e-5, 28 | optim_factory=d3rlpy.optimizers.AdamFactory(), 29 | batch_size=32, 30 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), 31 | reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0), 32 | compile_graph=args.compile, 33 | ).create(device=args.gpu) 34 | 35 | env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001) 36 | 37 | nfq.fit( 38 | dataset, 39 | n_steps=50000000 // 4, 40 | n_steps_per_epoch=125000, 41 | save_interval=10, 42 | evaluators={"environment": env_scorer}, 43 | experiment_name=f"NFQ_{args.game}_{args.seed}", 44 | ) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /reproductions/offline/plas.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | if "medium-replay" in args.dataset: 21 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([128, 128]) 22 | else: 23 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750]) 24 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([400, 300]) 25 | 26 | plas = d3rlpy.algos.PLASConfig( 27 | actor_learning_rate=1e-4, 28 | actor_encoder_factory=encoder, 29 | critic_learning_rate=1e-3, 30 | critic_encoder_factory=encoder, 31 | imitator_learning_rate=1e-4, 32 | imitator_encoder_factory=vae_encoder, 33 | batch_size=100, 34 | lam=1.0, 35 | warmup_steps=500000, 36 | compile_graph=args.compile, 37 | ).create(device=args.gpu) 38 | 39 | plas.fit( 40 | dataset, 41 | n_steps=1000000, # RL starts at 500000 step 42 | n_steps_per_epoch=1000, 43 | save_interval=10, 44 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 45 | experiment_name=f"PLAS_{args.dataset}_{args.seed}", 46 | ) 47 | 48 | 49 | if __name__ == "__main__": 50 | main() 51 | -------------------------------------------------------------------------------- /reproductions/offline/plas_with_perturbation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | ACTION_FLEXIBILITY = { 6 | "walker2d-random-v0": 0.05, 7 | "hopper-random-v0": 0.5, 8 | "halfcheetah-random-v0": 0.01, 9 | "walker2d-medium-v0": 0.1, 10 | "hopper-medium-v0": 0.01, 11 | "halfcheetah-medium-v0": 0.1, 12 | "walker2d-medium-relay-v0": 0.01, 13 | "hopper-medium-replay-v0": 0.2, 14 | "halfcheetah-medium-replay-v0": 0.05, 15 | "walker2d-medium-expert-v0": 0.01, 16 | "hopper-medium-expert-v0": 0.01, 17 | "halfcheetah-medium-expert-v0": 0.01, 18 | } 19 | 20 | 21 | def main() -> None: 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 24 | parser.add_argument("--seed", type=int, default=1) 25 | parser.add_argument("--gpu", type=int) 26 | parser.add_argument("--compile", action="store_true") 27 | args = parser.parse_args() 28 | 29 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 30 | 31 | # fix seed 32 | d3rlpy.seed(args.seed) 33 | d3rlpy.envs.seed_env(env, args.seed) 34 | 35 | if "medium-replay" in args.dataset: 36 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([128, 128]) 37 | else: 38 | vae_encoder = d3rlpy.models.encoders.VectorEncoderFactory([750, 750]) 39 | encoder = d3rlpy.models.encoders.VectorEncoderFactory([400, 300]) 40 | 41 | plas = d3rlpy.algos.PLASWithPerturbationConfig( 42 | actor_learning_rate=1e-4, 43 | actor_encoder_factory=encoder, 44 | critic_learning_rate=1e-3, 45 | critic_encoder_factory=encoder, 46 | imitator_learning_rate=1e-4, 47 | imitator_encoder_factory=vae_encoder, 48 | batch_size=100, 49 | lam=1.0, 50 | warmup_steps=500000, 51 | action_flexibility=ACTION_FLEXIBILITY[args.dataset], 52 | compile_graph=args.compile, 53 | ).create(device=args.gpu) 54 | 55 | plas.fit( 56 | dataset, 57 | n_steps=1000000, # RL starts at 500000 step 58 | n_steps_per_epoch=1000, 59 | save_interval=10, 60 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 61 | experiment_name=f"PLASWithPerturbation_{args.dataset}_{args.seed}", 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /reproductions/offline/prdc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v2") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int, default=0) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | prdc = d3rlpy.algos.PRDCConfig( 21 | actor_learning_rate=3e-4, 22 | critic_learning_rate=3e-4, 23 | batch_size=256, 24 | target_smoothing_sigma=0.2, 25 | target_smoothing_clip=0.5, 26 | alpha=2.5, 27 | update_actor_interval=2, 28 | observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(), 29 | compile_graph=args.compile, 30 | ).create(device=args.gpu) 31 | 32 | prdc.fit( 33 | dataset, 34 | n_steps=500000, 35 | n_steps_per_epoch=1000, 36 | save_interval=10, 37 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 38 | experiment_name=f"PRDC_{args.dataset}_{args.seed}", 39 | ) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /reproductions/offline/qr_dqn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--game", type=str, default="breakout") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | # fix seed 15 | d3rlpy.seed(args.seed) 16 | 17 | dataset, env = d3rlpy.datasets.get_atari_transitions( 18 | args.game, 19 | fraction=0.01, 20 | index=1 if args.game == "asterix" else 0, 21 | num_stack=4, 22 | ) 23 | 24 | d3rlpy.envs.seed_env(env, args.seed) 25 | 26 | dqn = d3rlpy.algos.DQNConfig( 27 | learning_rate=5e-5, 28 | optim_factory=d3rlpy.optimizers.AdamFactory(eps=1e-2 / 32), 29 | batch_size=32, 30 | q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory( 31 | n_quantiles=200 32 | ), 33 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), 34 | target_update_interval=2000, 35 | reward_scaler=d3rlpy.preprocessing.ClipRewardScaler(-1.0, 1.0), 36 | compile_graph=args.compile, 37 | ).create(device=args.gpu) 38 | 39 | env_scorer = d3rlpy.metrics.EnvironmentEvaluator(env, epsilon=0.001) 40 | 41 | dqn.fit( 42 | dataset, 43 | n_steps=50000000 // 4, 44 | n_steps_per_epoch=125000, 45 | save_interval=10, 46 | evaluators={"environment": env_scorer}, 47 | experiment_name=f"QRDQN_{args.game}_{args.seed}", 48 | ) 49 | 50 | 51 | if __name__ == "__main__": 52 | main() 53 | -------------------------------------------------------------------------------- /reproductions/offline/rebrac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | BETA_TABLE: dict[str, tuple[float, float]] = { 6 | "halfcheetah-random": (0.001, 0.1), 7 | "halfcheetah-medium": (0.001, 0.01), 8 | "halfcheetah-expert": (0.01, 0.01), 9 | "halfcheetah-medium-replay": (0.01, 0.001), 10 | "halfcheetah-full-replay": (0.001, 0.1), 11 | "hopper-random": (0.001, 0.01), 12 | "hopper-medium": (0.01, 0.001), 13 | "hopper-expert": (0.1, 0.001), 14 | "hopper-medium-expert": (0.1, 0.01), 15 | "hopper-medium-replay": (0.05, 0.5), 16 | "hopper-full-replay": (0.01, 0.01), 17 | "walker2d-random": (0.01, 0.0), 18 | "walker2d-medium": (0.05, 0.1), 19 | "walker2d-expert": (0.01, 0.5), 20 | "walker2d-medium-expert": (0.01, 0.01), 21 | "walker2d-medium-replay": (0.05, 0.01), 22 | "walker2d-full-replay": (0.01, 0.01), 23 | } 24 | 25 | 26 | def main() -> None: 27 | parser = argparse.ArgumentParser() 28 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 29 | parser.add_argument("--seed", type=int, default=1) 30 | parser.add_argument("--gpu", type=int) 31 | parser.add_argument("--compile", action="store_true") 32 | args = parser.parse_args() 33 | 34 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 35 | 36 | # fix seed 37 | d3rlpy.seed(args.seed) 38 | d3rlpy.envs.seed_env(env, args.seed) 39 | 40 | # deeper network 41 | actor_encoder = d3rlpy.models.VectorEncoderFactory([256, 256, 256]) 42 | critic_encoder = d3rlpy.models.VectorEncoderFactory( 43 | [256, 256, 256], use_layer_norm=True 44 | ) 45 | 46 | actor_beta, critic_beta = 0.01, 0.01 47 | for dataset_name, beta_from_paper in BETA_TABLE.items(): 48 | if dataset_name in args.dataset: 49 | actor_beta, critic_beta = beta_from_paper 50 | break 51 | 52 | rebrac = d3rlpy.algos.ReBRACConfig( 53 | actor_learning_rate=1e-3, 54 | critic_learning_rate=1e-3, 55 | batch_size=1024, 56 | gamma=0.99, 57 | actor_encoder_factory=actor_encoder, 58 | critic_encoder_factory=critic_encoder, 59 | target_smoothing_sigma=0.2, 60 | target_smoothing_clip=0.5, 61 | update_actor_interval=2, 62 | actor_beta=actor_beta, 63 | critic_beta=critic_beta, 64 | observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(), 65 | compile_graph=args.compile, 66 | ).create(device=args.gpu) 67 | 68 | rebrac.fit( 69 | dataset, 70 | n_steps=1000000, 71 | n_steps_per_epoch=1000, 72 | save_interval=10, 73 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 74 | experiment_name=f"ReBRAC_{args.dataset}_{args.seed}", 75 | ) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /reproductions/offline/sac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | sac = d3rlpy.algos.SACConfig( 21 | actor_learning_rate=3e-4, 22 | critic_learning_rate=3e-4, 23 | temp_learning_rate=3e-4, 24 | batch_size=256, 25 | compile_graph=args.compile, 26 | ).create(device=args.gpu) 27 | 28 | sac.fit( 29 | dataset, 30 | n_steps=500000, 31 | n_steps_per_epoch=1000, 32 | save_interval=10, 33 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 34 | experiment_name=f"SAC_{args.dataset}_{args.seed}", 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | -------------------------------------------------------------------------------- /reproductions/offline/tacr.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | if "halfcheetah" in args.dataset: 21 | target_return = 6000 22 | elif "hopper" in args.dataset: 23 | target_return = 3600 24 | elif "walker" in args.dataset: 25 | target_return = 5000 26 | else: 27 | raise ValueError("unsupported dataset") 28 | 29 | tacr = d3rlpy.algos.TACRConfig( 30 | batch_size=64, 31 | actor_learning_rate=1e-4, 32 | actor_optim_factory=d3rlpy.optimizers.AdamWFactory( 33 | weight_decay=1e-4, 34 | clip_grad_norm=0.25, 35 | lr_scheduler_factory=d3rlpy.optimizers.WarmupSchedulerFactory( 36 | warmup_steps=10000 37 | ), 38 | ), 39 | actor_encoder_factory=d3rlpy.models.VectorEncoderFactory( 40 | [128], 41 | exclude_last_activation=True, 42 | ), 43 | observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(), 44 | position_encoding_type=d3rlpy.PositionEncodingType.SIMPLE, 45 | context_size=20, 46 | num_heads=1, 47 | num_layers=3, 48 | max_timestep=1000, 49 | compile_graph=args.compile, 50 | ).create(device=args.gpu) 51 | 52 | tacr.fit( 53 | dataset, 54 | n_steps=100000, 55 | n_steps_per_epoch=1000, 56 | save_interval=10, 57 | eval_env=env, 58 | eval_target_return=target_return, 59 | experiment_name=f"TACR_{args.dataset}_{args.seed}", 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /reproductions/offline/td3.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | td3 = d3rlpy.algos.TD3Config( 21 | actor_learning_rate=3e-4, 22 | critic_learning_rate=3e-4, 23 | batch_size=256, 24 | target_smoothing_sigma=0.2, 25 | target_smoothing_clip=0.5, 26 | update_actor_interval=2, 27 | compile_graph=args.compile, 28 | ).create(device=args.gpu) 29 | 30 | td3.fit( 31 | dataset, 32 | n_steps=500000, 33 | n_steps_per_epoch=1000, 34 | save_interval=10, 35 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 36 | experiment_name=f"TD3_{args.dataset}_{args.seed}", 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /reproductions/offline/td3_plus_bc.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import d3rlpy 4 | 5 | 6 | def main() -> None: 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--dataset", type=str, default="hopper-medium-v0") 9 | parser.add_argument("--seed", type=int, default=1) 10 | parser.add_argument("--gpu", type=int) 11 | parser.add_argument("--compile", action="store_true") 12 | args = parser.parse_args() 13 | 14 | dataset, env = d3rlpy.datasets.get_dataset(args.dataset) 15 | 16 | # fix seed 17 | d3rlpy.seed(args.seed) 18 | d3rlpy.envs.seed_env(env, args.seed) 19 | 20 | td3 = d3rlpy.algos.TD3PlusBCConfig( 21 | actor_learning_rate=3e-4, 22 | critic_learning_rate=3e-4, 23 | batch_size=256, 24 | target_smoothing_sigma=0.2, 25 | target_smoothing_clip=0.5, 26 | alpha=2.5, 27 | update_actor_interval=2, 28 | observation_scaler=d3rlpy.preprocessing.StandardObservationScaler(), 29 | compile_graph=args.compile, 30 | ).create(device=args.gpu) 31 | 32 | td3.fit( 33 | dataset, 34 | n_steps=500000, 35 | n_steps_per_epoch=1000, 36 | save_interval=10, 37 | evaluators={"environment": d3rlpy.metrics.EnvironmentEvaluator(env)}, 38 | experiment_name=f"TD3PlusBC_{args.dataset}_{args.seed}", 39 | ) 40 | 41 | 42 | if __name__ == "__main__": 43 | main() 44 | -------------------------------------------------------------------------------- /reproductions/online/double_dqn_online.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | import d3rlpy 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") 11 | parser.add_argument("--seed", type=int, default=1) 12 | parser.add_argument("--gpu", action="store_true") 13 | parser.add_argument("--compile", action="store_true") 14 | args = parser.parse_args() 15 | 16 | # get wrapped atari environment 17 | env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4) 18 | eval_env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4, is_eval=True) 19 | 20 | # fix seed 21 | d3rlpy.seed(args.seed) 22 | d3rlpy.envs.seed_env(env, args.seed) 23 | d3rlpy.envs.seed_env(eval_env, args.seed) 24 | 25 | # setup algorithm 26 | dqn = d3rlpy.algos.DoubleDQNConfig( 27 | batch_size=32, 28 | learning_rate=2.5e-4, 29 | optim_factory=d3rlpy.optimizers.RMSpropFactory(), 30 | target_update_interval=10000, 31 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), 32 | compile_graph=args.compile, 33 | ).create(device=args.gpu) 34 | 35 | # replay buffer for experience replay 36 | buffer = d3rlpy.dataset.create_fifo_replay_buffer( 37 | limit=1000000, 38 | transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4), 39 | writer_preprocessor=d3rlpy.dataset.LastFrameWriterPreprocess(), 40 | env=env, 41 | ) 42 | 43 | # epilon-greedy explorer 44 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy( 45 | start_epsilon=1.0, end_epsilon=0.1, duration=1000000 46 | ) 47 | 48 | # start training 49 | dqn.fit_online( 50 | env, 51 | buffer, 52 | explorer, 53 | eval_env=eval_env, 54 | eval_epsilon=0.01, 55 | n_steps=50000000, 56 | n_steps_per_epoch=100000, 57 | update_interval=4, 58 | update_start_step=50000, 59 | ) 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /reproductions/online/dqn_online.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | import d3rlpy 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") 11 | parser.add_argument("--seed", type=int, default=1) 12 | parser.add_argument("--gpu", action="store_true") 13 | parser.add_argument("--compile", action="store_true") 14 | args = parser.parse_args() 15 | 16 | # get wrapped atari environment 17 | env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4) 18 | eval_env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4, is_eval=True) 19 | 20 | # fix seed 21 | d3rlpy.seed(args.seed) 22 | d3rlpy.envs.seed_env(env, args.seed) 23 | d3rlpy.envs.seed_env(eval_env, args.seed) 24 | 25 | # setup algorithm 26 | dqn = d3rlpy.algos.DQNConfig( 27 | batch_size=32, 28 | learning_rate=2.5e-4, 29 | optim_factory=d3rlpy.optimizers.RMSpropFactory(), 30 | target_update_interval=10000 // 4, 31 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), 32 | compile_graph=args.compile, 33 | ).create(device=args.gpu) 34 | 35 | # replay buffer for experience replay 36 | buffer = d3rlpy.dataset.create_fifo_replay_buffer( 37 | limit=1000000, 38 | transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4), 39 | writer_preprocessor=d3rlpy.dataset.LastFrameWriterPreprocess(), 40 | env=env, 41 | ) 42 | 43 | # epilon-greedy explorer 44 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy( 45 | start_epsilon=1.0, end_epsilon=0.1, duration=1000000 46 | ) 47 | 48 | # start training 49 | dqn.fit_online( 50 | env, 51 | buffer, 52 | explorer, 53 | eval_env=eval_env, 54 | eval_epsilon=0.01, 55 | n_steps=50000000, 56 | n_steps_per_epoch=100000, 57 | update_interval=4, 58 | update_start_step=50000, 59 | ) 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /reproductions/online/iqn_online.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | import d3rlpy 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") 11 | parser.add_argument("--seed", type=int, default=1) 12 | parser.add_argument("--gpu", action="store_true") 13 | parser.add_argument("--compile", action="store_true") 14 | args = parser.parse_args() 15 | 16 | # get wrapped atari environment 17 | env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4) 18 | eval_env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4, is_eval=True) 19 | 20 | # fix seed 21 | d3rlpy.seed(args.seed) 22 | d3rlpy.envs.seed_env(env, args.seed) 23 | d3rlpy.envs.seed_env(eval_env, args.seed) 24 | 25 | # setup algorithm 26 | dqn = d3rlpy.algos.DQNConfig( 27 | batch_size=32, 28 | learning_rate=5e-5, 29 | optim_factory=d3rlpy.optimizers.AdamFactory(eps=1e-2 / 32), 30 | target_update_interval=10000 // 4, 31 | q_func_factory=d3rlpy.models.IQNQFunctionFactory(), 32 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), 33 | compile_graph=args.compile, 34 | ).create(device=args.gpu) 35 | 36 | # replay buffer for experience replay 37 | buffer = d3rlpy.dataset.create_fifo_replay_buffer( 38 | limit=1000000, 39 | transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4), 40 | writer_preprocessor=d3rlpy.dataset.LastFrameWriterPreprocess(), 41 | env=env, 42 | ) 43 | 44 | # epilon-greedy explorer 45 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy( 46 | start_epsilon=1.0, end_epsilon=0.01, duration=1000000 47 | ) 48 | 49 | # start training 50 | dqn.fit_online( 51 | env, 52 | buffer, 53 | explorer, 54 | eval_env=eval_env, 55 | eval_epsilon=0.001, 56 | n_steps=50000000, 57 | n_steps_per_epoch=100000, 58 | update_interval=4, 59 | update_start_step=50000, 60 | ) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /reproductions/online/qr_dqn_online.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | import d3rlpy 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--env", type=str, default="BreakoutNoFrameskip-v4") 11 | parser.add_argument("--seed", type=int, default=1) 12 | parser.add_argument("--gpu", action="store_true") 13 | parser.add_argument("--compile", action="store_true") 14 | args = parser.parse_args() 15 | 16 | # get wrapped atari environment 17 | env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4) 18 | eval_env = d3rlpy.envs.Atari(gym.make(args.env), num_stack=4, is_eval=True) 19 | 20 | # fix seed 21 | d3rlpy.seed(args.seed) 22 | d3rlpy.envs.seed_env(env, args.seed) 23 | d3rlpy.envs.seed_env(eval_env, args.seed) 24 | 25 | # setup algorithm 26 | dqn = d3rlpy.algos.DQNConfig( 27 | batch_size=32, 28 | learning_rate=5e-5, 29 | optim_factory=d3rlpy.optimizers.AdamFactory(eps=1e-2 / 32), 30 | target_update_interval=10000 // 4, 31 | q_func_factory=d3rlpy.models.q_functions.QRQFunctionFactory( 32 | n_quantiles=200 33 | ), 34 | observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), 35 | compile_graph=args.compile, 36 | ).create(device=args.gpu) 37 | 38 | # replay buffer for experience replay 39 | buffer = d3rlpy.dataset.create_fifo_replay_buffer( 40 | limit=1000000, 41 | transition_picker=d3rlpy.dataset.FrameStackTransitionPicker(n_frames=4), 42 | writer_preprocessor=d3rlpy.dataset.LastFrameWriterPreprocess(), 43 | env=env, 44 | ) 45 | 46 | # epilon-greedy explorer 47 | explorer = d3rlpy.algos.LinearDecayEpsilonGreedy( 48 | start_epsilon=1.0, end_epsilon=0.01, duration=1000000 49 | ) 50 | 51 | # start training 52 | dqn.fit_online( 53 | env, 54 | buffer, 55 | explorer, 56 | eval_env=eval_env, 57 | eval_epsilon=0.001, 58 | n_steps=50000000, 59 | n_steps_per_epoch=100000, 60 | update_interval=4, 61 | update_start_step=50000, 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /reproductions/online/sac_online.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import gym 4 | 5 | import d3rlpy 6 | 7 | 8 | def main() -> None: 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--env", type=str, default="Hopper-v2") 11 | parser.add_argument("--seed", type=int, default=1) 12 | parser.add_argument("--gpu", action="store_true") 13 | parser.add_argument("--compile", action="store_true") 14 | args = parser.parse_args() 15 | 16 | env = gym.make(args.env) 17 | eval_env = gym.make(args.env) 18 | 19 | # fix seed 20 | d3rlpy.seed(args.seed) 21 | d3rlpy.envs.seed_env(env, args.seed) 22 | d3rlpy.envs.seed_env(eval_env, args.seed) 23 | 24 | # setup algorithm 25 | sac = d3rlpy.algos.SACConfig( 26 | batch_size=256, 27 | actor_learning_rate=3e-4, 28 | critic_learning_rate=3e-4, 29 | temp_learning_rate=3e-4, 30 | compile_graph=args.compile, 31 | ).create(device=args.gpu) 32 | 33 | # replay buffer for experience replay 34 | buffer = d3rlpy.dataset.create_fifo_replay_buffer(limit=1000000, env=env) 35 | 36 | # start training 37 | sac.fit_online( 38 | env, 39 | buffer, 40 | eval_env=eval_env, 41 | n_steps=1000000, 42 | n_steps_per_epoch=10000, 43 | update_interval=1, 44 | update_start_step=1000, 45 | ) 46 | 47 | 48 | if __name__ == "__main__": 49 | main() 50 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.0 2 | tqdm>=4.66.1 3 | h5py==2.10.0 4 | gym==0.26.2 5 | click==8.0.1 6 | typing-extensions==3.7.4.3 7 | structlog==20.2.0 8 | colorama==0.4.4 9 | gymnasium==1.0.0 10 | scikit-learn==1.5.2 11 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | line-length = 80 2 | indent-width = 4 3 | target-version = "py39" 4 | unsafe-fixes = true 5 | 6 | [lint] 7 | select = ["E4", "E5", "E7", "E9", "F", "UP006", "I", "W"] 8 | ignore = ["F403"] 9 | 10 | # Allow fix for all enabled rules (when `--fix`) is provided. 11 | fixable = ["ALL"] 12 | unfixable = [] 13 | 14 | [format] 15 | # Like Black, use double quotes for strings. 16 | quote-style = "double" 17 | 18 | # Like Black, indent with spaces, rather than tabs. 19 | indent-style = "space" 20 | 21 | # Like Black, respect magic trailing commas. 22 | skip-magic-trailing-comma = false 23 | 24 | # Like Black, automatically detect the appropriate line ending. 25 | line-ending = "auto" 26 | -------------------------------------------------------------------------------- /scripts/build-dist: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # clean 4 | rm -rf d3rlpy.egg-info 5 | rm -rf dist 6 | rm -rf build 7 | 8 | python setup.py test 9 | 10 | python setup.py sdist bdist_wheel 11 | -------------------------------------------------------------------------------- /scripts/build-docker: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | docker build -t takuseno/d3rlpy:latest docker 4 | -------------------------------------------------------------------------------- /scripts/build-docs: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | sphinx-apidoc -f -o ./docs d3rlpy *tests 4 | 5 | sphinx-build -b html ./docs ./docs/_build 6 | -------------------------------------------------------------------------------- /scripts/create_cartpole_dataset: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import gym 4 | import d3rlpy 5 | 6 | d3rlpy.seed(100) 7 | 8 | # prepare environment 9 | env = gym.make('CartPole-v0') 10 | eval_env = gym.make('CartPole-v0') 11 | 12 | # prepare algorithms 13 | dqn = d3rlpy.algos.DQN(learning_rate=1e-3, target_update_interval=100) 14 | 15 | # prepare utilities 16 | buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env) 17 | explorer = d3rlpy.online.explorers.ConstantEpsilonGreedy(epsilon=0.3) 18 | 19 | # start training 20 | dqn.fit_online( 21 | env, buffer=buffer, explorer=explorer, eval_env=eval_env, n_steps=100000 22 | ) 23 | 24 | # export replay buffer as MDPDataset 25 | dataset = buffer.to_mdp_dataset() 26 | 27 | # save MDPDataset 28 | dataset.dump('cartpole.h5') 29 | -------------------------------------------------------------------------------- /scripts/create_cartpole_random_dataset: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import gym 4 | import d3rlpy 5 | 6 | d3rlpy.seed(100) 7 | 8 | # prepare environment 9 | env = gym.make('CartPole-v0') 10 | 11 | # prepare algorithms 12 | policy = d3rlpy.algos.DiscreteRandomPolicy() 13 | 14 | # prepare utilities 15 | buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env) 16 | 17 | # start training 18 | policy.collect(env, buffer=buffer, n_steps=100000) 19 | 20 | # export replay buffer as MDPDataset 21 | dataset = buffer.to_mdp_dataset() 22 | 23 | # save MDPDataset 24 | dataset.dump('cartpole_random.h5') 25 | -------------------------------------------------------------------------------- /scripts/create_pendulum_dataset: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import gym 4 | import d3rlpy 5 | 6 | d3rlpy.seed(100) 7 | 8 | # prepare environment 9 | env = gym.make('Pendulum-v0') 10 | eval_env = gym.make('Pendulum-v0') 11 | 12 | # prepare algorithms 13 | sac = d3rlpy.algos.SAC(action_scaler='min_max') 14 | 15 | # prepare utilities 16 | buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env) 17 | 18 | # start training 19 | sac.fit_online(env, buffer=buffer, eval_env=eval_env, n_steps=100000) 20 | 21 | # export replay buffer as MDPDataset 22 | dataset = buffer.to_mdp_dataset() 23 | 24 | # save MDPDataset 25 | dataset.dump('pendulum.h5') 26 | -------------------------------------------------------------------------------- /scripts/create_pendulum_random_dataset: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import gym 4 | import d3rlpy 5 | 6 | d3rlpy.seed(100) 7 | 8 | # prepare environment 9 | env = gym.make('Pendulum-v0') 10 | 11 | # prepare algorithms 12 | policy = d3rlpy.algos.RandomPolicy( 13 | distribution='normal', 14 | action_scaler='min_max', 15 | ) 16 | 17 | # prepare utilities 18 | buffer = d3rlpy.online.buffers.ReplayBuffer(maxlen=1000000, env=env) 19 | 20 | # start training 21 | policy.collect(env, buffer=buffer, n_steps=100000) 22 | 23 | # export replay buffer as MDPDataset 24 | dataset = buffer.to_mdp_dataset() 25 | 26 | # save MDPDataset 27 | dataset.dump('pendulum_random.h5') 28 | -------------------------------------------------------------------------------- /scripts/lint: -------------------------------------------------------------------------------- 1 | #!/bin/bash -ex 2 | 3 | if [[ -z $CI ]]; then 4 | BLACK_ARG="" 5 | RUFF_ARG="check --fix" 6 | DOCFORMATTER_ARG="--in-place" 7 | else 8 | ISORT_ARG="--check --diff" 9 | RUFF_ARG="check" 10 | DOCFORMATTER_ARG="--check --diff" 11 | fi 12 | 13 | # use black for the better type annotations 14 | black -l 80 $BLACK_ARG d3rlpy tests setup.py reproductions examples 15 | 16 | # formatter and linter 17 | ruff $RUFF_ARG d3rlpy tests examples reproductions setup.py 18 | 19 | # format docstrings 20 | docformatter $DOCFORMATTER_ARG --black --wrap-summaries 80 --wrap-descriptions 80 -r d3rlpy 21 | 22 | # type check 23 | mypy d3rlpy reproductions tests examples 24 | -------------------------------------------------------------------------------- /scripts/test: -------------------------------------------------------------------------------- 1 | #!/bin/bash -eux 2 | 3 | FLAG_P="FALSE" 4 | 5 | while getopts p OPT 6 | do 7 | case $OPT in 8 | "p" ) FLAG_P="TRUE" ;; # do performance test 9 | * ) echo "Usage: ./scripts/test [-p]" 1>&2 10 | exit 1 ;; 11 | esac 12 | done 13 | 14 | # create temporary directory for tests 15 | mkdir -p test_data 16 | 17 | # set flag 18 | if [ "$FLAG_P" = "TRUE" ]; then 19 | TEST_PERFORMANCE="TRUE" 20 | else 21 | TEST_PERFORMANCE="FALSE" 22 | fi 23 | 24 | # run tests 25 | echo "Run unit tests" 26 | TEST_PERFORMANCE=$TEST_PERFORMANCE \ 27 | KMP_DUPLICATE_LIB_OK='True' pytest --cov-report=xml \ 28 | --cov=d3rlpy \ 29 | --cov-config=.coveragerc \ 30 | tests -p no:warnings -v 31 | 32 | # clean up 33 | rm -r test_data 34 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from setuptools import find_packages, setup 4 | 5 | # get __version__ variable 6 | here = os.path.abspath(os.path.dirname(__file__)) 7 | exec(open(os.path.join(here, "d3rlpy", "_version.py")).read()) 8 | 9 | if __name__ == "__main__": 10 | setup( 11 | name="d3rlpy", 12 | version=__version__, # noqa 13 | description="An offline deep reinforcement learning library", 14 | long_description=open("README.md").read(), 15 | long_description_content_type="text/markdown", 16 | url="https://github.com/takuseno/d3rlpy", 17 | author="Takuma Seno", 18 | author_email="takuma.seno@gmail.com", 19 | license="MIT License", 20 | classifiers=[ 21 | "Development Status :: 5 - Production/Stable", 22 | "Intended Audience :: Developers", 23 | "Intended Audience :: Education", 24 | "Intended Audience :: Science/Research", 25 | "Topic :: Scientific/Engineering", 26 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 27 | "Programming Language :: Python :: 3.9", 28 | "Programming Language :: Python :: 3.10", 29 | "Programming Language :: Python :: 3.11", 30 | "Programming Language :: Python :: Implementation :: CPython", 31 | "Operating System :: POSIX :: Linux", 32 | "Operating System :: Microsoft :: Windows", 33 | "Operating System :: MacOS :: MacOS X", 34 | ], 35 | install_requires=[ 36 | "torch>=2.5.0", 37 | "tqdm>=4.66.3", 38 | "h5py", 39 | "gym>=0.26.0", 40 | "click", 41 | "typing-extensions", 42 | "structlog", 43 | "colorama", 44 | "dataclasses-json", 45 | "gymnasium==1.0.0", 46 | "scikit-learn", 47 | ], 48 | packages=find_packages(exclude=["tests*"]), 49 | python_requires=">=3.9.0", 50 | zip_safe=True, 51 | entry_points={"console_scripts": ["d3rlpy=d3rlpy.cli:cli"]}, 52 | ) 53 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | is_skipping_performance_test = os.environ.get("TEST_PERFORMANCE") != "TRUE" 6 | performance_test = pytest.mark.skipif( 7 | is_skipping_performance_test, reason="skip performance tests" 8 | ) 9 | -------------------------------------------------------------------------------- /tests/algos/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/algos/__init__.py -------------------------------------------------------------------------------- /tests/algos/qlearning/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/algos/qlearning/__init__.py -------------------------------------------------------------------------------- /tests/algos/qlearning/test_awac.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.awac import AWACConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_awac( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = AWACConfig( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | q_func_factory=q_func_factory, 37 | observation_scaler=observation_scaler, 38 | action_scaler=action_scaler, 39 | reward_scaler=reward_scaler, 40 | ) 41 | awac = config.create() 42 | algo_tester(awac, observation_shape) # type: ignore 43 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_bc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.bc import BCConfig, DiscreteBCConfig 6 | from d3rlpy.types import Shape 7 | 8 | from ...models.torch.model_test import DummyEncoderFactory 9 | from ...testing_utils import create_scaler_tuple 10 | from .algo_test import algo_tester 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 15 | ) 16 | @pytest.mark.parametrize("policy_type", ["deterministic", "stochastic"]) 17 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 18 | def test_bc( 19 | observation_shape: Shape, policy_type: str, scalers: Optional[str] 20 | ) -> None: 21 | observation_scaler, action_scaler, _ = create_scaler_tuple( 22 | scalers, observation_shape 23 | ) 24 | config = BCConfig( 25 | encoder_factory=DummyEncoderFactory(), 26 | observation_scaler=observation_scaler, 27 | action_scaler=action_scaler, 28 | policy_type=policy_type, 29 | ) 30 | bc = config.create() 31 | algo_tester( 32 | bc, # type: ignore 33 | observation_shape, 34 | test_predict_value=False, 35 | test_policy_copy=False, 36 | test_policy_optim_copy=False, 37 | test_q_function_optim_copy=False, 38 | test_q_function_copy=False, 39 | ) 40 | 41 | 42 | @pytest.mark.parametrize( 43 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 44 | ) 45 | @pytest.mark.parametrize("scaler", [None, "min_max"]) 46 | def test_discrete_bc(observation_shape: Shape, scaler: Optional[str]) -> None: 47 | observation_scaler, _, _ = create_scaler_tuple(scaler, observation_shape) 48 | config = DiscreteBCConfig( 49 | encoder_factory=DummyEncoderFactory(), 50 | observation_scaler=observation_scaler, 51 | ) 52 | bc = config.create() 53 | algo_tester( 54 | bc, # type: ignore 55 | observation_shape, 56 | test_policy_copy=False, 57 | test_predict_value=False, 58 | test_policy_optim_copy=False, 59 | test_q_function_optim_copy=False, 60 | test_q_function_copy=False, 61 | ) 62 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_bcq.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.bcq import BCQConfig, DiscreteBCQConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_bcq( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = BCQConfig( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | imitator_encoder_factory=DummyEncoderFactory(), 37 | q_func_factory=q_func_factory, 38 | observation_scaler=observation_scaler, 39 | action_scaler=action_scaler, 40 | reward_scaler=reward_scaler, 41 | rl_start_step=0, 42 | ) 43 | bcq = config.create() 44 | algo_tester( 45 | bcq, # type: ignore 46 | observation_shape, 47 | deterministic_best_action=False, 48 | test_policy_copy=False, 49 | ) 50 | 51 | 52 | @pytest.mark.parametrize( 53 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 54 | ) 55 | @pytest.mark.parametrize("n_critics", [1]) 56 | @pytest.mark.parametrize( 57 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 58 | ) 59 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 60 | def test_discrete_bcq( 61 | observation_shape: Shape, 62 | n_critics: int, 63 | q_func_factory: QFunctionFactory, 64 | scalers: Optional[str], 65 | ) -> None: 66 | observation_scaler, _, reward_scaler = create_scaler_tuple( 67 | scalers, observation_shape 68 | ) 69 | config = DiscreteBCQConfig( 70 | encoder_factory=DummyEncoderFactory(), 71 | n_critics=n_critics, 72 | q_func_factory=q_func_factory, 73 | observation_scaler=observation_scaler, 74 | reward_scaler=reward_scaler, 75 | ) 76 | bcq = config.create() 77 | algo_tester( 78 | bcq, # type: ignore 79 | observation_shape, 80 | test_policy_copy=False, 81 | test_policy_optim_copy=False, 82 | ) 83 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_bear.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.bear import BEARConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_bear( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = BEARConfig( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | imitator_encoder_factory=DummyEncoderFactory(), 37 | q_func_factory=q_func_factory, 38 | observation_scaler=observation_scaler, 39 | action_scaler=action_scaler, 40 | reward_scaler=reward_scaler, 41 | warmup_steps=0, 42 | ) 43 | bear = config.create() 44 | algo_tester( 45 | bear, # type: ignore 46 | observation_shape, 47 | deterministic_best_action=False, 48 | test_policy_copy=False, 49 | ) 50 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_cal_ql.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.cal_ql import CalQLConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_cal_ql( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = CalQLConfig( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | q_func_factory=q_func_factory, 37 | observation_scaler=observation_scaler, 38 | action_scaler=action_scaler, 39 | reward_scaler=reward_scaler, 40 | ) 41 | cal_ql = config.create() 42 | algo_tester(cal_ql, observation_shape) # type: ignore 43 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_cql.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.cql import CQLConfig, DiscreteCQLConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_cql( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = CQLConfig( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | q_func_factory=q_func_factory, 37 | observation_scaler=observation_scaler, 38 | action_scaler=action_scaler, 39 | reward_scaler=reward_scaler, 40 | ) 41 | cql = config.create() 42 | algo_tester(cql, observation_shape) # type: ignore 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 47 | ) 48 | @pytest.mark.parametrize("n_critics", [1]) 49 | @pytest.mark.parametrize( 50 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 51 | ) 52 | @pytest.mark.parametrize("scalers", [None, None, "min_max"]) 53 | def test_discrete_cql( 54 | observation_shape: Shape, 55 | n_critics: int, 56 | q_func_factory: QFunctionFactory, 57 | scalers: Optional[str], 58 | ) -> None: 59 | observation_scaler, _, reward_scaler = create_scaler_tuple( 60 | scalers, observation_shape 61 | ) 62 | config = DiscreteCQLConfig( 63 | encoder_factory=DummyEncoderFactory(), 64 | n_critics=n_critics, 65 | q_func_factory=q_func_factory, 66 | observation_scaler=observation_scaler, 67 | reward_scaler=reward_scaler, 68 | ) 69 | cql = config.create() 70 | algo_tester( 71 | cql, # type: ignore 72 | observation_shape, 73 | test_policy_copy=False, 74 | test_policy_optim_copy=False, 75 | ) 76 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_crr.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.crr import CRRConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | @pytest.mark.parametrize("advantage_type", ["mean", "max"]) 26 | @pytest.mark.parametrize("weight_type", ["exp", "binary"]) 27 | @pytest.mark.parametrize("target_update_type", ["hard", "soft"]) 28 | def test_crr( 29 | observation_shape: Shape, 30 | q_func_factory: QFunctionFactory, 31 | scalers: Optional[str], 32 | advantage_type: str, 33 | weight_type: str, 34 | target_update_type: str, 35 | ) -> None: 36 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 37 | scalers, observation_shape 38 | ) 39 | config = CRRConfig( 40 | actor_encoder_factory=DummyEncoderFactory(), 41 | critic_encoder_factory=DummyEncoderFactory(), 42 | q_func_factory=q_func_factory, 43 | observation_scaler=observation_scaler, 44 | action_scaler=action_scaler, 45 | reward_scaler=reward_scaler, 46 | advantage_type=advantage_type, 47 | weight_type=weight_type, 48 | target_update_type=target_update_type, 49 | ) 50 | crr = config.create() 51 | algo_tester( 52 | crr, # type: ignore 53 | observation_shape, 54 | deterministic_best_action=False, 55 | test_policy_copy=False, 56 | ) 57 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_ddpg.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.ddpg import DDPGConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_ddpg( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = DDPGConfig( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | q_func_factory=q_func_factory, 37 | observation_scaler=observation_scaler, 38 | action_scaler=action_scaler, 39 | reward_scaler=reward_scaler, 40 | ) 41 | ddpg = config.create() 42 | algo_tester(ddpg, observation_shape) # type: ignore 43 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_dqn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.dqn import DoubleDQNConfig, DQNConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize("n_critics", [1]) 22 | @pytest.mark.parametrize( 23 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 24 | ) 25 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 26 | def test_dqn( 27 | observation_shape: Shape, 28 | n_critics: int, 29 | q_func_factory: QFunctionFactory, 30 | scalers: Optional[str], 31 | ) -> None: 32 | observation_scaler, _, reward_scaler = create_scaler_tuple( 33 | scalers, observation_shape 34 | ) 35 | config = DQNConfig( 36 | encoder_factory=DummyEncoderFactory(), 37 | n_critics=n_critics, 38 | q_func_factory=q_func_factory, 39 | observation_scaler=observation_scaler, 40 | reward_scaler=reward_scaler, 41 | ) 42 | dqn = config.create() 43 | algo_tester( 44 | dqn, # type: ignore 45 | observation_shape, 46 | test_policy_copy=False, 47 | test_policy_optim_copy=False, 48 | ) 49 | 50 | 51 | @pytest.mark.parametrize( 52 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 53 | ) 54 | @pytest.mark.parametrize("n_critics", [1]) 55 | @pytest.mark.parametrize( 56 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 57 | ) 58 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 59 | def test_double_dqn( 60 | observation_shape: Shape, 61 | n_critics: int, 62 | q_func_factory: QFunctionFactory, 63 | scalers: Optional[str], 64 | ) -> None: 65 | observation_scaler, _, reward_scaler = create_scaler_tuple( 66 | scalers, observation_shape 67 | ) 68 | config = DoubleDQNConfig( 69 | encoder_factory=DummyEncoderFactory(), 70 | n_critics=n_critics, 71 | q_func_factory=q_func_factory, 72 | observation_scaler=observation_scaler, 73 | reward_scaler=reward_scaler, 74 | ) 75 | double_dqn = config.create() 76 | algo_tester( 77 | double_dqn, # type: ignore 78 | observation_shape, 79 | test_policy_copy=False, 80 | test_policy_optim_copy=False, 81 | ) 82 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_iql.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.iql import IQLConfig 6 | from d3rlpy.types import Shape 7 | 8 | from ...models.torch.model_test import DummyEncoderFactory 9 | from ...testing_utils import create_scaler_tuple 10 | from .algo_test import algo_tester 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 15 | ) 16 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 17 | def test_iql(observation_shape: Shape, scalers: Optional[str]) -> None: 18 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 19 | scalers, observation_shape 20 | ) 21 | config = IQLConfig( 22 | actor_encoder_factory=DummyEncoderFactory(), 23 | critic_encoder_factory=DummyEncoderFactory(), 24 | value_encoder_factory=DummyEncoderFactory(), 25 | observation_scaler=observation_scaler, 26 | action_scaler=action_scaler, 27 | reward_scaler=reward_scaler, 28 | ) 29 | iql = config.create() 30 | algo_tester(iql, observation_shape) # type: ignore 31 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_nfq.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.nfq import NFQConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize("n_critics", [1]) 22 | @pytest.mark.parametrize( 23 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 24 | ) 25 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 26 | def test_nfq( 27 | observation_shape: Shape, 28 | n_critics: int, 29 | q_func_factory: QFunctionFactory, 30 | scalers: Optional[str], 31 | ) -> None: 32 | observation_scaler, _, reward_scaler = create_scaler_tuple( 33 | scalers, observation_shape 34 | ) 35 | config = NFQConfig( 36 | encoder_factory=DummyEncoderFactory(), 37 | n_critics=n_critics, 38 | q_func_factory=q_func_factory, 39 | observation_scaler=observation_scaler, 40 | reward_scaler=reward_scaler, 41 | ) 42 | nfq = config.create() 43 | algo_tester( 44 | nfq, # type: ignore 45 | observation_shape, 46 | test_policy_copy=False, 47 | test_policy_optim_copy=False, 48 | ) 49 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_plas.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.plas import PLASConfig, PLASWithPerturbationConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_plas( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = PLASConfig( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | imitator_encoder_factory=DummyEncoderFactory(), 37 | q_func_factory=q_func_factory, 38 | observation_scaler=observation_scaler, 39 | action_scaler=action_scaler, 40 | reward_scaler=reward_scaler, 41 | warmup_steps=0, 42 | ) 43 | plas = config.create() 44 | algo_tester(plas, observation_shape, test_policy_copy=False) # type: ignore 45 | 46 | 47 | @pytest.mark.parametrize( 48 | "observation_shape", [(100,), (4, 84, 84), ((100,), (200,))] 49 | ) 50 | @pytest.mark.parametrize( 51 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 52 | ) 53 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 54 | def test_plas_with_perturbation( 55 | observation_shape: Shape, 56 | q_func_factory: QFunctionFactory, 57 | scalers: Optional[str], 58 | ) -> None: 59 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 60 | scalers, observation_shape 61 | ) 62 | config = PLASWithPerturbationConfig( 63 | actor_encoder_factory=DummyEncoderFactory(), 64 | critic_encoder_factory=DummyEncoderFactory(), 65 | imitator_encoder_factory=DummyEncoderFactory(), 66 | q_func_factory=q_func_factory, 67 | observation_scaler=observation_scaler, 68 | action_scaler=action_scaler, 69 | reward_scaler=reward_scaler, 70 | warmup_steps=0, 71 | ) 72 | plas = config.create() 73 | algo_tester(plas, observation_shape, test_policy_copy=False) # type: ignore 74 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_prdc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.prdc import PRDCConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize("observation_shape", [(100,), (17,)]) 19 | @pytest.mark.parametrize( 20 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 21 | ) 22 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 23 | def test_prdc( 24 | observation_shape: Shape, 25 | q_func_factory: QFunctionFactory, 26 | scalers: Optional[str], 27 | ) -> None: 28 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 29 | scalers, observation_shape 30 | ) 31 | config = PRDCConfig( 32 | actor_encoder_factory=DummyEncoderFactory(), 33 | critic_encoder_factory=DummyEncoderFactory(), 34 | q_func_factory=q_func_factory, 35 | observation_scaler=observation_scaler, 36 | action_scaler=action_scaler, 37 | reward_scaler=reward_scaler, 38 | ) 39 | prdc = config.create() 40 | algo_tester(prdc, observation_shape) # type: ignore 41 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_random_policy.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from d3rlpy.algos.qlearning.random_policy import ( 7 | DiscreteRandomPolicyConfig, 8 | RandomPolicyConfig, 9 | ) 10 | 11 | 12 | @pytest.mark.parametrize("distribution", ["uniform", "normal"]) 13 | @pytest.mark.parametrize("action_size", [4]) 14 | @pytest.mark.parametrize("batch_size", [32]) 15 | @pytest.mark.parametrize("observation_shape", [(100,)]) 16 | def test_random_policy( 17 | distribution: str, 18 | action_size: int, 19 | batch_size: int, 20 | observation_shape: Sequence[int], 21 | ) -> None: 22 | config = RandomPolicyConfig(distribution=distribution) 23 | algo = config.create() 24 | algo.create_impl(observation_shape, action_size) 25 | 26 | x = np.random.random((batch_size, *observation_shape)) 27 | 28 | # check predict 29 | action = algo.predict(x) 30 | assert action.shape == (batch_size, action_size) 31 | 32 | # check sample_action 33 | action = algo.sample_action(x) 34 | assert action.shape == (batch_size, action_size) 35 | 36 | 37 | @pytest.mark.parametrize("action_size", [4]) 38 | @pytest.mark.parametrize("batch_size", [32]) 39 | @pytest.mark.parametrize("observation_shape", [(100,)]) 40 | def test_discrete_random_policy( 41 | action_size: int, batch_size: int, observation_shape: Sequence[int] 42 | ) -> None: 43 | algo = DiscreteRandomPolicyConfig().create() 44 | algo.create_impl(observation_shape, action_size) 45 | 46 | x = np.random.random((batch_size, *observation_shape)) 47 | 48 | # check predict 49 | action = algo.predict(x) 50 | assert action.shape == (batch_size,) 51 | assert np.all(action < action_size) 52 | 53 | # check sample_action 54 | action = algo.sample_action(x) 55 | assert action.shape == (batch_size,) 56 | assert np.all(action < action_size) 57 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_rebrac.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.rebrac import ReBRACConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_rebrac( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = ReBRACConfig( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | q_func_factory=q_func_factory, 37 | observation_scaler=observation_scaler, 38 | action_scaler=action_scaler, 39 | reward_scaler=reward_scaler, 40 | ) 41 | rebrac = config.create() 42 | algo_tester(rebrac, observation_shape) # type: ignore 43 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_sac.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.sac import DiscreteSACConfig, SACConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_sac( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = SACConfig( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | q_func_factory=q_func_factory, 37 | observation_scaler=observation_scaler, 38 | action_scaler=action_scaler, 39 | reward_scaler=reward_scaler, 40 | ) 41 | sac = config.create() 42 | algo_tester(sac, observation_shape) # type: ignore 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 47 | ) 48 | @pytest.mark.parametrize( 49 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 50 | ) 51 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 52 | def test_discrete_sac( 53 | observation_shape: Shape, 54 | q_func_factory: QFunctionFactory, 55 | scalers: Optional[str], 56 | ) -> None: 57 | observation_scaler, _, reward_scaler = create_scaler_tuple( 58 | scalers, observation_shape 59 | ) 60 | config = DiscreteSACConfig( 61 | actor_encoder_factory=DummyEncoderFactory(), 62 | critic_encoder_factory=DummyEncoderFactory(), 63 | q_func_factory=q_func_factory, 64 | observation_scaler=observation_scaler, 65 | reward_scaler=reward_scaler, 66 | ) 67 | sac = config.create() 68 | algo_tester(sac, observation_shape, action_size=100) # type: ignore 69 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_td3.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.td3 import TD3Config 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_td3( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = TD3Config( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | q_func_factory=q_func_factory, 37 | observation_scaler=observation_scaler, 38 | action_scaler=action_scaler, 39 | reward_scaler=reward_scaler, 40 | ) 41 | td3 = config.create() 42 | algo_tester(td3, observation_shape) # type: ignore 43 | -------------------------------------------------------------------------------- /tests/algos/qlearning/test_td3_plus_bc.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos.qlearning.td3_plus_bc import TD3PlusBCConfig 6 | from d3rlpy.models import ( 7 | MeanQFunctionFactory, 8 | QFunctionFactory, 9 | QRQFunctionFactory, 10 | ) 11 | from d3rlpy.types import Shape 12 | 13 | from ...models.torch.model_test import DummyEncoderFactory 14 | from ...testing_utils import create_scaler_tuple 15 | from .algo_test import algo_tester 16 | 17 | 18 | @pytest.mark.parametrize( 19 | "observation_shape", [(100,), (4, 32, 32), ((100,), (200,))] 20 | ) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 25 | def test_td3_plus_bc( 26 | observation_shape: Shape, 27 | q_func_factory: QFunctionFactory, 28 | scalers: Optional[str], 29 | ) -> None: 30 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 31 | scalers, observation_shape 32 | ) 33 | config = TD3PlusBCConfig( 34 | actor_encoder_factory=DummyEncoderFactory(), 35 | critic_encoder_factory=DummyEncoderFactory(), 36 | q_func_factory=q_func_factory, 37 | observation_scaler=observation_scaler, 38 | action_scaler=action_scaler, 39 | reward_scaler=reward_scaler, 40 | ) 41 | td3 = config.create() 42 | algo_tester(td3, observation_shape) # type: ignore 43 | -------------------------------------------------------------------------------- /tests/algos/qlearning/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/algos/qlearning/torch/__init__.py -------------------------------------------------------------------------------- /tests/algos/qlearning/torch/test_utility.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from d3rlpy.algos.qlearning.torch.utility import sample_q_values_with_policy 4 | from d3rlpy.models.builders import ( 5 | create_continuous_q_function, 6 | create_normal_policy, 7 | ) 8 | from d3rlpy.models.q_functions import MeanQFunctionFactory 9 | from d3rlpy.types import Shape 10 | 11 | from ....models.torch.model_test import DummyEncoderFactory 12 | from ....testing_utils import create_torch_observations 13 | 14 | 15 | @pytest.mark.parametrize( 16 | "observation_shape", [(100,), (4, 84, 84), ((100,), (200,))] 17 | ) 18 | @pytest.mark.parametrize("action_size", [4]) 19 | @pytest.mark.parametrize("n_action_samples", [10]) 20 | @pytest.mark.parametrize("batch_size", [256]) 21 | @pytest.mark.parametrize("n_critics", [2]) 22 | def test_sample_q_values_with_policy( 23 | observation_shape: Shape, 24 | action_size: int, 25 | n_action_samples: int, 26 | batch_size: int, 27 | n_critics: int, 28 | ) -> None: 29 | policy = create_normal_policy( 30 | observation_shape=observation_shape, 31 | action_size=action_size, 32 | encoder_factory=DummyEncoderFactory(), 33 | device="cpu:0", 34 | enable_ddp=False, 35 | ) 36 | _, q_func_forwarder = create_continuous_q_function( 37 | observation_shape=observation_shape, 38 | action_size=action_size, 39 | encoder_factory=DummyEncoderFactory(), 40 | q_func_factory=MeanQFunctionFactory(), 41 | n_ensembles=n_critics, 42 | device="cpu:0", 43 | enable_ddp=False, 44 | ) 45 | 46 | observations = create_torch_observations(observation_shape, batch_size) 47 | 48 | q_values, log_probs = sample_q_values_with_policy( 49 | policy=policy, 50 | q_func_forwarder=q_func_forwarder, 51 | policy_observations=observations, 52 | value_observations=observations, 53 | n_action_samples=n_action_samples, 54 | detach_policy_output=False, 55 | ) 56 | assert q_values.shape == (n_critics, batch_size, n_action_samples) 57 | assert log_probs.shape == (1, batch_size, n_action_samples) 58 | -------------------------------------------------------------------------------- /tests/algos/transformer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/algos/transformer/__init__.py -------------------------------------------------------------------------------- /tests/algos/transformer/test_action_samplers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from d3rlpy.algos import ( 5 | GreedyTransformerActionSampler, 6 | IdentityTransformerActionSampler, 7 | SoftmaxTransformerActionSampler, 8 | ) 9 | 10 | 11 | @pytest.mark.parametrize("action_size", [4]) 12 | def test_identity_transformer_action_sampler(action_size: int) -> None: 13 | action_sampler = IdentityTransformerActionSampler() 14 | 15 | x = np.random.random(action_size) 16 | action = action_sampler(x) 17 | 18 | assert np.all(action == x) 19 | 20 | 21 | @pytest.mark.parametrize("action_size", [10]) 22 | def test_softmax_transformer_action_sampler(action_size: int) -> None: 23 | action_sampler = SoftmaxTransformerActionSampler() 24 | 25 | logits = np.random.random(action_size) 26 | action = action_sampler(logits) 27 | assert isinstance(action, int) 28 | 29 | same_actions = [] 30 | for _ in range(100): 31 | same_actions.append(action == action_sampler(logits)) 32 | assert not all(same_actions) 33 | 34 | 35 | @pytest.mark.parametrize("action_size", [10]) 36 | def test_greedy_transformer_action_sampler(action_size: int) -> None: 37 | action_sampler = GreedyTransformerActionSampler() 38 | 39 | logits = np.random.random(action_size) 40 | action = action_sampler(logits) 41 | assert isinstance(action, int) 42 | assert action == np.argmax(logits) 43 | -------------------------------------------------------------------------------- /tests/algos/transformer/test_decision_transformer.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos import ( 6 | DecisionTransformerConfig, 7 | DiscreteDecisionTransformerConfig, 8 | ) 9 | from d3rlpy.types import Shape 10 | 11 | from ...models.torch.model_test import DummyEncoderFactory 12 | from ...testing_utils import create_scaler_tuple 13 | from .algo_test import algo_tester 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "observation_shape", [(100,), (4, 8, 8), ((100,), (200,))] 18 | ) 19 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 20 | def test_decision_transformer( 21 | observation_shape: Shape, 22 | scalers: Optional[str], 23 | ) -> None: 24 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 25 | scalers, observation_shape 26 | ) 27 | config = DecisionTransformerConfig( 28 | encoder_factory=DummyEncoderFactory(), 29 | observation_scaler=observation_scaler, 30 | action_scaler=action_scaler, 31 | reward_scaler=reward_scaler, 32 | ) 33 | dt = config.create() 34 | algo_tester( 35 | dt, # type: ignore 36 | observation_shape, 37 | ) 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "observation_shape", [(100,), (4, 8, 8), ((100,), (200,))] 42 | ) 43 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 44 | def test_discrete_decision_transformer( 45 | observation_shape: Shape, 46 | scalers: Optional[str], 47 | ) -> None: 48 | observation_scaler, _, reward_scaler = create_scaler_tuple( 49 | scalers, observation_shape 50 | ) 51 | config = DiscreteDecisionTransformerConfig( 52 | encoder_factory=DummyEncoderFactory(), 53 | observation_scaler=observation_scaler, 54 | reward_scaler=reward_scaler, 55 | num_heads=4, 56 | ) 57 | dt = config.create() 58 | algo_tester( 59 | dt, # type: ignore 60 | observation_shape, 61 | action_size=10, 62 | ) 63 | -------------------------------------------------------------------------------- /tests/algos/transformer/test_tacr.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import pytest 4 | 5 | from d3rlpy.algos import TACRConfig 6 | from d3rlpy.types import Shape 7 | 8 | from ...models.torch.model_test import DummyEncoderFactory 9 | from ...testing_utils import create_scaler_tuple 10 | from .algo_test import algo_tester 11 | 12 | 13 | @pytest.mark.parametrize( 14 | "observation_shape", [(100,), (4, 8, 8), ((100,), (200,))] 15 | ) 16 | @pytest.mark.parametrize("scalers", [None, "min_max"]) 17 | def test_tacr(observation_shape: Shape, scalers: Optional[str]) -> None: 18 | observation_scaler, action_scaler, reward_scaler = create_scaler_tuple( 19 | scalers, observation_shape 20 | ) 21 | config = TACRConfig( 22 | actor_encoder_factory=DummyEncoderFactory(), 23 | critic_encoder_factory=DummyEncoderFactory(), 24 | observation_scaler=observation_scaler, 25 | action_scaler=action_scaler, 26 | reward_scaler=reward_scaler, 27 | ) 28 | tacr = config.create() 29 | algo_tester( 30 | tacr, # type: ignore 31 | observation_shape, 32 | ) 33 | -------------------------------------------------------------------------------- /tests/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/dataset/__init__.py -------------------------------------------------------------------------------- /tests/dataset/test_buffers.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import pytest 4 | 5 | from d3rlpy.dataset import FIFOBuffer, InfiniteBuffer 6 | 7 | from ..testing_utils import create_episode 8 | 9 | 10 | @pytest.mark.parametrize("observation_shape", [(4,)]) 11 | @pytest.mark.parametrize("action_size", [2]) 12 | @pytest.mark.parametrize("length", [100]) 13 | @pytest.mark.parametrize("terminated", [False, True]) 14 | def test_infinite_buffer( 15 | observation_shape: Sequence[int], 16 | action_size: int, 17 | length: int, 18 | terminated: bool, 19 | ) -> None: 20 | buffer = InfiniteBuffer() 21 | 22 | for i in range(10): 23 | episode = create_episode( 24 | observation_shape, action_size, length, terminated=terminated 25 | ) 26 | for j in range(episode.transition_count): 27 | buffer.append(episode, j) 28 | 29 | if terminated: 30 | assert buffer.transition_count == (i + 1) * (length) 31 | else: 32 | assert buffer.transition_count == (i + 1) * (length - 1) 33 | assert len(buffer.episodes) == i + 1 34 | 35 | 36 | @pytest.mark.parametrize("observation_shape", [(4,)]) 37 | @pytest.mark.parametrize("action_size", [2]) 38 | @pytest.mark.parametrize("length", [100]) 39 | @pytest.mark.parametrize("limit", [500]) 40 | @pytest.mark.parametrize("terminated", [False, True]) 41 | def test_fifo_buffer( 42 | observation_shape: Sequence[int], 43 | action_size: int, 44 | length: int, 45 | limit: int, 46 | terminated: bool, 47 | ) -> None: 48 | buffer = FIFOBuffer(limit) 49 | 50 | for i in range(10): 51 | episode = create_episode( 52 | observation_shape, action_size, length, terminated=terminated 53 | ) 54 | for j in range(episode.transition_count): 55 | buffer.append(episode, j) 56 | 57 | if i >= 5: 58 | assert buffer.transition_count == limit 59 | if terminated: 60 | assert len(buffer.episodes) == 5 61 | else: 62 | assert len(buffer.episodes) == 6 63 | else: 64 | if terminated: 65 | assert buffer.transition_count == (i + 1) * length 66 | else: 67 | assert buffer.transition_count == (i + 1) * (length - 1) 68 | assert len(buffer.episodes) == i + 1 69 | -------------------------------------------------------------------------------- /tests/dataset/test_compat.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from d3rlpy.dataset import MDPDataset, ReplayBuffer 7 | 8 | from ..testing_utils import create_observation 9 | 10 | 11 | @pytest.mark.parametrize("observation_shape", [(4,)]) 12 | @pytest.mark.parametrize("action_size", [2]) 13 | @pytest.mark.parametrize("length", [100]) 14 | @pytest.mark.parametrize("num_episodes", [10]) 15 | def test_replay_buffer( 16 | observation_shape: Sequence[int], 17 | action_size: int, 18 | length: int, 19 | num_episodes: int, 20 | ) -> None: 21 | observations = [] 22 | actions = [] 23 | rewards = [] 24 | terminals = [] 25 | for _ in range(num_episodes): 26 | for i in range(length): 27 | observations.append(create_observation(observation_shape)) 28 | actions.append(np.random.random(action_size)) 29 | rewards.append(np.random.random()) 30 | terminals.append(float(i == length - 1)) 31 | 32 | dataset = MDPDataset( 33 | observations=np.array(observations), 34 | actions=np.array(actions), 35 | rewards=np.array(rewards), 36 | terminals=np.array(terminals), 37 | ) 38 | 39 | assert isinstance(dataset, ReplayBuffer) 40 | assert len(dataset.episodes) == num_episodes 41 | -------------------------------------------------------------------------------- /tests/dataset/test_episode_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from d3rlpy.dataset import EpisodeGenerator 5 | from d3rlpy.types import Float32NDArray, Shape 6 | 7 | from ..testing_utils import create_observations 8 | 9 | 10 | @pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))]) 11 | @pytest.mark.parametrize("action_size", [2]) 12 | @pytest.mark.parametrize("length", [1000]) 13 | @pytest.mark.parametrize("terminal", [False, True]) 14 | def test_episode_generator( 15 | observation_shape: Shape, action_size: int, length: int, terminal: bool 16 | ) -> None: 17 | observations = create_observations(observation_shape, length) 18 | actions = np.random.random((length, action_size)) 19 | rewards: Float32NDArray = np.random.random((length, 1)).astype(np.float32) 20 | terminals: Float32NDArray = np.zeros(length, dtype=np.float32) 21 | timeouts: Float32NDArray = np.zeros(length, dtype=np.float32) 22 | for i in range(length // 100): 23 | if terminal: 24 | terminals[(i + 1) * 100 - 1] = 1.0 25 | else: 26 | timeouts[(i + 1) * 100 - 1] = 1.0 27 | 28 | episode_generator = EpisodeGenerator( 29 | observations=observations, 30 | actions=actions, 31 | rewards=rewards, 32 | terminals=terminals, 33 | timeouts=timeouts, 34 | ) 35 | 36 | episodes = episode_generator() 37 | assert len(episodes) == length // 100 38 | 39 | for episode in episodes: 40 | assert len(episode) == 100 41 | if isinstance(observation_shape[0], tuple): 42 | for i, shape in enumerate(observation_shape): 43 | assert isinstance(shape, tuple) 44 | assert episode.observations[i].shape == (100, *shape) 45 | else: 46 | assert isinstance(episode.observations, np.ndarray) 47 | assert episode.observations.shape == (100, *observation_shape) 48 | assert episode.actions.shape == (100, action_size) 49 | assert episode.rewards.shape == (100, 1) 50 | assert episode.terminated == terminal 51 | -------------------------------------------------------------------------------- /tests/dataset/test_io.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | from d3rlpy.dataset import Episode, dump, load 7 | from d3rlpy.types import Shape 8 | 9 | from ..testing_utils import create_episode 10 | 11 | 12 | @pytest.mark.parametrize("observation_shape", [(4,), ((4,), (8,))]) 13 | @pytest.mark.parametrize("action_size", [2]) 14 | @pytest.mark.parametrize("length", [100]) 15 | def test_dump_and_load( 16 | observation_shape: Shape, action_size: int, length: int 17 | ) -> None: 18 | episode1 = create_episode(observation_shape, action_size, length) 19 | episode2 = create_episode(observation_shape, action_size, length * 2) 20 | 21 | path = os.path.join("test_data", "data.h5") 22 | 23 | # dump 24 | with open(path, "w+b") as f: 25 | dump([episode1, episode2], f) 26 | 27 | # load 28 | with open(path, "rb") as f: 29 | loaded_episodes = load(Episode, f) 30 | assert len(loaded_episodes) == 2 31 | 32 | for episode, loaded_episode in zip([episode1, episode2], loaded_episodes): 33 | if isinstance(observation_shape[0], tuple): 34 | for i in range(len(observation_shape)): 35 | assert np.all( 36 | episode.observations[i] == loaded_episode.observations[i] 37 | ) 38 | else: 39 | assert np.all(episode.observations == loaded_episode.observations) 40 | assert np.all(episode.actions == loaded_episode.actions) 41 | assert np.all(episode.rewards == loaded_episode.rewards) 42 | assert np.all(episode.terminated == loaded_episode.terminated) 43 | -------------------------------------------------------------------------------- /tests/dummy_env.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import gym 4 | import numpy as np 5 | from gym.spaces import Box, Discrete 6 | 7 | from d3rlpy.types import NDArray 8 | 9 | 10 | class DummyAtari(gym.Env[NDArray, int]): 11 | def __init__(self, grayscale: bool = True, squeeze: bool = False): 12 | if grayscale: 13 | shape = (84, 84) if squeeze else (84, 84, 1) 14 | else: 15 | shape = (84, 84, 3) 16 | self.observation_space = Box( 17 | low=np.zeros(shape), 18 | high=np.zeros(shape) + 255, 19 | dtype=np.uint8, 20 | ) 21 | self.action_space = Discrete(4) 22 | self.t = 1 23 | 24 | def step( 25 | self, action: int 26 | ) -> tuple[NDArray, float, bool, bool, dict[str, Any]]: 27 | observation = self.observation_space.sample() 28 | reward = np.random.random() 29 | return observation, reward, False, self.t % 80 == 0, {} 30 | 31 | def reset(self, **kwargs: Any) -> tuple[NDArray, dict[str, Any]]: 32 | self.t = 1 33 | return self.observation_space.sample(), {} 34 | -------------------------------------------------------------------------------- /tests/envs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/envs/__init__.py -------------------------------------------------------------------------------- /tests/metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/metrics/__init__.py -------------------------------------------------------------------------------- /tests/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/models/__init__.py -------------------------------------------------------------------------------- /tests/models/test_lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch.optim import SGD 4 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR 5 | 6 | from d3rlpy.optimizers.lr_schedulers import ( 7 | CosineAnnealingLRFactory, 8 | WarmupSchedulerFactory, 9 | ) 10 | 11 | 12 | @pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)]) 13 | def test_warmup_scheduler_factory(module: torch.nn.Module) -> None: 14 | factory = WarmupSchedulerFactory(warmup_steps=1000) 15 | 16 | lr_scheduler = factory.create(SGD(module.parameters(), 1e-4)) 17 | 18 | assert isinstance(lr_scheduler, LambdaLR) 19 | 20 | # check serialization and deserialization 21 | WarmupSchedulerFactory.deserialize(factory.serialize()) 22 | 23 | 24 | @pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)]) 25 | def test_cosine_annealing_lr_factory(module: torch.nn.Module) -> None: 26 | factory = CosineAnnealingLRFactory(T_max=1000) 27 | 28 | lr_scheduler = factory.create(SGD(module.parameters(), 1e-4)) 29 | 30 | assert isinstance(lr_scheduler, CosineAnnealingLR) 31 | 32 | # check serialization and deserialization 33 | CosineAnnealingLRFactory.deserialize(factory.serialize()) 34 | -------------------------------------------------------------------------------- /tests/models/torch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/models/torch/__init__.py -------------------------------------------------------------------------------- /tests/models/torch/q_functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/models/torch/q_functions/__init__.py -------------------------------------------------------------------------------- /tests/models/torch/q_functions/test_utility.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from d3rlpy.models.torch.q_functions.utility import ( 6 | compute_quantile_huber_loss, 7 | pick_quantile_value_by_action, 8 | pick_value_by_action, 9 | ) 10 | 11 | from ..model_test import ref_quantile_huber_loss 12 | 13 | 14 | @pytest.mark.parametrize("batch_size", [32]) 15 | @pytest.mark.parametrize("action_size", [2]) 16 | @pytest.mark.parametrize("keepdims", [True, False]) 17 | def test_pick_value_by_action( 18 | batch_size: int, action_size: int, keepdims: bool 19 | ) -> None: 20 | values = torch.rand(batch_size, action_size) 21 | action = torch.randint(action_size, size=(batch_size,)) 22 | 23 | rets = pick_value_by_action(values, action, keepdims) 24 | 25 | if keepdims: 26 | assert rets.shape == (batch_size, 1) 27 | else: 28 | assert rets.shape == (batch_size,) 29 | 30 | rets = rets.view(batch_size, -1) 31 | 32 | for i in range(batch_size): 33 | assert (rets[i] == values[i][action[i]]).all() 34 | 35 | 36 | @pytest.mark.parametrize("batch_size", [32]) 37 | @pytest.mark.parametrize("action_size", [2]) 38 | @pytest.mark.parametrize("n_quantiles", [200]) 39 | @pytest.mark.parametrize("keepdims", [True, False]) 40 | def test_pick_quantile_value_by_action( 41 | batch_size: int, 42 | action_size: int, 43 | n_quantiles: int, 44 | keepdims: bool, 45 | ) -> None: 46 | values = torch.rand(batch_size, action_size, n_quantiles) 47 | action = torch.randint(action_size, size=(batch_size,)) 48 | 49 | rets = pick_quantile_value_by_action(values, action, keepdims) 50 | 51 | if keepdims: 52 | assert rets.shape == (batch_size, 1, n_quantiles) 53 | else: 54 | assert rets.shape == (batch_size, n_quantiles) 55 | 56 | rets = rets.view(batch_size, -1) 57 | 58 | for i in range(batch_size): 59 | assert (rets[i] == values[i][action[i]]).all() 60 | 61 | 62 | @pytest.mark.parametrize("batch_size", [32]) 63 | @pytest.mark.parametrize("n_quantiles", [200]) 64 | def test_compute_quantile_huber_loss(batch_size: int, n_quantiles: int) -> None: 65 | y = np.random.random((batch_size, n_quantiles, 1)) 66 | target = np.random.random((batch_size, 1, n_quantiles)) 67 | taus = np.random.random((1, 1, n_quantiles)) 68 | 69 | ref_loss = ref_quantile_huber_loss(y, target, taus, n_quantiles) 70 | loss = compute_quantile_huber_loss( 71 | torch.tensor(y), torch.tensor(target), torch.tensor(taus) 72 | ) 73 | 74 | assert np.allclose(loss.cpu().detach().numpy(), ref_loss) 75 | -------------------------------------------------------------------------------- /tests/models/torch/test_parameters.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence 2 | 3 | import pytest 4 | import torch 5 | 6 | from d3rlpy.models.torch.parameters import Parameter, get_parameter 7 | 8 | 9 | @pytest.mark.parametrize("shape", [(100,)]) 10 | def test_parameter(shape: Sequence[int]) -> None: 11 | data = torch.rand(shape) 12 | parameter = Parameter(data) 13 | 14 | assert get_parameter(parameter).data.shape == shape 15 | assert torch.all(get_parameter(parameter).data == data) 16 | -------------------------------------------------------------------------------- /tests/models/torch/test_q_functions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from d3rlpy.models.builders import create_continuous_q_function 5 | from d3rlpy.models.q_functions import ( 6 | MeanQFunctionFactory, 7 | QFunctionFactory, 8 | QRQFunctionFactory, 9 | ) 10 | from d3rlpy.models.torch.q_functions import compute_max_with_n_actions 11 | from d3rlpy.types import Shape 12 | 13 | from ...testing_utils import create_torch_observations 14 | from .model_test import DummyEncoderFactory 15 | 16 | 17 | @pytest.mark.parametrize( 18 | "observation_shape", [(4, 84, 84), (100,), ((100,), (200,))] 19 | ) 20 | @pytest.mark.parametrize("action_size", [3]) 21 | @pytest.mark.parametrize( 22 | "q_func_factory", [MeanQFunctionFactory(), QRQFunctionFactory()] 23 | ) 24 | @pytest.mark.parametrize("n_ensembles", [2]) 25 | @pytest.mark.parametrize("batch_size", [100]) 26 | @pytest.mark.parametrize("n_actions", [10]) 27 | @pytest.mark.parametrize("lam", [0.75]) 28 | def test_compute_max_with_n_actions( 29 | observation_shape: Shape, 30 | action_size: int, 31 | q_func_factory: QFunctionFactory, 32 | n_ensembles: int, 33 | batch_size: int, 34 | n_actions: int, 35 | lam: float, 36 | ) -> None: 37 | _, forwarder = create_continuous_q_function( 38 | observation_shape, 39 | action_size, 40 | DummyEncoderFactory(), 41 | q_func_factory, 42 | n_ensembles=n_ensembles, 43 | device="cpu:0", 44 | enable_ddp=False, 45 | ) 46 | x = create_torch_observations(observation_shape, batch_size) 47 | actions = torch.rand(batch_size, n_actions, action_size) 48 | 49 | y = compute_max_with_n_actions(x, actions, forwarder, lam) 50 | 51 | if isinstance(q_func_factory, MeanQFunctionFactory): 52 | assert y.shape == (batch_size, 1) 53 | else: 54 | assert isinstance(q_func_factory, QRQFunctionFactory) 55 | assert y.shape == (batch_size, q_func_factory.n_quantiles) 56 | -------------------------------------------------------------------------------- /tests/models/torch/test_v_functions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | from d3rlpy.models.torch.v_functions import ( 6 | ValueFunction, 7 | compute_v_function_error, 8 | ) 9 | from d3rlpy.types import Shape 10 | 11 | from ...testing_utils import create_torch_observations 12 | from .model_test import DummyEncoder, check_parameter_updates 13 | 14 | 15 | @pytest.mark.parametrize("observation_shape", [(100,), ((100,), (200,))]) 16 | @pytest.mark.parametrize("batch_size", [32]) 17 | def test_value_function(observation_shape: Shape, batch_size: int) -> None: 18 | encoder = DummyEncoder(observation_shape) 19 | v_func = ValueFunction(encoder, encoder.get_feature_size()) 20 | 21 | # check output shape 22 | x = create_torch_observations(observation_shape, batch_size) 23 | y = v_func(x) 24 | assert y.shape == (batch_size, 1) 25 | 26 | # check compute_error 27 | returns = torch.rand(batch_size, 1) 28 | loss = compute_v_function_error(v_func, x, returns) 29 | assert torch.allclose(loss, F.mse_loss(y, returns)) 30 | 31 | # check layer connections 32 | check_parameter_updates( 33 | v_func, 34 | (x,), 35 | ) 36 | -------------------------------------------------------------------------------- /tests/ope/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/ope/__init__.py -------------------------------------------------------------------------------- /tests/optimizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/optimizers/__init__.py -------------------------------------------------------------------------------- /tests/optimizers/test_lr_schedulers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from torch.optim import SGD 5 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR 6 | 7 | from d3rlpy.optimizers.lr_schedulers import ( 8 | CosineAnnealingLRFactory, 9 | WarmupSchedulerFactory, 10 | ) 11 | 12 | 13 | @pytest.mark.parametrize("warmup_steps", [100]) 14 | @pytest.mark.parametrize("lr", [1e-4]) 15 | @pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)]) 16 | def test_warmup_scheduler_factory( 17 | warmup_steps: int, lr: float, module: torch.nn.Module 18 | ) -> None: 19 | factory = WarmupSchedulerFactory(warmup_steps) 20 | 21 | lr_scheduler = factory.create(SGD(module.parameters(), lr=lr)) 22 | 23 | assert np.allclose(lr_scheduler.get_lr()[0], lr / warmup_steps) 24 | for _ in range(warmup_steps): 25 | lr_scheduler.step() 26 | assert lr_scheduler.get_lr()[0] == lr 27 | 28 | assert isinstance(lr_scheduler, LambdaLR) 29 | 30 | # check serialization and deserialization 31 | WarmupSchedulerFactory.deserialize(factory.serialize()) 32 | 33 | 34 | @pytest.mark.parametrize("T_max", [100]) 35 | @pytest.mark.parametrize("module", [torch.nn.Linear(2, 3)]) 36 | def test_cosine_annealing_factory(T_max: int, module: torch.nn.Module) -> None: 37 | factory = CosineAnnealingLRFactory(T_max=T_max) 38 | 39 | lr_scheduler = factory.create(SGD(module.parameters())) 40 | 41 | assert isinstance(lr_scheduler, CosineAnnealingLR) 42 | 43 | # check serialization and deserialization 44 | CosineAnnealingLRFactory.deserialize(factory.serialize()) 45 | -------------------------------------------------------------------------------- /tests/preprocessing/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/preprocessing/__init__.py -------------------------------------------------------------------------------- /tests/preprocessing/test_base.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from d3rlpy.preprocessing.base import add_leading_dims, add_leading_dims_numpy 5 | 6 | 7 | def test_add_leading_dims() -> None: 8 | x = torch.rand(3) 9 | target = torch.rand(1, 2, 3) 10 | assert add_leading_dims(x, target).shape == (1, 1, 3) 11 | 12 | 13 | def test_add_leading_dims_numpy() -> None: 14 | x = np.random.random(3) 15 | target = np.random.random((1, 2, 3)) 16 | assert add_leading_dims_numpy(x, target).shape == (1, 1, 3) 17 | -------------------------------------------------------------------------------- /tests/test_dataclass_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | 3 | import torch 4 | 5 | from d3rlpy.dataclass_utils import asdict_as_float, asdict_without_copy 6 | 7 | 8 | @dataclasses.dataclass(frozen=True) 9 | class A: 10 | a: int 11 | 12 | 13 | @dataclasses.dataclass(frozen=True) 14 | class D: 15 | a: A 16 | b: float 17 | c: str 18 | 19 | 20 | def test_asdict_without_any() -> None: 21 | a = A(1) 22 | d = D(a, 2.0, "3") 23 | dict_d = asdict_without_copy(d) 24 | assert dict_d["a"] is a 25 | assert dict_d["b"] == 2.0 26 | assert dict_d["c"] == "3" 27 | 28 | 29 | @dataclasses.dataclass(frozen=True) 30 | class D2: 31 | a: float 32 | b: torch.Tensor 33 | 34 | 35 | def test_asdict_as_float() -> None: 36 | b = torch.rand([], dtype=torch.float32) 37 | d = D2(a=1.0, b=b) 38 | dict_d = asdict_as_float(d) 39 | assert dict_d["a"] == 1.0 40 | assert dict_d["b"] == b.numpy() 41 | -------------------------------------------------------------------------------- /tests/test_datasets.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from d3rlpy.datasets import get_cartpole, get_dataset, get_minari, get_pendulum 4 | 5 | 6 | @pytest.mark.parametrize("dataset_type", ["replay", "random"]) 7 | def test_get_cartpole(dataset_type: str) -> None: 8 | get_cartpole(dataset_type=dataset_type) 9 | 10 | 11 | @pytest.mark.parametrize("dataset_type", ["replay", "random"]) 12 | def test_get_pendulum(dataset_type: str) -> None: 13 | get_pendulum(dataset_type=dataset_type) 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "env_name", 18 | ["cartpole-random", "pendulum-random"], 19 | ) 20 | def test_get_dataset(env_name: str) -> None: 21 | _, env = get_dataset(env_name) 22 | if env_name == "cartpole-random": 23 | assert env.unwrapped.spec.id == "CartPole-v1" 24 | elif env_name == "pendulum-random": 25 | assert env.unwrapped.spec.id == "Pendulum-v1" 26 | 27 | 28 | @pytest.mark.parametrize( 29 | "dataset_name, env_name", 30 | [ 31 | ("D4RL/door/cloned-v2", "AdroitHandDoor-v1"), 32 | ("D4RL/kitchen/complete-v2", "FrankaKitchen-v1"), 33 | ], 34 | ) 35 | @pytest.mark.parametrize("tuple_observation", [False, True]) 36 | def test_get_minari( 37 | dataset_name: str, env_name: str, tuple_observation: bool 38 | ) -> None: 39 | dataset, env = get_minari(dataset_name, tuple_observation=tuple_observation) 40 | assert env.unwrapped.spec.id == env_name # type: ignore 41 | 42 | if tuple_observation: 43 | # check shape 44 | ep = dataset.episodes[0] 45 | ref_shape0 = ep.observations[0].shape[1:] 46 | ref_shape1 = ep.observations[1].shape[1:] 47 | obs, _ = env.reset() 48 | assert obs[0].shape == ref_shape0 49 | assert obs[1].shape == ref_shape1 50 | obs, _, _, _, _ = env.step(env.action_space.sample()) 51 | assert obs[0].shape == ref_shape0 52 | assert obs[1].shape == ref_shape1 53 | else: 54 | # check shape 55 | ref_shape = dataset.episodes[0].observations.shape[1:] # type: ignore 56 | obs, _ = env.reset() 57 | assert obs.shape == ref_shape 58 | obs, _, _, _, _ = env.step(env.action_space.sample()) 59 | assert obs.shape == ref_shape 60 | -------------------------------------------------------------------------------- /tests/test_itertools.py: -------------------------------------------------------------------------------- 1 | from d3rlpy.itertools import first_flag, last_flag 2 | 3 | 4 | def test_last_flag() -> None: 5 | x = [1, 2, 3, 4, 5] 6 | i = 0 7 | for is_last, value in last_flag(x): 8 | if i == len(x) - 1: 9 | assert is_last 10 | else: 11 | assert not is_last 12 | assert value == x[i] 13 | i += 1 14 | 15 | 16 | def test_first_flag() -> None: 17 | x = [1, 2, 3, 4, 5] 18 | i = 0 19 | for is_first, value in first_flag(x): 20 | if i == 0: 21 | assert is_first 22 | else: 23 | assert not is_first 24 | assert value == x[i] 25 | i += 1 26 | -------------------------------------------------------------------------------- /tests/tokenizers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/takuseno/d3rlpy/4f0956ba8db469ebeeb617bb28060a8401a7dfa6/tests/tokenizers/__init__.py -------------------------------------------------------------------------------- /tests/tokenizers/test_tokenizers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from d3rlpy.tokenizers import FloatTokenizer 4 | from d3rlpy.types import NDArray 5 | 6 | 7 | def test_float_tokenizer() -> None: 8 | tokenizer = FloatTokenizer(num_bins=100, use_mu_law_encode=False) 9 | v: NDArray = (2.0 * np.arange(100) / 100 - 1).astype(np.float32) 10 | tokenized_v = tokenizer(v) 11 | assert np.all(tokenized_v == np.arange(100)) 12 | 13 | # check mu_law_encode 14 | tokenizer = FloatTokenizer(num_bins=100) 15 | v = np.arange(100) - 50 16 | tokenized_v = tokenizer(v) 17 | assert np.all(tokenized_v >= 0) 18 | assert np.all(tokenized_v < 100) 19 | 20 | # check token_offset 21 | tokenizer = FloatTokenizer( 22 | num_bins=100, use_mu_law_encode=False, token_offset=1 23 | ) 24 | v = np.array([-1, 1]) 25 | tokenized_v = tokenizer(v) 26 | assert tokenized_v[0] == 1 27 | assert tokenized_v[1] == 100 28 | 29 | # check decode 30 | tokenizer = FloatTokenizer(num_bins=1000000) 31 | v = np.arange(100) - 50 32 | decoded_v = tokenizer.decode(tokenizer(v)) 33 | assert np.allclose(decoded_v, v, atol=1e-3) 34 | 35 | # check decode with multi-dimension 36 | v = np.reshape(v, [5, -1]) 37 | decoded_v = tokenizer.decode(tokenizer(v)) 38 | assert v.shape == decoded_v.shape 39 | assert np.allclose(decoded_v, v, atol=1e-3) 40 | -------------------------------------------------------------------------------- /tests/tokenizers/test_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from d3rlpy.tokenizers import mu_law_decode, mu_law_encode 4 | from d3rlpy.types import NDArray 5 | 6 | 7 | def test_mu_law_encode() -> None: 8 | v = np.arange(100) - 50 9 | encoded_v = mu_law_encode(v, mu=100, basis=256) 10 | assert np.all(encoded_v < 1) 11 | assert np.all(-1 < encoded_v) 12 | 13 | 14 | def test_mu_law_decode() -> None: 15 | v: NDArray = np.array(np.arange(100) - 50, dtype=np.float32) 16 | decoded_v = mu_law_decode( 17 | mu_law_encode(v, mu=100, basis=256), mu=100, basis=256 18 | ) 19 | assert np.allclose(decoded_v, v) 20 | --------------------------------------------------------------------------------