├── scale_rl ├── __init__.py ├── agents │ ├── simba │ │ ├── __init__.py │ │ ├── simba_layer.py │ │ ├── simba_network.py │ │ ├── simba_update.py │ │ └── simba_agent.py │ ├── jax_utils │ │ ├── __init__.py │ │ ├── tree_utils.py │ │ └── network.py │ ├── simbaV2 │ │ ├── __init__.py │ │ ├── simbaV2_network.py │ │ ├── simbaV2_layer.py │ │ └── simbaV2_update.py │ ├── wrappers │ │ ├── __init__.py │ │ ├── utils.py │ │ └── normalization.py │ ├── random_agent.py │ ├── __init__.py │ └── base_agent.py ├── envs │ ├── wrappers │ │ ├── __init__.py │ │ └── repeat_action.py │ ├── mujoco.py │ ├── dmc.py │ ├── d4rl.py │ ├── myosuite.py │ ├── __init__.py │ └── humanoid_bench.py ├── common │ ├── __init__.py │ ├── scheduler.py │ └── logger.py ├── buffers │ ├── __init__.py │ ├── base_buffer.py │ ├── utils.py │ └── numpy_buffer.py └── evaluation.py ├── docs ├── README.md ├── .DS_Store ├── images │ ├── header.png │ ├── online.png │ ├── analysis.png │ ├── offline.png │ ├── overview.png │ ├── benchmark.png │ ├── utd_scaling.png │ ├── param_scaling.png │ └── simbav2_architecture.png ├── thumbnails │ ├── 0.png │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ └── 7.png ├── dataset │ ├── videos │ │ ├── .DS_Store │ │ ├── dmc │ │ │ ├── simbav2-dog-run.mp4 │ │ │ ├── simbav2-dog-trot.mp4 │ │ │ ├── simbav2-dog-walk.mp4 │ │ │ ├── simbav2-cheetah-run.mp4 │ │ │ ├── simbav2-dog-stand.mp4 │ │ │ ├── simbav2-finger-spin.mp4 │ │ │ ├── simbav2-fish-swim.mp4 │ │ │ ├── simbav2-hopper-hop.mp4 │ │ │ ├── simbav2-walker-run.mp4 │ │ │ ├── simbav2-walker-walk.mp4 │ │ │ ├── simbav2-hopper-stand.mp4 │ │ │ ├── simbav2-humanoid-run.mp4 │ │ │ ├── simbav2-humanoid-walk.mp4 │ │ │ ├── simbav2-quadruped-run.mp4 │ │ │ ├── simbav2-reacher-easy.mp4 │ │ │ ├── simbav2-reacher-hard.mp4 │ │ │ ├── simbav2-walker-stand.mp4 │ │ │ ├── simbav2-acrobot-swingup.mp4 │ │ │ ├── simbav2-cartpole-balance.mp4 │ │ │ ├── simbav2-cartpole-swingup.mp4 │ │ │ ├── simbav2-finger-turn-easy.mp4 │ │ │ ├── simbav2-finger-turn-hard.mp4 │ │ │ ├── simbav2-humanoid-stand.mp4 │ │ │ ├── simbav2-pendulum-swingup.mp4 │ │ │ ├── simbav2-quadruped-walk.mp4 │ │ │ ├── simbav2-ball-in-cup-catch.mp4 │ │ │ ├── simbav2-cartpole-balance-sparse.mp4 │ │ │ └── simbav2-cartpole-swingup-sparse.mp4 │ │ ├── hbench │ │ │ ├── simbav2-h1-crawl.mp4 │ │ │ ├── simbav2-h1-maze.mp4 │ │ │ ├── simbav2-h1-pole.mp4 │ │ │ ├── simbav2-h1-reach.mp4 │ │ │ ├── simbav2-h1-run.mp4 │ │ │ ├── simbav2-h1-slide.mp4 │ │ │ ├── simbav2-h1-stair.mp4 │ │ │ ├── simbav2-h1-stand.mp4 │ │ │ ├── simbav2-h1-walk.mp4 │ │ │ ├── simbav2-h1-hurdle.mp4 │ │ │ ├── simbav2-h1-sit-hard.mp4 │ │ │ ├── simbav2-h1-sit-simple.mp4 │ │ │ ├── simbav2-h1-balance-hard.mp4 │ │ │ └── simbav2-h1-balance-simple.mp4 │ │ ├── mujoco │ │ │ ├── simbav2-ant-v4.mp4 │ │ │ ├── simbav2-cheetah-v4.mp4 │ │ │ ├── simbav2-hopper-v4.mp4 │ │ │ ├── simbav2-walker-v4.mp4 │ │ │ └── simbav2-humanoid-v4.mp4 │ │ └── myosuite │ │ │ ├── simbav2-myo-key-easy.mp4 │ │ │ ├── simbav2-myo-key-hard.mp4 │ │ │ ├── simbav2-myo-obj-easy.mp4 │ │ │ ├── simbav2-myo-obj-hard.mp4 │ │ │ ├── simbav2-myo-pen-easy.mp4 │ │ │ ├── simbav2-myo-pen-hard.mp4 │ │ │ ├── simbav2-myo-pose-easy.mp4 │ │ │ ├── simbav2-myo-pose-hard.mp4 │ │ │ ├── simbav2-myo-reach-easy.mp4 │ │ │ └── simbav2-myo-reach-hard.mp4 │ ├── css │ │ ├── bulma.min.css │ │ ├── bulma-carousel.min.css │ │ └── style.css │ └── js │ │ └── index.js ├── videos │ ├── dmc │ │ ├── simbav2-dog-run.mp4 │ │ ├── simbav2-dog-trot.mp4 │ │ ├── simbav2-humanoid-stand.mp4 │ │ └── simbav2-humanoid-walk.mp4 │ ├── mujoco │ │ ├── simbav2-ant-v4.mp4 │ │ ├── simbav2-cheetah-v4.mp4 │ │ ├── simbav2-humanoid-v4.mp4 │ │ └── simbav2-walker-v4.mp4 │ ├── hbench │ │ ├── simbav2-h1-crawl.mp4 │ │ ├── simbav2-h1-stand.mp4 │ │ ├── simbav2-h1-sit-hard.mp4 │ │ └── simbav2-h1-balance-hard.mp4 │ └── myosuite │ │ ├── simbav2-myo-key-hard.mp4 │ │ ├── simbav2-myo-obj-hard.mp4 │ │ ├── simbav2-myo-pen-hard.mp4 │ │ └── simbav2-myo-reach-hard.mp4 └── style.css ├── scripts └── simbaV2 │ ├── test_offline.sh │ └── test_online.sh ├── configs ├── agent │ ├── random.yaml │ ├── simba.yaml │ ├── simbaV2_bc.yaml │ └── simbaV2.yaml ├── buffer │ └── numpy_uniform.yaml ├── env │ ├── d4rl.yaml │ ├── mujoco.yaml │ ├── myosuite.yaml │ ├── dmc.yaml │ └── hb_locomotion.yaml ├── offline_rl.yaml └── online_rl.yaml ├── setup.py ├── deps ├── requirements.txt ├── environment.yaml └── Dockerfile ├── .github └── workflows │ └── pre-commit.yml ├── .pre-commit-config.yaml ├── .gitignore ├── README.md ├── run_offline.py ├── run_online.py └── run_parallel.py /scale_rl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # simbaV2.github.io -------------------------------------------------------------------------------- /scale_rl/agents/simba/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scale_rl/agents/jax_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scale_rl/agents/simbaV2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/.DS_Store -------------------------------------------------------------------------------- /docs/images/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/header.png -------------------------------------------------------------------------------- /docs/images/online.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/online.png -------------------------------------------------------------------------------- /docs/thumbnails/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/0.png -------------------------------------------------------------------------------- /docs/thumbnails/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/1.png -------------------------------------------------------------------------------- /docs/thumbnails/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/2.png -------------------------------------------------------------------------------- /docs/thumbnails/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/3.png -------------------------------------------------------------------------------- /docs/thumbnails/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/4.png -------------------------------------------------------------------------------- /docs/thumbnails/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/5.png -------------------------------------------------------------------------------- /docs/thumbnails/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/6.png -------------------------------------------------------------------------------- /docs/thumbnails/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/7.png -------------------------------------------------------------------------------- /scale_rl/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from scale_rl.envs.wrappers.repeat_action import RepeatAction 2 | -------------------------------------------------------------------------------- /docs/images/analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/analysis.png -------------------------------------------------------------------------------- /docs/images/offline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/offline.png -------------------------------------------------------------------------------- /docs/images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/overview.png -------------------------------------------------------------------------------- /docs/images/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/benchmark.png -------------------------------------------------------------------------------- /docs/images/utd_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/utd_scaling.png -------------------------------------------------------------------------------- /docs/dataset/videos/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/.DS_Store -------------------------------------------------------------------------------- /docs/images/param_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/param_scaling.png -------------------------------------------------------------------------------- /scale_rl/agents/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from scale_rl.agents.wrappers.normalization import ObservationNormalizer, RewardNormalizer -------------------------------------------------------------------------------- /docs/images/simbav2_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/simbav2_architecture.png -------------------------------------------------------------------------------- /docs/videos/dmc/simbav2-dog-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/dmc/simbav2-dog-run.mp4 -------------------------------------------------------------------------------- /docs/videos/dmc/simbav2-dog-trot.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/dmc/simbav2-dog-trot.mp4 -------------------------------------------------------------------------------- /docs/videos/mujoco/simbav2-ant-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/mujoco/simbav2-ant-v4.mp4 -------------------------------------------------------------------------------- /docs/videos/hbench/simbav2-h1-crawl.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/hbench/simbav2-h1-crawl.mp4 -------------------------------------------------------------------------------- /docs/videos/hbench/simbav2-h1-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/hbench/simbav2-h1-stand.mp4 -------------------------------------------------------------------------------- /docs/videos/dmc/simbav2-humanoid-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/dmc/simbav2-humanoid-stand.mp4 -------------------------------------------------------------------------------- /docs/videos/dmc/simbav2-humanoid-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/dmc/simbav2-humanoid-walk.mp4 -------------------------------------------------------------------------------- /docs/videos/hbench/simbav2-h1-sit-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/hbench/simbav2-h1-sit-hard.mp4 -------------------------------------------------------------------------------- /docs/videos/mujoco/simbav2-cheetah-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/mujoco/simbav2-cheetah-v4.mp4 -------------------------------------------------------------------------------- /docs/videos/mujoco/simbav2-humanoid-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/mujoco/simbav2-humanoid-v4.mp4 -------------------------------------------------------------------------------- /docs/videos/mujoco/simbav2-walker-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/mujoco/simbav2-walker-v4.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-dog-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-dog-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-dog-trot.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-dog-trot.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-dog-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-dog-walk.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-cheetah-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-cheetah-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-dog-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-dog-stand.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-finger-spin.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-finger-spin.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-fish-swim.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-fish-swim.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-hopper-hop.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-hopper-hop.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-walker-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-walker-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-walker-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-walker-walk.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-crawl.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-crawl.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-maze.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-maze.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-pole.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-pole.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-reach.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-reach.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-slide.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-slide.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-stair.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-stair.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-stand.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-walk.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/mujoco/simbav2-ant-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/mujoco/simbav2-ant-v4.mp4 -------------------------------------------------------------------------------- /docs/videos/hbench/simbav2-h1-balance-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/hbench/simbav2-h1-balance-hard.mp4 -------------------------------------------------------------------------------- /docs/videos/myosuite/simbav2-myo-key-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/myosuite/simbav2-myo-key-hard.mp4 -------------------------------------------------------------------------------- /docs/videos/myosuite/simbav2-myo-obj-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/myosuite/simbav2-myo-obj-hard.mp4 -------------------------------------------------------------------------------- /docs/videos/myosuite/simbav2-myo-pen-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/myosuite/simbav2-myo-pen-hard.mp4 -------------------------------------------------------------------------------- /docs/videos/myosuite/simbav2-myo-reach-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/myosuite/simbav2-myo-reach-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-hopper-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-hopper-stand.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-humanoid-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-humanoid-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-humanoid-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-humanoid-walk.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-quadruped-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-quadruped-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-reacher-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-reacher-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-reacher-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-reacher-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-walker-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-walker-stand.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-hurdle.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-hurdle.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/mujoco/simbav2-cheetah-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/mujoco/simbav2-cheetah-v4.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/mujoco/simbav2-hopper-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/mujoco/simbav2-hopper-v4.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/mujoco/simbav2-walker-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/mujoco/simbav2-walker-v4.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-acrobot-swingup.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-acrobot-swingup.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-cartpole-balance.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-cartpole-balance.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-cartpole-swingup.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-cartpole-swingup.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-finger-turn-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-finger-turn-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-finger-turn-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-finger-turn-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-humanoid-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-humanoid-stand.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-pendulum-swingup.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-pendulum-swingup.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-quadruped-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-quadruped-walk.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-sit-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-sit-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-sit-simple.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-sit-simple.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/mujoco/simbav2-humanoid-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/mujoco/simbav2-humanoid-v4.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-ball-in-cup-catch.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-ball-in-cup-catch.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-balance-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-balance-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-key-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-key-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-key-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-key-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-obj-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-obj-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-obj-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-obj-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-pen-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-pen-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-pen-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-pen-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-pose-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-pose-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-pose-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-pose-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-balance-simple.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-balance-simple.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-reach-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-reach-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-reach-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-reach-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-cartpole-balance-sparse.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-cartpole-balance-sparse.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-cartpole-swingup-sparse.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-cartpole-swingup-sparse.mp4 -------------------------------------------------------------------------------- /scale_rl/common/__init__.py: -------------------------------------------------------------------------------- 1 | from scale_rl.common.logger import WandbTrainerLogger 2 | from scale_rl.common.scheduler import ( 3 | linear_decay_scheduler, 4 | constant_value_scheduler, 5 | cyclic_exponential_decay_scheduler, 6 | ) -------------------------------------------------------------------------------- /scripts/simbaV2/test_offline.sh: -------------------------------------------------------------------------------- 1 | python run_parallel.py \ 2 | --server kaist \ 3 | --group_name final_test \ 4 | --exp_name simbav2_bc \ 5 | --config_name offline_rl \ 6 | --agent_config simbaV2_bc \ 7 | --env_type d4rl_mujoco \ 8 | --device_ids 0 1 2 3 4 5 6 7 \ 9 | --num_seeds 5 \ 10 | --num_exp_per_device 1 \ -------------------------------------------------------------------------------- /scripts/simbaV2/test_online.sh: -------------------------------------------------------------------------------- 1 | python run_parallel.py \ 2 | --server kaist \ 3 | --group_name final_test \ 4 | --exp_name simbaV2_no_hb \ 5 | --agent_config simbaV2 \ 6 | --env_type hb_locomotion \ 7 | --device_ids 0 1 2 3 4 5 6 7 \ 8 | --num_seeds 5 \ 9 | --num_exp_per_device 4 \ 10 | --overrides recording_per_interaction_step=9999999 -------------------------------------------------------------------------------- /configs/agent/random.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # This agent selects actions randomly. (Useful for debugging purposes.) 3 | ################################################################################## 4 | 5 | agent_type: 'random' 6 | 7 | seed: ${seed} 8 | normalize_observation: false 9 | normalize_reward: false -------------------------------------------------------------------------------- /configs/buffer/numpy_uniform.yaml: -------------------------------------------------------------------------------- 1 | buffer_class_type: 'numpy' # [numpy, jax (distributed-friendly)] 2 | buffer_type: 'uniform' # [uniform, prioritized] 3 | 4 | n_step: ${n_step} 5 | gamma: ${gamma} 6 | max_length: 1_000_000 # maximum buffer size. 7 | min_length: 5_000 # minimum buffer size (= number of data to collect before training). 8 | sample_batch_size: 256 # batch size for sampling = training. -------------------------------------------------------------------------------- /configs/env/d4rl.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # D4RL 3 | ################################################################################## 4 | 5 | env_type: 'd4rl' 6 | env_name: 'hopper-medium-replay-v2' 7 | num_env_steps: null 8 | episodic: true 9 | 10 | seed: ${seed} 11 | num_train_envs: 1 12 | num_eval_envs: 1 13 | action_repeat: 1 14 | rescale_action: true 15 | max_episode_steps: 1000 -------------------------------------------------------------------------------- /configs/env/mujoco.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Mujoco 3 | ################################################################################## 4 | 5 | env_type: 'mujoco' 6 | env_name: 'Humanoid-v4' 7 | num_env_steps: 1_000_000 8 | episodic: true 9 | 10 | seed: ${seed} 11 | num_train_envs: 1 12 | num_eval_envs: 1 13 | action_repeat: 1 14 | rescale_action: true 15 | max_episode_steps: 1000 -------------------------------------------------------------------------------- /configs/env/myosuite.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # MyoSuite 3 | ################################################################################## 4 | 5 | env_type: 'myosuite' 6 | env_name: 'myo-reach' 7 | num_env_steps: 1_000_000 8 | episodic: false 9 | 10 | seed: ${seed} 11 | num_train_envs: 1 12 | num_eval_envs: 1 13 | action_repeat: 2 14 | rescale_action: true 15 | max_episode_steps: 100 -------------------------------------------------------------------------------- /configs/env/dmc.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Deepmind Control Suite with hard tasks (humanoid & dog) 3 | ################################################################################## 4 | 5 | env_type: 'dmc' 6 | env_name: 'acrobot-swingup' 7 | num_env_steps: 1_000_000 8 | episodic: false 9 | 10 | seed: ${seed} 11 | num_train_envs: 1 12 | num_eval_envs: 1 13 | action_repeat: 2 14 | rescale_action: true 15 | max_episode_steps: 1000 -------------------------------------------------------------------------------- /configs/env/hb_locomotion.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Humanoid Benchmark with Locomotion tasks 3 | ################################################################################## 4 | 5 | env_type: 'humanoid_bench' 6 | env_name: 'h1-walk-v0' 7 | num_env_steps: 1_000_000 8 | episodic: true 9 | 10 | seed: ${seed} 11 | num_train_envs: 1 12 | num_eval_envs: 1 13 | action_repeat: 2 14 | rescale_action: true 15 | max_episode_steps: 1000 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("deps/requirements.txt") as f: 4 | install_requires = f.read() 5 | 6 | 7 | if __name__ == "__main__": 8 | setup( 9 | name="scale_rl", 10 | version="0.1.0", 11 | url="https://github.com/dojeon-ai/SimbaV2", 12 | license="Apache License 2.0", 13 | install_requires=install_requires, 14 | packages=find_packages(), 15 | python_requires=">=3.9.0", 16 | zip_safe=True, 17 | ) 18 | -------------------------------------------------------------------------------- /deps/requirements.txt: -------------------------------------------------------------------------------- 1 | dm-control==1.0.20 2 | dotmap==1.3.30 3 | Cython==0.29.37 4 | flax==0.8.4 5 | hydra-core==1.3.2 6 | imageio==2.33.1 7 | jax==0.4.25 8 | jaxlib==0.4.25 9 | jaxopt==0.8.3 10 | matplotlib==3.9.0 11 | mujoco==3.1.6 12 | mujoco-py==2.1.2.14 13 | numpy==1.24.1 14 | omegaconf==2.3.0 15 | optax==0.2.2 16 | pandas==2.1.4 17 | Shimmy==2.0.0 18 | tensorflow-probability==0.24.0 19 | termcolor==2.4.0 20 | tqdm==4.66.1 21 | wandb==0.16.6 22 | moviepy==1.0.3 23 | git+https://github.com/takuseno/D4RL 24 | git+https://github.com/Farama-Foundation/Gymnasium.git@d92e030e9a0b03806deda51122c705169313f158 25 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | # https://pre-commit.com 2 | # This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file. 3 | --- 4 | name: pre-commit 5 | on: 6 | pull_request: 7 | push: 8 | branches: [master] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | pre-commit: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | - uses: actions/setup-python@v3 19 | - run: pip install pre-commit 20 | - run: pre-commit --version 21 | - run: pre-commit install 22 | - run: pre-commit run --all-files 23 | -------------------------------------------------------------------------------- /docs/dataset/css/bulma.min.css: -------------------------------------------------------------------------------- 1 | .bibtex-container{ 2 | flex-grow:1; 3 | margin:0 auto 4 | ;position:relative; 5 | width:auto 6 | } 7 | 8 | .bibtex pre{ 9 | -webkit-overflow-scrolling: touch; 10 | overflow-x:auto; 11 | padding:1.25em 1.5em; 12 | white-space:pre; 13 | word-wrap:normal; 14 | font-family: "Courier", monospace; 15 | background-color: #f4f4f4; 16 | text-align: left; 17 | } 18 | 19 | .hero{ 20 | align-items:stretch; 21 | display:flex; 22 | flex-direction:column; 23 | justify-content:space-between 24 | } 25 | 26 | .hero.is-light{ 27 | color:rgba(0,0,0,.7) 28 | } 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: https://github.com/codespell-project/codespell 4 | rev: v2.2.5 5 | hooks: 6 | - id: codespell # english spelling check 7 | args: [ 8 | --ignore-words-list=null 9 | ] 10 | exclude: ^(.*\.ipynb|.*\.html)$ 11 | - repo: https://github.com/astral-sh/ruff-pre-commit 12 | rev: v0.2.1 13 | hooks: 14 | - id: ruff # code formatting check. 15 | args: [ 16 | --fix, # applies formatting automatically. 17 | --select=I, # import-sort (library import is automatically sorted) 18 | ] 19 | exclude: ^.*__init__\.py$ 20 | - id: ruff-format 21 | exclude: ^.*__init__\.py$ 22 | -------------------------------------------------------------------------------- /scale_rl/envs/mujoco.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | MUJOCO_ALL = [ 4 | "HalfCheetah-v4", 5 | "Hopper-v4", 6 | "Walker2d-v4", 7 | "Ant-v4", 8 | "Humanoid-v4", 9 | ] 10 | 11 | MUJOCO_RANDOM_SCORE = { 12 | "HalfCheetah-v4": -289.415, 13 | "Hopper-v4": 18.791, 14 | "Walker2d-v4": 2.791, 15 | "Ant-v4": -70.288, 16 | "Humanoid-v4": 120.423, 17 | } 18 | 19 | MUJOCO_TD3_SCORE = { 20 | "HalfCheetah-v4": 10574, 21 | "Hopper-v4": 3226, 22 | "Walker2d-v4": 3946, 23 | "Ant-v4": 3942, 24 | "Humanoid-v4": 5165, 25 | } 26 | 27 | 28 | def make_mujoco_env( 29 | env_name: str, 30 | seed: int, 31 | ) -> gym.Env: 32 | env = gym.make(env_name, render_mode="rgb_array") 33 | 34 | return env 35 | -------------------------------------------------------------------------------- /scale_rl/envs/wrappers/repeat_action.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | 4 | 5 | class RepeatAction(gym.Wrapper): 6 | def __init__(self, env: gym.Env, action_repeat=4): 7 | super().__init__(env) 8 | self._action_repeat = action_repeat 9 | 10 | def step(self, action: np.ndarray): 11 | total_reward = 0.0 12 | terminated = False 13 | truncated = False 14 | combined_info = {} 15 | 16 | for _ in range(self._action_repeat): 17 | obs, reward, terminated, truncated, info = self.env.step(action) 18 | total_reward += float(reward) 19 | combined_info.update(info) 20 | if terminated or truncated: 21 | break 22 | 23 | return obs, total_reward, terminated, truncated, combined_info 24 | -------------------------------------------------------------------------------- /deps/environment.yaml: -------------------------------------------------------------------------------- 1 | name: scale_rl 2 | channels: 3 | - nvidia 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - glew=2.1.0 8 | - glib=2.68.4 9 | - pip=21.0 10 | - python=3.9.0 11 | - pip: 12 | - dm-control==1.0.20 13 | - dotmap==1.3.30 14 | - Cython==0.29.37 15 | - flax==0.8.4 16 | - git+https://github.com/Farama-Foundation/Gymnasium.git@d92e030e9a0b03806deda51122c705169313f158 17 | - git+https://github.com/takuseno/D4RL 18 | - hydra-core==1.3.2 19 | - imageio==2.33.1 20 | - jax==0.4.25 21 | - jaxlib==0.4.25 22 | - jaxopt==0.8.3 23 | - matplotlib==3.9.0 24 | - mujoco==3.1.6 25 | - mujoco-py==2.1.2.14 26 | - numpy==1.24.1 27 | - omegaconf==2.3.0 28 | - optax==0.2.2 29 | - pandas==2.1.4 30 | - Shimmy==2.0.0 31 | - tensorflow-probability==0.24.0 32 | - termcolor==2.4.0 33 | - tqdm==4.66.1 34 | - wandb==0.16.6 35 | - moviepy==1.0.3 -------------------------------------------------------------------------------- /configs/agent/simba.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Simba 3 | ################################################################################## 4 | 5 | agent_type: 'simba' 6 | 7 | seed: ${seed} 8 | normalize_observation: true 9 | normalize_reward: false 10 | normalized_g_max: 0.0 11 | 12 | load_only_param: true # do not load optimizer 13 | load_param_key: null # load designated key 14 | load_observation_normalizer: true 15 | load_reward_normalizer: true 16 | 17 | learning_rate_init: 1e-4 18 | learning_rate_end: 1e-4 19 | learning_rate_decay_rate: 1.0 20 | learning_rate_decay_step: ${eval:'int(${agent.learning_rate_decay_rate} * ${num_interaction_steps} * ${updates_per_interaction_step})'} 21 | weight_decay: 1e-2 22 | 23 | actor_num_blocks: 1 24 | actor_hidden_dim: 128 25 | actor_bc_alpha: 0.0 26 | 27 | critic_use_cdq: ${env.episodic} 28 | critic_num_blocks: 2 29 | critic_hidden_dim: 512 30 | 31 | target_tau: 0.005 32 | 33 | temp_initial_value: 0.01 34 | temp_target_entropy: null # entropy_coef * action_dim 35 | temp_target_entropy_coef: -0.5 36 | 37 | gamma: ${gamma} 38 | n_step: ${n_step} 39 | -------------------------------------------------------------------------------- /scale_rl/buffers/__init__.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from typing import Tuple, Optional 3 | from scale_rl.buffers.base_buffer import BaseBuffer, Batch 4 | from scale_rl.buffers.numpy_buffer import NpyUniformBuffer, NpyPrioritizedBuffer 5 | 6 | 7 | def create_buffer( 8 | buffer_class_type: str, 9 | buffer_type: str, 10 | observation_space: gym.spaces.Space, 11 | action_space: gym.spaces.Space, 12 | n_step: int, 13 | gamma: float, 14 | max_length: int, 15 | min_length: int, 16 | sample_batch_size: int, 17 | **kwargs, 18 | ) -> BaseBuffer: 19 | 20 | if buffer_class_type == 'numpy': 21 | if buffer_type == 'uniform': 22 | buffer = NpyUniformBuffer( 23 | observation_space=observation_space, 24 | action_space=action_space, 25 | n_step=n_step, 26 | gamma=gamma, 27 | max_length=max_length, 28 | min_length=min_length, 29 | sample_batch_size=sample_batch_size, 30 | ) 31 | 32 | else: 33 | raise NotImplementedError 34 | 35 | elif buffer_class_type == 'jax': 36 | raise NotImplementedError 37 | 38 | return buffer 39 | -------------------------------------------------------------------------------- /scale_rl/agents/random_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, TypeVar 2 | 3 | import gymnasium as gym 4 | import numpy as np 5 | 6 | from scale_rl.agents.base_agent import BaseAgent 7 | 8 | Config = TypeVar("Config") 9 | 10 | 11 | class RandomAgent(BaseAgent): 12 | def __init__( 13 | self, 14 | observation_space: gym.spaces.Space, 15 | action_space: gym.spaces.Space, 16 | cfg: Config, 17 | ): 18 | """ 19 | An agent that randomly selects actions without training. 20 | Useful for collecting baseline results and for debugging purposes. 21 | """ 22 | super(RandomAgent, self).__init__( 23 | observation_space, 24 | action_space, 25 | cfg, 26 | ) 27 | 28 | def sample_actions( 29 | self, 30 | interaction_step: int, 31 | prev_timestep: Dict[str, np.ndarray], 32 | training: bool, 33 | ) -> np.ndarray: 34 | num_envs = prev_timestep["next_observation"].shape[0] 35 | actions = [] 36 | for _ in range(num_envs): 37 | actions.append(self._action_space.sample()) 38 | 39 | actions = np.stack(actions).reshape(-1) 40 | return actions 41 | 42 | def update(self, update_step: int, batch: Dict[str, np.ndarray]) -> Dict: 43 | update_info = {} 44 | return update_info 45 | 46 | def save(self, path: str) -> None: 47 | pass 48 | 49 | def load(self, path: str) -> None: 50 | pass 51 | -------------------------------------------------------------------------------- /configs/offline_rl.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Common 3 | ################################################################################## 4 | 5 | project_name: 'SimbaV2' 6 | entity_name: 'draftrec' 7 | group_name: 'test' 8 | exp_name: 'test' 9 | seed: 0 10 | server: 'local' 11 | save_path: 'models/${group_name}/${exp_name}/${env.env_name}/${seed}' 12 | load_path: null 13 | 14 | ################################################################################## 15 | # Training 16 | ################################################################################## 17 | 18 | # gamma value is set with a heuristic from TD-MPCv2 19 | eff_episode_len: ${eval:'${env.max_episode_steps} / ${env.action_repeat}'} 20 | gamma: ${eval:'max(min((${eff_episode_len} / 5 - 1) / (${eff_episode_len} / 5), 0.995), 0.95)'} 21 | n_step: 1 22 | 23 | num_epochs: 100 24 | action_repeat: ${env.action_repeat} 25 | 26 | num_interaction_steps: null # update steps 27 | updates_per_interaction_step: 1 # fixed 28 | evaluation_per_interaction_step: 50_000 # evaluation frequency per interaction step. 29 | metrics_per_interaction_step: 50_000 # log metrics per interaction step. 30 | recording_per_interaction_step: -1 # recording is not supported for offline rl yet. 31 | logging_per_interaction_step: 10_000 # logging frequency per interaction step. 32 | save_checkpoint_per_interaction_step: null 33 | num_eval_episodes: 100 34 | num_record_episodes: 1 35 | 36 | defaults: 37 | - _self_ 38 | - agent: simbaV2_bc 39 | - buffer: numpy_uniform 40 | - env: d4rl 41 | -------------------------------------------------------------------------------- /scale_rl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from typing import TypeVar 3 | from omegaconf import OmegaConf 4 | from scale_rl.agents.base_agent import BaseAgent 5 | from scale_rl.agents.wrappers import ObservationNormalizer, RewardNormalizer 6 | 7 | Config = TypeVar('Config') 8 | 9 | 10 | def create_agent( 11 | observation_space: gym.spaces.Space, 12 | action_space: gym.spaces.Space, 13 | cfg: Config, 14 | ) -> BaseAgent: 15 | 16 | cfg = OmegaConf.to_container(cfg, throw_on_missing=True) 17 | agent_type = cfg.pop('agent_type') 18 | 19 | if agent_type == 'random': 20 | from scale_rl.agents.random_agent import RandomAgent 21 | agent = RandomAgent(observation_space, action_space, cfg) 22 | 23 | elif agent_type == 'simba': 24 | from scale_rl.agents.simba.simba_agent import SimbaAgent 25 | agent = SimbaAgent(observation_space, action_space, cfg) 26 | 27 | elif agent_type == 'simbaV2': 28 | from scale_rl.agents.simbaV2.simbaV2_agent import SimbaV2Agent 29 | agent = SimbaV2Agent(observation_space, action_space, cfg) 30 | 31 | else: 32 | raise NotImplementedError 33 | 34 | # observation and reward normalization wrappers 35 | if cfg['normalize_observation']: 36 | agent = ObservationNormalizer( 37 | agent, 38 | load_rms=cfg['load_observation_normalizer'] 39 | ) 40 | if cfg['normalize_reward']: 41 | agent = RewardNormalizer( 42 | agent, 43 | gamma=cfg['gamma'], 44 | g_max=cfg['normalized_g_max'], 45 | load_rms=cfg['load_reward_normalizer'] 46 | ) 47 | 48 | return agent -------------------------------------------------------------------------------- /scale_rl/common/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | EPS = 1e-8 4 | 5 | 6 | def cyclic_exponential_decay_scheduler( 7 | decay_period, initial_value, final_value, reverse=False 8 | ): 9 | if reverse: 10 | initial_value = 1 - initial_value 11 | final_value = 1 - final_value 12 | 13 | start = np.log(initial_value + EPS) 14 | end = np.log(final_value + EPS) 15 | 16 | def scheduler(step): 17 | cycle_length = decay_period 18 | cycle_step = step % cycle_length 19 | 20 | steps_left = decay_period - cycle_step 21 | bonus_frac = steps_left / decay_period 22 | bonus = np.clip(bonus_frac, 0.0, 1.0) 23 | new_value = bonus * (start - end) + end 24 | 25 | new_value = np.exp(new_value) - EPS 26 | if reverse: 27 | new_value = 1 - new_value 28 | return new_value 29 | 30 | return scheduler 31 | 32 | 33 | def linear_decay_scheduler(decay_period, initial_value, final_value): 34 | def scheduler(step): 35 | # Ensure step does not exceed decay_period 36 | step = min(step, decay_period) 37 | 38 | # Calculate the linear interpolation factor 39 | fraction = step / decay_period 40 | new_value = (1 - fraction) * initial_value + fraction * final_value 41 | 42 | return new_value 43 | 44 | return scheduler 45 | 46 | 47 | def constant_value_scheduler(value): 48 | """ 49 | Returns a scheduler function that always returns the same value. 50 | 51 | Args: 52 | value (float): The constant value to return. 53 | 54 | Returns: 55 | function: A scheduler function that always returns `value`. 56 | """ 57 | 58 | def scheduler(step): 59 | return value 60 | 61 | return scheduler 62 | -------------------------------------------------------------------------------- /configs/online_rl.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Common 3 | ################################################################################## 4 | 5 | project_name: 'SimbaV2' 6 | entity_name: 'draftrec' 7 | group_name: 'test' 8 | exp_name: 'test' 9 | seed: 0 10 | server: 'local' 11 | save_path: 'models/${group_name}/${exp_name}/${env.env_name}/${seed}' 12 | load_path: null 13 | 14 | ################################################################################## 15 | # Training 16 | ################################################################################## 17 | 18 | # gamma value is set with a heuristic from TD-MPCv2 19 | eff_episode_len: ${eval:'${env.max_episode_steps} / ${env.action_repeat}'} 20 | gamma: ${eval:'max(min((${eff_episode_len} / 5 - 1) / (${eff_episode_len} / 5), 0.995), 0.95)'} 21 | n_step: 1 22 | 23 | num_train_envs: ${env.num_train_envs} 24 | num_env_steps: ${env.num_env_steps} 25 | action_repeat: ${env.action_repeat} 26 | 27 | num_interaction_steps: ${eval:'${num_env_steps} / (${num_train_envs} * ${action_repeat})'} 28 | updates_per_interaction_step: ${action_repeat} # number of updates per interaction step. 29 | evaluation_per_interaction_step: 50_000 # evaluation frequency per interaction step. 30 | metrics_per_interaction_step: 50_000 # log metrics per interaction step. 31 | recording_per_interaction_step: ${num_interaction_steps} # video recording frequency per interaction step. 32 | logging_per_interaction_step: 10_000 # logging frequency per interaction step. 33 | save_checkpoint_per_interaction_step: ${num_interaction_steps} 34 | save_buffer_per_interaction_step: ${num_interaction_steps} 35 | num_eval_episodes: 10 36 | num_record_episodes: 1 37 | 38 | defaults: 39 | - _self_ 40 | - agent: simbaV2 41 | - buffer: numpy_uniform 42 | - env: dmc -------------------------------------------------------------------------------- /scale_rl/envs/dmc.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from dm_control import suite 3 | from gymnasium import spaces 4 | from gymnasium.wrappers import FlattenObservation 5 | from shimmy import DmControlCompatibilityV0 as DmControltoGymnasium 6 | 7 | # 20 tasks 8 | DMC_EASY_MEDIUM = [ 9 | "acrobot-swingup", 10 | "ball_in_cup-catch", 11 | "cartpole-balance", 12 | "cartpole-balance_sparse", 13 | "cartpole-swingup", 14 | "cartpole-swingup_sparse", 15 | "cheetah-run", 16 | "finger-spin", 17 | "finger-turn_easy", 18 | "finger-turn_hard", 19 | "fish-swim", 20 | "hopper-hop", 21 | "hopper-stand", 22 | "pendulum-swingup", 23 | "quadruped-walk", 24 | "quadruped-run", 25 | "reacher-easy", 26 | "reacher-hard", 27 | "walker-stand", 28 | "walker-walk", 29 | "walker-run", 30 | ] 31 | 32 | # 8 tasks 33 | DMC_SPARSE = [ 34 | "cartpole-balance_sparse", 35 | "cartpole-swingup_sparse", 36 | "ball_in_cup-catch", 37 | "finger-spin", 38 | "finger-turn_easy", 39 | "finger-turn_hard", 40 | "reacher-easy", 41 | "reacher-hard", 42 | ] 43 | 44 | # 7 tasks 45 | DMC_HARD = [ 46 | "humanoid-stand", 47 | "humanoid-walk", 48 | "humanoid-run", 49 | "dog-stand", 50 | "dog-walk", 51 | "dog-run", 52 | "dog-trot", 53 | ] 54 | 55 | 56 | def make_dmc_env( 57 | env_name: str, 58 | seed: int, 59 | flatten: bool = True, 60 | ) -> gym.Env: 61 | domain_name, task_name = env_name.split("-") 62 | env = suite.load( 63 | domain_name=domain_name, 64 | task_name=task_name, 65 | task_kwargs={"random": seed}, 66 | ) 67 | env = DmControltoGymnasium(env, render_mode="rgb_array") 68 | if flatten and isinstance(env.observation_space, spaces.Dict): 69 | env = FlattenObservation(env) 70 | 71 | return env 72 | -------------------------------------------------------------------------------- /configs/agent/simbaV2_bc.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # SAC with Hyper-Simba architecture + Behavioral Cloning Loss 3 | ################################################################################## 4 | 5 | agent_type: 'simbaV2' 6 | 7 | seed: ${seed} 8 | normalize_observation: true 9 | normalize_reward: true 10 | normalized_g_max: 5.0 11 | 12 | load_only_param: true # do not load optimizer 13 | load_param_key: null # load designated key 14 | load_observation_normalizer: true 15 | load_reward_normalizer: true 16 | 17 | learning_rate_init: 1e-4 18 | learning_rate_end: 1e-5 19 | learning_rate_decay_rate: 0.5 20 | learning_rate_decay_step: null 21 | 22 | actor_num_blocks: 1 23 | actor_hidden_dim: 128 24 | actor_c_shift: 3.0 25 | actor_scaler_init: ${eval:'math.sqrt(2 / ${agent.actor_hidden_dim})'} 26 | actor_scaler_scale: ${eval:'math.sqrt(2 / ${agent.actor_hidden_dim})'} 27 | actor_alpha_init: ${eval:'1 / (${agent.actor_num_blocks} + 1)'} 28 | actor_alpha_scale: ${eval:'1 / math.sqrt(${agent.actor_hidden_dim})'} 29 | actor_bc_alpha: 0.1 # offline-rl 30 | 31 | critic_use_cdq: ${env.episodic} 32 | critic_num_blocks: 2 33 | critic_hidden_dim: 512 34 | critic_c_shift: 3.0 35 | critic_num_bins: 101 36 | critic_scaler_init: ${eval:'math.sqrt(2 / ${agent.critic_hidden_dim})'} 37 | critic_scaler_scale: ${eval:'math.sqrt(2 / ${agent.critic_hidden_dim})'} 38 | critic_min_v: ${eval:'-${agent.normalized_g_max}'} 39 | critic_max_v: ${eval:'${agent.normalized_g_max}'} 40 | critic_alpha_init: ${eval:'1 / (${agent.critic_num_blocks} + 1)'} 41 | critic_alpha_scale: ${eval:'1 / math.sqrt(${agent.critic_hidden_dim})'} 42 | 43 | target_tau: 0.005 44 | 45 | temp_initial_value: 0.01 46 | temp_target_entropy: null # entropy_coef * action_dim 47 | temp_target_entropy_coef: -0.5 48 | 49 | gamma: ${gamma} 50 | n_step: ${n_step} -------------------------------------------------------------------------------- /scale_rl/agents/wrappers/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is adapted from the Gymnasium API. 3 | ref: https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/utils.py 4 | """ 5 | 6 | import numpy as np 7 | 8 | 9 | class RunningMeanStd: 10 | """Tracks the mean, variance and count of values.""" 11 | 12 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 13 | def __init__(self, epsilon=1e-4, shape=(), dtype=np.float32): 14 | """Tracks the mean, variance and count of values.""" 15 | self.mean = np.zeros(shape, dtype=dtype) 16 | self.var = np.ones(shape, dtype=dtype) 17 | self.count = epsilon 18 | 19 | def update(self, x): 20 | """Updates the mean, var and count from a batch of samples.""" 21 | batch_mean = np.mean(x, axis=0) 22 | batch_var = np.var(x, axis=0) 23 | batch_count = x.shape[0] 24 | self.update_from_moments(batch_mean, batch_var, batch_count) 25 | 26 | def update_from_moments(self, batch_mean, batch_var, batch_count): 27 | """Updates from batch mean, variance and count moments.""" 28 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 29 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count 30 | ) 31 | 32 | 33 | def update_mean_var_count_from_moments( 34 | mean, var, count, batch_mean, batch_var, batch_count 35 | ): 36 | """Updates the mean, var and count using the previous mean, var, count and batch values.""" 37 | delta = batch_mean - mean 38 | tot_count = count + batch_count 39 | ratio = batch_count / tot_count 40 | 41 | new_mean = mean + delta * ratio 42 | m_a = var * count 43 | m_b = batch_var * batch_count 44 | M2 = m_a + m_b + np.square(delta) * count * ratio 45 | new_var = M2 / tot_count 46 | 47 | return new_mean, new_var, tot_count 48 | -------------------------------------------------------------------------------- /configs/agent/simbaV2.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # SAC with Hyper-Simba architecture 3 | ################################################################################## 4 | 5 | agent_type: 'simbaV2' 6 | 7 | seed: ${seed} 8 | normalize_observation: true 9 | normalize_reward: true 10 | normalized_g_max: 5.0 11 | 12 | load_only_param: true # do not load optimizer 13 | load_param_key: null # load designated key 14 | load_observation_normalizer: true 15 | load_reward_normalizer: true 16 | 17 | learning_rate_init: 1e-4 18 | learning_rate_end: 5e-5 19 | learning_rate_decay_rate: 1.0 20 | learning_rate_decay_step: ${eval:'int(${agent.learning_rate_decay_rate} * ${num_interaction_steps} * ${updates_per_interaction_step})'} 21 | 22 | actor_num_blocks: 1 23 | actor_hidden_dim: 128 24 | actor_c_shift: 3.0 25 | actor_scaler_init: ${eval:'math.sqrt(2 / ${agent.actor_hidden_dim})'} 26 | actor_scaler_scale: ${eval:'math.sqrt(2 / ${agent.actor_hidden_dim})'} 27 | actor_alpha_init: ${eval:'1 / (${agent.actor_num_blocks} + 1)'} 28 | actor_alpha_scale: ${eval:'1 / math.sqrt(${agent.actor_hidden_dim})'} 29 | actor_bc_alpha: 0.0 30 | 31 | critic_use_cdq: ${env.episodic} 32 | critic_num_blocks: 2 33 | critic_hidden_dim: 512 34 | critic_c_shift: 3.0 35 | critic_num_bins: 101 36 | critic_scaler_init: ${eval:'math.sqrt(2 / ${agent.critic_hidden_dim})'} 37 | critic_scaler_scale: ${eval:'math.sqrt(2 / ${agent.critic_hidden_dim})'} 38 | critic_min_v: ${eval:'-${agent.normalized_g_max}'} 39 | critic_max_v: ${eval:'${agent.normalized_g_max}'} 40 | critic_alpha_init: ${eval:'1 / (${agent.critic_num_blocks} + 1)'} 41 | critic_alpha_scale: ${eval:'1 / math.sqrt(${agent.critic_hidden_dim})'} 42 | 43 | target_tau: 0.005 44 | 45 | temp_initial_value: 0.01 46 | temp_target_entropy: null # entropy_coef * action_dim 47 | temp_target_entropy_coef: -0.5 48 | 49 | gamma: ${gamma} 50 | n_step: ${n_step} -------------------------------------------------------------------------------- /scale_rl/buffers/base_buffer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict 3 | 4 | import gymnasium as gym 5 | import numpy as np 6 | 7 | Batch = Dict[str, np.ndarray] 8 | 9 | 10 | class BaseBuffer(ABC): 11 | def __init__( 12 | self, 13 | observation_space: gym.spaces.Space, 14 | action_space: gym.spaces.Space, 15 | n_step: int, 16 | gamma: float, 17 | max_length: int, 18 | min_length: int, 19 | sample_batch_size: int, 20 | ): 21 | """ 22 | A generic buffer class. 23 | 24 | args: 25 | observation_shape 26 | action_shapce 27 | max_length: maximum length of buffer (max number of experiences stored within the state). 28 | min_length: minimum number of experiences saved in the buffer state before we can sample. 29 | add_sequences: indiciator of whether we will be adding data in sequences to the buffer? 30 | sample_batch_size: batch size of data that is sampled from a single sampling call. 31 | """ 32 | 33 | self._observation_space = observation_space 34 | self._action_space = action_space 35 | self._max_length = max_length 36 | self._min_length = min_length 37 | self._n_step = n_step 38 | self._gamma = gamma 39 | self._sample_batch_size = sample_batch_size 40 | 41 | def __len__(self): 42 | pass 43 | 44 | @abstractmethod 45 | def reset(self) -> None: 46 | pass 47 | 48 | @abstractmethod 49 | def add(self, timestep: Dict[str, np.ndarray]) -> None: 50 | pass 51 | 52 | @abstractmethod 53 | def can_sample(self) -> bool: 54 | pass 55 | 56 | @abstractmethod 57 | def sample(self) -> Batch: 58 | pass 59 | 60 | @abstractmethod 61 | def save(self, path: str) -> None: 62 | pass 63 | 64 | @abstractmethod 65 | def get_observations(self) -> np.ndarray: 66 | pass 67 | -------------------------------------------------------------------------------- /scale_rl/envs/d4rl.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import d4rl 4 | import gymnasium as gym 5 | import numpy as np 6 | 7 | D4RL_MUJOCO = [ 8 | "hopper-medium-v2", 9 | "hopper-medium-replay-v2", 10 | "hopper-medium-expert-v2", 11 | "halfcheetah-medium-v2", 12 | "halfcheetah-medium-replay-v2", 13 | "halfcheetah-medium-expert-v2", 14 | "walker2d-medium-v2", 15 | "walker2d-medium-replay-v2", 16 | "walker2d-medium-expert-v2", 17 | ] 18 | 19 | 20 | def make_d4rl_env(env_name: str, seed: int) -> gym.Env: 21 | env = gym.make("GymV26Environment-v0", env_id=env_name) 22 | env.reset(seed=seed) 23 | return env 24 | 25 | 26 | def make_d4rl_dataset(env_name: str) -> list[dict[str, Any]]: 27 | env = make_d4rl_env(env_name, seed=0) 28 | 29 | # extract dataset 30 | dataset = env.env.env.gym_env.get_dataset() 31 | timesteps = [] 32 | total_steps = dataset["rewards"].shape[0] 33 | for i in range(total_steps - 1): 34 | obs = dataset["observations"][i] 35 | action = dataset["actions"][i] 36 | reward = dataset["rewards"][i] 37 | terminated = dataset["terminals"][i] 38 | truncated = dataset["timeouts"][i] 39 | if terminated: 40 | next_obs = np.zeros_like(obs) 41 | elif truncated: 42 | continue 43 | else: 44 | next_obs = dataset["observations"][i + 1] 45 | timestep = { 46 | "observation": np.expand_dims(obs, axis=0), 47 | "action": np.expand_dims(action, axis=0), 48 | "reward": np.array([reward]), 49 | "terminated": np.array([terminated]), 50 | "truncated": np.array([truncated]), 51 | "next_observation": np.expand_dims(next_obs, axis=0), 52 | } 53 | timesteps.append(timestep) 54 | 55 | return timesteps 56 | 57 | 58 | def get_d4rl_normalized_score(env_name: str, unnormalized_score: float) -> float: 59 | return 100 * d4rl.get_normalized_score(env_name, unnormalized_score) 60 | -------------------------------------------------------------------------------- /scale_rl/envs/myosuite.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | MYOSUITE_TASKS = [ 4 | "myo-reach", 5 | "myo-reach-hard", 6 | "myo-pose", 7 | "myo-pose-hard", 8 | "myo-obj-hold", 9 | "myo-obj-hold-hard", 10 | "myo-key-turn", 11 | "myo-key-turn-hard", 12 | "myo-pen-twirl", 13 | "myo-pen-twirl-hard", 14 | ] 15 | 16 | 17 | MYOSUITE_TASKS_DICT = { 18 | "myo-reach": "myoHandReachFixed-v0", 19 | "myo-reach-hard": "myoHandReachRandom-v0", 20 | "myo-pose": "myoHandPoseFixed-v0", 21 | "myo-pose-hard": "myoHandPoseRandom-v0", 22 | "myo-obj-hold": "myoHandObjHoldFixed-v0", 23 | "myo-obj-hold-hard": "myoHandObjHoldRandom-v0", 24 | "myo-key-turn": "myoHandKeyTurnFixed-v0", 25 | "myo-key-turn-hard": "myoHandKeyTurnRandom-v0", 26 | "myo-pen-twirl": "myoHandPenTwirlFixed-v0", 27 | "myo-pen-twirl-hard": "myoHandPenTwirlRandom-v0", 28 | } 29 | 30 | 31 | class MyosuiteGymnasiumVersionWrapper(gym.Wrapper): 32 | """ 33 | myosuite originally requires gymnasium==0.15 34 | however, we are currently using gymnasium==1.0.0a2, 35 | hence requiring some minor fix to the 36 | - fix a. 37 | - fix b. 38 | """ 39 | 40 | def __init__(self, env: gym.Env): 41 | super().__init__(env) 42 | self.unwrapped_env = env.unwrapped 43 | 44 | def step(self, action): 45 | obs, reward, terminated, truncated, info = self.env.step(action) 46 | info["success"] = info["solved"] 47 | return obs, reward, terminated, truncated, info 48 | 49 | def render( 50 | self, width: int = 192, height: int = 192, camera_id: str = "hand_side_inter" 51 | ): 52 | return self.unwrapped_env.sim.renderer.render_offscreen( 53 | width=width, 54 | height=height, 55 | camera_id=camera_id, 56 | ) 57 | 58 | 59 | def make_myosuite_env( 60 | env_name: str, 61 | seed: int, 62 | **kwargs, 63 | ) -> gym.Env: 64 | from myosuite.utils import gym as myo_gym 65 | 66 | env = myo_gym.make(MYOSUITE_TASKS_DICT[env_name]) 67 | env = MyosuiteGymnasiumVersionWrapper(env) 68 | 69 | return env 70 | -------------------------------------------------------------------------------- /scale_rl/agents/simba/simba_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | from tensorflow_probability.substrates import jax as tfp 6 | 7 | tfd = tfp.distributions 8 | tfb = tfp.bijectors 9 | 10 | 11 | class PreLNResidualBlock(nn.Module): 12 | hidden_dim: int 13 | expansion: int = 4 14 | 15 | def setup(self): 16 | self.pre_ln = nn.LayerNorm() 17 | self.w1 = nn.Dense( 18 | self.hidden_dim * self.expansion, kernel_init=nn.initializers.he_normal() 19 | ) 20 | self.w2 = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal()) 21 | 22 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 23 | res = x 24 | x = self.pre_ln(x) 25 | x = self.w1(x) 26 | x = nn.relu(x) 27 | x = self.w2(x) 28 | return res + x 29 | 30 | 31 | class LinearCritic(nn.Module): 32 | def setup(self): 33 | self.w = nn.Dense(1, kernel_init=nn.initializers.orthogonal(1.0)) 34 | 35 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 36 | value = self.w(x).squeeze(-1) 37 | info = {} 38 | return value, info 39 | 40 | 41 | class NormalTanhPolicy(nn.Module): 42 | action_dim: int 43 | log_std_min: float = -10.0 44 | log_std_max: float = 2.0 45 | 46 | def setup(self): 47 | self.mean_w = nn.Dense( 48 | self.action_dim, kernel_init=nn.initializers.orthogonal(1.0) 49 | ) 50 | self.std_w = nn.Dense( 51 | self.action_dim, kernel_init=nn.initializers.orthogonal(1.0) 52 | ) 53 | 54 | def __call__( 55 | self, 56 | x: jnp.ndarray, 57 | temperature: float = 1.0, 58 | ) -> tfd.Distribution: 59 | mean = self.mean_w(x) 60 | log_std = self.std_w(x) 61 | 62 | # normalize log-stds for stability 63 | log_std = self.log_std_min + (self.log_std_max - self.log_std_min) * 0.5 * ( 64 | 1 + nn.tanh(log_std) 65 | ) 66 | 67 | # N(mu, exp(log_sigma)) 68 | dist = tfd.MultivariateNormalDiag( 69 | loc=mean, 70 | scale_diag=jnp.exp(log_std) * temperature, 71 | ) 72 | 73 | # tanh(N(mu, sigma)) 74 | dist = tfd.TransformedDistribution(distribution=dist, bijector=tfb.Tanh()) 75 | 76 | info = {} 77 | return dist, info 78 | -------------------------------------------------------------------------------- /docs/dataset/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function() { return false; }; 18 | image.oncontextmenu = function() { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function() { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function() { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for(var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function(state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function(event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) 79 | -------------------------------------------------------------------------------- /scale_rl/agents/jax_utils/tree_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | def tree_norm(tree): 8 | return jnp.sqrt(sum((x**2).sum() for x in jax.tree_util.tree_leaves(tree))) 9 | 10 | 11 | def tree_map_until_match( 12 | f, tree, target_re, *rest, keep_structure=True, keep_values=False 13 | ): 14 | """ 15 | Similar to `jax.tree_util.tree_map_with_path`, but `is_leaf` is a regex condition. 16 | args: 17 | f: A function to map the discovered nodes (i.e., dict key matches `target_re`). 18 | Inputs to f will be (1) the discovered node and (2) the corresponding nodes in `*rest``. 19 | target_re: A regex string condition that triggers `f`. 20 | tree: A pytree to be searched by `target_re` and mapped by `f`. 21 | *rest: List of pytrees that are at least 'almost' identical structure to `tree`. 22 | 'Almost', since the substructure of matching nodes don't have to be identical. 23 | i.e., The tree structure of `tree` and `*rest` should be identical only up to the matching nodes. 24 | keep_structure: If false, the returned tree will only contain subtrees that lead to the matching nodes. 25 | keep_values: If false, unmatched leaves will become `None`. Assumes `keep_structure=True`. 26 | """ 27 | 28 | if not isinstance(tree, dict): 29 | return tree if keep_values else None 30 | 31 | ret_tree = {} 32 | for k, v in tree.items(): 33 | v_rest = [r[k] for r in rest] 34 | if re.fullmatch(target_re, k): 35 | ret_tree[k] = f(v, *v_rest) 36 | else: 37 | subtree = tree_map_until_match( 38 | f, 39 | v, 40 | target_re, 41 | *v_rest, 42 | keep_structure=keep_structure, 43 | keep_values=keep_values, 44 | ) 45 | if keep_structure or subtree: 46 | ret_tree[k] = subtree 47 | 48 | return ret_tree 49 | 50 | 51 | def tree_filter(f, tree, target_re="scaler"): 52 | if isinstance(tree, dict): 53 | # Keep only "target_re" keys in the dictionary 54 | filtered_tree = {} 55 | for k, v in tree.items(): 56 | if re.fullmatch(target_re, k): 57 | filtered_tree[k] = tree_filter(f, v, target_re="scaler") 58 | elif isinstance(v, dict): # Recursively check nested dictionaries 59 | filtered_value = tree_filter(f, v, target_re="scaler") 60 | if filtered_value: # Only keep non-empty dictionaries 61 | filtered_tree[k] = filtered_value 62 | return filtered_tree 63 | else: 64 | # If not a dictionary, return the tree as is (typically a leaf node) 65 | return f(tree) 66 | -------------------------------------------------------------------------------- /deps/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | ENV CUDA_HOME=/usr/local/cuda-12.2 4 | ENV PATH=$CUDA_HOME/bin:$PATH 5 | ENV LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH 6 | 7 | # packages 8 | RUN apt-get -y update && \ 9 | apt-get install -y --no-install-recommends \ 10 | build-essential \ 11 | git \ 12 | rsync \ 13 | tree \ 14 | curl \ 15 | wget \ 16 | unzip \ 17 | htop \ 18 | tmux \ 19 | xvfb \ 20 | patchelf \ 21 | ca-certificates \ 22 | bash-completion \ 23 | libjpeg-dev \ 24 | libpng-dev \ 25 | ffmpeg \ 26 | cmake \ 27 | swig \ 28 | libssl-dev \ 29 | libcurl4-openssl-dev \ 30 | libopenmpi-dev \ 31 | python3-dev \ 32 | zlib1g-dev \ 33 | qtbase5-dev \ 34 | qtdeclarative5-dev \ 35 | libglib2.0-0 \ 36 | libglu1-mesa-dev \ 37 | libgl1-mesa-dev \ 38 | libvulkan1 \ 39 | libgl1-mesa-glx \ 40 | libosmesa6 \ 41 | libosmesa6-dev \ 42 | libglew-dev \ 43 | mesa-utils && \ 44 | apt-get clean && \ 45 | apt-get autoremove -y && \ 46 | rm -rf /var/lib/apt/lists/* && \ 47 | mkdir /root/.ssh 48 | 49 | # python 50 | RUN apt-get -y update && \ 51 | apt-get install -y software-properties-common && \ 52 | add-apt-repository ppa:deadsnakes/ppa && \ 53 | apt-get install -y python3.10 python3.10-distutils python3.10-venv 54 | 55 | ## kubernetes authorization 56 | # kubernetes needs a numeric user apparently 57 | # Ensure the user has write permissions 58 | RUN useradd --create-home \ 59 | --shell /bin/bash \ 60 | --base-dir /home \ 61 | --groups dialout,audio,video,plugdev \ 62 | --uid 1000 \ 63 | user 64 | USER root 65 | WORKDIR /home/user 66 | RUN chown -R user:user /home/user && \ 67 | chmod -R u+rwx /home/user 68 | USER 1000 69 | 70 | # Install packages 71 | COPY deps/requirements.txt /home/user/requirements.txt 72 | RUN python3.10 -m venv /home/user/venv && \ 73 | . /home/user/venv/bin/activate && \ 74 | pip install -r requirements.txt && \ 75 | pip install -U "jax[cuda12]==0.4.25" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 76 | ENV VIRTUAL_ENV=/home/user/venv 77 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 78 | ENV XLA_PYTHON_CLIENT_PREALLOCATE=false 79 | 80 | # install mujoco 2.1.0, humanoid-bench and myosuite 81 | ENV MUJOCO_GL egl 82 | ENV LD_LIBRARY_PATH /home/user/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH} 83 | RUN wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \ 84 | tar -xzf mujoco210-linux-x86_64.tar.gz && \ 85 | rm mujoco210-linux-x86_64.tar.gz && \ 86 | mkdir /home/user/.mujoco && \ 87 | mv mujoco210 /home/user/.mujoco/mujoco210 && \ 88 | find /home/user/.mujoco -exec chown user:user {} \; && \ 89 | python -c "import mujoco_py" && \ 90 | git clone https://github.com/joonleesky/humanoid-bench /home/user/humanoid-bench && \ 91 | pip install -e /home/user/humanoid-bench && \ 92 | git clone --recursive https://github.com/joonleesky/myosuite /home/user/myosuite && \ 93 | pip install -e /home/user/myosuite -------------------------------------------------------------------------------- /scale_rl/common/logger.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from omegaconf import OmegaConf 4 | 5 | import wandb 6 | 7 | 8 | class WandbTrainerLogger(object): 9 | def __init__(self, cfg: Dict): 10 | self.cfg = cfg 11 | dict_cfg = OmegaConf.to_container(cfg, throw_on_missing=True) 12 | 13 | wandb.init( 14 | project=cfg.project_name, 15 | entity=cfg.entity_name, 16 | group=cfg.group_name, 17 | config=dict_cfg, 18 | ) 19 | 20 | self.reset() 21 | 22 | def update_metric(self, **kwargs) -> None: 23 | for k, v in kwargs.items(): 24 | if isinstance(v, float) or isinstance(v, int): 25 | self.average_meter_dict.update(k, v) 26 | else: 27 | self.media_dict[k] = v 28 | 29 | def log_metric(self, step: int) -> Dict: 30 | log_data = {} 31 | log_data.update(self.average_meter_dict.averages()) 32 | log_data.update(self.media_dict) 33 | wandb.log(log_data, step=step) 34 | 35 | def reset(self) -> None: 36 | self.average_meter_dict = AverageMeterDict() 37 | self.media_dict = {} 38 | 39 | 40 | class AverageMeterDict(object): 41 | """ 42 | Manages a collection of AverageMeter instances, 43 | allowing for grouped tracking and averaging of multiple metrics. 44 | """ 45 | 46 | def __init__(self, meters=None): 47 | self.meters = meters if meters else {} 48 | 49 | def __getitem__(self, key): 50 | if key not in self.meters: 51 | meter = AverageMeter() 52 | meter.update(0) 53 | return meter 54 | return self.meters[key] 55 | 56 | def update(self, name, value, n=1) -> None: 57 | if name not in self.meters: 58 | self.meters[name] = AverageMeter() 59 | self.meters[name].update(value, n) 60 | 61 | def reset(self) -> None: 62 | for meter in self.meters.values(): 63 | meter.reset() 64 | 65 | def values(self, format_string="{}"): 66 | return { 67 | format_string.format(name): meter.val for name, meter in self.meters.items() 68 | } 69 | 70 | def averages(self, format_string="{}"): 71 | return { 72 | format_string.format(name): meter.avg for name, meter in self.meters.items() 73 | } 74 | 75 | 76 | class AverageMeter(object): 77 | """ 78 | Tracks and calculates the average and current values of a series of numbers. 79 | """ 80 | 81 | def __init__(self): 82 | self.val = 0 83 | self.avg = 0 84 | self.sum = 0 85 | self.count = 0 86 | 87 | def reset(self): 88 | self.val = 0 89 | self.avg = 0 90 | self.sum = 0 91 | self.count = 0 92 | 93 | def update(self, val, n=1): 94 | # TODO: description for using n 95 | self.val = val 96 | self.sum += val * n 97 | self.count += n 98 | self.avg = self.sum / self.count 99 | 100 | def __format__(self, format): 101 | return "{self.val:{format}} ({self.avg:{format}})".format( 102 | self=self, format=format 103 | ) 104 | -------------------------------------------------------------------------------- /docs/dataset/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:0.4rem 1rem;text-align:center}.slider-pagination .slider-page{background:#648ef6;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /scale_rl/agents/simba/simba_network.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | from tensorflow_probability.substrates import jax as tfp 4 | 5 | from scale_rl.agents.simba.simba_layer import ( 6 | LinearCritic, 7 | NormalTanhPolicy, 8 | PreLNResidualBlock, 9 | ) 10 | 11 | tfd = tfp.distributions 12 | tfb = tfp.bijectors 13 | 14 | 15 | class SimbaActor(nn.Module): 16 | num_blocks: int 17 | hidden_dim: int 18 | action_dim: int 19 | 20 | def setup(self): 21 | self.embedder = nn.Dense( 22 | self.hidden_dim, kernel_init=nn.initializers.orthogonal(1.0) 23 | ) 24 | self.encoder = nn.Sequential( 25 | [ 26 | *[PreLNResidualBlock(hidden_dim=self.hidden_dim) 27 | for _ in range(self.num_blocks)], 28 | nn.LayerNorm(), 29 | ] 30 | ) 31 | self.predictor = NormalTanhPolicy(self.action_dim) 32 | 33 | def __call__( 34 | self, 35 | observations: jnp.ndarray, 36 | temperature: float = 1.0, 37 | ) -> tfd.Distribution: 38 | x = observations 39 | y = self.embedder(x) 40 | z = self.encoder(y) 41 | dist, info = self.predictor(z, temperature) 42 | return dist, info 43 | 44 | 45 | class SimbaCritic(nn.Module): 46 | num_blocks: int 47 | hidden_dim: int 48 | 49 | def setup(self): 50 | self.embedder = nn.Dense( 51 | self.hidden_dim, kernel_init=nn.initializers.orthogonal(1.0) 52 | ) 53 | self.encoder = nn.Sequential( 54 | [ 55 | *[PreLNResidualBlock(hidden_dim=self.hidden_dim) 56 | for _ in range(self.num_blocks)], 57 | nn.LayerNorm(), 58 | ] 59 | ) 60 | self.predictor = LinearCritic() 61 | 62 | def __call__( 63 | self, 64 | observations: jnp.ndarray, 65 | actions: jnp.ndarray, 66 | ) -> jnp.ndarray: 67 | x = jnp.concatenate((observations, actions), axis=1) 68 | y = self.embedder(x) 69 | z = self.encoder(y) 70 | q, info = self.predictor(z) 71 | return q, info 72 | 73 | 74 | class SimbaDoubleCritic(nn.Module): 75 | """ 76 | Vectorized Double-Q for Clipped Double Q-learning. 77 | https://arxiv.org/pdf/1802.09477v3 78 | """ 79 | 80 | num_blocks: int 81 | hidden_dim: int 82 | 83 | num_qs: int = 2 84 | 85 | @nn.compact 86 | def __call__( 87 | self, 88 | observations: jnp.ndarray, 89 | actions: jnp.ndarray, 90 | ) -> jnp.ndarray: 91 | VmapCritic = nn.vmap( 92 | SimbaCritic, 93 | variable_axes={"params": 0}, 94 | split_rngs={"params": True}, 95 | in_axes=None, 96 | out_axes=0, 97 | axis_size=self.num_qs, 98 | ) 99 | 100 | qs, infos = VmapCritic( 101 | num_blocks=self.num_blocks, 102 | hidden_dim=self.hidden_dim, 103 | )(observations, actions) 104 | 105 | return qs, infos 106 | 107 | 108 | class SimbaTemperature(nn.Module): 109 | initial_value: float = 1.0 110 | 111 | @nn.compact 112 | def __call__(self) -> jnp.ndarray: 113 | log_temp = self.param( 114 | name="log_temp", 115 | init_fn=lambda key: jnp.full( 116 | shape=(), fill_value=jnp.log(self.initial_value) 117 | ), 118 | ) 119 | return jnp.exp(log_temp) 120 | -------------------------------------------------------------------------------- /scale_rl/agents/base_agent.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, TypeVar 3 | 4 | import gymnasium as gym 5 | import numpy as np 6 | 7 | Config = TypeVar("Config") 8 | 9 | 10 | class BaseAgent(ABC): 11 | def __init__( 12 | self, 13 | observation_space: gym.spaces.Space, 14 | action_space: gym.spaces.Space, 15 | cfg: Config, 16 | ): 17 | """ 18 | A generic agent class. 19 | """ 20 | self._observation_space = observation_space 21 | self._action_space = action_space 22 | self._cfg = cfg 23 | 24 | @abstractmethod 25 | def sample_actions( 26 | self, 27 | interaction_step: int, 28 | prev_timestep: Dict[str, np.ndarray], 29 | training: bool, 30 | ) -> np.ndarray: 31 | pass 32 | 33 | @abstractmethod 34 | def update(self, update_step: int, batch: Dict[str, np.ndarray]) -> Dict: 35 | pass 36 | 37 | @abstractmethod 38 | def save(self, path: str) -> None: 39 | pass 40 | 41 | @abstractmethod 42 | def load(self, path: str) -> None: 43 | pass 44 | 45 | # @abstractmethod 46 | def get_metrics( 47 | self, batch: Dict[str, np.ndarray], update_info: Dict[str, Any] 48 | ) -> Dict: 49 | pass 50 | 51 | 52 | class AgentWrapper(BaseAgent): 53 | """Wraps the agent to allow a modular transformation. 54 | 55 | This class is the base class for all wrappers for agent class. 56 | The subclass could override some methods to change the behavior of the original agent 57 | without touching the original code. 58 | 59 | Note: 60 | Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. 61 | """ 62 | 63 | def __init__(self, agent: BaseAgent): 64 | self.agent = agent 65 | 66 | # explicitly forward the methods defined in Agent to self.agent 67 | def sample_actions( 68 | self, 69 | interaction_step: int, 70 | prev_timestep: Dict[str, np.ndarray], 71 | training: bool, 72 | ) -> np.ndarray: 73 | return self.agent.sample_actions( 74 | interaction_step=interaction_step, 75 | prev_timestep=prev_timestep, 76 | training=training, 77 | ) 78 | 79 | def update(self, update_step: int, batch: Dict[str, np.ndarray]) -> Dict: 80 | return self.agent.update( 81 | update_step=update_step, 82 | batch=batch, 83 | ) 84 | 85 | def get_metrics( 86 | self, batch: Dict[str, np.ndarray], update_info: Dict[str, Any] 87 | ) -> Dict: 88 | return self.agent.get_metrics( 89 | batch=batch, 90 | update_info=update_info, 91 | ) 92 | 93 | def save(self, path: str) -> None: 94 | self.agent.save(path) 95 | 96 | def load(self, path: str) -> None: 97 | self.agent.load(path) 98 | 99 | def set_attr(self, name, values): 100 | return self.agent.set_attr(name, values) 101 | 102 | # implicitly forward all other methods and attributes to self.env 103 | def __getattr__(self, name): 104 | if name.startswith("_"): 105 | raise AttributeError(f"attempted to get missing private attribute '{name}'") 106 | """ 107 | logger.warn( 108 | f"env.{name} to get variables from other wrappers is deprecated and will be removed in v1.0, " 109 | f"to get this variable you can do `env.unwrapped.{name}` for environment variables." 110 | ) 111 | """ 112 | return getattr(self.agent, name) 113 | 114 | @property 115 | def unwrapped(self): 116 | return self.agent.unwrapped 117 | 118 | def __repr__(self): 119 | return f"<{self.__class__.__name__}, {self.agent}>" 120 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | # pytype static type analyzer 143 | .pytype/ 144 | 145 | # Cython debug symbols 146 | cython_debug/ 147 | 148 | # PyCharm 149 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 150 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 151 | # and can be added to the global gitignore or merged into this file. For a more nuclear 152 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 153 | #.idea/ 154 | 155 | # log 156 | wandb/** 157 | tmp/** 158 | 159 | # cache 160 | *.ipynb_checkpoints 161 | *.nfs* 162 | 163 | # saved checkpoints 164 | models/** -------------------------------------------------------------------------------- /scale_rl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from typing import Tuple, Any 3 | from gymnasium.vector import SyncVectorEnv, AsyncVectorEnv, VectorEnv 4 | from gymnasium.wrappers import RescaleAction, TimeLimit 5 | 6 | from scale_rl.envs.dmc import make_dmc_env 7 | from scale_rl.envs.mujoco import make_mujoco_env 8 | from scale_rl.envs.humanoid_bench import make_humanoid_env 9 | from scale_rl.envs.myosuite import make_myosuite_env 10 | from scale_rl.envs.d4rl import make_d4rl_env, make_d4rl_dataset, get_d4rl_normalized_score 11 | from scale_rl.envs.wrappers import RepeatAction 12 | 13 | 14 | def create_envs( 15 | env_type: str, 16 | seed: int, 17 | env_name: str, 18 | num_train_envs: int, 19 | num_eval_envs: int, 20 | rescale_action: bool, 21 | action_repeat: int, 22 | max_episode_steps: int, 23 | **kwargs, 24 | )-> Tuple[VectorEnv, VectorEnv]: 25 | 26 | train_env = create_vec_env( 27 | env_type=env_type, 28 | env_name=env_name, 29 | seed=seed, 30 | num_envs=num_train_envs, 31 | action_repeat=action_repeat, 32 | rescale_action=rescale_action, 33 | max_episode_steps=max_episode_steps, 34 | ) 35 | eval_env = create_vec_env( 36 | env_type=env_type, 37 | env_name=env_name, 38 | seed=seed, 39 | num_envs=num_eval_envs, 40 | action_repeat=action_repeat, 41 | rescale_action=rescale_action, 42 | max_episode_steps=max_episode_steps, 43 | ) 44 | 45 | return train_env, eval_env 46 | 47 | 48 | def create_vec_env( 49 | env_type: str, 50 | env_name: str, 51 | num_envs: int, 52 | seed: int, 53 | rescale_action: bool = True, 54 | action_repeat: int = 1, 55 | max_episode_steps: int = 1000, 56 | ) -> VectorEnv: 57 | 58 | def make_one_env( 59 | env_type: str, 60 | env_name:str, 61 | seed:int, 62 | rescale_action:bool, 63 | action_repeat:int, 64 | max_episode_steps: int, 65 | **kwargs 66 | ) -> gym.Env: 67 | 68 | if env_type == 'dmc': 69 | env = make_dmc_env(env_name, seed, **kwargs) 70 | elif env_type == 'mujoco': 71 | env = make_mujoco_env(env_name, seed, **kwargs) 72 | elif env_type == 'humanoid_bench': 73 | env = make_humanoid_env(env_name, seed, **kwargs) 74 | elif env_type == 'myosuite': 75 | env = make_myosuite_env(env_name, seed, **kwargs) 76 | elif env_type == "d4rl": 77 | env = make_d4rl_env(env_name, seed, **kwargs) 78 | else: 79 | raise NotImplementedError 80 | 81 | if rescale_action: 82 | env = RescaleAction(env, -1.0, 1.0) 83 | 84 | # limit max_steps before action_repeat. 85 | env = TimeLimit(env, max_episode_steps) 86 | 87 | if action_repeat > 1: 88 | env = RepeatAction(env, action_repeat) 89 | 90 | env.observation_space.seed(seed) 91 | env.action_space.seed(seed) 92 | 93 | return env 94 | 95 | env_fns = [ 96 | ( 97 | lambda i=i: make_one_env( 98 | env_type=env_type, 99 | env_name=env_name, 100 | seed=seed + i, 101 | rescale_action=rescale_action, 102 | action_repeat=action_repeat, 103 | max_episode_steps=max_episode_steps, 104 | ) 105 | ) 106 | for i in range(num_envs) 107 | ] 108 | if len(env_fns) > 1: 109 | envs = AsyncVectorEnv(env_fns, autoreset_mode='SameStep') 110 | else: 111 | envs = SyncVectorEnv(env_fns, autoreset_mode='SameStep') 112 | 113 | return envs 114 | 115 | 116 | def create_dataset(env_type: str, env_name: str) -> list[dict[str, Any]]: 117 | if env_type == 'd4rl': 118 | dataset = make_d4rl_dataset(env_name) 119 | else: 120 | raise NotImplementedError 121 | return dataset 122 | 123 | 124 | def get_normalized_score(env_type: str, env_name: str, unnormalized_score: float) -> float: 125 | if env_type == "d4rl": 126 | score = get_d4rl_normalized_score(env_name, unnormalized_score) 127 | else: 128 | raise NotImplementedError 129 | return score 130 | -------------------------------------------------------------------------------- /scale_rl/evaluation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import gymnasium as gym 4 | import numpy as np 5 | from gymnasium.vector import VectorEnv 6 | 7 | import wandb 8 | 9 | 10 | def evaluate( 11 | agent, 12 | env: VectorEnv, 13 | num_episodes: int, 14 | ) -> Dict[str, float]: 15 | n = env.num_envs 16 | 17 | assert num_episodes % n == 0, "num_episodes must be divisible by env.num_envs" 18 | num_eval_episodes_per_env = num_episodes // n 19 | 20 | total_returns = [] 21 | total_successes = [] 22 | total_lengths = [] 23 | 24 | for _ in range(num_eval_episodes_per_env): 25 | returns = np.zeros(n) 26 | lengths = np.zeros(n) 27 | successes = np.zeros(n) 28 | 29 | observations, infos = env.reset() 30 | 31 | prev_timestep = {"next_observation": observations} 32 | 33 | dones = np.zeros(n) 34 | while np.sum(dones) < n: 35 | actions = agent.sample_actions( 36 | interaction_step=0, 37 | prev_timestep=prev_timestep, 38 | training=False, 39 | ) 40 | next_observations, rewards, terminateds, truncateds, infos = env.step( 41 | actions 42 | ) 43 | 44 | prev_timestep = {"next_observation": next_observations} 45 | 46 | returns += rewards * (1 - dones) 47 | lengths += 1 - dones 48 | 49 | if "success" in infos: 50 | successes += infos["success"].astype("float") * (1 - dones) 51 | 52 | elif "final_info" in infos: 53 | final_successes = np.zeros(n) 54 | for idx in range(n): 55 | final_info = infos["final_info"] 56 | 57 | if "success" in final_info: 58 | try: 59 | final_successes[idx] = final_info["success"][idx].astype( 60 | "float" 61 | ) 62 | except: 63 | final_successes[idx] = np.array( 64 | final_info["success"][idx] 65 | ).astype("float") 66 | successes += final_successes * (1 - dones) 67 | 68 | else: 69 | pass 70 | 71 | # once an episode is done in a sub-environment, we assume it to be done. 72 | # also, we assume to be done whether it is terminated or truncated during evaluation. 73 | dones = np.maximum(dones, terminateds) 74 | dones = np.maximum(dones, truncateds) 75 | 76 | # proceed 77 | observations = next_observations 78 | 79 | for env_idx in range(n): 80 | total_returns.append(returns[env_idx]) 81 | total_lengths.append(lengths[env_idx]) 82 | total_successes.append(successes[env_idx].astype("bool").astype("float")) 83 | 84 | eval_info = { 85 | "avg_return": np.mean(total_returns), 86 | "avg_length": np.mean(total_lengths), 87 | "avg_success": np.mean(total_successes), 88 | } 89 | 90 | return eval_info 91 | 92 | 93 | def record_video( 94 | agent, 95 | env: VectorEnv, 96 | num_episodes: int, 97 | video_length: int = 100, 98 | ) -> Dict[str, float]: 99 | n = env.num_envs 100 | assert num_episodes % n == 0, "num_episodes must be divisible by env.num_envs" 101 | num_eval_episodes_per_env = num_episodes // n 102 | 103 | total_videos = [] 104 | 105 | for _ in range(num_eval_episodes_per_env): 106 | videos = [] 107 | 108 | observations, infos = env.reset() 109 | prev_timestep = {"next_observation": observations} 110 | images = env.call("render") 111 | dones = np.zeros(n) 112 | while np.sum(dones) < n: 113 | actions = agent.sample_actions( 114 | interaction_step=0, 115 | prev_timestep=prev_timestep, 116 | training=False, 117 | ) 118 | next_observations, rewards, terminateds, truncateds, infos = env.step( 119 | actions 120 | ) 121 | 122 | prev_timestep = {"next_observation": next_observations} 123 | 124 | # once an episode is done in a sub-environment, we assume it to be done. 125 | dones = np.maximum(dones, terminateds) 126 | dones = np.maximum(dones, truncateds) 127 | 128 | # proceed 129 | videos.append(images) 130 | images = env.call("render") 131 | observations = next_observations 132 | 133 | total_videos.append(np.stack(videos, axis=1)) # (n, t, c, h, w) 134 | 135 | total_videos = np.concatenate(total_videos, axis=0) # (b, t, h, w, c) 136 | total_videos = total_videos[:, :video_length] 137 | total_videos = total_videos.transpose(0, 1, 4, 2, 3) # (b, t, c, h, w) 138 | 139 | video_info = {"video": wandb.Video(total_videos, fps=10, format="gif")} 140 | 141 | return video_info 142 | -------------------------------------------------------------------------------- /scale_rl/buffers/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def fast_uniform_sample(max_size: int, num_samples: int): 5 | """ 6 | for speed comparison of uniform sampling, refer to analysis/benchmark_buffer_speed.ipynb 7 | """ 8 | interval = max_size // num_samples 9 | if max_size % num_samples == 0: 10 | return np.arange(0, max_size, interval) + np.random.randint( 11 | 0, interval, size=num_samples 12 | ) 13 | else: 14 | return np.arange(0, max_size, interval)[:-1] + np.random.randint( 15 | 0, interval, size=num_samples 16 | ) 17 | 18 | 19 | # Segment tree data structure where parent node values are sum/max of children node values 20 | class SegmentTree: 21 | def __init__(self, size): 22 | self._index = 0 23 | self._size = size 24 | self._full = False # Used to track actual capacity 25 | self._tree_start_idx = ( 26 | 2 ** (size - 1).bit_length() - 1 27 | ) # Put all used node leaves on last tree level 28 | self._sum_tree = np.zeros( 29 | (self._tree_start_idx + self._size,), dtype=np.float32 30 | ) 31 | self._max = ( 32 | 1 # Initial max value to return (1 = 1^ω), default priority is set to max 33 | ) 34 | 35 | # Updates nodes values from current tree 36 | def _update_nodes(self, indices): 37 | children_indices = indices * 2 + np.expand_dims([1, 2], axis=1) 38 | self._sum_tree[indices] = np.sum(self._sum_tree[children_indices], axis=0) 39 | 40 | # Propagates changes up tree given tree indices 41 | def _propagate(self, indices): 42 | parents = (indices - 1) // 2 43 | unique_parents = np.unique(parents) 44 | self._update_nodes(unique_parents) 45 | if parents[0] != 0: 46 | self._propagate(parents) 47 | 48 | # Propagates single value up tree given a tree index for efficiency 49 | def _propagate_index(self, index): 50 | parent = (index - 1) // 2 51 | left, right = 2 * parent + 1, 2 * parent + 2 52 | self._sum_tree[parent] = self._sum_tree[left] + self._sum_tree[right] 53 | if parent != 0: 54 | self._propagate_index(parent) 55 | 56 | # Updates values given tree indices 57 | def update(self, indices, values): 58 | self._sum_tree[indices] = values # Set new values 59 | self._propagate(indices) # Propagate values 60 | current_max_value = np.max(values) 61 | self._max = max(current_max_value, self._max) 62 | 63 | # Updates single value given a tree index for efficiency 64 | def _update_index(self, index, value): 65 | self._sum_tree[index] = value # Set new value 66 | self._propagate_index(index) # Propagate value 67 | self._max = max(value, self._max) 68 | 69 | def add(self, value): 70 | self._update_index(self._index + self._tree_start_idx, value) # Update tree 71 | self._index = (self._index + 1) % self._size # Update index 72 | self._full = self._full or self._index == 0 # Save when capacity reached 73 | self._max = max(value, self._max) 74 | 75 | # Searches for the location of values in sum tree 76 | def _retrieve(self, indices, values): 77 | children_indices = indices * 2 + np.expand_dims( 78 | [1, 2], axis=1 79 | ) # Make matrix of children indices 80 | # If indices correspond to leaf nodes, return them 81 | if children_indices[0, 0] >= self._sum_tree.shape[0]: 82 | return indices 83 | # If children indices correspond to leaf nodes, bound rare outliers in case total slightly overshoots 84 | elif children_indices[0, 0] >= self._tree_start_idx: 85 | children_indices = np.minimum(children_indices, self._sum_tree.shape[0] - 1) 86 | left_children_values = self._sum_tree[children_indices[0]] 87 | successor_choices = np.greater(values, left_children_values).astype( 88 | np.int32 89 | ) # Classify which values are in left or right branches 90 | successor_indices = children_indices[ 91 | successor_choices, np.arange(indices.size) 92 | ] # Use classification to index into the indices matrix 93 | successor_values = ( 94 | values - successor_choices * left_children_values 95 | ) # Subtract the left branch values when searching in the right branch 96 | return self._retrieve(successor_indices, successor_values) 97 | 98 | # Searches for values in sum tree and returns values, data indices and tree indices 99 | def find(self, values): 100 | indices = self._retrieve(np.zeros(values.shape, dtype=np.int32), values) 101 | data_index = indices - self._tree_start_idx 102 | return ( 103 | data_index, 104 | indices, 105 | self._sum_tree[indices], 106 | ) # Return values, data indices, tree indices 107 | 108 | @property 109 | def total(self): 110 | return self._sum_tree[0] 111 | 112 | @property 113 | def max(self): 114 | return self._max 115 | -------------------------------------------------------------------------------- /scale_rl/agents/simbaV2/simbaV2_network.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | from tensorflow_probability.substrates import jax as tfp 4 | 5 | from scale_rl.agents.simbaV2.simbaV2_layer import ( 6 | HyperCategoricalValue, 7 | HyperEmbedder, 8 | HyperLERPBlock, 9 | HyperNormalTanhPolicy, 10 | ) 11 | 12 | tfd = tfp.distributions 13 | tfb = tfp.bijectors 14 | 15 | 16 | class SimbaV2Actor(nn.Module): 17 | num_blocks: int 18 | hidden_dim: int 19 | action_dim: int 20 | scaler_init: float 21 | scaler_scale: float 22 | alpha_init: float 23 | alpha_scale: float 24 | c_shift: float 25 | 26 | def setup(self): 27 | self.embedder = HyperEmbedder( 28 | hidden_dim=self.hidden_dim, 29 | scaler_init=self.scaler_init, 30 | scaler_scale=self.scaler_scale, 31 | c_shift=self.c_shift, 32 | ) 33 | self.encoder = nn.Sequential( 34 | [ 35 | HyperLERPBlock( 36 | hidden_dim=self.hidden_dim, 37 | scaler_init=self.scaler_init, 38 | scaler_scale=self.scaler_scale, 39 | alpha_init=self.alpha_init, 40 | alpha_scale=self.alpha_scale, 41 | ) 42 | for _ in range(self.num_blocks) 43 | ] 44 | ) 45 | self.predictor = HyperNormalTanhPolicy( 46 | hidden_dim=self.hidden_dim, 47 | action_dim=self.action_dim, 48 | scaler_init=1.0, 49 | scaler_scale=1.0, 50 | ) 51 | 52 | def __call__( 53 | self, 54 | observations: jnp.ndarray, 55 | temperature: float = 1.0, 56 | ) -> tfd.Distribution: 57 | x = observations 58 | y = self.embedder(x) 59 | z = self.encoder(y) 60 | dist, info = self.predictor(z, temperature) 61 | 62 | return dist, info 63 | 64 | 65 | class SimbaV2Critic(nn.Module): 66 | num_blocks: int 67 | hidden_dim: int 68 | scaler_init: float 69 | scaler_scale: float 70 | alpha_init: float 71 | alpha_scale: float 72 | c_shift: float 73 | num_bins: int 74 | min_v: float 75 | max_v: float 76 | 77 | def setup(self): 78 | self.embedder = HyperEmbedder( 79 | hidden_dim=self.hidden_dim, 80 | scaler_init=self.scaler_init, 81 | scaler_scale=self.scaler_scale, 82 | c_shift=self.c_shift, 83 | ) 84 | self.encoder = nn.Sequential( 85 | [ 86 | HyperLERPBlock( 87 | hidden_dim=self.hidden_dim, 88 | scaler_init=self.scaler_init, 89 | scaler_scale=self.scaler_scale, 90 | alpha_init=self.alpha_init, 91 | alpha_scale=self.alpha_scale, 92 | ) 93 | for _ in range(self.num_blocks) 94 | ] 95 | ) 96 | 97 | self.predictor = HyperCategoricalValue( 98 | hidden_dim=self.hidden_dim, 99 | num_bins=self.num_bins, 100 | min_v=self.min_v, 101 | max_v=self.max_v, 102 | scaler_init=1.0, 103 | scaler_scale=1.0, 104 | ) 105 | 106 | def __call__( 107 | self, 108 | observations: jnp.ndarray, 109 | actions: jnp.ndarray, 110 | ) -> jnp.ndarray: 111 | x = jnp.concatenate((observations, actions), axis=1) 112 | y = self.embedder(x) 113 | z = self.encoder(y) 114 | q, info = self.predictor(z) 115 | return q, info 116 | 117 | 118 | class SimbaV2DoubleCritic(nn.Module): 119 | """ 120 | Vectorized Double-Q for Clipped Double Q-learning. 121 | https://arxiv.org/pdf/1802.09477v3 122 | """ 123 | 124 | num_blocks: int 125 | hidden_dim: int 126 | scaler_init: float 127 | scaler_scale: float 128 | alpha_init: float 129 | alpha_scale: float 130 | c_shift: float 131 | num_bins: int 132 | min_v: float 133 | max_v: float 134 | 135 | num_qs: int = 2 136 | 137 | @nn.compact 138 | def __call__( 139 | self, 140 | observations: jnp.ndarray, 141 | actions: jnp.ndarray, 142 | ) -> jnp.ndarray: 143 | VmapCritic = nn.vmap( 144 | SimbaV2Critic, 145 | variable_axes={"params": 0}, 146 | split_rngs={"params": True}, 147 | in_axes=None, 148 | out_axes=0, 149 | axis_size=self.num_qs, 150 | ) 151 | 152 | qs, infos = VmapCritic( 153 | num_blocks=self.num_blocks, 154 | hidden_dim=self.hidden_dim, 155 | scaler_init=self.scaler_init, 156 | scaler_scale=self.scaler_scale, 157 | alpha_init=self.alpha_init, 158 | alpha_scale=self.alpha_scale, 159 | c_shift=self.c_shift, 160 | num_bins=self.num_bins, 161 | min_v=self.min_v, 162 | max_v=self.max_v, 163 | )(observations, actions) 164 | 165 | return qs, infos 166 | 167 | 168 | class SimbaV2Temperature(nn.Module): 169 | initial_value: float = 0.01 170 | 171 | @nn.compact 172 | def __call__(self) -> jnp.ndarray: 173 | log_temp = self.param( 174 | name="log_temp", 175 | init_fn=lambda key: jnp.full( 176 | shape=(), fill_value=jnp.log(self.initial_value) 177 | ), 178 | ) 179 | return jnp.exp(log_temp) 180 | -------------------------------------------------------------------------------- /scale_rl/envs/humanoid_bench.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | ################################# 4 | # 5 | # Original Humanoid 6 | # 7 | ################################# 8 | 9 | # 14 tasks 10 | HB_LOCOMOTION = [ 11 | "h1hand-walk-v0", 12 | "h1hand-stand-v0", 13 | "h1hand-run-v0", 14 | "h1hand-reach-v0", 15 | "h1hand-hurdle-v0", 16 | "h1hand-crawl-v0", 17 | "h1hand-maze-v0", 18 | "h1hand-sit_simple-v0", 19 | "h1hand-sit_hard-v0", 20 | "h1hand-balance_simple-v0", 21 | "h1hand-balance_hard-v0", 22 | "h1hand-stair-v0", 23 | "h1hand-slide-v0", 24 | "h1hand-pole-v0", 25 | ] 26 | 27 | # 17 tasks 28 | HB_MANIPULATION = [ 29 | "h1hand-push-v0", 30 | "h1hand-cabinet-v0", 31 | "h1strong-highbar_hard-v0", # Make hands stronger to be able to hang from the high bar 32 | "h1hand-door-v0", 33 | "h1hand-truck-v0", 34 | "h1hand-cube-v0", 35 | "h1hand-bookshelf_simple-v0", 36 | "h1hand-bookshelf_hard-v0", 37 | "h1hand-basketball-v0", 38 | "h1hand-window-v0", 39 | "h1hand-spoon-v0", 40 | "h1hand-kitchen-v0", 41 | "h1hand-package-v0", 42 | "h1hand-powerlift-v0", 43 | "h1hand-room-v0", 44 | "h1hand-insert_small-v0", 45 | "h1hand-insert_normal-v0", 46 | ] 47 | 48 | 49 | ################################# 50 | # 51 | # No Hand Humanoid 52 | # 53 | ################################# 54 | 55 | HB_LOCOMOTION_NOHAND = [ 56 | "h1-walk-v0", 57 | "h1-stand-v0", 58 | "h1-run-v0", 59 | "h1-reach-v0", 60 | "h1-hurdle-v0", 61 | "h1-crawl-v0", 62 | "h1-maze-v0", 63 | "h1-sit_simple-v0", 64 | "h1-sit_hard-v0", 65 | "h1-balance_simple-v0", 66 | "h1-balance_hard-v0", 67 | "h1-stair-v0", 68 | "h1-slide-v0", 69 | "h1-pole-v0", 70 | ] 71 | 72 | HB_LOCOMOTION_NOHAND_MINI = [ 73 | "h1-walk-v0", 74 | "h1-run-v0", 75 | "h1-sit_hard-v0", 76 | "h1-balance_simple-v0", 77 | "h1-stair-v0", 78 | ] 79 | 80 | ################################# 81 | # 82 | # Task Success scores 83 | # 84 | ################################# 85 | 86 | # 10 seeds, 10 eval envs per 1 seed 87 | HB_RANDOM_SCORE = { 88 | "h1-walk-v0": 2.377, 89 | "h1-stand-v0": 10.545, 90 | "h1-run-v0": 2.02, 91 | "h1-reach-v0": 260.302, 92 | "h1-hurdle-v0": 2.214, 93 | "h1-crawl-v0": 272.658, 94 | "h1-maze-v0": 106.441, 95 | "h1-sit_simple-v0": 9.393, 96 | "h1-sit_hard-v0": 2.448, 97 | "h1-balance_simple-v0": 9.391, 98 | "h1-balance_hard-v0": 9.044, 99 | "h1-stair-v0": 3.112, 100 | "h1-slide-v0": 3.191, 101 | "h1-pole-v0": 20.09, 102 | "h1hand-push-v0": -526.8, 103 | "h1hand-cabinet-v0": 37.733, 104 | "h1strong-highbar-hard-v0": 0.178, 105 | "h1hand-door-v0": 2.771, 106 | "h1hand-truck-v0": 562.419, 107 | "h1hand-cube-v0": 4.787, 108 | "h1hand-bookshelf_simple-v0": 16.777, 109 | "h1hand-bookshelf_hard-v0": 14.848, 110 | "h1hand-basketball-v0": 8.979, 111 | "h1hand-window-v0": 2.713, 112 | "h1hand-spoon-v0": 4.661, 113 | "h1hand-kitchen-v0": 0.0, 114 | "h1hand-package-v0": -10040.932, 115 | "h1hand-powerlift-v0": 17.638, 116 | "h1hand-room-v0": 3.018, 117 | "h1hand-insert_small-v0": 1.653, 118 | "h1hand-insert_normal-v0": 1.673, 119 | "h1hand-walk-v0": 2.505, 120 | "h1hand-stand-v0": 11.973, 121 | "h1hand-run-v0": 1.927, 122 | "h1hand-reach-v0": -50.024, 123 | "h1hand-hurdle-v0": 2.371, 124 | "h1hand-crawl-v0": 278.868, 125 | "h1hand-maze-v0": 106.233, 126 | "h1hand-sit_simple-v0": 10.768, 127 | "h1hand-sit_hard-v0": 2.477, 128 | "h1hand-balance_simple-v0": 10.17, 129 | "h1hand-balance_hard-v0": 10.032, 130 | "h1hand-stair-v0": 3.161, 131 | "h1hand-slide-v0": 3.142, 132 | "h1hand-pole-v0": 19.721, 133 | } 134 | 135 | HB_SUCCESS_SCORE = { 136 | "h1-walk-v0": 700.0, 137 | "h1-stand-v0": 800.0, 138 | "h1-run-v0": 700.0, 139 | "h1-reach-v0": 12000.0, 140 | "h1-hurdle-v0": 700.0, 141 | "h1-crawl-v0": 700.0, 142 | "h1-maze-v0": 1200.0, 143 | "h1-sit_simple-v0": 750.0, 144 | "h1-sit_hard-v0": 750.0, 145 | "h1-balance_simple-v0": 800.0, 146 | "h1-balance_hard-v0": 800.0, 147 | "h1-stair_v0": 700.0, 148 | "h1-slide_v0": 700.0, 149 | "h1-pole_v0": 700.0, 150 | "h1-push_v0": 700.0, 151 | "h1-cabinet_v0": 2500.0, 152 | "h1-highbar_v0": 750.0, 153 | "h1-door_v0": 600.0, 154 | "h1-truck_v0": 3000.0, 155 | "h1-cube_v0": 370.0, 156 | "h1-bookshelf_simple_v0": 2000.0, 157 | "h1-bookshelf_hard_v0": 2000.0, 158 | "h1-basketball_v0": 1200.0, 159 | "h1-window_v0": 650.0, 160 | "h1-spoon_v0": 650.0, 161 | "h1-kitchen_v0": 4.0, 162 | "h1-package_v0": 1500.0, 163 | "h1-powerlift_v0": 800.0, 164 | "h1-room_v0": 400.0, 165 | "h1-insert_small_v0": 350.0, 166 | "h1-insert_normal_v0": 350.0, 167 | } 168 | 169 | 170 | class HBGymnasiumVersionWrapper(gym.Wrapper): 171 | """ 172 | humanoid bench originally requires gymnasium==0.29.1 173 | however, we are currently using gymnasium==1.0.0a2, 174 | hence requiring some minor fix to the rendering function 175 | """ 176 | 177 | def __init__(self, env: gym.Env): 178 | super().__init__(env) 179 | self.task = env.unwrapped.task 180 | 181 | def render(self): 182 | return self.task._env.mujoco_renderer.render(self.task._env.render_mode) 183 | 184 | 185 | def make_humanoid_env( 186 | env_name: str, 187 | seed: int, 188 | ) -> gym.Env: 189 | import humanoid_bench 190 | 191 | additional_kwargs = {} 192 | if env_name == "h1hand-package-v0": 193 | additional_kwargs = {"policy_path": None} 194 | env = gym.make(env_name, **additional_kwargs) 195 | env = HBGymnasiumVersionWrapper(env) 196 | 197 | return env 198 | -------------------------------------------------------------------------------- /scale_rl/agents/simba/simba_update.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple 2 | 3 | import flax 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from scale_rl.agents.jax_utils.network import Network, PRNGKey 8 | from scale_rl.buffers import Batch 9 | 10 | 11 | def update_actor( 12 | key: PRNGKey, 13 | actor: Network, 14 | critic: Network, 15 | temperature: Network, 16 | batch: Batch, 17 | use_cdq: bool, 18 | bc_alpha: float, 19 | ) -> Tuple[Network, Dict[str, float]]: 20 | def actor_loss_fn( 21 | actor_params: flax.core.FrozenDict[str, Any], 22 | ) -> Tuple[jnp.ndarray, Dict[str, float]]: 23 | dist, _ = actor.apply( 24 | variables={"params": actor_params}, 25 | observations=batch["observation"], 26 | ) 27 | actions = dist.sample(seed=key) 28 | log_probs = dist.log_prob(actions) 29 | 30 | if use_cdq: 31 | # qs: (2, n) 32 | qs, q_infos = critic(observations=batch["observation"], actions=actions) 33 | q = jnp.minimum(qs[0], qs[1]) 34 | else: 35 | q, _ = critic(observations=batch["observation"], actions=actions) 36 | 37 | actor_loss = (log_probs * temperature() - q).mean() 38 | 39 | if bc_alpha > 0: 40 | # https://arxiv.org/abs/2306.02451 41 | q_abs = jax.lax.stop_gradient(jnp.abs(q).mean()) 42 | bc_loss = ((actions - batch["action"]) ** 2).mean() 43 | actor_loss = actor_loss + bc_alpha * q_abs * bc_loss 44 | 45 | actor_info = { 46 | "actor/loss": actor_loss, 47 | "actor/entropy": -log_probs.mean(), 48 | "actor/mean_action": jnp.mean(actions), 49 | } 50 | return actor_loss, actor_info 51 | 52 | actor, info = actor.apply_gradient(actor_loss_fn) 53 | 54 | return actor, info 55 | 56 | 57 | def update_critic( 58 | key: PRNGKey, 59 | actor: Network, 60 | critic: Network, 61 | target_critic: Network, 62 | temperature: Network, 63 | batch: Batch, 64 | use_cdq: bool, 65 | gamma: float, 66 | n_step: int, 67 | ) -> Tuple[Network, Dict[str, float]]: 68 | # compute the target q-value 69 | next_dist, _ = actor(observations=batch["next_observation"]) 70 | next_actions = next_dist.sample(seed=key) 71 | next_actor_log_probs = next_dist.log_prob(next_actions) 72 | next_actor_entropy = temperature() * next_actor_log_probs 73 | 74 | if use_cdq: 75 | # next_qs: (2, n) 76 | next_qs, next_q_infos = target_critic( 77 | observations=batch["next_observation"], actions=next_actions 78 | ) 79 | next_q = jnp.minimum(next_qs[0], next_qs[1]) 80 | else: 81 | next_q, next_q_info = target_critic( 82 | observations=batch["next_observation"], 83 | actions=next_actions, 84 | ) 85 | 86 | # compute the td-target, incorporating the n-step accumulated reward 87 | # https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/ 88 | target_q = batch["reward"] + (gamma**n_step) * (1 - batch["terminated"]) * ( 89 | next_q - next_actor_entropy 90 | ) 91 | 92 | def critic_loss_fn( 93 | critic_params: flax.core.FrozenDict[str, Any], 94 | ) -> Tuple[jnp.ndarray, Dict[str, float]]: 95 | # compute predicted q-value 96 | if use_cdq: 97 | pred_qs, pred_q_infos = critic.apply( 98 | variables={"params": critic_params}, 99 | observations=batch["observation"], 100 | actions=batch["action"], 101 | ) 102 | loss_1 = (pred_qs[0] - target_q) ** 2 103 | loss_2 = (pred_qs[1] - target_q) ** 2 104 | critic_loss = (loss_1 + loss_2).mean() 105 | else: 106 | pred_q, _ = critic.apply( 107 | variables={"params": critic_params}, 108 | observations=batch["observation"], 109 | actions=batch["action"], 110 | ) 111 | critic_loss = ((pred_q - target_q) ** 2).mean() 112 | 113 | critic_info = { 114 | "critic/loss": critic_loss, 115 | "critic/batch_rew_min": batch["reward"].min(), 116 | "critic/batch_rew_mean": batch["reward"].mean(), 117 | "critic/batch_rew_max": batch["reward"].max(), 118 | } 119 | 120 | return critic_loss, critic_info 121 | 122 | critic, info = critic.apply_gradient(critic_loss_fn) 123 | 124 | return critic, info 125 | 126 | 127 | def update_target_network( 128 | network: Network, 129 | target_network: Network, 130 | target_tau: bool, 131 | ) -> Tuple[Network, Dict[str, float]]: 132 | new_target_params = jax.tree_map( 133 | lambda p, tp: p * target_tau + tp * (1 - target_tau), 134 | network.params, 135 | target_network.params, 136 | ) 137 | target_network = target_network.replace(params=new_target_params) 138 | 139 | info = {} 140 | 141 | return target_network, info 142 | 143 | 144 | def update_temperature( 145 | temperature: Network, entropy: float, target_entropy: float 146 | ) -> Tuple[Network, Dict[str, float]]: 147 | def temperature_loss_fn( 148 | temperature_params: flax.core.FrozenDict[str, Any], 149 | ) -> Tuple[jnp.ndarray, Dict[str, float]]: 150 | temperature_value = temperature.apply({"params": temperature_params}) 151 | temperature_loss = temperature_value * (entropy - target_entropy).mean() 152 | temperature_info = { 153 | "temperature/value": temperature_value, 154 | "temperature/loss": temperature_loss, 155 | } 156 | 157 | return temperature_loss, temperature_info 158 | 159 | temperature, info = temperature.apply_gradient(temperature_loss_fn) 160 | 161 | return temperature, info 162 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimbaV2 2 | 3 | > **DAVIAN Robotics, KAIST AI & Sony AI** 4 | > ICML 2025 (spotlight). 5 | 6 | ## Introduction 7 | 8 | SimbaV2 is a reinforcement learning architecture designed to stabilize training via hyperspherical normalization. By increasing model capacity and compute, SimbaV2 achieves state-of-the-art results on 57 continuous control tasks from MuJoCo, DMControl, MyoSuite, and Humanoid-bench. 9 | 10 |
11 |
12 |
36 |
37 |
46 |
47 |
48 |
53 |
54 |