├── .dockerignore ├── docs ├── source │ ├── _static │ │ └── .gitkeepme │ ├── image │ │ └── logo.png │ ├── utils │ │ ├── plotter.rst │ │ ├── model.rst │ │ ├── math.rst │ │ ├── distributed.rst │ │ ├── tools.rst │ │ └── config.rst │ ├── envs │ │ ├── mujoco_env.rst │ │ ├── custom.rst │ │ ├── safety_gymnasium.rst │ │ ├── core.rst │ │ ├── adapter.rst │ │ └── wrapper.rst │ ├── common │ │ ├── logger.rst │ │ ├── normalizer.rst │ │ ├── exp_grid.rst │ │ ├── simmer_agent.rst │ │ ├── offline_data.rst │ │ ├── lagrange.rst │ │ └── buffer.rst │ ├── model │ │ ├── offline.rst │ │ ├── critic.rst │ │ ├── actor_critic.rst │ │ ├── modelbased_planner.rst │ │ ├── modelbased_model.rst │ │ └── actor.rst │ ├── baserlapi │ │ ├── model_based.rst │ │ ├── off_policy.rst │ │ └── on_policy.rst │ ├── saferlapi │ │ ├── first_order.rst │ │ ├── penalty_function.rst │ │ ├── second_order.rst │ │ ├── model_based.rst │ │ └── lagrange.rst │ └── start │ │ ├── installation.rst │ │ ├── env.rst │ │ └── algo.md ├── requirements.txt ├── Makefile └── make.bat ├── MANIFEST.in ├── tutorials ├── requirements.txt ├── images │ └── omnisafe.jpg └── README.md ├── tests ├── requirements.txt ├── saved_source │ ├── Test-v0.npz │ ├── PPO-{SafetyPointGoal1-v0} │ │ └── seed-000-2023-03-16-12-08-52 │ │ │ ├── torch_save │ │ │ └── epoch-0.pt │ │ │ └── config.json │ ├── test_statistics_tools │ │ ├── SafetyAntVelocity-v1---556c9cedab7db813a6ea3860f5921d7ccbc176d70900e709065fc2604d02b9a6 │ │ │ ├── NaturalPG-{SafetyAntVelocity-v1} │ │ │ │ └── seed-000-2023-04-14-00-42-56 │ │ │ │ │ ├── torch_save │ │ │ │ │ └── epoch-0.pt │ │ │ │ │ ├── tb │ │ │ │ │ └── events.out.tfevents.1681404176.0839c7628832.57350.0 │ │ │ │ │ ├── progress.csv │ │ │ │ │ └── config.json │ │ │ └── exps_config.json │ │ ├── SafetyAntVelocity-v1---bc8c9e4f26a8072e52cf3eddefc51eae641946d3314efe537e918e0851f83aa5 │ │ │ ├── PolicyGradient-{SafetyAntVelocity-v1} │ │ │ │ └── seed-000-2023-04-14-00-42-56 │ │ │ │ │ ├── torch_save │ │ │ │ │ └── epoch-0.pt │ │ │ │ │ ├── tb │ │ │ │ │ └── events.out.tfevents.1681404176.0839c7628832.57349.0 │ │ │ │ │ ├── progress.csv │ │ │ │ │ └── config.json │ │ │ └── exps_config.json │ │ └── grid_config.json │ ├── train_config.yaml │ └── benchmark_config.yaml ├── .coveragerc ├── distribution_train.py ├── test_registry.py ├── test_statistics_tools.py ├── test_normalizer.py ├── helpers.py ├── test_ensemble.py └── simple_env.py ├── images ├── logo.png └── train.gif ├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── questions.yml │ └── feature-request.yml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ ├── lint.yml │ └── test.yml ├── requirements.txt ├── codecov.yml ├── conftest.py ├── .editorconfig ├── .readthedocs.yaml ├── omnisafe ├── utils │ └── __init__.py ├── models │ ├── offline │ │ └── __init__.py │ ├── critic │ │ └── __init__.py │ ├── actor_critic │ │ └── __init__.py │ ├── actor │ │ ├── __init__.py │ │ └── gaussian_actor.py │ └── __init__.py ├── envs │ ├── classic_control │ │ └── __init__.py │ └── __init__.py ├── algorithms │ ├── on_policy │ │ ├── primal │ │ │ ├── __init__.py │ │ │ └── crpo.py │ │ ├── first_order │ │ │ └── __init__.py │ │ ├── second_order │ │ │ └── __init__.py │ │ ├── penalty_function │ │ │ ├── __init__.py │ │ │ └── ipo.py │ │ ├── pid_lagrange │ │ │ └── __init__.py │ │ ├── saute │ │ │ ├── __init__.py │ │ │ ├── ppo_saute.py │ │ │ └── trpo_saute.py │ │ ├── simmer │ │ │ ├── __init__.py │ │ │ ├── ppo_simmer_pid.py │ │ │ └── trpo_simmer_pid.py │ │ ├── early_terminated │ │ │ ├── __init__.py │ │ │ ├── trpo_early_terminated.py │ │ │ └── ppo_early_terminated.py │ │ ├── base │ │ │ ├── __init__.py │ │ │ └── ppo.py │ │ ├── naive_lagrange │ │ │ └── __init__.py │ │ └── __init__.py │ ├── model_based │ │ ├── base │ │ │ └── __init__.py │ │ ├── __init__.py │ │ └── planner │ │ │ └── __init__.py │ ├── offline │ │ └── __init__.py │ ├── off_policy │ │ └── __init__.py │ ├── registry.py │ ├── __init__.py │ └── base_algo.py ├── common │ ├── offline │ │ └── __init__.py │ ├── __init__.py │ └── buffer │ │ └── __init__.py ├── __init__.py ├── adapter │ └── __init__.py ├── typing.py ├── version.py └── configs │ └── offline │ ├── VAEBC.yaml │ ├── CRR.yaml │ ├── BCQ.yaml │ └── CCRR.yaml ├── examples ├── benchmarks │ ├── example_cli_benchmark_config.yaml │ └── run_experiment_grid.py ├── train_from_yaml.py ├── collect_offline_data.py ├── train_from_custom_dict.py ├── analyze_experiment_results.py ├── evaluate_saved_policy.py ├── plot.py ├── train_policy.py └── train_from_custom_env.py ├── setup.py ├── .flake8 ├── conda-recipe.yaml └── Dockerfile /.dockerignore: -------------------------------------------------------------------------------- 1 | .gitignore -------------------------------------------------------------------------------- /docs/source/_static/.gitkeepme: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include omnisafe/configs/ *.yaml 2 | -------------------------------------------------------------------------------- /tutorials/requirements.txt: -------------------------------------------------------------------------------- 1 | ipywidgets 2 | ipykernel 3 | -------------------------------------------------------------------------------- /tests/requirements.txt: -------------------------------------------------------------------------------- 1 | pytest 2 | pytest-cov 3 | pytest-xdist 4 | -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Alignment/omnisafe/HEAD/images/logo.png -------------------------------------------------------------------------------- /images/train.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Alignment/omnisafe/HEAD/images/train.gif -------------------------------------------------------------------------------- /docs/source/image/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Alignment/omnisafe/HEAD/docs/source/image/logo.png -------------------------------------------------------------------------------- /tests/saved_source/Test-v0.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Alignment/omnisafe/HEAD/tests/saved_source/Test-v0.npz -------------------------------------------------------------------------------- /tutorials/images/omnisafe.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Alignment/omnisafe/HEAD/tutorials/images/omnisafe.jpg -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | * text eol=lf 2 | *.bat eol=crlf 3 | *.ipynb linguist-detectable=false 4 | 5 | *.png binary 6 | *.jpg binary 7 | *.jpeg binary 8 | *.gif binary 9 | *.pdf binary 10 | *.stl binary 11 | -------------------------------------------------------------------------------- /tests/saved_source/PPO-{SafetyPointGoal1-v0}/seed-000-2023-03-16-12-08-52/torch_save/epoch-0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Alignment/omnisafe/HEAD/tests/saved_source/PPO-{SafetyPointGoal1-v0}/seed-000-2023-03-16-12-08-52/torch_save/epoch-0.pt -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: 💬 Start a discussion 4 | url: https://github.com/PKU-Alignment/omnisafe/discussions/new 5 | about: Please ask and answer questions here if unsure. 6 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==7.1.2 2 | sphinx-copybutton 3 | sphinx-design 4 | sphinx-press-theme 5 | sphinx-markdown-tables 6 | recommonmark 7 | sphinx-autoapi 8 | sphinx-autobuild 9 | sphinx-autodoc-typehints 10 | furo 11 | sphinxcontrib-spelling 12 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # Sync with project.dependencies 2 | safety-gymnasium >= 0.1.0 3 | torch >= 1.10.0 4 | numpy >= 1.20.0 5 | tensorboard >= 2.8.0 6 | wandb >= 0.13.0 7 | pyyaml >= 6.0 8 | moviepy >= 1.0.0 9 | typing-extensions >= 4.0.0 10 | typer[all] >= 0.7.0 11 | seaborn >= 0.12.2 12 | pandas >= 1.5.3 13 | matplotlib >= 3.7.1 14 | gdown >= 4.6.0 15 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | coverage: 2 | precision: 2 3 | round: nearest 4 | status: 5 | project: 6 | default: 7 | target: auto 8 | threshold: 0.05% 9 | branches: 10 | - main 11 | - dev 12 | patch: 13 | default: 14 | target: 100% 15 | informational: true 16 | branches: 17 | - main 18 | - dev 19 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | try: 5 | from metadrive import SafeMetaDriveEnv 6 | 7 | meta_drive_env_available = True 8 | except ImportError: 9 | meta_drive_env_available = False 10 | 11 | 12 | def pytest_ignore_collect(path, config): 13 | if os.path.basename(path) == 'meta_drive_env.py' and not meta_drive_env_available: 14 | return True 15 | return False 16 | -------------------------------------------------------------------------------- /docs/source/utils/plotter.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Plotter 2 | ================ 3 | 4 | .. currentmodule:: omnisafe.utils.plotter 5 | 6 | Plotter 7 | ------- 8 | 9 | .. card:: 10 | :class-header: sd-bg-success sd-text-white 11 | :class-card: sd-outline-success sd-rounded-1 12 | 13 | Documentation 14 | ^^^ 15 | 16 | .. autoclass:: Plotter 17 | :members: 18 | :private-members: 19 | -------------------------------------------------------------------------------- /docs/source/envs/mujoco_env.rst: -------------------------------------------------------------------------------- 1 | Mujoco Environment 2 | ================== 3 | 4 | .. currentmodule:: omnisafe.envs.mujoco_env 5 | 6 | MujocoEnv Interface 7 | ------------------- 8 | 9 | .. card:: 10 | :class-header: sd-bg-success sd-text-white 11 | :class-card: sd-outline-success sd-rounded-1 12 | 13 | Documentation 14 | ^^^ 15 | 16 | .. autoclass:: MujocoEnv 17 | :members: 18 | :private-members: 19 | -------------------------------------------------------------------------------- /docs/source/common/logger.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Logger 2 | =============== 3 | 4 | .. currentmodule:: omnisafe.common.logger 5 | 6 | .. autosummary:: 7 | 8 | Logger 9 | 10 | Logger 11 | ------ 12 | 13 | .. card:: 14 | :class-header: sd-bg-success sd-text-white 15 | :class-card: sd-outline-success sd-rounded-1 16 | 17 | Documentation 18 | ^^^ 19 | 20 | .. autoclass:: Logger 21 | :members: 22 | :private-members: 23 | -------------------------------------------------------------------------------- /docs/source/model/offline.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Offline Model 2 | ====================== 3 | 4 | .. currentmodule:: omnisafe.models.offline.dice 5 | 6 | Observation Encoder 7 | ------------------- 8 | 9 | .. card:: 10 | :class-header: sd-bg-success sd-text-white 11 | :class-card: sd-outline-success sd-rounded-1 12 | 13 | Documentation 14 | ^^^ 15 | 16 | .. autoclass:: ObsEncoder 17 | :members: 18 | :private-members: 19 | -------------------------------------------------------------------------------- /docs/source/common/normalizer.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Normalizer 2 | =================== 3 | 4 | .. currentmodule:: omnisafe.common.normalizer 5 | 6 | .. autosummary:: Normalizer 7 | 8 | Normalizer 9 | ---------- 10 | 11 | .. card:: 12 | :class-header: sd-bg-success sd-text-white 13 | :class-card: sd-outline-success sd-rounded-1 14 | 15 | Documentation 16 | ^^^ 17 | 18 | .. autoclass:: Normalizer 19 | :members: 20 | :private-members: 21 | -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---556c9cedab7db813a6ea3860f5921d7ccbc176d70900e709065fc2604d02b9a6/NaturalPG-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/torch_save/epoch-0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Alignment/omnisafe/HEAD/tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---556c9cedab7db813a6ea3860f5921d7ccbc176d70900e709065fc2604d02b9a6/NaturalPG-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/torch_save/epoch-0.pt -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---bc8c9e4f26a8072e52cf3eddefc51eae641946d3314efe537e918e0851f83aa5/PolicyGradient-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/torch_save/epoch-0.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Alignment/omnisafe/HEAD/tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---bc8c9e4f26a8072e52cf3eddefc51eae641946d3314efe537e918e0851f83aa5/PolicyGradient-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/torch_save/epoch-0.pt -------------------------------------------------------------------------------- /docs/source/common/exp_grid.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Experiment Grid 2 | ======================== 3 | 4 | .. currentmodule:: omnisafe.common.experiment_grid 5 | 6 | .. autosummary:: 7 | 8 | ExperimentGrid 9 | 10 | Experiment Grid 11 | --------------- 12 | 13 | .. card:: 14 | :class-header: sd-bg-success sd-text-white 15 | :class-card: sd-outline-success sd-rounded-1 16 | 17 | Documentation 18 | ^^^ 19 | 20 | .. autoclass:: ExperimentGrid 21 | :members: 22 | :private-members: 23 | -------------------------------------------------------------------------------- /docs/source/envs/custom.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Customization Interface of Environments 2 | ================================================ 3 | 4 | .. currentmodule:: omnisafe.envs.custom_env 5 | 6 | .. autosummary:: 7 | 8 | CustomEnv 9 | 10 | CustomEnv 11 | --------- 12 | 13 | .. card:: 14 | :class-header: sd-bg-success sd-text-white 15 | :class-card: sd-outline-success sd-rounded-1 16 | 17 | Documentation 18 | ^^^ 19 | 20 | .. autoclass:: CustomEnv 21 | :members: 22 | :private-members: 23 | -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---556c9cedab7db813a6ea3860f5921d7ccbc176d70900e709065fc2604d02b9a6/NaturalPG-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/tb/events.out.tfevents.1681404176.0839c7628832.57350.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Alignment/omnisafe/HEAD/tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---556c9cedab7db813a6ea3860f5921d7ccbc176d70900e709065fc2604d02b9a6/NaturalPG-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/tb/events.out.tfevents.1681404176.0839c7628832.57350.0 -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---bc8c9e4f26a8072e52cf3eddefc51eae641946d3314efe537e918e0851f83aa5/PolicyGradient-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/tb/events.out.tfevents.1681404176.0839c7628832.57349.0: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-Alignment/omnisafe/HEAD/tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---bc8c9e4f26a8072e52cf3eddefc51eae641946d3314efe537e918e0851f83aa5/PolicyGradient-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/tb/events.out.tfevents.1681404176.0839c7628832.57349.0 -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---556c9cedab7db813a6ea3860f5921d7ccbc176d70900e709065fc2604d02b9a6/exps_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "seed": 0, 4 | "algo_cfgs": { 5 | "steps_per_epoch": 2048 6 | }, 7 | "train_cfgs": { 8 | "total_steps": 4096, 9 | "torch_threads": 1, 10 | "vector_env_nums": 2 11 | }, 12 | "logger_cfgs": { 13 | "use_wandb": false, 14 | "log_dir": "" 15 | }, 16 | "env_id": "SafetyAntVelocity-v1", 17 | "algo": "NaturalPG" 18 | } 19 | -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---bc8c9e4f26a8072e52cf3eddefc51eae641946d3314efe537e918e0851f83aa5/exps_config.json: -------------------------------------------------------------------------------- 1 | 2 | { 3 | "seed": 0, 4 | "algo_cfgs": { 5 | "steps_per_epoch": 2048 6 | }, 7 | "train_cfgs": { 8 | "total_steps": 4096, 9 | "torch_threads": 1, 10 | "vector_env_nums": 2 11 | }, 12 | "logger_cfgs": { 13 | "use_wandb": false, 14 | "log_dir": "" 15 | }, 16 | "env_id": "SafetyAntVelocity-v1", 17 | "algo": "PolicyGradient" 18 | } 19 | -------------------------------------------------------------------------------- /docs/source/utils/model.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Model Utils 2 | ==================== 3 | 4 | .. currentmodule:: omnisafe.utils.model 5 | 6 | .. autosummary:: 7 | 8 | initialize_layer 9 | get_activation 10 | build_mlp_network 11 | 12 | Model Building Utils 13 | -------------------- 14 | 15 | .. card:: 16 | :class-header: sd-bg-success sd-text-white 17 | :class-card: sd-outline-success sd-rounded-1 18 | 19 | Documentation 20 | ^^^ 21 | 22 | .. autofunction:: initialize_layer 23 | .. autofunction:: get_activation 24 | .. autofunction:: build_mlp_network 25 | -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/grid_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "algo": [ 3 | "PolicyGradient", 4 | "NaturalPG" 5 | ], 6 | "env_id": [ 7 | "SafetyAntVelocity-v1" 8 | ], 9 | "logger_cfgs:use_wandb": [ 10 | false 11 | ], 12 | "train_cfgs:vector_env_nums": [ 13 | 2 14 | ], 15 | "train_cfgs:torch_threads": [ 16 | 1 17 | ], 18 | "train_cfgs:total_steps": [ 19 | 4096 20 | ], 21 | "algo_cfgs:steps_per_epoch": [ 22 | 2048 23 | ], 24 | "seed": [ 25 | 0 26 | ] 27 | } 28 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # https://editorconfig.org/ 2 | 3 | root = true 4 | 5 | [*] 6 | charset = utf-8 7 | end_of_line = lf 8 | indent_style = space 9 | indent_size = 4 10 | trim_trailing_whitespace = true 11 | insert_final_newline = true 12 | 13 | [*.py] 14 | indent_size = 4 15 | src_paths=omnisafe,tests,examples 16 | 17 | [*.{yaml,yml,xml}] 18 | indent_size = 2 19 | 20 | [*.md] 21 | indent_size = 2 22 | x-soft-wrap-text = true 23 | 24 | [*.rst] 25 | indent_size = 4 26 | x-soft-wrap-text = true 27 | 28 | [Makefile] 29 | indent_style = tab 30 | 31 | [*.sh] 32 | indent_style = tab 33 | 34 | [*.bat] 35 | end_of_line = crlf 36 | indent_style = tab 37 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # Required 2 | version: 2 3 | 4 | # Set the version of Python and other tools you might need 5 | build: 6 | os: ubuntu-20.04 7 | tools: 8 | python: "3.9" 9 | jobs: 10 | post_install: 11 | - python -m pip install --upgrade pip setuptools 12 | - python -m pip install --no-build-isolation --editable . 13 | - python -m pip install -r docs/requirements.txt 14 | 15 | # Build documentation in the docs/ directory with Sphinx 16 | sphinx: 17 | configuration: docs/source/conf.py 18 | fail_on_warning: True 19 | 20 | # If using Sphinx, optionally build your docs in additional formats such as PDF 21 | # formats: 22 | # - pdf 23 | -------------------------------------------------------------------------------- /docs/source/baserlapi/model_based.rst: -------------------------------------------------------------------------------- 1 | Base Model-based Algorithms 2 | =========================== 3 | 4 | .. currentmodule:: omnisafe.algorithms.model_based.base 5 | 6 | 7 | LOOP 8 | ---- 9 | 10 | .. card:: 11 | :class-header: sd-bg-success sd-text-white 12 | :class-card: sd-outline-success sd-rounded-1 13 | 14 | Documentation 15 | ^^^ 16 | 17 | .. autoclass:: LOOP 18 | :members: 19 | :private-members: 20 | 21 | PETS 22 | ---- 23 | 24 | .. card:: 25 | :class-header: sd-bg-success sd-text-white 26 | :class-card: sd-outline-success sd-rounded-1 27 | 28 | Documentation 29 | ^^^ 30 | 31 | .. autoclass:: PETS 32 | :members: 33 | :private-members: 34 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/source/saferlapi/first_order.rst: -------------------------------------------------------------------------------- 1 | First Order Algorithms 2 | ====================== 3 | 4 | .. currentmodule:: omnisafe.algorithms.on_policy 5 | 6 | .. autosummary:: 7 | 8 | FOCOPS 9 | CUP 10 | 11 | .. _focopsapi: 12 | 13 | FOCOPS 14 | ------ 15 | 16 | .. card:: 17 | :class-header: sd-bg-success sd-text-white 18 | :class-card: sd-outline-success sd-rounded-1 19 | 20 | Documentation 21 | ^^^ 22 | 23 | .. autoclass:: FOCOPS 24 | :members: 25 | :private-members: 26 | 27 | CUP 28 | --- 29 | 30 | .. card:: 31 | :class-header: sd-bg-success sd-text-white 32 | :class-card: sd-outline-success sd-rounded-1 33 | 34 | Documentation 35 | ^^^ 36 | 37 | .. autoclass:: CUP 38 | :members: 39 | :private-members: 40 | -------------------------------------------------------------------------------- /omnisafe/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utility functions for OmniSafe.""" 16 | -------------------------------------------------------------------------------- /docs/source/common/simmer_agent.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Simmer Agent 2 | ===================== 3 | 4 | .. currentmodule:: omnisafe.common.simmer_agent 5 | 6 | .. autosummary:: 7 | 8 | BaseSimmerAgent 9 | SimmerPIDAgent 10 | 11 | Base Simmer Agent 12 | ----------------- 13 | 14 | .. card:: 15 | :class-header: sd-bg-success sd-text-white 16 | :class-card: sd-outline-success sd-rounded-1 17 | 18 | Documentation 19 | ^^^ 20 | 21 | .. autoclass:: BaseSimmerAgent 22 | :members: 23 | :private-members: 24 | 25 | Simmer PID Agent 26 | ---------------- 27 | 28 | .. card:: 29 | :class-header: sd-bg-success sd-text-white 30 | :class-card: sd-outline-success sd-rounded-1 31 | 32 | Documentation 33 | ^^^ 34 | 35 | .. autoclass:: SimmerPIDAgent 36 | :members: 37 | :private-members: 38 | -------------------------------------------------------------------------------- /tests/.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | omit = 3 | ../omnisafe/version.py 4 | ../docs/* 5 | ../examples/* 6 | ../tutorials/* 7 | ../omnisafe/common/control_barrier_function/crabs/* 8 | ../omnisafe/envs/classic_control/* 9 | ../omnisafe/algorithms/off_policy/crabs.py 10 | ../omnisafe/adapter/crabs_adapter.py 11 | ../omnisafe/envs/crabs_env.py 12 | ../omnisafe/envs/custom_env.py 13 | ../omnisafe/envs/safety_isaac_gym_env.py 14 | ../omnisafe/utils/isaac_gym_utils.py 15 | ../omnisafe/envs/meta_drive_env.py 16 | ../omnisafe/evaluator.py 17 | 18 | [report] 19 | exclude_lines = 20 | pragma: no cover 21 | raise NotImplementedError 22 | raise FileNotFoundError 23 | class .*\bProtocol\): 24 | @(abc\.)?abstractmethod 25 | if __name__ == ('__main__'|"__main__"): 26 | if TYPE_CHECKING: 27 | -------------------------------------------------------------------------------- /docs/source/common/offline_data.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Offline Data 2 | ===================== 3 | 4 | .. currentmodule:: omnisafe.common.offline 5 | 6 | Data Collector 7 | -------------- 8 | 9 | .. card:: 10 | :class-header: sd-bg-success sd-text-white 11 | :class-card: sd-outline-success sd-rounded-1 12 | 13 | Documentation 14 | ^^^ 15 | 16 | .. autoclass:: OfflineDataCollector 17 | :members: 18 | :private-members: 19 | 20 | Offline Dataset 21 | --------------- 22 | 23 | .. card:: 24 | :class-header: sd-bg-success sd-text-white 25 | :class-card: sd-outline-success sd-rounded-1 26 | 27 | Documentation 28 | ^^^ 29 | 30 | .. autoclass:: OfflineDataset 31 | :members: 32 | :private-members: 33 | 34 | .. autoclass:: OfflineDatasetWithInit 35 | :members: 36 | :private-members: 37 | -------------------------------------------------------------------------------- /omnisafe/models/offline/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """model used in offline RL algorithms.""" 16 | 17 | from omnisafe.models.offline.dice import ObsEncoder 18 | -------------------------------------------------------------------------------- /omnisafe/envs/classic_control/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Environment implementations from papers.""" 16 | from omnisafe.envs.classic_control import envs_from_crabs 17 | -------------------------------------------------------------------------------- /docs/source/saferlapi/penalty_function.rst: -------------------------------------------------------------------------------- 1 | Penalty Function Algorithms 2 | =========================== 3 | 4 | .. currentmodule:: omnisafe.algorithms.on_policy 5 | 6 | .. autosummary:: 7 | 8 | P3O 9 | IPO 10 | 11 | Penalized Proximal Policy Optimization 12 | -------------------------------------- 13 | 14 | .. card:: 15 | :class-header: sd-bg-success sd-text-white 16 | :class-card: sd-outline-success sd-rounded-1 17 | 18 | Documentation 19 | ^^^ 20 | 21 | .. autoclass:: P3O 22 | :members: 23 | :private-members: 24 | 25 | 26 | 27 | Interior-point Policy Optimization 28 | ---------------------------------- 29 | 30 | .. card:: 31 | :class-header: sd-bg-success sd-text-white 32 | :class-card: sd-outline-success sd-rounded-1 33 | 34 | Documentation 35 | ^^^ 36 | 37 | .. autoclass:: IPO 38 | :members: 39 | :private-members: 40 | -------------------------------------------------------------------------------- /docs/source/envs/safety_gymnasium.rst: -------------------------------------------------------------------------------- 1 | Safety Gymnasium Environment 2 | ============================ 3 | 4 | .. currentmodule:: omnisafe.envs.safety_gymnasium_env 5 | 6 | Safety Gymnasium Interface 7 | -------------------------- 8 | 9 | .. card:: 10 | :class-header: sd-bg-success sd-text-white 11 | :class-card: sd-outline-success sd-rounded-1 12 | 13 | Documentation 14 | ^^^ 15 | 16 | .. autoclass:: SafetyGymnasiumEnv 17 | :members: 18 | :private-members: 19 | 20 | .. currentmodule:: omnisafe.envs.safety_gymnasium_modelbased 21 | 22 | Safety Gymnasium World Model 23 | ---------------------------- 24 | 25 | .. card:: 26 | :class-header: sd-bg-success sd-text-white 27 | :class-card: sd-outline-success sd-rounded-1 28 | 29 | Documentation 30 | ^^^ 31 | 32 | .. autoclass:: SafetyGymnasiumModelBased 33 | :members: 34 | :private-members: 35 | -------------------------------------------------------------------------------- /docs/source/common/lagrange.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Lagrange Multiplier 2 | ============================ 3 | 4 | .. currentmodule:: omnisafe.common.lagrange 5 | 6 | .. autosummary:: 7 | 8 | Lagrange 9 | 10 | Lagrange Multiplier 11 | ------------------- 12 | 13 | .. card:: 14 | :class-header: sd-bg-success sd-text-white 15 | :class-card: sd-outline-success sd-rounded-1 16 | 17 | Documentation 18 | ^^^ 19 | 20 | .. autoclass:: Lagrange 21 | :members: 22 | :private-members: 23 | 24 | 25 | .. currentmodule:: omnisafe.common.pid_lagrange 26 | 27 | .. autosummary:: 28 | 29 | PIDLagrangian 30 | 31 | PIDLagrange 32 | ----------- 33 | 34 | .. card:: 35 | :class-header: sd-bg-success sd-text-white 36 | :class-card: sd-outline-success sd-rounded-1 37 | 38 | Documentation 39 | ^^^ 40 | 41 | .. autoclass:: PIDLagrangian 42 | :members: 43 | :private-members: 44 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/primal/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Primal algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy.primal.crpo import OnCRPO 18 | 19 | 20 | __all__ = [ 21 | 'OnCRPO', 22 | ] 23 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | %SPHINXBUILD% >NUL 2>NUL 14 | if errorlevel 9009 ( 15 | echo. 16 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 17 | echo.installed, then set the SPHINXBUILD environment variable to point 18 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 19 | echo.may add the Sphinx directory to PATH. 20 | echo. 21 | echo.If you don't have Sphinx installed, grab it from 22 | echo.https://www.sphinx-doc.org/ 23 | exit /b 1 24 | ) 25 | 26 | if "%1" == "" goto help 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /tests/saved_source/train_config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | algo: 17 | PPOLag 18 | env_id: 19 | SafetyAntVelocity-v1 20 | train_cfgs: 21 | total_steps: 22 | 2048 23 | vector_env_nums: 1 24 | algo_cfgs: 25 | steps_per_epoch: 26 | 1024 27 | -------------------------------------------------------------------------------- /docs/source/saferlapi/second_order.rst: -------------------------------------------------------------------------------- 1 | Second Order Algorithms 2 | ======================= 3 | 4 | .. currentmodule:: omnisafe.algorithms.on_policy 5 | 6 | .. autosummary:: 7 | 8 | CPO 9 | PCPO 10 | 11 | 12 | .. _cpoapi: 13 | 14 | 15 | Constraint Policy Optimization 16 | ------------------------------ 17 | 18 | .. card:: 19 | :class-header: sd-bg-success sd-text-white 20 | :class-card: sd-outline-success sd-rounded-1 21 | 22 | Documentation 23 | ^^^ 24 | 25 | .. autoclass:: CPO 26 | :members: 27 | :private-members: 28 | 29 | 30 | 31 | 32 | .. _pcpoapi: 33 | 34 | 35 | Projection Based Constraint Policy Optimization 36 | ----------------------------------------------- 37 | 38 | .. card:: 39 | :class-header: sd-bg-success sd-text-white 40 | :class-card: sd-outline-success sd-rounded-1 41 | 42 | Documentation 43 | ^^^ 44 | 45 | .. autoclass:: PCPO 46 | :members: 47 | :private-members: 48 | -------------------------------------------------------------------------------- /omnisafe/common/offline/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Useful Tools for Offline Algorithms.""" 16 | 17 | 18 | from omnisafe.common.offline.data_collector import OfflineDataCollector 19 | from omnisafe.common.offline.dataset import OfflineDataset, OfflineDatasetWithInit, OfflineMeta 20 | -------------------------------------------------------------------------------- /omnisafe/algorithms/model_based/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Basic Model Based Reinforcement Learning algorithms.""" 16 | 17 | from omnisafe.algorithms.model_based.base.loop import LOOP 18 | from omnisafe.algorithms.model_based.base.pets import PETS 19 | 20 | 21 | __all__ = ['LOOP', 'PETS'] 22 | -------------------------------------------------------------------------------- /omnisafe/common/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Common Common utilities for OmniSafe.""" 16 | 17 | from omnisafe.common.lagrange import Lagrange 18 | from omnisafe.common.logger import Logger 19 | from omnisafe.common.normalizer import Normalizer 20 | from omnisafe.common.pid_lagrange import PIDLagrangian 21 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/first_order/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """First-order algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy.first_order.cup import CUP 18 | from omnisafe.algorithms.on_policy.first_order.focops import FOCOPS 19 | 20 | 21 | __all__ = [ 22 | 'CUP', 23 | 'FOCOPS', 24 | ] 25 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/second_order/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Second-order algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy.second_order.cpo import CPO 18 | from omnisafe.algorithms.on_policy.second_order.pcpo import PCPO 19 | 20 | 21 | __all__ = [ 22 | 'CPO', 23 | 'PCPO', 24 | ] 25 | -------------------------------------------------------------------------------- /omnisafe/models/critic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The abstract interfaces of Critic networks for the Actor-Critic algorithm.""" 16 | 17 | from omnisafe.models.critic.critic_builder import CriticBuilder 18 | from omnisafe.models.critic.q_critic import QCritic 19 | from omnisafe.models.critic.v_critic import VCritic 20 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/penalty_function/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Naive Lagrange algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy.penalty_function.ipo import IPO 18 | from omnisafe.algorithms.on_policy.penalty_function.p3o import P3O 19 | 20 | 21 | __all__ = [ 22 | 'P3O', 23 | 'IPO', 24 | ] 25 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/pid_lagrange/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """PID Lagrange algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy.pid_lagrange.cppo_pid import CPPOPID 18 | from omnisafe.algorithms.on_policy.pid_lagrange.trpo_pid import TRPOPID 19 | 20 | 21 | __all__ = [ 22 | 'CPPOPID', 23 | 'TRPOPID', 24 | ] 25 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/saute/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Saute Reinforcement Learning algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy.saute.ppo_saute import PPOSaute 18 | from omnisafe.algorithms.on_policy.saute.trpo_saute import TRPOSaute 19 | 20 | 21 | __all__ = [ 22 | 'TRPOSaute', 23 | 'PPOSaute', 24 | ] 25 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/simmer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Simmer Reinforcement Learning algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy.simmer.ppo_simmer_pid import PPOSimmerPID 18 | from omnisafe.algorithms.on_policy.simmer.trpo_simmer_pid import TRPOSimmerPID 19 | 20 | 21 | __all__ = [ 22 | 'TRPOSimmerPID', 23 | 'PPOSimmerPID', 24 | ] 25 | -------------------------------------------------------------------------------- /tests/saved_source/benchmark_config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | algo: 17 | ['PolicyGradient', 'NaturalPG'] 18 | env_id: 19 | ['SafetyAntVelocity-v1'] 20 | logger_cfgs:use_wandb: 21 | [False] 22 | train_cfgs:vector_env_nums: 23 | [2] 24 | train_cfgs:torch_threads: 25 | [1] 26 | train_cfgs:total_steps: 27 | 4096 28 | algo_cfgs:steps_per_epoch: 29 | 2048 30 | seed: 31 | [0] 32 | -------------------------------------------------------------------------------- /docs/source/utils/math.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Math 2 | ============= 3 | 4 | .. currentmodule:: omnisafe.utils.math 5 | 6 | .. autosummary:: 7 | 8 | get_transpose 9 | get_diagonal 10 | discount_cumsum 11 | conjugate_gradients 12 | SafeTanhTransformer 13 | TanhNormal 14 | 15 | 16 | Tensor Operations 17 | ----------------- 18 | 19 | .. card:: 20 | :class-header: sd-bg-success sd-text-white 21 | :class-card: sd-outline-success sd-rounded-1 22 | 23 | Documentation 24 | ^^^ 25 | 26 | .. autofunction:: get_transpose 27 | .. autofunction:: get_diagonal 28 | .. autofunction:: discount_cumsum 29 | .. autofunction:: conjugate_gradients 30 | 31 | 32 | Distribution Operations 33 | ----------------------- 34 | 35 | .. card:: 36 | :class-header: sd-bg-success sd-text-white 37 | :class-card: sd-outline-success sd-rounded-1 38 | 39 | Documentation 40 | ^^^ 41 | 42 | .. autoclass:: SafeTanhTransformer 43 | :members: 44 | :private-members: 45 | 46 | .. autoclass:: TanhNormal 47 | :members: 48 | :private-members: 49 | -------------------------------------------------------------------------------- /examples/benchmarks/example_cli_benchmark_config.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | algo: 17 | ['PolicyGradient', 'NaturalPG'] 18 | env_id: 19 | ['SafetyAntVelocity-v1'] 20 | logger_cfgs:use_wandb: 21 | [False] 22 | train_cfgs:vector_env_nums: 23 | [2] 24 | train_cfgs:torch_threads: 25 | [1] 26 | train_cfgs:total_steps: 27 | 1024 28 | algo_cfgs:steps_per_epoch: 29 | 512 30 | seed: 31 | [0] 32 | -------------------------------------------------------------------------------- /docs/source/envs/core.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Core Environment 2 | ========================= 3 | 4 | .. currentmodule:: omnisafe.envs.core 5 | 6 | .. autosummary:: 7 | 8 | CMDP 9 | Wrapper 10 | EnvRegister 11 | 12 | CMDP 13 | ---- 14 | 15 | .. card:: 16 | :class-header: sd-bg-success sd-text-white 17 | :class-card: sd-outline-success sd-rounded-1 18 | 19 | Documentation 20 | ^^^ 21 | 22 | .. autoclass:: CMDP 23 | :members: 24 | :private-members: 25 | 26 | Wrapper 27 | ------- 28 | 29 | .. card:: 30 | :class-header: sd-bg-success sd-text-white 31 | :class-card: sd-outline-success sd-rounded-1 32 | 33 | Documentation 34 | ^^^ 35 | 36 | .. autoclass:: Wrapper 37 | :members: 38 | :private-members: 39 | 40 | Make an Environment 41 | ------------------- 42 | 43 | .. card:: 44 | :class-header: sd-bg-success sd-text-white 45 | :class-card: sd-outline-success sd-rounded-1 46 | 47 | Documentation 48 | ^^^ 49 | 50 | .. autoclass:: EnvRegister 51 | :members: 52 | :private-members: 53 | 54 | .. autofunction:: make 55 | -------------------------------------------------------------------------------- /omnisafe/models/actor_critic/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of ActorCritic.""" 16 | 17 | from omnisafe.models.actor_critic.actor_critic import ActorCritic 18 | from omnisafe.models.actor_critic.actor_q_critic import ActorQCritic 19 | from omnisafe.models.actor_critic.constraint_actor_critic import ConstraintActorCritic 20 | from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic 21 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/early_terminated/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Early Terminated Safe Reinforcement Learning algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy.early_terminated.ppo_early_terminated import PPOEarlyTerminated 18 | from omnisafe.algorithms.on_policy.early_terminated.trpo_early_terminated import TRPOEarlyTerminated 19 | 20 | 21 | __all__ = [ 22 | 'TRPOEarlyTerminated', 23 | 'PPOEarlyTerminated', 24 | ] 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | import re 3 | import sys 4 | 5 | from setuptools import setup 6 | 7 | 8 | HERE = pathlib.Path(__file__).absolute().parent 9 | VERSION_FILE = HERE / 'omnisafe' / 'version.py' 10 | 11 | sys.path.insert(0, str(VERSION_FILE.parent)) 12 | import version # noqa 13 | 14 | 15 | VERSION_CONTENT = None 16 | 17 | try: 18 | if not version.__release__: 19 | try: 20 | VERSION_CONTENT = VERSION_FILE.read_text(encoding='utf-8') 21 | VERSION_FILE.write_text( 22 | data=re.sub( 23 | r"""__version__\s*=\s*('[^']+'|"[^"]+")""", 24 | f'__version__ = {version.__version__!r}', 25 | string=VERSION_CONTENT, 26 | ), 27 | encoding='utf-8', 28 | ) 29 | except OSError: 30 | VERSION_CONTENT = None 31 | 32 | setup( 33 | name='omnisafe', 34 | version=version.__version__, 35 | ) 36 | finally: 37 | if VERSION_CONTENT is not None: 38 | with VERSION_FILE.open(mode='wt', encoding='utf-8', newline='') as file: 39 | file.write(VERSION_CONTENT) 40 | -------------------------------------------------------------------------------- /docs/source/baserlapi/off_policy.rst: -------------------------------------------------------------------------------- 1 | Base Off-policy Algorithms 2 | ========================== 3 | 4 | .. currentmodule:: omnisafe.algorithms.off_policy 5 | 6 | .. autosummary:: 7 | 8 | DDPG 9 | TD3 10 | SAC 11 | 12 | Deep Deterministic Policy Gradient 13 | ---------------------------------- 14 | 15 | .. card:: 16 | :class-header: sd-bg-success sd-text-white 17 | :class-card: sd-outline-success sd-rounded-1 18 | 19 | Documentation 20 | ^^^ 21 | 22 | .. autoclass:: DDPG 23 | :members: 24 | :private-members: 25 | 26 | 27 | Twin Delayed DDPG 28 | ----------------- 29 | 30 | .. card:: 31 | :class-header: sd-bg-success sd-text-white 32 | :class-card: sd-outline-success sd-rounded-1 33 | 34 | Documentation 35 | ^^^ 36 | 37 | .. autoclass:: TD3 38 | :members: 39 | :private-members: 40 | 41 | Soft Actor-Critic 42 | ----------------- 43 | 44 | .. card:: 45 | :class-header: sd-bg-success sd-text-white 46 | :class-card: sd-outline-success sd-rounded-1 47 | 48 | Documentation 49 | ^^^ 50 | 51 | .. autoclass:: SAC 52 | :members: 53 | :private-members: 54 | -------------------------------------------------------------------------------- /examples/train_from_yaml.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Example of training a policy from default config yaml with OmniSafe.""" 16 | import omnisafe 17 | 18 | 19 | if __name__ == '__main__': 20 | env_id = 'SafetyPointGoal1-v0' 21 | 22 | agent = omnisafe.Agent('PPOLag', env_id) 23 | agent.learn() 24 | 25 | agent.plot(smooth=1) 26 | agent.render(num_episodes=1, render_mode='rgb_array', width=256, height=256) 27 | agent.evaluate(num_episodes=1) 28 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/questions.yml: -------------------------------------------------------------------------------- 1 | name: 🤔 Questions / Help / Support 2 | description: Do you need support? 3 | title: "[Question] " 4 | labels: [question] 5 | body: 6 | - type: checkboxes 7 | id: steps 8 | attributes: 9 | label: Required prerequisites 10 | description: Make sure you've completed the following steps before submitting your issue -- thank you! 11 | options: 12 | - label: I have read the documentation . 13 | required: true 14 | - label: I have searched the [Issue Tracker](https://github.com/PKU-Alignment/omnisafe/issues) and [Discussions](https://github.com/PKU-Alignment/omnisafe/discussions) that this hasn't already been reported. (+1 or comment there if it has.) 15 | required: true 16 | - label: Consider asking first in a [Discussion](https://github.com/PKU-Alignment/omnisafe/discussions/new). 17 | required: false 18 | 19 | - type: textarea 20 | id: questions 21 | attributes: 22 | label: Questions 23 | description: Describe your questions with relevant resources such as snippets, links, images, etc. 24 | validations: 25 | required: true 26 | -------------------------------------------------------------------------------- /omnisafe/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """OmniSafe: A comprehensive and reliable benchmark for safe reinforcement learning.""" 16 | 17 | from contextlib import suppress 18 | 19 | 20 | with suppress(ImportError): 21 | from isaacgym import gymutil 22 | 23 | from omnisafe import algorithms 24 | from omnisafe.algorithms import ALGORITHMS 25 | from omnisafe.algorithms.algo_wrapper import AlgoWrapper as Agent 26 | from omnisafe.evaluator import Evaluator 27 | from omnisafe.version import __version__ 28 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/base/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Basic Reinforcement Learning algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy.base.natural_pg import NaturalPG 18 | from omnisafe.algorithms.on_policy.base.policy_gradient import PolicyGradient 19 | from omnisafe.algorithms.on_policy.base.ppo import PPO 20 | from omnisafe.algorithms.on_policy.base.trpo import TRPO 21 | 22 | 23 | __all__ = [ 24 | 'NaturalPG', 25 | 'PolicyGradient', 26 | 'PPO', 27 | 'TRPO', 28 | ] 29 | -------------------------------------------------------------------------------- /omnisafe/algorithms/offline/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Offline algorithms.""" 16 | 17 | from omnisafe.algorithms.offline.bcq import BCQ 18 | from omnisafe.algorithms.offline.bcq_lag import BCQLag 19 | from omnisafe.algorithms.offline.c_crr import CCRR 20 | from omnisafe.algorithms.offline.coptidice import COptiDICE 21 | from omnisafe.algorithms.offline.crr import CRR 22 | from omnisafe.algorithms.offline.vae_bc import VAEBC 23 | 24 | 25 | __all__ = ['BCQ', 'BCQLag', 'CCRR', 'CRR', 'COptiDICE', 'VAEBC'] 26 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/naive_lagrange/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Naive Lagrange algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy.naive_lagrange.pdo import PDO 18 | from omnisafe.algorithms.on_policy.naive_lagrange.ppo_lag import PPOLag 19 | from omnisafe.algorithms.on_policy.naive_lagrange.rcpo import RCPO 20 | from omnisafe.algorithms.on_policy.naive_lagrange.trpo_lag import TRPOLag 21 | 22 | 23 | __all__ = [ 24 | 'RCPO', 25 | 'PDO', 26 | 'PPOLag', 27 | 'TRPOLag', 28 | ] 29 | -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---bc8c9e4f26a8072e52cf3eddefc51eae641946d3314efe537e918e0851f83aa5/PolicyGradient-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/progress.csv: -------------------------------------------------------------------------------- 1 | Metrics/EpRet,Metrics/EpCost,Metrics/EpLen,Train/Epoch,Train/Entropy,Train/KL,Train/StopIter,Train/PolicyRatio,Train/LR,Train/PolicyStd,TotalEnvSteps,Loss/Loss_pi,Loss/Loss_pi/Delta,Value/Adv,Loss/Loss_reward_critic,Loss/Loss_reward_critic/Delta,Value/reward,Time/Total,Time/Rollout,Time/Update,Time/Epoch,Time/FPS 2 | -178.79351806640625,0.2800000011920929,55.119998931884766,0.0,1.4139022827148438,0.03477708250284195,3.0,0.995556652545929,0.0001500000071246177,0.994982898235321,2048.0,-0.0314510278403759,-0.0314510278403759,-0.12468741834163666,2110.103759765625,2110.103759765625,-0.14448364078998566,4.811440944671631,4.10282039642334,0.7073924541473389,4.810637474060059,425.72332763671875 3 | -179.4430389404297,0.35483869910240173,56.0,1.0,1.4029006958007812,0.027002619579434395,5.0,0.9962325096130371,0.0,0.9841020703315735,4096.0,-0.028531571850180626,0.0029194559901952744,-0.23381422460079193,2087.865966796875,-22.23779296875,-1.619920253753662,10.351533889770508,4.4213104248046875,1.0958952903747559,5.517416954040527,371.18829345703125 4 | -------------------------------------------------------------------------------- /docs/source/start/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | 4 | Installation Guide 5 | ------------------ 6 | 7 | You can install OmniSafe from the Python Package Index (PyPI) using pip: 8 | 9 | .. code-block:: bash 10 | 11 | $ conda create -n omnisafe python=3.8 12 | $ conda activate omnisafe 13 | $ pip install omnisafe 14 | 15 | You can also install OmniSafe from source: 16 | 17 | .. code-block:: bash 18 | 19 | $ git clone https://github.com/PKU-Alignment/omnisafe.git 20 | $ cd omnisafe 21 | $ conda create -n omnisafe python=3.8 22 | $ conda activate omnisafe 23 | $ pip install -e . 24 | 25 | Installation Video 26 | ------------------ 27 | 28 | Here we provide a video tutorial for installing OmniSafe on Ubuntu. 29 | 30 | .. admonition:: Video Tutorial 31 | :class: hint 32 | 33 | Install OmniSafe From PyPI 34 | 35 | .. raw:: html 36 | 37 | 38 | 39 | Install OmniSafe From Source 40 | 41 | .. raw:: html 42 | 43 | 44 | -------------------------------------------------------------------------------- /docs/source/saferlapi/model_based.rst: -------------------------------------------------------------------------------- 1 | Model-based Algorithms 2 | ====================== 3 | 4 | .. currentmodule:: omnisafe.algorithms.model_based 5 | 6 | CAPPETS 7 | ------- 8 | 9 | .. card:: 10 | :class-header: sd-bg-success sd-text-white 11 | :class-card: sd-outline-success sd-rounded-1 12 | 13 | Documentation 14 | ^^^ 15 | 16 | .. autoclass:: CAPPETS 17 | :members: 18 | :private-members: 19 | 20 | CCEPETS 21 | ------- 22 | 23 | .. card:: 24 | :class-header: sd-bg-success sd-text-white 25 | :class-card: sd-outline-success sd-rounded-1 26 | 27 | Documentation 28 | ^^^ 29 | 30 | .. autoclass:: CCEPETS 31 | :members: 32 | :private-members: 33 | 34 | RCEPETS 35 | ------- 36 | 37 | .. card:: 38 | :class-header: sd-bg-success sd-text-white 39 | :class-card: sd-outline-success sd-rounded-1 40 | 41 | Documentation 42 | ^^^ 43 | 44 | .. autoclass:: RCEPETS 45 | :members: 46 | :private-members: 47 | 48 | Safe LOOP 49 | --------- 50 | 51 | .. card:: 52 | :class-header: sd-bg-success sd-text-white 53 | :class-card: sd-outline-success sd-rounded-1 54 | 55 | Documentation 56 | ^^^ 57 | 58 | .. autoclass:: SafeLOOP 59 | :members: 60 | :private-members: 61 | -------------------------------------------------------------------------------- /omnisafe/models/actor/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """The abstract interfaces of Actor networks for the Actor-Critic algorithm.""" 16 | 17 | from omnisafe.models.actor.actor_builder import ActorBuilder 18 | from omnisafe.models.actor.gaussian_actor import GaussianActor 19 | from omnisafe.models.actor.gaussian_learning_actor import GaussianLearningActor 20 | from omnisafe.models.actor.gaussian_sac_actor import GaussianSACActor 21 | from omnisafe.models.actor.mlp_actor import MLPActor 22 | from omnisafe.models.actor.perturbation_actor import PerturbationActor 23 | from omnisafe.models.actor.vae_actor import VAE 24 | -------------------------------------------------------------------------------- /omnisafe/algorithms/model_based/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Model-Based algorithms.""" 16 | 17 | from omnisafe.algorithms.model_based import base 18 | from omnisafe.algorithms.model_based.base import LOOP, PETS 19 | from omnisafe.algorithms.model_based.cap_pets import CAPPETS 20 | from omnisafe.algorithms.model_based.cce_pets import CCEPETS 21 | from omnisafe.algorithms.model_based.rce_pets import RCEPETS 22 | from omnisafe.algorithms.model_based.safeloop import SafeLOOP 23 | 24 | 25 | __all__ = [ 26 | *base.__all__, 27 | 'CAPPETS', 28 | 'CCEPETS', 29 | 'SafeLOOP', 30 | 'RCEPETS', 31 | ] 32 | -------------------------------------------------------------------------------- /omnisafe/common/buffer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of Buffer.""" 16 | 17 | from omnisafe.common.buffer.base import BaseBuffer 18 | from omnisafe.common.buffer.offpolicy_buffer import OffPolicyBuffer 19 | from omnisafe.common.buffer.onpolicy_buffer import OnPolicyBuffer 20 | from omnisafe.common.buffer.vector_offpolicy_buffer import VectorOffPolicyBuffer 21 | from omnisafe.common.buffer.vector_onpolicy_buffer import VectorOnPolicyBuffer 22 | 23 | 24 | __all__ = [ 25 | 'BaseBuffer', 26 | 'OffPolicyBuffer', 27 | 'OnPolicyBuffer', 28 | 'VectorOffPolicyBuffer', 29 | 'VectorOnPolicyBuffer', 30 | ] 31 | -------------------------------------------------------------------------------- /omnisafe/adapter/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Adapter for the environment and the algorithm.""" 16 | 17 | from omnisafe.adapter.early_terminated_adapter import EarlyTerminatedAdapter 18 | from omnisafe.adapter.modelbased_adapter import ModelBasedAdapter 19 | from omnisafe.adapter.offline_adapter import OfflineAdapter 20 | from omnisafe.adapter.offpolicy_adapter import OffPolicyAdapter 21 | from omnisafe.adapter.online_adapter import OnlineAdapter 22 | from omnisafe.adapter.onpolicy_adapter import OnPolicyAdapter 23 | from omnisafe.adapter.saute_adapter import SauteAdapter 24 | from omnisafe.adapter.simmer_adapter import SimmerAdapter 25 | -------------------------------------------------------------------------------- /omnisafe/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Environment API for OmniSafe.""" 16 | 17 | from omnisafe.envs import classic_control 18 | from omnisafe.envs.core import CMDP, env_register, make, support_envs 19 | from omnisafe.envs.crabs_env import CRABSEnv 20 | from omnisafe.envs.custom_env import CustomEnv 21 | from omnisafe.envs.meta_drive_env import SafetyMetaDriveEnv 22 | from omnisafe.envs.mujoco_env import MujocoEnv 23 | from omnisafe.envs.safety_gymnasium_env import SafetyGymnasiumEnv 24 | from omnisafe.envs.safety_gymnasium_modelbased import SafetyGymnasiumModelBased 25 | from omnisafe.envs.safety_isaac_gym_env import SafetyIsaacGymEnv 26 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | max-doc-length = 100 4 | select = B,C,E,F,W,Y,SIM 5 | ignore = 6 | # E203: whitespace before ':' 7 | # W503: line break before binary operator 8 | # W504: line break after binary operator 9 | # format by black 10 | E203,W503,W504, 11 | # E501: line too long 12 | # W505: doc line too long 13 | # too long docstring due to long example blocks 14 | E501,W505, 15 | per-file-ignores = 16 | # F401: module imported but unused 17 | # intentionally unused imports 18 | __init__.py: F401 19 | # E301: expected 1 blank line 20 | # E302: expected 2 blank lines 21 | # E305: expected 2 blank lines after class or function definition 22 | # E701: multiple statements on one line (colon) 23 | # E704: multiple statements on one line (def) 24 | # format by black 25 | *.pyi: E301,E302,E305,E701,E704 26 | # B008: do not perform function calls in argument defaults 27 | # typer requires this 28 | omnisafe/utils/command_app.py: B008 29 | # SIM113: use enumerate 30 | omnisafe/algorithms/on_policy/base/policy_gradient.py: SIM113 31 | exclude = 32 | .git, 33 | .vscode, 34 | venv, 35 | third-party, 36 | __pycache__, 37 | docs/source/conf.py, 38 | build, 39 | dist, 40 | examples, 41 | tests, 42 | conftest.py 43 | -------------------------------------------------------------------------------- /omnisafe/algorithms/model_based/planner/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Model-based planner.""" 16 | 17 | from omnisafe.algorithms.model_based.planner.arc import ARCPlanner 18 | from omnisafe.algorithms.model_based.planner.cap import CAPPlanner 19 | from omnisafe.algorithms.model_based.planner.cce import CCEPlanner 20 | from omnisafe.algorithms.model_based.planner.cem import CEMPlanner 21 | from omnisafe.algorithms.model_based.planner.rce import RCEPlanner 22 | from omnisafe.algorithms.model_based.planner.safe_arc import SafeARCPlanner 23 | 24 | 25 | __all__ = [ 26 | 'CEMPlanner', 27 | 'CCEPlanner', 28 | 'ARCPlanner', 29 | 'SafeARCPlanner', 30 | 'RCEPlanner', 31 | 'CAPPlanner', 32 | ] 33 | -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---556c9cedab7db813a6ea3860f5921d7ccbc176d70900e709065fc2604d02b9a6/NaturalPG-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/progress.csv: -------------------------------------------------------------------------------- 1 | Metrics/EpRet,Metrics/EpCost,Metrics/EpLen,Train/Epoch,Train/Entropy,Train/KL,Train/StopIter,Train/PolicyRatio,Train/LR,Train/PolicyStd,TotalEnvSteps,Loss/Loss_pi,Loss/Loss_pi/Delta,Value/Adv,Loss/Loss_reward_critic,Loss/Loss_reward_critic/Delta,Value/reward,Time/Total,Time/Rollout,Time/Update,Time/Epoch,Time/FPS,Misc/Alpha,Misc/FinalStepNorm,Misc/gradient_norm,Misc/xHx,Misc/H_inv_g 2 | -178.79351806640625,0.2800000011920929,55.119998931884766,0.0,1.378043293952942,0.0,10.0,0.9962406158447266,0.0,0.9600533843040466,2048.0,-0.08913429081439972,-0.08913429081439972,2.528540790081024e-07,1686.935546875,1686.935546875,-0.14448364078998566,5.198263168334961,4.091709136962891,1.1056640148162842,5.197443962097168,394.03997802734375,0.2244206666946411,0.33427420258522034,0.29370883107185364,0.39710402488708496,1.4894983768463135 3 | -237.6625518798828,1.1515151262283325,78.1515121459961,1.0,1.3339864015579224,0.0,10.0,1.002729058265686,0.0,0.9186481833457947,4096.0,-0.09461010247468948,-0.005475811660289764,-1.3690441846847534e-07,1352.0322265625,-334.9033203125,-14.09510612487793,10.670058250427246,4.426172256469727,1.0115702152252197,5.437805652618408,376.62261962890625,0.22050604224205017,0.3292768597602844,0.30748632550239563,0.4113287031650543,1.4932781457901 4 | -------------------------------------------------------------------------------- /docs/source/model/critic.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Critic 2 | =============== 3 | 4 | .. currentmodule:: omnisafe.models.base 5 | 6 | .. autosummary:: 7 | 8 | Critic 9 | 10 | 11 | Base Critic 12 | ----------- 13 | 14 | .. card:: 15 | :class-header: sd-bg-success sd-text-white 16 | :class-card: sd-outline-success sd-rounded-1 17 | 18 | Documentation 19 | ^^^ 20 | 21 | .. autoclass:: Critic 22 | :members: 23 | :private-members: 24 | 25 | .. currentmodule:: omnisafe.models.critic 26 | 27 | .. autosummary:: 28 | 29 | CriticBuilder 30 | QCritic 31 | VCritic 32 | 33 | Critic Builder 34 | -------------- 35 | .. card:: 36 | :class-header: sd-bg-success sd-text-white 37 | :class-card: sd-outline-success sd-rounded-1 38 | 39 | Documentation 40 | ^^^ 41 | 42 | .. autoclass:: CriticBuilder 43 | :members: 44 | :private-members: 45 | 46 | Q Critic 47 | -------- 48 | .. card:: 49 | :class-header: sd-bg-success sd-text-white 50 | :class-card: sd-outline-success sd-rounded-1 51 | 52 | Documentation 53 | ^^^ 54 | 55 | .. autoclass:: QCritic 56 | :members: 57 | :private-members: 58 | 59 | V Critic 60 | -------- 61 | .. card:: 62 | :class-header: sd-bg-success sd-text-white 63 | :class-card: sd-outline-success sd-rounded-1 64 | 65 | Documentation 66 | ^^^ 67 | 68 | .. autoclass:: VCritic 69 | :members: 70 | :private-members: 71 | -------------------------------------------------------------------------------- /tests/distribution_train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Example of training a policy with OmniSafe.""" 16 | 17 | import omnisafe 18 | import simple_env # noqa: F401 19 | 20 | 21 | if __name__ == '__main__': 22 | algo = 'NaturalPG' 23 | env_id = 'Test-v0' 24 | custom_cfgs = { 25 | 'train_cfgs': { 26 | 'total_steps': 4096, 27 | 'vector_env_nums': 1, 28 | }, 29 | 'algo_cfgs': { 30 | 'steps_per_epoch': 1024, 31 | 'update_iters': 2, 32 | }, 33 | 'logger_cfgs': { 34 | 'use_wandb': False, 35 | }, 36 | } 37 | train_terminal_cfgs = { 38 | 'parallel': 2, 39 | } 40 | agent = omnisafe.Agent(algo, env_id, train_terminal_cfgs, custom_cfgs) 41 | agent.learn() 42 | -------------------------------------------------------------------------------- /docs/source/model/actor_critic.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Actor Critic 2 | ===================== 3 | 4 | .. currentmodule:: omnisafe.models.actor_critic 5 | 6 | .. autosummary:: 7 | 8 | ActorCritic 9 | ActorQCritic 10 | ConstraintActorCritic 11 | ConstraintActorQCritic 12 | 13 | Actor Critic 14 | ------------ 15 | 16 | .. card:: 17 | :class-header: sd-bg-success sd-text-white 18 | :class-card: sd-outline-success sd-rounded-1 19 | 20 | Documentation 21 | ^^^ 22 | 23 | .. autoclass:: ActorCritic 24 | :members: 25 | :private-members: 26 | 27 | Actor Q Critic 28 | -------------- 29 | 30 | .. card:: 31 | :class-header: sd-bg-success sd-text-white 32 | :class-card: sd-outline-success sd-rounded-1 33 | 34 | Documentation 35 | ^^^ 36 | 37 | .. autoclass:: ActorQCritic 38 | :members: 39 | :private-members: 40 | 41 | Constraint Actor Critic 42 | ----------------------- 43 | 44 | .. card:: 45 | :class-header: sd-bg-success sd-text-white 46 | :class-card: sd-outline-success sd-rounded-1 47 | 48 | Documentation 49 | ^^^ 50 | 51 | .. autoclass:: ConstraintActorCritic 52 | :members: 53 | :private-members: 54 | 55 | Constraint Actor Q Critic 56 | ------------------------- 57 | 58 | .. card:: 59 | :class-header: sd-bg-success sd-text-white 60 | :class-card: sd-outline-success sd-rounded-1 61 | 62 | Documentation 63 | ^^^ 64 | 65 | .. autoclass:: ConstraintActorQCritic 66 | :members: 67 | :private-members: 68 | -------------------------------------------------------------------------------- /examples/collect_offline_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Example of collecting offline data with OmniSafe.""" 16 | 17 | from omnisafe.common.offline.data_collector import OfflineDataCollector 18 | 19 | 20 | # please change agent path and env name 21 | # also, please make sure you have run: 22 | # python train_policy.py --algo PPO --env ENVID 23 | # where ENVID is the environment from which you want to collect data. 24 | # The `PATH_TO_AGENT` is the directory path containing the `torch_save`. 25 | 26 | env_name = 'SafetyAntVelocity-v1' 27 | size = 1_000_000 28 | agents = [ 29 | ('PATH_TO_AGENT', 'epoch-500.pt', 1_000_000), 30 | ] 31 | save_dir = './data' 32 | 33 | if __name__ == '__main__': 34 | col = OfflineDataCollector(size, env_name) 35 | for agent, model_name, num in agents: 36 | col.register_agent(agent, model_name, num) 37 | col.collect(save_dir) 38 | -------------------------------------------------------------------------------- /examples/train_from_custom_dict.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Example of training a policy from custom dict with OmniSafe.""" 16 | 17 | import omnisafe 18 | 19 | 20 | if __name__ == '__main__': 21 | env_id = 'SafetyPointGoal1-v0' 22 | custom_cfgs = { 23 | 'train_cfgs': { 24 | 'total_steps': 1024000, 25 | 'vector_env_nums': 1, 26 | 'parallel': 1, 27 | }, 28 | 'algo_cfgs': { 29 | 'steps_per_epoch': 2048, 30 | 'update_iters': 1, 31 | }, 32 | 'logger_cfgs': { 33 | 'use_wandb': False, 34 | }, 35 | } 36 | 37 | agent = omnisafe.Agent('PPOLag', env_id, custom_cfgs=custom_cfgs) 38 | agent.learn() 39 | 40 | agent.plot(smooth=1) 41 | agent.render(num_episodes=1, render_mode='rgb_array', width=256, height=256) 42 | agent.evaluate(num_episodes=1) 43 | -------------------------------------------------------------------------------- /examples/analyze_experiment_results.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Example of analyzing policies trained by exp-x with OmniSafe.""" 16 | 17 | from omnisafe.common.statistics_tools import StatisticsTools 18 | 19 | 20 | # just fill in the path in which experiment grid runs. 21 | PATH = '' 22 | if __name__ == '__main__': 23 | st = StatisticsTools() 24 | st.load_source(PATH) 25 | # just fill in the name of the parameter of which value you want to compare. 26 | # then you can specify the value of the parameter you want to compare, 27 | # or you can just specify how many values you want to compare in single graph at most, 28 | # and the function will automatically generate all possible combinations of the graph. 29 | # but the two mode can not be used at the same time. 30 | st.draw_graph(parameter='', values=None, compare_num=2, cost_limit=None, show_image=True) 31 | -------------------------------------------------------------------------------- /omnisafe/models/actor/gaussian_actor.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """This module contains some base normal distribution agent for the models.""" 16 | 17 | from abc import ABC, abstractmethod 18 | 19 | from omnisafe.models.base import Actor 20 | 21 | 22 | class GaussianActor(Actor, ABC): 23 | """An abstract class for normal distribution actor. 24 | 25 | A NormalActor inherits from Actor and use Normal distribution to approximate the policy function. 26 | 27 | .. note:: 28 | You can use this class to implement your own actor by inheriting it. 29 | """ 30 | 31 | @property 32 | @abstractmethod 33 | def std(self) -> float: 34 | """Get the standard deviation of the normal distribution.""" 35 | 36 | @std.setter 37 | @abstractmethod 38 | def std(self, std: float) -> None: 39 | """Set the standard deviation of the normal distribution.""" 40 | -------------------------------------------------------------------------------- /docs/source/baserlapi/on_policy.rst: -------------------------------------------------------------------------------- 1 | Base On-policy Algorithms 2 | ========================= 3 | 4 | .. currentmodule:: omnisafe.algorithms.on_policy 5 | 6 | .. autosummary:: 7 | 8 | PolicyGradient 9 | NaturalPG 10 | TRPO 11 | PPO 12 | 13 | Policy Gradient 14 | --------------- 15 | 16 | .. card:: 17 | :class-header: sd-bg-success sd-text-white 18 | :class-card: sd-outline-success sd-rounded-1 19 | 20 | Documentation 21 | ^^^ 22 | 23 | .. autoclass:: PolicyGradient 24 | :members: 25 | :private-members: 26 | 27 | 28 | Natural Policy Gradient 29 | ----------------------- 30 | 31 | .. card:: 32 | :class-header: sd-bg-success sd-text-white 33 | :class-card: sd-outline-success sd-rounded-1 34 | 35 | Documentation 36 | ^^^ 37 | 38 | .. autoclass:: NaturalPG 39 | :members: 40 | :private-members: 41 | 42 | .. _trpoapi: 43 | 44 | Trust Region Policy Optimization 45 | -------------------------------- 46 | 47 | .. card:: 48 | :class-header: sd-bg-success sd-text-white 49 | :class-card: sd-outline-success sd-rounded-1 50 | 51 | Documentation 52 | ^^^ 53 | 54 | .. autoclass:: TRPO 55 | :members: 56 | :private-members: 57 | 58 | 59 | 60 | 61 | .. _ppoapi: 62 | 63 | 64 | Proximal Policy Optimization 65 | ---------------------------- 66 | 67 | .. card:: 68 | :class-header: sd-bg-success sd-text-white 69 | :class-card: sd-outline-success sd-rounded-1 70 | 71 | Documentation 72 | ^^^ 73 | 74 | .. autoclass:: PPO 75 | :members: 76 | :private-members: 77 | -------------------------------------------------------------------------------- /omnisafe/algorithms/off_policy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Off-policy algorithms.""" 16 | 17 | from omnisafe.algorithms.off_policy.crabs import CRABS 18 | from omnisafe.algorithms.off_policy.ddpg import DDPG 19 | from omnisafe.algorithms.off_policy.ddpg_lag import DDPGLag 20 | from omnisafe.algorithms.off_policy.ddpg_pid import DDPGPID 21 | from omnisafe.algorithms.off_policy.sac import SAC 22 | from omnisafe.algorithms.off_policy.sac_lag import SACLag 23 | from omnisafe.algorithms.off_policy.sac_pid import SACPID 24 | from omnisafe.algorithms.off_policy.td3 import TD3 25 | from omnisafe.algorithms.off_policy.td3_lag import TD3Lag 26 | from omnisafe.algorithms.off_policy.td3_pid import TD3PID 27 | 28 | 29 | __all__ = [ 30 | 'DDPG', 31 | 'TD3', 32 | 'SAC', 33 | 'DDPGLag', 34 | 'TD3Lag', 35 | 'SACLag', 36 | 'DDPGPID', 37 | 'TD3PID', 38 | 'SACPID', 39 | 'CRABS', 40 | ] 41 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Description 2 | 3 | Describe your changes in detail. 4 | 5 | ## Motivation and Context 6 | 7 | Why is this change required? What problem does it solve? 8 | If it fixes an open issue, please link to the issue here. 9 | You can use the syntax `close #15213` if this solves the issue #15213 10 | 11 | - [ ] I have raised an issue to propose this change ([required](https://github.com/PKU-Alignment/omnisafe/issues) for new features and bug fixes) 12 | 13 | ## Types of changes 14 | 15 | What types of changes does your code introduce? Put an `x` in all the boxes that apply: 16 | 17 | - [ ] Bug fix (non-breaking change which fixes an issue) 18 | - [ ] New feature (non-breaking change which adds core functionality) 19 | - [ ] Breaking change (fix or feature that would cause existing functionality to change) 20 | - [ ] Documentation (update in the documentation) 21 | 22 | ## Checklist 23 | 24 | Go over all the following points, and put an `x` in all the boxes that apply. 25 | If you are unsure about any of these, don't hesitate to ask. We are here to help! 26 | 27 | - [ ] I have read the [CONTRIBUTION](https://github.com/PKU-Alignment/omnisafe/blob/HEAD/CONTRIBUTING.md) guide. (**required**) 28 | - [ ] My change requires a change to the documentation. 29 | - [ ] I have updated the tests accordingly. (*required for a bug fix or a new feature*) 30 | - [ ] I have updated the documentation accordingly. 31 | - [ ] I have reformatted the code using `make format`. (**required**) 32 | - [ ] I have checked the code using `make lint`. (**required**) 33 | - [ ] I have ensured `make test` pass. (**required**) 34 | -------------------------------------------------------------------------------- /docs/source/utils/distributed.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Distributed 2 | ==================== 3 | 4 | .. currentmodule:: omnisafe.utils.distributed 5 | 6 | .. autosummary:: 7 | 8 | setup_distributed 9 | get_rank 10 | world_size 11 | fork 12 | avg_tensor 13 | avg_grads 14 | sync_params 15 | avg_params 16 | dist_avg 17 | dist_sum 18 | dist_max 19 | dist_min 20 | dist_op 21 | dist_statistics_scalar 22 | 23 | Set up distributed training 24 | --------------------------- 25 | 26 | .. card:: 27 | :class-header: sd-bg-success sd-text-white 28 | :class-card: sd-outline-success sd-rounded-1 29 | 30 | Documentation 31 | ^^^ 32 | 33 | .. autofunction:: setup_distributed 34 | .. autofunction:: get_rank 35 | .. autofunction:: world_size 36 | .. autofunction:: fork 37 | 38 | Tensor Operations 39 | ----------------- 40 | 41 | .. card:: 42 | :class-header: sd-bg-success sd-text-white 43 | :class-card: sd-outline-success sd-rounded-1 44 | 45 | Documentation 46 | ^^^ 47 | 48 | .. autofunction:: avg_tensor 49 | .. autofunction:: avg_grads 50 | .. autofunction:: sync_params 51 | .. autofunction:: avg_params 52 | 53 | Distributed Operations 54 | ---------------------- 55 | 56 | .. card:: 57 | :class-header: sd-bg-success sd-text-white 58 | :class-card: sd-outline-success sd-rounded-1 59 | 60 | Documentation 61 | ^^^ 62 | 63 | .. autofunction:: dist_avg 64 | .. autofunction:: dist_sum 65 | .. autofunction:: dist_max 66 | .. autofunction:: dist_min 67 | .. autofunction:: dist_op 68 | .. autofunction:: dist_statistics_scalar 69 | -------------------------------------------------------------------------------- /examples/evaluate_saved_policy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """One example for evaluate saved policy.""" 16 | 17 | import os 18 | 19 | import omnisafe 20 | 21 | 22 | # Just fill your experiment's log directory in here. 23 | # Such as: ~/omnisafe/examples/runs/PPOLag-{SafetyPointGoal1-v0}/seed-000-2023-03-07-20-25-48 24 | LOG_DIR = '' 25 | if __name__ == '__main__': 26 | evaluator = omnisafe.Evaluator(render_mode='rgb_array') 27 | scan_dir = os.scandir(os.path.join(LOG_DIR, 'torch_save')) 28 | for item in scan_dir: 29 | if item.is_file() and item.name.split('.')[-1] == 'pt': 30 | evaluator.load_saved( 31 | save_dir=LOG_DIR, 32 | model_name=item.name, 33 | camera_name='track', 34 | width=256, 35 | height=256, 36 | ) 37 | evaluator.render(num_episodes=1) 38 | evaluator.evaluate(num_episodes=1) 39 | scan_dir.close() 40 | -------------------------------------------------------------------------------- /docs/source/utils/tools.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Tools 2 | ============== 3 | 4 | .. currentmodule:: omnisafe.utils.tools 5 | 6 | .. autosummary:: 7 | 8 | get_flat_params_from 9 | get_flat_gradients_from 10 | set_param_values_to_model 11 | custom_cfgs_to_dict 12 | update_dict 13 | load_yaml 14 | recursive_check_config 15 | seed_all 16 | 17 | Algorithms Tools 18 | ---------------- 19 | 20 | .. card:: 21 | :class-header: sd-bg-success sd-text-white 22 | :class-card: sd-outline-success sd-rounded-1 23 | 24 | Documentation 25 | ^^^ 26 | 27 | .. autofunction:: get_flat_params_from 28 | .. autofunction:: get_flat_gradients_from 29 | .. autofunction:: set_param_values_to_model 30 | 31 | Config Tools 32 | ---------------- 33 | 34 | .. card:: 35 | :class-header: sd-bg-success sd-text-white 36 | :class-card: sd-outline-success sd-rounded-1 37 | 38 | Documentation 39 | ^^^ 40 | .. autofunction:: custom_cfgs_to_dict 41 | .. autofunction:: update_dict 42 | .. autofunction:: load_yaml 43 | .. autofunction:: recursive_check_config 44 | 45 | Seed Tools 46 | ---------------- 47 | 48 | .. card:: 49 | :class-header: sd-bg-success sd-text-white 50 | :class-card: sd-outline-success sd-rounded-1 51 | 52 | Documentation 53 | ^^^ 54 | .. autofunction:: seed_all 55 | 56 | .. currentmodule:: omnisafe.utils.exp_grid_tools 57 | 58 | Experiment Grid Tools 59 | --------------------- 60 | 61 | .. card:: 62 | :class-header: sd-bg-success sd-text-white 63 | :class-card: sd-outline-success sd-rounded-1 64 | 65 | Documentation 66 | ^^^ 67 | .. autofunction:: all_bools 68 | .. autofunction:: valid_str 69 | -------------------------------------------------------------------------------- /tests/test_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test registry.""" 16 | 17 | import pytest 18 | 19 | from omnisafe.algorithms.registry import Registry 20 | from omnisafe.envs.core import CMDP, env_register, env_unregister 21 | 22 | 23 | class TestRegistry: 24 | name: str = 'test' 25 | idx: int = 0 26 | 27 | 28 | def test_with_error() -> None: 29 | registry = Registry('test') 30 | TestRegistry() 31 | with pytest.raises(TypeError): 32 | registry.register('test') 33 | with pytest.raises(KeyError): 34 | registry.get('test') 35 | with pytest.raises(TypeError): 36 | 37 | @env_register 38 | class TestEnv: 39 | name: str = 'test' 40 | idx: int = 0 41 | 42 | with pytest.raises(ValueError): 43 | 44 | @env_register 45 | class CustomEnv(CMDP): 46 | pass 47 | 48 | @env_register 49 | @env_unregister 50 | @env_unregister 51 | class CustomEnv(CMDP): # noqa 52 | _support_envs = ['Simple-v0'] # noqa 53 | -------------------------------------------------------------------------------- /tests/test_statistics_tools.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test analyzing policies trained by exp-x with OmniSafe.""" 16 | 17 | import pytest 18 | 19 | from omnisafe.common.statistics_tools import StatisticsTools 20 | 21 | 22 | def test_statistics_tools(): 23 | # just fill in the path in which experiment grid runs. 24 | path = './saved_source/test_statistics_tools' 25 | st = StatisticsTools() 26 | st.load_source(path) 27 | # just fill in the name of the parameter of which value you want to compare. 28 | # then you can specify the value of the parameter you want to compare, 29 | # or you can just specify how many values you want to compare in single graph at most, 30 | # and the function will automatically generate all possible combinations of the graph. 31 | # but the two mode can not be used at the same time. 32 | st.draw_graph('algo', None, 1) 33 | st.draw_graph('algo', ['PolicyGradient'], None) 34 | not_a_path = 'not_a_path' 35 | with pytest.raises(SystemExit): 36 | st.load_source(not_a_path) 37 | -------------------------------------------------------------------------------- /tests/test_normalizer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test normalizer.""" 16 | 17 | import torch 18 | 19 | import helpers 20 | from omnisafe.common.normalizer import Normalizer 21 | 22 | 23 | @helpers.parametrize( 24 | shape=[(), (10,), (10, 10)], 25 | ) 26 | def test_normalizer(shape: tuple): 27 | norm = Normalizer(shape) 28 | 29 | assert norm.mean.shape == shape 30 | 31 | data_lst = [] 32 | for _ in range(1000): 33 | data = torch.randn(shape) 34 | data_lst.append(data) 35 | norm(data) 36 | 37 | data = torch.stack(data_lst) 38 | assert torch.allclose(data.mean(dim=0), norm.mean, atol=1e-2) 39 | assert torch.allclose(data.std(dim=0), norm.std, atol=1e-2) 40 | 41 | norm = Normalizer(shape) 42 | 43 | data_lst = [] 44 | for _ in range(1000): 45 | data = torch.randn(10, *shape) 46 | data_lst.append(data) 47 | norm(data) 48 | 49 | data = torch.cat(data_lst) 50 | assert torch.allclose(data.mean(dim=0), norm.mean, atol=1e-2) 51 | assert torch.allclose(data.std(dim=0), norm.std, atol=1e-2) 52 | -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | pull_request: 9 | # Allow to trigger the workflow manually 10 | workflow_dispatch: 11 | 12 | permissions: 13 | contents: read 14 | 15 | concurrency: 16 | group: "${{ github.workflow }}-${{ github.ref }}" 17 | cancel-in-progress: ${{ github.event_name == 'pull_request' }} 18 | 19 | jobs: 20 | lint: 21 | runs-on: ubuntu-latest 22 | timeout-minutes: 30 23 | steps: 24 | - name: Checkout 25 | uses: actions/checkout@v3 26 | with: 27 | submodules: "recursive" 28 | fetch-depth: 1 29 | 30 | - name: Set up Python 31 | uses: actions/setup-python@v4 32 | with: 33 | python-version: "3.8" 34 | update-environment: true 35 | 36 | - name: Upgrade pip 37 | run: | 38 | python -m pip install --upgrade pip setuptools 39 | 40 | - name: Install OmniSafe 41 | run: | 42 | python -m pip install -vvv --editable '.[lint]' 43 | 44 | - name: pre-commit 45 | run: | 46 | make pre-commit 47 | 48 | - name: flake8 49 | run: | 50 | make flake8 51 | 52 | - name: pylint 53 | run: | 54 | make pylint 55 | 56 | - name: isort and black 57 | run: | 58 | make py-format 59 | 60 | - name: addlicense 61 | run: | 62 | make addlicense 63 | 64 | - name: mypy 65 | run: | 66 | make mypy 67 | 68 | - name: Install dependencies 69 | run: | 70 | python -m pip install -r docs/requirements.txt 71 | 72 | - name: docstyle 73 | run: | 74 | make docstyle 75 | 76 | - name: spelling 77 | run: | 78 | make spelling 79 | -------------------------------------------------------------------------------- /.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - dev 8 | pull_request: 9 | paths: 10 | - setup.py 11 | - pyproject.toml 12 | - tests/** 13 | - omnisafe/** 14 | - .github/workflows/tests.yml 15 | # Allow to trigger the workflow manually 16 | workflow_dispatch: 17 | 18 | permissions: 19 | contents: read 20 | 21 | concurrency: 22 | group: "${{ github.workflow }}-${{ github.ref }}" 23 | cancel-in-progress: ${{ github.event_name == 'pull_request' }} 24 | 25 | jobs: 26 | test: 27 | name: Test for Python ${{ matrix.python-version }} on ${{ matrix.os }} 28 | runs-on: ${{ matrix.os }} 29 | strategy: 30 | matrix: 31 | os: [ubuntu-latest] 32 | python-version: ["3.8", "3.9", "3.10"] 33 | fail-fast: false 34 | timeout-minutes: 30 35 | steps: 36 | - name: Checkout 37 | uses: actions/checkout@v3 38 | with: 39 | submodules: "recursive" 40 | fetch-depth: 1 41 | 42 | - name: Set up Python 43 | uses: actions/setup-python@v4 44 | with: 45 | python-version: ${{ matrix.python-version }} 46 | update-environment: true 47 | 48 | - name: Upgrade pip 49 | run: | 50 | python -m pip install --upgrade pip setuptools 51 | 52 | - name: Install OmniSafe 53 | run: | 54 | python -m pip install -vvv -e '.[test]' 55 | 56 | - name: Test with pytest 57 | run: | 58 | make test 59 | 60 | - name: Upload coverage reports to Codecov 61 | if: ${{ matrix.python-version == '3.8'}} 62 | run: | 63 | curl -Os https://uploader.codecov.io/latest/linux/codecov 64 | chmod +x codecov 65 | ./codecov -t ${CODECOV_TOKEN=634594d3-0416-4632-ab6a-3bf34a8c0af3} 66 | -------------------------------------------------------------------------------- /omnisafe/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """This module contains the model for all methods.""" 16 | 17 | from omnisafe.models.actor import ActorBuilder 18 | from omnisafe.models.actor.gaussian_actor import GaussianActor 19 | from omnisafe.models.actor.gaussian_learning_actor import GaussianLearningActor 20 | from omnisafe.models.actor.gaussian_sac_actor import GaussianSACActor 21 | from omnisafe.models.actor.mlp_actor import MLPActor 22 | from omnisafe.models.actor.perturbation_actor import PerturbationActor 23 | from omnisafe.models.actor.vae_actor import VAE 24 | from omnisafe.models.actor_critic.actor_critic import ActorCritic 25 | from omnisafe.models.actor_critic.actor_q_critic import ActorQCritic 26 | from omnisafe.models.actor_critic.constraint_actor_critic import ConstraintActorCritic 27 | from omnisafe.models.actor_critic.constraint_actor_q_critic import ConstraintActorQCritic 28 | from omnisafe.models.base import Actor, Critic 29 | from omnisafe.models.critic import CriticBuilder 30 | from omnisafe.models.critic.q_critic import QCritic 31 | from omnisafe.models.critic.v_critic import VCritic 32 | from omnisafe.models.offline.dice import ObsEncoder 33 | -------------------------------------------------------------------------------- /conda-recipe.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # 16 | # Create virtual environment with command: 17 | # 18 | # $ CONDA_OVERRIDE_CUDA=11.7 conda env create --file conda-recipe.yaml 19 | # 20 | 21 | name: omnisafe 22 | 23 | channels: 24 | - pytorch 25 | - nvidia/label/cuda-11.7.1 26 | - defaults 27 | - conda-forge 28 | 29 | dependencies: 30 | - python = 3.10 31 | - pip 32 | 33 | # Learning 34 | - pytorch::pytorch >= 1.10.0 # sync with project.dependencies 35 | - pytorch::torchvision 36 | - pytorch::pytorch-mutex = *=*cuda* 37 | - pip: 38 | - safety-gymnasium >= 0.1 39 | - tensorboard 40 | - wandb 41 | 42 | # Other dependencies 43 | - numpy 44 | - moviepy 45 | - pyyaml 46 | - typer 47 | - seaborn 48 | - pandas 49 | - matplotlib-base 50 | - typing-extensions 51 | 52 | # Device select 53 | - nvidia/label/cuda-11.7.1::cuda-toolkit = 11.7 54 | 55 | # Documentation 56 | - sphinx 57 | - sphinx-autobuild 58 | - sphinx-copybutton 59 | - sphinxcontrib-spelling 60 | - sphinxcontrib-bibtex 61 | - sphinx-autodoc-typehints 62 | - pyenchant 63 | - hunspell-en 64 | - myst-nb 65 | - ipykernel 66 | - pandoc 67 | - docutils 68 | -------------------------------------------------------------------------------- /tutorials/README.md: -------------------------------------------------------------------------------- 1 | # Tutorial 2 | ![logo](./images/omnisafe.jpg) 3 | 4 | At present, our tutorials are available in the following languages, and you can access them via Colab: 5 | 6 | ## English 7 | - [Getting Started](https://colab.research.google.com/github/PKU-Alignment/omnisafe/blob/main/tutorials/English/1.Getting_Started.ipynb) Introduce the basic usage of OmniSafe so that users can quickly hand on it. 8 | - [CLI Command](https://colab.research.google.com/github/PKU-Alignment/omnisafe/blob/main/tutorials/English/2.CLI_Command.ipynb) Introduce how to use the CLI tool of OmniSafe. 9 | - [Environment Customization From Scratch](https://colab.research.google.com/github/PKU-Alignment/omnisafe/blob/main/tutorials/English/3.Environment%20Customization%20from%20Scratch.ipynb) Introduce how to complete and integrate a customized environment from scratch. 10 | - [Environment Customization From Community](https://colab.research.google.com/github/PKU-Alignment/omnisafe/blob/main/tutorials/English/4.Environment%20Customization%20from%20Community.ipynb) Introduce how to integrate the community environment, e.g., Gymnasium, into OmniSafe. 11 | 12 | ## Zh-CN 13 | - [Getting Started](https://colab.research.google.com/github/PKU-Alignment/omnisafe/blob/main/tutorials/zh-cn/1.Getting%20Started.ipynb) 介绍OmniSafe的基本用法,使用户能够快速上手。 14 | - [CLI Command](https://colab.research.google.com/github/PKU-Alignment/omnisafe/blob/main/tutorials/zh-cn/2.CLI%20Command.ipynb) 介绍如何使用OmniSafe的命令行工具。 15 | - [Environment Customization From Scratch](https://colab.research.google.com/github/PKU-Alignment/omnisafe/blob/main/tutorials/zh-cn/3.Environment%20Customization%20from%20Scratch.ipynb) 介绍如何从零开始构建并嵌入定制化环境。 16 | - [Environment Customization From Community](https://colab.research.google.com/github/PKU-Alignment/omnisafe/blob/main/tutorials/zh-cn/4.Environment%20Customization%20from%20Community.ipynb) 以Gymnasium为例介绍如何将社区环境嵌入OmniSafe。 17 | -------------------------------------------------------------------------------- /docs/source/common/buffer.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Buffer 2 | =============== 3 | 4 | .. currentmodule:: omnisafe.common.buffer 5 | 6 | .. autosummary:: 7 | 8 | BaseBuffer 9 | OnPolicyBuffer 10 | OffPolicyBuffer 11 | VectorOffPolicyBuffer 12 | VectorOnPolicyBuffer 13 | 14 | 15 | Base Buffer 16 | ----------- 17 | 18 | .. card:: 19 | :class-header: sd-bg-success sd-text-white 20 | :class-card: sd-outline-success sd-rounded-1 21 | 22 | Documentation 23 | ^^^ 24 | 25 | .. autoclass:: BaseBuffer 26 | :members: 27 | :private-members: 28 | 29 | 30 | 31 | On Policy Buffer 32 | ---------------- 33 | 34 | .. card:: 35 | :class-header: sd-bg-success sd-text-white 36 | :class-card: sd-outline-success sd-rounded-1 37 | 38 | Documentation 39 | ^^^ 40 | 41 | .. autoclass:: OnPolicyBuffer 42 | :members: 43 | :private-members: 44 | 45 | 46 | 47 | Off Policy buffer 48 | ----------------- 49 | 50 | .. card:: 51 | :class-header: sd-bg-success sd-text-white 52 | :class-card: sd-outline-success sd-rounded-1 53 | 54 | Documentation 55 | ^^^ 56 | 57 | .. autoclass:: OffPolicyBuffer 58 | :members: 59 | :private-members: 60 | 61 | 62 | 63 | Vector On Policy Buffer 64 | ----------------------- 65 | 66 | .. card:: 67 | :class-header: sd-bg-success sd-text-white 68 | :class-card: sd-outline-success sd-rounded-1 69 | 70 | Documentation 71 | ^^^ 72 | 73 | .. autoclass:: VectorOnPolicyBuffer 74 | :members: 75 | :private-members: 76 | 77 | 78 | 79 | Vector Off Policy Buffer 80 | ------------------------ 81 | 82 | .. card:: 83 | :class-header: sd-bg-success sd-text-white 84 | :class-card: sd-outline-success sd-rounded-1 85 | 86 | Documentation 87 | ^^^ 88 | 89 | .. autoclass:: VectorOffPolicyBuffer 90 | :members: 91 | :private-members: 92 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: ✨ Feature Request 2 | description: Suggest an idea for this project. 3 | title: "[Feature Request] " 4 | labels: [enhancement] 5 | body: 6 | - type: checkboxes 7 | id: steps 8 | attributes: 9 | label: Required prerequisites 10 | description: Make sure you've completed the following steps before submitting your issue -- thank you! 11 | options: 12 | - label: I have searched the [Issue Tracker](https://github.com/PKU-Alignment/omnisafe/issues) and [Discussions](https://github.com/PKU-Alignment/omnisafe/discussions) that this hasn't already been reported. (+1 or comment there if it has.) 13 | required: true 14 | - label: Consider asking first in a [Discussion](https://github.com/PKU-Alignment/omnisafe/discussions/new). 15 | required: false 16 | 17 | - type: textarea 18 | id: motivation 19 | attributes: 20 | label: Motivation 21 | description: Outline the motivation for the proposal. 22 | value: | 23 | 26 | validations: 27 | required: true 28 | 29 | - type: textarea 30 | id: solution 31 | attributes: 32 | label: Solution 33 | description: Provide a clear and concise description of what you want to happen. 34 | 35 | - type: textarea 36 | id: alternatives 37 | attributes: 38 | label: Alternatives 39 | description: A clear and concise description of any alternative solutions or features you've considered. 40 | 41 | - type: textarea 42 | id: additional-context 43 | attributes: 44 | label: Additional context 45 | description: Add any other context about the problem here. Screenshots may also be helpful. 46 | -------------------------------------------------------------------------------- /docs/source/model/modelbased_planner.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Model-based Planner 2 | ============================ 3 | 4 | .. currentmodule:: omnisafe.algorithms.model_based.planner 5 | 6 | ARC Planner 7 | ----------- 8 | 9 | .. card:: 10 | :class-header: sd-bg-success sd-text-white 11 | :class-card: sd-outline-success sd-rounded-1 12 | 13 | Documentation 14 | ^^^ 15 | 16 | .. autoclass:: ARCPlanner 17 | :members: 18 | :private-members: 19 | 20 | CAP Planner 21 | ----------- 22 | 23 | .. card:: 24 | :class-header: sd-bg-success sd-text-white 25 | :class-card: sd-outline-success sd-rounded-1 26 | 27 | Documentation 28 | ^^^ 29 | 30 | .. autoclass:: CAPPlanner 31 | :members: 32 | :private-members: 33 | 34 | CCE Planner 35 | ----------- 36 | 37 | .. card:: 38 | :class-header: sd-bg-success sd-text-white 39 | :class-card: sd-outline-success sd-rounded-1 40 | 41 | Documentation 42 | ^^^ 43 | 44 | .. autoclass:: CCEPlanner 45 | :members: 46 | :private-members: 47 | 48 | CEM Planner 49 | ----------- 50 | 51 | .. card:: 52 | :class-header: sd-bg-success sd-text-white 53 | :class-card: sd-outline-success sd-rounded-1 54 | 55 | Documentation 56 | ^^^ 57 | 58 | .. autoclass:: CEMPlanner 59 | :members: 60 | :private-members: 61 | 62 | RCE Planner 63 | ----------- 64 | 65 | .. card:: 66 | :class-header: sd-bg-success sd-text-white 67 | :class-card: sd-outline-success sd-rounded-1 68 | 69 | Documentation 70 | ^^^ 71 | 72 | .. autoclass:: RCEPlanner 73 | :members: 74 | :private-members: 75 | 76 | 77 | SafeARC Planner 78 | --------------- 79 | 80 | .. card:: 81 | :class-header: sd-bg-success sd-text-white 82 | :class-card: sd-outline-success sd-rounded-1 83 | 84 | Documentation 85 | ^^^ 86 | 87 | .. autoclass:: SafeARCPlanner 88 | :members: 89 | :private-members: 90 | -------------------------------------------------------------------------------- /docs/source/model/modelbased_model.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Model-based Model 2 | ========================== 3 | 4 | .. currentmodule:: omnisafe.algorithms.model_based.base.ensemble 5 | 6 | Standard Scaler 7 | --------------- 8 | 9 | .. card:: 10 | :class-header: sd-bg-success sd-text-white 11 | :class-card: sd-outline-success sd-rounded-1 12 | 13 | Documentation 14 | ^^^ 15 | 16 | .. autoclass:: StandardScaler 17 | :members: 18 | :private-members: 19 | 20 | Initialize Weight 21 | ----------------- 22 | 23 | .. card:: 24 | :class-header: sd-bg-success sd-text-white 25 | :class-card: sd-outline-success sd-rounded-1 26 | 27 | Documentation 28 | ^^^ 29 | 30 | .. autofunction:: init_weights 31 | 32 | Unbatched Forward 33 | ----------------- 34 | 35 | .. card:: 36 | :class-header: sd-bg-success sd-text-white 37 | :class-card: sd-outline-success sd-rounded-1 38 | 39 | Documentation 40 | ^^^ 41 | 42 | .. autofunction:: unbatched_forward 43 | 44 | Ensemble Fully-Connected Layer 45 | ------------------------------ 46 | 47 | .. card:: 48 | :class-header: sd-bg-success sd-text-white 49 | :class-card: sd-outline-success sd-rounded-1 50 | 51 | Documentation 52 | ^^^ 53 | 54 | .. autoclass:: EnsembleFC 55 | :members: 56 | :private-members: 57 | 58 | Ensemble Model 59 | -------------- 60 | 61 | .. card:: 62 | :class-header: sd-bg-success sd-text-white 63 | :class-card: sd-outline-success sd-rounded-1 64 | 65 | Documentation 66 | ^^^ 67 | 68 | .. autoclass:: EnsembleModel 69 | :members: 70 | :private-members: 71 | 72 | 73 | Ensemble Dynamics Model 74 | ----------------------- 75 | 76 | .. card:: 77 | :class-header: sd-bg-success sd-text-white 78 | :class-card: sd-outline-success sd-rounded-1 79 | 80 | Documentation 81 | ^^^ 82 | 83 | .. autoclass:: EnsembleDynamicsModel 84 | :members: 85 | :private-members: 86 | -------------------------------------------------------------------------------- /omnisafe/typing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Typing utilities.""" 16 | 17 | from typing import ( 18 | Any, 19 | Callable, 20 | Dict, 21 | List, 22 | Literal, 23 | NamedTuple, 24 | Optional, 25 | Sequence, 26 | Tuple, 27 | TypeVar, 28 | Union, 29 | ) 30 | 31 | import torch 32 | from gymnasium.spaces import Box, Discrete 33 | from torch.types import Device 34 | 35 | 36 | RenderFrame = TypeVar('RenderFrame') 37 | OmnisafeSpace = Union[Box, Discrete] 38 | Activation = Literal['identity', 'relu', 'sigmoid', 'softplus', 'tanh'] 39 | AdvatageEstimator = Literal['gae', 'gae-rtg', 'vtrace', 'plain'] 40 | InitFunction = Literal['kaiming_uniform', 'xavier_normal', 'glorot', 'xavier_uniform', 'orthogonal'] 41 | CriticType = Literal['v', 'q'] 42 | ActorType = Literal['gaussian_learning', 'gaussian_sac', 'mlp', 'vae', 'perturbation'] 43 | DEVICE_CPU = torch.device('cpu') 44 | 45 | 46 | __all__ = [ 47 | 'Activation', 48 | 'AdvatageEstimator', 49 | 'InitFunction', 50 | 'Callable', 51 | 'List', 52 | 'Optional', 53 | 'Sequence', 54 | 'Tuple', 55 | 'TypeVar', 56 | 'Union', 57 | 'Dict', 58 | 'NamedTuple', 59 | 'Any', 60 | 'OmnisafeSpace', 61 | 'RenderFrame', 62 | 'Device', 63 | 'DEVICE_CPU', 64 | ] 65 | -------------------------------------------------------------------------------- /omnisafe/version.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """OmniSafe: A comprehensive and reliable benchmark for safe reinforcement learning.""" 16 | 17 | __version__ = '0.5.0' 18 | __license__ = 'Apache License, Version 2.0' 19 | __author__ = 'OmniSafe Contributors' 20 | __release__ = False 21 | 22 | if not __release__: 23 | import os 24 | import subprocess 25 | 26 | try: 27 | prefix, sep, suffix = ( 28 | subprocess.check_output( 29 | ['git', 'describe', '--abbrev=7'], # noqa: S603,S607 30 | cwd=os.path.dirname(os.path.abspath(__file__)), 31 | stderr=subprocess.DEVNULL, 32 | text=True, 33 | ) 34 | .strip() 35 | .lstrip('v') 36 | .replace('-', '.dev', 1) 37 | .replace('-', '+', 1) 38 | .partition('.dev') 39 | ) 40 | if sep: 41 | version_prefix, dot, version_tail = prefix.rpartition('.') 42 | prefix = f'{version_prefix}{dot}{int(version_tail) + 1}' 43 | __version__ = sep.join((prefix, suffix)) 44 | del version_prefix, dot, version_tail 45 | else: 46 | __version__ = prefix 47 | del prefix, sep, suffix 48 | except (OSError, subprocess.CalledProcessError): 49 | pass 50 | 51 | del os, subprocess 52 | -------------------------------------------------------------------------------- /tests/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Helpers""" 16 | 17 | import itertools 18 | 19 | import numpy as np 20 | import pytest 21 | import torch 22 | import torch.types 23 | 24 | 25 | def dtype_numpy2torch(dtype: np.dtype) -> torch.dtype: 26 | """Convert numpy dtype to torch dtype""" 27 | return torch.tensor(np.zeros(1, dtype=dtype)).dtype 28 | 29 | 30 | def dtype_torch2numpy(dtype: torch.dtype) -> np.dtype: 31 | """Convert torch dtype to numpy dtype""" 32 | return torch.zeros(1, dtype=dtype).numpy().dtype 33 | 34 | 35 | def parametrize(**argvalues) -> pytest.mark.parametrize: 36 | """Test with multiple parameters""" 37 | arguments = list(argvalues) 38 | 39 | if 'dtype' in argvalues: 40 | dtypes = argvalues['dtype'] 41 | argvalues['dtype'] = dtypes[:1] 42 | arguments.remove('dtype') 43 | arguments.insert(0, 'dtype') 44 | 45 | argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments)))) 46 | first_product = argvalues[0] 47 | argvalues.extend((dtype,) + first_product[1:] for dtype in dtypes[1:]) 48 | else: 49 | argvalues = list(itertools.product(*tuple(map(argvalues.get, arguments)))) 50 | 51 | ids = tuple( 52 | '-'.join(f'{arg}({val})' for arg, val in zip(arguments, values)) for values in argvalues 53 | ) 54 | 55 | return pytest.mark.parametrize(arguments, argvalues, ids=ids) 56 | -------------------------------------------------------------------------------- /examples/plot.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Example of plotting training curve.""" 16 | 17 | 18 | import argparse 19 | 20 | from omnisafe.utils.plotter import Plotter 21 | 22 | 23 | # For example, you can run the following command to plot the training curve: 24 | # python plot.py --logdir omnisafe/examples/runs/PPOLag-{SafetyAntVelocity-v1} 25 | # after training the policy with the following command: 26 | # python train_policy.py --algo PPOLag --env-id SafetyAntVelocity-v1 27 | if __name__ == '__main__': 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--logdir', nargs='*') 30 | parser.add_argument('--legend', '-l', nargs='*') 31 | parser.add_argument('--xaxis', '-x', default='Steps') 32 | parser.add_argument('--value', '-y', default='Rewards', nargs='*') 33 | parser.add_argument('--count', action='store_true') 34 | parser.add_argument('--smooth', '-s', type=int, default=1) 35 | parser.add_argument('--select', nargs='*') 36 | parser.add_argument('--exclude', nargs='*') 37 | parser.add_argument('--estimator', default='mean') 38 | args = parser.parse_args() 39 | 40 | plotter = Plotter() 41 | plotter.make_plots( 42 | args.logdir, 43 | args.legend, 44 | args.xaxis, 45 | args.value, 46 | args.count, 47 | smooth=args.smooth, 48 | select=args.select, 49 | exclude=args.exclude, 50 | estimator=args.estimator, 51 | ) 52 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """On-policy algorithms.""" 16 | 17 | from omnisafe.algorithms.on_policy import ( 18 | base, 19 | early_terminated, 20 | first_order, 21 | naive_lagrange, 22 | penalty_function, 23 | pid_lagrange, 24 | primal, 25 | saute, 26 | second_order, 27 | simmer, 28 | ) 29 | from omnisafe.algorithms.on_policy.base import PPO, TRPO, NaturalPG, PolicyGradient 30 | from omnisafe.algorithms.on_policy.early_terminated import PPOEarlyTerminated, TRPOEarlyTerminated 31 | from omnisafe.algorithms.on_policy.first_order import CUP, FOCOPS 32 | from omnisafe.algorithms.on_policy.naive_lagrange import PDO, RCPO, PPOLag, TRPOLag 33 | from omnisafe.algorithms.on_policy.penalty_function import IPO, P3O 34 | from omnisafe.algorithms.on_policy.pid_lagrange import CPPOPID, TRPOPID 35 | from omnisafe.algorithms.on_policy.primal import OnCRPO 36 | from omnisafe.algorithms.on_policy.saute import PPOSaute, TRPOSaute 37 | from omnisafe.algorithms.on_policy.second_order import CPO, PCPO 38 | from omnisafe.algorithms.on_policy.simmer import PPOSimmerPID, TRPOSimmerPID 39 | 40 | 41 | __all__ = [ 42 | *base.__all__, 43 | *early_terminated.__all__, 44 | *first_order.__all__, 45 | *naive_lagrange.__all__, 46 | *primal.__all__, 47 | *penalty_function.__all__, 48 | *pid_lagrange.__all__, 49 | *saute.__all__, 50 | *second_order.__all__, 51 | *simmer.__all__, 52 | ] 53 | -------------------------------------------------------------------------------- /docs/source/saferlapi/lagrange.rst: -------------------------------------------------------------------------------- 1 | Lagrange Algorithms 2 | =================== 3 | 4 | .. currentmodule:: omnisafe.algorithms.on_policy 5 | 6 | .. autosummary:: 7 | 8 | PPOLag 9 | TRPOLag 10 | 11 | .. _ppolagapi: 12 | 13 | PPOLag 14 | ------ 15 | 16 | .. card:: 17 | :class-header: sd-bg-success sd-text-white 18 | :class-card: sd-outline-success sd-rounded-1 19 | 20 | Documentation 21 | ^^^ 22 | 23 | .. autoclass:: PPOLag 24 | :members: 25 | :private-members: 26 | 27 | 28 | 29 | TRPOLag 30 | ------- 31 | 32 | .. card:: 33 | :class-header: sd-bg-success sd-text-white 34 | :class-card: sd-outline-success sd-rounded-1 35 | 36 | Documentation 37 | ^^^ 38 | 39 | .. autoclass:: TRPOLag 40 | :members: 41 | :private-members: 42 | 43 | CRPO 44 | ---- 45 | 46 | .. card:: 47 | :class-header: sd-bg-success sd-text-white 48 | :class-card: sd-outline-success sd-rounded-1 49 | 50 | Documentation 51 | ^^^ 52 | 53 | .. autoclass:: OnCRPO 54 | :members: 55 | :private-members: 56 | 57 | .. currentmodule:: omnisafe.algorithms.off_policy 58 | 59 | .. autosummary:: 60 | 61 | DDPGLag 62 | TD3Lag 63 | SACLag 64 | 65 | DDPGLag 66 | ------- 67 | 68 | .. card:: 69 | :class-header: sd-bg-success sd-text-white 70 | :class-card: sd-outline-success sd-rounded-1 71 | 72 | Documentation 73 | ^^^ 74 | 75 | .. autoclass:: DDPGLag 76 | :members: 77 | :private-members: 78 | 79 | SACLag 80 | ------ 81 | 82 | .. card:: 83 | :class-header: sd-bg-success sd-text-white 84 | :class-card: sd-outline-success sd-rounded-1 85 | 86 | Documentation 87 | ^^^ 88 | 89 | .. autoclass:: SACLag 90 | :members: 91 | :private-members: 92 | 93 | TD3Lag 94 | ------ 95 | 96 | .. card:: 97 | :class-header: sd-bg-success sd-text-white 98 | :class-card: sd-outline-success sd-rounded-1 99 | 100 | Documentation 101 | ^^^ 102 | 103 | .. autoclass:: TD3Lag 104 | :members: 105 | :private-members: 106 | -------------------------------------------------------------------------------- /tests/saved_source/PPO-{SafetyPointGoal1-v0}/seed-000-2023-03-16-12-08-52/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 0, 3 | "train_cfgs": { 4 | "device": "cpu", 5 | "torch_threads": 16, 6 | "vector_env_nums": 1, 7 | "parallel": 1, 8 | "total_steps": 1024000, 9 | "algo": "PPO", 10 | "env_id": "SafetyPointGoal1-v0", 11 | "epochs": 1000 12 | }, 13 | "algo_cfgs": { 14 | "steps_per_epoch": 1024, 15 | "update_iters": 40, 16 | "batch_size": 64, 17 | "target_kl": 0.02, 18 | "entropy_coef": 0.0, 19 | "reward_normalize": true, 20 | "cost_normalize": true, 21 | "obs_normalize": true, 22 | "kl_early_stop": true, 23 | "use_max_grad_norm": true, 24 | "max_grad_norm": 40.0, 25 | "use_critic_norm": true, 26 | "critic_norm_coef": 0.001, 27 | "gamma": 0.99, 28 | "cost_gamma": 0.99, 29 | "lam": 0.95, 30 | "lam_c": 0.95, 31 | "clip": 0.2, 32 | "adv_estimation_method": "gae", 33 | "standardized_rew_adv": true, 34 | "standardized_cost_adv": true, 35 | "penalty_coef": 0.0, 36 | "use_cost": false 37 | }, 38 | "logger_cfgs": { 39 | "use_wandb": false, 40 | "wandb_project": "omnisafe", 41 | "use_tensorboard": true, 42 | "save_model_freq": 100, 43 | "log_dir": "./runs", 44 | "window_lens": 100 45 | }, 46 | "model_cfgs": { 47 | "weight_initialization_mode": "kaiming_uniform", 48 | "actor_type": "gaussian_learning", 49 | "linear_lr_decay": true, 50 | "exploration_noise_anneal": false, 51 | "std_range": [ 52 | 0.5, 53 | 0.1 54 | ], 55 | "actor": { 56 | "hidden_sizes": [ 57 | 64, 58 | 64 59 | ], 60 | "activation": "tanh", 61 | "lr": 0.0003 62 | }, 63 | "critic": { 64 | "hidden_sizes": [ 65 | 64, 66 | 64 67 | ], 68 | "activation": "tanh", 69 | "lr": 0.0003 70 | } 71 | }, 72 | "exp_name": "PPO-(SafetyPointGoal1-v0)", 73 | "env_id": "SafetyPointGoal1-v0", 74 | "algo": "PPO" 75 | } 76 | -------------------------------------------------------------------------------- /omnisafe/configs/offline/VAEBC.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | defaults: 17 | # seed for random number generator 18 | seed: 0 19 | # training configurations 20 | train_cfgs: 21 | # device to use for training, options: cpu, cuda, cuda:0, cuda:0,1, etc. 22 | device: cpu 23 | # number of threads for torch 24 | torch_threads: 16 25 | # total number of steps to train 26 | total_steps: 1000000 27 | # dataset name 28 | dataset: SafetyPointCircle1-v0_mixed_0.5 29 | # evaluate_epoisodes 30 | evaluate_epoisodes: 10 31 | # parallel, offline only supports 1 32 | parallel: 1 33 | # vector_env_nums, offline only supports 1 34 | vector_env_nums: 1 35 | # algorithm configurations 36 | algo_cfgs: 37 | # batch size 38 | batch_size: 256 39 | # step per epoch, algo will log and eval every epoch 40 | steps_per_epoch: 1000 41 | # logger configurations 42 | logger_cfgs: 43 | # use wandb for logging 44 | use_wandb: False 45 | # wandb project name 46 | wandb_project: omnisafe 47 | # use tensorboard for logging 48 | use_tensorboard: True 49 | # save model frequency 50 | save_model_freq: 100 51 | # save logger path 52 | log_dir: "./runs" 53 | # model configurations 54 | model_cfgs: 55 | # The mode to initiate the weight of network, choosing from "kaiming_uniform", "xavier_normal", "glorot" and "orthogonal". 56 | weight_initialization_mode: "kaiming_uniform" 57 | # Size of hidden layers 58 | hidden_sizes: [750, 750] 59 | # Type of activation function, choosing from "tanh", "relu", "sigmoid", "identity", "softplus" 60 | activation: relu 61 | # Learning rate of model 62 | learning_rate: 0.001 63 | -------------------------------------------------------------------------------- /omnisafe/algorithms/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Registry for algorithms.""" 16 | 17 | from __future__ import annotations 18 | 19 | import inspect 20 | from typing import Any 21 | 22 | 23 | class Registry: 24 | """A registry to map strings to classes. 25 | 26 | Args: 27 | name (str): Registry name. 28 | """ 29 | 30 | def __init__(self, name: str) -> None: 31 | """Initialize an instance of :class:`Registry`.""" 32 | self._name: str = name 33 | self._module_dict: dict[str, type] = {} 34 | 35 | @property 36 | def name(self) -> str: 37 | """Return the name of the registry.""" 38 | return self._name 39 | 40 | def get(self, key: str) -> Any: 41 | """Get the class that has been registered under the given key.""" 42 | res = self._module_dict.get(key) 43 | if res is None: 44 | raise KeyError(f'{key} is not in the {self.name} registry') 45 | return res 46 | 47 | def _register_module(self, module_class: type) -> None: 48 | """Register a module. 49 | 50 | Args: 51 | module_class (type): Module to be registered. 52 | """ 53 | if not inspect.isclass(module_class): 54 | raise TypeError(f'module must be a class, but got {type(module_class)}') 55 | module_name = module_class.__name__ 56 | if module_name in self._module_dict: 57 | raise KeyError(f'{module_name} is already registered in {self.name}') 58 | self._module_dict[module_name] = module_class 59 | 60 | def register(self, cls: type) -> type: 61 | """Register a module class.""" 62 | self._register_module(cls) 63 | return cls 64 | 65 | 66 | REGISTRY = Registry('OmniSafe') 67 | 68 | 69 | register = REGISTRY.register 70 | get = REGISTRY.get 71 | -------------------------------------------------------------------------------- /docs/source/model/actor.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Actor 2 | ============== 3 | 4 | .. currentmodule:: omnisafe.models.base 5 | 6 | Base Actor 7 | ----------- 8 | 9 | .. card:: 10 | :class-header: sd-bg-success sd-text-white 11 | :class-card: sd-outline-success sd-rounded-1 12 | 13 | Documentation 14 | ^^^ 15 | 16 | .. autoclass:: Actor 17 | :members: 18 | :private-members: 19 | 20 | 21 | .. currentmodule:: omnisafe.models.actor 22 | 23 | .. autosummary:: 24 | 25 | ActorBuilder 26 | GaussianActor 27 | GaussianLearningActor 28 | GaussianSACActor 29 | 30 | Actor Builder 31 | ------------- 32 | .. card:: 33 | :class-header: sd-bg-success sd-text-white 34 | :class-card: sd-outline-success sd-rounded-1 35 | 36 | Documentation 37 | ^^^ 38 | 39 | .. autoclass:: ActorBuilder 40 | :members: 41 | :private-members: 42 | 43 | Gaussian Actor 44 | -------------- 45 | 46 | .. card:: 47 | :class-header: sd-bg-success sd-text-white 48 | :class-card: sd-outline-success sd-rounded-1 49 | 50 | Documentation 51 | ^^^ 52 | 53 | .. autoclass:: GaussianActor 54 | :members: 55 | :private-members: 56 | 57 | Gaussian Learning Actor 58 | ----------------------- 59 | 60 | .. card:: 61 | :class-header: sd-bg-success sd-text-white 62 | :class-card: sd-outline-success sd-rounded-1 63 | 64 | Documentation 65 | ^^^ 66 | 67 | .. autoclass:: GaussianLearningActor 68 | :members: 69 | :private-members: 70 | 71 | Gaussian SAC Actor 72 | ----------------------- 73 | 74 | .. card:: 75 | :class-header: sd-bg-success sd-text-white 76 | :class-card: sd-outline-success sd-rounded-1 77 | 78 | Documentation 79 | ^^^ 80 | 81 | .. autoclass:: GaussianSACActor 82 | :members: 83 | :private-members: 84 | 85 | Perturbation Actor 86 | ------------------ 87 | 88 | .. card:: 89 | :class-header: sd-bg-success sd-text-white 90 | :class-card: sd-outline-success sd-rounded-1 91 | 92 | Documentation 93 | ^^^ 94 | 95 | .. autoclass:: PerturbationActor 96 | :members: 97 | :private-members: 98 | 99 | VAE Actor 100 | --------- 101 | 102 | .. card:: 103 | :class-header: sd-bg-success sd-text-white 104 | :class-card: sd-outline-success sd-rounded-1 105 | 106 | Documentation 107 | ^^^ 108 | 109 | .. autoclass:: VAE 110 | :members: 111 | :private-members: 112 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/early_terminated/trpo_early_terminated.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of the Early terminated version of the TRPO algorithm.""" 16 | 17 | 18 | from omnisafe.adapter.early_terminated_adapter import EarlyTerminatedAdapter 19 | from omnisafe.algorithms import registry 20 | from omnisafe.algorithms.on_policy.base.trpo import TRPO 21 | from omnisafe.utils import distributed 22 | 23 | 24 | @registry.register 25 | class TRPOEarlyTerminated(TRPO): 26 | """The Early terminated version of the TRPO algorithm. 27 | 28 | References: 29 | - Title: Safe Exploration by Solving Early Terminated MDP. 30 | - Authors: Hao Sun, Ziping Xu, Meng Fang, Zhenghao Peng, Jiadong Guo, Bo Dai, Bolei Zhou. 31 | - URL: `TRPOEarlyTerminated `_ 32 | """ 33 | 34 | def _init_env(self) -> None: 35 | """Initialize the environment. 36 | 37 | OmniSafe uses :class:`omnisafe.adapter.EarlyTerminatedAdapter` to adapt the environment to 38 | the algorithm. 39 | 40 | User can customize the environment by inheriting this method. 41 | 42 | Examples: 43 | >>> def _init_env(self) -> None: 44 | ... self._env = CustomAdapter() 45 | """ 46 | self._env: EarlyTerminatedAdapter = EarlyTerminatedAdapter( 47 | self._env_id, 48 | self._cfgs.train_cfgs.vector_env_nums, 49 | self._seed, 50 | self._cfgs, 51 | ) 52 | assert (self._cfgs.algo_cfgs.steps_per_epoch) % ( 53 | distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums 54 | ) == 0, 'The number of steps per epoch is not divisible by the number of environments.' 55 | self._steps_per_epoch: int = ( 56 | self._cfgs.algo_cfgs.steps_per_epoch 57 | // distributed.world_size() 58 | // self._cfgs.train_cfgs.vector_env_nums 59 | ) 60 | -------------------------------------------------------------------------------- /docs/source/envs/adapter.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Adapter 2 | ================ 3 | 4 | .. currentmodule:: omnisafe.adapter 5 | 6 | OmniSafe provides a set of adapters to adapt the environment to the framework. 7 | 8 | .. autosummary:: 9 | 10 | OnlineAdapter 11 | OnPolicyAdapter 12 | OffPolicyAdapter 13 | SauteAdapter 14 | SimmerAdapter 15 | ModelBasedAdapter 16 | 17 | Online Adapter 18 | -------------- 19 | 20 | .. card:: 21 | :class-header: sd-bg-success sd-text-white 22 | :class-card: sd-outline-success sd-rounded-1 23 | 24 | Documentation 25 | ^^^ 26 | 27 | .. autoclass:: OnlineAdapter 28 | :members: 29 | :private-members: 30 | 31 | Offline Adapter 32 | --------------- 33 | 34 | .. card:: 35 | :class-header: sd-bg-success sd-text-white 36 | :class-card: sd-outline-success sd-rounded-1 37 | 38 | Documentation 39 | ^^^ 40 | 41 | .. autoclass:: OfflineAdapter 42 | :members: 43 | :private-members: 44 | 45 | 46 | On Policy Adapter 47 | ----------------- 48 | 49 | .. card:: 50 | :class-header: sd-bg-success sd-text-white 51 | :class-card: sd-outline-success sd-rounded-1 52 | 53 | Documentation 54 | ^^^ 55 | 56 | .. autoclass:: OnPolicyAdapter 57 | :members: 58 | :private-members: 59 | 60 | Off Policy Adapter 61 | ------------------ 62 | 63 | .. card:: 64 | :class-header: sd-bg-success sd-text-white 65 | :class-card: sd-outline-success sd-rounded-1 66 | 67 | Documentation 68 | ^^^ 69 | 70 | .. autoclass:: OffPolicyAdapter 71 | :members: 72 | :private-members: 73 | 74 | Saute Adapter 75 | ------------- 76 | 77 | .. card:: 78 | :class-header: sd-bg-success sd-text-white 79 | :class-card: sd-outline-success sd-rounded-1 80 | 81 | Documentation 82 | ^^^ 83 | 84 | .. autoclass:: SauteAdapter 85 | :members: 86 | :private-members: 87 | 88 | Simmer Adapter 89 | -------------- 90 | 91 | .. card:: 92 | :class-header: sd-bg-success sd-text-white 93 | :class-card: sd-outline-success sd-rounded-1 94 | 95 | Documentation 96 | ^^^ 97 | 98 | .. autoclass:: SimmerAdapter 99 | :members: 100 | :private-members: 101 | 102 | Model-based Adapter 103 | ------------------- 104 | 105 | .. card:: 106 | :class-header: sd-bg-success sd-text-white 107 | :class-card: sd-outline-success sd-rounded-1 108 | 109 | Documentation 110 | ^^^ 111 | 112 | .. autoclass:: ModelBasedAdapter 113 | :members: 114 | :private-members: 115 | -------------------------------------------------------------------------------- /omnisafe/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Safe Reinforcement Learning algorithms.""" 16 | 17 | import itertools 18 | from types import MappingProxyType 19 | 20 | from omnisafe.algorithms import model_based, off_policy, offline, on_policy 21 | from omnisafe.algorithms.base_algo import BaseAlgo 22 | 23 | # Model-based Safe 24 | from omnisafe.algorithms.model_based import CAPPETS, CCEPETS, LOOP, PETS, RCEPETS, SafeLOOP 25 | 26 | # Off-Policy Safe 27 | from omnisafe.algorithms.off_policy import ( 28 | CRABS, 29 | DDPG, 30 | DDPGPID, 31 | SAC, 32 | SACPID, 33 | TD3, 34 | TD3PID, 35 | DDPGLag, 36 | SACLag, 37 | TD3Lag, 38 | ) 39 | 40 | # Offline Safe 41 | from omnisafe.algorithms.offline import BCQ, CCRR, CRR, VAEBC, BCQLag, COptiDICE 42 | 43 | # On-Policy Safe 44 | from omnisafe.algorithms.on_policy import ( 45 | CPO, 46 | CPPOPID, 47 | CUP, 48 | FOCOPS, 49 | PCPO, 50 | PDO, 51 | PPO, 52 | RCPO, 53 | TRPO, 54 | TRPOPID, 55 | NaturalPG, 56 | OnCRPO, 57 | PolicyGradient, 58 | PPOEarlyTerminated, 59 | PPOLag, 60 | PPOSaute, 61 | PPOSimmerPID, 62 | TRPOEarlyTerminated, 63 | TRPOLag, 64 | TRPOSaute, 65 | TRPOSimmerPID, 66 | ) 67 | 68 | 69 | ALGORITHMS = { 70 | 'on-policy': tuple(on_policy.__all__), 71 | 'off-policy': tuple(off_policy.__all__), 72 | 'model-based': tuple(model_based.__all__), 73 | 'offline': tuple(offline.__all__), 74 | } 75 | 76 | ALGORITHM2TYPE = { 77 | algo: algo_type for algo_type, algorithms in ALGORITHMS.items() for algo in algorithms 78 | } 79 | 80 | __all__ = ALGORITHMS['all'] = tuple(itertools.chain.from_iterable(ALGORITHMS.values())) 81 | 82 | assert len(ALGORITHM2TYPE) == len(__all__), 'Duplicate algorithm names found.' 83 | 84 | ALGORITHMS = MappingProxyType(ALGORITHMS) # make this immutable 85 | ALGORITHM2TYPE = MappingProxyType(ALGORITHM2TYPE) # make this immutable 86 | 87 | del itertools, MappingProxyType 88 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/early_terminated/ppo_early_terminated.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of the Early terminated version of the PPO algorithm.""" 16 | 17 | 18 | from omnisafe.adapter.early_terminated_adapter import EarlyTerminatedAdapter 19 | from omnisafe.algorithms import registry 20 | from omnisafe.algorithms.on_policy.base.ppo import PPO 21 | from omnisafe.utils import distributed 22 | 23 | 24 | @registry.register 25 | class PPOEarlyTerminated(PPO): 26 | """The Early terminated version of the PPO algorithm. 27 | 28 | A simple combination of the Early terminated RL and the Proximal Policy Optimization algorithm. 29 | 30 | References: 31 | - Title: Safe Exploration by Solving Early Terminated MDP. 32 | - Authors: Hao Sun, Ziping Xu, Meng Fang, Zhenghao Peng, Jiadong Guo, Bo Dai, Bolei Zhou. 33 | - URL: `PPOEarlyTerminated `_ 34 | """ 35 | 36 | def _init_env(self) -> None: 37 | """Initialize the environment. 38 | 39 | OmniSafe uses :class:`omnisafe.adapter.EarlyTerminatedAdapter` to adapt the environment to 40 | the algorithm. 41 | 42 | User can customize the environment by inheriting this method. 43 | 44 | Examples: 45 | >>> def _init_env(self) -> None: 46 | ... self._env = CustomAdapter() 47 | """ 48 | self._env: EarlyTerminatedAdapter = EarlyTerminatedAdapter( 49 | self._env_id, 50 | self._cfgs.train_cfgs.vector_env_nums, 51 | self._seed, 52 | self._cfgs, 53 | ) 54 | assert (self._cfgs.algo_cfgs.steps_per_epoch) % ( 55 | distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums 56 | ) == 0, 'The number of steps per epoch is not divisible by the number of environments.' 57 | self._steps_per_epoch: int = ( 58 | self._cfgs.algo_cfgs.steps_per_epoch 59 | // distributed.world_size() 60 | // self._cfgs.train_cfgs.vector_env_nums 61 | ) 62 | -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---bc8c9e4f26a8072e52cf3eddefc51eae641946d3314efe537e918e0851f83aa5/PolicyGradient-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 0, 3 | "train_cfgs": { 4 | "device": "cpu", 5 | "torch_threads": 1, 6 | "vector_env_nums": 2, 7 | "parallel": 1, 8 | "total_steps": 4096, 9 | "epochs": 2 10 | }, 11 | "algo_cfgs": { 12 | "steps_per_epoch": 2048, 13 | "update_iters": 10, 14 | "batch_size": 64, 15 | "target_kl": 0.02, 16 | "entropy_coef": 0.0, 17 | "reward_normalize": false, 18 | "cost_normalize": false, 19 | "obs_normalize": true, 20 | "kl_early_stop": true, 21 | "use_max_grad_norm": true, 22 | "max_grad_norm": 40.0, 23 | "use_critic_norm": true, 24 | "critic_norm_coef": 0.001, 25 | "gamma": 0.99, 26 | "cost_gamma": 0.99, 27 | "lam": 0.95, 28 | "lam_c": 0.95, 29 | "adv_estimation_method": "gae", 30 | "standardized_rew_adv": true, 31 | "standardized_cost_adv": true, 32 | "penalty_coef": 0.0, 33 | "use_cost": false 34 | }, 35 | "logger_cfgs": { 36 | "use_wandb": false, 37 | "wandb_project": "omnisafe", 38 | "use_tensorboard": true, 39 | "save_model_freq": 100, 40 | "log_dir": "", 41 | "window_lens": 100 42 | }, 43 | "model_cfgs": { 44 | "weight_initialization_mode": "kaiming_uniform", 45 | "actor_type": "gaussian_learning", 46 | "linear_lr_decay": true, 47 | "exploration_noise_anneal": false, 48 | "std_range": [ 49 | 0.5, 50 | 0.1 51 | ], 52 | "actor": { 53 | "hidden_sizes": [ 54 | 64, 55 | 64 56 | ], 57 | "activation": "tanh", 58 | "lr": 0.0003 59 | }, 60 | "critic": { 61 | "hidden_sizes": [ 62 | 64, 63 | 64 64 | ], 65 | "activation": "tanh", 66 | "lr": 0.0003 67 | } 68 | }, 69 | "exp_increment_cfgs": { 70 | "seed": 0, 71 | "algo_cfgs": { 72 | "steps_per_epoch": 2048 73 | }, 74 | "train_cfgs": { 75 | "total_steps": 4096, 76 | "torch_threads": 1, 77 | "vector_env_nums": 2 78 | }, 79 | "logger_cfgs": { 80 | "use_wandb": false, 81 | "log_dir": "" 82 | } 83 | }, 84 | "exp_name": "PolicyGradient-{SafetyAntVelocity-v1}", 85 | "env_id": "SafetyAntVelocity-v1", 86 | "algo": "PolicyGradient" 87 | } 88 | -------------------------------------------------------------------------------- /tests/test_ensemble.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Test ensemble.""" 16 | 17 | from __future__ import annotations 18 | 19 | import numpy as np 20 | import torch 21 | import torch.nn as nn 22 | 23 | from omnisafe.algorithms.model_based.base.ensemble import ( 24 | EnsembleFC, 25 | EnsembleModel, 26 | StandardScaler, 27 | unbatched_forward, 28 | ) 29 | 30 | 31 | def test_standard_scaler(): 32 | standard_scaler = StandardScaler(device='cpu') 33 | torch_input = torch.rand(10, 10) 34 | assert isinstance(standard_scaler.transform(torch_input), torch.Tensor) 35 | 36 | 37 | def test_unbatched_forward(): 38 | layer = nn.Linear(10, 10) 39 | torch_input = torch.rand(10, 10) 40 | assert isinstance(unbatched_forward(layer, torch_input, 0), torch.Tensor) 41 | 42 | 43 | def test_ensemble_fc(): 44 | ensemble_fc = EnsembleFC( 45 | in_features=10, 46 | out_features=10, 47 | ensemble_size=10, 48 | weight_decay=0.0, 49 | bias=False, 50 | ) 51 | torch_input = torch.rand(10, 10, 10) 52 | assert isinstance(ensemble_fc(torch_input), torch.Tensor) 53 | 54 | 55 | def test_enemble_model(): 56 | ensemble_model = EnsembleModel( 57 | device='cpu', 58 | state_size=5, 59 | action_size=5, 60 | reward_size=1, 61 | cost_size=1, 62 | ensemble_size=10, 63 | predict_reward=True, 64 | predict_cost=True, 65 | ) 66 | numpy_state = np.random.rand(10, 10, 10) 67 | mean, var = ensemble_model(numpy_state) 68 | assert isinstance(mean, torch.Tensor) 69 | assert isinstance(var, torch.Tensor) 70 | mean, log_var = ensemble_model(numpy_state, ret_log_var=True) 71 | assert isinstance(mean, torch.Tensor) 72 | assert isinstance(log_var, torch.Tensor) 73 | numpy_state = np.random.rand(1, 10, 10) 74 | mean, var = ensemble_model.forward_idx(numpy_state, idx_model=0) 75 | assert isinstance(mean, torch.Tensor) 76 | assert isinstance(var, torch.Tensor) 77 | mean, log_var = ensemble_model.forward_idx(numpy_state, idx_model=0, ret_log_var=True) 78 | assert isinstance(mean, torch.Tensor) 79 | assert isinstance(log_var, torch.Tensor) 80 | -------------------------------------------------------------------------------- /tests/saved_source/test_statistics_tools/SafetyAntVelocity-v1---556c9cedab7db813a6ea3860f5921d7ccbc176d70900e709065fc2604d02b9a6/NaturalPG-{SafetyAntVelocity-v1}/seed-000-2023-04-14-00-42-56/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "seed": 0, 3 | "train_cfgs": { 4 | "device": "cpu", 5 | "torch_threads": 1, 6 | "vector_env_nums": 2, 7 | "parallel": 1, 8 | "total_steps": 4096, 9 | "epochs": 2 10 | }, 11 | "algo_cfgs": { 12 | "steps_per_epoch": 2048, 13 | "update_iters": 10, 14 | "batch_size": 128, 15 | "target_kl": 0.01, 16 | "entropy_coef": 0.0, 17 | "reward_normalize": false, 18 | "cost_normalize": false, 19 | "obs_normalize": true, 20 | "kl_early_stop": false, 21 | "use_max_grad_norm": true, 22 | "max_grad_norm": 40.0, 23 | "use_critic_norm": true, 24 | "critic_norm_coef": 0.001, 25 | "gamma": 0.99, 26 | "cost_gamma": 0.99, 27 | "lam": 0.95, 28 | "lam_c": 0.95, 29 | "clip": 0.2, 30 | "adv_estimation_method": "gae", 31 | "standardized_rew_adv": true, 32 | "standardized_cost_adv": true, 33 | "penalty_coef": 0.0, 34 | "use_cost": false, 35 | "cg_damping": 0.1, 36 | "cg_iters": 15, 37 | "fvp_obs": "None", 38 | "fvp_sample_freq": 1 39 | }, 40 | "logger_cfgs": { 41 | "use_wandb": false, 42 | "wandb_project": "omnisafe", 43 | "use_tensorboard": true, 44 | "save_model_freq": 100, 45 | "log_dir": "", 46 | "window_lens": 100 47 | }, 48 | "model_cfgs": { 49 | "weight_initialization_mode": "kaiming_uniform", 50 | "actor_type": "gaussian_learning", 51 | "linear_lr_decay": false, 52 | "exploration_noise_anneal": false, 53 | "std_range": [ 54 | 0.5, 55 | 0.1 56 | ], 57 | "actor": { 58 | "hidden_sizes": [ 59 | 64, 60 | 64 61 | ], 62 | "activation": "tanh", 63 | "lr": null 64 | }, 65 | "critic": { 66 | "hidden_sizes": [ 67 | 64, 68 | 64 69 | ], 70 | "activation": "tanh", 71 | "lr": 0.001 72 | } 73 | }, 74 | "exp_increment_cfgs": { 75 | "seed": 0, 76 | "algo_cfgs": { 77 | "steps_per_epoch": 2048 78 | }, 79 | "train_cfgs": { 80 | "total_steps": 4096, 81 | "torch_threads": 1, 82 | "vector_env_nums": 2 83 | }, 84 | "logger_cfgs": { 85 | "use_wandb": false, 86 | "log_dir": "" 87 | } 88 | }, 89 | "exp_name": "NaturalPG-{SafetyAntVelocity-v1}", 90 | "env_id": "SafetyAntVelocity-v1", 91 | "algo": "NaturalPG" 92 | } 93 | -------------------------------------------------------------------------------- /omnisafe/algorithms/base_algo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of the Base algorithms.""" 16 | 17 | from __future__ import annotations 18 | 19 | from abc import ABC, abstractmethod 20 | 21 | import torch 22 | 23 | from omnisafe.common.logger import Logger 24 | from omnisafe.utils import distributed 25 | from omnisafe.utils.config import Config 26 | from omnisafe.utils.tools import get_device, seed_all 27 | 28 | 29 | class BaseAlgo(ABC): # pylint: disable=too-few-public-methods 30 | """Base class for all algorithms.""" 31 | 32 | _logger: Logger 33 | 34 | def __init__(self, env_id: str, cfgs: Config) -> None: 35 | """Initialize an instance of algorithm.""" 36 | self._env_id: str = env_id 37 | self._cfgs: Config = cfgs 38 | 39 | assert hasattr(cfgs, 'seed'), 'Please specify the seed in the config file.' 40 | self._seed: int = int(cfgs.seed) + distributed.get_rank() * 1000 41 | seed_all(self._seed) 42 | 43 | assert hasattr(cfgs.train_cfgs, 'device'), 'Please specify the device in the config file.' 44 | self._device: torch.device = get_device(self._cfgs.train_cfgs.device) 45 | 46 | distributed.setup_distributed() 47 | 48 | self._init_env() 49 | self._init_model() 50 | 51 | self._init() 52 | 53 | self._init_log() 54 | 55 | @property 56 | def logger(self) -> Logger: 57 | """Get the logger.""" 58 | return self._logger # pylint: disable=no-member 59 | 60 | @property 61 | def cost_limit(self) -> float | None: 62 | """Get the cost limit.""" 63 | return getattr(self._cfgs.algo_cfgs, '_cost_limit', None) 64 | 65 | @abstractmethod 66 | def _init(self) -> None: 67 | """Initialize the algorithm.""" 68 | 69 | @abstractmethod 70 | def _init_env(self) -> None: 71 | """Initialize the environment.""" 72 | 73 | @abstractmethod 74 | def _init_model(self) -> None: 75 | """Initialize the model.""" 76 | 77 | @abstractmethod 78 | def _init_log(self) -> None: 79 | """Initialize the logger.""" 80 | 81 | @abstractmethod 82 | def learn(self) -> tuple[float, float, float]: 83 | """Learn the policy.""" 84 | -------------------------------------------------------------------------------- /docs/source/envs/wrapper.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Wrapper 2 | ================ 3 | 4 | .. currentmodule:: omnisafe.envs.wrapper 5 | 6 | .. autosummary:: 7 | 8 | TimeLimit 9 | AutoReset 10 | ObsNormalize 11 | RewardNormalize 12 | CostNormalize 13 | ActionScale 14 | Unsqueeze 15 | 16 | Time Limit Wrapper 17 | ------------------ 18 | 19 | .. card:: 20 | :class-header: sd-bg-success sd-text-white 21 | :class-card: sd-outline-success sd-rounded-1 22 | 23 | Documentation 24 | ^^^ 25 | 26 | .. autoclass:: TimeLimit 27 | :members: 28 | :private-members: 29 | 30 | Auto Reset Wrapper 31 | ------------------ 32 | 33 | .. card:: 34 | :class-header: sd-bg-success sd-text-white 35 | :class-card: sd-outline-success sd-rounded-1 36 | 37 | Documentation 38 | ^^^ 39 | 40 | .. autoclass:: AutoReset 41 | :members: 42 | :private-members: 43 | 44 | 45 | Observation Normalization Wrapper 46 | --------------------------------- 47 | 48 | .. card:: 49 | :class-header: sd-bg-success sd-text-white 50 | :class-card: sd-outline-success sd-rounded-1 51 | 52 | Documentation 53 | ^^^ 54 | 55 | .. autoclass:: ObsNormalize 56 | :members: 57 | :private-members: 58 | 59 | Reward Normalization Wrapper 60 | ---------------------------- 61 | 62 | .. card:: 63 | :class-header: sd-bg-success sd-text-white 64 | :class-card: sd-outline-success sd-rounded-1 65 | 66 | Documentation 67 | ^^^ 68 | 69 | .. autoclass:: RewardNormalize 70 | :members: 71 | :private-members: 72 | 73 | Cost Normalization Wrapper 74 | -------------------------- 75 | 76 | .. card:: 77 | :class-header: sd-bg-success sd-text-white 78 | :class-card: sd-outline-success sd-rounded-1 79 | 80 | Documentation 81 | ^^^ 82 | 83 | .. autoclass:: CostNormalize 84 | :members: 85 | :private-members: 86 | 87 | Action Scale 88 | ------------ 89 | 90 | .. card:: 91 | :class-header: sd-bg-success sd-text-white 92 | :class-card: sd-outline-success sd-rounded-1 93 | 94 | Documentation 95 | ^^^ 96 | 97 | .. autoclass:: ActionScale 98 | :members: 99 | :private-members: 100 | 101 | Action Repeat 102 | ------------- 103 | 104 | .. card:: 105 | :class-header: sd-bg-success sd-text-white 106 | :class-card: sd-outline-success sd-rounded-1 107 | 108 | Documentation 109 | ^^^ 110 | 111 | .. autoclass:: ActionRepeat 112 | :members: 113 | :private-members: 114 | 115 | Unsqueeze Wrapper 116 | ----------------- 117 | 118 | .. card:: 119 | :class-header: sd-bg-success sd-text-white 120 | :class-card: sd-outline-success sd-rounded-1 121 | 122 | Documentation 123 | ^^^ 124 | 125 | .. autoclass:: Unsqueeze 126 | :members: 127 | :private-members: 128 | -------------------------------------------------------------------------------- /examples/train_policy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Example of training a policy with OmniSafe.""" 16 | 17 | import argparse 18 | 19 | import omnisafe 20 | from omnisafe.utils.tools import custom_cfgs_to_dict, update_dict 21 | 22 | 23 | if __name__ == '__main__': 24 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 25 | parser.add_argument( 26 | '--algo', 27 | type=str, 28 | metavar='ALGO', 29 | default='PPOLag', 30 | help='algorithm to train', 31 | choices=omnisafe.ALGORITHMS['all'], 32 | ) 33 | parser.add_argument( 34 | '--env-id', 35 | type=str, 36 | metavar='ENV', 37 | default='SafetyPointGoal1-v0', 38 | help='the name of test environment', 39 | ) 40 | parser.add_argument( 41 | '--parallel', 42 | default=1, 43 | type=int, 44 | metavar='N', 45 | help='number of paralleled progress for calculations.', 46 | ) 47 | parser.add_argument( 48 | '--total-steps', 49 | type=int, 50 | default=10000000, 51 | metavar='STEPS', 52 | help='total number of steps to train for algorithm', 53 | ) 54 | parser.add_argument( 55 | '--device', 56 | type=str, 57 | default='cpu', 58 | metavar='DEVICES', 59 | help='device to use for training', 60 | ) 61 | parser.add_argument( 62 | '--vector-env-nums', 63 | type=int, 64 | default=1, 65 | metavar='VECTOR-ENV', 66 | help='number of vector envs to use for training', 67 | ) 68 | parser.add_argument( 69 | '--torch-threads', 70 | type=int, 71 | default=16, 72 | metavar='THREADS', 73 | help='number of threads to use for torch', 74 | ) 75 | args, unparsed_args = parser.parse_known_args() 76 | keys = [k[2:] for k in unparsed_args[0::2]] 77 | values = list(unparsed_args[1::2]) 78 | unparsed_args = dict(zip(keys, values)) 79 | 80 | custom_cfgs = {} 81 | for k, v in unparsed_args.items(): 82 | update_dict(custom_cfgs, custom_cfgs_to_dict(k, v)) 83 | 84 | agent = omnisafe.Agent( 85 | args.algo, 86 | args.env_id, 87 | train_terminal_cfgs=vars(args), 88 | custom_cfgs=custom_cfgs, 89 | ) 90 | agent.learn() 91 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | # Dockerfile for OmniSafe 2 | # 3 | # $ docker build --target base --tag omnisafe:latest . 4 | # 5 | # or 6 | # 7 | # $ docker build --target devel --tag omnisafe-devel:latest . 8 | # 9 | 10 | ARG cuda_docker_tag="11.7.1-cudnn8-devel-ubuntu22.04" 11 | FROM nvidia/cuda:"${cuda_docker_tag}" AS builder 12 | 13 | ENV DEBIAN_FRONTEND=noninteractive 14 | SHELL ["/bin/bash", "-c"] 15 | 16 | # Install packages 17 | RUN apt-get update && \ 18 | apt-get install -y sudo ca-certificates openssl \ 19 | git ssh build-essential gcc g++ cmake make \ 20 | python3-dev python3-venv python3-opengl libosmesa6-dev && \ 21 | rm -rf /var/lib/apt/lists/* 22 | 23 | ENV LANG C.UTF-8 24 | ENV MUJOCO_GL osmesa 25 | ENV PYOPENGL_PLATFORM osmesa 26 | ENV CC=gcc CXX=g++ 27 | 28 | # Add a new user 29 | RUN useradd -m -s /bin/bash omnisafe && \ 30 | echo "omnisafe ALL=(ALL) NOPASSWD:ALL" >> /etc/sudoers 31 | USER omnisafe 32 | RUN echo "export PS1='[\[\e[1;33m\]\u\[\e[0m\]:\[\e[1;35m\]\w\[\e[0m\]]\$ '" >> ~/.bashrc 33 | 34 | # Setup virtual environment 35 | RUN /usr/bin/python3 -m venv --upgrade-deps ~/venv && rm -rf ~/.pip/cache 36 | RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cu$(echo "${CUDA_VERSION}" | cut -d'.' -f-2 | tr -d '.')" && \ 37 | echo "export PIP_EXTRA_INDEX_URL='${PIP_EXTRA_INDEX_URL}'" >> ~/venv/bin/activate && \ 38 | echo "source /home/omnisafe/venv/bin/activate" >> ~/.bashrc 39 | 40 | # Install dependencies 41 | WORKDIR /home/omnisafe/omnisafe 42 | COPY --chown=omnisafe requirements.txt requirements.txt 43 | RUN source ~/venv/bin/activate && \ 44 | python -m pip install -r requirements.txt && \ 45 | rm -rf ~/.pip/cache ~/.cache/pip 46 | 47 | #################################################################################################### 48 | 49 | FROM builder AS devel-builder 50 | 51 | # Install extra dependencies 52 | RUN sudo apt-get update && \ 53 | sudo apt-get install -y golang && \ 54 | sudo chown -R "$(whoami):$(whoami)" "$(realpath /usr/lib/go)" && \ 55 | sudo rm -rf /var/lib/apt/lists/* 56 | 57 | # Install addlicense 58 | ENV GOROOT="/usr/lib/go" 59 | ENV GOBIN="${GOROOT}/bin" 60 | ENV PATH="${GOBIN}:${PATH}" 61 | RUN go install github.com/google/addlicense@latest 62 | 63 | # Install extra PyPI dependencies 64 | COPY --chown=omnisafe tests/requirements.txt tests/requirements.txt 65 | RUN source ~/venv/bin/activate && \ 66 | python -m pip install -r tests/requirements.txt && \ 67 | rm -rf ~/.pip/cache ~/.cache/pip 68 | 69 | #################################################################################################### 70 | 71 | FROM builder AS base 72 | 73 | COPY --chown=omnisafe . . 74 | 75 | # Install omnisafe 76 | RUN source ~/venv/bin/activate && \ 77 | make install-editable && \ 78 | rm -rf .eggs *.egg-info ~/.pip/cache ~/.cache/pip 79 | 80 | ENTRYPOINT [ "/bin/bash", "--login" ] 81 | 82 | #################################################################################################### 83 | 84 | FROM devel-builder AS devel 85 | 86 | COPY --from=base /home/omnisafe/omnisafe . 87 | -------------------------------------------------------------------------------- /omnisafe/configs/offline/CRR.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | defaults: 17 | # seed for random number generator 18 | seed: 0 19 | # training configurations 20 | train_cfgs: 21 | # device to use for training, options: cpu, cuda, cuda:0, cuda:0,1, etc. 22 | device: cpu 23 | # number of threads for torch 24 | torch_threads: 16 25 | # total number of steps to train 26 | total_steps: 1000000 27 | # dataset name 28 | dataset: SafetyPointCircle1-v0_mixed_0.5 29 | # evaluate_epoisodes 30 | evaluate_epoisodes: 10 31 | # parallel, offline only supports 1 32 | parallel: 1 33 | # vector_env_nums, offline only supports 1 34 | vector_env_nums: 1 35 | # algorithm configurations 36 | algo_cfgs: 37 | # gamma used in RL 38 | gamma: 0.99 39 | # beat in Crr, f := exp(A(s, a) / beta) 40 | beta: 1 41 | # batch size 42 | batch_size: 256 43 | # step per epoch, algo will log and eval every epoch 44 | steps_per_epoch: 1000 45 | # phi used in BCQ 46 | phi: 0.05 47 | # sample action numbers when update critic 48 | sampled_action_num: 10 49 | # minimum weighting when compute Q, Q = w * min(q1, q2) + (1 - w) * max(q1, q2) 50 | polyak: 0.005 51 | # logger configurations 52 | logger_cfgs: 53 | # use wandb for logging 54 | use_wandb: False 55 | # wandb project name 56 | wandb_project: omnisafe 57 | # use tensorboard for logging 58 | use_tensorboard: True 59 | # save model frequency 60 | save_model_freq: 100 61 | # save logger path 62 | log_dir: "./runs" 63 | # model configurations 64 | model_cfgs: 65 | # The mode to initiate the weight of network, choosing from "kaiming_uniform", "xavier_normal", "glorot" and "orthogonal". 66 | weight_initialization_mode: "kaiming_uniform" 67 | # actor's cfgs 68 | actor: 69 | # Size of hidden layers 70 | hidden_sizes: [256, 256, 256] 71 | # Type of activation function, choosing from "tanh", "relu", "sigmoid", "identity", "softplus" 72 | activation: relu 73 | # Learning rate of model 74 | lr: 0.001 75 | # critic's cfgs 76 | critic: 77 | # Size of hidden layers 78 | hidden_sizes: [256, 256, 256] 79 | # Type of activation function, choosing from "tanh", "relu", "sigmoid", "identity", "softplus" 80 | activation: relu 81 | # Learning rate of model 82 | lr: 0.001 83 | -------------------------------------------------------------------------------- /omnisafe/configs/offline/BCQ.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | defaults: 17 | # seed for random number generator 18 | seed: 0 19 | # training configurations 20 | train_cfgs: 21 | # device to use for training, options: cpu, cuda, cuda:0, cuda:0,1, etc. 22 | device: cpu 23 | # number of threads for torch 24 | torch_threads: 16 25 | # total number of steps to train 26 | total_steps: 1000000 27 | # dataset name 28 | dataset: SafetyPointCircle1-v0_mixed_0.5 29 | # evaluate_epoisodes 30 | evaluate_epoisodes: 10 31 | # parallel, offline only supports 1 32 | parallel: 1 33 | # vector_env_nums, offline only supports 1 34 | vector_env_nums: 1 35 | # algorithm configurations 36 | algo_cfgs: 37 | # gamma used in RL 38 | gamma: 0.99 39 | # batch size 40 | batch_size: 256 41 | # step per epoch, algo will log and eval every epoch 42 | steps_per_epoch: 1000 43 | # phi used in BCQ 44 | phi: 0.05 45 | # sample action numbers when update critic 46 | sampled_action_num: 10 47 | # minimum weighting when compute Q, Q = w * min(q1, q2) + (1 - w) * max(q1, q2) 48 | minimum_weighting: 0.75 49 | # The soft update coefficient 50 | polyak: 0.005 51 | # logger configurations 52 | logger_cfgs: 53 | # use wandb for logging 54 | use_wandb: False 55 | # wandb project name 56 | wandb_project: omnisafe 57 | # use tensorboard for logging 58 | use_tensorboard: True 59 | # save model frequency 60 | save_model_freq: 100 61 | # save logger path 62 | log_dir: "./runs" 63 | # model configurations 64 | model_cfgs: 65 | # The mode to initiate the weight of network, choosing from "kaiming_uniform", "xavier_normal", "glorot" and "orthogonal". 66 | weight_initialization_mode: "kaiming_uniform" 67 | # actor's cfgs 68 | actor: 69 | # Size of hidden layers 70 | hidden_sizes: [256, 256, 256] 71 | # Type of activation function, choosing from "tanh", "relu", "sigmoid", "identity", "softplus" 72 | activation: relu 73 | # Learning rate of model 74 | lr: 0.001 75 | # critic's cfgs 76 | critic: 77 | # Size of hidden layers 78 | hidden_sizes: [256, 256, 256] 79 | # Type of activation function, choosing from "tanh", "relu", "sigmoid", "identity", "softplus" 80 | activation: relu 81 | # Learning rate of model 82 | lr: 0.001 83 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/penalty_function/ipo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of IPO algorithm.""" 16 | 17 | import torch 18 | 19 | from omnisafe.algorithms import registry 20 | from omnisafe.algorithms.on_policy.base.ppo import PPO 21 | 22 | 23 | @registry.register 24 | class IPO(PPO): 25 | """The Implementation of the IPO algorithm. 26 | 27 | References: 28 | - Title: IPO: Interior-point Policy Optimization under Constraints 29 | - Authors: Yongshuai Liu, Jiaxin Ding, Xin Liu. 30 | - URL: `IPO `_ 31 | """ 32 | 33 | def _init_log(self) -> None: 34 | """Log the IPO specific information. 35 | 36 | +---------------+--------------------------+ 37 | | Things to log | Description | 38 | +===============+==========================+ 39 | | Misc/Penalty | The penalty coefficient. | 40 | +---------------+--------------------------+ 41 | """ 42 | super()._init_log() 43 | self._logger.register_key('Misc/Penalty') 44 | 45 | def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> torch.Tensor: 46 | r"""Compute surrogate loss. 47 | 48 | IPO uses the following surrogate loss: 49 | 50 | .. math:: 51 | 52 | L = -\underset{s_t \sim \rho_{\theta}}{\mathbb{E}} \left[ 53 | \frac{\pi_{\theta}^{'} (a_t|s_t)}{\pi_{\theta} (a_t|s_t)} A (s_t, a_t) 54 | - \kappa \frac{J^{C}_{\pi_{\theta}} (s_t, a_t)}{C - J^{C}_{\pi_{\theta}} (s_t, a_t) + \epsilon} 55 | \right] 56 | 57 | Where :math:`\kappa` is the penalty coefficient, :math:`C` is the cost limit, 58 | and :math:`\epsilon` is a small number to avoid division by zero. 59 | 60 | Args: 61 | adv_r (torch.Tensor): The ``reward_advantage`` sampled from buffer. 62 | adv_c (torch.Tensor): The ``cost_advantage`` sampled from buffer. 63 | 64 | Returns: 65 | The advantage function combined with reward and cost. 66 | """ 67 | Jc = self._logger.get_stats('Metrics/EpCost')[0] 68 | penalty = self._cfgs.algo_cfgs.kappa / (self._cfgs.algo_cfgs.cost_limit - Jc + 1e-8) 69 | if penalty < 0 or penalty > self._cfgs.algo_cfgs.penalty_max: 70 | penalty = self._cfgs.algo_cfgs.penalty_max 71 | 72 | self._logger.store({'Misc/Penalty': penalty}) 73 | 74 | return (adv_r - penalty * adv_c) / (1 + penalty) 75 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/saute/ppo_saute.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of the Saute version of the PPO algorithm.""" 16 | 17 | 18 | from omnisafe.adapter.saute_adapter import SauteAdapter 19 | from omnisafe.algorithms import registry 20 | from omnisafe.algorithms.on_policy.base.ppo import PPO 21 | from omnisafe.utils import distributed 22 | 23 | 24 | @registry.register 25 | class PPOSaute(PPO): 26 | """The Saute version of the PPO algorithm. 27 | 28 | A simple combination of the Saute RL and the Proximal Policy Optimization algorithm. 29 | 30 | References: 31 | - Title: Saute RL: Almost Surely Safe Reinforcement Learning Using State Augmentation 32 | - Authors: Aivar Sootla, Alexander I. Cowen-Rivers, Taher Jafferjee, Ziyan Wang, 33 | David Mguni, Jun Wang, Haitham Bou-Ammar. 34 | - URL: `PPOSaute `_ 35 | """ 36 | 37 | def _init_env(self) -> None: 38 | """Initialize the environment. 39 | 40 | OmniSafe uses :class:`omnisafe.adapter.SauteAdapter` to adapt the environment to the algorithm. 41 | 42 | User can customize the environment by inheriting this method. 43 | 44 | Examples: 45 | >>> def _init_env(self) -> None: 46 | ... self._env = CustomAdapter() 47 | """ 48 | self._env: SauteAdapter = SauteAdapter( 49 | self._env_id, 50 | self._cfgs.train_cfgs.vector_env_nums, 51 | self._seed, 52 | self._cfgs, 53 | ) 54 | assert (self._cfgs.algo_cfgs.steps_per_epoch) % ( 55 | distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums 56 | ) == 0, 'The number of steps per epoch is not divisible by the number of environments.' 57 | self._steps_per_epoch: int = ( 58 | self._cfgs.algo_cfgs.steps_per_epoch 59 | // distributed.world_size() 60 | // self._cfgs.train_cfgs.vector_env_nums 61 | ) 62 | 63 | def _init_log(self) -> None: 64 | """Log the PPOSaute specific information. 65 | 66 | +------------------+-----------------------------------+ 67 | | Things to log | Description | 68 | +==================+===================================+ 69 | | Metrics/EpBudget | The safety budget of the episode. | 70 | +------------------+-----------------------------------+ 71 | """ 72 | super()._init_log() 73 | self._logger.register_key('Metrics/EpBudget') 74 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/saute/trpo_saute.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of the Saute version of the TRPO algorithm.""" 16 | 17 | 18 | from omnisafe.adapter.saute_adapter import SauteAdapter 19 | from omnisafe.algorithms import registry 20 | from omnisafe.algorithms.on_policy.base.trpo import TRPO 21 | from omnisafe.utils import distributed 22 | 23 | 24 | @registry.register 25 | class TRPOSaute(TRPO): 26 | """The Saute version of the TRPO algorithm. 27 | 28 | A simple combination of the Saute RL and the Trust Region Policy Optimization algorithm. 29 | 30 | References: 31 | - Title: Saute RL: Almost Surely Safe Reinforcement Learning Using State Augmentation 32 | - Authors: Aivar Sootla, Alexander I. Cowen-Rivers, Taher Jafferjee, Ziyan Wang, 33 | David Mguni, Jun Wang, Haitham Bou-Ammar. 34 | - URL: `TRPOSaute `_ 35 | """ 36 | 37 | def _init_env(self) -> None: 38 | """Initialize the environment. 39 | 40 | OmniSafe uses :class:`omnisafe.adapter.SauteAdapter` to adapt the environment to the algorithm. 41 | 42 | User can customize the environment by inheriting this method. 43 | 44 | Examples: 45 | >>> def _init_env(self) -> None: 46 | ... self._env = CustomAdapter() 47 | """ 48 | self._env: SauteAdapter = SauteAdapter( 49 | self._env_id, 50 | self._cfgs.train_cfgs.vector_env_nums, 51 | self._seed, 52 | self._cfgs, 53 | ) 54 | assert (self._cfgs.algo_cfgs.steps_per_epoch) % ( 55 | distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums 56 | ) == 0, 'The number of steps per epoch is not divisible by the number of environments.' 57 | self._steps_per_epoch: int = ( 58 | self._cfgs.algo_cfgs.steps_per_epoch 59 | // distributed.world_size() 60 | // self._cfgs.train_cfgs.vector_env_nums 61 | ) 62 | 63 | def _init_log(self) -> None: 64 | """Log the TRPOSaute specific information. 65 | 66 | +------------------+-----------------------------------+ 67 | | Things to log | Description | 68 | +==================+===================================+ 69 | | Metrics/EpBudget | The safety budget of the episode. | 70 | +------------------+-----------------------------------+ 71 | """ 72 | super()._init_log() 73 | self._logger.register_key('Metrics/EpBudget') 74 | -------------------------------------------------------------------------------- /docs/source/start/env.rst: -------------------------------------------------------------------------------- 1 | Environments Customization 2 | =========================== 3 | 4 | OmniSafe supports a flexible environment customization interface. Users only need to make minimal 5 | interface adaptations within the simplest template provided by OmniSafe to complete the environment 6 | customization. 7 | 8 | .. note:: 9 | The highlight of OmniSafe's environment customization is that **users only need to modify the code at the environment layer**, to enjoy OmniSafe's complete set of training, saving, and data logging mechanisms. This allows users who install from PyPI to use it easily and only focus on the dynamics of the environment. 10 | 11 | 12 | Get Started with the Simplest Template 13 | -------------------------------------- 14 | 15 | OmniSafe offers a minimal implementation of an environment template as an example of a customized 16 | environments, :doc:`../envs/custom`. 17 | We recommend reading this template in detail and customizing it based on it. 18 | 19 | .. card:: 20 | :class-header: sd-bg-success sd-text-white 21 | :class-card: sd-outline-success sd-rounded-1 22 | 23 | Frequently Asked Questions 24 | ^^^ 25 | 1. What changes are necessary to embed the environment into OmniSafe? 26 | 2. My environment requires specific parameters; can these be integrated into OmniSafe's parameter mechanism? 27 | 3. I need to log information during training; how can I achieve this? 28 | 4. After embedding the environment, how do I run the algorithms in OmniSafe for training? 29 | 30 | For the above questions, we provide a complete Jupyter Notebook example (Please see our tutorial on 31 | GitHub page). We will demonstrate how to start from the most common environments in 32 | `Gymnasium `_ style, implement 33 | environment customization and complete the training process. 34 | 35 | 36 | Customization of Your Environments 37 | ----------------------------------- 38 | 39 | From Source Code 40 | ^^^^^^^^^^^^^^^^ 41 | 42 | If you are installing from the source code, you can follow the steps below: 43 | 44 | .. card:: 45 | :class-header: sd-bg-success sd-text-white 46 | :class-card: sd-outline-success sd-rounded-1 47 | 48 | Build from Source Code 49 | ^^^ 50 | 1. Create a new file under `omnisafe/envs/`, for example, `omnisafe/envs/my_env.py`. 51 | 2. Customize the environment in `omnisafe/envs/my_env.py`. Assuming the class name is `MyEnv`, and the environment name is `MyEnv-v0`. 52 | 3. Add `from .my_env import MyEnv` in `omnisafe/envs/__init__.py`. 53 | 4. Run the following command in the `omnisafe/examples` folder: 54 | 55 | .. code-block:: bash 56 | :linenos: 57 | 58 | python train_policy.py --algo PPOLag --env MyEnv-v0 59 | 60 | From PyPI 61 | ^^^^^^^^^ 62 | 63 | .. card:: 64 | :class-header: sd-bg-success sd-text-white 65 | :class-card: sd-outline-success sd-rounded-1 66 | 67 | Build from PyPI 68 | ^^^ 69 | 1. Customize the environment in any folder. Assuming the class name is `MyEnv`, and the environment name is `MyEnv-v0`. 70 | 2. Import OmniSafe and the environment registration decorator. 71 | 3. Run the training. 72 | 73 | For a short but detailed example, please see `examples/train_from_custom_env.py` 74 | -------------------------------------------------------------------------------- /examples/train_from_custom_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Example and template for environment customization.""" 16 | 17 | from __future__ import annotations 18 | 19 | import random 20 | from typing import Any, ClassVar 21 | 22 | import torch 23 | from gymnasium import spaces 24 | 25 | import omnisafe 26 | from omnisafe.envs.core import CMDP, env_register 27 | 28 | 29 | # first, define the environment class. 30 | # the most important thing is to add the `env_register` decorator. 31 | @env_register 32 | class CustomExampleEnv(CMDP): 33 | 34 | # define what tasks the environment support. 35 | _support_envs: ClassVar[list[str]] = ['Custom-v0'] 36 | 37 | # automatically reset when `terminated` or `truncated` 38 | need_auto_reset_wrapper = True 39 | # set `truncated=True` when the total steps exceed the time limit. 40 | need_time_limit_wrapper = True 41 | 42 | def __init__(self, env_id: str, **kwargs: dict[str, Any]) -> None: 43 | self._count = 0 44 | self._num_envs = 1 45 | self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,)) 46 | self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,)) 47 | 48 | def step( 49 | self, 50 | action: torch.Tensor, 51 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]: 52 | self._count += 1 53 | obs = torch.as_tensor(self._observation_space.sample()) 54 | reward = 2 * torch.as_tensor(random.random()) # noqa 55 | cost = 2 * torch.as_tensor(random.random()) # noqa 56 | terminated = torch.as_tensor(random.random() > 0.9) # noqa 57 | truncated = torch.as_tensor(self._count > self.max_episode_steps) 58 | return obs, reward, cost, terminated, truncated, {'final_observation': obs} 59 | 60 | @property 61 | def max_episode_steps(self) -> int: 62 | """The max steps per episode.""" 63 | return 10 64 | 65 | def reset( 66 | self, 67 | seed: int | None = None, 68 | options: dict[str, Any] | None = None, 69 | ) -> tuple[torch.Tensor, dict]: 70 | self.set_seed(seed) 71 | obs = torch.as_tensor(self._observation_space.sample()) 72 | self._count = 0 73 | return obs, {} 74 | 75 | def set_seed(self, seed: int) -> None: 76 | random.seed(seed) 77 | 78 | def close(self) -> None: 79 | pass 80 | 81 | def render(self) -> Any: 82 | pass 83 | 84 | 85 | # Then you can use it like this: 86 | agent = omnisafe.Agent( 87 | 'PPOLag', 88 | 'Custom-v0', 89 | ) 90 | agent.learn() 91 | -------------------------------------------------------------------------------- /examples/benchmarks/run_experiment_grid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Example of training a policy from exp-x config with OmniSafe.""" 16 | 17 | import warnings 18 | 19 | import torch 20 | 21 | from omnisafe.common.experiment_grid import ExperimentGrid 22 | from omnisafe.utils.exp_grid_tools import train 23 | 24 | 25 | if __name__ == '__main__': 26 | eg = ExperimentGrid(exp_name='Benchmark_Safety_Velocity') 27 | 28 | # Set the algorithms. 29 | base_policy = ['PolicyGradient', 'NaturalPG', 'TRPO', 'PPO'] 30 | naive_lagrange_policy = ['PPOLag', 'TRPOLag', 'RCPO', 'OnCRPO', 'PDO'] 31 | first_order_policy = ['CUP', 'FOCOPS', 'P3O'] 32 | second_order_policy = ['CPO', 'PCPO'] 33 | 34 | # Set the environments. 35 | mujoco_envs = [ 36 | 'SafetyAntVelocity-v1', 37 | 'SafetyHopperVelocity-v1', 38 | 'SafetyHumanoidVelocity-v1', 39 | 'SafetyWalker2dVelocity-v1', 40 | 'SafetyHalfCheetahVelocity-v1', 41 | 'SafetySwimmerVelocity-v1', 42 | ] 43 | eg.add('env_id', mujoco_envs) 44 | 45 | # Set the device. 46 | avaliable_gpus = list(range(torch.cuda.device_count())) 47 | gpu_id = [0, 1, 2, 3] 48 | # if you want to use CPU, please set gpu_id = None 49 | # gpu_id = None 50 | 51 | if gpu_id and not set(gpu_id).issubset(avaliable_gpus): 52 | warnings.warn('The GPU ID is not available, use CPU instead.', stacklevel=1) 53 | gpu_id = None 54 | 55 | eg.add('algo', base_policy + naive_lagrange_policy + first_order_policy + second_order_policy) 56 | eg.add('logger_cfgs:use_wandb', [False]) 57 | eg.add('train_cfgs:vector_env_nums', [4]) 58 | eg.add('train_cfgs:torch_threads', [1]) 59 | eg.add('algo_cfgs:steps_per_epoch', [20000]) 60 | eg.add('train_cfgs:total_steps', [10000000]) 61 | eg.add('seed', [0]) 62 | # total experiment num must can be divided by num_pool 63 | # meanwhile, users should decide this value according to their machine 64 | eg.run(train, num_pool=12, gpu_id=gpu_id) 65 | 66 | # just fill in the name of the parameter of which value you want to compare. 67 | # then you can specify the value of the parameter you want to compare, 68 | # or you can just specify how many values you want to compare in single graph at most, 69 | # and the function will automatically generate all possible combinations of the graph. 70 | # but the two mode can not be used at the same time. 71 | eg.analyze(parameter='env_id', values=None, compare_num=6, cost_limit=25) 72 | eg.render(num_episodes=1, render_mode='rgb_array', width=256, height=256) 73 | eg.evaluate(num_episodes=1) 74 | -------------------------------------------------------------------------------- /tests/simple_env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Simplest environment for testing.""" 16 | 17 | from __future__ import annotations 18 | 19 | import random 20 | from typing import Any, ClassVar 21 | 22 | import numpy as np 23 | import torch 24 | from gymnasium import spaces 25 | 26 | from omnisafe.envs.core import CMDP, env_register 27 | from omnisafe.typing import OmnisafeSpace 28 | 29 | 30 | @env_register 31 | class TestEnv(CMDP): 32 | """Simplest environment for testing.""" 33 | 34 | _support_envs: ClassVar[list[str]] = ['Test-v0'] 35 | metadata: ClassVar[dict[str, int]] = {'render_fps': 30} 36 | need_auto_reset_wrapper = True 37 | need_time_limit_wrapper = True 38 | _num_envs = 1 39 | _coordinate_observation_space: OmnisafeSpace 40 | 41 | def __init__(self, env_id: str, **kwargs) -> None: 42 | self._count = 0 43 | self._observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,)) 44 | self._action_space = spaces.Box(low=-1.0, high=1.0, shape=(2,)) 45 | self._coordinate_observation_space = spaces.Box(low=-1.0, high=1.0, shape=(3,)) 46 | 47 | @property 48 | def get_cost_from_obs_tensor(self) -> None: 49 | return None 50 | 51 | @property 52 | def coordinate_observation_space(self) -> OmnisafeSpace: 53 | return self._coordinate_observation_space 54 | 55 | def step( 56 | self, 57 | action: torch.Tensor, 58 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict]: 59 | self._count += 1 60 | obs = torch.as_tensor(self._observation_space.sample()) 61 | reward = 10000 * torch.as_tensor(random.random()) 62 | cost = 10000 * torch.as_tensor(random.random()) 63 | terminated = torch.as_tensor(random.random() > 0.9) 64 | truncated = torch.as_tensor(self._count > self.max_episode_steps) 65 | return obs, reward, cost, terminated, truncated, {'final_observation': obs} 66 | 67 | @property 68 | def max_episode_steps(self) -> int: 69 | """The max steps per episode.""" 70 | return 10 71 | 72 | def reset( 73 | self, 74 | seed: int | None = None, 75 | options: dict[str, Any] | None = None, 76 | ) -> tuple[torch.Tensor, dict]: 77 | if seed is not None: 78 | self.set_seed(seed) 79 | obs = torch.as_tensor(self._observation_space.sample()) 80 | self._count = 0 81 | return obs, {} 82 | 83 | def set_seed(self, seed: int) -> None: 84 | random.seed(seed) 85 | 86 | def render(self) -> Any: 87 | return np.zeros((100, 100, 3), dtype=np.uint8) 88 | 89 | def close(self) -> None: 90 | pass 91 | -------------------------------------------------------------------------------- /docs/source/utils/config.rst: -------------------------------------------------------------------------------- 1 | OmniSafe Config 2 | =============== 3 | 4 | .. currentmodule:: omnisafe.utils.config 5 | 6 | .. autosummary:: 7 | 8 | Config 9 | ModelConfig 10 | get_default_kwargs_yaml 11 | check_all_configs 12 | __check_algo_configs 13 | __check_logger_configs 14 | 15 | 16 | Config 17 | ------ 18 | 19 | OmniSafe uses yaml file to store all the configurations. The configuration file 20 | is stored in ``omnisafe/configs``. The configuration file is divided into 21 | several parts. 22 | 23 | Take ``PPOLag`` as an example, the configuration file is as follows: 24 | 25 | .. list-table:: 26 | 27 | * - Config 28 | - Description 29 | * - ``train_cfgs`` 30 | - Training configurations. 31 | * - ``algo_cfgs`` 32 | - Algorithm configurations 33 | * - ``logger_cfgs`` 34 | - Logger configurations 35 | * - ``model_cfgs`` 36 | - Model configurations 37 | * - ``lagrange_cfgs`` 38 | - Lagrange configurations 39 | 40 | Specifically, the ``train_cfgs`` is as follows: 41 | 42 | .. list-table:: 43 | 44 | * - Config 45 | - Description 46 | - Value 47 | * - ``device`` 48 | - Device to use. 49 | - ``cuda`` or ``cpu`` 50 | * - ``torch_threads`` 51 | - Number of threads to use. 52 | - 16 53 | * - ``vector_env_nums`` 54 | - Number of vectorized environments. 55 | - 1 56 | * - ``parallel`` 57 | - Number of parallel agent, similar to A3C. 58 | - 1 59 | * - ``total_steps`` 60 | - Total number of training steps. 61 | - 1000000 62 | 63 | Other configurations are similar to ``train_cfgs``. You can refer to the ``omnisafe/configs`` for more details. 64 | 65 | .. card:: 66 | :class-header: sd-bg-success sd-text-white 67 | :class-card: sd-outline-success sd-rounded-1 68 | 69 | Documentation 70 | ^^^ 71 | 72 | .. autoclass:: Config 73 | :members: 74 | :private-members: 75 | 76 | Model Config 77 | ------------ 78 | 79 | .. card:: 80 | :class-header: sd-bg-success sd-text-white 81 | :class-card: sd-outline-success sd-rounded-1 82 | 83 | Documentation 84 | ^^^ 85 | 86 | .. autoclass:: ModelConfig 87 | :members: 88 | :private-members: 89 | 90 | 91 | 92 | Common Method 93 | ------------- 94 | 95 | .. card:: 96 | :class-header: sd-bg-success sd-text-white 97 | :class-card: sd-outline-success sd-rounded-1 98 | 99 | Documentation 100 | ^^^ 101 | 102 | .. autofunction:: get_default_kwargs_yaml 103 | 104 | .. card:: 105 | :class-header: sd-bg-success sd-text-white 106 | :class-card: sd-outline-success sd-rounded-1 107 | 108 | Documentation 109 | ^^^ 110 | 111 | .. autofunction:: check_all_configs 112 | 113 | .. card:: 114 | :class-header: sd-bg-success sd-text-white 115 | :class-card: sd-outline-success sd-rounded-1 116 | 117 | Documentation 118 | ^^^ 119 | 120 | .. autofunction:: __check_algo_configs 121 | 122 | .. card:: 123 | :class-header: sd-bg-success sd-text-white 124 | :class-card: sd-outline-success sd-rounded-1 125 | 126 | Documentation 127 | ^^^ 128 | 129 | .. autofunction:: __check_logger_configs 130 | -------------------------------------------------------------------------------- /docs/source/start/algo.md: -------------------------------------------------------------------------------- 1 | # Supported Algorithms 2 | 3 | OmniSafe offers a highly modular framework that integrates an extensive collection of algorithms specifically designed for Safe Reinforcement Learning (SafeRL) in various domains. The `Adapter` module in OmniSafe allows for easily expanding different types of SafeRL algorithms. 4 | 5 | 6 | 7 | 8 | 9 | 33 | 34 | 35 | 36 |
37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 |
DomainsTypesAlgorithms Registry
On PolicyPrimal DualTRPOLag; PPOLag; PDO; RCPO
TRPOPID; CPPOPID
Convex OptimizationCPO; PCPO; FOCOPS; CUP
Penalty FunctionIPO; P3O
PrimalOnCRPO
Off PolicyPrimal-DualDDPGLag; TD3Lag; SACLag
DDPGPID; TD3PID; SACPID
Control Barrier FunctionDDPGCBF, SACRCBF, CRABS
Model-basedOnline PlanSafeLOOP; CCEPETS; RCEPETS
Pessimistic EstimateCAPPETS
OfflineQ-Learning BasedBCQLag; C-CRR
DICE BasedCOptDICE
Other Formulation MDPET-MDPPPOEarlyTerminated; TRPOEarlyTerminated
SauteRLPPOSaute; TRPOSaute
SimmerRLPPOSimmerPID; TRPOSimmerPID
109 |
110 | 111 |

Table 1: OmniSafe supports varieties of SafeRL algorithms. From the perspective of classic RL, OmniSafe includes on-policy, off-policy, offline, and model-based algorithms; From the perspective of the SafeRL learning paradigm, OmniSafe supports primal-dual, projection, penalty function, primal, etc.

112 | -------------------------------------------------------------------------------- /omnisafe/configs/offline/CCRR.yaml: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | defaults: 17 | # seed for random number generator 18 | seed: 0 19 | # training configurations 20 | train_cfgs: 21 | # device to use for training, options: cpu, cuda, cuda:0, cuda:0,1, etc. 22 | device: cpu 23 | # number of threads for torch 24 | torch_threads: 16 25 | # total number of steps to train 26 | total_steps: 1000000 27 | # dataset name 28 | dataset: SafetyPointCircle1-v0_mixed_0.5 29 | # evaluate_epoisodes 30 | evaluate_epoisodes: 10 31 | # parallel, offline only supports 1 32 | parallel: 1 33 | # vector_env_nums, offline only supports 1 34 | vector_env_nums: 1 35 | # algorithm configurations 36 | algo_cfgs: 37 | # gamma used in RL 38 | gamma: 0.99 39 | # beat in Crr, f := exp(A(s, a) / beta) 40 | beta: 1 41 | # batch size 42 | batch_size: 256 43 | # step per epoch, algo will log and eval every epoch 44 | steps_per_epoch: 1000 45 | # phi used in BCQ 46 | phi: 0.05 47 | # sample action numbers when update critic 48 | sampled_action_num: 10 49 | # minimum weighting when compute Q, Q = w * min(q1, q2) + (1 - w) * max(q1, q2) 50 | polyak: 0.005 51 | # when starting lagrange update 52 | lagrange_start_step: 150000 53 | # logger configurations 54 | logger_cfgs: 55 | # use wandb for logging 56 | use_wandb: False 57 | # wandb project name 58 | wandb_project: omnisafe 59 | # use tensorboard for logging 60 | use_tensorboard: True 61 | # save model frequency 62 | save_model_freq: 100 63 | # save logger path 64 | log_dir: "./runs" 65 | # model configurations 66 | model_cfgs: 67 | # The mode to initiate the weight of network, choosing from "kaiming_uniform", "xavier_normal", "glorot" and "orthogonal". 68 | weight_initialization_mode: "kaiming_uniform" 69 | # actor's cfgs 70 | actor: 71 | # Size of hidden layers 72 | hidden_sizes: [256, 256, 256] 73 | # Type of activation function, choosing from "tanh", "relu", "sigmoid", "identity", "softplus" 74 | activation: relu 75 | # Learning rate of model 76 | lr: 0.001 77 | # critic's cfgs 78 | critic: 79 | # Size of hidden layers 80 | hidden_sizes: [256, 256, 256] 81 | # Type of activation function, choosing from "tanh", "relu", "sigmoid", "identity", "softplus" 82 | activation: relu 83 | # Learning rate of model 84 | lr: 0.001 85 | # lagrangian configurations 86 | lagrange_cfgs: 87 | # Tolerance of constraint violation 88 | cost_limit: 25.0 89 | # Initial value of lagrangian multiplier 90 | lagrangian_multiplier_init: 1.0 91 | # Learning rate of lagrangian multiplier 92 | lambda_lr: 0.001 93 | # Type of lagrangian optimizer 94 | lambda_optimizer: "Adam" 95 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/primal/crpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of the on-policy CRPO algorithm.""" 16 | 17 | import torch 18 | 19 | from omnisafe.algorithms import registry 20 | from omnisafe.algorithms.on_policy.base.trpo import TRPO 21 | from omnisafe.utils.config import Config 22 | 23 | 24 | @registry.register 25 | class OnCRPO(TRPO): 26 | """The on-policy CRPO algorithm. 27 | 28 | References: 29 | - Title: CRPO: A New Approach for Safe Reinforcement Learning with Convergence Guarantee. 30 | - Authors: Tengyu Xu, Yingbin Liang, Guanghui Lan. 31 | - URL: `CRPO `_. 32 | """ 33 | 34 | def __init__(self, env_id: str, cfgs: Config) -> None: 35 | """Initialize an instance of :class:`OnCRPO`.""" 36 | super().__init__(env_id, cfgs) 37 | self._rew_update: int = 0 38 | self._cost_update: int = 0 39 | 40 | def _init_log(self) -> None: 41 | """Log the CRPO specific information. 42 | 43 | +-----------------+--------------------------------------------+ 44 | | Things to log | Description | 45 | +=================+============================================+ 46 | | Misc/RewUpdate | The number of times the reward is updated. | 47 | +-----------------+--------------------------------------------+ 48 | | Misc/CostUpdate | The number of times the cost is updated. | 49 | +-----------------+--------------------------------------------+ 50 | """ 51 | super()._init_log() 52 | self._logger.register_key('Misc/RewUpdate') 53 | self._logger.register_key('Misc/CostUpdate') 54 | 55 | def _compute_adv_surrogate(self, adv_r: torch.Tensor, adv_c: torch.Tensor) -> torch.Tensor: 56 | """Compute the advantage surrogate. 57 | 58 | In CRPO algorithm, we first judge whether the cost is within the limit. If the cost is 59 | within the limit, we use the advantage of the policy. Otherwise, we use the advantage of the 60 | cost. 61 | 62 | Args: 63 | adv_r (torch.Tensor): The ``reward_advantage`` sampled from buffer. 64 | adv_c (torch.Tensor): The ``cost_advantage`` sampled from buffer. 65 | 66 | Returns: 67 | The advantage function chosen from reward and cost. 68 | """ 69 | Jc = self._logger.get_stats('Metrics/EpCost')[0] 70 | if Jc <= self._cfgs.algo_cfgs.cost_limit + self._cfgs.algo_cfgs.distance: 71 | self._rew_update += 1 72 | return adv_r 73 | self._cost_update += 1 74 | self._logger.store( 75 | { 76 | 'Misc/RewUpdate': self._rew_update, 77 | 'Misc/CostUpdate': self._cost_update, 78 | }, 79 | ) 80 | return -adv_c 81 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/simmer/ppo_simmer_pid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of the Simmer version of the PPO algorithm.""" 16 | 17 | import torch 18 | 19 | from omnisafe.adapter.simmer_adapter import SimmerAdapter 20 | from omnisafe.algorithms import registry 21 | from omnisafe.algorithms.on_policy.base.ppo import PPO 22 | from omnisafe.utils import distributed 23 | 24 | 25 | @registry.register 26 | class PPOSimmerPID(PPO): 27 | """The Simmer version(based on PID controller) of the PPO algorithm. 28 | 29 | A simple combination of the Simmer RL and the Proximal Policy Optimization algorithm. 30 | 31 | References: 32 | - Title: Effects of Safety State Augmentation on Safe Exploration. 33 | - Authors: Aivar Sootla, Alexander I. Cowen-Rivers, Jun Wang, Haitham Bou Ammar. 34 | - URL: `PPOSimmerPID `_ 35 | """ 36 | 37 | def _init_env(self) -> None: 38 | """Initialize the environment. 39 | 40 | OmniSafe uses :class:`omnisafe.adapter.SimmerAdapter` to adapt the environment to the algorithm. 41 | 42 | User can customize the environment by inheriting this method. 43 | 44 | Examples: 45 | >>> def _init_env(self) -> None: 46 | ... self._env = CustomAdapter() 47 | """ 48 | self._env: SimmerAdapter = SimmerAdapter( 49 | self._env_id, 50 | self._cfgs.train_cfgs.vector_env_nums, 51 | self._seed, 52 | self._cfgs, 53 | ) 54 | assert (self._cfgs.algo_cfgs.steps_per_epoch) % ( 55 | distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums 56 | ) == 0, 'The number of steps per epoch is not divisible by the number of environments.' 57 | self._steps_per_epoch: int = ( 58 | self._cfgs.algo_cfgs.steps_per_epoch 59 | // distributed.world_size() 60 | // self._cfgs.train_cfgs.vector_env_nums 61 | ) 62 | 63 | def _init_log(self) -> None: 64 | """Log the PPOSimmerPID specific information. 65 | 66 | +------------------+-----------------------------------+ 67 | | Things to log | Description | 68 | +==================+===================================+ 69 | | Metrics/EpBudget | The safety budget of the episode. | 70 | +------------------+-----------------------------------+ 71 | """ 72 | super()._init_log() 73 | self._logger.register_key('Metrics/EpBudget') 74 | 75 | def _update(self) -> None: 76 | """Update actor, critic, as we used in the :class:`PolicyGradient` algorithm.""" 77 | Jc = self._logger.get_stats('Metrics/EpCost')[0] 78 | self._env.control_budget(torch.as_tensor(Jc, dtype=torch.float32, device=self._device)) 79 | super()._update() 80 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/simmer/trpo_simmer_pid.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of the Simmer version of the TRPO algorithm.""" 16 | 17 | import torch 18 | 19 | from omnisafe.adapter.simmer_adapter import SimmerAdapter 20 | from omnisafe.algorithms import registry 21 | from omnisafe.algorithms.on_policy.base.trpo import TRPO 22 | from omnisafe.utils import distributed 23 | 24 | 25 | @registry.register 26 | class TRPOSimmerPID(TRPO): 27 | """The Simmer version(based on PID controller) of the TRPO algorithm. 28 | 29 | A simple combination of the Simmer RL and the Trust Region Policy Optimization algorithm. 30 | 31 | References: 32 | - Title: Effects of Safety State Augmentation on Safe Exploration. 33 | - Authors: Aivar Sootla, Alexander I. Cowen-Rivers, Jun Wang, Haitham Bou Ammar. 34 | - URL: `TRPOSimmerPID `_ 35 | """ 36 | 37 | def _init_env(self) -> None: 38 | """Initialize the environment. 39 | 40 | OmniSafe uses :class:`omnisafe.adapter.SimmerAdapter` to adapt the environment to the algorithm. 41 | 42 | User can customize the environment by inheriting this method. 43 | 44 | Examples: 45 | >>> def _init_env(self) -> None: 46 | ... self._env = CustomAdapter() 47 | """ 48 | self._env: SimmerAdapter = SimmerAdapter( 49 | self._env_id, 50 | self._cfgs.train_cfgs.vector_env_nums, 51 | self._seed, 52 | self._cfgs, 53 | ) 54 | assert (self._cfgs.algo_cfgs.steps_per_epoch) % ( 55 | distributed.world_size() * self._cfgs.train_cfgs.vector_env_nums 56 | ) == 0, 'The number of steps per epoch is not divisible by the number of environments.' 57 | self._steps_per_epoch: int = ( 58 | self._cfgs.algo_cfgs.steps_per_epoch 59 | // distributed.world_size() 60 | // self._cfgs.train_cfgs.vector_env_nums 61 | ) 62 | 63 | def _init_log(self) -> None: 64 | """Log the TRPOSimmerPID specific information. 65 | 66 | +------------------+-----------------------------------+ 67 | | Things to log | Description | 68 | +==================+===================================+ 69 | | Metrics/EpBudget | The safety budget of the episode. | 70 | +------------------+-----------------------------------+ 71 | """ 72 | super()._init_log() 73 | self._logger.register_key('Metrics/EpBudget') 74 | 75 | def _update(self) -> None: 76 | """Update actor, critic, as we used in the :class:`PolicyGradient` algorithm.""" 77 | Jc = self._logger.get_stats('Metrics/EpCost')[0] 78 | self._env.control_budget(torch.as_tensor(Jc, dtype=torch.float32, device=self._device)) 79 | super()._update() 80 | -------------------------------------------------------------------------------- /omnisafe/algorithms/on_policy/base/ppo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 OmniSafe Team. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Implementation of the PPO algorithm.""" 16 | 17 | from __future__ import annotations 18 | 19 | import torch 20 | 21 | from omnisafe.algorithms import registry 22 | from omnisafe.algorithms.on_policy.base.policy_gradient import PolicyGradient 23 | 24 | 25 | @registry.register 26 | class PPO(PolicyGradient): 27 | """The Proximal Policy Optimization (PPO) algorithm. 28 | 29 | References: 30 | - Title: Proximal Policy Optimization Algorithms 31 | - Authors: John Schulman, Filip Wolski, Prafulla Dhariwal, Alec Radford, Oleg Klimov. 32 | - URL: `PPO `_ 33 | """ 34 | 35 | def _loss_pi( 36 | self, 37 | obs: torch.Tensor, 38 | act: torch.Tensor, 39 | logp: torch.Tensor, 40 | adv: torch.Tensor, 41 | ) -> torch.Tensor: 42 | r"""Computing pi/actor loss. 43 | 44 | In Proximal Policy Optimization, the loss is defined as: 45 | 46 | .. math:: 47 | 48 | L^{CLIP} = \underset{s_t \sim \rho_{\theta}}{\mathbb{E}} \left[ 49 | \min ( r_t A^{R}_{\pi_{\theta}} (s_t, a_t) , \text{clip} (r_t, 1 - \epsilon, 1 + \epsilon) 50 | A^{R}_{\pi_{\theta}} (s_t, a_t) 51 | \right] 52 | 53 | where :math:`r_t = \frac{\pi_{\theta}^{'} (a_t|s_t)}{\pi_{\theta} (a_t|s_t)}`, 54 | :math:`\epsilon` is the clip parameter, and :math:`A^{R}_{\pi_{\theta}} (s_t, a_t)` is the 55 | advantage. 56 | 57 | Args: 58 | obs (torch.Tensor): The ``observation`` sampled from buffer. 59 | act (torch.Tensor): The ``action`` sampled from buffer. 60 | logp (torch.Tensor): The ``log probability`` of action sampled from buffer. 61 | adv (torch.Tensor): The ``advantage`` processed. ``reward_advantage`` here. 62 | 63 | Returns: 64 | The loss of pi/actor. 65 | """ 66 | distribution = self._actor_critic.actor(obs) 67 | logp_ = self._actor_critic.actor.log_prob(act) 68 | std = self._actor_critic.actor.std 69 | ratio = torch.exp(logp_ - logp) 70 | ratio_cliped = torch.clamp( 71 | ratio, 72 | 1 - self._cfgs.algo_cfgs.clip, 73 | 1 + self._cfgs.algo_cfgs.clip, 74 | ) 75 | loss = -torch.min(ratio * adv, ratio_cliped * adv).mean() 76 | loss -= self._cfgs.algo_cfgs.entropy_coef * distribution.entropy().mean() 77 | # useful extra info 78 | entropy = distribution.entropy().mean().item() 79 | self._logger.store( 80 | { 81 | 'Train/Entropy': entropy, 82 | 'Train/PolicyRatio': ratio, 83 | 'Train/PolicyStd': std, 84 | 'Loss/Loss_pi': loss.mean().item(), 85 | }, 86 | ) 87 | return loss 88 | --------------------------------------------------------------------------------