├── .gitignore
├── LICENSE
├── README.md
├── assets
└── cheetah.gif
├── rs
├── __init__.py
├── checkpoint
│ └── pretrained
│ │ └── cheetah_4293_784
├── environment.py
├── envs
│ ├── ant.py
│ ├── cheetah.py
│ ├── humanoid.py
│ └── walker.py
├── models
│ ├── LICENSE_ant
│ ├── LICENSE_dm_control
│ ├── LICENSE_gymnasium
│ ├── README.md
│ ├── ant.xml
│ ├── cheetah.xml
│ ├── common
│ │ ├── materials.xml
│ │ ├── skybox.xml
│ │ └── visual.xml
│ ├── humanoid.xml
│ └── walker.xml
├── policy.py
├── search.py
├── train.py
└── utilities.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # MuJoCo
163 | MUJOCO_LOG.TXT
164 |
165 | # macOS
166 | .DS_Store
167 |
168 | # checkpoints
169 | checkpoint/
170 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Taylor Howell
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Random Search
2 | A simple [JAX](https://github.com/google/jax)-based implementation of [random search](https://arxiv.org/abs/1803.07055) for [locomotion tasks](https://github.com/openai/gym/tree/master/gym/envs/mujoco) using [MuJoCo XLA (MJX)](https://mujoco.readthedocs.io/en/stable/mjx.html).
3 |
4 | ## Installation
5 | Clone the repository:
6 | ```sh
7 | git clone https://github.com/thowell/rs
8 | ```
9 |
10 | Optionally, create a conda environment:
11 | ```sh
12 | conda create -n rs python=3.10
13 | conda activate rs
14 | ```
15 |
16 | pip install:
17 | ```sh
18 | pip install -e .
19 | ```
20 |
21 | ## Train cheetah
22 | Train cheetah in ~1 minute with [Nvidia RTX 4090](https://www.nvidia.com/en-us/geforce/graphics-cards/40-series/rtx-4090/) on [Ubuntu 22.04.4 LTS](https://releases.ubuntu.com/jammy/).
23 |
24 |
25 |
26 | Run:
27 | ```sh
28 | python rs/train.py --env cheetah --search --visualize --nsample 2048 --ntop 512 --niter 50 --neval 5 --nhorizon_search 200 --nhorizon_eval 1000 --random_step 0.1 --update_step 0.1
29 | ```
30 |
31 | Output:
32 | ```
33 | Settings:
34 | environment: cheetah
35 | nsample: 2048 | ntop: 512
36 | niter: 50 | neval: 5
37 | nhorizon_search: 200 | nhorizon_eval: 1000
38 | random_step: 0.1 | update_step: 0.1
39 | nenveval: 128
40 | reward_shift: 0.0
41 | Search:
42 | iteration (10 / 50): reward = 1172.42 +- 1144.11 | time = 17.52 | avg episode length: 1000 / 1000 | global steps: 8232960 | steps/second: 470022
43 | iteration (20 / 50): reward = 2947.71 +- 1237.87 | time = 5.58 | avg episode length: 1000 / 1000 | global steps: 16465920 | steps/second: 1474670
44 | iteration (30 / 50): reward = 3152.07 +- 1401.50 | time = 5.58 | avg episode length: 1000 / 1000 | global steps: 24698880 | steps/second: 1475961
45 | iteration (40 / 50): reward = 4175.49 +- 783.41 | time = 5.59 | avg episode length: 1000 / 1000 | global steps: 32931840 | steps/second: 1472244
46 | iteration (50 / 50): reward = 4293.36 +- 784.80 | time = 5.59 | avg episode length: 1000 / 1000 | global steps: 41164800 | steps/second: 1473380
47 |
48 | total time: 56.43
49 | ```
50 |
51 | The pretrained policy can be visualized in MuJoCo's passive viewer:
52 | ```
53 | python train.py --env cheetah --load pretrained/cheetah --visualize
54 | ```
55 |
56 | ## Environments
57 | Environments available:
58 |
59 | - [Ant](rs/envs/ant.py)
60 | - based on [ant_v5](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/mujoco/ant_v5.py)
61 | - modified solver settings
62 | - only contact between feet and floor
63 | - no rewards or observations dependent on contact forces
64 | - [Cheetah](rs/envs/cheetah.py)
65 | - based on [half_cheetah_v5](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/mujoco/half_cheetah_v5.py)
66 | - modified solver settings
67 | - [Humanoid](rs/envs/humanoid.py)
68 | - based on [humanoid_v5](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/mujoco/humanoid_v5.py)
69 | - modified solver settings
70 | - only contact between feet and floor
71 | - no rewards or observations dependent on contact forces
72 | - [Walker](rs/envs/walker.py)
73 | - based on [walker2d_v5](https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/envs/mujoco/walker2d_v5.py)
74 | - modified solver settings
75 | - only contact between feet and floor
76 |
77 |
78 | ## Usage
79 | **Note**: run multiple times to find good policies.
80 |
81 | First, change to `rs/` directory:
82 | ```sh
83 | cd rs
84 | ```
85 |
86 | ### Ant
87 | Search:
88 | ```sh
89 | python train.py --env ant --search
90 | ```
91 |
92 | Visualize policy checkpoint:
93 | ```sh
94 | python train.py --env ant --mode visualize --load pretrained/ant
95 | ```
96 |
97 | ### Cheetah
98 | Search:
99 | ```sh
100 | python train.py --env cheetah --search
101 | ```
102 |
103 | Visualize policy checkpoint:
104 | ```sh
105 | python train.py --env cheetah --load pretrained/cheetah --visualize
106 | ```
107 |
108 | ### Humanoid
109 | Search:
110 | ```sh
111 | python train.py --env humanoid --search
112 | ```
113 |
114 | Visualize policy checkpoint:
115 | ```sh
116 | python train.py --env humanoid --load pretrained/humanoid --visualize
117 | ```
118 |
119 | ### Walker
120 | Search:
121 | ```sh
122 | python train.py --env walker --search
123 | ```
124 |
125 | Visualize policy checkpoint:
126 | ```sh
127 | python train.py --env walker --load pretrained/walker --visualize
128 | ```
129 |
130 | ### Command line arguments
131 | Setup:
132 | - `--env`: `ant`, `cheetah`, `humanoid`, `walker`
133 | - `--search`: run random search to improve policy
134 | - `--checkpoint`: filename in `checkpoint/` to save policy
135 | - `--load`: provide string in `checkpoint/`
136 | directory to load policy from checkpoint
137 | - `--seed`: int for random number generation
138 | - `--visualize`: visualize policy
139 |
140 | Search settings:
141 | - `--nsample`: number of random directions to sample
142 | - `--ntop`: number of random directions to use for policy update
143 | - `--niter`: number of policy updates
144 | - `--neval`: number of policy evaluations during search
145 | - `--nhorizon_search`: number of environment steps during policy improvement
146 | - `--nhorizon_eval`: number of environment steps during policy evaluation
147 | - `--random_step`: step size for random direction during policy perturbation
148 | - `--update_step`: step size for policy update during policy improvement
149 | - `--nenveval`: number of environments for policy evaluation
150 | - `--reward_shift`: subtract baseline from per-timestep reward
151 |
152 | ## Mapping notation from the paper to code
153 | $\alpha$: `update_step`
154 |
155 | $\nu$: `random_step`
156 |
157 | $N$: `nsample`
158 |
159 | $b$: `ntop`
160 |
161 | ## Notes
162 | - The environments are based on the [v5 MuJoCo Gym environments](https://github.com/Farama-Foundation/Gymnasium/tree/main/gymnasium/envs/mujoco) but may not be exact in all details.
163 | - The search settings are based on [Simple random search provides a competitive approach to reinforcement learning: Table 9](https://arxiv.org/abs/1803.07055) but may not be exact in all details either.
164 |
165 | This repository was developed to:
166 | - understand the [Augmented Random Search](https://arxiv.org/abs/1803.07055) algorithm
167 | - understand how to compute numerically stable running statistics
168 | - understand the details of [Gym environments](https://github.com/openai/gym)
169 | - experiment with code generation tools that are useful for improving development times, including: [ChatGPT](https://pytorch.org/cppdocs/) and [Claude](https://claude.ai/)
170 | - gain experience with [MuJoCo XLA (MJX)](https://mujoco.readthedocs.io/en/stable/mjx.html)
171 | - gain experience with [JAX](https://github.com/google/jax)
172 |
173 | MuJoCo models use resources from [Gymnasium](https://github.com/Farama-Foundation/Gymnasium/tree/main/gymnasium/envs/mujoco) and [dm_control](https://github.com/google-deepmind/dm_control)
174 |
--------------------------------------------------------------------------------
/assets/cheetah.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thowell/rs/f71a8af39ed7c2ab351a614174d8173ab654dd6d/assets/cheetah.gif
--------------------------------------------------------------------------------
/rs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thowell/rs/f71a8af39ed7c2ab351a614174d8173ab654dd6d/rs/__init__.py
--------------------------------------------------------------------------------
/rs/checkpoint/pretrained/cheetah_4293_784:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thowell/rs/f71a8af39ed7c2ab351a614174d8173ab654dd6d/rs/checkpoint/pretrained/cheetah_4293_784
--------------------------------------------------------------------------------
/rs/environment.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | import jax
3 | import jax.numpy as jnp
4 | from mujoco import mjx
5 | from mujoco.mjx._src import dataclasses, forward
6 |
7 | from rs.policy import Policy, policy
8 | from typing import Callable, Tuple
9 |
10 |
11 | class Environment(dataclasses.PyTreeNode):
12 | """Learning environment.
13 |
14 | Attributes:
15 | model: MuJoCo model
16 | reward: function returning per-timestep reward
17 | observation: function returning per-timestep observation
18 | reset: function reseting environment, returns mjx.Data
19 | done: function determining environment termination
20 | naction: number of actions
21 | nobservation: number of observations
22 | ndecimation: number of physics steps per policy step
23 | """
24 |
25 | model: mjx.Model
26 | data: mjx.Data
27 | reward: Callable
28 | observation: Callable
29 | reset: Callable
30 | done: Callable
31 | naction: int
32 | nobservation: int
33 | ndecimation: int
34 |
35 |
36 | @partial(jax.jit, static_argnums=(3,))
37 | def multistep(
38 | model: mjx.Model, data: mjx.Data, action: jax.Array, nstep: int
39 | ) -> mjx.Data:
40 | """Multiple physics steps for action.
41 |
42 | Args:
43 | model (mjx.Model)
44 | data (mjx.Data)
45 | action (jax.Array)
46 | nstep (int): number of physics steps.
47 |
48 | Returns:
49 | mjx.Data
50 | """
51 | # set action
52 | data = data.replace(ctrl=action)
53 |
54 | # step physics
55 | def step(d, _):
56 | # step dynamics
57 | d = forward.step(model, d)
58 |
59 | return d, None
60 |
61 | # next data
62 | data, _ = jax.lax.scan(step, data, None, length=nstep, unroll=nstep)
63 | return data
64 |
65 | def rollout(
66 | env: Environment,
67 | p: Policy,
68 | d: mjx.Data,
69 | shift: float = 0.0,
70 | nhorizon: int = 1000,
71 | ) -> Tuple[float, jax.Array]:
72 | """Simulate environment.
73 |
74 | Args:
75 | env (Environment): simulation environment
76 | p (Policy): affine feedback policy
77 | d (mjx.Data): MuJoCo data
78 | shift (float): subtract value from per-timestep reward
79 | nhorizon: number of environment steps
80 | rng: JAX random number key
81 |
82 | Returns:
83 | Tuple[jax.Array, jax.Array]: per-timestep rewards and observations
84 | """
85 | # get observation
86 | obs = env.observation(env.model, d)
87 |
88 | # initialize observation statistics
89 | # (mean, var, count)
90 | obs_stats_init = (obs.copy(), jnp.zeros_like(obs), 1)
91 |
92 | # continue physics steps
93 | def continue_step(carry):
94 | # unpack carry
95 | data, total_reward, obs, obs_stats, steps, done = carry
96 | return jnp.logical_and(jnp.where(done, False, True), steps < nhorizon)
97 |
98 | # step
99 | def step(carry):
100 | # unpack carry
101 | data, total_reward, obs, obs_stats, steps, done = carry
102 |
103 | # get action
104 | action = policy(p, obs)
105 |
106 | # step
107 | next_data = multistep(env.model, data, action, env.ndecimation)
108 |
109 | # done
110 | next_done = jnp.logical_or(env.done(env.model, next_data), done)
111 | not_done = jnp.where(next_done, 0, 1)
112 |
113 | # get reward
114 | reward = env.reward(
115 | env.model, data, next_data, env.model.opt.timestep * env.ndecimation
116 | )
117 | reward -= shift
118 | reward *= not_done
119 |
120 | # get observation
121 | next_obs = env.observation(env.model, next_data)
122 |
123 | # unpack observation statistics
124 | mean, var, count = obs_stats
125 |
126 | # update observation statistics
127 | delta = next_obs - mean
128 | next_count = count + not_done
129 | next_mean = mean + not_done * delta / next_count
130 | next_var = var + not_done * delta * delta * count / next_count
131 | next_obs_stats = (next_mean, next_var, next_count)
132 |
133 | return (
134 | next_data,
135 | total_reward + reward,
136 | next_obs,
137 | next_obs_stats,
138 | steps + 1,
139 | next_done,
140 | )
141 |
142 | # loop
143 | carry = jax.lax.while_loop(continue_step, step, (d, 0.0, obs, obs_stats_init, 0, False))
144 |
145 | total_reward = carry[1]
146 | obs_stats = carry[3]
147 |
148 | return total_reward, obs_stats
149 |
150 |
151 | # jit and vmap rollout
152 | v_rollout = jax.jit(jax.vmap(rollout, in_axes=(None, 0, 0, None, None)))
153 | v_rollout_eval = jax.jit(jax.vmap(rollout, in_axes=(None, None, 0, None, None)))
154 |
--------------------------------------------------------------------------------
/rs/envs/ant.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from mujoco import mjx
4 |
5 | from rs.policy import initialize_policy
6 | from rs.environment import Environment
7 | from rs.utilities import load_model, visualize
8 |
9 |
10 | def ant_environment():
11 | """Return Ant learning environment."""
12 |
13 | # load MuJoCo model + data for MuJoCo C | MuJoCo XLA
14 | mc, dc, mx, dx = load_model("ant")
15 |
16 | # sizes
17 | naction = mx.nu
18 | nobservation = mx.nq - 2 + mx.nv
19 | ndecimation = 5
20 |
21 | # healthy
22 | def is_healthy(m: mjx.Model, d: mjx.Data) -> bool:
23 | z = d.qpos[2]
24 |
25 | min_z, max_z = 0.2, 1.0
26 |
27 | healthy_z = jnp.logical_and(min_z < z, z < max_z)
28 |
29 | return healthy_z
30 |
31 |
32 | def reward(m: mjx.Model, d0: mjx.Data, d1: mjx.Data, dt: float) -> float:
33 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/walker2d_v4.py
34 |
35 | # healthy
36 | r_healthy = is_healthy(m, d1).astype(float)
37 |
38 | # forward
39 | r_forward = (d1.qpos[0] - d0.qpos[0]) / dt
40 |
41 | # control penalty
42 | r_control = jnp.sum(jnp.square(d1.ctrl))
43 |
44 | return 1.0 * r_healthy + 1.0 * r_forward - 0.5 * r_control
45 |
46 |
47 | def observation(m: mjx.Model, d: mjx.Data) -> jax.Array:
48 | return jnp.hstack([d.qpos[2:], jnp.clip(d.qvel, -10.0, 10.0)])
49 |
50 |
51 | def reset(m: mjx.Model, d: mjx.Data, rng) -> mjx.Data:
52 | # key
53 | key_pos, key_vel = jax.random.split(rng)
54 |
55 | # qpos
56 | dqpos = 0.1 * jax.random.uniform(
57 | key_pos, shape=(m.nq,), minval=-1.0, maxval=1.0
58 | )
59 | qpos = m.qpos0
60 | qpos = qpos.at[:3].set(qpos[:3] + dqpos[:3])
61 | qpos = qpos.at[7:].set(qpos[7:] + dqpos[7:])
62 |
63 | # qvel
64 | qvel = jnp.clip(jax.random.normal(
65 | key_vel,
66 | shape=(m.nv,),
67 | ), -0.1, 1.0)
68 |
69 | # update data
70 | d = d.replace(qpos=qpos, qvel=qvel)
71 |
72 | return d
73 |
74 |
75 | def done(m: mjx.Model, d: mjx.Data) -> bool:
76 | return jnp.where(is_healthy(m, d), 0, 1)
77 |
78 |
79 | env = Environment(
80 | model=mx,
81 | data=dx,
82 | reward=reward,
83 | observation=observation,
84 | reset=reset,
85 | done=done,
86 | naction=naction,
87 | nobservation=nobservation,
88 | ndecimation=ndecimation,
89 | )
90 |
91 | # policy
92 | limits = (mc.actuator_ctrlrange[:, 0], mc.actuator_ctrlrange[:, 1])
93 | p = initialize_policy(env.naction, env.nobservation, limits)
94 |
95 | # search settings
96 | settings = {"nsample": 60,
97 | "ntop": 20,
98 | "niter": 1000,
99 | "neval": 100,
100 | "nhorizon_search": 1000,
101 | "nhorizon_eval": 1000,
102 | "random_step": 0.025,
103 | "update_step": 0.015,
104 | "nenveval": 128,
105 | "reward_shift": 1.0}
106 |
107 | # tracking camera
108 | def lookat(viewer, data):
109 | viewer.cam.lookat[0] = data.qpos[0]
110 | viewer.cam.lookat[1] = data.qpos[1]
111 | viewer.cam.lookat[2] = data.qpos[2] - 0.35
112 | viewer.cam.distance = 4.0
113 |
114 | # visualize
115 | def vis(p):
116 | visualize(mc, dc, env, p, lookat=lookat)
117 |
118 | return env, p, settings, vis
119 |
120 |
121 | if __name__ == "__main__":
122 | # environment
123 | env, p, settings, vis = ant_environment()
124 |
125 | # visualize
126 | vis(p)
127 |
--------------------------------------------------------------------------------
/rs/envs/cheetah.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from mujoco import mjx
4 |
5 | from rs.policy import initialize_policy
6 | from rs.environment import Environment
7 | from rs.utilities import load_model, visualize
8 |
9 |
10 | def cheetah_environment():
11 | """Return Half Cheetah learning environment."""
12 | # load MuJoCo model + data for MuJoCo C | MuJoCo XLA
13 | mc, dc, mx, dx = load_model("cheetah")
14 |
15 | # sizes
16 | naction = mx.nu
17 | nobservation = mx.nq - 1 + mx.nv
18 | ndecimation = 5
19 |
20 | def reward(m: mjx.Model, d0: mjx.Data, d1: mjx.Data, dt: float) -> float:
21 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/half_cheetah_v4.py
22 | r_forward = (d1.qpos[0] - d0.qpos[0]) / dt
23 | r_control = -jnp.dot(d1.ctrl, d1.ctrl)
24 | return 1.0 * r_forward + 0.1 * r_control
25 |
26 | def observation(m: mjx.Model, d: mjx.Data) -> jax.Array:
27 | return jnp.hstack([d.qpos[1:], d.qvel])
28 |
29 | def reset(m: mjx.Model, d: mjx.Data, rng) -> mjx.Data:
30 | # key
31 | key_pos, key_vel = jax.random.split(rng)
32 |
33 | # qpos
34 | is_limited = mx.jnt_limited == 1
35 | lower, upper = mx.jnt_range[is_limited].T
36 | qpos = jax.random.uniform(key_pos, shape=(mx.nq,), minval=-0.1, maxval=0.1)
37 | qclip = jnp.clip(qpos[is_limited], a_min=lower, a_max=upper)
38 | qpos = qpos.at[is_limited].set(qclip)
39 |
40 | # qvel
41 | qvel = jnp.clip(
42 | 0.1 * jax.random.normal(key_vel, shape=(mx.nv,)), a_min=-1.0, a_max=1.0
43 | )
44 |
45 | # update data
46 | d = d.replace(qpos=qpos, qvel=qvel)
47 |
48 | return d
49 |
50 | def done(m: mjx.Model, d: mjx.Data) -> bool:
51 | return False
52 |
53 | env = Environment(
54 | model=mx,
55 | data=dx,
56 | reward=reward,
57 | observation=observation,
58 | reset=reset,
59 | done=done,
60 | naction=naction,
61 | nobservation=nobservation,
62 | ndecimation=ndecimation,
63 | )
64 |
65 | # policy
66 | limits = (mc.actuator_ctrlrange[:, 0], mc.actuator_ctrlrange[:, 1])
67 | p = initialize_policy(env.naction, env.nobservation, limits)
68 |
69 | # search settings
70 | settings = {
71 | "nsample": 32,
72 | "ntop": 4,
73 | "niter": 100,
74 | "neval": 10,
75 | "nhorizon_search": 1000,
76 | "nhorizon_eval": 1000,
77 | "random_step": 0.03,
78 | "update_step": 0.02,
79 | "nenveval": 128,
80 | "reward_shift": 0.0,
81 | }
82 |
83 | # tracking camera
84 | def lookat(viewer, data):
85 | viewer.cam.lookat[0] = data.qpos[0]
86 | viewer.cam.lookat[2] = data.qpos[1] + 0.5
87 | viewer.cam.distance = 4.0
88 |
89 | # visualize
90 | def vis(p):
91 | visualize(mc, dc, env, p, lookat=lookat)
92 |
93 | return env, p, settings, vis
94 |
95 |
96 | if __name__ == "__main__":
97 | # environment
98 | env, p, settings, vis = cheetah_environment()
99 |
100 | # visualize
101 | vis(p)
102 |
--------------------------------------------------------------------------------
/rs/envs/humanoid.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from mujoco import mjx
4 |
5 | from rs.policy import initialize_policy
6 | from rs.environment import Environment
7 | from rs.utilities import load_model, visualize
8 |
9 |
10 | def humanoid_environment():
11 | # load MuJoCo model + data for MuJoCo C | MuJoCo XLA
12 | mc, dc, mx, dx = load_model("humanoid")
13 |
14 | # sizes
15 | naction = mx.nu
16 | nobservation = mx.nq - 2 + mx.nv
17 | nobservation += (mx.nbody - 1) * 10 # d.cinert
18 | nobservation += (mx.nbody - 1) * 6 # d.cvel
19 | nobservation += mx.nv - 6 # d.qfrc_actuator
20 | ndecimation = 5
21 |
22 | # healthy
23 | def is_healthy(m: mjx.Model, d: mjx.Data) -> bool:
24 | z = d.qpos[2]
25 |
26 | min_z, max_z = 1.0, 2.0
27 |
28 | healthy_z = jnp.logical_and(min_z < z, z < max_z)
29 |
30 | return healthy_z
31 |
32 |
33 | def reward(m: mjx.Model, d0: mjx.Data, d1: mjx.Data, dt: float) -> float:
34 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/walker2d_v4.py
35 |
36 | # healthy
37 | r_healthy = is_healthy(m, d1).astype(float)
38 |
39 | # forward
40 | r_forward = (d1.qpos[0] - d0.qpos[0]) / dt
41 |
42 | # control penalty
43 | r_control = jnp.sum(jnp.square(d1.ctrl))
44 |
45 | return 5.0 * r_healthy + 1.25 * r_forward - 0.1 * r_control
46 |
47 |
48 | def observation(m: mjx.Model, d: mjx.Data) -> jax.Array:
49 | pos = d.qpos[2:]
50 | vel = d.qvel
51 |
52 | com_inertia = d.cinert[1:].flatten()
53 | com_velocity = d.cvel[1:].flatten()
54 | actuator_forces = d.qfrc_actuator[6:]
55 |
56 | return jnp.hstack([pos, vel, com_inertia, com_velocity, actuator_forces])
57 |
58 |
59 | def reset(m: mjx.Model, d: mjx.Data, rng) -> mjx.Data:
60 | # key
61 | key_pos, key_vel = jax.random.split(rng)
62 |
63 | # qpos
64 | qpos = m.qpos0 + 0.01 * jax.random.uniform(
65 | key_pos, shape=(m.nq,), minval=-1.0, maxval=1.0
66 | )
67 |
68 | # qvel
69 | qvel = 0.01 * jax.random.uniform(
70 | key_vel,
71 | shape=(m.nv,), minval=-1.0, maxval=1.0,
72 | )
73 |
74 | # update data
75 | d = d.replace(qpos=qpos, qvel=qvel)
76 |
77 | return d
78 |
79 |
80 | def done(m: mjx.Model, d: mjx.Data) -> bool:
81 | return jnp.where(is_healthy(m, d), 0, 1)
82 |
83 |
84 | env = Environment(
85 | model=mx,
86 | data=dx,
87 | reward=reward,
88 | observation=observation,
89 | reset=reset,
90 | done=done,
91 | naction=naction,
92 | nobservation=nobservation,
93 | ndecimation=ndecimation,
94 | )
95 |
96 | # policy
97 | limits = (mc.actuator_ctrlrange[:, 0], mc.actuator_ctrlrange[:, 1])
98 | p = initialize_policy(env.naction, env.nobservation, limits)
99 |
100 | # search settings
101 | settings = {
102 | "nsample": 320,
103 | "ntop": 320,
104 | "niter": 1000,
105 | "neval": 100,
106 | "nhorizon_search": 1000,
107 | "nhorizon_eval": 1000,
108 | "random_step": 0.0075,
109 | "update_step": 0.02,
110 | "nenveval": 128,
111 | "reward_shift": 5.0,
112 | }
113 |
114 | # tracking camera
115 | def lookat(viewer, data):
116 | viewer.cam.lookat[0] = data.qpos[0]
117 | viewer.cam.lookat[1] = data.qpos[1]
118 | viewer.cam.lookat[2] = data.qpos[2]
119 | viewer.cam.distance = 4.0
120 |
121 | # visualize
122 | def vis(p):
123 | visualize(mc, dc, env, p, lookat=lookat)
124 |
125 | return env, p, settings, vis
126 |
127 |
128 | if __name__ == "__main__":
129 | # environment
130 | env, p, settings, vis = humanoid_environment()
131 |
132 | # visualize
133 | vis(p)
134 |
--------------------------------------------------------------------------------
/rs/envs/walker.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from mujoco import mjx
4 |
5 | from rs.policy import initialize_policy
6 | from rs.environment import Environment
7 | from rs.utilities import load_model, visualize
8 |
9 | def walker_environment():
10 | # load MuJoCo model + data for MuJoCo C | MuJoCo XLA
11 | mc, dc, mx, dx = load_model("walker")
12 |
13 | # sizes
14 | naction = mx.nu
15 | nobservation = mx.nq - 1 + mx.nv
16 | ndecimation = 4
17 |
18 | # healthy
19 | def is_healthy(m: mjx.Model, d: mjx.Data) -> bool:
20 | z, angle = d.qpos[1:3]
21 |
22 | min_z, max_z = 0.8, 2.0
23 | min_angle, max_angle = -1.0, 1.0
24 |
25 | healthy_z = jnp.logical_and(min_z < z, z < max_z)
26 | healthy_angle = jnp.logical_and(min_angle < angle, angle < max_angle)
27 |
28 | return jnp.logical_and(healthy_z, healthy_angle)
29 |
30 |
31 | def reward(m: mjx.Model, d0: mjx.Data, d1: mjx.Data, dt: float) -> float:
32 | # https://github.com/openai/gym/blob/master/gym/envs/mujoco/walker2d_v4.py
33 |
34 | # healthy
35 | r_healthy = is_healthy(m, d1).astype(float)
36 |
37 | # forward
38 | r_forward = (d1.qpos[0] - d0.qpos[0]) / dt
39 |
40 | # control penalty
41 | r_control = -jnp.sum(jnp.square(d1.ctrl))
42 |
43 | return 1.0 * r_healthy + 1.0 * r_forward + 0.001 * r_control
44 |
45 |
46 | def observation(m: mjx.Model, d: mjx.Data) -> jax.Array:
47 | return jnp.hstack([d.qpos[1:], jnp.clip(d.qvel, -10.0, 10.0)])
48 |
49 |
50 | def reset(m: mjx.Model, d: mjx.Data, rng) -> mjx.Data:
51 | # key
52 | key_pos, key_vel = jax.random.split(rng)
53 |
54 | # qpos
55 | qpos = m.qpos0 + 5.0e-3 * jax.random.uniform(
56 | key_pos, shape=(m.nq,), minval=-1.0, maxval=1.0
57 | )
58 |
59 | # qvel
60 | qvel = jnp.zeros(m.nv) + 5.0e-3 * jax.random.uniform(
61 | key_vel, shape=(m.nv,), minval=-1.0, maxval=1.0
62 | )
63 |
64 | # update data
65 | d = d.replace(qpos=qpos, qvel=qvel)
66 |
67 | return d
68 |
69 |
70 | def done(m: mjx.Model, d: mjx.Data) -> bool:
71 | return jnp.where(is_healthy(m, d), 0, 1)
72 |
73 |
74 | env = Environment(
75 | model=mx,
76 | data=dx,
77 | reward=reward,
78 | observation=observation,
79 | reset=reset,
80 | done=done,
81 | naction=naction,
82 | nobservation=nobservation,
83 | ndecimation=ndecimation,
84 | )
85 |
86 | # policy
87 | limits = (mc.actuator_ctrlrange[:, 0], mc.actuator_ctrlrange[:, 1])
88 | p = initialize_policy(env.naction, env.nobservation, limits)
89 |
90 | # search settings
91 | settings = {
92 | "nsample": 40,
93 | "ntop": 30,
94 | "niter": 1000,
95 | "neval": 100,
96 | "nhorizon_search": 1000,
97 | "nhorizon_eval": 1000,
98 | "random_step": 0.025,
99 | "update_step": 0.015,
100 | "nenveval": 128,
101 | "reward_shift": 1.0}
102 |
103 | # tracking camera
104 | def lookat(viewer, data):
105 | viewer.cam.lookat[0] = data.qpos[0]
106 | viewer.cam.lookat[2] = data.qpos[1]
107 | viewer.cam.distance = 4.0
108 |
109 | # visualize
110 | def vis(p):
111 | visualize(mc, dc, env, p, lookat=lookat)
112 |
113 | return env, p, settings, vis
114 |
115 | if __name__ == "__main__":
116 | # environment
117 | env, p, settings, vis = walker_environment()
118 |
119 | # visualize
120 | vis(p)
121 |
--------------------------------------------------------------------------------
/rs/models/LICENSE_ant:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020 Philipp Moritz, The dm_control Authors
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
--------------------------------------------------------------------------------
/rs/models/LICENSE_dm_control:
--------------------------------------------------------------------------------
1 |
2 | Apache License
3 | Version 2.0, January 2004
4 | http://www.apache.org/licenses/
5 |
6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7 |
8 | 1. Definitions.
9 |
10 | "License" shall mean the terms and conditions for use, reproduction,
11 | and distribution as defined by Sections 1 through 9 of this document.
12 |
13 | "Licensor" shall mean the copyright owner or entity authorized by
14 | the copyright owner that is granting the License.
15 |
16 | "Legal Entity" shall mean the union of the acting entity and all
17 | other entities that control, are controlled by, or are under common
18 | control with that entity. For the purposes of this definition,
19 | "control" means (i) the power, direct or indirect, to cause the
20 | direction or management of such entity, whether by contract or
21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
22 | outstanding shares, or (iii) beneficial ownership of such entity.
23 |
24 | "You" (or "Your") shall mean an individual or Legal Entity
25 | exercising permissions granted by this License.
26 |
27 | "Source" form shall mean the preferred form for making modifications,
28 | including but not limited to software source code, documentation
29 | source, and configuration files.
30 |
31 | "Object" form shall mean any form resulting from mechanical
32 | transformation or translation of a Source form, including but
33 | not limited to compiled object code, generated documentation,
34 | and conversions to other media types.
35 |
36 | "Work" shall mean the work of authorship, whether in Source or
37 | Object form, made available under the License, as indicated by a
38 | copyright notice that is included in or attached to the work
39 | (an example is provided in the Appendix below).
40 |
41 | "Derivative Works" shall mean any work, whether in Source or Object
42 | form, that is based on (or derived from) the Work and for which the
43 | editorial revisions, annotations, elaborations, or other modifications
44 | represent, as a whole, an original work of authorship. For the purposes
45 | of this License, Derivative Works shall not include works that remain
46 | separable from, or merely link (or bind by name) to the interfaces of,
47 | the Work and Derivative Works thereof.
48 |
49 | "Contribution" shall mean any work of authorship, including
50 | the original version of the Work and any modifications or additions
51 | to that Work or Derivative Works thereof, that is intentionally
52 | submitted to Licensor for inclusion in the Work by the copyright owner
53 | or by an individual or Legal Entity authorized to submit on behalf of
54 | the copyright owner. For the purposes of this definition, "submitted"
55 | means any form of electronic, verbal, or written communication sent
56 | to the Licensor or its representatives, including but not limited to
57 | communication on electronic mailing lists, source code control systems,
58 | and issue tracking systems that are managed by, or on behalf of, the
59 | Licensor for the purpose of discussing and improving the Work, but
60 | excluding communication that is conspicuously marked or otherwise
61 | designated in writing by the copyright owner as "Not a Contribution."
62 |
63 | "Contributor" shall mean Licensor and any individual or Legal Entity
64 | on behalf of whom a Contribution has been received by Licensor and
65 | subsequently incorporated within the Work.
66 |
67 | 2. Grant of Copyright License. Subject to the terms and conditions of
68 | this License, each Contributor hereby grants to You a perpetual,
69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70 | copyright license to reproduce, prepare Derivative Works of,
71 | publicly display, publicly perform, sublicense, and distribute the
72 | Work and such Derivative Works in Source or Object form.
73 |
74 | 3. Grant of Patent License. Subject to the terms and conditions of
75 | this License, each Contributor hereby grants to You a perpetual,
76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77 | (except as stated in this section) patent license to make, have made,
78 | use, offer to sell, sell, import, and otherwise transfer the Work,
79 | where such license applies only to those patent claims licensable
80 | by such Contributor that are necessarily infringed by their
81 | Contribution(s) alone or by combination of their Contribution(s)
82 | with the Work to which such Contribution(s) was submitted. If You
83 | institute patent litigation against any entity (including a
84 | cross-claim or counterclaim in a lawsuit) alleging that the Work
85 | or a Contribution incorporated within the Work constitutes direct
86 | or contributory patent infringement, then any patent licenses
87 | granted to You under this License for that Work shall terminate
88 | as of the date such litigation is filed.
89 |
90 | 4. Redistribution. You may reproduce and distribute copies of the
91 | Work or Derivative Works thereof in any medium, with or without
92 | modifications, and in Source or Object form, provided that You
93 | meet the following conditions:
94 |
95 | (a) You must give any other recipients of the Work or
96 | Derivative Works a copy of this License; and
97 |
98 | (b) You must cause any modified files to carry prominent notices
99 | stating that You changed the files; and
100 |
101 | (c) You must retain, in the Source form of any Derivative Works
102 | that You distribute, all copyright, patent, trademark, and
103 | attribution notices from the Source form of the Work,
104 | excluding those notices that do not pertain to any part of
105 | the Derivative Works; and
106 |
107 | (d) If the Work includes a "NOTICE" text file as part of its
108 | distribution, then any Derivative Works that You distribute must
109 | include a readable copy of the attribution notices contained
110 | within such NOTICE file, excluding those notices that do not
111 | pertain to any part of the Derivative Works, in at least one
112 | of the following places: within a NOTICE text file distributed
113 | as part of the Derivative Works; within the Source form or
114 | documentation, if provided along with the Derivative Works; or,
115 | within a display generated by the Derivative Works, if and
116 | wherever such third-party notices normally appear. The contents
117 | of the NOTICE file are for informational purposes only and
118 | do not modify the License. You may add Your own attribution
119 | notices within Derivative Works that You distribute, alongside
120 | or as an addendum to the NOTICE text from the Work, provided
121 | that such additional attribution notices cannot be construed
122 | as modifying the License.
123 |
124 | You may add Your own copyright statement to Your modifications and
125 | may provide additional or different license terms and conditions
126 | for use, reproduction, or distribution of Your modifications, or
127 | for any such Derivative Works as a whole, provided Your use,
128 | reproduction, and distribution of the Work otherwise complies with
129 | the conditions stated in this License.
130 |
131 | 5. Submission of Contributions. Unless You explicitly state otherwise,
132 | any Contribution intentionally submitted for inclusion in the Work
133 | by You to the Licensor shall be under the terms and conditions of
134 | this License, without any additional terms or conditions.
135 | Notwithstanding the above, nothing herein shall supersede or modify
136 | the terms of any separate license agreement you may have executed
137 | with Licensor regarding such Contributions.
138 |
139 | 6. Trademarks. This License does not grant permission to use the trade
140 | names, trademarks, service marks, or product names of the Licensor,
141 | except as required for reasonable and customary use in describing the
142 | origin of the Work and reproducing the content of the NOTICE file.
143 |
144 | 7. Disclaimer of Warranty. Unless required by applicable law or
145 | agreed to in writing, Licensor provides the Work (and each
146 | Contributor provides its Contributions) on an "AS IS" BASIS,
147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148 | implied, including, without limitation, any warranties or conditions
149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150 | PARTICULAR PURPOSE. You are solely responsible for determining the
151 | appropriateness of using or redistributing the Work and assume any
152 | risks associated with Your exercise of permissions under this License.
153 |
154 | 8. Limitation of Liability. In no event and under no legal theory,
155 | whether in tort (including negligence), contract, or otherwise,
156 | unless required by applicable law (such as deliberate and grossly
157 | negligent acts) or agreed to in writing, shall any Contributor be
158 | liable to You for damages, including any direct, indirect, special,
159 | incidental, or consequential damages of any character arising as a
160 | result of this License or out of the use or inability to use the
161 | Work (including but not limited to damages for loss of goodwill,
162 | work stoppage, computer failure or malfunction, or any and all
163 | other commercial damages or losses), even if such Contributor
164 | has been advised of the possibility of such damages.
165 |
166 | 9. Accepting Warranty or Additional Liability. While redistributing
167 | the Work or Derivative Works thereof, You may choose to offer,
168 | and charge a fee for, acceptance of support, warranty, indemnity,
169 | or other liability obligations and/or rights consistent with this
170 | License. However, in accepting such obligations, You may act only
171 | on Your own behalf and on Your sole responsibility, not on behalf
172 | of any other Contributor, and only if You agree to indemnify,
173 | defend, and hold each Contributor harmless for any liability
174 | incurred by, or claims asserted against, such Contributor by reason
175 | of your accepting any such warranty or additional liability.
176 |
177 | END OF TERMS AND CONDITIONS
178 |
179 | APPENDIX: How to apply the Apache License to your work.
180 |
181 | To apply the Apache License to your work, attach the following
182 | boilerplate notice, with the fields enclosed by brackets "[]"
183 | replaced with your own identifying information. (Don't include
184 | the brackets!) The text should be enclosed in the appropriate
185 | comment syntax for the file format. We also recommend that a
186 | file or class name and description of purpose be included on the
187 | same "printed page" as the copyright notice for easier
188 | identification within third-party archives.
189 |
190 | Copyright [yyyy] [name of copyright owner]
191 |
192 | Licensed under the Apache License, Version 2.0 (the "License");
193 | you may not use this file except in compliance with the License.
194 | You may obtain a copy of the License at
195 |
196 | http://www.apache.org/licenses/LICENSE-2.0
197 |
198 | Unless required by applicable law or agreed to in writing, software
199 | distributed under the License is distributed on an "AS IS" BASIS,
200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201 | See the License for the specific language governing permissions and
202 | limitations under the License.
--------------------------------------------------------------------------------
/rs/models/LICENSE_gymnasium:
--------------------------------------------------------------------------------
1 | The MIT License
2 |
3 | Copyright (c) 2016 OpenAI
4 | Copyright (c) 2022 Farama Foundation
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in
14 | all copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22 | THE SOFTWARE.
--------------------------------------------------------------------------------
/rs/models/README.md:
--------------------------------------------------------------------------------
1 | # Models
2 |
3 | Models are originally from `Gymnasium` [Release 1.0.0a2 (v5)](https://github.com/Farama-Foundation/Gymnasium/releases/tag/v1.0.0a2).
4 |
5 | Modifications are made for improved performance on accelerators with MuJoCo XLA (MJX). Visualization is changed to match [dm_control](https://github.com/google-deepmind/dm_control).
6 |
--------------------------------------------------------------------------------
/rs/models/ant.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/rs/models/cheetah.xml:
--------------------------------------------------------------------------------
1 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
--------------------------------------------------------------------------------
/rs/models/common/materials.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
--------------------------------------------------------------------------------
/rs/models/common/skybox.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/rs/models/common/visual.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/rs/models/humanoid.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
--------------------------------------------------------------------------------
/rs/models/walker.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
--------------------------------------------------------------------------------
/rs/policy.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 | from mujoco.mjx._src import dataclasses
4 | from typing import Tuple
5 |
6 |
7 | class Policy(dataclasses.PyTreeNode):
8 | """Affine feedback policy.
9 |
10 | output = clip(weight @ ((input - shift) / scale), limits[0], limits[1])
11 |
12 | Attributes:
13 | weight: feedback matrix
14 | shift: input shift
15 | scale: input scaling
16 | limits: output limits
17 | """
18 |
19 | weight: jax.Array
20 | shift: jax.Array
21 | scale: jax.Array
22 | limits: Tuple[jax.Array, jax.Array]
23 |
24 |
25 | def initialize_policy(
26 | nact: int, nobs: int, limits: Tuple[jax.Array, jax.Array]
27 | ) -> Policy:
28 | """Initialize policy.
29 |
30 | Args:
31 | nact (int): action dimension
32 | nobs (int): observation dimension
33 | limits (Tuple[jax.Array, jax.Array]): action limits
34 |
35 | Returns:
36 | Policy
37 | """
38 | return Policy(
39 | weight=jnp.zeros((nact, nobs)),
40 | shift=jnp.zeros(nobs),
41 | scale=jnp.ones(nobs),
42 | limits=limits,
43 | )
44 |
45 |
46 | def policy(p: Policy, obs: jax.Array) -> jax.Array:
47 | """Evaluate policy.
48 |
49 | Args:
50 | p (Policy)
51 | obs (jax.Array): input to policy
52 |
53 | Returns:
54 | jax.Array: output from policy
55 | """
56 | return jnp.clip(
57 | p.weight @ ((obs - p.shift) / (p.scale + 1.0e-5)),
58 | a_min=p.limits[0],
59 | a_max=p.limits[1],
60 | )
61 |
62 |
63 | def noisy_policy(p: Policy, scale: float, rng) -> Tuple[Policy, jax.Array]:
64 | """Sample noisy policy.
65 |
66 | Args:
67 | p (Policy)
68 | scale (float): scaling
69 | rng (jax.Array): JAX random number key
70 |
71 | Returns:
72 | Policy
73 | perturb (jax.Array): perturbation
74 | """
75 | # sample noise: perturb ~ N(0, I)
76 | perturb = jax.random.normal(rng, shape=p.weight.shape)
77 |
78 | # copy policy
79 | noisy_policy = Policy(
80 | weight=p.weight.copy() + scale * perturb,
81 | shift=p.shift.copy(),
82 | scale=p.scale.copy(),
83 | limits=(p.limits[0], p.limits[1]),
84 | )
85 | return noisy_policy, perturb
86 |
87 |
88 | # jit and vmap noisy_policy
89 | v_noisy_policy = jax.jit(jax.vmap(noisy_policy, in_axes=(None, 0, 0)))
90 |
--------------------------------------------------------------------------------
/rs/search.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Tuple
3 | import jax
4 | import jax.numpy as jnp
5 | from mujoco.mjx._src import dataclasses
6 | import pathlib
7 | from pathlib import Path
8 | import pickle
9 | import time
10 |
11 | from rs.environment import Environment, v_rollout, v_rollout_eval
12 | from rs.policy import Policy, v_noisy_policy
13 |
14 | # running statistics: (mean, var, count)
15 | RunningStatistics = Tuple[jax.Array, jax.Array, int]
16 |
17 |
18 | class Search(dataclasses.PyTreeNode):
19 | """Search settings.
20 |
21 | Attributes:
22 | nsample: number of samples to evaluate at each iteration (+, -)
23 | ntop: number of top samples used for update
24 | niter: number of iterations
25 | neval: number of evaluations
26 | nhorizon_search: number of steps to simulate policy for improvement with search
27 | nhorizon_eval: number of steps to simulate policy for evaluation
28 | random_step: search size; p +- step * N(0, I)
29 | update_step: parameter update size: W += step * update
30 | step_direction: directions for random_step
31 | nenveval: number of environments for evaluation
32 | reward_shift: value to subtract from per-timestep environment reward
33 | """
34 |
35 | nsample: int
36 | ntop: int
37 | niter: int
38 | neval: int
39 | nhorizon_search: int
40 | nhorizon_eval: int
41 | random_step: float
42 | update_step: float
43 | step_direction: jax.Array
44 | nenveval: int
45 | reward_shift: float
46 |
47 |
48 | def initialize_search(
49 | nsample: int,
50 | ntop: int,
51 | niter: int,
52 | neval: int,
53 | nhorizon_search: int,
54 | nhorizon_eval: int,
55 | random_step: float,
56 | update_step: float,
57 | nenveval: int = 128,
58 | reward_shift: float = 0.0
59 | ) -> Search:
60 | """Create Search with settings.
61 |
62 | Args:
63 | nsample (int)
64 | ntop (int)
65 | niter (int)
66 | neval (int)
67 | nhorizon_search (int)
68 | nhorizon_eval (int)
69 | random_step (float)
70 | update_step (float)
71 | nenveval (int)
72 | reward_shift (float)
73 |
74 | Returns:
75 | Search
76 | """
77 | # step direction
78 | step_direction = jnp.concatenate(
79 | [
80 | random_step * jnp.ones(nsample),
81 | -random_step * jnp.ones(nsample),
82 | ]
83 | )
84 |
85 | return Search(
86 | nsample=nsample,
87 | ntop=ntop,
88 | niter=niter,
89 | neval=neval,
90 | nhorizon_search=nhorizon_search,
91 | nhorizon_eval=nhorizon_eval,
92 | random_step=random_step,
93 | update_step=update_step,
94 | step_direction=step_direction,
95 | nenveval=nenveval,
96 | reward_shift=reward_shift,
97 | )
98 |
99 |
100 | @partial(jax.jit, static_argnums=(5,))
101 | def search(
102 | s: Search,
103 | env: Environment,
104 | p: Policy,
105 | obs_stats: RunningStatistics,
106 | rng: jax.Array,
107 | iter: int = 1,
108 | ) -> Tuple[Policy, RunningStatistics, Tuple[jax.Array, jax.Array]]:
109 | """Improve policy with random search.
110 |
111 | Returns:
112 | Policy
113 | obs_stats (RunningStatistics)
114 | reward_stats (Tuple[jax.Array, jax.Array, jax.Array]): reward statistics (mean, std, count)
115 | """
116 |
117 | ## iteration
118 | def iteration(carry, _):
119 | # unpack
120 | rng, p, obs_stats = carry
121 |
122 | # random
123 | keys = jax.random.split(rng, 3 * s.nsample + 1)
124 | rng = keys[-1]
125 | key_perturb = keys[:s.nsample]
126 | key_policy_positive_negative = jnp.concatenate([key_perturb, key_perturb])
127 |
128 | # noisy policies (and perturbations)
129 | policy_noisy, policy_perturb = v_noisy_policy(
130 | p, s.step_direction, key_policy_positive_negative
131 | )
132 |
133 | # reset
134 | key_reset = keys[s.nsample:-1]
135 | d_random = jax.vmap(env.reset, in_axes=(None, None, 0))(
136 | env.model, env.data, key_reset
137 | )
138 |
139 | # rollout noisy policies
140 | rewards, obs_stats_rollout = v_rollout(env, policy_noisy, d_random, s.reward_shift, s.nhorizon_search)
141 |
142 | # collect running statistics from environments
143 | def merge_running_statistics(stats0, stats1):
144 | # https://github.com/a-mitani/welford/blob/b7f96b9ad5e803d6de665c7df1cdcfb2a53bddc8/welford/welford.py#L132
145 | mean0, var0, count0 = stats0
146 | mean1, var1, count1 = stats1
147 |
148 | count = count0 + count1
149 |
150 | delta = mean0 - mean1
151 | delta2 = delta * delta
152 |
153 | mean = (count0 * mean0 + count1 * mean1) / count
154 | var = var0 + var1 + delta2 * count0 * count1 / count
155 |
156 | return (mean, var, count), count1
157 |
158 | obs_stats_updated, counts = jax.lax.scan(
159 | merge_running_statistics, obs_stats, obs_stats_rollout
160 | )
161 |
162 | # collect reward pairs
163 | rewards_pos_neg = jnp.vstack([rewards[: s.nsample], rewards[s.nsample :]])
164 |
165 | # reward pair max
166 | rewards_pos_neg_max = jnp.max(rewards_pos_neg, axis=0)
167 |
168 | # sort reward pairs descending, keep first ntop
169 | sort = jnp.argsort(rewards_pos_neg_max, descending=True)[: s.ntop]
170 |
171 | # ntop best pairs
172 | rewards_best = rewards_pos_neg[:, sort]
173 |
174 | # best pair statistics
175 | rewards_best_mean = rewards_best.flatten().mean()
176 | rewards_best_std = rewards_best.flatten().std()
177 | rewards_best_std = jnp.where(rewards_best_std < 1.0e-7, float("inf"), rewards_best_std)
178 |
179 | # new weights
180 | # https://arxiv.org/pdf/1803.07055.pdf: algorithm 2
181 | weight_update = jnp.einsum(
182 | "i,ijk->jk",
183 | jnp.dot(jnp.array([1.0, -1.0]), rewards_best),
184 | policy_perturb[sort],
185 | )
186 | weight = (
187 | p.weight
188 | + s.update_step / s.ntop / (rewards_best_std + 1.0e-5) * weight_update
189 | )
190 |
191 | # update policy
192 | mean, var, count = obs_stats_updated
193 | std = jnp.where(count > 1, jnp.sqrt(var / (count - 1)), 1.0)
194 | std = jnp.where(std < 1.0e-7, float("inf"), std)
195 | p = p.replace(
196 | weight=weight,
197 | shift=mean,
198 | scale=std,
199 | )
200 |
201 | return (rng, p, obs_stats_updated), (
202 | rewards_best_mean,
203 | rewards_best_std,
204 | jnp.mean(counts),
205 | )
206 |
207 | # loop
208 | initial_count = obs_stats[2]
209 | carry, reward_stats = jax.lax.scan(
210 | iteration,
211 | (
212 | rng,
213 | p,
214 | obs_stats,
215 | ),
216 | None,
217 | length=iter,
218 | )
219 | policy = carry[1]
220 | obs_stats_update = carry[2]
221 | env_steps = obs_stats_update[2] - initial_count
222 | return policy, obs_stats_update, reward_stats, env_steps
223 |
224 |
225 | def eval_search(
226 | s: Search,
227 | env: Environment,
228 | p: Policy,
229 | seed: int = 0,
230 | checkpoint: str = None,
231 | ) -> Policy:
232 | """Improve policy with random search and provide evaluation information during training.
233 |
234 | Returns
235 | Policy
236 | """
237 | print("Search:")
238 |
239 | # vmap reset
240 | v_reset = jax.jit(jax.vmap(env.reset, in_axes=(None, None, 0)))
241 |
242 | # create logs directory for checkpoints
243 | if checkpoint is not None:
244 | # directory
245 | checkpoint_dir = pathlib.Path(__file__).parent / "checkpoint"
246 |
247 | # create directory
248 | Path(str(checkpoint_dir)).mkdir(parents=True, exist_ok=True)
249 |
250 | # create subdirectory
251 | Path(str(checkpoint_dir) + "/" + checkpoint).mkdir(parents=True, exist_ok=True)
252 |
253 |
254 | # start total timer
255 | start_total = time.time()
256 |
257 | # initialize key from seed
258 | key = jax.random.PRNGKey(seed)
259 |
260 | # initialize observation statistics
261 | obs_stats = (
262 | jnp.zeros(env.nobservation),
263 | jnp.ones(env.nobservation),
264 | jnp.ones(1).astype(int),
265 | )
266 |
267 | # evaluation iterations
268 | iter_per_eval = int(s.niter / s.neval)
269 | for i in range(s.neval):
270 | # random
271 | rng, key = jax.random.split(key)
272 |
273 | # search
274 | start = time.time()
275 | p, obs_stats, _, search_steps = search(
276 | s, env, p, obs_stats, rng, iter=iter_per_eval
277 | )
278 |
279 | # stop search timer
280 | search_time = time.time() - start
281 |
282 | ## evaluate
283 | # random
284 | keys = jax.random.split(rng, s.nenveval + 2)
285 | key_eval_rollout = keys[:-2]
286 | rng, key = keys[-2:]
287 |
288 | # environment reset
289 | d_random = v_reset(env.model, env.data, key_eval_rollout)
290 |
291 | # rollout current
292 | rewards_eval, obs_stats_eval = v_rollout_eval(env, p, d_random, 0.0, s.nhorizon_eval)
293 |
294 | # stats
295 | reward_mean = rewards_eval.mean()
296 | reward_std = rewards_eval.std()
297 | avg_episode_len = int(obs_stats_eval[2].mean()) - 1
298 | env_steps = int(obs_stats[2][0]) - 1
299 | print(
300 | f"iteration ({(i + 1) * iter_per_eval} / {s.niter}): reward = {reward_mean:.2f} +- {reward_std:.2f} | time = {search_time:.2f} | avg episode length: {avg_episode_len} / {s.nhorizon_eval} | global steps: {env_steps} | steps/second: {int(search_steps[0]/search_time)}"
301 | )
302 |
303 | # save checkpoint
304 | if checkpoint is not None:
305 | checkpoint_path = (
306 | str(checkpoint_dir)
307 | + "/"
308 | + checkpoint
309 | + "/"
310 | + checkpoint
311 | + "_"
312 | + str(i)
313 | + "_"
314 | + "{:.2f}".format(reward_mean)
315 | + "_"
316 | + "{:.2f}".format(reward_std)
317 | )
318 | with open(checkpoint_path, "wb") as file:
319 | pickle.dump(p, file)
320 |
321 | # total time
322 | print(f"\ntotal time: {time.time() - start_total:.2f}")
323 |
324 | return p
325 |
--------------------------------------------------------------------------------
/rs/train.py:
--------------------------------------------------------------------------------
1 | from rs.search import initialize_search, eval_search
2 | from rs.utilities import load_policy, parse
3 |
4 | # import environments
5 | from rs.envs.ant import ant_environment
6 | from rs.envs.cheetah import cheetah_environment
7 | from rs.envs.humanoid import humanoid_environment
8 | from rs.envs.walker import walker_environment
9 |
10 |
11 | def train_settings(env, settings):
12 | """Print training information."""
13 | print("Settings:")
14 | print(f" environment: {env}")
15 | print(f" nsample: {settings['nsample']} | ntop: {settings['ntop']}")
16 | print(f" niter: {settings['niter']} | neval: {settings['neval']}")
17 | print(
18 | f" nhorizon_search: {settings['nhorizon_search']} | nhorizon_eval: {settings['nhorizon_eval']}"
19 | )
20 | print(
21 | f" random_step: {settings['random_step']} | update_step: {settings['update_step']}"
22 | )
23 | print(f" nenveval: {settings['nenveval']}")
24 | print(f" reward_shift: {settings['reward_shift']}")
25 |
26 |
27 | def train():
28 | """Train linear policy with random search."""
29 | # parse settings
30 | args = parse()
31 |
32 | # environment
33 | env, p, settings, vis = eval(str(args.env + "_environment()"))
34 |
35 | # load policy
36 | if args.load != "":
37 | p = load_policy(p, args.load)
38 |
39 | # search
40 | if args.search:
41 | # update settings with parsed arguments
42 | for k, v in vars(args).items():
43 | if k in settings and v is not None:
44 | settings[k] = v
45 |
46 | # settings
47 | train_settings(args.env, settings)
48 |
49 | # initialize
50 | s = initialize_search(**settings)
51 |
52 | # search + evaluation
53 | p = eval_search(s, env, p, seed=args.seed, checkpoint=args.checkpoint)
54 |
55 | # visualize policy
56 | if args.visualize:
57 | vis(p)
58 |
59 |
60 | if __name__ == "__main__":
61 | train()
62 |
--------------------------------------------------------------------------------
/rs/utilities.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple
2 | import argparse
3 | import mujoco
4 | import mujoco.viewer
5 | from mujoco import mjx
6 | import pathlib
7 | import pickle
8 | import time
9 |
10 | from rs.environment import Environment
11 | from rs.policy import Policy, policy
12 |
13 | def load_model(model: str) -> Tuple[mjx.Model, mjx.Data]:
14 | """Load MuJoCo XLA model and data."""
15 | # path to model
16 | path = pathlib.Path(__file__).parent / str("models/" + model + ".xml")
17 |
18 | # mjc model + data
19 | m = mujoco.MjModel.from_xml_path(str(path))
20 | d = mujoco.MjData(m)
21 |
22 | # mjx model + data
23 | m_mjx = mjx.put_model(m)
24 | d_mjx = mjx.put_data(m, d)
25 |
26 | return m, d, m_mjx, d_mjx
27 |
28 |
29 | def load_policy(p: Policy, file: str) -> Policy:
30 | """Load policy from checkpoint."""
31 | checkpoint_dir = pathlib.Path(__file__).parent / "checkpoint"
32 | checkpoint_path = str(checkpoint_dir) + "/" + file
33 | try:
34 | with open(checkpoint_path, 'rb') as f:
35 | p = pickle.load(f)
36 |
37 | print(f"Success: load policy from: {checkpoint_path}")
38 | except Exception as e:
39 | print(f"Failure: load policy from: {checkpoint_path}")
40 |
41 | return p
42 |
43 |
44 | def visualize(
45 | m: mujoco.MjModel,
46 | d: mujoco.MjData,
47 | env: Environment,
48 | p: Policy,
49 | lookat: Callable = lambda viewer, data: None,
50 | ):
51 | """Visualize learned policy with MuJoCo passive viewer."""
52 | # visualize policy
53 | with mujoco.viewer.launch_passive(m, d) as viewer:
54 | while viewer.is_running():
55 | # start timer
56 | t0 = time.time()
57 |
58 | # observations
59 | obs = env.observation(m, d)
60 |
61 | # compute and set actions
62 | d.ctrl = policy(p, obs)
63 |
64 | # simulate
65 | for _ in range(env.ndecimation):
66 | # step physics
67 | mujoco.mj_step(m, d)
68 |
69 | # camera tracking
70 | lookat(viewer, d)
71 |
72 | # sync visualization
73 | viewer.sync()
74 |
75 | # wait
76 | elapsed = time.time() - t0
77 | time.sleep(max(m.opt.timestep * env.ndecimation - elapsed, 0.0))
78 |
79 |
80 | def parse():
81 | """Parse command line arguments for environment and search settings."""
82 | # parser
83 | parser = argparse.ArgumentParser()
84 |
85 | # search setup
86 | parser.add_argument("--env", type=str, choices=["ant", "cheetah", "humanoid", "walker"], default="cheetah", help="Learning environment (default: cheetah)")
87 | parser.add_argument("--search", action="store_true", help="Random search to find policy")
88 | parser.add_argument("--load", type=str, default="", help="Path to saved policy (default: "")")
89 | parser.add_argument("--checkpoint", type=str, default=None, help="Path to saved policy (default: None)")
90 | parser.add_argument("--seed", type=int, default=0, help="Random seed (default: 0)")
91 | parser.add_argument("--visualize", action="store_true", help="Visualize policy")
92 |
93 | # search settings
94 | parser.add_argument("--nsample", type=int, default=None, help="Number of random directions to sample")
95 | parser.add_argument("--ntop", type=int, default=None, help="Number of random directions to use for policy update")
96 | parser.add_argument("--niter", type=int, default=None, help="Number of policy updates")
97 | parser.add_argument("--neval", type=int, default=None, help="Number of policy evaluations during search")
98 | parser.add_argument("--nhorizon_search", type=int, default=None, help="Number of environment steps during policy search")
99 | parser.add_argument("--nhorizon_eval", type=int, default=None, help="Number of environment steps during policy evaluation")
100 | parser.add_argument("--random_step", type=float, default=None, help="Step size for random direction")
101 | parser.add_argument("--update_step", type=float, default=None, help="Step size for policy update")
102 | parser.add_argument("--nenveval", type=int, default=None, help="Number of environments for policy evaluation")
103 | parser.add_argument("--reward_shift", type=float, default=None, help="Subtract from per-timestep reward")
104 |
105 | # args
106 | args = parser.parse_args()
107 |
108 | return args
109 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="rs",
5 | version="0.1",
6 | packages=find_packages(),
7 | install_requires=[
8 | "mujoco",
9 | "mujoco-mjx",
10 | "jax==0.4.29",
11 | "jaxlib @ https://storage.googleapis.com/jax-releases/cuda12/jaxlib-0.4.29+cuda12.cudnn91-cp310-cp310-manylinux2014_x86_64.whl",
12 | ],
13 | include_package_data=True,
14 | description="A simple JAX-based implementation of random search for locomotion tasks using MuJoCo XLA (MJX).",
15 | long_description=open("README.md").read(),
16 | url="https://github.com/thowell/rs",
17 | author="Taylor Howell",
18 | author_email="taylor.athaniel.howell@gmail.com",
19 | python_requires=">=3.10",
20 | )
21 |
--------------------------------------------------------------------------------