├── .gitignore ├── LICENSE ├── README.md ├── assets └── cheetah.gif ├── rs ├── __init__.py ├── checkpoint │ └── pretrained │ │ └── cheetah_4293_784 ├── environment.py ├── envs │ ├── ant.py │ ├── cheetah.py │ ├── humanoid.py │ └── walker.py ├── models │ ├── LICENSE_ant │ ├── LICENSE_dm_control │ ├── LICENSE_gymnasium │ ├── README.md │ ├── ant.xml │ ├── cheetah.xml │ ├── common │ │ ├── materials.xml │ │ ├── skybox.xml │ │ └── visual.xml │ ├── humanoid.xml │ └── walker.xml ├── policy.py ├── search.py ├── train.py └── utilities.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # MuJoCo 163 | MUJOCO_LOG.TXT 164 | 165 | # macOS 166 | .DS_Store 167 | 168 | # checkpoints 169 | checkpoint/ 170 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Taylor Howell 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Random Search 2 | A simple [JAX](https://github.com/google/jax)-based implementation of [random search](https://arxiv.org/abs/1803.07055) for [locomotion tasks](https://github.com/openai/gym/tree/master/gym/envs/mujoco) using [MuJoCo XLA (MJX)](https://mujoco.readthedocs.io/en/stable/mjx.html). 3 | 4 | ## Installation 5 | Clone the repository: 6 | ```sh 7 | git clone https://github.com/thowell/rs 8 | ``` 9 | 10 | Optionally, create a conda environment: 11 | ```sh 12 | conda create -n rs python=3.10 13 | conda activate rs 14 | ``` 15 | 16 | pip install: 17 | ```sh 18 | pip install -e . 19 | ``` 20 | 21 | ## Train cheetah 22 | Train cheetah in ~1 minute with [Nvidia RTX 4090](https://www.nvidia.com/en-us/geforce/graphics-cards/40-series/rtx-4090/) on [Ubuntu 22.04.4 LTS](https://releases.ubuntu.com/jammy/). 23 | 24 | drawing 25 | 26 | Run: 27 | ```sh 28 | python rs/train.py --env cheetah --search --visualize --nsample 2048 --ntop 512 --niter 50 --neval 5 --nhorizon_search 200 --nhorizon_eval 1000 --random_step 0.1 --update_step 0.1 29 | ``` 30 | 31 | Output: 32 | ``` 33 | Settings: 34 | environment: cheetah 35 | nsample: 2048 | ntop: 512 36 | niter: 50 | neval: 5 37 | nhorizon_search: 200 | nhorizon_eval: 1000 38 | random_step: 0.1 | update_step: 0.1 39 | nenveval: 128 40 | reward_shift: 0.0 41 | Search: 42 | iteration (10 / 50): reward = 1172.42 +- 1144.11 | time = 17.52 | avg episode length: 1000 / 1000 | global steps: 8232960 | steps/second: 470022 43 | iteration (20 / 50): reward = 2947.71 +- 1237.87 | time = 5.58 | avg episode length: 1000 / 1000 | global steps: 16465920 | steps/second: 1474670 44 | iteration (30 / 50): reward = 3152.07 +- 1401.50 | time = 5.58 | avg episode length: 1000 / 1000 | global steps: 24698880 | steps/second: 1475961 45 | iteration (40 / 50): reward = 4175.49 +- 783.41 | time = 5.59 | avg episode length: 1000 / 1000 | global steps: 32931840 | steps/second: 1472244 46 | iteration (50 / 50): reward = 4293.36 +- 784.80 | time = 5.59 | avg episode length: 1000 / 1000 | global steps: 41164800 | steps/second: 1473380 47 | 48 | total time: 56.43 49 | ``` 50 | 51 | The pretrained policy can be visualized in MuJoCo's passive viewer: 52 | ``` 53 | python train.py --env cheetah --load pretrained/cheetah --visualize 54 | ``` 55 | 56 | ## Environments 57 | Environments available: 58 | 59 | - [Ant](rs/envs/ant.py) 60 | - based on [ant_v5](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/mujoco/ant_v5.py) 61 | - modified solver settings 62 | - only contact between feet and floor 63 | - no rewards or observations dependent on contact forces 64 | - [Cheetah](rs/envs/cheetah.py) 65 | - based on [half_cheetah_v5](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/mujoco/half_cheetah_v5.py) 66 | - modified solver settings 67 | - [Humanoid](rs/envs/humanoid.py) 68 | - based on [humanoid_v5](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/mujoco/humanoid_v5.py) 69 | - modified solver settings 70 | - only contact between feet and floor 71 | - no rewards or observations dependent on contact forces 72 | - [Walker](rs/envs/walker.py) 73 | - based on [walker2d_v5](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/mujoco/walker2d_v5.py) 74 | - modified solver settings 75 | - only contact between feet and floor 76 | 77 | 78 | ## Usage 79 | **Note**: run multiple times to find good policies. 80 | 81 | First, change to `rs/` directory: 82 | ```sh 83 | cd rs 84 | ``` 85 | 86 | ### Ant 87 | Search: 88 | ```sh 89 | python train.py --env ant --search 90 | ``` 91 | 92 | Visualize policy checkpoint: 93 | ```sh 94 | python train.py --env ant --mode visualize --load pretrained/ant 95 | ``` 96 | 97 | ### Cheetah 98 | Search: 99 | ```sh 100 | python train.py --env cheetah --search 101 | ``` 102 | 103 | Visualize policy checkpoint: 104 | ```sh 105 | python train.py --env cheetah --load pretrained/cheetah --visualize 106 | ``` 107 | 108 | ### Humanoid 109 | Search: 110 | ```sh 111 | python train.py --env humanoid --search 112 | ``` 113 | 114 | Visualize policy checkpoint: 115 | ```sh 116 | python train.py --env humanoid --load pretrained/humanoid --visualize 117 | ``` 118 | 119 | ### Walker 120 | Search: 121 | ```sh 122 | python train.py --env walker --search 123 | ``` 124 | 125 | Visualize policy checkpoint: 126 | ```sh 127 | python train.py --env walker --load pretrained/walker --visualize 128 | ``` 129 | 130 | ### Command line arguments 131 | Setup: 132 | - `--env`: `ant`, `cheetah`, `humanoid`, `walker` 133 | - `--search`: run random search to improve policy 134 | - `--checkpoint`: filename in `checkpoint/` to save policy 135 | - `--load`: provide string in `checkpoint/` 136 | directory to load policy from checkpoint 137 | - `--seed`: int for random number generation 138 | - `--visualize`: visualize policy 139 | 140 | Search settings: 141 | - `--nsample`: number of random directions to sample 142 | - `--ntop`: number of random directions to use for policy update 143 | - `--niter`: number of policy updates 144 | - `--neval`: number of policy evaluations during search 145 | - `--nhorizon_search`: number of environment steps during policy improvement 146 | - `--nhorizon_eval`: number of environment steps during policy evaluation 147 | - `--random_step`: step size for random direction during policy perturbation 148 | - `--update_step`: step size for policy update during policy improvement 149 | - `--nenveval`: number of environments for policy evaluation 150 | - `--reward_shift`: subtract baseline from per-timestep reward 151 | 152 | ## Mapping notation from the paper to code 153 | $\alpha$: `update_step` 154 | 155 | $\nu$: `random_step` 156 | 157 | $N$: `nsample` 158 | 159 | $b$: `ntop` 160 | 161 | ## Notes 162 | - The environments are based on the [v5 MuJoCo Gym environments](https://github.com/Farama-Foundation/Gymnasium/tree/main/gymnasium/envs/mujoco) but may not be exact in all details. 163 | - The search settings are based on [Simple random search provides a competitive approach to reinforcement learning: Table 9](https://arxiv.org/abs/1803.07055) but may not be exact in all details either. 164 | 165 | This repository was developed to: 166 | - understand the [Augmented Random Search](https://arxiv.org/abs/1803.07055) algorithm 167 | - understand how to compute numerically stable running statistics 168 | - understand the details of [Gym environments](https://github.com/openai/gym) 169 | - experiment with code generation tools that are useful for improving development times, including: [ChatGPT](https://pytorch.org/cppdocs/) and [Claude](https://claude.ai/) 170 | - gain experience with [MuJoCo XLA (MJX)](https://mujoco.readthedocs.io/en/stable/mjx.html) 171 | - gain experience with [JAX](https://github.com/google/jax) 172 | 173 | MuJoCo models use resources from [Gymnasium](https://github.com/Farama-Foundation/Gymnasium/tree/main/gymnasium/envs/mujoco) and [dm_control](https://github.com/google-deepmind/dm_control) 174 | -------------------------------------------------------------------------------- /assets/cheetah.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thowell/rs/f71a8af39ed7c2ab351a614174d8173ab654dd6d/assets/cheetah.gif -------------------------------------------------------------------------------- /rs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thowell/rs/f71a8af39ed7c2ab351a614174d8173ab654dd6d/rs/__init__.py -------------------------------------------------------------------------------- /rs/checkpoint/pretrained/cheetah_4293_784: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thowell/rs/f71a8af39ed7c2ab351a614174d8173ab654dd6d/rs/checkpoint/pretrained/cheetah_4293_784 -------------------------------------------------------------------------------- /rs/environment.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import jax 3 | import jax.numpy as jnp 4 | from mujoco import mjx 5 | from mujoco.mjx._src import dataclasses, forward 6 | 7 | from rs.policy import Policy, policy 8 | from typing import Callable, Tuple 9 | 10 | 11 | class Environment(dataclasses.PyTreeNode): 12 | """Learning environment. 13 | 14 | Attributes: 15 | model: MuJoCo model 16 | reward: function returning per-timestep reward 17 | observation: function returning per-timestep observation 18 | reset: function reseting environment, returns mjx.Data 19 | done: function determining environment termination 20 | naction: number of actions 21 | nobservation: number of observations 22 | ndecimation: number of physics steps per policy step 23 | """ 24 | 25 | model: mjx.Model 26 | data: mjx.Data 27 | reward: Callable 28 | observation: Callable 29 | reset: Callable 30 | done: Callable 31 | naction: int 32 | nobservation: int 33 | ndecimation: int 34 | 35 | 36 | @partial(jax.jit, static_argnums=(3,)) 37 | def multistep( 38 | model: mjx.Model, data: mjx.Data, action: jax.Array, nstep: int 39 | ) -> mjx.Data: 40 | """Multiple physics steps for action. 41 | 42 | Args: 43 | model (mjx.Model) 44 | data (mjx.Data) 45 | action (jax.Array) 46 | nstep (int): number of physics steps. 47 | 48 | Returns: 49 | mjx.Data 50 | """ 51 | # set action 52 | data = data.replace(ctrl=action) 53 | 54 | # step physics 55 | def step(d, _): 56 | # step dynamics 57 | d = forward.step(model, d) 58 | 59 | return d, None 60 | 61 | # next data 62 | data, _ = jax.lax.scan(step, data, None, length=nstep, unroll=nstep) 63 | return data 64 | 65 | def rollout( 66 | env: Environment, 67 | p: Policy, 68 | d: mjx.Data, 69 | shift: float = 0.0, 70 | nhorizon: int = 1000, 71 | ) -> Tuple[float, jax.Array]: 72 | """Simulate environment. 73 | 74 | Args: 75 | env (Environment): simulation environment 76 | p (Policy): affine feedback policy 77 | d (mjx.Data): MuJoCo data 78 | shift (float): subtract value from per-timestep reward 79 | nhorizon: number of environment steps 80 | rng: JAX random number key 81 | 82 | Returns: 83 | Tuple[jax.Array, jax.Array]: per-timestep rewards and observations 84 | """ 85 | # get observation 86 | obs = env.observation(env.model, d) 87 | 88 | # initialize observation statistics 89 | # (mean, var, count) 90 | obs_stats_init = (obs.copy(), jnp.zeros_like(obs), 1) 91 | 92 | # continue physics steps 93 | def continue_step(carry): 94 | # unpack carry 95 | data, total_reward, obs, obs_stats, steps, done = carry 96 | return jnp.logical_and(jnp.where(done, False, True), steps < nhorizon) 97 | 98 | # step 99 | def step(carry): 100 | # unpack carry 101 | data, total_reward, obs, obs_stats, steps, done = carry 102 | 103 | # get action 104 | action = policy(p, obs) 105 | 106 | # step 107 | next_data = multistep(env.model, data, action, env.ndecimation) 108 | 109 | # done 110 | next_done = jnp.logical_or(env.done(env.model, next_data), done) 111 | not_done = jnp.where(next_done, 0, 1) 112 | 113 | # get reward 114 | reward = env.reward( 115 | env.model, data, next_data, env.model.opt.timestep * env.ndecimation 116 | ) 117 | reward -= shift 118 | reward *= not_done 119 | 120 | # get observation 121 | next_obs = env.observation(env.model, next_data) 122 | 123 | # unpack observation statistics 124 | mean, var, count = obs_stats 125 | 126 | # update observation statistics 127 | delta = next_obs - mean 128 | next_count = count + not_done 129 | next_mean = mean + not_done * delta / next_count 130 | next_var = var + not_done * delta * delta * count / next_count 131 | next_obs_stats = (next_mean, next_var, next_count) 132 | 133 | return ( 134 | next_data, 135 | total_reward + reward, 136 | next_obs, 137 | next_obs_stats, 138 | steps + 1, 139 | next_done, 140 | ) 141 | 142 | # loop 143 | carry = jax.lax.while_loop(continue_step, step, (d, 0.0, obs, obs_stats_init, 0, False)) 144 | 145 | total_reward = carry[1] 146 | obs_stats = carry[3] 147 | 148 | return total_reward, obs_stats 149 | 150 | 151 | # jit and vmap rollout 152 | v_rollout = jax.jit(jax.vmap(rollout, in_axes=(None, 0, 0, None, None))) 153 | v_rollout_eval = jax.jit(jax.vmap(rollout, in_axes=(None, None, 0, None, None))) 154 | -------------------------------------------------------------------------------- /rs/envs/ant.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from mujoco import mjx 4 | 5 | from rs.policy import initialize_policy 6 | from rs.environment import Environment 7 | from rs.utilities import load_model, visualize 8 | 9 | 10 | def ant_environment(): 11 | """Return Ant learning environment.""" 12 | 13 | # load MuJoCo model + data for MuJoCo C | MuJoCo XLA 14 | mc, dc, mx, dx = load_model("ant") 15 | 16 | # sizes 17 | naction = mx.nu 18 | nobservation = mx.nq - 2 + mx.nv 19 | ndecimation = 5 20 | 21 | # healthy 22 | def is_healthy(m: mjx.Model, d: mjx.Data) -> bool: 23 | z = d.qpos[2] 24 | 25 | min_z, max_z = 0.2, 1.0 26 | 27 | healthy_z = jnp.logical_and(min_z < z, z < max_z) 28 | 29 | return healthy_z 30 | 31 | 32 | def reward(m: mjx.Model, d0: mjx.Data, d1: mjx.Data, dt: float) -> float: 33 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/walker2d_v4.py 34 | 35 | # healthy 36 | r_healthy = is_healthy(m, d1).astype(float) 37 | 38 | # forward 39 | r_forward = (d1.qpos[0] - d0.qpos[0]) / dt 40 | 41 | # control penalty 42 | r_control = jnp.sum(jnp.square(d1.ctrl)) 43 | 44 | return 1.0 * r_healthy + 1.0 * r_forward - 0.5 * r_control 45 | 46 | 47 | def observation(m: mjx.Model, d: mjx.Data) -> jax.Array: 48 | return jnp.hstack([d.qpos[2:], jnp.clip(d.qvel, -10.0, 10.0)]) 49 | 50 | 51 | def reset(m: mjx.Model, d: mjx.Data, rng) -> mjx.Data: 52 | # key 53 | key_pos, key_vel = jax.random.split(rng) 54 | 55 | # qpos 56 | dqpos = 0.1 * jax.random.uniform( 57 | key_pos, shape=(m.nq,), minval=-1.0, maxval=1.0 58 | ) 59 | qpos = m.qpos0 60 | qpos = qpos.at[:3].set(qpos[:3] + dqpos[:3]) 61 | qpos = qpos.at[7:].set(qpos[7:] + dqpos[7:]) 62 | 63 | # qvel 64 | qvel = jnp.clip(jax.random.normal( 65 | key_vel, 66 | shape=(m.nv,), 67 | ), -0.1, 1.0) 68 | 69 | # update data 70 | d = d.replace(qpos=qpos, qvel=qvel) 71 | 72 | return d 73 | 74 | 75 | def done(m: mjx.Model, d: mjx.Data) -> bool: 76 | return jnp.where(is_healthy(m, d), 0, 1) 77 | 78 | 79 | env = Environment( 80 | model=mx, 81 | data=dx, 82 | reward=reward, 83 | observation=observation, 84 | reset=reset, 85 | done=done, 86 | naction=naction, 87 | nobservation=nobservation, 88 | ndecimation=ndecimation, 89 | ) 90 | 91 | # policy 92 | limits = (mc.actuator_ctrlrange[:, 0], mc.actuator_ctrlrange[:, 1]) 93 | p = initialize_policy(env.naction, env.nobservation, limits) 94 | 95 | # search settings 96 | settings = {"nsample": 60, 97 | "ntop": 20, 98 | "niter": 1000, 99 | "neval": 100, 100 | "nhorizon_search": 1000, 101 | "nhorizon_eval": 1000, 102 | "random_step": 0.025, 103 | "update_step": 0.015, 104 | "nenveval": 128, 105 | "reward_shift": 1.0} 106 | 107 | # tracking camera 108 | def lookat(viewer, data): 109 | viewer.cam.lookat[0] = data.qpos[0] 110 | viewer.cam.lookat[1] = data.qpos[1] 111 | viewer.cam.lookat[2] = data.qpos[2] - 0.35 112 | viewer.cam.distance = 4.0 113 | 114 | # visualize 115 | def vis(p): 116 | visualize(mc, dc, env, p, lookat=lookat) 117 | 118 | return env, p, settings, vis 119 | 120 | 121 | if __name__ == "__main__": 122 | # environment 123 | env, p, settings, vis = ant_environment() 124 | 125 | # visualize 126 | vis(p) 127 | -------------------------------------------------------------------------------- /rs/envs/cheetah.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from mujoco import mjx 4 | 5 | from rs.policy import initialize_policy 6 | from rs.environment import Environment 7 | from rs.utilities import load_model, visualize 8 | 9 | 10 | def cheetah_environment(): 11 | """Return Half Cheetah learning environment.""" 12 | # load MuJoCo model + data for MuJoCo C | MuJoCo XLA 13 | mc, dc, mx, dx = load_model("cheetah") 14 | 15 | # sizes 16 | naction = mx.nu 17 | nobservation = mx.nq - 1 + mx.nv 18 | ndecimation = 5 19 | 20 | def reward(m: mjx.Model, d0: mjx.Data, d1: mjx.Data, dt: float) -> float: 21 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/half_cheetah_v4.py 22 | r_forward = (d1.qpos[0] - d0.qpos[0]) / dt 23 | r_control = -jnp.dot(d1.ctrl, d1.ctrl) 24 | return 1.0 * r_forward + 0.1 * r_control 25 | 26 | def observation(m: mjx.Model, d: mjx.Data) -> jax.Array: 27 | return jnp.hstack([d.qpos[1:], d.qvel]) 28 | 29 | def reset(m: mjx.Model, d: mjx.Data, rng) -> mjx.Data: 30 | # key 31 | key_pos, key_vel = jax.random.split(rng) 32 | 33 | # qpos 34 | is_limited = mx.jnt_limited == 1 35 | lower, upper = mx.jnt_range[is_limited].T 36 | qpos = jax.random.uniform(key_pos, shape=(mx.nq,), minval=-0.1, maxval=0.1) 37 | qclip = jnp.clip(qpos[is_limited], a_min=lower, a_max=upper) 38 | qpos = qpos.at[is_limited].set(qclip) 39 | 40 | # qvel 41 | qvel = jnp.clip( 42 | 0.1 * jax.random.normal(key_vel, shape=(mx.nv,)), a_min=-1.0, a_max=1.0 43 | ) 44 | 45 | # update data 46 | d = d.replace(qpos=qpos, qvel=qvel) 47 | 48 | return d 49 | 50 | def done(m: mjx.Model, d: mjx.Data) -> bool: 51 | return False 52 | 53 | env = Environment( 54 | model=mx, 55 | data=dx, 56 | reward=reward, 57 | observation=observation, 58 | reset=reset, 59 | done=done, 60 | naction=naction, 61 | nobservation=nobservation, 62 | ndecimation=ndecimation, 63 | ) 64 | 65 | # policy 66 | limits = (mc.actuator_ctrlrange[:, 0], mc.actuator_ctrlrange[:, 1]) 67 | p = initialize_policy(env.naction, env.nobservation, limits) 68 | 69 | # search settings 70 | settings = { 71 | "nsample": 32, 72 | "ntop": 4, 73 | "niter": 100, 74 | "neval": 10, 75 | "nhorizon_search": 1000, 76 | "nhorizon_eval": 1000, 77 | "random_step": 0.03, 78 | "update_step": 0.02, 79 | "nenveval": 128, 80 | "reward_shift": 0.0, 81 | } 82 | 83 | # tracking camera 84 | def lookat(viewer, data): 85 | viewer.cam.lookat[0] = data.qpos[0] 86 | viewer.cam.lookat[2] = data.qpos[1] + 0.5 87 | viewer.cam.distance = 4.0 88 | 89 | # visualize 90 | def vis(p): 91 | visualize(mc, dc, env, p, lookat=lookat) 92 | 93 | return env, p, settings, vis 94 | 95 | 96 | if __name__ == "__main__": 97 | # environment 98 | env, p, settings, vis = cheetah_environment() 99 | 100 | # visualize 101 | vis(p) 102 | -------------------------------------------------------------------------------- /rs/envs/humanoid.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from mujoco import mjx 4 | 5 | from rs.policy import initialize_policy 6 | from rs.environment import Environment 7 | from rs.utilities import load_model, visualize 8 | 9 | 10 | def humanoid_environment(): 11 | # load MuJoCo model + data for MuJoCo C | MuJoCo XLA 12 | mc, dc, mx, dx = load_model("humanoid") 13 | 14 | # sizes 15 | naction = mx.nu 16 | nobservation = mx.nq - 2 + mx.nv 17 | nobservation += (mx.nbody - 1) * 10 # d.cinert 18 | nobservation += (mx.nbody - 1) * 6 # d.cvel 19 | nobservation += mx.nv - 6 # d.qfrc_actuator 20 | ndecimation = 5 21 | 22 | # healthy 23 | def is_healthy(m: mjx.Model, d: mjx.Data) -> bool: 24 | z = d.qpos[2] 25 | 26 | min_z, max_z = 1.0, 2.0 27 | 28 | healthy_z = jnp.logical_and(min_z < z, z < max_z) 29 | 30 | return healthy_z 31 | 32 | 33 | def reward(m: mjx.Model, d0: mjx.Data, d1: mjx.Data, dt: float) -> float: 34 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/walker2d_v4.py 35 | 36 | # healthy 37 | r_healthy = is_healthy(m, d1).astype(float) 38 | 39 | # forward 40 | r_forward = (d1.qpos[0] - d0.qpos[0]) / dt 41 | 42 | # control penalty 43 | r_control = jnp.sum(jnp.square(d1.ctrl)) 44 | 45 | return 5.0 * r_healthy + 1.25 * r_forward - 0.1 * r_control 46 | 47 | 48 | def observation(m: mjx.Model, d: mjx.Data) -> jax.Array: 49 | pos = d.qpos[2:] 50 | vel = d.qvel 51 | 52 | com_inertia = d.cinert[1:].flatten() 53 | com_velocity = d.cvel[1:].flatten() 54 | actuator_forces = d.qfrc_actuator[6:] 55 | 56 | return jnp.hstack([pos, vel, com_inertia, com_velocity, actuator_forces]) 57 | 58 | 59 | def reset(m: mjx.Model, d: mjx.Data, rng) -> mjx.Data: 60 | # key 61 | key_pos, key_vel = jax.random.split(rng) 62 | 63 | # qpos 64 | qpos = m.qpos0 + 0.01 * jax.random.uniform( 65 | key_pos, shape=(m.nq,), minval=-1.0, maxval=1.0 66 | ) 67 | 68 | # qvel 69 | qvel = 0.01 * jax.random.uniform( 70 | key_vel, 71 | shape=(m.nv,), minval=-1.0, maxval=1.0, 72 | ) 73 | 74 | # update data 75 | d = d.replace(qpos=qpos, qvel=qvel) 76 | 77 | return d 78 | 79 | 80 | def done(m: mjx.Model, d: mjx.Data) -> bool: 81 | return jnp.where(is_healthy(m, d), 0, 1) 82 | 83 | 84 | env = Environment( 85 | model=mx, 86 | data=dx, 87 | reward=reward, 88 | observation=observation, 89 | reset=reset, 90 | done=done, 91 | naction=naction, 92 | nobservation=nobservation, 93 | ndecimation=ndecimation, 94 | ) 95 | 96 | # policy 97 | limits = (mc.actuator_ctrlrange[:, 0], mc.actuator_ctrlrange[:, 1]) 98 | p = initialize_policy(env.naction, env.nobservation, limits) 99 | 100 | # search settings 101 | settings = { 102 | "nsample": 320, 103 | "ntop": 320, 104 | "niter": 1000, 105 | "neval": 100, 106 | "nhorizon_search": 1000, 107 | "nhorizon_eval": 1000, 108 | "random_step": 0.0075, 109 | "update_step": 0.02, 110 | "nenveval": 128, 111 | "reward_shift": 5.0, 112 | } 113 | 114 | # tracking camera 115 | def lookat(viewer, data): 116 | viewer.cam.lookat[0] = data.qpos[0] 117 | viewer.cam.lookat[1] = data.qpos[1] 118 | viewer.cam.lookat[2] = data.qpos[2] 119 | viewer.cam.distance = 4.0 120 | 121 | # visualize 122 | def vis(p): 123 | visualize(mc, dc, env, p, lookat=lookat) 124 | 125 | return env, p, settings, vis 126 | 127 | 128 | if __name__ == "__main__": 129 | # environment 130 | env, p, settings, vis = humanoid_environment() 131 | 132 | # visualize 133 | vis(p) 134 | -------------------------------------------------------------------------------- /rs/envs/walker.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from mujoco import mjx 4 | 5 | from rs.policy import initialize_policy 6 | from rs.environment import Environment 7 | from rs.utilities import load_model, visualize 8 | 9 | def walker_environment(): 10 | # load MuJoCo model + data for MuJoCo C | MuJoCo XLA 11 | mc, dc, mx, dx = load_model("walker") 12 | 13 | # sizes 14 | naction = mx.nu 15 | nobservation = mx.nq - 1 + mx.nv 16 | ndecimation = 4 17 | 18 | # healthy 19 | def is_healthy(m: mjx.Model, d: mjx.Data) -> bool: 20 | z, angle = d.qpos[1:3] 21 | 22 | min_z, max_z = 0.8, 2.0 23 | min_angle, max_angle = -1.0, 1.0 24 | 25 | healthy_z = jnp.logical_and(min_z < z, z < max_z) 26 | healthy_angle = jnp.logical_and(min_angle < angle, angle < max_angle) 27 | 28 | return jnp.logical_and(healthy_z, healthy_angle) 29 | 30 | 31 | def reward(m: mjx.Model, d0: mjx.Data, d1: mjx.Data, dt: float) -> float: 32 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/walker2d_v4.py 33 | 34 | # healthy 35 | r_healthy = is_healthy(m, d1).astype(float) 36 | 37 | # forward 38 | r_forward = (d1.qpos[0] - d0.qpos[0]) / dt 39 | 40 | # control penalty 41 | r_control = -jnp.sum(jnp.square(d1.ctrl)) 42 | 43 | return 1.0 * r_healthy + 1.0 * r_forward + 0.001 * r_control 44 | 45 | 46 | def observation(m: mjx.Model, d: mjx.Data) -> jax.Array: 47 | return jnp.hstack([d.qpos[1:], jnp.clip(d.qvel, -10.0, 10.0)]) 48 | 49 | 50 | def reset(m: mjx.Model, d: mjx.Data, rng) -> mjx.Data: 51 | # key 52 | key_pos, key_vel = jax.random.split(rng) 53 | 54 | # qpos 55 | qpos = m.qpos0 + 5.0e-3 * jax.random.uniform( 56 | key_pos, shape=(m.nq,), minval=-1.0, maxval=1.0 57 | ) 58 | 59 | # qvel 60 | qvel = jnp.zeros(m.nv) + 5.0e-3 * jax.random.uniform( 61 | key_vel, shape=(m.nv,), minval=-1.0, maxval=1.0 62 | ) 63 | 64 | # update data 65 | d = d.replace(qpos=qpos, qvel=qvel) 66 | 67 | return d 68 | 69 | 70 | def done(m: mjx.Model, d: mjx.Data) -> bool: 71 | return jnp.where(is_healthy(m, d), 0, 1) 72 | 73 | 74 | env = Environment( 75 | model=mx, 76 | data=dx, 77 | reward=reward, 78 | observation=observation, 79 | reset=reset, 80 | done=done, 81 | naction=naction, 82 | nobservation=nobservation, 83 | ndecimation=ndecimation, 84 | ) 85 | 86 | # policy 87 | limits = (mc.actuator_ctrlrange[:, 0], mc.actuator_ctrlrange[:, 1]) 88 | p = initialize_policy(env.naction, env.nobservation, limits) 89 | 90 | # search settings 91 | settings = { 92 | "nsample": 40, 93 | "ntop": 30, 94 | "niter": 1000, 95 | "neval": 100, 96 | "nhorizon_search": 1000, 97 | "nhorizon_eval": 1000, 98 | "random_step": 0.025, 99 | "update_step": 0.015, 100 | "nenveval": 128, 101 | "reward_shift": 1.0} 102 | 103 | # tracking camera 104 | def lookat(viewer, data): 105 | viewer.cam.lookat[0] = data.qpos[0] 106 | viewer.cam.lookat[2] = data.qpos[1] 107 | viewer.cam.distance = 4.0 108 | 109 | # visualize 110 | def vis(p): 111 | visualize(mc, dc, env, p, lookat=lookat) 112 | 113 | return env, p, settings, vis 114 | 115 | if __name__ == "__main__": 116 | # environment 117 | env, p, settings, vis = walker_environment() 118 | 119 | # visualize 120 | vis(p) 121 | -------------------------------------------------------------------------------- /rs/models/LICENSE_ant: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Philipp Moritz, The dm_control Authors 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. -------------------------------------------------------------------------------- /rs/models/LICENSE_dm_control: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /rs/models/LICENSE_gymnasium: -------------------------------------------------------------------------------- 1 | The MIT License 2 | 3 | Copyright (c) 2016 OpenAI 4 | Copyright (c) 2022 Farama Foundation 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in 14 | all copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22 | THE SOFTWARE. -------------------------------------------------------------------------------- /rs/models/README.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | Models are originally from `Gymnasium` [Release 1.0.0a2 (v5)](https://github.com/Farama-Foundation/Gymnasium/releases/tag/v1.0.0a2). 4 | 5 | Modifications are made for improved performance on accelerators with MuJoCo XLA (MJX). Visualization is changed to match [dm_control](https://github.com/google-deepmind/dm_control). 6 | -------------------------------------------------------------------------------- /rs/models/ant.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /rs/models/cheetah.xml: -------------------------------------------------------------------------------- 1 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | -------------------------------------------------------------------------------- /rs/models/common/materials.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | -------------------------------------------------------------------------------- /rs/models/common/skybox.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | -------------------------------------------------------------------------------- /rs/models/common/visual.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /rs/models/humanoid.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /rs/models/walker.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | -------------------------------------------------------------------------------- /rs/policy.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | from mujoco.mjx._src import dataclasses 4 | from typing import Tuple 5 | 6 | 7 | class Policy(dataclasses.PyTreeNode): 8 | """Affine feedback policy. 9 | 10 | output = clip(weight @ ((input - shift) / scale), limits[0], limits[1]) 11 | 12 | Attributes: 13 | weight: feedback matrix 14 | shift: input shift 15 | scale: input scaling 16 | limits: output limits 17 | """ 18 | 19 | weight: jax.Array 20 | shift: jax.Array 21 | scale: jax.Array 22 | limits: Tuple[jax.Array, jax.Array] 23 | 24 | 25 | def initialize_policy( 26 | nact: int, nobs: int, limits: Tuple[jax.Array, jax.Array] 27 | ) -> Policy: 28 | """Initialize policy. 29 | 30 | Args: 31 | nact (int): action dimension 32 | nobs (int): observation dimension 33 | limits (Tuple[jax.Array, jax.Array]): action limits 34 | 35 | Returns: 36 | Policy 37 | """ 38 | return Policy( 39 | weight=jnp.zeros((nact, nobs)), 40 | shift=jnp.zeros(nobs), 41 | scale=jnp.ones(nobs), 42 | limits=limits, 43 | ) 44 | 45 | 46 | def policy(p: Policy, obs: jax.Array) -> jax.Array: 47 | """Evaluate policy. 48 | 49 | Args: 50 | p (Policy) 51 | obs (jax.Array): input to policy 52 | 53 | Returns: 54 | jax.Array: output from policy 55 | """ 56 | return jnp.clip( 57 | p.weight @ ((obs - p.shift) / (p.scale + 1.0e-5)), 58 | a_min=p.limits[0], 59 | a_max=p.limits[1], 60 | ) 61 | 62 | 63 | def noisy_policy(p: Policy, scale: float, rng) -> Tuple[Policy, jax.Array]: 64 | """Sample noisy policy. 65 | 66 | Args: 67 | p (Policy) 68 | scale (float): scaling 69 | rng (jax.Array): JAX random number key 70 | 71 | Returns: 72 | Policy 73 | perturb (jax.Array): perturbation 74 | """ 75 | # sample noise: perturb ~ N(0, I) 76 | perturb = jax.random.normal(rng, shape=p.weight.shape) 77 | 78 | # copy policy 79 | noisy_policy = Policy( 80 | weight=p.weight.copy() + scale * perturb, 81 | shift=p.shift.copy(), 82 | scale=p.scale.copy(), 83 | limits=(p.limits[0], p.limits[1]), 84 | ) 85 | return noisy_policy, perturb 86 | 87 | 88 | # jit and vmap noisy_policy 89 | v_noisy_policy = jax.jit(jax.vmap(noisy_policy, in_axes=(None, 0, 0))) 90 | -------------------------------------------------------------------------------- /rs/search.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Tuple 3 | import jax 4 | import jax.numpy as jnp 5 | from mujoco.mjx._src import dataclasses 6 | import pathlib 7 | from pathlib import Path 8 | import pickle 9 | import time 10 | 11 | from rs.environment import Environment, v_rollout, v_rollout_eval 12 | from rs.policy import Policy, v_noisy_policy 13 | 14 | # running statistics: (mean, var, count) 15 | RunningStatistics = Tuple[jax.Array, jax.Array, int] 16 | 17 | 18 | class Search(dataclasses.PyTreeNode): 19 | """Search settings. 20 | 21 | Attributes: 22 | nsample: number of samples to evaluate at each iteration (+, -) 23 | ntop: number of top samples used for update 24 | niter: number of iterations 25 | neval: number of evaluations 26 | nhorizon_search: number of steps to simulate policy for improvement with search 27 | nhorizon_eval: number of steps to simulate policy for evaluation 28 | random_step: search size; p +- step * N(0, I) 29 | update_step: parameter update size: W += step * update 30 | step_direction: directions for random_step 31 | nenveval: number of environments for evaluation 32 | reward_shift: value to subtract from per-timestep environment reward 33 | """ 34 | 35 | nsample: int 36 | ntop: int 37 | niter: int 38 | neval: int 39 | nhorizon_search: int 40 | nhorizon_eval: int 41 | random_step: float 42 | update_step: float 43 | step_direction: jax.Array 44 | nenveval: int 45 | reward_shift: float 46 | 47 | 48 | def initialize_search( 49 | nsample: int, 50 | ntop: int, 51 | niter: int, 52 | neval: int, 53 | nhorizon_search: int, 54 | nhorizon_eval: int, 55 | random_step: float, 56 | update_step: float, 57 | nenveval: int = 128, 58 | reward_shift: float = 0.0 59 | ) -> Search: 60 | """Create Search with settings. 61 | 62 | Args: 63 | nsample (int) 64 | ntop (int) 65 | niter (int) 66 | neval (int) 67 | nhorizon_search (int) 68 | nhorizon_eval (int) 69 | random_step (float) 70 | update_step (float) 71 | nenveval (int) 72 | reward_shift (float) 73 | 74 | Returns: 75 | Search 76 | """ 77 | # step direction 78 | step_direction = jnp.concatenate( 79 | [ 80 | random_step * jnp.ones(nsample), 81 | -random_step * jnp.ones(nsample), 82 | ] 83 | ) 84 | 85 | return Search( 86 | nsample=nsample, 87 | ntop=ntop, 88 | niter=niter, 89 | neval=neval, 90 | nhorizon_search=nhorizon_search, 91 | nhorizon_eval=nhorizon_eval, 92 | random_step=random_step, 93 | update_step=update_step, 94 | step_direction=step_direction, 95 | nenveval=nenveval, 96 | reward_shift=reward_shift, 97 | ) 98 | 99 | 100 | @partial(jax.jit, static_argnums=(5,)) 101 | def search( 102 | s: Search, 103 | env: Environment, 104 | p: Policy, 105 | obs_stats: RunningStatistics, 106 | rng: jax.Array, 107 | iter: int = 1, 108 | ) -> Tuple[Policy, RunningStatistics, Tuple[jax.Array, jax.Array]]: 109 | """Improve policy with random search. 110 | 111 | Returns: 112 | Policy 113 | obs_stats (RunningStatistics) 114 | reward_stats (Tuple[jax.Array, jax.Array, jax.Array]): reward statistics (mean, std, count) 115 | """ 116 | 117 | ## iteration 118 | def iteration(carry, _): 119 | # unpack 120 | rng, p, obs_stats = carry 121 | 122 | # random 123 | keys = jax.random.split(rng, 3 * s.nsample + 1) 124 | rng = keys[-1] 125 | key_perturb = keys[:s.nsample] 126 | key_policy_positive_negative = jnp.concatenate([key_perturb, key_perturb]) 127 | 128 | # noisy policies (and perturbations) 129 | policy_noisy, policy_perturb = v_noisy_policy( 130 | p, s.step_direction, key_policy_positive_negative 131 | ) 132 | 133 | # reset 134 | key_reset = keys[s.nsample:-1] 135 | d_random = jax.vmap(env.reset, in_axes=(None, None, 0))( 136 | env.model, env.data, key_reset 137 | ) 138 | 139 | # rollout noisy policies 140 | rewards, obs_stats_rollout = v_rollout(env, policy_noisy, d_random, s.reward_shift, s.nhorizon_search) 141 | 142 | # collect running statistics from environments 143 | def merge_running_statistics(stats0, stats1): 144 | # https://github.com/a-mitani/welford/blob/b7f96b9ad5e803d6de665c7df1cdcfb2a53bddc8/welford/welford.py#L132 145 | mean0, var0, count0 = stats0 146 | mean1, var1, count1 = stats1 147 | 148 | count = count0 + count1 149 | 150 | delta = mean0 - mean1 151 | delta2 = delta * delta 152 | 153 | mean = (count0 * mean0 + count1 * mean1) / count 154 | var = var0 + var1 + delta2 * count0 * count1 / count 155 | 156 | return (mean, var, count), count1 157 | 158 | obs_stats_updated, counts = jax.lax.scan( 159 | merge_running_statistics, obs_stats, obs_stats_rollout 160 | ) 161 | 162 | # collect reward pairs 163 | rewards_pos_neg = jnp.vstack([rewards[: s.nsample], rewards[s.nsample :]]) 164 | 165 | # reward pair max 166 | rewards_pos_neg_max = jnp.max(rewards_pos_neg, axis=0) 167 | 168 | # sort reward pairs descending, keep first ntop 169 | sort = jnp.argsort(rewards_pos_neg_max, descending=True)[: s.ntop] 170 | 171 | # ntop best pairs 172 | rewards_best = rewards_pos_neg[:, sort] 173 | 174 | # best pair statistics 175 | rewards_best_mean = rewards_best.flatten().mean() 176 | rewards_best_std = rewards_best.flatten().std() 177 | rewards_best_std = jnp.where(rewards_best_std < 1.0e-7, float("inf"), rewards_best_std) 178 | 179 | # new weights 180 | # https://arxiv.org/pdf/1803.07055.pdf: algorithm 2 181 | weight_update = jnp.einsum( 182 | "i,ijk->jk", 183 | jnp.dot(jnp.array([1.0, -1.0]), rewards_best), 184 | policy_perturb[sort], 185 | ) 186 | weight = ( 187 | p.weight 188 | + s.update_step / s.ntop / (rewards_best_std + 1.0e-5) * weight_update 189 | ) 190 | 191 | # update policy 192 | mean, var, count = obs_stats_updated 193 | std = jnp.where(count > 1, jnp.sqrt(var / (count - 1)), 1.0) 194 | std = jnp.where(std < 1.0e-7, float("inf"), std) 195 | p = p.replace( 196 | weight=weight, 197 | shift=mean, 198 | scale=std, 199 | ) 200 | 201 | return (rng, p, obs_stats_updated), ( 202 | rewards_best_mean, 203 | rewards_best_std, 204 | jnp.mean(counts), 205 | ) 206 | 207 | # loop 208 | initial_count = obs_stats[2] 209 | carry, reward_stats = jax.lax.scan( 210 | iteration, 211 | ( 212 | rng, 213 | p, 214 | obs_stats, 215 | ), 216 | None, 217 | length=iter, 218 | ) 219 | policy = carry[1] 220 | obs_stats_update = carry[2] 221 | env_steps = obs_stats_update[2] - initial_count 222 | return policy, obs_stats_update, reward_stats, env_steps 223 | 224 | 225 | def eval_search( 226 | s: Search, 227 | env: Environment, 228 | p: Policy, 229 | seed: int = 0, 230 | checkpoint: str = None, 231 | ) -> Policy: 232 | """Improve policy with random search and provide evaluation information during training. 233 | 234 | Returns 235 | Policy 236 | """ 237 | print("Search:") 238 | 239 | # vmap reset 240 | v_reset = jax.jit(jax.vmap(env.reset, in_axes=(None, None, 0))) 241 | 242 | # create logs directory for checkpoints 243 | if checkpoint is not None: 244 | # directory 245 | checkpoint_dir = pathlib.Path(__file__).parent / "checkpoint" 246 | 247 | # create directory 248 | Path(str(checkpoint_dir)).mkdir(parents=True, exist_ok=True) 249 | 250 | # create subdirectory 251 | Path(str(checkpoint_dir) + "/" + checkpoint).mkdir(parents=True, exist_ok=True) 252 | 253 | 254 | # start total timer 255 | start_total = time.time() 256 | 257 | # initialize key from seed 258 | key = jax.random.PRNGKey(seed) 259 | 260 | # initialize observation statistics 261 | obs_stats = ( 262 | jnp.zeros(env.nobservation), 263 | jnp.ones(env.nobservation), 264 | jnp.ones(1).astype(int), 265 | ) 266 | 267 | # evaluation iterations 268 | iter_per_eval = int(s.niter / s.neval) 269 | for i in range(s.neval): 270 | # random 271 | rng, key = jax.random.split(key) 272 | 273 | # search 274 | start = time.time() 275 | p, obs_stats, _, search_steps = search( 276 | s, env, p, obs_stats, rng, iter=iter_per_eval 277 | ) 278 | 279 | # stop search timer 280 | search_time = time.time() - start 281 | 282 | ## evaluate 283 | # random 284 | keys = jax.random.split(rng, s.nenveval + 2) 285 | key_eval_rollout = keys[:-2] 286 | rng, key = keys[-2:] 287 | 288 | # environment reset 289 | d_random = v_reset(env.model, env.data, key_eval_rollout) 290 | 291 | # rollout current 292 | rewards_eval, obs_stats_eval = v_rollout_eval(env, p, d_random, 0.0, s.nhorizon_eval) 293 | 294 | # stats 295 | reward_mean = rewards_eval.mean() 296 | reward_std = rewards_eval.std() 297 | avg_episode_len = int(obs_stats_eval[2].mean()) - 1 298 | env_steps = int(obs_stats[2][0]) - 1 299 | print( 300 | f"iteration ({(i + 1) * iter_per_eval} / {s.niter}): reward = {reward_mean:.2f} +- {reward_std:.2f} | time = {search_time:.2f} | avg episode length: {avg_episode_len} / {s.nhorizon_eval} | global steps: {env_steps} | steps/second: {int(search_steps[0]/search_time)}" 301 | ) 302 | 303 | # save checkpoint 304 | if checkpoint is not None: 305 | checkpoint_path = ( 306 | str(checkpoint_dir) 307 | + "/" 308 | + checkpoint 309 | + "/" 310 | + checkpoint 311 | + "_" 312 | + str(i) 313 | + "_" 314 | + "{:.2f}".format(reward_mean) 315 | + "_" 316 | + "{:.2f}".format(reward_std) 317 | ) 318 | with open(checkpoint_path, "wb") as file: 319 | pickle.dump(p, file) 320 | 321 | # total time 322 | print(f"\ntotal time: {time.time() - start_total:.2f}") 323 | 324 | return p 325 | -------------------------------------------------------------------------------- /rs/train.py: -------------------------------------------------------------------------------- 1 | from rs.search import initialize_search, eval_search 2 | from rs.utilities import load_policy, parse 3 | 4 | # import environments 5 | from rs.envs.ant import ant_environment 6 | from rs.envs.cheetah import cheetah_environment 7 | from rs.envs.humanoid import humanoid_environment 8 | from rs.envs.walker import walker_environment 9 | 10 | 11 | def train_settings(env, settings): 12 | """Print training information.""" 13 | print("Settings:") 14 | print(f" environment: {env}") 15 | print(f" nsample: {settings['nsample']} | ntop: {settings['ntop']}") 16 | print(f" niter: {settings['niter']} | neval: {settings['neval']}") 17 | print( 18 | f" nhorizon_search: {settings['nhorizon_search']} | nhorizon_eval: {settings['nhorizon_eval']}" 19 | ) 20 | print( 21 | f" random_step: {settings['random_step']} | update_step: {settings['update_step']}" 22 | ) 23 | print(f" nenveval: {settings['nenveval']}") 24 | print(f" reward_shift: {settings['reward_shift']}") 25 | 26 | 27 | def train(): 28 | """Train linear policy with random search.""" 29 | # parse settings 30 | args = parse() 31 | 32 | # environment 33 | env, p, settings, vis = eval(str(args.env + "_environment()")) 34 | 35 | # load policy 36 | if args.load != "": 37 | p = load_policy(p, args.load) 38 | 39 | # search 40 | if args.search: 41 | # update settings with parsed arguments 42 | for k, v in vars(args).items(): 43 | if k in settings and v is not None: 44 | settings[k] = v 45 | 46 | # settings 47 | train_settings(args.env, settings) 48 | 49 | # initialize 50 | s = initialize_search(**settings) 51 | 52 | # search + evaluation 53 | p = eval_search(s, env, p, seed=args.seed, checkpoint=args.checkpoint) 54 | 55 | # visualize policy 56 | if args.visualize: 57 | vis(p) 58 | 59 | 60 | if __name__ == "__main__": 61 | train() 62 | -------------------------------------------------------------------------------- /rs/utilities.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | import argparse 3 | import mujoco 4 | import mujoco.viewer 5 | from mujoco import mjx 6 | import pathlib 7 | import pickle 8 | import time 9 | 10 | from rs.environment import Environment 11 | from rs.policy import Policy, policy 12 | 13 | def load_model(model: str) -> Tuple[mjx.Model, mjx.Data]: 14 | """Load MuJoCo XLA model and data.""" 15 | # path to model 16 | path = pathlib.Path(__file__).parent / str("models/" + model + ".xml") 17 | 18 | # mjc model + data 19 | m = mujoco.MjModel.from_xml_path(str(path)) 20 | d = mujoco.MjData(m) 21 | 22 | # mjx model + data 23 | m_mjx = mjx.put_model(m) 24 | d_mjx = mjx.put_data(m, d) 25 | 26 | return m, d, m_mjx, d_mjx 27 | 28 | 29 | def load_policy(p: Policy, file: str) -> Policy: 30 | """Load policy from checkpoint.""" 31 | checkpoint_dir = pathlib.Path(__file__).parent / "checkpoint" 32 | checkpoint_path = str(checkpoint_dir) + "/" + file 33 | try: 34 | with open(checkpoint_path, 'rb') as f: 35 | p = pickle.load(f) 36 | 37 | print(f"Success: load policy from: {checkpoint_path}") 38 | except Exception as e: 39 | print(f"Failure: load policy from: {checkpoint_path}") 40 | 41 | return p 42 | 43 | 44 | def visualize( 45 | m: mujoco.MjModel, 46 | d: mujoco.MjData, 47 | env: Environment, 48 | p: Policy, 49 | lookat: Callable = lambda viewer, data: None, 50 | ): 51 | """Visualize learned policy with MuJoCo passive viewer.""" 52 | # visualize policy 53 | with mujoco.viewer.launch_passive(m, d) as viewer: 54 | while viewer.is_running(): 55 | # start timer 56 | t0 = time.time() 57 | 58 | # observations 59 | obs = env.observation(m, d) 60 | 61 | # compute and set actions 62 | d.ctrl = policy(p, obs) 63 | 64 | # simulate 65 | for _ in range(env.ndecimation): 66 | # step physics 67 | mujoco.mj_step(m, d) 68 | 69 | # camera tracking 70 | lookat(viewer, d) 71 | 72 | # sync visualization 73 | viewer.sync() 74 | 75 | # wait 76 | elapsed = time.time() - t0 77 | time.sleep(max(m.opt.timestep * env.ndecimation - elapsed, 0.0)) 78 | 79 | 80 | def parse(): 81 | """Parse command line arguments for environment and search settings.""" 82 | # parser 83 | parser = argparse.ArgumentParser() 84 | 85 | # search setup 86 | parser.add_argument("--env", type=str, choices=["ant", "cheetah", "humanoid", "walker"], default="cheetah", help="Learning environment (default: cheetah)") 87 | parser.add_argument("--search", action="store_true", help="Random search to find policy") 88 | parser.add_argument("--load", type=str, default="", help="Path to saved policy (default: "")") 89 | parser.add_argument("--checkpoint", type=str, default=None, help="Path to saved policy (default: None)") 90 | parser.add_argument("--seed", type=int, default=0, help="Random seed (default: 0)") 91 | parser.add_argument("--visualize", action="store_true", help="Visualize policy") 92 | 93 | # search settings 94 | parser.add_argument("--nsample", type=int, default=None, help="Number of random directions to sample") 95 | parser.add_argument("--ntop", type=int, default=None, help="Number of random directions to use for policy update") 96 | parser.add_argument("--niter", type=int, default=None, help="Number of policy updates") 97 | parser.add_argument("--neval", type=int, default=None, help="Number of policy evaluations during search") 98 | parser.add_argument("--nhorizon_search", type=int, default=None, help="Number of environment steps during policy search") 99 | parser.add_argument("--nhorizon_eval", type=int, default=None, help="Number of environment steps during policy evaluation") 100 | parser.add_argument("--random_step", type=float, default=None, help="Step size for random direction") 101 | parser.add_argument("--update_step", type=float, default=None, help="Step size for policy update") 102 | parser.add_argument("--nenveval", type=int, default=None, help="Number of environments for policy evaluation") 103 | parser.add_argument("--reward_shift", type=float, default=None, help="Subtract from per-timestep reward") 104 | 105 | # args 106 | args = parser.parse_args() 107 | 108 | return args 109 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="rs", 5 | version="0.1", 6 | packages=find_packages(), 7 | install_requires=[ 8 | "mujoco", 9 | "mujoco-mjx", 10 | "jax==0.4.29", 11 | "jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.29+cuda12.cudnn91-cp310-cp310-manylinux2014_x86_64.whl", 12 | ], 13 | include_package_data=True, 14 | description="A simple JAX-based implementation of random search for locomotion tasks using MuJoCo XLA (MJX).", 15 | long_description=open("README.md").read(), 16 | url="https://github.com/thowell/rs", 17 | author="Taylor Howell", 18 | author_email="taylor.athaniel.howell@gmail.com", 19 | python_requires=">=3.10", 20 | ) 21 | --------------------------------------------------------------------------------