├── .github └── workflows │ ├── doc.yml │ ├── release.yml │ └── unit_test.yml ├── .gitignore ├── .style.yapf ├── CHANGELOG ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── README.zh.md ├── assets └── framework.png ├── docs ├── Makefile └── source │ ├── _static │ ├── css │ │ └── style.css │ └── images │ │ └── logo.png │ ├── _templates │ └── layout.html │ ├── all.mk │ ├── api_doc │ ├── agents │ │ └── index.rst │ ├── algorithms │ │ └── index.rst │ ├── datasets │ │ └── index.rst │ ├── generative_models │ │ └── index.rst │ ├── neural_network │ │ └── index.rst │ ├── numerical_methods │ │ └── index.rst │ ├── rl_modules │ │ └── index.rst │ └── utils │ │ └── index.rst │ ├── concepts │ └── index.rst │ ├── conf.py │ ├── diagrams.mk │ ├── graphviz.mk │ ├── index.rst │ ├── tutorials │ ├── installation │ │ └── index.rst │ └── quick_start │ │ └── index.rst │ └── user_guide │ ├── evaluating_agents.rst │ ├── index.rst │ ├── installation.rst │ ├── training_agents.rst │ └── training_generative_models.rst ├── grl ├── __init__.py ├── agents │ ├── __init__.py │ ├── base.py │ ├── gm.py │ ├── idql.py │ ├── qgpo.py │ └── srpo.py ├── algorithms │ ├── __init__.py │ ├── base.py │ ├── gmpg.py │ ├── gmpo.py │ ├── idql.py │ ├── qgpo.py │ └── srpo.py ├── datasets │ ├── __init__.py │ ├── d4rl.py │ ├── gp.py │ ├── minari_dataset.py │ └── qgpo.py ├── generative_models │ ├── __init__.py │ ├── bridge_flow_model │ │ ├── function.py │ │ ├── guided_bridge_conditional_flow_model.py │ │ └── schrodinger_bridge_conditional_flow_model.py │ ├── conditional_flow_model │ │ ├── __init__.py │ │ ├── guided_conditional_flow_model.py │ │ ├── independent_conditional_flow_model.py │ │ └── optimal_transport_conditional_flow_model.py │ ├── diffusion_model │ │ ├── __init__.py │ │ ├── diffusion_model.py │ │ ├── energy_conditional_diffusion_model.py │ │ └── guided_diffusion_model.py │ ├── diffusion_process.py │ ├── discrete_model │ │ ├── __init__.py │ │ └── discrete_flow_matching.py │ ├── intrinsic_model.py │ ├── metric.py │ ├── model_functions │ │ ├── __init__.py │ │ ├── data_prediction_function.py │ │ ├── noise_function.py │ │ ├── score_function.py │ │ └── velocity_function.py │ ├── normalizing_flow │ │ └── flow.py │ ├── random_generator.py │ ├── sro.py │ ├── stochastic_process.py │ └── variational_autoencoder.py ├── neural_network │ ├── __init__.py │ ├── activation.py │ ├── encoders.py │ ├── neural_operator │ │ ├── __init__.py │ │ └── fourier_neural_operator.py │ ├── residual_network.py │ ├── transformers │ │ ├── __init__.py │ │ ├── dit.py │ │ ├── maxvit.py │ │ └── uvit.py │ └── unet │ │ ├── __init__.py │ │ └── unet_2D.py ├── numerical_methods │ ├── __init__.py │ ├── monte_carlo.py │ ├── numerical_solvers │ │ ├── __init__.py │ │ ├── dpm_solver.py │ │ ├── ode_solver.py │ │ └── sde_solver.py │ ├── ode.py │ ├── probability_path.py │ └── sde.py ├── rl_modules │ ├── __init__.py │ ├── policy │ │ ├── __init__.py │ │ └── base.py │ ├── replay_buffer │ │ ├── __init__.py │ │ └── buffer_by_torchrl.py │ ├── simulators │ │ ├── __init__.py │ │ ├── base.py │ │ ├── dm_control_env_simulator.py │ │ └── gym_env_simulator.py │ ├── value_network │ │ ├── __init__.py │ │ ├── one_shot_value_function.py │ │ ├── q_network.py │ │ └── value_network.py │ └── world_model │ │ ├── dynamic_model.py │ │ └── state_prior_dynamic_model.py ├── unittest │ ├── agents │ │ └── functions.py │ ├── generative_models │ │ ├── diffusion_model │ │ │ └── test_diffusion_model.py │ │ └── test_random_generator.py │ ├── neural_network │ │ ├── test_activation.py │ │ └── test_encoder.py │ ├── rl_modules │ │ └── replay_buffer │ │ │ └── test_buffer_by_torchrl.py │ └── utils │ │ ├── test_model_utils.py │ │ └── test_plot.py └── utils │ ├── __init__.py │ ├── config.py │ ├── huggingface.py │ ├── log.py │ ├── model_utils.py │ ├── plot.py │ └── statistics.py ├── grl_pipelines ├── __init__.py ├── base.py ├── benchmark │ ├── README.md │ ├── README.zh.md │ ├── generative_policy.png │ ├── gmpg │ │ ├── gvp │ │ │ ├── __init__.py │ │ │ ├── dm_control_suit_cartpole_swing.py │ │ │ ├── dm_control_suit_cheetah_run.py │ │ │ ├── dm_control_suit_finger_turn_hard.py │ │ │ ├── dm_control_suit_fish_swim.py │ │ │ ├── dm_control_suit_humanoid_run.py │ │ │ ├── dm_control_suit_manipulator_insert_ball.py │ │ │ ├── dm_control_suit_manipulator_insert_peg.py │ │ │ ├── dm_control_suit_rodent_gaps.py │ │ │ ├── dm_control_suit_walker_stand.py │ │ │ ├── dm_control_suit_walker_walk.py │ │ │ ├── dmcontrol_suit_humanoid_run.py │ │ │ ├── halfcheetah_medium.py │ │ │ ├── halfcheetah_medium_expert.py │ │ │ ├── halfcheetah_medium_replay.py │ │ │ ├── hopper_medium.py │ │ │ ├── hopper_medium_expert.py │ │ │ ├── hopper_medium_replay.py │ │ │ ├── walker2d_medium.py │ │ │ ├── walker2d_medium_expert.py │ │ │ └── walker2d_medium_replay.py │ │ ├── icfm │ │ │ ├── __init__.py │ │ │ ├── halfcheetah_medium.py │ │ │ ├── halfcheetah_medium_expert.py │ │ │ ├── halfcheetah_medium_replay.py │ │ │ ├── hopper_medium.py │ │ │ ├── hopper_medium_expert.py │ │ │ ├── hopper_medium_replay.py │ │ │ ├── walker2d_medium.py │ │ │ ├── walker2d_medium_expert.py │ │ │ └── walker2d_medium_replay.py │ │ └── vpsde │ │ │ ├── __init__.py │ │ │ ├── antmaze_large_diverse.py │ │ │ ├── antmaze_large_play.py │ │ │ ├── antmaze_medium_diverse.py │ │ │ ├── antmaze_medium_play.py │ │ │ ├── antmaze_umaze.py │ │ │ ├── antmaze_umaze_diverse.py │ │ │ ├── halfcheetah_medium.py │ │ │ ├── halfcheetah_medium_expert.py │ │ │ ├── halfcheetah_medium_replay.py │ │ │ ├── hopper_medium.py │ │ │ ├── hopper_medium_expert.py │ │ │ ├── hopper_medium_replay.py │ │ │ ├── walker2d_medium.py │ │ │ ├── walker2d_medium_expert.py │ │ │ └── walker2d_medium_replay.py │ ├── gmpo │ │ ├── gvp │ │ │ ├── __init__.py │ │ │ ├── antmaze_large_diverse.py │ │ │ ├── antmaze_large_play.py │ │ │ ├── antmaze_medium_diverse.py │ │ │ ├── antmaze_medium_play.py │ │ │ ├── antmaze_umaze.py │ │ │ ├── antmaze_umaze_diverse.py │ │ │ ├── dm_control_suit_cartpole_swing.py │ │ │ ├── dm_control_suit_cheetah_run.py │ │ │ ├── dm_control_suit_finger_turn_hard.py │ │ │ ├── dm_control_suit_fish_swim.py │ │ │ ├── dm_control_suit_humanoid_run.py │ │ │ ├── dm_control_suit_manipulator_insert_ball.py │ │ │ ├── dm_control_suit_manipulator_insert_peg.py │ │ │ ├── dm_control_suit_walker_stand.py │ │ │ ├── dm_control_suit_walker_walk.py │ │ │ ├── halfcheetah_medium.py │ │ │ ├── halfcheetah_medium_expert.py │ │ │ ├── halfcheetah_medium_replay.py │ │ │ ├── hopper_medium.py │ │ │ ├── hopper_medium_expert.py │ │ │ ├── hopper_medium_replay.py │ │ │ ├── walker2d_medium.py │ │ │ ├── walker2d_medium_expert.py │ │ │ └── walker2d_medium_replay.py │ │ ├── icfm │ │ │ ├── __init__.py │ │ │ ├── halfcheetah_medium.py │ │ │ ├── halfcheetah_medium_expert.py │ │ │ ├── halfcheetah_medium_replay.py │ │ │ ├── hopper_medium.py │ │ │ ├── hopper_medium_expert.py │ │ │ ├── hopper_medium_replay.py │ │ │ ├── walker2d_medium.py │ │ │ ├── walker2d_medium_expert.py │ │ │ └── walker2d_medium_replay.py │ │ └── vpsde │ │ │ ├── __init__.py │ │ │ ├── halfcheetah_medium.py │ │ │ ├── halfcheetah_medium_expert.py │ │ │ ├── halfcheetah_medium_replay.py │ │ │ ├── hopper_medium.py │ │ │ ├── hopper_medium_expert.py │ │ │ ├── hopper_medium_replay.py │ │ │ ├── walker2d_medium.py │ │ │ ├── walker2d_medium_expert.py │ │ │ └── walker2d_medium_replay.py │ ├── idql │ │ └── vpsde │ │ │ ├── dm_control_suit_cartpole_swingup.py │ │ │ ├── dm_control_suit_cheetah_run.py │ │ │ ├── dm_control_suit_finger_turn_hard.py │ │ │ ├── dm_control_suit_fish_swim.py │ │ │ ├── dm_control_suit_humanoid_run.py │ │ │ ├── dm_control_suit_manipulator_insert_ball.py │ │ │ ├── dm_control_suit_manipulator_insert_peg.py │ │ │ ├── dm_control_suit_walker_stand.py │ │ │ ├── dm_control_suit_walker_walk.py │ │ │ ├── halfcheetah_medium.py │ │ │ ├── halfcheetah_medium_expert.py │ │ │ ├── halfcheetah_medium_replay.py │ │ │ ├── hopper_medium.py │ │ │ ├── hopper_medium_expert.py │ │ │ ├── hopper_medium_replay.py │ │ │ ├── walker2d_medium.py │ │ │ ├── walker2d_medium_expert.py │ │ │ └── walker2d_medium_replay.py │ └── srpo │ │ └── vpsde │ │ ├── dm_control_suit_cartpole_swingup.py │ │ ├── dm_control_suit_cheetah_run.py │ │ ├── dm_control_suit_finger_turn_hard.py │ │ ├── dm_control_suit_fish_swim.py │ │ ├── dm_control_suit_manipulator_insert_ball.py │ │ ├── dm_control_suit_manipulator_insert_peg.py │ │ ├── dm_control_suit_walker_stand.py │ │ ├── dm_control_suit_walker_walk.py │ │ ├── halfcheetah_medium.py │ │ ├── halfcheetah_medium_expert.py │ │ ├── halfcheetah_medium_replay.py │ │ ├── hopper_medium.py │ │ ├── hopper_medium_expert.py │ │ ├── hopper_medium_replay.py │ │ ├── walker2d_medium.py │ │ ├── walker2d_medium_expert.py │ │ └── walker2d_medium_replay.py ├── diffusion_model │ ├── __init__.py │ ├── configurations │ │ ├── __init__.py │ │ ├── d4rl_halfcheetah_qgpo.py │ │ ├── d4rl_walker2d_qgpo.py │ │ └── lunarlander_continuous_qgpo.py │ ├── d4rl_halfcheetah_qgpo.py │ ├── d4rl_walker2d_qgpo.py │ └── lunarlander_continuous_qgpo.py └── tutorials │ ├── README.md │ ├── README.zh.md │ ├── applications │ └── swiss_roll_world_model.py │ ├── generative_models │ ├── swiss_roll_diffusion.py │ ├── swiss_roll_discrete_flow_model.py │ ├── swiss_roll_energy_condition.py │ ├── swiss_roll_icfm.py │ ├── swiss_roll_icfm_with_mask.py │ ├── swiss_roll_otcfm.py │ └── swiss_roll_sf2m.py │ ├── huggingface │ ├── lunarlander_continuous_qgpo_huggingface_pull.py │ ├── lunarlander_continuous_qgpo_huggingface_push.py │ └── modelcard_template.md │ ├── metrics │ └── swiss_roll_likelihood.py │ ├── solvers │ ├── swiss_roll_dpmsolver.py │ └── swiss_roll_sdesolver.py │ └── special_usages │ ├── customized_modules.py │ └── dict_tensor_ode.py ├── requirements-doc.txt └── setup.py /.github/workflows/doc.yml: -------------------------------------------------------------------------------- 1 | # This workflow will check flake style 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: Docs Deploy 5 | 6 | on: 7 | push: 8 | branches: [ main, 'doc/*', 'dev/*' ] 9 | release: 10 | types: [ published ] 11 | 12 | jobs: 13 | doc: 14 | runs-on: ubuntu-latest 15 | permissions: 16 | contents: write # Allows writing to the repository 17 | strategy: 18 | matrix: 19 | python-version: [ 3.9 ] 20 | 21 | services: 22 | plantuml: 23 | image: plantuml/plantuml-server 24 | ports: 25 | - 18080:8080 26 | 27 | steps: 28 | - uses: actions/checkout@v4 29 | - name: Set up Python ${{ matrix.python-version }} 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: ${{ matrix.python-version }} 33 | - name: Install dependencies 34 | run: | 35 | sudo apt-get update -y 36 | sudo apt-get install -y make wget curl cloc graphviz 37 | dot -V 38 | python -m pip install -r requirements-doc.txt 39 | python -m pip install . 40 | - name: Generate 41 | env: 42 | ENV_PROD: 'true' 43 | PLANTUML_HOST: http://localhost:18080 44 | run: | 45 | cd docs 46 | make html 47 | mv build/html ../public 48 | - name: Deploy to Github Page 49 | uses: JamesIves/github-pages-deploy-action@v4 50 | with: 51 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 52 | BRANCH: gh-pages # The branch the action should deploy to. 53 | FOLDER: public # The folder the action should deploy. 54 | CLEAN: true # Automatically remove deleted files from the deploy branch 55 | 56 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | # Trigger this workflow when a new tag is pushed 5 | push: 6 | tags: 7 | - 'v*' 8 | 9 | jobs: 10 | build: 11 | name: Build distribution 📦 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v4 16 | - name: Set up Python 17 | uses: actions/setup-python@v5 18 | with: 19 | python-version: "3.x" 20 | - name: Install pypa/build 21 | run: >- 22 | python3 -m 23 | pip install 24 | build 25 | --user 26 | - name: Build a binary wheel and a source tarball 27 | run: python3 -m build 28 | - name: Store the distribution packages 29 | uses: actions/upload-artifact@v4 30 | with: 31 | name: python-package-distributions 32 | path: dist/ 33 | 34 | publish-to-pypi: 35 | name: >- 36 | Publish Python 🐍 distribution 📦 to PyPI 37 | if: startsWith(github.ref, 'refs/tags/') # only publish to PyPI on tag pushes 38 | needs: 39 | - build 40 | runs-on: ubuntu-latest 41 | environment: 42 | name: pypi 43 | url: https://pypi.org/p/GenerativeRL # Replace with your PyPI project name 44 | permissions: 45 | id-token: write # IMPORTANT: mandatory for trusted publishing 46 | steps: 47 | - name: Download all the dists 48 | uses: actions/download-artifact@v4 49 | with: 50 | name: python-package-distributions 51 | path: dist/ 52 | - name: Publish distribution 📦 to PyPI 53 | uses: pypa/gh-action-pypi-publish@release/v1 54 | 55 | github-release: 56 | name: >- 57 | Sign the Python 🐍 distribution 📦 with Sigstore 58 | and upload them to GitHub Release 59 | needs: 60 | - publish-to-pypi 61 | runs-on: ubuntu-latest 62 | 63 | permissions: 64 | contents: write # IMPORTANT: mandatory for making GitHub Releases 65 | id-token: write # IMPORTANT: mandatory for sigstore 66 | 67 | steps: 68 | - name: Download all the dists 69 | uses: actions/download-artifact@v4 70 | with: 71 | name: python-package-distributions 72 | path: dist/ 73 | - name: Sign the dists with Sigstore 74 | uses: sigstore/gh-action-sigstore-python@v3.0.0 75 | with: 76 | inputs: >- 77 | ./dist/*.tar.gz 78 | ./dist/*.whl 79 | - name: Create GitHub Release 80 | env: 81 | GITHUB_TOKEN: ${{ github.token }} 82 | run: >- 83 | gh release create 84 | '${{ github.ref_name }}' 85 | --repo '${{ github.repository }}' 86 | --notes "" 87 | - name: Upload artifact signatures to GitHub Release 88 | env: 89 | GITHUB_TOKEN: ${{ github.token }} 90 | # Upload to GitHub Release using the `gh` CLI. 91 | # `dist/` contains the built packages, and the 92 | # sigstore-produced signatures and certificates. 93 | run: >- 94 | gh release upload 95 | '${{ github.ref_name }}' dist/** 96 | --repo '${{ github.repository }}' 97 | -------------------------------------------------------------------------------- /.github/workflows/unit_test.yml: -------------------------------------------------------------------------------- 1 | # This workflow will check pytest 2 | # For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions 3 | 4 | name: unit_test 5 | 6 | on: [push, pull_request] 7 | 8 | jobs: 9 | test_unittest: 10 | runs-on: ubuntu-latest 11 | if: ( !contains(github.event.head_commit.message, 'ci skip') && !contains(github.event.head_commit.message, 'ut skip')) 12 | strategy: 13 | matrix: 14 | python-version: ["3.9", "3.10", "3.11", "3.12"] 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v5 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: do_unittest 23 | timeout-minutes: 40 24 | run: | 25 | python -m pip install --upgrade pip 26 | python -m pip install . 27 | python -m pip install pytest pytest-cov 28 | pytest ./grl/unittest --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html 29 | - name: Upload coverage to Codecov 30 | uses: codecov/codecov-action@v1 31 | with: 32 | token: ${{ secrets.CODECOV_TOKEN }} 33 | file: ./coverage.xml 34 | flags: unittests 35 | name: codecov-umbrella 36 | fail_ci_if_error: false 37 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | *.pt 163 | wandb/ 164 | .vscode/ 165 | *.mp4 166 | *.npy 167 | *.png 168 | *.npz -------------------------------------------------------------------------------- /.style.yapf: -------------------------------------------------------------------------------- 1 | [style] 2 | # For explanation and more information: https://github.com/google/yapf 3 | BASED_ON_STYLE=pep8 4 | DEDENT_CLOSING_BRACKETS=True 5 | SPLIT_BEFORE_FIRST_ARGUMENT=True 6 | ALLOW_SPLIT_BEFORE_DICT_VALUE=False 7 | JOIN_MULTIPLE_LINES=False 8 | BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=True 9 | BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION=2 10 | SPACES_AROUND_POWER_OPERATOR=True -------------------------------------------------------------------------------- /CHANGELOG: -------------------------------------------------------------------------------- 1 | # Changelog 2 | 3 | All notable changes to this project will be documented in this file. 4 | 5 | The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), 6 | and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). 7 | 8 | ## [Unreleased] 9 | ### Added 10 | - New feature for better integration with generative models. 11 | 12 | ## [0.0.1] - 2024-10-31 13 | ### Added 14 | - Initial release of `GenerativeRL` supporting Generative Models for RL. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | [Git Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/git_guide.html) 2 | 3 | [GitHub Cooperation Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html) 4 | 5 | - [Code Style](https://di-engine-docs.readthedocs.io/en/latest/21_code_style/index.html) 6 | - [Unit Test](https://di-engine-docs.readthedocs.io/en/latest/22_test/index.html) 7 | - [Code Review](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html#pr-s-code-review) -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/assets/framework.png -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # You can set these variables from the command line, and also 2 | # from the environment for the first two. 3 | PROJ_DIR ?= ${CURDIR} 4 | SPHINXOPTS ?= 5 | SPHINXBUILD ?= $(shell which sphinx-build) 6 | SPHINXMULTIVERSION ?= $(shell which sphinx-multiversion) 7 | SOURCEDIR ?= ${PROJ_DIR}/source 8 | BUILDDIR ?= ${PROJ_DIR}/build 9 | 10 | # Minimal makefile for Sphinx documentation 11 | ALL_MK := ${SOURCEDIR}/all.mk 12 | ALL := $(MAKE) -f "${ALL_MK}" SOURCE=${SOURCEDIR} 13 | 14 | .EXPORT_ALL_VARIABLES: 15 | 16 | NO_CONTENTS_BUILD = true 17 | 18 | # Catch-all target: route all unknown targets to Sphinx using the new 19 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 20 | # Put it first so that "make" without argument is like "make help". 21 | .PHONY: help contents build html prod clean sourcedir builddir Makefile 22 | 23 | help: 24 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 25 | 26 | contents: 27 | @$(ALL) build 28 | build: html 29 | html: contents 30 | @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 31 | @touch "$(BUILDDIR)/html/.nojekyll" 32 | prod: 33 | @NO_CONTENTS_BUILD='' $(SPHINXMULTIVERSION) "$(SOURCEDIR)" "$(BUILDDIR)/html" $(SPHINXOPTS) $(O) 34 | @cp main_page.html "$(BUILDDIR)/html/index.html" 35 | @touch "$(BUILDDIR)/html/.nojekyll" 36 | 37 | clean: 38 | @$(ALL) clean 39 | @$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 40 | 41 | sourcedir: 42 | @echo $(shell readlink -f ${SOURCEDIR}) 43 | builddir: 44 | @echo $(shell readlink -f ${BUILDDIR}/html) 45 | -------------------------------------------------------------------------------- /docs/source/_static/css/style.css: -------------------------------------------------------------------------------- 1 | .header-logo { 2 | background-image: url("../images/logo.png"); 3 | background-size: 180px 40px; 4 | height: 40px; 5 | width: 180px; 6 | } 7 | -------------------------------------------------------------------------------- /docs/source/_static/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/docs/source/_static/images/logo.png -------------------------------------------------------------------------------- /docs/source/_templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | {% block extrahead %} 3 | 4 | {% endblock %} 5 | -------------------------------------------------------------------------------- /docs/source/all.mk: -------------------------------------------------------------------------------- 1 | PIP := $(shell which pip) 2 | 3 | SPHINXOPTS ?= 4 | SPHINXBUILD ?= $(shell which sphinx-build) 5 | SPHINXMULTIVERSION ?= $(shell which sphinx-multiversion) 6 | SOURCEDIR ?= $(shell readlink -f ${CURDIR}) 7 | BUILDDIR ?= $(shell readlink -f ${CURDIR}/../build) 8 | 9 | DIAGRAMS_MK := ${SOURCEDIR}/diagrams.mk 10 | DIAGRAMS := $(MAKE) -f "${DIAGRAMS_MK}" SOURCE=${SOURCEDIR} 11 | GRAPHVIZ_MK := ${SOURCEDIR}/graphviz.mk 12 | GRAPHVIZ := $(MAKE) -f "${GRAPHVIZ_MK}" SOURCE=${SOURCEDIR} 13 | 14 | _CURRENT_PATH := ${PATH} 15 | _PROJ_DIR := $(shell readlink -f ${SOURCEDIR}/../..) 16 | _LIBS_DIR := $(shell readlink -f ${SOURCEDIR}/_libs) 17 | _SHIMS_DIR := $(shell readlink -f ${SOURCEDIR}/_shims) 18 | 19 | .EXPORT_ALL_VARIABLES: 20 | 21 | PYTHONPATH = ${_PROJ_DIR}:${_LIBS_DIR} 22 | PATH = ${_SHIMS_DIR}:${_CURRENT_PATH} 23 | 24 | .PHONY: all build clean pip 25 | 26 | pip: 27 | @$(PIP) install -r ${_PROJ_DIR}/requirements-doc.txt 28 | 29 | build: 30 | @$(DIAGRAMS) build 31 | @$(GRAPHVIZ) build 32 | 33 | all: build 34 | 35 | clean: 36 | @$(DIAGRAMS) clean 37 | @$(GRAPHVIZ) clean -------------------------------------------------------------------------------- /docs/source/api_doc/agents/index.rst: -------------------------------------------------------------------------------- 1 | grl.agents 2 | ===================== 3 | 4 | .. currentmodule:: grl.agents 5 | 6 | .. automodule:: grl.agents 7 | 8 | 9 | QGPOAgent 10 | ---------- 11 | 12 | .. autoclass:: QGPOAgent 13 | :special-members: __init__ 14 | :members: 15 | 16 | SRPOAgent 17 | ---------- 18 | 19 | .. autoclass:: SRPOAgent 20 | :special-members: __init__ 21 | :members: 22 | 23 | GPAgent 24 | ---------- 25 | 26 | .. autoclass:: GPAgent 27 | :special-members: __init__ 28 | :members: 29 | -------------------------------------------------------------------------------- /docs/source/api_doc/algorithms/index.rst: -------------------------------------------------------------------------------- 1 | grl.algorithms 2 | ===================== 3 | 4 | .. currentmodule:: grl.algorithms 5 | 6 | .. automodule:: grl.algorithms 7 | 8 | 9 | QGPOCritic 10 | ------------- 11 | .. autoclass:: QGPOCritic 12 | :special-members: __init__ 13 | :members: 14 | 15 | QGPOPolicy 16 | ------------- 17 | .. autoclass:: QGPOPolicy 18 | :special-members: __init__ 19 | :members: 20 | 21 | QGPOAlgorithm 22 | ------------- 23 | .. autoclass:: QGPOAlgorithm 24 | :special-members: __init__ 25 | :members: 26 | 27 | SRPOCritic 28 | ------------- 29 | .. autoclass:: SRPOCritic 30 | :special-members: __init__ 31 | :members: 32 | 33 | SRPOPolicy 34 | ------------- 35 | .. autoclass:: SRPOPolicy 36 | :special-members: __init__ 37 | :members: 38 | 39 | SRPOAlgorithm 40 | ------------- 41 | .. autoclass:: SRPOAlgorithm 42 | :special-members: __init__ 43 | :members: 44 | 45 | GMPOCritic 46 | ------------- 47 | .. autoclass:: GMPOCritic 48 | :special-members: __init__ 49 | :members: 50 | 51 | GMPOPolicy 52 | ------------- 53 | .. autoclass:: GMPOPolicy 54 | :special-members: __init__ 55 | :members: 56 | 57 | GMPOAlgorithm 58 | ------------- 59 | .. autoclass:: GMPOAlgorithm 60 | :special-members: __init__ 61 | :members: 62 | 63 | GMPGCritic 64 | ------------- 65 | .. autoclass:: GMPGCritic 66 | :special-members: __init__ 67 | :members: 68 | 69 | GMPGPolicy 70 | ------------- 71 | .. autoclass:: GMPGPolicy 72 | :special-members: __init__ 73 | :members: 74 | 75 | GMPGAlgorithm 76 | ------------- 77 | .. autoclass:: GMPGAlgorithm 78 | :special-members: __init__ 79 | :members: 80 | -------------------------------------------------------------------------------- /docs/source/api_doc/datasets/index.rst: -------------------------------------------------------------------------------- 1 | grl.datasets 2 | ===================== 3 | 4 | .. currentmodule:: grl.datasets 5 | 6 | .. automodule:: grl.datasets 7 | 8 | 9 | QGPOD4RLDataset 10 | ------------------ 11 | 12 | .. autoclass:: QGPOD4RLDataset 13 | :special-members: __init__ 14 | :members: 15 | 16 | QGPODataset 17 | ------------------ 18 | 19 | .. autoclass:: QGPODataset 20 | :special-members: __init__ 21 | :members: 22 | 23 | GPD4RLDataset 24 | ------------------ 25 | 26 | .. autoclass:: GPD4RLDataset 27 | :special-members: __init__ 28 | :members: 29 | 30 | GPDataset 31 | ------------------ 32 | 33 | .. autoclass:: GPDataset 34 | :special-members: __init__ 35 | :members: 36 | 37 | -------------------------------------------------------------------------------- /docs/source/api_doc/generative_models/index.rst: -------------------------------------------------------------------------------- 1 | grl.generative_models 2 | ===================== 3 | 4 | .. currentmodule:: grl.generative_models 5 | 6 | .. automodule:: grl.generative_models 7 | 8 | DiffusionModel 9 | -------------- 10 | .. autoclass:: DiffusionModel 11 | :special-members: __init__ 12 | :members: 13 | 14 | EnergyConditionalDiffusionModel 15 | ------------------------------- 16 | .. autoclass:: EnergyConditionalDiffusionModel 17 | :special-members: __init__ 18 | :members: 19 | 20 | IndependentConditionalFlowModel 21 | ------------------------------- 22 | .. autoclass:: IndependentConditionalFlowModel 23 | :special-members: __init__ 24 | :members: 25 | 26 | OptimalTransportConditionalFlowModel 27 | ------------------------------- 28 | .. autoclass:: OptimalTransportConditionalFlowModel 29 | :special-members: __init__ 30 | :members: 31 | -------------------------------------------------------------------------------- /docs/source/api_doc/neural_network/index.rst: -------------------------------------------------------------------------------- 1 | grl.neural_network 2 | ===================== 3 | 4 | .. currentmodule:: grl.neural_network 5 | 6 | .. automodule:: grl.neural_network 7 | 8 | ConcatenateLayer 9 | ----------------- 10 | .. autoclass:: ConcatenateLayer 11 | :special-members: __init__ 12 | :members: 13 | 14 | MultiLayerPerceptron 15 | --------------------- 16 | .. autoclass:: MultiLayerPerceptron 17 | :special-members: __init__ 18 | :members: 19 | 20 | ConcatenateMLP 21 | -------------- 22 | .. autoclass:: ConcatenateMLP 23 | :special-members: __init__ 24 | :members: 25 | 26 | TemporalSpatialResidualNet 27 | --------------------------- 28 | .. autoclass:: TemporalSpatialResidualNet 29 | :special-members: __init__ 30 | :members: 31 | 32 | DiT 33 | --- 34 | .. autoclass:: DiT 35 | :special-members: __init__ 36 | :members: 37 | 38 | DiT1D 39 | ------ 40 | .. autoclass:: DiT1D 41 | :special-members: __init__ 42 | :members: 43 | 44 | DiT2D 45 | ------ 46 | .. autoclass:: DiT2D 47 | :special-members: __init__ 48 | :members: 49 | 50 | 51 | DiT3D 52 | ------ 53 | .. autoclass:: DiT3D 54 | :special-members: __init__ 55 | :members: 56 | -------------------------------------------------------------------------------- /docs/source/api_doc/numerical_methods/index.rst: -------------------------------------------------------------------------------- 1 | grl.numerical_methods 2 | ===================== 3 | 4 | .. currentmodule:: grl.numerical_methods 5 | 6 | .. automodule:: grl.numerical_methods 7 | 8 | ODE 9 | --- 10 | .. autoclass:: ODE 11 | :special-members: __init__ 12 | :members: 13 | 14 | SDE 15 | --- 16 | .. autoclass:: SDE 17 | :special-members: __init__ 18 | :members: 19 | 20 | DPMSolver 21 | --------- 22 | .. autoclass:: DPMSolver 23 | :special-members: __init__ 24 | :members: 25 | 26 | ODESolver 27 | --------- 28 | .. autoclass:: ODESolver 29 | :special-members: __init__ 30 | :members: 31 | 32 | SDESolver 33 | --------- 34 | .. autoclass:: SDESolver 35 | :special-members: __init__ 36 | :members: 37 | 38 | GaussianConditionalProbabilityPath 39 | ----------------------------------- 40 | .. autoclass:: GaussianConditionalProbabilityPath 41 | :special-members: __init__ 42 | :members: 43 | -------------------------------------------------------------------------------- /docs/source/api_doc/rl_modules/index.rst: -------------------------------------------------------------------------------- 1 | grl.rl_modules 2 | ===================== 3 | 4 | .. currentmodule:: grl.rl_modules 5 | 6 | .. automodule:: grl.rl_modules 7 | 8 | 9 | GymEnvSimulator 10 | --------------------- 11 | .. autoclass:: GymEnvSimulator 12 | :special-members: __init__ 13 | :members: 14 | 15 | OneShotValueFunction 16 | --------------------- 17 | .. autoclass:: OneShotValueFunction 18 | :special-members: __init__ 19 | :members: 20 | 21 | VNetwork 22 | --------------------- 23 | .. autoclass:: VNetwork 24 | :special-members: __init__ 25 | :members: 26 | 27 | DoubleVNetwork 28 | --------------------- 29 | .. autoclass:: DoubleVNetwork 30 | :special-members: __init__ 31 | :members: 32 | 33 | QNetwork 34 | --------------------- 35 | .. autoclass:: QNetwork 36 | :special-members: __init__ 37 | :members: 38 | 39 | DoubleQNetwork 40 | --------------------- 41 | .. autoclass:: DoubleQNetwork 42 | :special-members: __init__ 43 | :members: 44 | -------------------------------------------------------------------------------- /docs/source/api_doc/utils/index.rst: -------------------------------------------------------------------------------- 1 | grl.utils 2 | ===================== 3 | 4 | .. currentmodule:: grl.utils 5 | 6 | .. automodule:: grl.utils 7 | 8 | 9 | set_seed 10 | -------- 11 | .. autofunction:: grl.utils.set_seed 12 | -------------------------------------------------------------------------------- /docs/source/diagrams.mk: -------------------------------------------------------------------------------- 1 | PLANTUMLCLI ?= $(shell which plantumlcli) 2 | 3 | SOURCE ?= . 4 | PUMLS := $(shell find ${SOURCE} -name *.puml) 5 | PNGS := $(addsuffix .puml.png, $(basename ${PUMLS})) 6 | SVGS := $(addsuffix .puml.svg, $(basename ${PUMLS})) 7 | 8 | %.puml.png: %.puml 9 | $(PLANTUMLCLI) -t png -o "$(shell readlink -f $@)" "$(shell readlink -f $<)" 10 | 11 | %.puml.svg: %.puml 12 | $(PLANTUMLCLI) -t svg -o "$(shell readlink -f $@)" "$(shell readlink -f $<)" 13 | 14 | build: ${SVGS} ${PNGS} 15 | 16 | all: build 17 | 18 | clean: 19 | rm -rf \ 20 | $(shell find ${SOURCE} -name *.puml.svg) \ 21 | $(shell find ${SOURCE} -name *.puml.png) \ -------------------------------------------------------------------------------- /docs/source/graphviz.mk: -------------------------------------------------------------------------------- 1 | DOT := $(shell which dot) 2 | 3 | SOURCE ?= . 4 | GVS := $(shell find ${SOURCE} -name *.gv) 5 | PNGS := $(addsuffix .gv.png, $(basename ${GVS})) 6 | SVGS := $(addsuffix .gv.svg, $(basename ${GVS})) 7 | 8 | %.gv.png: %.gv 9 | $(DOT) -Tpng -o"$(shell readlink -f $@)" "$(shell readlink -f $<)" 10 | 11 | %.gv.svg: %.gv 12 | $(DOT) -Tsvg -o"$(shell readlink -f $@)" "$(shell readlink -f $<)" 13 | 14 | build: ${SVGS} ${PNGS} 15 | 16 | all: build 17 | 18 | clean: 19 | rm -rf \ 20 | $(shell find ${SOURCE} -name *.gv.svg) \ 21 | $(shell find ${SOURCE} -name *.gv.png) \ -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | GenerativeRL Documentation 2 | ========================================================= 3 | 4 | Overview 5 | ------------- 6 | 7 | ``GenerativeRL`` is a is a Python library for solving reinforcement learning (RL) problems using generative models, such as diffusion models and flow models. 8 | This library aims to provide a framework for combining the power of generative models with the decision-making capabilities of reinforcement learning algorithms. 9 | 10 | .. toctree:: 11 | :maxdepth: 2 12 | :caption: Tutorials 13 | 14 | tutorials/installation/index 15 | tutorials/quick_start/index 16 | 17 | .. toctree:: 18 | :maxdepth: 2 19 | :caption: Concepts 20 | 21 | concepts/index 22 | 23 | .. toctree:: 24 | :maxdepth: 2 25 | :caption: User Guide 26 | 27 | user_guide/index 28 | 29 | .. toctree:: 30 | :maxdepth: 2 31 | :caption: API Documentation 32 | 33 | api_doc/agents/index 34 | api_doc/algorithms/index 35 | api_doc/datasets/index 36 | api_doc/generative_models/index 37 | api_doc/neural_network/index 38 | api_doc/numerical_methods/index 39 | api_doc/rl_modules/index 40 | api_doc/utils/index 41 | 42 | .. toctree:: 43 | :maxdepth: 2 44 | :caption: Contributor Guide -------------------------------------------------------------------------------- /docs/source/tutorials/installation/index.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | GenerativeRL can be installed using pip: 5 | 6 | .. code-block:: console 7 | 8 | $ pip install GenerativeRL 9 | 10 | You can also install the latest development version from GitHub: 11 | 12 | .. code-block:: console 13 | 14 | $ pip install git+https://github.com/opendilab/GenerativeRL.git 15 | -------------------------------------------------------------------------------- /docs/source/tutorials/quick_start/index.rst: -------------------------------------------------------------------------------- 1 | Quick Start 2 | =========== 3 | 4 | Generative model in GenerativeRL 5 | -------------------------------- 6 | 7 | GenerativeRL support easy-to-use APIs for training and deploying generative model. 8 | We provide a simple example of how to train a diffusion model on the swiss roll dataset in `Colab `_. 9 | 10 | More usage examples can be found in the folder `grl_pipelines/tutorials/`. 11 | 12 | Reinforcement Learning 13 | ----------------------- 14 | 15 | GenerativeRL provides a simple and flexible interface for training and deploying reinforcement learning agents powered by generative models. Here's an example of how to use the library to train a Q-guided policy optimization (QGPO) agent on the HalfCheetah environment and deploy it for evaluation. 16 | 17 | .. code-block:: python 18 | 19 | from grl_pipelines.diffusion_model.configurations.d4rl_halfcheetah_qgpo import config 20 | from grl.algorithms import QGPOAlgorithm 21 | from grl.utils.log import log 22 | import gym 23 | 24 | def qgpo_pipeline(config): 25 | qgpo = QGPOAlgorithm(config) 26 | qgpo.train() 27 | 28 | agent = qgpo.deploy() 29 | env = gym.make(config.deploy.env.env_id) 30 | observation = env.reset() 31 | for _ in range(config.deploy.num_deploy_steps): 32 | env.render() 33 | observation, reward, done, _ = env.step(agent.act(observation)) 34 | 35 | if __name__ == '__main__': 36 | log.info("config: \n{}".format(config)) 37 | qgpo_pipeline(config) 38 | 39 | Explanation 40 | ----------- 41 | 42 | 1. First, we import the necessary components from the GenerativeRL library, including the configuration for the HalfCheetah environment and the QGPO algorithm, as well as the logging utility and the OpenAI Gym environment. 43 | 44 | 2. The ``qgpo_pipeline`` function encapsulates the training and deployment process: 45 | 46 | - An instance of the ``QGPOAlgorithm`` is created with the provided configuration. 47 | - The ``qgpo.train()`` method is called to train the QGPO agent on the HalfCheetah environment. 48 | - After training, the ``qgpo.deploy()`` method is called to obtain the trained agent for deployment. 49 | - A new instance of the HalfCheetah environment is created using ``gym.make``. 50 | - The environment is reset to its initial state with ``env.reset()``. 51 | - A loop is executed for the specified number of steps (``config.deploy.num_deploy_steps``), rendering the environment and stepping through it using the agent's ``act`` method. 52 | 53 | 3. In the ``if __name__ == '__main__'`` block, the configuration is printed to the console using the logging utility, and the ``qgpo_pipeline`` function is called with the provided configuration. 54 | 55 | This example demonstrates how to utilize the GenerativeRL library to train a QGPO agent on the HalfCheetah environment and then deploy the trained agent for evaluation within the environment. You can modify the configuration and algorithm as needed to suit your specific use case. 56 | 57 | For more detailed information and advanced usage examples, please refer to the API documentation and other sections of the GenerativeRL documentation. 58 | -------------------------------------------------------------------------------- /docs/source/user_guide/evaluating_agents.rst: -------------------------------------------------------------------------------- 1 | How to evaluate RL agents performance 2 | ------------------------------------------------- 3 | 4 | In GenerativeRL, the performance of reinforcement learning (RL) agents is evaluated using simulators or environments. 5 | 6 | The class of agent is implemented as a class under the ``grl.agents`` module, which has a unified ``act`` method that takes the observation as input and returns the action. 7 | 8 | User can evaluate the performance of an agent by running it in a simulator or environment and collecting the rewards. 9 | 10 | .. code-block:: python 11 | 12 | import gym 13 | agent = algorithm.deploy() 14 | env = gym.make(config.deploy.env.env_id) 15 | observation = env.reset() 16 | for _ in range(config.deploy.num_deploy_steps): 17 | env.render() 18 | observation, reward, done, _ = env.step(agent.act(observation)) 19 | 20 | -------------------------------------------------------------------------------- /docs/source/user_guide/index.rst: -------------------------------------------------------------------------------- 1 | User Guide 2 | ================ 3 | 4 | Here is a list of user guide sections: 5 | 6 | .. toctree:: 7 | :maxdepth: 2 8 | :caption: User Guide 9 | 10 | installation 11 | training_agents 12 | training_generative_models 13 | evaluating_agents 14 | 15 | For more detailed information and advanced usage examples, please refer to the API documentation and other sections of the GenerativeRL documentation. 16 | 17 | -------------------------------------------------------------------------------- /docs/source/user_guide/installation.rst: -------------------------------------------------------------------------------- 1 | How to install GenerativeRL and its dependencies 2 | ------------------------------------------------- 3 | 4 | GenerativeRL is a Python library that requires the following dependencies to be installed: 5 | 6 | - Python 3.9 or higher 7 | - PyTorch 2.0.0 or higher 8 | 9 | Install GenerativeRL using the following command: 10 | 11 | .. code-block:: bash 12 | 13 | git clone https://github.com/opendilab/GenerativeRL.git 14 | cd GenerativeRL 15 | pip install -e . 16 | 17 | For solving reinforcement learning problems, you have to install additional environments and dependencies, such as Gym, PyBullet, MuJoCo, and DeepMind Control Suite, etc. 18 | You can install these dependencies after installing GenerativeRL, such as: 19 | 20 | .. code-block:: bash 21 | 22 | pip install gym 23 | pip install pybullet 24 | pip install mujoco-py 25 | pip install dm_control 26 | 27 | It is to be noted that some of these dependencies require additional setup and licensing to use, for example, D4RL requires a special Gym environment version to be installed: 28 | 29 | .. code-block:: bash 30 | 31 | pip install 'gym==0.23.1' 32 | 33 | Another important thing is that some of the environments require additional setup, such as MuJoCo, which requires the following steps: 34 | 35 | .. code-block:: bash 36 | 37 | sudo apt-get install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev -y 38 | sudo apt-get install swig gcc g++ make locales dnsutils cmake -y 39 | sudo apt-get install build-essential libgl1-mesa-dev libgl1-mesa-glx libglew-dev -y 40 | sudo apt-get install libosmesa6-dev libglfw3 libglfw3-dev libsdl2-dev libsdl2-image-dev -y 41 | sudo apt-get install libglm-dev libfreetype6-dev patchelf ffmpeg -y 42 | mkdir -p /root/.mujoco 43 | wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz 44 | tar -xf mujoco.tar.gz -C /root/.mujoco 45 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mjpro210/bin:/root/.mujoco/mujoco210/bin 46 | git clone https://github.com/Farama-Foundation/D4RL.git 47 | cd D4RL 48 | pip install -e . 49 | pip install lockfile 50 | pip install "Cython<3.0" 51 | 52 | Check whether the installation is successful by running the following command: 53 | 54 | .. code-block:: bash 55 | 56 | python -c "import grl" 57 | 58 | -------------------------------------------------------------------------------- /docs/source/user_guide/training_agents.rst: -------------------------------------------------------------------------------- 1 | How to train and deploy reinforcement learning agents 2 | ----------------------------------------------------- 3 | 4 | In GenerativeRL, the RL algorithms are implemented as a class under the ``grl.algorithms`` module, while the agents are implemented as a class under the ``grl.agents`` module. 5 | 6 | Every algorithm class has a ``train`` method that takes the environment, dataset, and other hyperparameters as input and returns the trained model. 7 | Every algorithm class also has a ``deploy`` method that copys the trained model and returns the trained agent. 8 | 9 | For training a specific RL algorithm, you need to follow these steps: 10 | 11 | 1. Create an instance of the RL algorithm class. 12 | 13 | .. code-block:: python 14 | 15 | from grl.algorithms.qgpo import QGPOAlgorithm 16 | 17 | 2. Define the hyperparameters for the algorithm in a configurations dictionary. You can use the default configurations provided under the ``grl_pipelines`` module. 18 | 19 | .. code-block:: python 20 | 21 | from grl_pipelines.diffusion_model.configurations.d4rl_halfcheetah_qgpo import config 22 | 23 | 3. Create an instance of algorithm class with the configurations dictionary. 24 | 25 | .. code-block:: python 26 | 27 | algorithm = QGPOAlgorithm(config) 28 | 29 | 4. Train the algorithm using the ``train`` method. 30 | 31 | .. code-block:: python 32 | 33 | trained_model = algorithm.train() 34 | 35 | 5. Deploy the trained model using the ``deploy`` method. 36 | 37 | .. code-block:: python 38 | 39 | agent = algorithm.deploy() 40 | 41 | 6. Use the trained agent to interact with the environment and evaluate its performance. 42 | 43 | .. code-block:: python 44 | 45 | import gym 46 | env = gym.make(config.deploy.env.env_id) 47 | observation = env.reset() 48 | for _ in range(config.deploy.num_deploy_steps): 49 | env.render() 50 | observation, reward, done, _ = env.step(agent.act(observation)) 51 | 52 | For more information on how to train and deploy reinforcement learning agents, please refer to the API documentation and other sections of the GenerativeRL documentation. 53 | -------------------------------------------------------------------------------- /grl/__init__.py: -------------------------------------------------------------------------------- 1 | __TITLE__ = "GenerativeRL" 2 | __VERSION__ = "v0.0.1" 3 | __DESCRIPTION__ = "Do reinforcement learning with generative models." 4 | __AUTHOR__ = "OpenDILab Contributors" 5 | __AUTHOR_EMAIL__ = "opendilab@pjlab.org.cn" 6 | __version__ = __VERSION__ 7 | -------------------------------------------------------------------------------- /grl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import torch 3 | import numpy as np 4 | from tensordict import TensorDict 5 | 6 | 7 | def obs_transform(obs, device): 8 | 9 | if isinstance(obs, np.ndarray): 10 | obs = torch.from_numpy(obs).float().to(device) 11 | elif isinstance(obs, Dict): 12 | obs = {k: torch.from_numpy(v).float().to(device) for k, v in obs.items()} 13 | elif isinstance(obs, torch.Tensor): 14 | obs = obs.float().to(device) 15 | elif isinstance(obs, TensorDict): 16 | obs = obs.to(device) 17 | else: 18 | raise ValueError("observation must be a dict, torch.Tensor, or np.ndarray") 19 | 20 | return obs 21 | 22 | 23 | def action_transform(action, return_as_torch_tensor: bool = False): 24 | if isinstance(action, Dict): 25 | if return_as_torch_tensor: 26 | action = {k: v.cpu() for k, v in action.items()} 27 | else: 28 | action = {k: v.cpu().numpy() for k, v in action.items()} 29 | elif isinstance(action, torch.Tensor): 30 | if return_as_torch_tensor: 31 | action = action.cpu() 32 | else: 33 | action = action.numpy() 34 | elif isinstance(action, np.ndarray): 35 | pass 36 | else: 37 | raise ValueError("action must be a dict, torch.Tensor, or np.ndarray") 38 | 39 | return action 40 | 41 | 42 | from .base import BaseAgent 43 | from .qgpo import QGPOAgent 44 | from .srpo import SRPOAgent 45 | from .gm import GPAgent 46 | from .idql import IDQLAgent 47 | -------------------------------------------------------------------------------- /grl/agents/base.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import numpy as np 4 | import torch 5 | from easydict import EasyDict 6 | 7 | from grl.agents import obs_transform, action_transform 8 | 9 | 10 | class BaseAgent: 11 | 12 | def __init__( 13 | self, 14 | config: EasyDict, 15 | model: Union[torch.nn.Module, torch.nn.ModuleDict], 16 | ): 17 | """ 18 | Overview: 19 | Initialize the agent. 20 | Arguments: 21 | config (:obj:`EasyDict`): The configuration. 22 | model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. 23 | """ 24 | 25 | self.config = config 26 | self.device = config.device 27 | self.model = model.to(self.device) 28 | 29 | def act( 30 | self, 31 | obs: Union[np.ndarray, torch.Tensor, Dict], 32 | return_as_torch_tensor: bool = False, 33 | ) -> Union[np.ndarray, torch.Tensor, Dict]: 34 | """ 35 | Overview: 36 | Given an observation, return an action. 37 | Arguments: 38 | obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation. 39 | return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor. 40 | Returns: 41 | action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. 42 | """ 43 | 44 | obs = obs_transform(obs, self.device) 45 | 46 | with torch.no_grad(): 47 | 48 | # --------------------------------------- 49 | # Customized inference code ↓ 50 | # --------------------------------------- 51 | 52 | action = self.model(obs) 53 | 54 | # --------------------------------------- 55 | # Customized inference code ↑ 56 | # --------------------------------------- 57 | 58 | action = action_transform(action, return_as_torch_tensor) 59 | 60 | return action 61 | -------------------------------------------------------------------------------- /grl/agents/gm.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import numpy as np 4 | import torch 5 | from easydict import EasyDict 6 | 7 | from grl.agents import obs_transform, action_transform 8 | 9 | 10 | class GPAgent: 11 | """ 12 | Overview: 13 | The agent trained for generative policies. 14 | This class is designed to be used with the ``GMPGAlgorithm`` and ``GMPOAlgorithm``. 15 | Interface: 16 | ``__init__``, ``action`` 17 | """ 18 | 19 | def __init__( 20 | self, 21 | config: EasyDict, 22 | model: Union[torch.nn.Module, torch.nn.ModuleDict], 23 | ): 24 | """ 25 | Overview: 26 | Initialize the agent with the configuration and the model. 27 | Arguments: 28 | config (:obj:`EasyDict`): The configuration. 29 | model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. 30 | """ 31 | 32 | self.config = config 33 | self.device = config.device 34 | self.model = model.to(self.device) 35 | 36 | def act( 37 | self, 38 | obs: Union[np.ndarray, torch.Tensor, Dict], 39 | return_as_torch_tensor: bool = False, 40 | ) -> Union[np.ndarray, torch.Tensor, Dict]: 41 | """ 42 | Overview: 43 | Given an observation, return an action as a numpy array or a torch tensor. 44 | Arguments: 45 | obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation. 46 | return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor. 47 | Returns: 48 | action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. 49 | """ 50 | 51 | obs = obs_transform(obs, self.device) 52 | 53 | with torch.no_grad(): 54 | 55 | # --------------------------------------- 56 | # Customized inference code ↓ 57 | # --------------------------------------- 58 | 59 | obs = obs.unsqueeze(0) 60 | action = ( 61 | self.model.sample( 62 | condition=obs, 63 | t_span=( 64 | torch.linspace(0.0, 1.0, self.config.t_span).to(obs.device) 65 | if self.config.t_span is not None 66 | else None 67 | ), 68 | solver_config=( 69 | self.config.solver_config 70 | if hasattr(self.config, "solver_config") 71 | else None 72 | ), 73 | ) 74 | .squeeze(0) 75 | .cpu() 76 | .detach() 77 | .numpy() 78 | ) 79 | 80 | # --------------------------------------- 81 | # Customized inference code ↑ 82 | # --------------------------------------- 83 | 84 | action = action_transform(action, return_as_torch_tensor) 85 | 86 | return action 87 | -------------------------------------------------------------------------------- /grl/agents/idql.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import numpy as np 4 | import torch 5 | from easydict import EasyDict 6 | 7 | from grl.agents import obs_transform, action_transform 8 | 9 | 10 | class IDQLAgent: 11 | """ 12 | Overview: 13 | The IDQL agent. 14 | Interface: 15 | ``__init__``, ``action`` 16 | """ 17 | 18 | def __init__( 19 | self, 20 | config: EasyDict, 21 | model: Union[torch.nn.Module, torch.nn.ModuleDict], 22 | ): 23 | """ 24 | Overview: 25 | Initialize the agent. 26 | Arguments: 27 | config (:obj:`EasyDict`): The configuration. 28 | model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. 29 | """ 30 | 31 | self.config = config 32 | self.device = config.device 33 | self.model = model.to(self.device) 34 | 35 | def act( 36 | self, 37 | obs: Union[np.ndarray, torch.Tensor, Dict], 38 | return_as_torch_tensor: bool = False, 39 | ) -> Union[np.ndarray, torch.Tensor, Dict]: 40 | """ 41 | Overview: 42 | Given an observation, return an action. 43 | Arguments: 44 | obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation. 45 | return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor. 46 | Returns: 47 | action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. 48 | """ 49 | 50 | obs = obs_transform(obs, self.device) 51 | 52 | with torch.no_grad(): 53 | 54 | # --------------------------------------- 55 | # Customized inference code ↓ 56 | # --------------------------------------- 57 | 58 | obs = obs.unsqueeze(0) 59 | action = ( 60 | self.model["IDQLPolicy"] 61 | .get_action( 62 | state=obs, 63 | ) 64 | .squeeze(0) 65 | .cpu() 66 | .detach() 67 | .numpy() 68 | ) 69 | 70 | action = action_transform(action, return_as_torch_tensor) 71 | 72 | return action 73 | -------------------------------------------------------------------------------- /grl/agents/qgpo.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import numpy as np 4 | import torch 5 | from easydict import EasyDict 6 | 7 | from grl.agents import obs_transform, action_transform 8 | 9 | 10 | class QGPOAgent: 11 | """ 12 | Overview: 13 | The agent for the QGPO algorithm. 14 | Interface: 15 | ``__init__``, ``action`` 16 | """ 17 | 18 | def __init__( 19 | self, 20 | config: EasyDict, 21 | model: Union[torch.nn.Module, torch.nn.ModuleDict], 22 | ): 23 | """ 24 | Overview: 25 | Initialize the agent. 26 | Arguments: 27 | config (:obj:`EasyDict`): The configuration. 28 | model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. 29 | """ 30 | 31 | self.config = config 32 | self.device = config.device 33 | self.model = model.to(self.device) 34 | 35 | if hasattr(self.config, "guidance_scale"): 36 | self.guidance_scale = self.config.guidance_scale 37 | else: 38 | self.guidance_scale = 1.0 39 | 40 | def act( 41 | self, 42 | obs: Union[np.ndarray, torch.Tensor, Dict], 43 | return_as_torch_tensor: bool = False, 44 | ) -> Union[np.ndarray, torch.Tensor, Dict]: 45 | """ 46 | Overview: 47 | Given an observation, return an action. 48 | Arguments: 49 | obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation. 50 | return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor. 51 | Returns: 52 | action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. 53 | """ 54 | 55 | obs = obs_transform(obs, self.device) 56 | 57 | with torch.no_grad(): 58 | 59 | # --------------------------------------- 60 | # Customized inference code ↓ 61 | # --------------------------------------- 62 | 63 | obs = obs.unsqueeze(0) 64 | action = ( 65 | self.model["QGPOPolicy"] 66 | .sample( 67 | state=obs, 68 | t_span=( 69 | torch.linspace(0.0, 1.0, self.config.t_span).to(obs.device) 70 | if self.config.t_span is not None 71 | else None 72 | ), 73 | guidance_scale=self.guidance_scale, 74 | ) 75 | .squeeze(0) 76 | .cpu() 77 | .detach() 78 | .numpy() 79 | ) 80 | 81 | # --------------------------------------- 82 | # Customized inference code ↑ 83 | # --------------------------------------- 84 | 85 | action = action_transform(action, return_as_torch_tensor) 86 | 87 | return action 88 | -------------------------------------------------------------------------------- /grl/agents/srpo.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Union 2 | 3 | import numpy as np 4 | import torch 5 | from easydict import EasyDict 6 | 7 | from grl.agents import obs_transform, action_transform 8 | 9 | 10 | class SRPOAgent: 11 | """ 12 | Overview: 13 | The SRPO agent. 14 | Interface: 15 | ``__init__``, ``action`` 16 | """ 17 | 18 | def __init__( 19 | self, 20 | config: EasyDict, 21 | model: Union[torch.nn.Module, torch.nn.ModuleDict], 22 | ): 23 | """ 24 | Overview: 25 | Initialize the agent. 26 | Arguments: 27 | config (:obj:`EasyDict`): The configuration. 28 | model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The model. 29 | """ 30 | 31 | self.config = config 32 | self.device = config.device 33 | self.model = model.to(self.device) 34 | 35 | def act( 36 | self, 37 | obs: Union[np.ndarray, torch.Tensor, Dict], 38 | return_as_torch_tensor: bool = False, 39 | ) -> Union[np.ndarray, torch.Tensor, Dict]: 40 | """ 41 | Overview: 42 | Given an observation, return an action. 43 | Arguments: 44 | obs (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The observation. 45 | return_as_torch_tensor (:obj:`bool`): Whether to return the action as a torch tensor. 46 | Returns: 47 | action (:obj:`Union[np.ndarray, torch.Tensor, Dict]`): The action. 48 | """ 49 | 50 | obs = obs_transform(obs, self.device) 51 | 52 | with torch.no_grad(): 53 | 54 | # --------------------------------------- 55 | # Customized inference code ↓ 56 | # --------------------------------------- 57 | obs = obs.unsqueeze(0) 58 | action = ( 59 | self.model["SRPOPolicy"].policy(obs).squeeze(0).detach().cpu().numpy() 60 | ) 61 | # --------------------------------------- 62 | # Customized inference code ↑ 63 | # --------------------------------------- 64 | 65 | action = action_transform(action, return_as_torch_tensor) 66 | 67 | return action 68 | -------------------------------------------------------------------------------- /grl/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseAlgorithm 2 | from .gmpo import GMPOAlgorithm, GMPOCritic, GMPOPolicy 3 | from .gmpg import GMPGAlgorithm, GMPGCritic, GMPGPolicy 4 | from .idql import IDQLAlgorithm, IDQLCritic, IDQLPolicy 5 | from .qgpo import QGPOAlgorithm, QGPOCritic, QGPOPolicy 6 | from .srpo import SRPOAlgorithm, SRPOCritic, SRPOPolicy 7 | -------------------------------------------------------------------------------- /grl/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .d4rl import D4RLDataset 2 | from .qgpo import ( 3 | QGPODataset, 4 | QGPOD4RLDataset, 5 | QGPOOnlineDataset, 6 | QGPOCustomizedDataset, 7 | QGPOTensorDictDataset, 8 | QGPOD4RLTensorDictDataset, 9 | QGPOCustomizedTensorDictDataset, 10 | QGPODeepMindControlTensorDictDataset, 11 | ) 12 | from .gp import ( 13 | GPDataset, 14 | GPD4RLDataset, 15 | GPOnlineDataset, 16 | GPD4RLOnlineDataset, 17 | GPCustomizedDataset, 18 | GPTensorDictDataset, 19 | GPD4RLTensorDictDataset, 20 | GPCustomizedTensorDictDataset, 21 | GPDeepMindControlTensorDictDataset, 22 | GPDeepMindControlVisualTensorDictDataset, 23 | ) 24 | from .minari_dataset import MinariDataset 25 | 26 | DATASETS = { 27 | "QGPOD4RLDataset".lower(): QGPOD4RLDataset, 28 | "QGPODataset".lower(): QGPODataset, 29 | "D4RLDataset".lower(): D4RLDataset, 30 | "QGPOOnlineDataset".lower(): QGPOOnlineDataset, 31 | "QGPOCustomizedDataset".lower(): QGPOCustomizedDataset, 32 | "QGPOTensorDictDataset".lower(): QGPOTensorDictDataset, 33 | "QGPOD4RLTensorDictDataset".lower(): QGPOD4RLTensorDictDataset, 34 | "QGPOCustomizedTensorDictDataset".lower(): QGPOCustomizedTensorDictDataset, 35 | "QGPODeepMindControlTensorDictDataset".lower(): QGPODeepMindControlTensorDictDataset, 36 | "MinariDataset".lower(): MinariDataset, 37 | "GPDataset".lower(): GPDataset, 38 | "GPD4RLDataset".lower(): GPD4RLDataset, 39 | "GPOnlineDataset".lower(): GPOnlineDataset, 40 | "GPD4RLOnlineDataset".lower(): GPD4RLOnlineDataset, 41 | "GPCustomizedDataset".lower(): GPCustomizedDataset, 42 | "GPTensorDictDataset".lower(): GPTensorDictDataset, 43 | "GPD4RLTensorDictDataset".lower(): GPD4RLTensorDictDataset, 44 | "GPCustomizedTensorDictDataset".lower(): GPCustomizedTensorDictDataset, 45 | "GPDeepMindControlTensorDictDataset".lower(): GPDeepMindControlTensorDictDataset, 46 | "GPDeepMindControlVisualTensorDictDataset".lower(): GPDeepMindControlVisualTensorDictDataset, 47 | } 48 | 49 | 50 | def get_dataset(type: str): 51 | if type.lower() not in DATASETS: 52 | raise KeyError(f"Invalid dataset type: {type}") 53 | return DATASETS[type.lower()] 54 | 55 | 56 | def create_dataset(config, **kwargs): 57 | return get_dataset(config.type)(**config.args, **kwargs) 58 | -------------------------------------------------------------------------------- /grl/generative_models/__init__.py: -------------------------------------------------------------------------------- 1 | def get_generative_model(name: str): 2 | if name.lower() not in GENERATIVE_MODELS: 3 | raise ValueError("Unknown activation function {}".format(name)) 4 | return GENERATIVE_MODELS[name.lower()] 5 | 6 | 7 | from .conditional_flow_model import ( 8 | IndependentConditionalFlowModel, 9 | OptimalTransportConditionalFlowModel, 10 | ) 11 | from .diffusion_model import DiffusionModel, EnergyConditionalDiffusionModel 12 | from .variational_autoencoder import VariationalAutoencoder 13 | 14 | GENERATIVE_MODELS = { 15 | "DiffusionModel".lower(): DiffusionModel, 16 | "EnergyConditionalDiffusionModel".lower(): EnergyConditionalDiffusionModel, 17 | "VariationalAutoencoder".lower(): VariationalAutoencoder, 18 | "IndependentConditionalFlowModel".lower(): IndependentConditionalFlowModel, 19 | "OptimalTransportConditionalFlowModel".lower(): OptimalTransportConditionalFlowModel, 20 | } 21 | -------------------------------------------------------------------------------- /grl/generative_models/conditional_flow_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .independent_conditional_flow_model import IndependentConditionalFlowModel 2 | from .optimal_transport_conditional_flow_model import ( 3 | OptimalTransportConditionalFlowModel, 4 | ) 5 | -------------------------------------------------------------------------------- /grl/generative_models/diffusion_model/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffusion_model import DiffusionModel 2 | from .energy_conditional_diffusion_model import EnergyConditionalDiffusionModel 3 | -------------------------------------------------------------------------------- /grl/generative_models/discrete_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl/generative_models/discrete_model/__init__.py -------------------------------------------------------------------------------- /grl/generative_models/intrinsic_model.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from easydict import EasyDict 6 | from tensordict import TensorDict 7 | 8 | from grl.neural_network import get_module 9 | from grl.neural_network.encoders import get_encoder 10 | 11 | 12 | class IntrinsicModel(nn.Module): 13 | """ 14 | Overview: 15 | Intrinsic model of generative model, which is the backbone of many continuous-time generative models. 16 | Interfaces: 17 | ``__init__``, ``forward`` 18 | """ 19 | 20 | def __init__(self, config: EasyDict): 21 | # TODO 22 | 23 | super().__init__() 24 | 25 | self.config = config 26 | assert hasattr(config, "backbone"), "backbone must be specified in config" 27 | 28 | self.model = torch.nn.ModuleDict() 29 | if hasattr(config, "t_encoder"): 30 | self.model["t_encoder"] = get_encoder(config.t_encoder.type)( 31 | **config.t_encoder.args 32 | ) 33 | else: 34 | self.model["t_encoder"] = torch.nn.Identity() 35 | if hasattr(config, "x_encoder"): 36 | self.model["x_encoder"] = get_encoder(config.x_encoder.type)( 37 | **config.x_encoder.args 38 | ) 39 | else: 40 | self.model["x_encoder"] = torch.nn.Identity() 41 | if hasattr(config, "condition_encoder"): 42 | self.model["condition_encoder"] = get_encoder( 43 | config.condition_encoder.type 44 | )(**config.condition_encoder.args) 45 | else: 46 | self.model["condition_encoder"] = torch.nn.Identity() 47 | 48 | # TODO 49 | # specific backbone network 50 | self.model["backbone"] = get_module(config.backbone.type)( 51 | **config.backbone.args 52 | ) 53 | 54 | def forward( 55 | self, 56 | t: torch.Tensor, 57 | x: Union[torch.Tensor, TensorDict], 58 | condition: Union[torch.Tensor, TensorDict] = None, 59 | ) -> torch.Tensor: 60 | """ 61 | Overview: 62 | Return the output of the model at time t given the initial state. 63 | """ 64 | 65 | if condition is not None: 66 | t = self.model["t_encoder"](t) 67 | x = self.model["x_encoder"](x) 68 | condition = self.model["condition_encoder"](condition) 69 | output = self.model["backbone"](t, x, condition) 70 | else: 71 | t = self.model["t_encoder"](t) 72 | x = self.model["x_encoder"](x) 73 | output = self.model["backbone"](t, x) 74 | 75 | return output 76 | -------------------------------------------------------------------------------- /grl/generative_models/model_functions/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_prediction_function import DataPredictionFunction 2 | from .noise_function import NoiseFunction 3 | from .score_function import ScoreFunction 4 | from .velocity_function import VelocityFunction 5 | -------------------------------------------------------------------------------- /grl/generative_models/model_functions/data_prediction_function.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import treetensor 6 | from easydict import EasyDict 7 | from tensordict import TensorDict 8 | 9 | 10 | class DataPredictionFunction: 11 | """ 12 | Overview: 13 | Model of data prediction function in Score-based generative model. 14 | Interfaces: 15 | ``__init__``, ``forward`` 16 | """ 17 | 18 | def __init__( 19 | self, 20 | model_type: str, 21 | process: object, 22 | ): 23 | """ 24 | Overview: 25 | Initialize the noise function. 26 | Arguments: 27 | - model_type (:obj:`str`): The type of the model. 28 | - process (:obj:`object`): The process. 29 | """ 30 | 31 | self.model_type = model_type 32 | self.process = process 33 | # TODO: add more types 34 | assert self.model_type in [ 35 | "data_prediction_function", 36 | "noise_function", 37 | "score_function", 38 | "velocity_function", 39 | "denoiser_function", 40 | ], "Unknown type of ScoreFunction {}".format(type) 41 | 42 | def forward( 43 | self, 44 | model: Union[Callable, nn.Module], 45 | t: torch.Tensor, 46 | x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], 47 | condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, 48 | ) -> torch.Tensor: 49 | """ 50 | Overview: 51 | Return data prediction function of the model at time t given the initial state. 52 | .. math:: 53 | (- \sigma(t) x_t + \sigma^2(t) \nabla_{x_t} \log p_{\theta}(x_t)) / s(t) 54 | 55 | Arguments: 56 | - model (:obj:`Union[Callable, nn.Module]`): The model. 57 | - t (:obj:`torch.Tensor`): The input time. 58 | - x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. 59 | - condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. 60 | """ 61 | 62 | if self.model_type == "noise_function": 63 | return ( 64 | x - self.process.std(t, x) * model(t, x, condition) 65 | ) / self.process.scale(t, x) 66 | elif self.model_type == "score_function": 67 | return ( 68 | -self.process.std(t, x) * x 69 | + self.process.covariance(t, x) * model(t, x, condition) 70 | ) / self.process.scale(t, x) 71 | elif self.model_type == "velocity_function": 72 | return ( 73 | (self.process.drift(t, x) - model(t, x, condition)) 74 | * 2.0 75 | * self.process.covariance(t, x) 76 | / self.process.diffusion_squared(t, x) 77 | + x 78 | ) / self.process.scale(t, x) 79 | elif self.model_type == "data_prediction_function": 80 | return model(t, x, condition) 81 | else: 82 | raise NotImplementedError( 83 | "Unknown type of data prediction function {}".format(type) 84 | ) 85 | -------------------------------------------------------------------------------- /grl/generative_models/model_functions/noise_function.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import treetensor 6 | from easydict import EasyDict 7 | from tensordict import TensorDict 8 | 9 | 10 | class NoiseFunction: 11 | """ 12 | Overview: 13 | Model of noise function in diffusion model. 14 | Interfaces: 15 | ``__init__``, ``forward`` 16 | """ 17 | 18 | def __init__( 19 | self, 20 | model_type: str, 21 | process: object, 22 | ): 23 | """ 24 | Overview: 25 | Initialize the noise function. 26 | Arguments: 27 | - model_type (:obj:`str`): The type of the model. 28 | - process (:obj:`object`): The process. 29 | """ 30 | 31 | self.model_type = model_type 32 | self.process = process 33 | # TODO: add more types 34 | assert self.model_type in [ 35 | "data_prediction_function", 36 | "noise_function", 37 | "score_function", 38 | "velocity_function", 39 | "denoiser_function", 40 | ], "Unknown type of ScoreFunction {}".format(type) 41 | 42 | def forward( 43 | self, 44 | model: Union[Callable, nn.Module], 45 | t: torch.Tensor, 46 | x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], 47 | condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, 48 | ) -> torch.Tensor: 49 | """ 50 | Overview: 51 | Return noise function of the model at time t given the initial state. 52 | .. math:: 53 | \frac{- \sigma(t) x_t + \sigma^2(t) \nabla_{x_t} \log p_{\theta}(x_t)}{s(t)} 54 | 55 | Arguments: 56 | - model (:obj:`Union[Callable, nn.Module]`): The model. 57 | - t (:obj:`torch.Tensor`): The input time. 58 | - x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. 59 | - condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. 60 | """ 61 | 62 | if self.model_type == "noise_function": 63 | return model(t, x, condition) 64 | elif self.model_type == "score_function": 65 | return -model(t, x, condition) * self.process.std(t, x) 66 | elif self.model_type == "velocity_function": 67 | return ( 68 | (model(t, x, condition) - self.process.drift(t, x)) 69 | * 2.0 70 | * self.process.std(t, x) 71 | / self.process.diffusion_squared(t, x) 72 | ) 73 | elif self.model_type == "data_prediction_function": 74 | return ( 75 | x - self.process.scale(t, x) * model(t, x, condition) 76 | ) / self.process.std(t, x) 77 | else: 78 | raise NotImplementedError("Unknown type of noise function {}".format(type)) 79 | -------------------------------------------------------------------------------- /grl/generative_models/sro.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from easydict import EasyDict 6 | from tensordict import TensorDict 7 | 8 | from grl.generative_models.diffusion_model.diffusion_model import DiffusionModel 9 | 10 | 11 | class SRPOConditionalDiffusionModel(nn.Module): 12 | """ 13 | Overview: 14 | Score regularized policy optimization from a conditional diffusion model to some stochastic or deterministic model of some distribution type. 15 | Interfaces: 16 | ``__init__``, ``score_matching_loss``, ``srpo_loss`` 17 | """ 18 | 19 | def __init__( 20 | self, 21 | config: EasyDict, 22 | value_model: Union[torch.nn.Module, torch.nn.ModuleDict], 23 | distribution_model, 24 | ) -> None: 25 | """ 26 | Overview: 27 | Initialization of the SRPOConditionalDiffusionModel. 28 | Arguments: 29 | config (:obj:`EasyDict`): The configuration. 30 | energy_model (:obj:`Union[torch.nn.Module, torch.nn.ModuleDict]`): The energy model. 31 | """ 32 | 33 | super().__init__() 34 | self.config = config 35 | self.diffusion_model = DiffusionModel(config) 36 | self.value_model = value_model 37 | self.distribution_model = distribution_model 38 | self.env_beta = config.beta 39 | 40 | def score_matching_loss( 41 | self, 42 | x: Union[torch.Tensor, TensorDict], 43 | condition: Union[torch.Tensor, TensorDict] = None, 44 | ) -> torch.Tensor: 45 | """ 46 | Overview: 47 | The loss function for training unconditional diffusion model. 48 | Arguments: 49 | x (:obj:`Union[torch.Tensor, TensorDict]`): The input. 50 | condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. 51 | """ 52 | 53 | return self.diffusion_model.score_matching_loss(x, condition) 54 | 55 | def srpo_loss( 56 | self, 57 | condition: Union[torch.Tensor, TensorDict], # state 58 | ): 59 | """ 60 | Overview: 61 | The loss function for training conditional diffusion model. 62 | Arguments: 63 | condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. 64 | """ 65 | x = self.distribution_model(condition) 66 | # TODO: check if this is the right way to sample t_random with extra scaling and shifting 67 | # t_random = torch.rand(x.shape[0], device=x.device) 68 | t_random = torch.rand(x.shape[0], device=x.device) * 0.96 + 0.02 69 | x_t = self.diffusion_model.diffusion_process.direct_sample(t_random, x) 70 | wt = self.diffusion_model.diffusion_process.std(t_random, x) ** 2 71 | with torch.no_grad(): 72 | episilon = self.diffusion_model.noise_function( 73 | t_random, x_t, condition 74 | ).detach() 75 | detach_x = x.detach().requires_grad_(True) 76 | qs = self.value_model.q_target.compute_double_q(detach_x, condition) 77 | q = (qs[0].squeeze() + qs[1].squeeze()) / 2.0 78 | guidance = torch.autograd.grad(torch.sum(q), detach_x)[0].detach() 79 | loss = (episilon * x) * wt - (guidance * x) * self.env_beta 80 | return loss, torch.mean(q) 81 | -------------------------------------------------------------------------------- /grl/generative_models/variational_autoencoder.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import treetensor 7 | from easydict import EasyDict 8 | from tensordict import TensorDict 9 | 10 | from grl.neural_network import get_module 11 | from grl.neural_network.encoders import get_encoder 12 | 13 | 14 | class IntrinsicModel(nn.Module): 15 | """ 16 | Overview: 17 | Intrinsic model of VAE model. 18 | Interfaces: 19 | ``__init__``, ``forward`` 20 | """ 21 | 22 | def __init__(self, config: EasyDict): 23 | """ 24 | Overview: 25 | Initialize the intrinsic model. 26 | Arguments: 27 | config (:obj:`EasyDict`): The configuration. 28 | """ 29 | super().__init__() 30 | 31 | self.config = config 32 | assert hasattr(config, "backbone"), "backbone must be specified in config" 33 | 34 | self.model = torch.nn.ModuleDict() 35 | if hasattr(config, "x_encoder"): 36 | self.model["x_encoder"] = get_encoder(config.x_encoder.type)( 37 | **config.x_encoder.args 38 | ) 39 | else: 40 | self.model["x_encoder"] = torch.nn.Identity() 41 | if hasattr(config, "condition_encoder"): 42 | self.model["condition_encoder"] = get_encoder( 43 | config.condition_encoder.type 44 | )(**config.condition_encoder.args) 45 | else: 46 | self.model["condition_encoder"] = torch.nn.Identity() 47 | 48 | # TODO 49 | # specific backbone network 50 | self.model["backbone"] = get_module(config.backbone.type)( 51 | **config.backbone.args 52 | ) 53 | 54 | def forward( 55 | self, 56 | x: Union[torch.Tensor, TensorDict], 57 | condition: Union[torch.Tensor, TensorDict] = None, 58 | ) -> torch.Tensor: 59 | """ 60 | Overview: 61 | Return the output of the model at time t given the initial state. 62 | Arguments: 63 | x (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 64 | condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. 65 | """ 66 | 67 | if condition is not None: 68 | x = self.model["x_encoder"](x) 69 | condition = self.model["condition_encoder"](condition) 70 | output = self.model["backbone"](x, condition) 71 | else: 72 | x = self.model["x_encoder"](x) 73 | output = self.model["backbone"](x) 74 | 75 | return output 76 | 77 | 78 | class VariationalAutoencoder(nn.Module): 79 | """ 80 | Overview: 81 | Variational Autoencoder model. 82 | This is an in-development model, which is not used in the current version of the codebase. 83 | Interfaces: 84 | ``__init__``, ``encode``, ``reparameterize``, ``decode``, ``forward`` 85 | """ 86 | 87 | def __init__(self, config: EasyDict): 88 | super().__init__() 89 | 90 | self.device = config.device 91 | self.input_dim = config.input_dim 92 | self.output_dim = config.output_dim 93 | 94 | # Encoder 95 | self.encoder = IntrinsicModel(config.encoder) 96 | 97 | # Decoder 98 | self.decoder = IntrinsicModel(config.decoder) 99 | 100 | def encode( 101 | self, 102 | x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], 103 | condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, 104 | ): 105 | mu, logvar = self.encoder(x, condition) 106 | return mu, logvar 107 | 108 | def reparameterize( 109 | self, 110 | mu: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], 111 | logvar: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], 112 | ): 113 | std = torch.exp(0.5 * logvar) 114 | eps = torch.randn_like(std) 115 | z = mu + eps * std 116 | return z 117 | 118 | def decode( 119 | self, 120 | z: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], 121 | condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, 122 | ): 123 | x = self.decoder(z, condition) 124 | return x 125 | 126 | def forward( 127 | self, 128 | x: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor], 129 | condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, 130 | ): 131 | mu, logvar = self.encode(x, condition) 132 | z = self.reparameterize(mu, logvar) 133 | x_recon = self.decode(z, condition) 134 | return x_recon, mu, logvar 135 | -------------------------------------------------------------------------------- /grl/neural_network/activation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class Swish(nn.Module): 6 | """ 7 | Overview: 8 | Swish activation function. 9 | Interfaces: 10 | ``__init__``, ``forward`` 11 | """ 12 | 13 | def __init__(self): 14 | super(Swish, self).__init__() 15 | self.beta = nn.Parameter(torch.tensor(1.0)) 16 | 17 | def forward(self, x): 18 | return x * torch.sigmoid(self.beta * x) 19 | 20 | 21 | class Lambda(nn.Module): 22 | """ 23 | Overview: 24 | Lambda activation function. 25 | Interfaces: 26 | ``__init__``, ``forward`` 27 | """ 28 | 29 | def __init__(self, f): 30 | super(Lambda, self).__init__() 31 | self.f = f 32 | 33 | def forward(self, x): 34 | return self.f(x) 35 | 36 | 37 | ACTIVATIONS = { 38 | "mish": nn.Mish(), 39 | "tanh": nn.Tanh(), 40 | "relu": nn.ReLU(), 41 | "softplus": nn.Softplus(), 42 | "elu": nn.ELU(), 43 | "silu": nn.SiLU(), 44 | "swish": Swish(), 45 | "square": Lambda(lambda x: x**2), 46 | "identity": Lambda(lambda x: x), 47 | } 48 | 49 | 50 | def get_activation(name: str): 51 | if name not in ACTIVATIONS: 52 | raise ValueError("Unknown activation function {}".format(name)) 53 | return ACTIVATIONS[name] 54 | -------------------------------------------------------------------------------- /grl/neural_network/neural_operator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl/neural_network/neural_operator/__init__.py -------------------------------------------------------------------------------- /grl/neural_network/residual_network.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class MLPResNetBlock(nn.Module): 8 | """ 9 | Overview: 10 | MLPResNet block for MLPResNet. 11 | #TODO: add more details about the block. 12 | Interfaces: 13 | ``__init__``, ``forward``. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | hidden_dim: int, 19 | activations: nn.Module, 20 | dropout_rate: float = None, 21 | use_layer_norm: bool = False, 22 | ): 23 | """ 24 | Overview: 25 | Initialize the MLPResNetBlock according to arguments. 26 | Arguments: 27 | hidden_dim (:obj:`int`): The dimension of the hidden layer. 28 | activations (:obj:`nn.Module`): The activation function. 29 | dropout_rate (:obj:`float`, optional): The dropout rate. Default: None. 30 | use_layer_norm (:obj:`bool`, optional): Whether to use layer normalization. Default: False. 31 | """ 32 | 33 | super(MLPResNetBlock, self).__init__() 34 | self.hidden_dim = hidden_dim 35 | self.activations = activations 36 | self.dropout_rate = dropout_rate 37 | self.use_layer_norm = use_layer_norm 38 | 39 | if self.use_layer_norm: 40 | self.layer_norm = nn.LayerNorm(hidden_dim) 41 | 42 | self.fc1 = nn.Linear(hidden_dim, hidden_dim * 4) 43 | self.fc2 = nn.Linear(hidden_dim * 4, hidden_dim) 44 | self.residual = nn.Linear(hidden_dim, hidden_dim) 45 | self.dropout = ( 46 | nn.Dropout(dropout_rate) 47 | if dropout_rate is not None and dropout_rate > 0.0 48 | else None 49 | ) 50 | 51 | def forward(self, x: torch.Tensor): 52 | """ 53 | Overview: 54 | Return the output tensor of the input tensor. 55 | Arguments: 56 | x (:obj:`torch.Tensor`): Input tensor. 57 | Returns: 58 | x (:obj:`torch.Tensor`): Output tensor. 59 | Shapes: 60 | x (:obj:`torch.Tensor`): :math:`(B, D)`, where B is batch size and D is the dimension of the input tensor. 61 | """ 62 | residual = x 63 | if self.dropout is not None: 64 | x = self.dropout(x) 65 | 66 | if self.use_layer_norm: 67 | x = self.layer_norm(x) 68 | 69 | x = self.fc1(x) 70 | x = self.activations(x) 71 | x = self.fc2(x) 72 | 73 | if residual.shape != x.shape: 74 | residual = self.residual(residual) 75 | 76 | return residual + x 77 | 78 | 79 | class MLPResNet(nn.Module): 80 | """ 81 | Overview: 82 | Residual network build with MLP blocks. 83 | #TODO: add more details about the network. 84 | Interfaces: 85 | ``__init__``, ``forward``. 86 | """ 87 | 88 | def __init__( 89 | self, 90 | num_blocks: int, 91 | input_dim: int, 92 | output_dim: int, 93 | dropout_rate: float = None, 94 | use_layer_norm: bool = False, 95 | hidden_dim: int = 256, 96 | activations: nn.Module = nn.ReLU(), 97 | ): 98 | """ 99 | Overview: 100 | Initialize the MLPResNet. 101 | #TODO: add more details about the network. 102 | Arguments: 103 | num_blocks (:obj:`int`): The number of blocks. 104 | input_dim (:obj:`int`): The dimension of the input tensor. 105 | output_dim (:obj:`int`): The dimension of the output tensor. 106 | dropout_rate (:obj:`float`, optional): The dropout rate. Default: None. 107 | use_layer_norm (:obj:`bool`, optional): Whether to use layer normalization. Default: False. 108 | hidden_dim (:obj:`int`, optional): The dimension of the hidden layer. Default: 256. 109 | activations (:obj:`nn.Module`, optional): The activation function. Default: nn.ReLU(). 110 | """ 111 | super(MLPResNet, self).__init__() 112 | self.num_blocks = num_blocks 113 | self.out_dim = output_dim 114 | self.dropout_rate = dropout_rate 115 | self.use_layer_norm = use_layer_norm 116 | self.hidden_dim = hidden_dim 117 | self.activations = activations 118 | 119 | self.fc = nn.Linear(input_dim + 128, self.hidden_dim) 120 | 121 | self.blocks = nn.ModuleList( 122 | [ 123 | MLPResNetBlock( 124 | self.hidden_dim, 125 | self.activations, 126 | self.dropout_rate, 127 | self.use_layer_norm, 128 | ) 129 | for _ in range(self.num_blocks) 130 | ] 131 | ) 132 | 133 | self.out_fc = nn.Linear(self.hidden_dim, self.out_dim) 134 | 135 | def forward(self, x: torch.Tensor): 136 | """ 137 | Overview: 138 | Return the output tensor of the input tensor. 139 | Arguments: 140 | x (:obj:`torch.Tensor`): Input tensor. 141 | Returns: 142 | x (:obj:`torch.Tensor`): Output tensor. 143 | Shapes: 144 | x (:obj:`torch.Tensor`): :math:`(B, D)`, where B is batch size and D is the dimension of the input tensor. 145 | """ 146 | x = self.fc(x) 147 | 148 | for block in self.blocks: 149 | x = block(x) 150 | 151 | x = self.activations(x) 152 | x = self.out_fc(x) 153 | 154 | return x 155 | -------------------------------------------------------------------------------- /grl/neural_network/transformers/__init__.py: -------------------------------------------------------------------------------- 1 | from .maxvit import MaxVit 2 | -------------------------------------------------------------------------------- /grl/neural_network/unet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl/neural_network/unet/__init__.py -------------------------------------------------------------------------------- /grl/numerical_methods/__init__.py: -------------------------------------------------------------------------------- 1 | from .numerical_solvers import DPMSolver, ODESolver, SDESolver 2 | from .ode import ODE 3 | from .probability_path import GaussianConditionalProbabilityPath 4 | from .sde import SDE 5 | -------------------------------------------------------------------------------- /grl/numerical_methods/monte_carlo.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Callable 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | import torch.distributions.uniform as uniform 7 | 8 | 9 | class MonteCarloSampler: 10 | """ 11 | Overview: 12 | A class to sample from an unnormalized PDF using Monte Carlo sampling. 13 | Interface: 14 | ``__init__``, ``sample``, ``plot_samples`` 15 | """ 16 | 17 | def __init__( 18 | self, 19 | unnormalized_pdf: Callable, 20 | x_min: torch.Tensor, 21 | x_max: torch.Tensor, 22 | device: torch.device = torch.device("cpu"), 23 | ): 24 | """ 25 | Overview: 26 | Initialize the Monte Carlo sampler. 27 | Arguments: 28 | - unnormalized_pdf (:obj:`Callable`): The unnormalized PDF function. 29 | - x_min (:obj:`torch.Tensor`): The minimum value of the range. 30 | - x_max (:obj:`torch.Tensor`): The maximum value of the range. 31 | """ 32 | self.unnormalized_pdf = unnormalized_pdf 33 | self.x_min = x_min 34 | self.x_max = x_max 35 | self.device = device 36 | self.uniform_dist = uniform.Uniform( 37 | torch.tensor(self.x_min, device=device), 38 | torch.tensor(self.x_max, device=device), 39 | ) 40 | 41 | def sample(self, num_samples: int): 42 | """ 43 | Overview: 44 | Sample from the unnormalized PDF using Monte Carlo sampling. 45 | """ 46 | 47 | # if the number of accepted samples is less than the number of samples, sample more 48 | samples = torch.tensor([], device=self.device) 49 | sample_ratio = 1.0 50 | while len(samples) < num_samples: 51 | num_to_sample = math.floor((num_samples - len(samples)) * sample_ratio) 52 | # if num_to_sample is larger than INT_MAX, sample no more than INT_MAX samples 53 | if num_to_sample > 2**24: 54 | num_to_sample = 2**24 55 | samples_ = self._sample(num_to_sample) 56 | sample_ratio = num_to_sample / samples_.shape[0] 57 | samples = torch.cat([samples, samples_]) 58 | 59 | # randomly drop samples to get the exact number of samples 60 | samples = samples[:num_samples] 61 | return samples 62 | 63 | def _sample(self, num_samples: int): 64 | 65 | # Normalize the PDF 66 | # x = torch.linspace(self.x_min, self.x_max, eval_num) 67 | # pdf_values = self.unnormalized_pdf(x) 68 | # normalization_constant = torch.trapz(pdf_values, x) 69 | # normalized_pdf = self.unnormalized_pdf(x) / normalization_constant 70 | 71 | random_samples = self.uniform_dist.sample((num_samples,)) 72 | 73 | # Evaluate PDF values 74 | pdf_samples = self.unnormalized_pdf(random_samples) 75 | 76 | # Normalize PDF values 77 | normalized_pdf_samples = pdf_samples / torch.max(pdf_samples) 78 | 79 | # Accept or reject samples 80 | accepted_samples = random_samples[ 81 | torch.rand(num_samples, device=self.device) < normalized_pdf_samples 82 | ] 83 | 84 | return accepted_samples 85 | 86 | def plot_samples(self, samples, num_bins=50): 87 | plt.figure(figsize=(10, 6)) 88 | plt.hist( 89 | samples.detach().cpu().numpy(), 90 | bins=num_bins, 91 | density=True, 92 | alpha=0.5, 93 | label="Monte Carlo samples", 94 | ) 95 | x = torch.linspace(self.x_min, self.x_max, 1000) 96 | normalized_pdf = self.unnormalized_pdf(x) / torch.trapz( 97 | self.unnormalized_pdf(x), x 98 | ) 99 | plt.plot(x, normalized_pdf, color="red", label="Normalized PDF") 100 | plt.xlabel("x") 101 | plt.ylabel("Probability Density") 102 | plt.title("Sampling from an Unnormalized PDF using Monte Carlo") 103 | plt.legend() 104 | plt.show() 105 | 106 | 107 | if __name__ == "__main__": 108 | # Define the unnormalized PDF function 109 | def unnormalized_pdf(x): 110 | return torch.exp(-0.5 * (x - 0.5) ** 2) + 0.5 * torch.sin(2 * torch.pi * x) 111 | 112 | # Define the range [0, 1] 113 | x_min = 0.0 114 | x_max = 1.0 115 | 116 | # Initialize the Monte Carlo sampler 117 | monte_carlo_sampler = MonteCarloSampler(unnormalized_pdf, x_min, x_max) 118 | 119 | # Sample from the unnormalized PDF 120 | num_samples = 10000 121 | samples = monte_carlo_sampler.sample(num_samples) 122 | assert len(samples) == num_samples 123 | 124 | # Plot the samples 125 | monte_carlo_sampler.plot_samples(samples) 126 | -------------------------------------------------------------------------------- /grl/numerical_methods/numerical_solvers/__init__.py: -------------------------------------------------------------------------------- 1 | from .dpm_solver import DPMSolver 2 | from .ode_solver import DictTensorODESolver, ODESolver 3 | from .sde_solver import SDESolver 4 | 5 | 6 | def get_solver(solver_type): 7 | if solver_type.lower() in SOLVERS: 8 | return SOLVERS[solver_type.lower()] 9 | else: 10 | raise ValueError(f"Solver type {solver_type} not recognized") 11 | 12 | 13 | SOLVERS = { 14 | "DPMSolver".lower(): DPMSolver, 15 | "ODESolver".lower(): ODESolver, 16 | "DictTensorODESolver".lower(): DictTensorODESolver, 17 | "SDESolver".lower(): SDESolver, 18 | } 19 | -------------------------------------------------------------------------------- /grl/numerical_methods/numerical_solvers/sde_solver.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple, Union 2 | 3 | import torch 4 | import torchsde 5 | from torch import nn 6 | 7 | 8 | class TorchSDE(nn.Module): 9 | """ 10 | Overview: 11 | The SDE class for torchsde library, wich is an object with methods `f` and `g` representing the drift and diffusion. 12 | The output of `g` should be a single tensor of size (batch_size, d) for diagonal noise SDEs or (batch_size, d, m) for SDEs of other noise types, 13 | where d is the dimensionality of state and m is the dimensionality of Brownian motion. 14 | """ 15 | 16 | def __init__( 17 | self, 18 | drift, 19 | diffusion, 20 | noise_type, 21 | sde_type, 22 | ): 23 | """ 24 | Overview: 25 | Initialize the SDE object. 26 | Arguments: 27 | drift (:obj:`nn.Module`): The function that defines the drift of the SDE. 28 | diffusion (:obj:`nn.Module`): The function that defines the diffusion of the SDE. 29 | noise_type (:obj:`str`): The type of noise of the SDE. It can be 'diagonal', 'general', 'scalar' or 'additive'. 30 | sde_type (:obj:`str`): The type of the SDE. It can be 'ito' or 'stratonovich'. 31 | """ 32 | super().__init__() 33 | self.drift = drift 34 | self.diffusion = diffusion 35 | 36 | self.noise_type = noise_type 37 | self.sde_type = sde_type 38 | 39 | def f(self, t, y): 40 | """ 41 | Overview: 42 | The drift function of the SDE. 43 | """ 44 | return self.drift(t, y) 45 | 46 | def g(self, t, y): 47 | """ 48 | Overview: 49 | The diffusion function of the SDE. 50 | """ 51 | return self.diffusion(t, y) 52 | 53 | 54 | class SDESolver: 55 | 56 | def __init__( 57 | self, 58 | sde_solver="euler", 59 | sde_noise_type="diagonal", 60 | sde_type="ito", 61 | dt=0.001, 62 | atol=1e-5, 63 | rtol=1e-5, 64 | library="torchsde", 65 | **kwargs, 66 | ): 67 | """ 68 | Overview: 69 | Initialize the SDE solver using torchsde library. 70 | Arguments: 71 | sde_solver (:obj:`str`): The SDE solver to use. 72 | sde_noise_type (:obj:`str`): The type of noise of the SDE. It can be 'diagonal', 'general', 'scalar' or 'additive'. 73 | sde_type (:obj:`str`): The type of the SDE. It can be 'ito' or 'stratonovich'. 74 | dt (:obj:`float`): The time step. 75 | atol (:obj:`float`): The absolute tolerance. 76 | rtol (:obj:`float`): The relative tolerance. 77 | library (:obj:`str`): The library to use for the ODE solver. Currently, it supports 'torchsde'. 78 | **kwargs: Additional arguments for the ODE solver. 79 | """ 80 | super().__init__() 81 | self.sde_solver = sde_solver 82 | self.sde_noise_type = sde_noise_type 83 | self.sde_type = sde_type 84 | self.dt = dt 85 | self.atol = atol 86 | self.rtol = rtol 87 | self.nfe_drift = 0 88 | self.nfe_diffusion = 0 89 | self.kwargs = kwargs 90 | self.library = library 91 | 92 | def integrate(self, drift, diffusion, x0, t_span, logqp=False, adaptive=False): 93 | """ 94 | Overview: 95 | Integrate the SDE. 96 | Arguments: 97 | drift (:obj:`nn.Module`): The function that defines the ODE. 98 | diffusion (:obj:`nn.Module`): The function that defines the ODE. 99 | 100 | """ 101 | 102 | batch_size = x0.shape[0] 103 | data_size = x0.shape[1:] 104 | 105 | self.nfe_drift = 0 106 | self.nfe_diffusion = 0 107 | 108 | def forward_drift(t, x): 109 | self.nfe_drift += 1 110 | x = x.reshape(batch_size, *data_size) 111 | f = drift(t, x) 112 | return f.reshape(batch_size, -1) 113 | 114 | def forward_diffusion(t, x): 115 | self.nfe_diffusion += 1 116 | x = x.reshape(batch_size, *data_size) 117 | g = diffusion(t, x) 118 | return g.reshape(batch_size, -1) 119 | 120 | sde = TorchSDE( 121 | drift=forward_drift, 122 | diffusion=forward_diffusion, 123 | noise_type=self.sde_noise_type, 124 | sde_type=self.sde_type, 125 | ) 126 | 127 | x0 = x0.reshape(batch_size, -1) 128 | 129 | trajectory = torchsde.sdeint( 130 | sde, 131 | x0, 132 | t_span, 133 | method=self.sde_solver, 134 | dt=self.dt, 135 | rtol=self.rtol, 136 | atol=self.atol, 137 | logqp=logqp, 138 | adaptive=adaptive, 139 | **self.kwargs, 140 | ) 141 | 142 | trajectory = trajectory.reshape(t_span.shape[0], batch_size, *data_size) 143 | 144 | return trajectory 145 | -------------------------------------------------------------------------------- /grl/numerical_methods/ode.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | 3 | from torch import nn 4 | 5 | 6 | class ODE: 7 | """ 8 | Overview: 9 | Base class for ordinary differential equations. 10 | The ODE is defined as: 11 | 12 | .. math:: 13 | dx = f(x, t)dt 14 | 15 | where f(x, t) is the drift term. 16 | 17 | Interfaces: 18 | ``__init__`` 19 | """ 20 | 21 | def __init__( 22 | self, 23 | drift: Union[nn.Module, Callable] = None, 24 | ): 25 | self.drift = drift 26 | -------------------------------------------------------------------------------- /grl/numerical_methods/sde.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Union 2 | 3 | from torch import nn 4 | 5 | 6 | class SDE: 7 | """ 8 | Overview: 9 | Base class for stochastic differential equations. 10 | The SDE is defined as: 11 | 12 | .. math:: 13 | dx = f(x, t)dt + g(x, t)dW 14 | 15 | where f(x, t) is the drift term, g(x, t) is the diffusion term, and dW is the Wiener process. 16 | 17 | Interfaces: 18 | ``__init__`` 19 | """ 20 | 21 | def __init__( 22 | self, 23 | drift: Union[nn.Module, Callable] = None, 24 | diffusion: Union[nn.Module, Callable] = None, 25 | ): 26 | self.drift = drift 27 | self.diffusion = diffusion 28 | -------------------------------------------------------------------------------- /grl/rl_modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .simulators import GymEnvSimulator 2 | from .value_network import ( 3 | DoubleQNetwork, 4 | DoubleVNetwork, 5 | OneShotValueFunction, 6 | QNetwork, 7 | VNetwork, 8 | ) 9 | -------------------------------------------------------------------------------- /grl/rl_modules/policy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl/rl_modules/policy/__init__.py -------------------------------------------------------------------------------- /grl/rl_modules/policy/base.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl/rl_modules/policy/base.py -------------------------------------------------------------------------------- /grl/rl_modules/replay_buffer/__init__.py: -------------------------------------------------------------------------------- 1 | from .buffer_by_torchrl import GeneralListBuffer, TensorDictBuffer 2 | -------------------------------------------------------------------------------- /grl/rl_modules/simulators/__init__.py: -------------------------------------------------------------------------------- 1 | from .gym_env_simulator import GymEnvSimulator 2 | from .dm_control_env_simulator import ( 3 | DeepMindControlEnvSimulator, 4 | DeepMindControlVisualEnvSimulator, 5 | DeepMindControlVisualEnvSimulator2, 6 | ) 7 | 8 | 9 | def get_simulator(type: str): 10 | if type.lower() not in SIMULATORS: 11 | raise KeyError(f"Invalid simulator type: {type}") 12 | return SIMULATORS[type.lower()] 13 | 14 | 15 | def create_simulator(config): 16 | return get_simulator(config.type)(**config.args) 17 | 18 | 19 | SIMULATORS = { 20 | "GymEnvSimulator".lower(): GymEnvSimulator, 21 | "DeepMindControlEnvSimulator".lower(): DeepMindControlEnvSimulator, 22 | "DeepMindControlVisualEnvSimulator".lower(): DeepMindControlVisualEnvSimulator, 23 | "DeepMindControlVisualEnvSimulator2".lower(): DeepMindControlVisualEnvSimulator2, 24 | } 25 | -------------------------------------------------------------------------------- /grl/rl_modules/simulators/base.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Callable, Dict, List, Tuple, Union 2 | 3 | import torch 4 | from easydict import EasyDict 5 | 6 | 7 | class BaseSimulator: 8 | """ 9 | Overview: 10 | A base class for environment simulator in GenerativeRL. 11 | This class is used to define the interface of environment simulator in GenerativeRL. 12 | Interfaces: 13 | ``__init__``, ``collect_episodes``, ``collect_episodes``, ``evaluate`` 14 | """ 15 | 16 | def __init__(self, *args, **kwargs) -> None: 17 | """ 18 | Overview: 19 | Initialize the environment simulator. 20 | """ 21 | pass 22 | 23 | def collect_episodes( 24 | self, 25 | policy: Union[Callable, torch.nn.Module], 26 | num_episodes: int = None, 27 | num_steps: int = None, 28 | ) -> List[Dict]: 29 | """ 30 | Overview: 31 | Collect several episodes using the given policy. The environment will be reset at the beginning of each episode. 32 | No history will be stored in this method. The collected information of steps will be returned as a list of dictionaries. 33 | Arguments: 34 | policy (:obj:`Union[Callable, torch.nn.Module]`): The policy to collect episodes. 35 | num_episodes (:obj:`int`): The number of episodes to collect. 36 | num_steps (:obj:`int`): The number of steps to collect. 37 | """ 38 | 39 | pass 40 | 41 | def collect_steps( 42 | self, 43 | policy: Union[Callable, torch.nn.Module], 44 | num_episodes: int = None, 45 | num_steps: int = None, 46 | ) -> List[Dict]: 47 | """ 48 | Overview: 49 | Collect several steps using the given policy. The environment will not be reset until the end of the episode. 50 | Last observation will be stored in this method. The collected information of steps will be returned as a list of dictionaries. 51 | Arguments: 52 | policy (:obj:`Union[Callable, torch.nn.Module]`): The policy to collect steps. 53 | num_episodes (:obj:`int`): The number of episodes to collect. 54 | num_steps (:obj:`int`): The number of steps to collect. 55 | """ 56 | pass 57 | 58 | def evaluate( 59 | self, 60 | policy: Union[Callable, torch.nn.Module], 61 | num_episodes: int = None, 62 | ) -> List[Dict]: 63 | """ 64 | Overview: 65 | Evaluate the given policy using the environment. The environment will be reset at the beginning of each episode. 66 | No history will be stored in this method. The evaluation resultswill be returned as a list of dictionaries. 67 | """ 68 | pass 69 | 70 | 71 | class BaseEnv: 72 | """ 73 | Overview: 74 | A base class for environment in GenerativeRL. 75 | This class is used to define the interface of environment in GenerativeRL. 76 | Interfaces: 77 | ``__init__``, ``reset``, ``step``, ``render``, ``close`` 78 | """ 79 | 80 | def __init__(self, *args, **kwargs) -> None: 81 | """ 82 | Overview: 83 | Initialize the environment. 84 | """ 85 | pass 86 | 87 | def reset(self) -> Any: 88 | """ 89 | Overview: 90 | Reset the environment and return the initial observation. 91 | """ 92 | pass 93 | 94 | def step(self, action: Any) -> Any: 95 | """ 96 | Overview: 97 | Take an action in the environment and return the next observation, reward, done, and information. 98 | """ 99 | pass 100 | 101 | def render(self) -> None: 102 | """ 103 | Overview: 104 | Render the environment. 105 | """ 106 | pass 107 | 108 | def close(self) -> None: 109 | """ 110 | Overview: 111 | Close the environment. 112 | """ 113 | pass 114 | -------------------------------------------------------------------------------- /grl/rl_modules/value_network/__init__.py: -------------------------------------------------------------------------------- 1 | from .one_shot_value_function import OneShotValueFunction 2 | from .q_network import DoubleQNetwork, QNetwork 3 | from .value_network import DoubleVNetwork, VNetwork 4 | -------------------------------------------------------------------------------- /grl/rl_modules/value_network/one_shot_value_function.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from easydict import EasyDict 7 | from tensordict import TensorDict 8 | 9 | from grl.rl_modules.value_network.value_network import DoubleVNetwork 10 | 11 | 12 | class OneShotValueFunction(nn.Module): 13 | """ 14 | Overview: 15 | Value network for one-shot cases, which means that no Bellman backup is needed for training. 16 | Interfaces: 17 | ``__init__``, ``forward`` 18 | """ 19 | 20 | def __init__(self, config: EasyDict): 21 | """ 22 | Overview: 23 | Initialization of one-shot value network. 24 | Arguments: 25 | config (:obj:`EasyDict`): The configuration dict. 26 | """ 27 | 28 | super().__init__() 29 | self.config = config 30 | self.v_alpha = config.v_alpha 31 | self.v = DoubleVNetwork(config.DoubleVNetwork) 32 | self.v_target = copy.deepcopy(self.v).requires_grad_(False) 33 | 34 | def forward( 35 | self, 36 | state: Union[torch.Tensor, TensorDict], 37 | condition: Union[torch.Tensor, TensorDict] = None, 38 | ) -> torch.Tensor: 39 | """ 40 | Overview: 41 | Return the output of one-shot value network. 42 | Arguments: 43 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 44 | condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. 45 | """ 46 | 47 | return self.v(state, condition) 48 | 49 | def compute_double_v( 50 | self, 51 | state: Union[torch.Tensor, TensorDict], 52 | condition: Union[torch.Tensor, TensorDict] = None, 53 | ) -> Tuple[torch.Tensor, torch.Tensor]: 54 | """ 55 | Overview: 56 | Return the output of two value networks. 57 | Arguments: 58 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 59 | condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. 60 | Returns: 61 | v1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first value network. 62 | v2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second value network. 63 | """ 64 | return self.v.compute_double_v(state, condition=condition) 65 | 66 | def v_loss( 67 | self, 68 | state: Union[torch.Tensor, TensorDict], 69 | value: Union[torch.Tensor, TensorDict], 70 | condition: Union[torch.Tensor, TensorDict] = None, 71 | ) -> torch.Tensor: 72 | """ 73 | Overview: 74 | Calculate the v loss. 75 | Arguments: 76 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 77 | value (:obj:`Union[torch.Tensor, TensorDict]`): The input value. 78 | condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. 79 | Returns: 80 | v_loss (:obj:`torch.Tensor`): The v loss. 81 | """ 82 | 83 | # Update value function 84 | targets = value 85 | v0, v1 = self.v.compute_double_v(state, condition=condition) 86 | v_loss = ( 87 | torch.nn.functional.mse_loss(v0, targets) 88 | + torch.nn.functional.mse_loss(v1, targets) 89 | ) / 2 90 | return v_loss 91 | -------------------------------------------------------------------------------- /grl/rl_modules/value_network/q_network.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from easydict import EasyDict 6 | from tensordict import TensorDict 7 | 8 | from grl.neural_network import get_module 9 | from grl.neural_network.encoders import get_encoder 10 | 11 | 12 | class QNetwork(nn.Module): 13 | """ 14 | Overview: 15 | Q network, which is used to approximate the Q value. 16 | Interfaces: 17 | ``__init__``, ``forward`` 18 | """ 19 | 20 | def __init__(self, config: EasyDict): 21 | """ 22 | Overview: 23 | Initialization of Q network. 24 | Arguments: 25 | config (:obj:`EasyDict`): The configuration dict. 26 | """ 27 | super().__init__() 28 | self.config = config 29 | self.model = torch.nn.ModuleDict() 30 | if hasattr(config, "action_encoder"): 31 | self.model["action_encoder"] = get_encoder(config.action_encoder.type)( 32 | **config.action_encoder.args 33 | ) 34 | else: 35 | self.model["action_encoder"] = torch.nn.Identity() 36 | if hasattr(config, "state_encoder"): 37 | self.model["state_encoder"] = get_encoder(config.state_encoder.type)( 38 | **config.state_encoder.args 39 | ) 40 | else: 41 | self.model["state_encoder"] = torch.nn.Identity() 42 | # TODO 43 | # specific backbone network 44 | self.model["backbone"] = get_module(config.backbone.type)( 45 | **config.backbone.args 46 | ) 47 | 48 | def forward( 49 | self, 50 | action: Union[torch.Tensor, TensorDict], 51 | state: Union[torch.Tensor, TensorDict], 52 | ) -> torch.Tensor: 53 | """ 54 | Overview: 55 | Return output of Q networks. 56 | Arguments: 57 | action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. 58 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 59 | Returns: 60 | q (:obj:`Union[torch.Tensor, TensorDict]`): The output of Q network. 61 | """ 62 | action_embedding = self.model["action_encoder"](action) 63 | state_embedding = self.model["state_encoder"](state) 64 | return self.model["backbone"](action_embedding, state_embedding) 65 | 66 | 67 | class DoubleQNetwork(nn.Module): 68 | """ 69 | Overview: 70 | Double Q network, which has two Q networks. 71 | Interfaces: 72 | ``__init__``, ``forward``, ``compute_double_q``, ``compute_mininum_q`` 73 | """ 74 | 75 | def __init__(self, config: EasyDict): 76 | super().__init__() 77 | 78 | self.model = torch.nn.ModuleDict() 79 | self.model["q1"] = QNetwork(config) 80 | self.model["q2"] = QNetwork(config) 81 | 82 | def compute_double_q( 83 | self, 84 | action: Union[torch.Tensor, TensorDict], 85 | state: Union[torch.Tensor, TensorDict], 86 | ) -> Tuple[torch.Tensor, torch.Tensor]: 87 | """ 88 | Overview: 89 | Return the output of two Q networks. 90 | Arguments: 91 | action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. 92 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 93 | Returns: 94 | q1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first Q network. 95 | q2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second Q network. 96 | """ 97 | 98 | return self.model["q1"](action, state), self.model["q2"](action, state) 99 | 100 | def compute_mininum_q( 101 | self, 102 | action: Union[torch.Tensor, TensorDict], 103 | state: Union[torch.Tensor, TensorDict], 104 | ) -> torch.Tensor: 105 | """ 106 | Overview: 107 | Return the minimum output of two Q networks. 108 | Arguments: 109 | action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. 110 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 111 | Returns: 112 | minimum_q (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of Q network. 113 | """ 114 | 115 | return torch.min(*self.compute_double_q(action, state)) 116 | 117 | def forward( 118 | self, 119 | action: Union[torch.Tensor, TensorDict], 120 | state: Union[torch.Tensor, TensorDict], 121 | ) -> torch.Tensor: 122 | """ 123 | Overview: 124 | Return the minimum output of two Q networks. 125 | Arguments: 126 | action (:obj:`Union[torch.Tensor, TensorDict]`): The input action. 127 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 128 | Returns: 129 | minimum_q (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of Q network. 130 | """ 131 | 132 | return self.compute_mininum_q(action, state) 133 | -------------------------------------------------------------------------------- /grl/rl_modules/value_network/value_network.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | import torch.nn as nn 5 | from easydict import EasyDict 6 | from tensordict import TensorDict 7 | 8 | from grl.neural_network import get_module 9 | from grl.neural_network.encoders import get_encoder 10 | 11 | 12 | class VNetwork(nn.Module): 13 | """ 14 | Overview: 15 | Value network, which is used to approximate the value function. 16 | Interfaces: 17 | ``__init__``, ``forward`` 18 | """ 19 | 20 | def __init__(self, config: EasyDict): 21 | """ 22 | Overview: 23 | Initialization of value network. 24 | Arguments: 25 | config (:obj:`EasyDict`): The configuration dict. 26 | """ 27 | super().__init__() 28 | self.config = config 29 | self.model = torch.nn.ModuleDict() 30 | if hasattr(config, "state_encoder"): 31 | self.model["state_encoder"] = get_encoder(config.state_encoder.type)( 32 | **config.state_encoder.args 33 | ) 34 | else: 35 | self.model["state_encoder"] = torch.nn.Identity() 36 | if hasattr(config, "condition_encoder"): 37 | self.model["condition_encoder"] = get_encoder( 38 | config.condition_encoder.type 39 | )(**config.condition_encoder.args) 40 | else: 41 | self.model["condition_encoder"] = torch.nn.Identity() 42 | # TODO 43 | # specific backbone network 44 | self.model["backbone"] = get_module(config.backbone.type)( 45 | **config.backbone.args 46 | ) 47 | 48 | def forward( 49 | self, 50 | state: Union[torch.Tensor, TensorDict], 51 | condition: Union[torch.Tensor, TensorDict] = None, 52 | ) -> torch.Tensor: 53 | """ 54 | Overview: 55 | Return output of value networks. 56 | Arguments: 57 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 58 | condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. 59 | Returns: 60 | value (:obj:`Union[torch.Tensor, TensorDict]`): The output of value network. 61 | """ 62 | 63 | state_embedding = self.model["state_encoder"](state) 64 | if condition is not None: 65 | condition_encoder_embedding = self.model["condition_encoder"](condition) 66 | return self.model["backbone"](state_embedding, condition_encoder_embedding) 67 | else: 68 | return self.model["backbone"](state_embedding) 69 | 70 | 71 | class DoubleVNetwork(nn.Module): 72 | """ 73 | Overview: 74 | Double value network, which has two value networks. 75 | Interfaces: 76 | ``__init__``, ``forward``, ``compute_double_v``, ``compute_mininum_v`` 77 | """ 78 | 79 | def __init__(self, config: EasyDict): 80 | super().__init__() 81 | 82 | self.model = torch.nn.ModuleDict() 83 | self.model["v1"] = VNetwork(config) 84 | self.model["v2"] = VNetwork(config) 85 | 86 | def compute_double_v( 87 | self, 88 | state: Union[torch.Tensor, TensorDict], 89 | condition: Union[torch.Tensor, TensorDict], 90 | ) -> Tuple[torch.Tensor, torch.Tensor]: 91 | """ 92 | Overview: 93 | Return the output of two value networks. 94 | Arguments: 95 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 96 | condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. 97 | Returns: 98 | v1 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the first value network. 99 | v2 (:obj:`Union[torch.Tensor, TensorDict]`): The output of the second value network. 100 | """ 101 | 102 | return self.model["v1"](state, condition), self.model["v2"](state, condition) 103 | 104 | def compute_mininum_v( 105 | self, 106 | state: Union[torch.Tensor, TensorDict], 107 | condition: Union[torch.Tensor, TensorDict], 108 | ) -> torch.Tensor: 109 | """ 110 | Overview: 111 | Return the minimum output of two value networks. 112 | Arguments: 113 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 114 | condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. 115 | Returns: 116 | minimum_v (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of value network. 117 | """ 118 | 119 | return torch.min(*self.compute_double_v(state, condition=condition)) 120 | 121 | def forward( 122 | self, 123 | state: Union[torch.Tensor, TensorDict], 124 | condition: Union[torch.Tensor, TensorDict], 125 | ) -> torch.Tensor: 126 | """ 127 | Overview: 128 | Return the minimum output of two value networks. 129 | Arguments: 130 | state (:obj:`Union[torch.Tensor, TensorDict]`): The input state. 131 | condition (:obj:`Union[torch.Tensor, TensorDict]`): The input condition. 132 | Returns: 133 | minimum_v (:obj:`Union[torch.Tensor, TensorDict]`): The minimum output of value network. 134 | """ 135 | 136 | return self.compute_mininum_v(state, condition=condition) 137 | -------------------------------------------------------------------------------- /grl/rl_modules/world_model/dynamic_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from easydict import EasyDict 4 | from torch import nn 5 | 6 | from grl.generative_models import get_generative_model 7 | 8 | 9 | class DynamicModel(nn.Module): 10 | """ 11 | Overview: 12 | General dynamic model. 13 | Interfaces: 14 | ``__init__``, ``forward`` 15 | """ 16 | 17 | def __init__( 18 | self, 19 | config: EasyDict, 20 | ): 21 | """ 22 | Overview: 23 | Initialize the world model. 24 | Arguments: 25 | - config (:obj:`EasyDict`): The configuration. 26 | """ 27 | super().__init__() 28 | 29 | self.config = config 30 | self.model = get_generative_model(config.model_type)(config.model_config) 31 | 32 | def forward( 33 | self, 34 | condition: torch.Tensor, 35 | ) -> torch.Tensor: 36 | """ 37 | Overview: 38 | Return the next state given the current condition. 39 | Condition usually is a combination of action and state at the current time step or in the past. 40 | Arguments: 41 | - condition (:obj:`torch.Tensor`): The condition. 42 | """ 43 | 44 | return self.model.sample(condition=condition) 45 | 46 | def sample( 47 | self, 48 | condition: torch.Tensor, 49 | ) -> torch.Tensor: 50 | """ 51 | Overview: 52 | Return the next state given the current condition. 53 | Condition usually is a combination of action and state at the current time step or in the past. 54 | Arguments: 55 | - state (:obj:`torch.Tensor`): The current state. 56 | - condition (:obj:`torch.Tensor`): The condition. 57 | """ 58 | 59 | return self.model.sample(condition=condition) 60 | 61 | def log_prob( 62 | self, 63 | next_state: torch.Tensor, 64 | condition: torch.Tensor, 65 | ) -> torch.Tensor: 66 | """ 67 | Overview: 68 | Return the log probability of the next state given current condition. 69 | Condition usually is a combination of action and state at the current time step or in the past. 70 | Arguments: 71 | - next_state (:obj:`torch.Tensor`): The next state. 72 | - condition (:obj:`torch.Tensor`): The condition. 73 | """ 74 | 75 | return self.model.log_prob(x=next_state, condition=condition) 76 | 77 | def dynamic_loss( 78 | self, 79 | next_state: torch.Tensor, 80 | condition: torch.Tensor, 81 | ) -> torch.Tensor: 82 | """ 83 | Overview: 84 | Return the dynamic loss of the next state given current condition. 85 | Condition usually is a combination of action and state at the current time step or in the past. 86 | Arguments: 87 | - next_state (:obj:`torch.Tensor`): The next state. 88 | - condition (:obj:`torch.Tensor`): The condition. 89 | """ 90 | 91 | if self.config.loss_type == "score_matching": 92 | return self.model.score_matching_loss(x=next_state, condition=condition) 93 | elif self.config.loss_type == "flow_matching": 94 | return self.model.flow_matching_loss(x=next_state, condition=condition) 95 | else: 96 | raise ValueError("Invalid loss type: {}".format(self.config.loss_type)) 97 | -------------------------------------------------------------------------------- /grl/rl_modules/world_model/state_prior_dynamic_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from easydict import EasyDict 4 | from torch import nn 5 | 6 | from grl.generative_models import get_generative_model 7 | 8 | 9 | class StatePriorDynamicModel(nn.Module): 10 | """ 11 | Overview: 12 | Dynamic model that use state as sampling prior. 13 | Interfaces: 14 | ``__init__``, ``forward`` 15 | """ 16 | 17 | def __init__( 18 | self, 19 | config: EasyDict, 20 | ): 21 | """ 22 | Overview: 23 | Initialize the world model. 24 | Arguments: 25 | - config (:obj:`EasyDict`): The configuration. 26 | """ 27 | super().__init__() 28 | 29 | self.config = config 30 | self.model = get_generative_model(config.model_type)(config.model_config) 31 | 32 | def forward( 33 | self, 34 | state: torch.Tensor, 35 | condition: torch.Tensor, 36 | ) -> torch.Tensor: 37 | """ 38 | Overview: 39 | Return the next state given the current state and current condition. 40 | Condition usually is the action at the current time step or a combination of action and state in the past. 41 | Arguments: 42 | - state (:obj:`torch.Tensor`): The current state. 43 | - condition (:obj:`torch.Tensor`): The condition. 44 | """ 45 | 46 | return self.model.sample(x0=state, condition=condition) 47 | 48 | def sample( 49 | self, 50 | state: torch.Tensor, 51 | condition: torch.Tensor, 52 | ) -> torch.Tensor: 53 | """ 54 | Overview: 55 | Return the next state given the current state and current condition. 56 | Condition usually is the action at the current time step or a combination of action and state in the past. 57 | Arguments: 58 | - state (:obj:`torch.Tensor`): The current state. 59 | - condition (:obj:`torch.Tensor`): The condition. 60 | """ 61 | 62 | return self.model.sample(x0=state, condition=condition) 63 | 64 | def log_prob( 65 | self, 66 | state: torch.Tensor, 67 | next_state: torch.Tensor, 68 | condition: torch.Tensor, 69 | ) -> torch.Tensor: 70 | """ 71 | Overview: 72 | Return the log probability of the next state given the current state and current condition. 73 | Condition usually is the action at the current time step or a combination of action and state in the past. 74 | Arguments: 75 | - state (:obj:`torch.Tensor`): The current state. 76 | - next_state (:obj:`torch.Tensor`): The next state. 77 | - condition (:obj:`torch.Tensor`): The condition. 78 | """ 79 | 80 | return self.model.log_prob(x0=state, x1=next_state, condition=condition) 81 | -------------------------------------------------------------------------------- /grl/unittest/agents/functions.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import numpy as np 3 | import torch 4 | from grl.agents import obs_transform, action_transform 5 | 6 | # Assume obs_transform and action_transform are defined in the same module or imported properly here. 7 | 8 | 9 | class TestTransforms(unittest.TestCase): 10 | 11 | def setUp(self): 12 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 13 | 14 | def test_obs_transform_numpy(self): 15 | obs = np.array([1, 2, 3], dtype=np.float32) 16 | transformed = obs_transform(obs, self.device) 17 | self.assertIsInstance(transformed, torch.Tensor) 18 | self.assertTrue(transformed.is_floating_point()) 19 | self.assertEqual(transformed.device, self.device) 20 | np.testing.assert_array_equal(transformed.cpu().numpy(), obs) 21 | 22 | def test_obs_transform_dict(self): 23 | obs = { 24 | "a": np.array([1, 2, 3], dtype=np.float32), 25 | "b": np.array([4, 5, 6], dtype=np.float32), 26 | } 27 | transformed = obs_transform(obs, self.device) 28 | self.assertIsInstance(transformed, dict) 29 | for k, v in transformed.items(): 30 | self.assertIsInstance(v, torch.Tensor) 31 | self.assertTrue(v.is_floating_point()) 32 | self.assertEqual(v.device, self.device) 33 | np.testing.assert_array_equal(v.cpu().numpy(), obs[k]) 34 | 35 | def test_obs_transform_tensor(self): 36 | obs = torch.tensor([1, 2, 3], dtype=torch.float32) 37 | transformed = obs_transform(obs, self.device) 38 | self.assertIsInstance(transformed, torch.Tensor) 39 | self.assertTrue(transformed.is_floating_point()) 40 | self.assertEqual(transformed.device, self.device) 41 | self.assertTrue(torch.equal(transformed.cpu(), obs)) 42 | 43 | def test_obs_transform_invalid(self): 44 | obs = [1, 2, 3] 45 | with self.assertRaises(ValueError): 46 | obs_transform(obs, self.device) 47 | 48 | def test_action_transform_dict(self): 49 | action = { 50 | "a": torch.tensor([1, 2, 3], dtype=torch.float32), 51 | "b": torch.tensor([4, 5, 6], dtype=torch.float32), 52 | } 53 | transformed = action_transform(action, return_as_torch_tensor=True) 54 | self.assertIsInstance(transformed, dict) 55 | for k, v in transformed.items(): 56 | self.assertIsInstance(v, torch.Tensor) 57 | self.assertFalse(v.is_cuda) 58 | self.assertTrue(torch.equal(v, action[k].cpu())) 59 | 60 | transformed = action_transform(action, return_as_torch_tensor=False) 61 | self.assertIsInstance(transformed, dict) 62 | for k, v in transformed.items(): 63 | self.assertIsInstance(v, np.ndarray) 64 | np.testing.assert_array_equal(v, action[k].cpu().numpy()) 65 | 66 | def test_action_transform_tensor(self): 67 | action = torch.tensor([1, 2, 3], dtype=torch.float32).to(self.device) 68 | transformed = action_transform(action, return_as_torch_tensor=True) 69 | self.assertIsInstance(transformed, torch.Tensor) 70 | self.assertFalse(transformed.is_cuda) 71 | self.assertTrue(torch.equal(transformed, action.cpu())) 72 | 73 | transformed = action_transform(action, return_as_torch_tensor=False) 74 | self.assertIsInstance(transformed, np.ndarray) 75 | np.testing.assert_array_equal(transformed, action.cpu().numpy()) 76 | 77 | def test_action_transform_numpy(self): 78 | action = np.array([1, 2, 3], dtype=np.float32) 79 | transformed = action_transform(action) 80 | self.assertIsInstance(transformed, np.ndarray) 81 | np.testing.assert_array_equal(transformed, action) 82 | 83 | def test_action_transform_invalid(self): 84 | action = [1, 2, 3] 85 | with self.assertRaises(ValueError): 86 | action_transform(action) 87 | 88 | 89 | if __name__ == "__main__": 90 | unittest.main() 91 | -------------------------------------------------------------------------------- /grl/unittest/neural_network/test_activation.py: -------------------------------------------------------------------------------- 1 | # Test grl/neural_network/activation.py 2 | 3 | 4 | def test_activation(): 5 | import torch 6 | from torch import nn 7 | 8 | from grl.neural_network.activation import Swish, get_activation 9 | 10 | assert type(get_activation("mish")) == nn.Mish 11 | assert type(get_activation("tanh")) == nn.Tanh 12 | assert type(get_activation("relu")) == nn.ReLU 13 | assert type(get_activation("softplus")) == nn.Softplus 14 | assert type(get_activation("elu")) == nn.ELU 15 | assert type(get_activation("silu")) == nn.SiLU 16 | assert type(get_activation("swish")) == Swish 17 | assert get_activation("square")(10) == 100 18 | assert get_activation("identity")(100) == 100 19 | 20 | try: 21 | get_activation("unknown") 22 | except ValueError as e: 23 | assert str(e) == "Unknown activation function unknown" 24 | -------------------------------------------------------------------------------- /grl/unittest/neural_network/test_encoder.py: -------------------------------------------------------------------------------- 1 | # Test grl/neural_network/encoder.py 2 | 3 | 4 | def test_encoder(): 5 | import torch 6 | from torch import nn 7 | 8 | from grl.neural_network.encoders import ( 9 | ExponentialFourierProjectionTimeEncoder, 10 | GaussianFourierProjectionEncoder, 11 | GaussianFourierProjectionTimeEncoder, 12 | ) 13 | 14 | encoder = GaussianFourierProjectionTimeEncoder(128) 15 | x = torch.randn(100) 16 | output = encoder(x) 17 | assert output.shape == (100, 128) 18 | 19 | encoder = GaussianFourierProjectionEncoder(128, x_shape=(10,), flatten=False) 20 | x = torch.randn(100, 10) 21 | output = encoder(x) 22 | assert output.shape == (100, 10, 128) 23 | 24 | encoder = GaussianFourierProjectionEncoder(128, x_shape=(10,), flatten=True) 25 | x = torch.randn(100, 10) 26 | output = encoder(x) 27 | assert output.shape == (100, 1280) 28 | 29 | encoder = GaussianFourierProjectionEncoder(128, x_shape=(10, 20), flatten=False) 30 | x = torch.randn(100, 10, 20) 31 | output = encoder(x) 32 | assert output.shape == (100, 10, 20, 128) 33 | 34 | encoder = GaussianFourierProjectionEncoder(128, x_shape=(10, 20), flatten=True) 35 | x = torch.randn(100, 10, 20) 36 | output = encoder(x) 37 | assert output.shape == (100, 25600) 38 | 39 | encoder = ExponentialFourierProjectionTimeEncoder(128) 40 | x = torch.randn(100) 41 | output = encoder(x) 42 | assert output.shape == (100, 128) 43 | -------------------------------------------------------------------------------- /grl/unittest/rl_modules/replay_buffer/test_buffer_by_torchrl.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | from easydict import EasyDict 4 | from unittest.mock import MagicMock 5 | import tempfile 6 | from grl.rl_modules.replay_buffer.buffer_by_torchrl import ( 7 | GeneralListBuffer, 8 | TensorDictBuffer, 9 | ) 10 | from tensordict import TensorDict 11 | import torch 12 | 13 | 14 | class TestGeneralListBuffer(unittest.TestCase): 15 | 16 | def setUp(self): 17 | config = EasyDict(size=10, batch_size=2) 18 | self.buffer = GeneralListBuffer(config) 19 | 20 | def test_add_and_length(self): 21 | data = [{"state": 1}, {"state": 2}] 22 | self.buffer.add(data) 23 | self.assertEqual(len(self.buffer), 2) 24 | 25 | def test_sample(self): 26 | data = [{"state": 1}, {"state": 2}] 27 | self.buffer.add(data) 28 | sample = self.buffer.sample(batch_size=1) 29 | self.assertIn(sample[0], data) 30 | 31 | def test_get_item(self): 32 | data = [{"state": 1}, {"state": 2}] 33 | self.buffer.add(data) 34 | self.assertEqual(self.buffer[0], data[0]) 35 | 36 | 37 | class TestTensorDictBuffer(unittest.TestCase): 38 | 39 | def setUp(self): 40 | config = EasyDict(size=10, batch_size=2) 41 | self.buffer = TensorDictBuffer(config) 42 | 43 | def test_add_and_length(self): 44 | data = TensorDict( 45 | {"state": torch.tensor([[1]]), "action": torch.tensor([[0]])}, 46 | batch_size=[1], 47 | ) 48 | self.buffer.add(data) 49 | self.assertEqual(len(self.buffer), 1) 50 | 51 | def test_sample(self): 52 | data = TensorDict( 53 | {"state": torch.tensor([[1]]), "action": torch.tensor([[0]])}, 54 | batch_size=[1], 55 | ) 56 | self.buffer.add(data) 57 | # TODO: temporarily remove the test for compatibility on GitHub Actions 58 | # sample = self.buffer.sample(batch_size=1) 59 | # self.assertTrue(isinstance(sample, TensorDict)) 60 | 61 | def test_get_item(self): 62 | data = TensorDict( 63 | {"state": torch.tensor([[1]]), "action": torch.tensor([[0]])}, 64 | batch_size=[1], 65 | ) 66 | self.buffer.add(data) 67 | item = self.buffer[0] 68 | self.assertTrue(torch.equal(item["state"], torch.tensor([1]))) 69 | 70 | def test_save_without_path(self): 71 | with self.assertRaises(ValueError): 72 | self.buffer.save() 73 | 74 | def test_load_without_path(self): 75 | with self.assertRaises(ValueError): 76 | self.buffer.load() 77 | 78 | def test_save_and_load_with_path(self): 79 | data = TensorDict( 80 | {"state": torch.tensor([[1]]), "action": torch.tensor([[0]])}, 81 | batch_size=[1], 82 | ) 83 | self.buffer.add(data) 84 | 85 | with tempfile.TemporaryDirectory() as tmpdirname: 86 | path = os.path.join(tmpdirname, "buffer.pkl") 87 | self.buffer.save(path) 88 | buffer_2 = TensorDictBuffer(EasyDict(size=10, batch_size=2)) 89 | buffer_2.load(path) 90 | self.assertEqual(len(buffer_2), 1) 91 | item = buffer_2[0] 92 | self.assertTrue(torch.equal(item["state"], torch.tensor([1]))) 93 | -------------------------------------------------------------------------------- /grl/unittest/utils/test_model_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import shutil 4 | import tempfile 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from grl.utils.model_utils import save_model, load_model 9 | 10 | 11 | class TestModelCheckpointing(unittest.TestCase): 12 | 13 | def setUp(self): 14 | # Create a temporary directory to save/load checkpoints 15 | self.temp_dir = tempfile.mkdtemp() 16 | 17 | # Create a simple model and optimizer for testing 18 | self.model = nn.Linear(10, 2) 19 | self.optimizer = optim.SGD(self.model.parameters(), lr=0.01) 20 | 21 | def tearDown(self): 22 | # Remove the temporary directory after the test 23 | shutil.rmtree(self.temp_dir) 24 | 25 | def test_save_model(self): 26 | # Test saving the model 27 | iteration = 100 28 | save_model(self.temp_dir, self.model, self.optimizer, iteration) 29 | 30 | # Check if the directory was created and torch.save was called correctly 31 | self.assertTrue(os.path.exists(self.temp_dir)) 32 | 33 | def test_load_model(self): 34 | # Create a mock checkpoint file 35 | iteration = 100 36 | checkpoint = { 37 | "model": self.model.state_dict(), 38 | "optimizer": self.optimizer.state_dict(), 39 | "iteration": iteration, 40 | } 41 | 42 | # Save a checkpoint file manually 43 | checkpoint_file = os.path.join(self.temp_dir, f"checkpoint_{iteration}.pt") 44 | torch.save(checkpoint, checkpoint_file) 45 | 46 | # Create a simple model and optimizer for testing 47 | new_model = nn.Linear(10, 2) 48 | new_optimizer = optim.SGD(new_model.parameters(), lr=0.01) 49 | 50 | # Test loading the model 51 | loaded_iteration = load_model(self.temp_dir, new_model, new_optimizer) 52 | 53 | # Check if the correct iteration was returned 54 | self.assertEqual(loaded_iteration, iteration) 55 | 56 | # Check if the model and optimizer were loaded correctly 57 | self.assertTrue( 58 | torch.allclose( 59 | new_model.state_dict()["weight"], self.model.state_dict()["weight"] 60 | ) 61 | ) 62 | self.assertTrue( 63 | torch.allclose( 64 | new_model.state_dict()["bias"], self.model.state_dict()["bias"] 65 | ) 66 | ) 67 | self.assertTrue( 68 | torch.allclose( 69 | torch.tensor(new_optimizer.state_dict()["param_groups"][0]["lr"]), 70 | torch.tensor(self.optimizer.state_dict()["param_groups"][0]["lr"]), 71 | ) 72 | ) 73 | self.assertTrue( 74 | torch.allclose( 75 | torch.tensor(new_optimizer.state_dict()["param_groups"][0]["momentum"]), 76 | torch.tensor( 77 | self.optimizer.state_dict()["param_groups"][0]["momentum"] 78 | ), 79 | ) 80 | ) 81 | 82 | def test_load_model_order(self): 83 | # Create mock checkpoint files 84 | iterations = [100, 200, 300] 85 | for iteration in iterations: 86 | checkpoint = { 87 | "model": self.model.state_dict(), 88 | "optimizer": self.optimizer.state_dict(), 89 | "iteration": iteration, 90 | } 91 | checkpoint_file = os.path.join(self.temp_dir, f"checkpoint_{iteration}.pt") 92 | torch.save(checkpoint, checkpoint_file) 93 | 94 | # Create a simple model and optimizer for testing 95 | new_model = nn.Linear(10, 2) 96 | new_optimizer = optim.SGD(new_model.parameters(), lr=0.01) 97 | 98 | # Test loading the model 99 | loaded_iteration = load_model(self.temp_dir, new_model, new_optimizer) 100 | 101 | # Check if the correct iteration was returned 102 | self.assertEqual(loaded_iteration, iterations[-1]) 103 | 104 | def test_load_model_no_files(self): 105 | # Test loading when no checkpoint files exist 106 | loaded_iteration = load_model(self.temp_dir, self.model, self.optimizer) 107 | 108 | # Check that the function returns -1 when no files are found 109 | self.assertEqual(loaded_iteration, -1) 110 | 111 | 112 | if __name__ == "__main__": 113 | unittest.main() 114 | -------------------------------------------------------------------------------- /grl/unittest/utils/test_plot.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import os 3 | import numpy as np 4 | from grl.utils.plot import plot_distribution 5 | 6 | 7 | class TestPlotDistribution(unittest.TestCase): 8 | 9 | def setUp(self): 10 | """ 11 | Set up the test environment. This runs before each test. 12 | """ 13 | # Sample data for testing 14 | self.B = 1000 # Number of samples 15 | self.N = 4 # Number of features 16 | self.data = np.random.randn(self.B, self.N) # Random data for demonstration 17 | self.save_path = "test_distribution_plot.png" # Path to save test plot 18 | 19 | def tearDown(self): 20 | """ 21 | Clean up after the test. This runs after each test. 22 | """ 23 | # Remove the plot file after the test if it was created 24 | if os.path.exists(self.save_path): 25 | os.remove(self.save_path) 26 | 27 | def test_plot_creation(self): 28 | """ 29 | Test if the plot is created and saved to the specified path. 30 | """ 31 | # Call the plot_distribution function 32 | plot_distribution(self.data, self.save_path) 33 | 34 | # Check if the file was created 35 | self.assertTrue( 36 | os.path.exists(self.save_path), "The plot file was not created." 37 | ) 38 | 39 | # Verify the file is not empty 40 | self.assertGreater( 41 | os.path.getsize(self.save_path), 0, "The plot file is empty." 42 | ) 43 | 44 | def test_plot_size(self): 45 | """ 46 | Test if the plot can be saved with a specified size and DPI. 47 | """ 48 | size = (8, 8) 49 | dpi = 300 50 | 51 | # Call the plot_distribution function with a custom size and DPI 52 | plot_distribution(self.data, self.save_path, size=size, dpi=dpi) 53 | 54 | # Check if the file was created 55 | self.assertTrue( 56 | os.path.exists(self.save_path), "The plot file was not created." 57 | ) 58 | 59 | # Verify the file is not empty 60 | self.assertGreater( 61 | os.path.getsize(self.save_path), 0, "The plot file is empty." 62 | ) 63 | -------------------------------------------------------------------------------- /grl/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def set_seed(seed_value=None, cudnn_deterministic=True, cudnn_benchmark=False): 9 | """ 10 | Overview: 11 | Set the random seed. If no seed value is provided, generate a random seed. 12 | Arguments: 13 | seed_value (:obj:`int`, optional): The random seed to set. If None, a random seed will be generated. 14 | cudnn_deterministic (:obj:`bool`, optional): Whether to make cuDNN operations deterministic. Defaults to True. 15 | cudnn_benchmark (:obj:`bool`, optional): Whether to enable cuDNN benchmarking for convolutional operations. Defaults to False. 16 | Returns: 17 | seed_value (:obj:`int`): The seed value used. 18 | """ 19 | 20 | if seed_value is None: 21 | # Generate a random seed from system randomness 22 | seed_value = int.from_bytes(os.urandom(4), "little") 23 | 24 | random.seed(seed_value) # Set seed for Python's built-in random library 25 | np.random.seed(seed_value) # Set seed for NumPy 26 | torch.manual_seed(seed_value) # Set seed for PyTorch 27 | torch.cuda.manual_seed(seed_value) 28 | torch.cuda.manual_seed_all(seed_value) 29 | 30 | # Set PyTorch cuDNN behavior 31 | torch.backends.cudnn.deterministic = cudnn_deterministic 32 | torch.backends.cudnn.benchmark = cudnn_benchmark 33 | 34 | return seed_value 35 | 36 | 37 | from .config import merge_dict1_into_dict2, merge_two_dicts_into_newone 38 | from .log import log 39 | from .statistics import find_parameters 40 | from .plot import plot_distribution 41 | -------------------------------------------------------------------------------- /grl/utils/config.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Dict, List, Optional, Tuple, Union 3 | 4 | from easydict import EasyDict 5 | 6 | 7 | def merge_dict1_into_dict2( 8 | dict1: Union[Dict, EasyDict], dict2: Union[Dict, EasyDict] 9 | ) -> Union[Dict, EasyDict]: 10 | """ 11 | Overview: 12 | Merge two dictionaries recursively. \ 13 | Update values in dict2 with values in dict1, and add new keys from dict1 to dict2. 14 | Arguments: 15 | - dict1 (:obj:`dict`): The first dictionary. 16 | - dict2 (:obj:`dict`): The second dictionary. 17 | """ 18 | for key, value in dict1.items(): 19 | if key in dict2 and isinstance(value, dict) and isinstance(dict2[key], dict): 20 | # Both values are dictionaries, so merge them recursively 21 | merge_dict1_into_dict2(value, dict2[key]) 22 | else: 23 | # Either the key doesn't exist in dict2 or the values are not dictionaries 24 | dict2[key] = value 25 | 26 | return dict2 27 | 28 | 29 | def merge_two_dicts_into_newone( 30 | dict1: Union[Dict, EasyDict], dict2: Union[Dict, EasyDict] 31 | ) -> Union[Dict, EasyDict]: 32 | """ 33 | Overview: 34 | Merge two dictionaries recursively into a new dictionary. \ 35 | Update values in dict2 with values in dict1, and add new keys from dict1 to dict2. 36 | Arguments: 37 | - dict1 (:obj:`dict`): The first dictionary. 38 | - dict2 (:obj:`dict`): The second dictionary. 39 | """ 40 | dict2 = deepcopy(dict2) 41 | return merge_dict1_into_dict2(dict1, dict2) 42 | -------------------------------------------------------------------------------- /grl/utils/log.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | 4 | from rich.logging import RichHandler 5 | 6 | import wandb 7 | 8 | # Silence wandb by using the following line 9 | # os.environ["WANDB_SILENT"] = "True" 10 | # wandb_logger = logging.getLogger("wandb") 11 | # wandb_logger.setLevel(logging.ERROR) 12 | 13 | FORMAT = "%(message)s" 14 | logging.basicConfig( 15 | level="INFO", format=FORMAT, datefmt="[%X]", handlers=[RichHandler()] 16 | ) 17 | 18 | log = logging.getLogger("rich") 19 | -------------------------------------------------------------------------------- /grl/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from grl.utils.log import log 4 | 5 | 6 | def save_model( 7 | path: str, 8 | model: torch.nn.Module, 9 | optimizer: torch.optim.Optimizer, 10 | iteration: int, 11 | prefix="checkpoint", 12 | ): 13 | """ 14 | Overview: 15 | Save model state_dict, optimizer state_dict and training iteration to disk, name as 'prefix_iteration.pt'. 16 | Arguments: 17 | path (:obj:`str`): path to save model 18 | model (:obj:`torch.nn.Module`): model to save 19 | optimizer (:obj:`torch.optim.Optimizer`): optimizer to save 20 | iteration (:obj:`int`): iteration to save 21 | prefix (:obj:`str`): prefix of the checkpoint file 22 | """ 23 | 24 | if not os.path.exists(path): 25 | os.makedirs(path) 26 | torch.save( 27 | dict( 28 | model=model.state_dict(), 29 | optimizer=optimizer.state_dict(), 30 | iteration=iteration, 31 | ), 32 | f=os.path.join(path, f"{prefix}_{iteration}.pt"), 33 | ) 34 | 35 | 36 | def load_model( 37 | path: str, 38 | model: torch.nn.Module, 39 | optimizer: torch.optim.Optimizer = None, 40 | prefix="checkpoint", 41 | ) -> int: 42 | """ 43 | Overview: 44 | Load model state_dict, optimizer state_dict and training iteration from disk, load the latest checkpoint file named as 'prefix_iteration.pt'. 45 | Arguments: 46 | path (:obj:`str`): path to load model 47 | model (:obj:`torch.nn.Module`): model to load 48 | optimizer (:obj:`torch.optim.Optimizer`): optimizer to load 49 | prefix (:obj:`str`): prefix of the checkpoint file 50 | Returns: 51 | - last_iteraion (:obj:`int`): the iteration of the loaded checkpoint 52 | """ 53 | 54 | last_iteraion = -1 55 | checkpoint_path = path 56 | if checkpoint_path is not None: 57 | if not os.path.exists(checkpoint_path) or not os.listdir(checkpoint_path): 58 | log.warning(f"Checkpoint path {checkpoint_path} does not exist or is empty") 59 | return last_iteraion 60 | 61 | checkpoint_files = sorted( 62 | [ 63 | f 64 | for f in os.listdir(checkpoint_path) 65 | if f.endswith(".pt") and f.startswith(prefix) 66 | ], 67 | key=lambda x: int(x.split("_")[-1].split(".")[0]), 68 | ) 69 | if not checkpoint_files: 70 | log.warning(f"No checkpoint files found in {checkpoint_path}") 71 | return last_iteraion 72 | 73 | checkpoint_file = os.path.join(checkpoint_path, checkpoint_files[-1]) 74 | 75 | checkpoint = torch.load(checkpoint_file, map_location="cpu") 76 | last_iteraion = checkpoint.get("iteration", -1) 77 | ori_state_dict = { 78 | k.replace("module.", ""): v for k, v in checkpoint["model"].items() 79 | } 80 | ori_state_dict = { 81 | k.replace("_orig_mod.", ""): v for k, v in ori_state_dict.items() 82 | } 83 | model.load_state_dict(ori_state_dict) 84 | if optimizer is not None: 85 | optimizer.load_state_dict(checkpoint["optimizer"]) 86 | log.warning(f"{last_iteraion}_checkpoint files has been loaded") 87 | return last_iteraion 88 | return last_iteraion 89 | -------------------------------------------------------------------------------- /grl/utils/plot.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | def plot_distribution(data: np.ndarray, save_path: str, size=None, dpi=500): 6 | """ 7 | Overview: 8 | Plot a grid of N x N subplots where: 9 | - Diagonal contains 1D histograms for each feature. 10 | - Off-diagonal contains 2D histograms (pcolormesh) showing relationships between pairs of features. 11 | - The colorbar of the 2D histograms shows percentages of total data points. 12 | 13 | Parameters: 14 | - data: numpy.ndarray of shape (B, N), where B is the number of samples and N is the number of features. 15 | - save_path: str, path to save the generated figure. 16 | - size: tuple (width, height), optional, size of the figure. 17 | - dpi: int, optional, resolution of the saved figure in dots per inch. 18 | """ 19 | 20 | B, N = data.shape # B: number of samples, N: number of features 21 | 22 | # Create a figure with N * N subplots 23 | fig, axes = plt.subplots(N, N, figsize=size if size else (4 * N, 4 * N)) 24 | plt.subplots_adjust(wspace=0.4, hspace=0.4) 25 | 26 | # First, calculate the global minimum and maximum for the 2D histograms (normalized as percentages) 27 | hist_range = [ 28 | [np.min(data[:, i]) * 1.02, np.max(data[:, i] * 1.02)] for i in range(N) 29 | ] 30 | global_min, global_max = float("inf"), float("-inf") 31 | 32 | # Loop to calculate the min and max percentage values across all 2D histograms 33 | for i in range(N): 34 | for j in range(N): 35 | if i != j: 36 | hist, xedges, yedges = np.histogram2d( 37 | data[:, j], 38 | data[:, i], 39 | bins=30, 40 | range=[hist_range[j], hist_range[i]], 41 | ) 42 | hist_percentage = hist / B * 100 # Convert counts to percentages 43 | global_min = min(global_min, hist_percentage.min()) 44 | global_max = max(global_max, hist_percentage.max()) 45 | 46 | # Second loop to plot the figures using pcolormesh 47 | for i in range(N): 48 | for j in range(N): 49 | if i == j: 50 | # Diagonal: plot 1D histogram for feature i 51 | if N == 1: 52 | axes.hist(data[:, i], bins=30, color="skyblue", edgecolor="black") 53 | else: 54 | axes[i, j].hist( 55 | data[:, i], bins=30, color="skyblue", edgecolor="black" 56 | ) 57 | # axes[i, j].set_title(f'Hist of Feature {i+1}') 58 | else: 59 | # Off-diagonal: calculate 2D histogram and plot using pcolormesh with unified color scale (as percentage) 60 | hist, xedges, yedges = np.histogram2d( 61 | data[:, j], 62 | data[:, i], 63 | bins=30, 64 | range=[hist_range[j], hist_range[i]], 65 | ) 66 | hist_percentage = hist / B * 100 # Convert to percentage 67 | 68 | # Use pcolormesh to plot the 2D histogram 69 | mesh = axes[i, j].pcolormesh( 70 | xedges, 71 | yedges, 72 | hist_percentage.T, 73 | cmap="Blues", 74 | vmin=global_min, 75 | vmax=global_max, 76 | ) 77 | axes[i, j].set_xlabel(f"Dimension {j+1}") 78 | axes[i, j].set_ylabel(f"Dimension {i+1}") 79 | 80 | if N > 1: 81 | # Add a single colorbar for all pcolormesh plots (showing percentage) 82 | cbar = fig.colorbar( 83 | mesh, ax=axes, orientation="vertical", fraction=0.02, pad=0.04 84 | ) 85 | cbar.set_label("Percentage (%)") 86 | 87 | # Save the figure to the provided path 88 | plt.savefig(save_path, dpi=dpi, bbox_inches="tight") 89 | plt.close(fig) 90 | 91 | 92 | def plot_histogram2d_x_y(x_data, y_data, save_path: str, size=None, dpi=500): 93 | # Set up a figure with 3 subplots: 2D histogram, KDE, and scatter plot 94 | if isinstance(x_data, list): 95 | x_data = np.array(x_data) 96 | if isinstance(y_data, list): 97 | y_data = np.array(y_data) 98 | global_min, global_max = float("inf"), float("-inf") 99 | fig, ax = plt.subplots(figsize=size if size else (8, 6)) 100 | x_max = ((x_data.max() + 99) // 100) * 100 101 | y_max = np.ceil(y_data.max() / 2) * 2 102 | y_min = (y_data.min() // 2) * 2 103 | # 2D Histogram for density 104 | hist2d, xedges, yedges = np.histogram2d( 105 | x_data, y_data, bins=100, range=[[0, x_max], [y_min, y_max]] 106 | ) 107 | hist_percentage = hist2d / hist2d.sum() # Normalize the histogram 108 | global_min = min(global_min, hist_percentage.min()) 109 | global_max = max(global_max, hist_percentage.max()) 110 | # Plot the 2D histogram 111 | mesh = ax.pcolormesh( 112 | xedges, 113 | yedges, 114 | hist_percentage.T, 115 | cmap="Blues", 116 | vmin=global_min, 117 | vmax=global_max, 118 | ) 119 | ax.set_xlabel("Returns") 120 | ax.set_ylabel("LogP") 121 | ax.set_title("2D Histogram Density Plot") 122 | 123 | # Add colorbar to the 2D histogram 124 | cb = fig.colorbar(mesh, ax=ax, orientation="vertical", fraction=0.02, pad=0.04) 125 | cb.set_label("Percentage (%)") 126 | 127 | # Save the plot 128 | plt.savefig(save_path, dpi=dpi) 129 | plt.close(fig) 130 | -------------------------------------------------------------------------------- /grl/utils/statistics.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | def sort_files_by_criteria( 8 | folder_path: str, start_string: str = "checkpoint_", end_string: str = ".pt" 9 | ) -> List[str]: 10 | """ 11 | Overview: 12 | Sort the files in the specified folder by the criteria specified in the filename. 13 | If the filename is "checkpoint_N_M_..._1.pt", the files will be sorted in descending order by N, N-1, ..., 1. 14 | Arguments: 15 | - folder_path (:obj:`str`): The path to the folder containing the files. 16 | """ 17 | 18 | files = os.listdir(folder_path) 19 | file_list = [] 20 | 21 | for file in files: 22 | if file.startswith(start_string) and file.endswith(end_string): 23 | parts = file[len(start_string) : -len(end_string)].split( 24 | "_" 25 | ) # Split the filename by "_" and remove "checkpoint_" and ".pt" 26 | try: 27 | values = list(map(int, parts)) # Convert all parts to integers 28 | file_list.append( 29 | tuple(reversed(values)) + (file,) 30 | ) # Append a tuple (N, N-1, ..., 1, filename) to the list 31 | except ValueError: 32 | pass # Ignore files that don't match the expected pattern 33 | 34 | file_list.sort(reverse=True) # Sort the list in descending order 35 | sorted_files = [ 36 | file for values in file_list for file in [values[-1]] 37 | ] # Extract the filenames from the sorted tuples 38 | return sorted_files 39 | 40 | 41 | def find_parameters(module): 42 | 43 | assert isinstance(module, nn.Module) 44 | 45 | # If called within DataParallel, parameters won't appear in module.parameters(). 46 | if getattr(module, "_is_replica", False): 47 | 48 | def find_tensor_attributes(module): 49 | tuples = [ 50 | (k, v) 51 | for k, v in module.__dict__.items() 52 | if torch.is_tensor(v) and v.requires_grad 53 | ] 54 | return tuples 55 | 56 | gen = module._named_members(get_members_fn=find_tensor_attributes) 57 | return [param for _, param in gen] 58 | else: 59 | return list(module.parameters()) 60 | 61 | 62 | def calculate_tensor_memory_size(tensor): 63 | memory_usage_in_bytes = tensor.element_size() * tensor.nelement() 64 | return memory_usage_in_bytes 65 | 66 | 67 | def memory_allocated(device=torch.device("cuda")): 68 | return torch.cuda.memory_allocated(device) / (1024 * 1024 * 1024) 69 | -------------------------------------------------------------------------------- /grl_pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl_pipelines/__init__.py -------------------------------------------------------------------------------- /grl_pipelines/base.py: -------------------------------------------------------------------------------- 1 | from grl.algorithms.base import BaseAlgorithm 2 | from grl.datasets import create_dataset 3 | from grl.rl_modules.simulators import create_simulator 4 | from grl.rl_modules.simulators.base import BaseEnv 5 | from grl.utils.log import log 6 | from grl_pipelines.configurations.base import config 7 | 8 | 9 | def base_pipeline(config): 10 | """ 11 | Overview: 12 | The base pipeline for training and deploying an algorithm. 13 | Arguments: 14 | - config (:obj:`EasyDict`): The configuration, which must contain the following keys: 15 | - train (:obj:`EasyDict`): The training configuration. 16 | - train.simulator (:obj:`EasyDict`): The training environment simulator configuration. 17 | - train.dataset (:obj:`EasyDict`): The training dataset configuration. 18 | - deploy (:obj:`EasyDict`): The deployment configuration. 19 | - deploy.env (:obj:`EasyDict`): The deployment environment configuration. 20 | - deploy.num_deploy_steps (:obj:`int`): The number of deployment steps. 21 | .. note:: 22 | This pipeline is for demonstration purposes only. 23 | """ 24 | 25 | # --------------------------------------- 26 | # Customized train code ↓ 27 | # --------------------------------------- 28 | simulator = create_simulator(config.train.simulator) 29 | dataset = create_dataset(config.train.dataset) 30 | algo = BaseAlgorithm(simulator=simulator, dataset=dataset) 31 | algo.train(config=config.train) 32 | # --------------------------------------- 33 | # Customized train code ↑ 34 | # --------------------------------------- 35 | 36 | # --------------------------------------- 37 | # Customized deploy code ↓ 38 | # --------------------------------------- 39 | agent = algo.deploy(config=config.deploy) 40 | env = BaseEnv(config.deploy.env) 41 | env.reset() 42 | for _ in range(config.deploy.num_deploy_steps): 43 | env.render() 44 | env.step(agent.act(env.observation)) 45 | # --------------------------------------- 46 | # Customized deploy code ↑ 47 | # --------------------------------------- 48 | 49 | 50 | if __name__ == "__main__": 51 | log.info("config: \n{}".format(config)) 52 | base_pipeline(config) 53 | -------------------------------------------------------------------------------- /grl_pipelines/benchmark/generative_policy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl_pipelines/benchmark/generative_policy.png -------------------------------------------------------------------------------- /grl_pipelines/benchmark/gmpg/gvp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl_pipelines/benchmark/gmpg/gvp/__init__.py -------------------------------------------------------------------------------- /grl_pipelines/benchmark/gmpg/icfm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl_pipelines/benchmark/gmpg/icfm/__init__.py -------------------------------------------------------------------------------- /grl_pipelines/benchmark/gmpg/vpsde/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl_pipelines/benchmark/gmpg/vpsde/__init__.py -------------------------------------------------------------------------------- /grl_pipelines/benchmark/gmpo/gvp/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl_pipelines/benchmark/gmpo/gvp/__init__.py -------------------------------------------------------------------------------- /grl_pipelines/benchmark/gmpo/icfm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl_pipelines/benchmark/gmpo/icfm/__init__.py -------------------------------------------------------------------------------- /grl_pipelines/benchmark/gmpo/vpsde/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl_pipelines/benchmark/gmpo/vpsde/__init__.py -------------------------------------------------------------------------------- /grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from easydict import EasyDict 3 | import d4rl 4 | 5 | env_id = "halfcheetah-medium-v2" 6 | action_size = 6 7 | state_size = 17 8 | algorithm = "IDQL" 9 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 10 | 11 | t_embedding_dim = 32 12 | t_encoder = dict( 13 | type="GaussianFourierProjectionTimeEncoder", 14 | args=dict( 15 | embed_dim=t_embedding_dim, 16 | scale=30.0, 17 | ), 18 | ) 19 | 20 | config = EasyDict( 21 | train=dict( 22 | project=f"{env_id}-{algorithm}", 23 | device=device, 24 | simulator=dict( 25 | type="GymEnvSimulator", 26 | args=dict( 27 | env_id=env_id, 28 | ), 29 | ), 30 | dataset=dict( 31 | type="GPD4RLTensorDictDataset", 32 | args=dict( 33 | env_id=env_id, 34 | ), 35 | ), 36 | model=dict( 37 | IDQLPolicy=dict( 38 | device=device, 39 | critic=dict( 40 | device=device, 41 | q_alpha=1.0, 42 | DoubleQNetwork=dict( 43 | backbone=dict( 44 | type="ConcatenateMLP", 45 | args=dict( 46 | hidden_sizes=[action_size + state_size, 256, 256], 47 | output_size=1, 48 | activation="relu", 49 | ), 50 | ), 51 | ), 52 | VNetwork=dict( 53 | backbone=dict( 54 | type="MultiLayerPerceptron", 55 | args=dict( 56 | hidden_sizes=[state_size, 256, 256], 57 | output_size=1, 58 | activation="relu", 59 | ), 60 | ), 61 | ), 62 | ), 63 | diffusion_model=dict( 64 | device=device, 65 | x_size=action_size, 66 | alpha=1.0, 67 | beta=0.1, 68 | solver=dict( 69 | type="DPMSolver", 70 | args=dict( 71 | order=2, 72 | device=device, 73 | steps=17, 74 | ), 75 | ), 76 | path=dict( 77 | type="linear_vp_sde", 78 | beta_0=0.1, 79 | beta_1=20.0, 80 | ), 81 | model=dict( 82 | type="noise_function", 83 | args=dict( 84 | t_encoder=t_encoder, 85 | backbone=dict( 86 | type="TemporalSpatialResidualNet", 87 | args=dict( 88 | hidden_sizes=[512, 256, 128], 89 | output_dim=action_size, 90 | t_dim=t_embedding_dim, 91 | condition_dim=state_size, 92 | condition_hidden_dim=32, 93 | t_condition_hidden_dim=128, 94 | ), 95 | ), 96 | ), 97 | ), 98 | ), 99 | ) 100 | ), 101 | parameter=dict( 102 | behaviour_policy=dict( 103 | batch_size=4096, 104 | learning_rate=3e-4, 105 | epochs=4000, 106 | ), 107 | critic=dict( 108 | batch_size=4096, 109 | epochs=2000, 110 | learning_rate=3e-4, 111 | discount_factor=0.99, 112 | tau=0.7, 113 | update_momentum=0.005, 114 | ), 115 | evaluation=dict( 116 | evaluation_interval=50, 117 | repeat=10, 118 | ), 119 | checkpoint_path=f"./{env_id}-{algorithm}", 120 | ), 121 | ), 122 | deploy=dict( 123 | device=device, 124 | env=dict( 125 | env_id=env_id, 126 | seed=0, 127 | ), 128 | num_deploy_steps=1000, 129 | ), 130 | ) 131 | 132 | if __name__ == "__main__": 133 | 134 | import gym 135 | 136 | from grl.algorithms.idql import IDQLAlgorithm 137 | from grl.utils.log import log 138 | 139 | def idql_pipeline(config): 140 | 141 | idql = IDQLAlgorithm(config) 142 | 143 | # --------------------------------------- 144 | # Customized train code ↓ 145 | # --------------------------------------- 146 | idql.train() 147 | # --------------------------------------- 148 | # Customized train code ↑ 149 | # --------------------------------------- 150 | 151 | # --------------------------------------- 152 | # Customized deploy code ↓ 153 | # --------------------------------------- 154 | agent = idql.deploy() 155 | env = gym.make(config.deploy.env.env_id) 156 | env.reset() 157 | for _ in range(config.deploy.num_deploy_steps): 158 | env.render() 159 | env.step(agent.act(env.observation)) 160 | # --------------------------------------- 161 | # Customized deploy code ↑ 162 | # --------------------------------------- 163 | 164 | log.info("config: \n{}".format(config)) 165 | idql_pipeline(config) 166 | -------------------------------------------------------------------------------- /grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium_expert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from easydict import EasyDict 3 | import d4rl 4 | 5 | env_id = "halfcheetah-medium-expert-v2" 6 | action_size = 6 7 | state_size = 17 8 | algorithm = "IDQL" 9 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 10 | 11 | t_embedding_dim = 32 12 | t_encoder = dict( 13 | type="GaussianFourierProjectionTimeEncoder", 14 | args=dict( 15 | embed_dim=t_embedding_dim, 16 | scale=30.0, 17 | ), 18 | ) 19 | 20 | config = EasyDict( 21 | train=dict( 22 | project=f"{env_id}-{algorithm}", 23 | device=device, 24 | simulator=dict( 25 | type="GymEnvSimulator", 26 | args=dict( 27 | env_id=env_id, 28 | ), 29 | ), 30 | dataset=dict( 31 | type="GPD4RLTensorDictDataset", 32 | args=dict( 33 | env_id=env_id, 34 | ), 35 | ), 36 | model=dict( 37 | IDQLPolicy=dict( 38 | device=device, 39 | critic=dict( 40 | device=device, 41 | q_alpha=1.0, 42 | DoubleQNetwork=dict( 43 | backbone=dict( 44 | type="ConcatenateMLP", 45 | args=dict( 46 | hidden_sizes=[action_size + state_size, 256, 256], 47 | output_size=1, 48 | activation="relu", 49 | ), 50 | ), 51 | ), 52 | VNetwork=dict( 53 | backbone=dict( 54 | type="MultiLayerPerceptron", 55 | args=dict( 56 | hidden_sizes=[state_size, 256, 256], 57 | output_size=1, 58 | activation="relu", 59 | ), 60 | ), 61 | ), 62 | ), 63 | diffusion_model=dict( 64 | device=device, 65 | x_size=action_size, 66 | alpha=1.0, 67 | beta=0.1, 68 | solver=dict( 69 | type="DPMSolver", 70 | args=dict( 71 | order=2, 72 | device=device, 73 | steps=17, 74 | ), 75 | ), 76 | path=dict( 77 | type="linear_vp_sde", 78 | beta_0=0.1, 79 | beta_1=20.0, 80 | ), 81 | model=dict( 82 | type="noise_function", 83 | args=dict( 84 | t_encoder=t_encoder, 85 | backbone=dict( 86 | type="TemporalSpatialResidualNet", 87 | args=dict( 88 | hidden_sizes=[512, 256, 128], 89 | output_dim=action_size, 90 | t_dim=t_embedding_dim, 91 | condition_dim=state_size, 92 | condition_hidden_dim=32, 93 | t_condition_hidden_dim=128, 94 | ), 95 | ), 96 | ), 97 | ), 98 | ), 99 | ) 100 | ), 101 | parameter=dict( 102 | behaviour_policy=dict( 103 | batch_size=4096, 104 | learning_rate=3e-4, 105 | epochs=4000, 106 | ), 107 | critic=dict( 108 | batch_size=4096, 109 | epochs=2000, 110 | learning_rate=3e-4, 111 | discount_factor=0.99, 112 | tau=0.7, 113 | update_momentum=0.005, 114 | ), 115 | evaluation=dict( 116 | evaluation_interval=50, 117 | repeat=10, 118 | ), 119 | checkpoint_path=f"./{env_id}-{algorithm}", 120 | ), 121 | ), 122 | deploy=dict( 123 | device=device, 124 | env=dict( 125 | env_id=env_id, 126 | seed=0, 127 | ), 128 | num_deploy_steps=1000, 129 | ), 130 | ) 131 | 132 | if __name__ == "__main__": 133 | 134 | import gym 135 | 136 | from grl.algorithms.idql import IDQLAlgorithm 137 | from grl.utils.log import log 138 | 139 | def idql_pipeline(config): 140 | 141 | idql = IDQLAlgorithm(config) 142 | 143 | # --------------------------------------- 144 | # Customized train code ↓ 145 | # --------------------------------------- 146 | idql.train() 147 | # --------------------------------------- 148 | # Customized train code ↑ 149 | # --------------------------------------- 150 | 151 | # --------------------------------------- 152 | # Customized deploy code ↓ 153 | # --------------------------------------- 154 | agent = idql.deploy() 155 | env = gym.make(config.deploy.env.env_id) 156 | env.reset() 157 | for _ in range(config.deploy.num_deploy_steps): 158 | env.render() 159 | env.step(agent.act(env.observation)) 160 | # --------------------------------------- 161 | # Customized deploy code ↑ 162 | # --------------------------------------- 163 | 164 | log.info("config: \n{}".format(config)) 165 | idql_pipeline(config) 166 | -------------------------------------------------------------------------------- /grl_pipelines/benchmark/idql/vpsde/halfcheetah_medium_replay.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from easydict import EasyDict 3 | import d4rl 4 | 5 | env_id = "halfcheetah-medium-replay-v2" 6 | action_size = 6 7 | state_size = 17 8 | algorithm = "IDQL" 9 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 10 | 11 | t_embedding_dim = 32 12 | t_encoder = dict( 13 | type="GaussianFourierProjectionTimeEncoder", 14 | args=dict( 15 | embed_dim=t_embedding_dim, 16 | scale=30.0, 17 | ), 18 | ) 19 | 20 | config = EasyDict( 21 | train=dict( 22 | project=f"{env_id}-{algorithm}", 23 | device=device, 24 | simulator=dict( 25 | type="GymEnvSimulator", 26 | args=dict( 27 | env_id=env_id, 28 | ), 29 | ), 30 | dataset=dict( 31 | type="GPD4RLTensorDictDataset", 32 | args=dict( 33 | env_id=env_id, 34 | ), 35 | ), 36 | model=dict( 37 | IDQLPolicy=dict( 38 | device=device, 39 | critic=dict( 40 | device=device, 41 | q_alpha=1.0, 42 | DoubleQNetwork=dict( 43 | backbone=dict( 44 | type="ConcatenateMLP", 45 | args=dict( 46 | hidden_sizes=[action_size + state_size, 256, 256], 47 | output_size=1, 48 | activation="relu", 49 | ), 50 | ), 51 | ), 52 | VNetwork=dict( 53 | backbone=dict( 54 | type="MultiLayerPerceptron", 55 | args=dict( 56 | hidden_sizes=[state_size, 256, 256], 57 | output_size=1, 58 | activation="relu", 59 | ), 60 | ), 61 | ), 62 | ), 63 | diffusion_model=dict( 64 | device=device, 65 | x_size=action_size, 66 | alpha=1.0, 67 | beta=0.1, 68 | solver=dict( 69 | type="DPMSolver", 70 | args=dict( 71 | order=2, 72 | device=device, 73 | steps=17, 74 | ), 75 | ), 76 | path=dict( 77 | type="linear_vp_sde", 78 | beta_0=0.1, 79 | beta_1=20.0, 80 | ), 81 | model=dict( 82 | type="noise_function", 83 | args=dict( 84 | t_encoder=t_encoder, 85 | backbone=dict( 86 | type="TemporalSpatialResidualNet", 87 | args=dict( 88 | hidden_sizes=[512, 256, 128], 89 | output_dim=action_size, 90 | t_dim=t_embedding_dim, 91 | condition_dim=state_size, 92 | condition_hidden_dim=32, 93 | t_condition_hidden_dim=128, 94 | ), 95 | ), 96 | ), 97 | ), 98 | ), 99 | ) 100 | ), 101 | parameter=dict( 102 | behaviour_policy=dict( 103 | batch_size=4096, 104 | learning_rate=3e-4, 105 | epochs=4000, 106 | ), 107 | critic=dict( 108 | batch_size=4096, 109 | epochs=2000, 110 | learning_rate=3e-4, 111 | discount_factor=0.99, 112 | tau=0.7, 113 | update_momentum=0.005, 114 | ), 115 | evaluation=dict( 116 | evaluation_interval=50, 117 | repeat=10, 118 | ), 119 | checkpoint_path=f"./{env_id}-{algorithm}", 120 | ), 121 | ), 122 | deploy=dict( 123 | device=device, 124 | env=dict( 125 | env_id=env_id, 126 | seed=0, 127 | ), 128 | num_deploy_steps=1000, 129 | ), 130 | ) 131 | 132 | if __name__ == "__main__": 133 | 134 | import gym 135 | 136 | from grl.algorithms.idql import IDQLAlgorithm 137 | from grl.utils.log import log 138 | 139 | def idql_pipeline(config): 140 | 141 | idql = IDQLAlgorithm(config) 142 | 143 | # --------------------------------------- 144 | # Customized train code ↓ 145 | # --------------------------------------- 146 | idql.train() 147 | # --------------------------------------- 148 | # Customized train code ↑ 149 | # --------------------------------------- 150 | 151 | # --------------------------------------- 152 | # Customized deploy code ↓ 153 | # --------------------------------------- 154 | agent = idql.deploy() 155 | env = gym.make(config.deploy.env.env_id) 156 | env.reset() 157 | for _ in range(config.deploy.num_deploy_steps): 158 | env.render() 159 | env.step(agent.act(env.observation)) 160 | # --------------------------------------- 161 | # Customized deploy code ↑ 162 | # --------------------------------------- 163 | 164 | log.info("config: \n{}".format(config)) 165 | idql_pipeline(config) 166 | -------------------------------------------------------------------------------- /grl_pipelines/benchmark/idql/vpsde/hopper_medium.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from easydict import EasyDict 3 | import d4rl 4 | 5 | 6 | env_id = "hopper-medium-v2" 7 | action_size = 3 8 | state_size = 11 9 | algorithm = "IDQL" 10 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 11 | 12 | t_embedding_dim = 32 13 | t_encoder = dict( 14 | type="GaussianFourierProjectionTimeEncoder", 15 | args=dict( 16 | embed_dim=t_embedding_dim, 17 | scale=30.0, 18 | ), 19 | ) 20 | 21 | config = EasyDict( 22 | train=dict( 23 | project=f"{env_id}-{algorithm}", 24 | device=device, 25 | simulator=dict( 26 | type="GymEnvSimulator", 27 | args=dict( 28 | env_id=env_id, 29 | ), 30 | ), 31 | dataset=dict( 32 | type="GPD4RLTensorDictDataset", 33 | args=dict( 34 | env_id=env_id, 35 | ), 36 | ), 37 | model=dict( 38 | IDQLPolicy=dict( 39 | device=device, 40 | critic=dict( 41 | device=device, 42 | q_alpha=1.0, 43 | DoubleQNetwork=dict( 44 | backbone=dict( 45 | type="ConcatenateMLP", 46 | args=dict( 47 | hidden_sizes=[action_size + state_size, 256, 256], 48 | output_size=1, 49 | activation="relu", 50 | ), 51 | ), 52 | ), 53 | VNetwork=dict( 54 | backbone=dict( 55 | type="MultiLayerPerceptron", 56 | args=dict( 57 | hidden_sizes=[state_size, 256, 256], 58 | output_size=1, 59 | activation="relu", 60 | ), 61 | ), 62 | ), 63 | ), 64 | diffusion_model=dict( 65 | device=device, 66 | x_size=action_size, 67 | alpha=1.0, 68 | beta=0.1, 69 | solver=dict( 70 | type="DPMSolver", 71 | args=dict( 72 | order=2, 73 | device=device, 74 | steps=17, 75 | ), 76 | ), 77 | path=dict( 78 | type="linear_vp_sde", 79 | beta_0=0.1, 80 | beta_1=20.0, 81 | ), 82 | model=dict( 83 | type="noise_function", 84 | args=dict( 85 | t_encoder=t_encoder, 86 | backbone=dict( 87 | type="TemporalSpatialResidualNet", 88 | args=dict( 89 | hidden_sizes=[512, 256, 128], 90 | output_dim=action_size, 91 | t_dim=t_embedding_dim, 92 | condition_dim=state_size, 93 | condition_hidden_dim=32, 94 | t_condition_hidden_dim=128, 95 | ), 96 | ), 97 | ), 98 | ), 99 | ), 100 | ) 101 | ), 102 | parameter=dict( 103 | behaviour_policy=dict( 104 | batch_size=4096, 105 | learning_rate=3e-4, 106 | epochs=4000, 107 | ), 108 | critic=dict( 109 | batch_size=4096, 110 | epochs=2000, 111 | learning_rate=3e-4, 112 | discount_factor=0.99, 113 | tau=0.7, 114 | update_momentum=0.005, 115 | ), 116 | evaluation=dict( 117 | evaluation_interval=50, 118 | repeat=10, 119 | ), 120 | checkpoint_path=f"./{env_id}-{algorithm}", 121 | ), 122 | ), 123 | deploy=dict( 124 | device=device, 125 | env=dict( 126 | env_id=env_id, 127 | seed=0, 128 | ), 129 | num_deploy_steps=1000, 130 | ), 131 | ) 132 | 133 | if __name__ == "__main__": 134 | 135 | import gym 136 | 137 | from grl.algorithms.idql import IDQLAlgorithm 138 | from grl.utils.log import log 139 | 140 | def idql_pipeline(config): 141 | 142 | idql = IDQLAlgorithm(config) 143 | 144 | # --------------------------------------- 145 | # Customized train code ↓ 146 | # --------------------------------------- 147 | idql.train() 148 | # --------------------------------------- 149 | # Customized train code ↑ 150 | # --------------------------------------- 151 | 152 | # --------------------------------------- 153 | # Customized deploy code ↓ 154 | # --------------------------------------- 155 | agent = idql.deploy() 156 | env = gym.make(config.deploy.env.env_id) 157 | env.reset() 158 | for _ in range(config.deploy.num_deploy_steps): 159 | env.render() 160 | env.step(agent.act(env.observation)) 161 | # --------------------------------------- 162 | # Customized deploy code ↑ 163 | # --------------------------------------- 164 | 165 | log.info("config: \n{}".format(config)) 166 | idql_pipeline(config) 167 | -------------------------------------------------------------------------------- /grl_pipelines/benchmark/idql/vpsde/hopper_medium_expert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from easydict import EasyDict 3 | import d4rl 4 | 5 | 6 | env_id = "hopper-medium-expert-v2" 7 | action_size = 3 8 | state_size = 11 9 | algorithm = "IDQL" 10 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 11 | 12 | t_embedding_dim = 32 13 | t_encoder = dict( 14 | type="GaussianFourierProjectionTimeEncoder", 15 | args=dict( 16 | embed_dim=t_embedding_dim, 17 | scale=30.0, 18 | ), 19 | ) 20 | 21 | config = EasyDict( 22 | train=dict( 23 | project=f"{env_id}-{algorithm}", 24 | device=device, 25 | simulator=dict( 26 | type="GymEnvSimulator", 27 | args=dict( 28 | env_id=env_id, 29 | ), 30 | ), 31 | dataset=dict( 32 | type="GPD4RLTensorDictDataset", 33 | args=dict( 34 | env_id=env_id, 35 | ), 36 | ), 37 | model=dict( 38 | IDQLPolicy=dict( 39 | device=device, 40 | critic=dict( 41 | device=device, 42 | q_alpha=1.0, 43 | DoubleQNetwork=dict( 44 | backbone=dict( 45 | type="ConcatenateMLP", 46 | args=dict( 47 | hidden_sizes=[action_size + state_size, 256, 256], 48 | output_size=1, 49 | activation="relu", 50 | ), 51 | ), 52 | ), 53 | VNetwork=dict( 54 | backbone=dict( 55 | type="MultiLayerPerceptron", 56 | args=dict( 57 | hidden_sizes=[state_size, 256, 256], 58 | output_size=1, 59 | activation="relu", 60 | ), 61 | ), 62 | ), 63 | ), 64 | diffusion_model=dict( 65 | device=device, 66 | x_size=action_size, 67 | alpha=1.0, 68 | beta=0.1, 69 | solver=dict( 70 | type="DPMSolver", 71 | args=dict( 72 | order=2, 73 | device=device, 74 | steps=17, 75 | ), 76 | ), 77 | path=dict( 78 | type="linear_vp_sde", 79 | beta_0=0.1, 80 | beta_1=20.0, 81 | ), 82 | model=dict( 83 | type="noise_function", 84 | args=dict( 85 | t_encoder=t_encoder, 86 | backbone=dict( 87 | type="TemporalSpatialResidualNet", 88 | args=dict( 89 | hidden_sizes=[512, 256, 128], 90 | output_dim=action_size, 91 | t_dim=t_embedding_dim, 92 | condition_dim=state_size, 93 | condition_hidden_dim=32, 94 | t_condition_hidden_dim=128, 95 | ), 96 | ), 97 | ), 98 | ), 99 | ), 100 | ) 101 | ), 102 | parameter=dict( 103 | behaviour_policy=dict( 104 | batch_size=4096, 105 | learning_rate=3e-4, 106 | epochs=4000, 107 | ), 108 | critic=dict( 109 | batch_size=4096, 110 | epochs=2000, 111 | learning_rate=3e-4, 112 | discount_factor=0.99, 113 | tau=0.7, 114 | update_momentum=0.005, 115 | ), 116 | evaluation=dict( 117 | evaluation_interval=50, 118 | repeat=10, 119 | ), 120 | checkpoint_path=f"./{env_id}-{algorithm}", 121 | ), 122 | ), 123 | deploy=dict( 124 | device=device, 125 | env=dict( 126 | env_id=env_id, 127 | seed=0, 128 | ), 129 | num_deploy_steps=1000, 130 | ), 131 | ) 132 | 133 | if __name__ == "__main__": 134 | 135 | import gym 136 | 137 | from grl.algorithms.idql import IDQLAlgorithm 138 | from grl.utils.log import log 139 | 140 | def idql_pipeline(config): 141 | 142 | idql = IDQLAlgorithm(config) 143 | 144 | # --------------------------------------- 145 | # Customized train code ↓ 146 | # --------------------------------------- 147 | idql.train() 148 | # --------------------------------------- 149 | # Customized train code ↑ 150 | # --------------------------------------- 151 | 152 | # --------------------------------------- 153 | # Customized deploy code ↓ 154 | # --------------------------------------- 155 | agent = idql.deploy() 156 | env = gym.make(config.deploy.env.env_id) 157 | env.reset() 158 | for _ in range(config.deploy.num_deploy_steps): 159 | env.render() 160 | env.step(agent.act(env.observation)) 161 | # --------------------------------------- 162 | # Customized deploy code ↑ 163 | # --------------------------------------- 164 | 165 | log.info("config: \n{}".format(config)) 166 | idql_pipeline(config) 167 | -------------------------------------------------------------------------------- /grl_pipelines/benchmark/idql/vpsde/walker2d_medium.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from easydict import EasyDict 3 | import d4rl 4 | 5 | env_id = "walker2d-medium-v2" 6 | action_size = 6 7 | state_size = 17 8 | algorithm = "IDQL" 9 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 10 | 11 | t_embedding_dim = 32 12 | t_encoder = dict( 13 | type="GaussianFourierProjectionTimeEncoder", 14 | args=dict( 15 | embed_dim=t_embedding_dim, 16 | scale=30.0, 17 | ), 18 | ) 19 | 20 | config = EasyDict( 21 | train=dict( 22 | project=f"{env_id}-{algorithm}", 23 | device=device, 24 | simulator=dict( 25 | type="GymEnvSimulator", 26 | args=dict( 27 | env_id=env_id, 28 | ), 29 | ), 30 | dataset=dict( 31 | type="GPD4RLTensorDictDataset", 32 | args=dict( 33 | env_id=env_id, 34 | ), 35 | ), 36 | model=dict( 37 | IDQLPolicy=dict( 38 | device=device, 39 | critic=dict( 40 | device=device, 41 | q_alpha=1.0, 42 | DoubleQNetwork=dict( 43 | backbone=dict( 44 | type="ConcatenateMLP", 45 | args=dict( 46 | hidden_sizes=[action_size + state_size, 256, 256], 47 | output_size=1, 48 | activation="relu", 49 | ), 50 | ), 51 | ), 52 | VNetwork=dict( 53 | backbone=dict( 54 | type="MultiLayerPerceptron", 55 | args=dict( 56 | hidden_sizes=[state_size, 256, 256], 57 | output_size=1, 58 | activation="relu", 59 | ), 60 | ), 61 | ), 62 | ), 63 | diffusion_model=dict( 64 | device=device, 65 | x_size=action_size, 66 | alpha=1.0, 67 | beta=0.1, 68 | solver=dict( 69 | type="DPMSolver", 70 | args=dict( 71 | order=2, 72 | device=device, 73 | steps=17, 74 | ), 75 | ), 76 | path=dict( 77 | type="linear_vp_sde", 78 | beta_0=0.1, 79 | beta_1=20.0, 80 | ), 81 | model=dict( 82 | type="noise_function", 83 | args=dict( 84 | t_encoder=t_encoder, 85 | backbone=dict( 86 | type="TemporalSpatialResidualNet", 87 | args=dict( 88 | hidden_sizes=[512, 256, 128], 89 | output_dim=action_size, 90 | t_dim=t_embedding_dim, 91 | condition_dim=state_size, 92 | condition_hidden_dim=32, 93 | t_condition_hidden_dim=128, 94 | ), 95 | ), 96 | ), 97 | ), 98 | ), 99 | ) 100 | ), 101 | parameter=dict( 102 | behaviour_policy=dict( 103 | batch_size=4096, 104 | learning_rate=3e-4, 105 | epochs=4000, 106 | ), 107 | critic=dict( 108 | batch_size=4096, 109 | epochs=2000, 110 | learning_rate=3e-4, 111 | discount_factor=0.99, 112 | tau=0.7, 113 | update_momentum=0.005, 114 | ), 115 | evaluation=dict( 116 | evaluation_interval=50, 117 | repeat=10, 118 | ), 119 | checkpoint_path=f"./{env_id}-{algorithm}", 120 | ), 121 | ), 122 | deploy=dict( 123 | device=device, 124 | env=dict( 125 | env_id=env_id, 126 | seed=0, 127 | ), 128 | num_deploy_steps=1000, 129 | ), 130 | ) 131 | 132 | if __name__ == "__main__": 133 | 134 | import gym 135 | 136 | from grl.algorithms.idql import IDQLAlgorithm 137 | from grl.utils.log import log 138 | 139 | def idql_pipeline(config): 140 | 141 | idql = IDQLAlgorithm(config) 142 | 143 | # --------------------------------------- 144 | # Customized train code ↓ 145 | # --------------------------------------- 146 | idql.train() 147 | # --------------------------------------- 148 | # Customized train code ↑ 149 | # --------------------------------------- 150 | 151 | # --------------------------------------- 152 | # Customized deploy code ↓ 153 | # --------------------------------------- 154 | agent = idql.deploy() 155 | env = gym.make(config.deploy.env.env_id) 156 | env.reset() 157 | for _ in range(config.deploy.num_deploy_steps): 158 | env.render() 159 | env.step(agent.act(env.observation)) 160 | # --------------------------------------- 161 | # Customized deploy code ↑ 162 | # --------------------------------------- 163 | 164 | log.info("config: \n{}".format(config)) 165 | idql_pipeline(config) 166 | -------------------------------------------------------------------------------- /grl_pipelines/benchmark/idql/vpsde/walker2d_medium_expert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from easydict import EasyDict 3 | import d4rl 4 | 5 | env_id = "walker2d-medium-expert-v2" 6 | action_size = 6 7 | state_size = 17 8 | algorithm = "IDQL" 9 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 10 | 11 | t_embedding_dim = 32 12 | t_encoder = dict( 13 | type="GaussianFourierProjectionTimeEncoder", 14 | args=dict( 15 | embed_dim=t_embedding_dim, 16 | scale=30.0, 17 | ), 18 | ) 19 | 20 | config = EasyDict( 21 | train=dict( 22 | project=f"{env_id}-{algorithm}", 23 | device=device, 24 | simulator=dict( 25 | type="GymEnvSimulator", 26 | args=dict( 27 | env_id=env_id, 28 | ), 29 | ), 30 | dataset=dict( 31 | type="GPD4RLTensorDictDataset", 32 | args=dict( 33 | env_id=env_id, 34 | ), 35 | ), 36 | model=dict( 37 | IDQLPolicy=dict( 38 | device=device, 39 | critic=dict( 40 | device=device, 41 | q_alpha=1.0, 42 | DoubleQNetwork=dict( 43 | backbone=dict( 44 | type="ConcatenateMLP", 45 | args=dict( 46 | hidden_sizes=[action_size + state_size, 256, 256], 47 | output_size=1, 48 | activation="relu", 49 | ), 50 | ), 51 | ), 52 | VNetwork=dict( 53 | backbone=dict( 54 | type="MultiLayerPerceptron", 55 | args=dict( 56 | hidden_sizes=[state_size, 256, 256], 57 | output_size=1, 58 | activation="relu", 59 | ), 60 | ), 61 | ), 62 | ), 63 | diffusion_model=dict( 64 | device=device, 65 | x_size=action_size, 66 | alpha=1.0, 67 | beta=0.1, 68 | solver=dict( 69 | type="DPMSolver", 70 | args=dict( 71 | order=2, 72 | device=device, 73 | steps=17, 74 | ), 75 | ), 76 | path=dict( 77 | type="linear_vp_sde", 78 | beta_0=0.1, 79 | beta_1=20.0, 80 | ), 81 | model=dict( 82 | type="noise_function", 83 | args=dict( 84 | t_encoder=t_encoder, 85 | backbone=dict( 86 | type="TemporalSpatialResidualNet", 87 | args=dict( 88 | hidden_sizes=[512, 256, 128], 89 | output_dim=action_size, 90 | t_dim=t_embedding_dim, 91 | condition_dim=state_size, 92 | condition_hidden_dim=32, 93 | t_condition_hidden_dim=128, 94 | ), 95 | ), 96 | ), 97 | ), 98 | ), 99 | ) 100 | ), 101 | parameter=dict( 102 | behaviour_policy=dict( 103 | batch_size=4096, 104 | learning_rate=3e-4, 105 | epochs=4000, 106 | ), 107 | critic=dict( 108 | batch_size=4096, 109 | epochs=2000, 110 | learning_rate=3e-4, 111 | discount_factor=0.99, 112 | tau=0.7, 113 | update_momentum=0.005, 114 | ), 115 | evaluation=dict( 116 | evaluation_interval=50, 117 | repeat=10, 118 | ), 119 | checkpoint_path=f"./{env_id}-{algorithm}", 120 | ), 121 | ), 122 | deploy=dict( 123 | device=device, 124 | env=dict( 125 | env_id=env_id, 126 | seed=0, 127 | ), 128 | num_deploy_steps=1000, 129 | ), 130 | ) 131 | 132 | if __name__ == "__main__": 133 | 134 | import gym 135 | 136 | from grl.algorithms.idql import IDQLAlgorithm 137 | from grl.utils.log import log 138 | 139 | def idql_pipeline(config): 140 | 141 | idql = IDQLAlgorithm(config) 142 | 143 | # --------------------------------------- 144 | # Customized train code ↓ 145 | # --------------------------------------- 146 | idql.train() 147 | # --------------------------------------- 148 | # Customized train code ↑ 149 | # --------------------------------------- 150 | 151 | # --------------------------------------- 152 | # Customized deploy code ↓ 153 | # --------------------------------------- 154 | agent = idql.deploy() 155 | env = gym.make(config.deploy.env.env_id) 156 | env.reset() 157 | for _ in range(config.deploy.num_deploy_steps): 158 | env.render() 159 | env.step(agent.act(env.observation)) 160 | # --------------------------------------- 161 | # Customized deploy code ↑ 162 | # --------------------------------------- 163 | 164 | log.info("config: \n{}".format(config)) 165 | idql_pipeline(config) 166 | -------------------------------------------------------------------------------- /grl_pipelines/benchmark/idql/vpsde/walker2d_medium_replay.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from easydict import EasyDict 3 | import d4rl 4 | 5 | env_id = "walker2d-medium-replay-v2" 6 | action_size = 6 7 | state_size = 17 8 | algorithm = "IDQL" 9 | device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 10 | 11 | t_embedding_dim = 32 12 | t_encoder = dict( 13 | type="GaussianFourierProjectionTimeEncoder", 14 | args=dict( 15 | embed_dim=t_embedding_dim, 16 | scale=30.0, 17 | ), 18 | ) 19 | 20 | config = EasyDict( 21 | train=dict( 22 | project=f"{env_id}-{algorithm}", 23 | device=device, 24 | simulator=dict( 25 | type="GymEnvSimulator", 26 | args=dict( 27 | env_id=env_id, 28 | ), 29 | ), 30 | dataset=dict( 31 | type="GPD4RLTensorDictDataset", 32 | args=dict( 33 | env_id=env_id, 34 | ), 35 | ), 36 | model=dict( 37 | IDQLPolicy=dict( 38 | device=device, 39 | critic=dict( 40 | device=device, 41 | q_alpha=1.0, 42 | DoubleQNetwork=dict( 43 | backbone=dict( 44 | type="ConcatenateMLP", 45 | args=dict( 46 | hidden_sizes=[action_size + state_size, 256, 256], 47 | output_size=1, 48 | activation="relu", 49 | ), 50 | ), 51 | ), 52 | VNetwork=dict( 53 | backbone=dict( 54 | type="MultiLayerPerceptron", 55 | args=dict( 56 | hidden_sizes=[state_size, 256, 256], 57 | output_size=1, 58 | activation="relu", 59 | ), 60 | ), 61 | ), 62 | ), 63 | diffusion_model=dict( 64 | device=device, 65 | x_size=action_size, 66 | alpha=1.0, 67 | beta=0.1, 68 | solver=dict( 69 | type="DPMSolver", 70 | args=dict( 71 | order=2, 72 | device=device, 73 | steps=17, 74 | ), 75 | ), 76 | path=dict( 77 | type="linear_vp_sde", 78 | beta_0=0.1, 79 | beta_1=20.0, 80 | ), 81 | model=dict( 82 | type="noise_function", 83 | args=dict( 84 | t_encoder=t_encoder, 85 | backbone=dict( 86 | type="TemporalSpatialResidualNet", 87 | args=dict( 88 | hidden_sizes=[512, 256, 128], 89 | output_dim=action_size, 90 | t_dim=t_embedding_dim, 91 | condition_dim=state_size, 92 | condition_hidden_dim=32, 93 | t_condition_hidden_dim=128, 94 | ), 95 | ), 96 | ), 97 | ), 98 | ), 99 | ) 100 | ), 101 | parameter=dict( 102 | behaviour_policy=dict( 103 | batch_size=4096, 104 | learning_rate=3e-4, 105 | epochs=4000, 106 | ), 107 | critic=dict( 108 | batch_size=4096, 109 | epochs=2000, 110 | learning_rate=3e-4, 111 | discount_factor=0.99, 112 | tau=0.7, 113 | update_momentum=0.005, 114 | ), 115 | evaluation=dict( 116 | evaluation_interval=50, 117 | repeat=10, 118 | ), 119 | checkpoint_path=f"./{env_id}-{algorithm}", 120 | ), 121 | ), 122 | deploy=dict( 123 | device=device, 124 | env=dict( 125 | env_id=env_id, 126 | seed=0, 127 | ), 128 | num_deploy_steps=1000, 129 | ), 130 | ) 131 | 132 | if __name__ == "__main__": 133 | 134 | import gym 135 | 136 | from grl.algorithms.idql import IDQLAlgorithm 137 | from grl.utils.log import log 138 | 139 | def idql_pipeline(config): 140 | 141 | idql = IDQLAlgorithm(config) 142 | 143 | # --------------------------------------- 144 | # Customized train code ↓ 145 | # --------------------------------------- 146 | idql.train() 147 | # --------------------------------------- 148 | # Customized train code ↑ 149 | # --------------------------------------- 150 | 151 | # --------------------------------------- 152 | # Customized deploy code ↓ 153 | # --------------------------------------- 154 | agent = idql.deploy() 155 | env = gym.make(config.deploy.env.env_id) 156 | env.reset() 157 | for _ in range(config.deploy.num_deploy_steps): 158 | env.render() 159 | env.step(agent.act(env.observation)) 160 | # --------------------------------------- 161 | # Customized deploy code ↑ 162 | # --------------------------------------- 163 | 164 | log.info("config: \n{}".format(config)) 165 | idql_pipeline(config) 166 | -------------------------------------------------------------------------------- /grl_pipelines/diffusion_model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl_pipelines/diffusion_model/__init__.py -------------------------------------------------------------------------------- /grl_pipelines/diffusion_model/configurations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/opendilab/GenerativeRL/77ed957d89843e823e5780399959b41f965ebca7/grl_pipelines/diffusion_model/configurations/__init__.py -------------------------------------------------------------------------------- /grl_pipelines/diffusion_model/d4rl_halfcheetah_qgpo.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from grl.algorithms.qgpo import QGPOAlgorithm 4 | from grl.utils.log import log 5 | from grl_pipelines.diffusion_model.configurations.d4rl_halfcheetah_qgpo import config 6 | 7 | 8 | def qgpo_pipeline(config): 9 | 10 | qgpo = QGPOAlgorithm(config) 11 | 12 | # --------------------------------------- 13 | # Customized train code ↓ 14 | # --------------------------------------- 15 | qgpo.train() 16 | # --------------------------------------- 17 | # Customized train code ↑ 18 | # --------------------------------------- 19 | 20 | # --------------------------------------- 21 | # Customized deploy code ↓ 22 | # --------------------------------------- 23 | agent = qgpo.deploy() 24 | env = gym.make(config.deploy.env.env_id) 25 | observation = env.reset() 26 | for _ in range(config.deploy.num_deploy_steps): 27 | env.render() 28 | observation, reward, done, _ = env.step(agent.act(observation)) 29 | # --------------------------------------- 30 | # Customized deploy code ↑ 31 | # --------------------------------------- 32 | 33 | 34 | if __name__ == "__main__": 35 | log.info("config: \n{}".format(config)) 36 | qgpo_pipeline(config) 37 | -------------------------------------------------------------------------------- /grl_pipelines/diffusion_model/d4rl_walker2d_qgpo.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from grl.algorithms.qgpo import QGPOAlgorithm 4 | from grl.utils.log import log 5 | from grl_pipelines.diffusion_model.configurations.d4rl_walker2d_qgpo import config 6 | 7 | 8 | def qgpo_pipeline(config): 9 | 10 | qgpo = QGPOAlgorithm(config) 11 | 12 | # --------------------------------------- 13 | # Customized train code ↓ 14 | # --------------------------------------- 15 | qgpo.train() 16 | # --------------------------------------- 17 | # Customized train code ↑ 18 | # --------------------------------------- 19 | 20 | # --------------------------------------- 21 | # Customized deploy code ↓ 22 | # --------------------------------------- 23 | agent = qgpo.deploy() 24 | env = gym.make(config.deploy.env.env_id) 25 | observation = env.reset() 26 | for _ in range(config.deploy.num_deploy_steps): 27 | env.render() 28 | observation, reward, done, _ = env.step(agent.act(observation)) 29 | # --------------------------------------- 30 | # Customized deploy code ↑ 31 | # --------------------------------------- 32 | 33 | 34 | if __name__ == "__main__": 35 | log.info("config: \n{}".format(config)) 36 | qgpo_pipeline(config) 37 | -------------------------------------------------------------------------------- /grl_pipelines/diffusion_model/lunarlander_continuous_qgpo.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from grl.algorithms.qgpo import QGPOAlgorithm 4 | from grl.datasets import QGPOCustomizedTensorDictDataset 5 | from grl.utils.log import log 6 | from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import ( 7 | config, 8 | ) 9 | 10 | 11 | def qgpo_pipeline(config): 12 | 13 | qgpo = QGPOAlgorithm( 14 | config, 15 | dataset=QGPOCustomizedTensorDictDataset( 16 | numpy_data_path="./data.npz", 17 | action_augment_num=config.train.parameter.action_augment_num, 18 | ), 19 | ) 20 | 21 | # --------------------------------------- 22 | # Customized train code ↓ 23 | # --------------------------------------- 24 | qgpo.train() 25 | # --------------------------------------- 26 | # Customized train code ↑ 27 | # --------------------------------------- 28 | 29 | # --------------------------------------- 30 | # Customized deploy code ↓ 31 | # --------------------------------------- 32 | agent = qgpo.deploy() 33 | env = gym.make(config.deploy.env.env_id) 34 | observation = env.reset() 35 | for _ in range(config.deploy.num_deploy_steps): 36 | env.render() 37 | observation, reward, done, _ = env.step(agent.act(observation)) 38 | # --------------------------------------- 39 | # Customized deploy code ↑ 40 | # --------------------------------------- 41 | 42 | 43 | if __name__ == "__main__": 44 | log.info("config: \n{}".format(config)) 45 | qgpo_pipeline(config) 46 | -------------------------------------------------------------------------------- /grl_pipelines/tutorials/README.md: -------------------------------------------------------------------------------- 1 | # GenerativeRL Tutorials 2 | 3 | English | [简体中文(Simplified Chinese)](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/README.zh.md) 4 | 5 | ## Train a generative model 6 | 7 | ### Diffusion model 8 | 9 | We provide a simple colab notebook to demonstrate how to build a diffusion model using the `GenerativeRL` library. You can access the notebook [here](https://colab.research.google.com/drive/18yHUAmcMh_7xq2U6TBCtcLKX2y4YvNyk#scrollTo=aqtDAvG6cQ1V). 10 | 11 | ### Flow model 12 | 13 | We provide a simple colab notebook to demonstrate how to build a flow model using the `GenerativeRL` library. You can access the notebook [here](https://colab.research.google.com/drive/1vrxREVXKsSbnsv9G2CnKPVvrbFZleElI?usp=drive_link). 14 | 15 | ## Evaluate a generative model 16 | 17 | ### Sample generation 18 | 19 | We provide a simple colab notebook to demonstrate how to generate samples from a trained generative model using the `GenerativeRL` library. You can access the notebook [here](https://colab.research.google.com/drive/16jQhf1BDjtToxMZ4lDxB4IwGdRmr074j?usp=sharing). 20 | 21 | ### Density estimation 22 | 23 | We provide a simple colab notebook to demonstrate how to estimate the density of samples using a trained generative model using the `GenerativeRL` library. You can access the notebook [here](https://colab.research.google.com/drive/1zHsW13n338YqX87AIWG26KLC4uKQL1ZP?usp=sharing). 24 | 25 | ## Tutorials via toy examples 26 | 27 | We provide several toy examples to demonstrate the features of the `GenerativeRL` library. You can access the examples [here](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/). 28 | 29 | ### Diverse generative models 30 | 31 | - [Diffusion Model](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/generative_models/swiss_roll_diffusion.py) 32 | - [Energy condition Diffusion Model](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/generative_models/swiss_roll_energy_condition.py) 33 | - [Independent Conditional Flow Matching Model](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/generative_models/swiss_roll_icfm.py) 34 | - [Optimal Transport Conditional Flow Matching Model](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/generative_models/swiss_roll_otcfm.py) 35 | - [SF2M](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/generative_models/swiss_roll_otcfm.py) 36 | 37 | ### Generative model applications 38 | 39 | - [World Model](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/applications/swiss_roll_world_model.py) 40 | 41 | ### Generative model evaluation 42 | 43 | - [Likelihood Evaluation](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/metrics/swiss_roll_likelihood.py) 44 | 45 | ### ODE/SDE solvers usages 46 | 47 | - [DPM Solver](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/solvers/swiss_roll_dpmsolver.py) 48 | - [SDE Solver](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/solvers/swiss_roll_sdesolver.py) 49 | 50 | ### Special usages in GenerativeRL 51 | 52 | - [Customized Neural Network Modules](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/special_usages/customized_modules.py) 53 | - [Dict-like Structure Data Generation](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/special_usages/dict_tensor_ode.py) 54 | 55 | ## Use Hugging Face website to push and pull models 56 | 57 | ### Push a model 58 | 59 | We provide an [example](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/huggingface/lunarlander_continuous_qgpo_huggingface_push.py) to push a trained model to the Hugging Face website. 60 | 61 | In this example, we push a trained LunarLanderContinuous model to the Hugging Face website, and automatically generate a model card using the [template](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/huggingface/modelcard_template.md) to showcase the model's [detailed information](https://huggingface.co/OpenDILabCommunity/LunarLanderContinuous-v2-QGPO). 62 | 63 | ### Pull a model 64 | 65 | We provide an [example](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/huggingface/lunarlander_continuous_qgpo_huggingface_pull.py) to pull a model from the Hugging Face website, and test the model's performance in the environment. 66 | -------------------------------------------------------------------------------- /grl_pipelines/tutorials/README.zh.md: -------------------------------------------------------------------------------- 1 | # GenerativeRL 教程 2 | 3 | [英语 (English)](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/README.md) | 简体中文 4 | 5 | ## 训练生成模型 6 | 7 | ### 扩散模型 8 | 9 | 我们提供了一个简单的 colab 笔记本,演示如何使用 `GenerativeRL` 库构建扩散模型。您可以在[这里](https://colab.research.google.com/drive/18yHUAmcMh_7xq2U6TBCtcLKX2y4YvNyk#scrollTo=aqtDAvG6cQ1V)访问笔记本。 10 | 11 | ### 流模型 12 | 13 | 我们提供了一个简单的 colab 笔记本,演示如何使用 `GenerativeRL` 库构建流模型。您可以在[这里](https://colab.research.google.com/drive/1vrxREVXKsSbnsv9G2CnKPVvrbFZleElI?usp=drive_link)访问笔记本。 14 | 15 | ## 评估生成模型 16 | 17 | ### 采样生成 18 | 19 | 我们提供了一个简单的 colab 笔记本,演示如何使用 `GenerativeRL` 库从训练有素的生成模型生成样本。您可以在[这里](https://colab.research.google.com/drive/16jQhf1BDjtToxMZ4lDxB4IwGdRmr074j?usp=sharing)访问笔记本。 20 | 21 | ### 概率密度估计 22 | 23 | 我们提供了一个简单的 colab 笔记本,演示如何使用 `GenerativeRL` 库从训练有素的生成模型估计样本的概率密度。您可以在[这里](https://colab.research.google.com/drive/1zHsW13n338YqX87AIWG26KLC4uKQL1ZP?usp=sharing)访问笔记本。 24 | 25 | ## 玩具示例教程 26 | 27 | 我们提供了几个玩具示例,演示了 `GenerativeRL` 库的特性。您可以在[这里](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/)访问示例。 28 | 29 | ### 多种生成模型 30 | 31 | - [扩散模型](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/generative_models/swiss_roll_diffusion.py) 32 | - [能量条件扩散模型](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/generative_models/swiss_roll_energy_condition.py) 33 | - [独立条件流匹配模型](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/generative_models/swiss_roll_icfm.py) 34 | - [最优输运条件流匹配模型](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/generative_models/swiss_roll_otcfm.py) 35 | - [SF2M](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/generative_models/swiss_roll_otcfm.py) 36 | 37 | ### 生成模型应用 38 | 39 | - [世界模型](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/applications/swiss_roll_world_model.py) 40 | 41 | ### 生成模型评估 42 | 43 | - [似然性评估](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/metrics/swiss_roll_likelihood.py) 44 | 45 | ### ODE/SDE 求解器用法 46 | 47 | - [DPM 求解器](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/solvers/swiss_roll_dpmsolver.py) 48 | - [SDE 求解器](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/solvers/swiss_roll_sdesolver.py) 49 | 50 | ### GenerativeRL 的特殊用法 51 | 52 | - [自定义神经网络模块](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/special_usages/customized_modules.py) 53 | - [类似字典结构的数据生成](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/special_usages/dict_tensor_ode.py) 54 | 55 | ## 使用 Hugging Face 网站上传和下载模型 56 | 57 | ### 上传模型 58 | 我们提供了将训练好的模型上传到 Hugging Face 网站的[示例](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/huggingface/lunarlander_continuous_qgpo_huggingface_push.py)。 59 | 60 | 在这个示例中,我们将训练好的 LunarLanderContinuous 模型上传到 Hugging Face 网站,并通过[模板](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/huggingface/modelcard_template.md)自动生成模型卡片,展示模型的[详细信息](https://huggingface.co/OpenDILabCommunity/LunarLanderContinuous-v2-QGPO)。 61 | 62 | ### 下载模型 63 | 我们提供了从 Hugging Face 网站下载模型的[示例](https://github.com/opendilab/GenerativeRL/tree/main/grl_pipelines/tutorials/huggingface/lunarlander_continuous_qgpo_huggingface_pull.py)。 64 | 在这个示例中,我们下载了 Hugging Face 网站上的 LunarLanderContinuous 模型,并测试模型在该环境中的性能。 65 | -------------------------------------------------------------------------------- /grl_pipelines/tutorials/huggingface/lunarlander_continuous_qgpo_huggingface_pull.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from grl.algorithms.qgpo import QGPOAlgorithm 4 | from grl.datasets import QGPOCustomizedTensorDictDataset 5 | 6 | from grl.utils.huggingface import pull_model_from_hub 7 | 8 | 9 | def qgpo_pipeline(): 10 | 11 | policy_state_dict, config = pull_model_from_hub( 12 | repo_id="OpenDILabCommunity/LunarLanderContinuous-v2-QGPO", 13 | ) 14 | 15 | qgpo = QGPOAlgorithm( 16 | config, 17 | dataset=QGPOCustomizedTensorDictDataset( 18 | numpy_data_path="./data.npz", 19 | action_augment_num=config.train.parameter.action_augment_num, 20 | ), 21 | ) 22 | 23 | qgpo.model.load_state_dict(policy_state_dict) 24 | 25 | # --------------------------------------- 26 | # Customized train code ↓ 27 | # --------------------------------------- 28 | # qgpo.train() 29 | # --------------------------------------- 30 | # Customized train code ↑ 31 | # --------------------------------------- 32 | 33 | # --------------------------------------- 34 | # Customized deploy code ↓ 35 | # --------------------------------------- 36 | agent = qgpo.deploy() 37 | env = gym.make(config.deploy.env.env_id) 38 | observation = env.reset() 39 | images = [env.render(mode="rgb_array")] 40 | for _ in range(config.deploy.num_deploy_steps): 41 | observation, reward, done, _ = env.step(agent.act(observation)) 42 | image = env.render(mode="rgb_array") 43 | images.append(image) 44 | # save images into mp4 files 45 | import imageio.v3 as imageio 46 | import numpy as np 47 | 48 | images = np.array(images) 49 | imageio.imwrite("replay.mp4", images, fps=30, quality=8) 50 | # --------------------------------------- 51 | # Customized deploy code ↑ 52 | # --------------------------------------- 53 | 54 | 55 | if __name__ == "__main__": 56 | 57 | qgpo_pipeline() 58 | -------------------------------------------------------------------------------- /grl_pipelines/tutorials/huggingface/lunarlander_continuous_qgpo_huggingface_push.py: -------------------------------------------------------------------------------- 1 | import gym 2 | 3 | from grl.algorithms.qgpo import QGPOAlgorithm 4 | from grl.datasets import QGPOCustomizedTensorDictDataset 5 | from grl.utils.log import log 6 | from grl_pipelines.diffusion_model.configurations.lunarlander_continuous_qgpo import ( 7 | config, 8 | make_config, 9 | ) 10 | from grl.utils.huggingface import push_model_to_hub 11 | 12 | 13 | def qgpo_pipeline(config): 14 | 15 | qgpo = QGPOAlgorithm( 16 | config, 17 | dataset=QGPOCustomizedTensorDictDataset( 18 | numpy_data_path="./data.npz", 19 | action_augment_num=config.train.parameter.action_augment_num, 20 | ), 21 | ) 22 | 23 | # --------------------------------------- 24 | # Customized train code ↓ 25 | # --------------------------------------- 26 | qgpo.train() 27 | # --------------------------------------- 28 | # Customized train code ↑ 29 | # --------------------------------------- 30 | 31 | # --------------------------------------- 32 | # Customized deploy code ↓ 33 | # --------------------------------------- 34 | agent = qgpo.deploy() 35 | env = gym.make(config.deploy.env.env_id) 36 | observation = env.reset() 37 | images = [env.render(mode="rgb_array")] 38 | for _ in range(config.deploy.num_deploy_steps): 39 | observation, reward, done, _ = env.step(agent.act(observation)) 40 | image = env.render(mode="rgb_array") 41 | images.append(image) 42 | # save images into mp4 files 43 | import imageio.v3 as imageio 44 | import numpy as np 45 | 46 | images = np.array(images) 47 | imageio.imwrite("replay.mp4", images, fps=30, quality=8) 48 | # --------------------------------------- 49 | # Customized deploy code ↑ 50 | # --------------------------------------- 51 | 52 | push_model_to_hub( 53 | model=qgpo.model, 54 | config=make_config(device="cuda"), 55 | env_name="Box2d", 56 | task_name="LunarLanderContinuous-v2", 57 | algo_name="QGPO", 58 | repo_id="OpenDILabCommunity/LunarLanderContinuous-v2-QGPO", 59 | score=200.0, 60 | video_path="replay.mp4", 61 | wandb_url=None, 62 | usage_file="grl_pipelines/tutorials/huggingface/lunarlander_continuous_qgpo_huggingface_pull.py", 63 | train_file="grl_pipelines/tutorials/huggingface/lunarlander_continuous_qgpo_huggingface_push.py", 64 | github_repo_url="https://github.com/opendilab/GenerativeRL/", 65 | github_doc_model_url="https://opendilab.github.io/GenerativeRL/", 66 | github_doc_env_url="https://www.gymlibrary.dev/environments/box2d/lunar_lander/", 67 | model_description=None, 68 | installation_guide="pip3 install gym[box2d]==0.23.1", 69 | platform_info=None, 70 | create_repo=True, 71 | template_path="grl_pipelines/tutorials/huggingface/modelcard_template.md", 72 | ) 73 | 74 | 75 | if __name__ == "__main__": 76 | log.info("config: \n{}".format(config)) 77 | qgpo_pipeline(config) 78 | -------------------------------------------------------------------------------- /grl_pipelines/tutorials/huggingface/modelcard_template.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | # Play **{{ task_name | default("[More Information Needed]", true)}}** with **{{ algo_name | default("[More Information Needed]", true)}}** Policy 6 | 7 | ## Model Description 8 | 9 | 10 | This implementation applies **{{ algo_name | default("[More Information Needed]", true)}}** to the {{ benchmark_name | default("[More Information Needed]", true)}} **{{ task_name | default("[More Information Needed]", true)}}** environment using {{ platform_info | default("[GenerativeRL](https://github.com/opendilab/di-engine)", true)}}. 11 | 12 | {{ model_description | default("**GenerativeRL** is a Python library for various generative model based reinforcement learning algorithms and benchmarks. Built on PyTorch, it supports both academic research and prototype applications, offering customization of training pipelines.", false)}} 13 | 14 | ## Model Usage 15 | ### Install the Dependencies 16 |
17 | (Click for Details) 18 | 19 | ```shell 20 | # install GenerativeRL with huggingface support 21 | pip3 install GenerativeRL[huggingface] 22 | # install environment dependencies if needed 23 | {{ installation_guide | default("", false)}} 24 | ``` 25 |
26 | 27 | ### Download Model from Huggingface and Run the Model 28 | 29 |
30 | (Click for Details) 31 | 32 | ```shell 33 | # running with trained model 34 | python3 -u run.py 35 | ``` 36 | **run.py** 37 | ```python 38 | {{ usage | default("# [More Information Needed]", true)}} 39 | ``` 40 |
41 | 42 | ## Model Training 43 | 44 | ### Train the Model and Push to Huggingface_hub 45 | 46 |
47 | (Click for Details) 48 | 49 | ```shell 50 | #Training Your Own Agent 51 | python3 -u train.py 52 | ``` 53 | **train.py** 54 | ```python 55 | {{ python_code_for_train | default("# [More Information Needed]", true)}} 56 | ``` 57 |
58 | 59 | **Configuration** 60 |
61 | (Click for Details) 62 | 63 | 64 | ```python 65 | {{ python_config | default("# [More Information Needed]", true)}} 66 | ``` 67 | 68 | ```json 69 | {{ json_config | default("# [More Information Needed]", true)}} 70 | ``` 71 | 72 |
73 | 74 | **Training Procedure** 75 | 76 | - **Weights & Biases (wandb):** [monitor link]({{ wandb_url | default("[More Information Needed]", true)}}) 77 | 78 | ## Model Information 79 | 80 | - **Github Repository:** [repo link]({{ github_repo_url | default("[More Information Needed]", true)}}) 81 | - **Doc**: [Algorithm link]({{ github_doc_model_url | default("[More Information Needed]", true)}}) 82 | - **Configuration:** [config link]({{ config_file_url | default("[More Information Needed]", true)}}) 83 | - **Demo:** [video]({{ video_demo_url | default("[More Information Needed]", true)}}) 84 | 85 | - **Parameters total size:** {{ parameters_total_size | default("[More Information Needed]", true)}} 86 | - **Last Update Date:** {{ date | default("[More Information Needed]", true)}} 87 | 88 | ## Environments 89 | 90 | - **Benchmark:** {{ benchmark_name | default("[More Information Needed]", true)}} 91 | - **Task:** {{ task_name | default("[More Information Needed]", true)}} 92 | - **Gym version:** {{ gym_version | default("[More Information Needed]", true)}} 93 | - **GenerativeRL version:** {{ library_version | default("[More Information Needed]", true)}} 94 | - **PyTorch version:** {{ pytorch_version | default("[More Information Needed]", true)}} 95 | - **Doc**: [Environments link]({{ github_doc_env_url | default("[More Information Needed]", true)}}) 96 | 97 | -------------------------------------------------------------------------------- /requirements-doc.txt: -------------------------------------------------------------------------------- 1 | Jinja2>=3.0.0 2 | sphinx<7 3 | sphinx_rtd_theme>=0.4.3 4 | enum_tools 5 | sphinx-toolbox 6 | plantumlcli>=0.0.2 7 | packaging 8 | sphinx-multiversion>=0.2.4 9 | where~=1.0.2 10 | easydict 11 | ipykernel 12 | ipython 13 | m2r2 14 | nbclient 15 | nbformat 16 | nbsphinx 17 | platformdirs 18 | setuptools_scm 19 | sphinx_autodoc_typehints 20 | lxml_html_clean 21 | lxml 22 | -e git+https://github.com/opendilab/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme 23 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | with open('README.md', 'r', encoding="utf-8") as f: 4 | readme = f.read() 5 | 6 | setup( 7 | name='GenerativeRL', 8 | version='0.0.1', 9 | description='PyTorch implementations of generative reinforcement learning algorithms', 10 | long_description=readme, 11 | long_description_content_type='text/markdown', 12 | author='OpenDILab', 13 | author_email="opendilab@pjlab.org.cn", 14 | url="https://github.com/opendilab/GenerativeRL", 15 | 16 | packages=find_packages( 17 | exclude=["*.tests", "*.tests.*", "tests.*", "tests"]), 18 | install_requires=[ 19 | 'gym', 20 | 'numpy<=1.26.4', 21 | 'torch>=2.2.0', 22 | 'opencv-python', 23 | 'tensordict', 24 | 'torchrl', 25 | 'di-treetensor', 26 | 'matplotlib', 27 | 'wandb', 28 | 'rich', 29 | 'easydict', 30 | 'tqdm', 31 | 'torchdyn', 32 | 'torchsde', 33 | 'scipy', 34 | 'POT', 35 | 'beartype', 36 | 'diffusers', 37 | 'av', 38 | 'moviepy', 39 | 'imageio[ffmpeg]', 40 | ], 41 | dependency_links=[ 42 | 'git+https://github.com/rtqichen/torchdiffeq.git#egg=torchdiffeq', 43 | ], 44 | extras_require={ 45 | 'd4rl': [ 46 | 'gym==0.23.1', 47 | 'mujoco_py', 48 | 'Cython<3.0', 49 | ], 50 | 'DI-engine': [ 51 | 'DI-engine', 52 | ], 53 | 'HuggingFace': [ 54 | 'safetensors', 55 | 'huggingface_hub', 56 | ], 57 | 'formatter': [ 58 | 'black', 59 | 'isort', 60 | ], 61 | }, 62 | classifiers=[ 63 | "Programming Language :: Python :: 3", 64 | "Programming Language :: Python :: 3.9", 65 | "Programming Language :: Python :: 3.10", 66 | "Programming Language :: Python :: 3.11", 67 | "Programming Language :: Python :: 3.12", 68 | "License :: OSI Approved :: Apache Software License", 69 | "Operating System :: OS Independent", 70 | ], 71 | license="Apache-2.0", 72 | ) 73 | --------------------------------------------------------------------------------