├── scale_rl ├── __init__.py ├── agents │ ├── simba │ │ ├── __init__.py │ │ ├── simba_layer.py │ │ ├── simba_network.py │ │ ├── simba_update.py │ │ └── simba_agent.py │ ├── jax_utils │ │ ├── __init__.py │ │ ├── tree_utils.py │ │ └── network.py │ ├── simbaV2 │ │ ├── __init__.py │ │ ├── simbaV2_network.py │ │ ├── simbaV2_layer.py │ │ └── simbaV2_update.py │ ├── wrappers │ │ ├── __init__.py │ │ ├── utils.py │ │ └── normalization.py │ ├── random_agent.py │ ├── __init__.py │ └── base_agent.py ├── envs │ ├── wrappers │ │ ├── __init__.py │ │ └── repeat_action.py │ ├── mujoco.py │ ├── dmc.py │ ├── d4rl.py │ ├── myosuite.py │ ├── __init__.py │ └── humanoid_bench.py ├── common │ ├── __init__.py │ ├── scheduler.py │ └── logger.py ├── buffers │ ├── __init__.py │ ├── base_buffer.py │ ├── utils.py │ └── numpy_buffer.py └── evaluation.py ├── docs ├── README.md ├── .DS_Store ├── images │ ├── header.png │ ├── online.png │ ├── analysis.png │ ├── offline.png │ ├── overview.png │ ├── benchmark.png │ ├── utd_scaling.png │ ├── param_scaling.png │ └── simbav2_architecture.png ├── thumbnails │ ├── 0.png │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ └── 7.png ├── dataset │ ├── videos │ │ ├── .DS_Store │ │ ├── dmc │ │ │ ├── simbav2-dog-run.mp4 │ │ │ ├── simbav2-dog-trot.mp4 │ │ │ ├── simbav2-dog-walk.mp4 │ │ │ ├── simbav2-cheetah-run.mp4 │ │ │ ├── simbav2-dog-stand.mp4 │ │ │ ├── simbav2-finger-spin.mp4 │ │ │ ├── simbav2-fish-swim.mp4 │ │ │ ├── simbav2-hopper-hop.mp4 │ │ │ ├── simbav2-walker-run.mp4 │ │ │ ├── simbav2-walker-walk.mp4 │ │ │ ├── simbav2-hopper-stand.mp4 │ │ │ ├── simbav2-humanoid-run.mp4 │ │ │ ├── simbav2-humanoid-walk.mp4 │ │ │ ├── simbav2-quadruped-run.mp4 │ │ │ ├── simbav2-reacher-easy.mp4 │ │ │ ├── simbav2-reacher-hard.mp4 │ │ │ ├── simbav2-walker-stand.mp4 │ │ │ ├── simbav2-acrobot-swingup.mp4 │ │ │ ├── simbav2-cartpole-balance.mp4 │ │ │ ├── simbav2-cartpole-swingup.mp4 │ │ │ ├── simbav2-finger-turn-easy.mp4 │ │ │ ├── simbav2-finger-turn-hard.mp4 │ │ │ ├── simbav2-humanoid-stand.mp4 │ │ │ ├── simbav2-pendulum-swingup.mp4 │ │ │ ├── simbav2-quadruped-walk.mp4 │ │ │ ├── simbav2-ball-in-cup-catch.mp4 │ │ │ ├── simbav2-cartpole-balance-sparse.mp4 │ │ │ └── simbav2-cartpole-swingup-sparse.mp4 │ │ ├── hbench │ │ │ ├── simbav2-h1-crawl.mp4 │ │ │ ├── simbav2-h1-maze.mp4 │ │ │ ├── simbav2-h1-pole.mp4 │ │ │ ├── simbav2-h1-reach.mp4 │ │ │ ├── simbav2-h1-run.mp4 │ │ │ ├── simbav2-h1-slide.mp4 │ │ │ ├── simbav2-h1-stair.mp4 │ │ │ ├── simbav2-h1-stand.mp4 │ │ │ ├── simbav2-h1-walk.mp4 │ │ │ ├── simbav2-h1-hurdle.mp4 │ │ │ ├── simbav2-h1-sit-hard.mp4 │ │ │ ├── simbav2-h1-sit-simple.mp4 │ │ │ ├── simbav2-h1-balance-hard.mp4 │ │ │ └── simbav2-h1-balance-simple.mp4 │ │ ├── mujoco │ │ │ ├── simbav2-ant-v4.mp4 │ │ │ ├── simbav2-cheetah-v4.mp4 │ │ │ ├── simbav2-hopper-v4.mp4 │ │ │ ├── simbav2-walker-v4.mp4 │ │ │ └── simbav2-humanoid-v4.mp4 │ │ └── myosuite │ │ │ ├── simbav2-myo-key-easy.mp4 │ │ │ ├── simbav2-myo-key-hard.mp4 │ │ │ ├── simbav2-myo-obj-easy.mp4 │ │ │ ├── simbav2-myo-obj-hard.mp4 │ │ │ ├── simbav2-myo-pen-easy.mp4 │ │ │ ├── simbav2-myo-pen-hard.mp4 │ │ │ ├── simbav2-myo-pose-easy.mp4 │ │ │ ├── simbav2-myo-pose-hard.mp4 │ │ │ ├── simbav2-myo-reach-easy.mp4 │ │ │ └── simbav2-myo-reach-hard.mp4 │ ├── css │ │ ├── bulma.min.css │ │ ├── bulma-carousel.min.css │ │ └── style.css │ └── js │ │ └── index.js ├── videos │ ├── dmc │ │ ├── simbav2-dog-run.mp4 │ │ ├── simbav2-dog-trot.mp4 │ │ ├── simbav2-humanoid-stand.mp4 │ │ └── simbav2-humanoid-walk.mp4 │ ├── mujoco │ │ ├── simbav2-ant-v4.mp4 │ │ ├── simbav2-cheetah-v4.mp4 │ │ ├── simbav2-humanoid-v4.mp4 │ │ └── simbav2-walker-v4.mp4 │ ├── hbench │ │ ├── simbav2-h1-crawl.mp4 │ │ ├── simbav2-h1-stand.mp4 │ │ ├── simbav2-h1-sit-hard.mp4 │ │ └── simbav2-h1-balance-hard.mp4 │ └── myosuite │ │ ├── simbav2-myo-key-hard.mp4 │ │ ├── simbav2-myo-obj-hard.mp4 │ │ ├── simbav2-myo-pen-hard.mp4 │ │ └── simbav2-myo-reach-hard.mp4 └── style.css ├── scripts └── simbaV2 │ ├── test_offline.sh │ └── test_online.sh ├── configs ├── agent │ ├── random.yaml │ ├── simba.yaml │ ├── simbaV2_bc.yaml │ └── simbaV2.yaml ├── buffer │ └── numpy_uniform.yaml ├── env │ ├── d4rl.yaml │ ├── mujoco.yaml │ ├── myosuite.yaml │ ├── dmc.yaml │ └── hb_locomotion.yaml ├── offline_rl.yaml └── online_rl.yaml ├── setup.py ├── deps ├── requirements.txt ├── environment.yaml └── Dockerfile ├── .github └── workflows │ └── pre-commit.yml ├── .pre-commit-config.yaml ├── .gitignore ├── README.md ├── run_offline.py ├── run_online.py └── run_parallel.py /scale_rl/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/README.md: -------------------------------------------------------------------------------- 1 | # simbaV2.github.io -------------------------------------------------------------------------------- /scale_rl/agents/simba/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scale_rl/agents/jax_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /scale_rl/agents/simbaV2/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/.DS_Store -------------------------------------------------------------------------------- /docs/images/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/header.png -------------------------------------------------------------------------------- /docs/images/online.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/online.png -------------------------------------------------------------------------------- /docs/thumbnails/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/0.png -------------------------------------------------------------------------------- /docs/thumbnails/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/1.png -------------------------------------------------------------------------------- /docs/thumbnails/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/2.png -------------------------------------------------------------------------------- /docs/thumbnails/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/3.png -------------------------------------------------------------------------------- /docs/thumbnails/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/4.png -------------------------------------------------------------------------------- /docs/thumbnails/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/5.png -------------------------------------------------------------------------------- /docs/thumbnails/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/6.png -------------------------------------------------------------------------------- /docs/thumbnails/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/thumbnails/7.png -------------------------------------------------------------------------------- /scale_rl/envs/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from scale_rl.envs.wrappers.repeat_action import RepeatAction 2 | -------------------------------------------------------------------------------- /docs/images/analysis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/analysis.png -------------------------------------------------------------------------------- /docs/images/offline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/offline.png -------------------------------------------------------------------------------- /docs/images/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/overview.png -------------------------------------------------------------------------------- /docs/images/benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/benchmark.png -------------------------------------------------------------------------------- /docs/images/utd_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/utd_scaling.png -------------------------------------------------------------------------------- /docs/dataset/videos/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/.DS_Store -------------------------------------------------------------------------------- /docs/images/param_scaling.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/param_scaling.png -------------------------------------------------------------------------------- /scale_rl/agents/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | from scale_rl.agents.wrappers.normalization import ObservationNormalizer, RewardNormalizer -------------------------------------------------------------------------------- /docs/images/simbav2_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/images/simbav2_architecture.png -------------------------------------------------------------------------------- /docs/videos/dmc/simbav2-dog-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/dmc/simbav2-dog-run.mp4 -------------------------------------------------------------------------------- /docs/videos/dmc/simbav2-dog-trot.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/dmc/simbav2-dog-trot.mp4 -------------------------------------------------------------------------------- /docs/videos/mujoco/simbav2-ant-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/mujoco/simbav2-ant-v4.mp4 -------------------------------------------------------------------------------- /docs/videos/hbench/simbav2-h1-crawl.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/hbench/simbav2-h1-crawl.mp4 -------------------------------------------------------------------------------- /docs/videos/hbench/simbav2-h1-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/hbench/simbav2-h1-stand.mp4 -------------------------------------------------------------------------------- /docs/videos/dmc/simbav2-humanoid-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/dmc/simbav2-humanoid-stand.mp4 -------------------------------------------------------------------------------- /docs/videos/dmc/simbav2-humanoid-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/dmc/simbav2-humanoid-walk.mp4 -------------------------------------------------------------------------------- /docs/videos/hbench/simbav2-h1-sit-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/hbench/simbav2-h1-sit-hard.mp4 -------------------------------------------------------------------------------- /docs/videos/mujoco/simbav2-cheetah-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/mujoco/simbav2-cheetah-v4.mp4 -------------------------------------------------------------------------------- /docs/videos/mujoco/simbav2-humanoid-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/mujoco/simbav2-humanoid-v4.mp4 -------------------------------------------------------------------------------- /docs/videos/mujoco/simbav2-walker-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/mujoco/simbav2-walker-v4.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-dog-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-dog-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-dog-trot.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-dog-trot.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-dog-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-dog-walk.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-cheetah-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-cheetah-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-dog-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-dog-stand.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-finger-spin.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-finger-spin.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-fish-swim.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-fish-swim.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-hopper-hop.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-hopper-hop.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-walker-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-walker-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-walker-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-walker-walk.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-crawl.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-crawl.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-maze.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-maze.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-pole.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-pole.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-reach.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-reach.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-slide.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-slide.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-stair.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-stair.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-stand.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-walk.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/mujoco/simbav2-ant-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/mujoco/simbav2-ant-v4.mp4 -------------------------------------------------------------------------------- /docs/videos/hbench/simbav2-h1-balance-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/hbench/simbav2-h1-balance-hard.mp4 -------------------------------------------------------------------------------- /docs/videos/myosuite/simbav2-myo-key-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/myosuite/simbav2-myo-key-hard.mp4 -------------------------------------------------------------------------------- /docs/videos/myosuite/simbav2-myo-obj-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/myosuite/simbav2-myo-obj-hard.mp4 -------------------------------------------------------------------------------- /docs/videos/myosuite/simbav2-myo-pen-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/myosuite/simbav2-myo-pen-hard.mp4 -------------------------------------------------------------------------------- /docs/videos/myosuite/simbav2-myo-reach-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/videos/myosuite/simbav2-myo-reach-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-hopper-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-hopper-stand.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-humanoid-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-humanoid-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-humanoid-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-humanoid-walk.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-quadruped-run.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-quadruped-run.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-reacher-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-reacher-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-reacher-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-reacher-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-walker-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-walker-stand.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-hurdle.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-hurdle.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/mujoco/simbav2-cheetah-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/mujoco/simbav2-cheetah-v4.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/mujoco/simbav2-hopper-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/mujoco/simbav2-hopper-v4.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/mujoco/simbav2-walker-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/mujoco/simbav2-walker-v4.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-acrobot-swingup.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-acrobot-swingup.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-cartpole-balance.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-cartpole-balance.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-cartpole-swingup.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-cartpole-swingup.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-finger-turn-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-finger-turn-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-finger-turn-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-finger-turn-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-humanoid-stand.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-humanoid-stand.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-pendulum-swingup.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-pendulum-swingup.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-quadruped-walk.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-quadruped-walk.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-sit-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-sit-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-sit-simple.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-sit-simple.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/mujoco/simbav2-humanoid-v4.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/mujoco/simbav2-humanoid-v4.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-ball-in-cup-catch.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-ball-in-cup-catch.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-balance-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-balance-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-key-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-key-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-key-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-key-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-obj-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-obj-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-obj-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-obj-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-pen-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-pen-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-pen-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-pen-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-pose-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-pose-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-pose-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-pose-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/hbench/simbav2-h1-balance-simple.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/hbench/simbav2-h1-balance-simple.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-reach-easy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-reach-easy.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/myosuite/simbav2-myo-reach-hard.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/myosuite/simbav2-myo-reach-hard.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-cartpole-balance-sparse.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-cartpole-balance-sparse.mp4 -------------------------------------------------------------------------------- /docs/dataset/videos/dmc/simbav2-cartpole-swingup-sparse.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DAVIAN-Robotics/SimbaV2/HEAD/docs/dataset/videos/dmc/simbav2-cartpole-swingup-sparse.mp4 -------------------------------------------------------------------------------- /scale_rl/common/__init__.py: -------------------------------------------------------------------------------- 1 | from scale_rl.common.logger import WandbTrainerLogger 2 | from scale_rl.common.scheduler import ( 3 | linear_decay_scheduler, 4 | constant_value_scheduler, 5 | cyclic_exponential_decay_scheduler, 6 | ) -------------------------------------------------------------------------------- /scripts/simbaV2/test_offline.sh: -------------------------------------------------------------------------------- 1 | python run_parallel.py \ 2 | --server kaist \ 3 | --group_name final_test \ 4 | --exp_name simbav2_bc \ 5 | --config_name offline_rl \ 6 | --agent_config simbaV2_bc \ 7 | --env_type d4rl_mujoco \ 8 | --device_ids 0 1 2 3 4 5 6 7 \ 9 | --num_seeds 5 \ 10 | --num_exp_per_device 1 \ -------------------------------------------------------------------------------- /scripts/simbaV2/test_online.sh: -------------------------------------------------------------------------------- 1 | python run_parallel.py \ 2 | --server kaist \ 3 | --group_name final_test \ 4 | --exp_name simbaV2_no_hb \ 5 | --agent_config simbaV2 \ 6 | --env_type hb_locomotion \ 7 | --device_ids 0 1 2 3 4 5 6 7 \ 8 | --num_seeds 5 \ 9 | --num_exp_per_device 4 \ 10 | --overrides recording_per_interaction_step=9999999 -------------------------------------------------------------------------------- /configs/agent/random.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # This agent selects actions randomly. (Useful for debugging purposes.) 3 | ################################################################################## 4 | 5 | agent_type: 'random' 6 | 7 | seed: ${seed} 8 | normalize_observation: false 9 | normalize_reward: false -------------------------------------------------------------------------------- /configs/buffer/numpy_uniform.yaml: -------------------------------------------------------------------------------- 1 | buffer_class_type: 'numpy' # [numpy, jax (distributed-friendly)] 2 | buffer_type: 'uniform' # [uniform, prioritized] 3 | 4 | n_step: ${n_step} 5 | gamma: ${gamma} 6 | max_length: 1_000_000 # maximum buffer size. 7 | min_length: 5_000 # minimum buffer size (= number of data to collect before training). 8 | sample_batch_size: 256 # batch size for sampling = training. -------------------------------------------------------------------------------- /configs/env/d4rl.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # D4RL 3 | ################################################################################## 4 | 5 | env_type: 'd4rl' 6 | env_name: 'hopper-medium-replay-v2' 7 | num_env_steps: null 8 | episodic: true 9 | 10 | seed: ${seed} 11 | num_train_envs: 1 12 | num_eval_envs: 1 13 | action_repeat: 1 14 | rescale_action: true 15 | max_episode_steps: 1000 -------------------------------------------------------------------------------- /configs/env/mujoco.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Mujoco 3 | ################################################################################## 4 | 5 | env_type: 'mujoco' 6 | env_name: 'Humanoid-v4' 7 | num_env_steps: 1_000_000 8 | episodic: true 9 | 10 | seed: ${seed} 11 | num_train_envs: 1 12 | num_eval_envs: 1 13 | action_repeat: 1 14 | rescale_action: true 15 | max_episode_steps: 1000 -------------------------------------------------------------------------------- /configs/env/myosuite.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # MyoSuite 3 | ################################################################################## 4 | 5 | env_type: 'myosuite' 6 | env_name: 'myo-reach' 7 | num_env_steps: 1_000_000 8 | episodic: false 9 | 10 | seed: ${seed} 11 | num_train_envs: 1 12 | num_eval_envs: 1 13 | action_repeat: 2 14 | rescale_action: true 15 | max_episode_steps: 100 -------------------------------------------------------------------------------- /configs/env/dmc.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Deepmind Control Suite with hard tasks (humanoid & dog) 3 | ################################################################################## 4 | 5 | env_type: 'dmc' 6 | env_name: 'acrobot-swingup' 7 | num_env_steps: 1_000_000 8 | episodic: false 9 | 10 | seed: ${seed} 11 | num_train_envs: 1 12 | num_eval_envs: 1 13 | action_repeat: 2 14 | rescale_action: true 15 | max_episode_steps: 1000 -------------------------------------------------------------------------------- /configs/env/hb_locomotion.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Humanoid Benchmark with Locomotion tasks 3 | ################################################################################## 4 | 5 | env_type: 'humanoid_bench' 6 | env_name: 'h1-walk-v0' 7 | num_env_steps: 1_000_000 8 | episodic: true 9 | 10 | seed: ${seed} 11 | num_train_envs: 1 12 | num_eval_envs: 1 13 | action_repeat: 2 14 | rescale_action: true 15 | max_episode_steps: 1000 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("deps/requirements.txt") as f: 4 | install_requires = f.read() 5 | 6 | 7 | if __name__ == "__main__": 8 | setup( 9 | name="scale_rl", 10 | version="0.1.0", 11 | url="https://github.com/dojeon-ai/SimbaV2", 12 | license="Apache License 2.0", 13 | install_requires=install_requires, 14 | packages=find_packages(), 15 | python_requires=">=3.9.0", 16 | zip_safe=True, 17 | ) 18 | -------------------------------------------------------------------------------- /deps/requirements.txt: -------------------------------------------------------------------------------- 1 | dm-control==1.0.20 2 | dotmap==1.3.30 3 | Cython==0.29.37 4 | flax==0.8.4 5 | hydra-core==1.3.2 6 | imageio==2.33.1 7 | jax==0.4.25 8 | jaxlib==0.4.25 9 | jaxopt==0.8.3 10 | matplotlib==3.9.0 11 | mujoco==3.1.6 12 | mujoco-py==2.1.2.14 13 | numpy==1.24.1 14 | omegaconf==2.3.0 15 | optax==0.2.2 16 | pandas==2.1.4 17 | Shimmy==2.0.0 18 | tensorflow-probability==0.24.0 19 | termcolor==2.4.0 20 | tqdm==4.66.1 21 | wandb==0.16.6 22 | moviepy==1.0.3 23 | git+https://github.com/takuseno/D4RL 24 | git+https://github.com/Farama-Foundation/Gymnasium.git@d92e030e9a0b03806deda51122c705169313f158 25 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | # https://pre-commit.com 2 | # This GitHub Action assumes that the repo contains a valid .pre-commit-config.yaml file. 3 | --- 4 | name: pre-commit 5 | on: 6 | pull_request: 7 | push: 8 | branches: [master] 9 | 10 | permissions: 11 | contents: read 12 | 13 | jobs: 14 | pre-commit: 15 | runs-on: ubuntu-latest 16 | steps: 17 | - uses: actions/checkout@v3 18 | - uses: actions/setup-python@v3 19 | - run: pip install pre-commit 20 | - run: pre-commit --version 21 | - run: pre-commit install 22 | - run: pre-commit run --all-files 23 | -------------------------------------------------------------------------------- /docs/dataset/css/bulma.min.css: -------------------------------------------------------------------------------- 1 | .bibtex-container{ 2 | flex-grow:1; 3 | margin:0 auto 4 | ;position:relative; 5 | width:auto 6 | } 7 | 8 | .bibtex pre{ 9 | -webkit-overflow-scrolling: touch; 10 | overflow-x:auto; 11 | padding:1.25em 1.5em; 12 | white-space:pre; 13 | word-wrap:normal; 14 | font-family: "Courier", monospace; 15 | background-color: #f4f4f4; 16 | text-align: left; 17 | } 18 | 19 | .hero{ 20 | align-items:stretch; 21 | display:flex; 22 | flex-direction:column; 23 | justify-content:space-between 24 | } 25 | 26 | .hero.is-light{ 27 | color:rgba(0,0,0,.7) 28 | } 29 | 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | repos: 3 | - repo: https://github.com/codespell-project/codespell 4 | rev: v2.2.5 5 | hooks: 6 | - id: codespell # english spelling check 7 | args: [ 8 | --ignore-words-list=null 9 | ] 10 | exclude: ^(.*\.ipynb|.*\.html)$ 11 | - repo: https://github.com/astral-sh/ruff-pre-commit 12 | rev: v0.2.1 13 | hooks: 14 | - id: ruff # code formatting check. 15 | args: [ 16 | --fix, # applies formatting automatically. 17 | --select=I, # import-sort (library import is automatically sorted) 18 | ] 19 | exclude: ^.*__init__\.py$ 20 | - id: ruff-format 21 | exclude: ^.*__init__\.py$ 22 | -------------------------------------------------------------------------------- /scale_rl/envs/mujoco.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | MUJOCO_ALL = [ 4 | "HalfCheetah-v4", 5 | "Hopper-v4", 6 | "Walker2d-v4", 7 | "Ant-v4", 8 | "Humanoid-v4", 9 | ] 10 | 11 | MUJOCO_RANDOM_SCORE = { 12 | "HalfCheetah-v4": -289.415, 13 | "Hopper-v4": 18.791, 14 | "Walker2d-v4": 2.791, 15 | "Ant-v4": -70.288, 16 | "Humanoid-v4": 120.423, 17 | } 18 | 19 | MUJOCO_TD3_SCORE = { 20 | "HalfCheetah-v4": 10574, 21 | "Hopper-v4": 3226, 22 | "Walker2d-v4": 3946, 23 | "Ant-v4": 3942, 24 | "Humanoid-v4": 5165, 25 | } 26 | 27 | 28 | def make_mujoco_env( 29 | env_name: str, 30 | seed: int, 31 | ) -> gym.Env: 32 | env = gym.make(env_name, render_mode="rgb_array") 33 | 34 | return env 35 | -------------------------------------------------------------------------------- /scale_rl/envs/wrappers/repeat_action.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | import numpy as np 3 | 4 | 5 | class RepeatAction(gym.Wrapper): 6 | def __init__(self, env: gym.Env, action_repeat=4): 7 | super().__init__(env) 8 | self._action_repeat = action_repeat 9 | 10 | def step(self, action: np.ndarray): 11 | total_reward = 0.0 12 | terminated = False 13 | truncated = False 14 | combined_info = {} 15 | 16 | for _ in range(self._action_repeat): 17 | obs, reward, terminated, truncated, info = self.env.step(action) 18 | total_reward += float(reward) 19 | combined_info.update(info) 20 | if terminated or truncated: 21 | break 22 | 23 | return obs, total_reward, terminated, truncated, combined_info 24 | -------------------------------------------------------------------------------- /deps/environment.yaml: -------------------------------------------------------------------------------- 1 | name: scale_rl 2 | channels: 3 | - nvidia 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - glew=2.1.0 8 | - glib=2.68.4 9 | - pip=21.0 10 | - python=3.9.0 11 | - pip: 12 | - dm-control==1.0.20 13 | - dotmap==1.3.30 14 | - Cython==0.29.37 15 | - flax==0.8.4 16 | - git+https://github.com/Farama-Foundation/Gymnasium.git@d92e030e9a0b03806deda51122c705169313f158 17 | - git+https://github.com/takuseno/D4RL 18 | - hydra-core==1.3.2 19 | - imageio==2.33.1 20 | - jax==0.4.25 21 | - jaxlib==0.4.25 22 | - jaxopt==0.8.3 23 | - matplotlib==3.9.0 24 | - mujoco==3.1.6 25 | - mujoco-py==2.1.2.14 26 | - numpy==1.24.1 27 | - omegaconf==2.3.0 28 | - optax==0.2.2 29 | - pandas==2.1.4 30 | - Shimmy==2.0.0 31 | - tensorflow-probability==0.24.0 32 | - termcolor==2.4.0 33 | - tqdm==4.66.1 34 | - wandb==0.16.6 35 | - moviepy==1.0.3 -------------------------------------------------------------------------------- /configs/agent/simba.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Simba 3 | ################################################################################## 4 | 5 | agent_type: 'simba' 6 | 7 | seed: ${seed} 8 | normalize_observation: true 9 | normalize_reward: false 10 | normalized_g_max: 0.0 11 | 12 | load_only_param: true # do not load optimizer 13 | load_param_key: null # load designated key 14 | load_observation_normalizer: true 15 | load_reward_normalizer: true 16 | 17 | learning_rate_init: 1e-4 18 | learning_rate_end: 1e-4 19 | learning_rate_decay_rate: 1.0 20 | learning_rate_decay_step: ${eval:'int(${agent.learning_rate_decay_rate} * ${num_interaction_steps} * ${updates_per_interaction_step})'} 21 | weight_decay: 1e-2 22 | 23 | actor_num_blocks: 1 24 | actor_hidden_dim: 128 25 | actor_bc_alpha: 0.0 26 | 27 | critic_use_cdq: ${env.episodic} 28 | critic_num_blocks: 2 29 | critic_hidden_dim: 512 30 | 31 | target_tau: 0.005 32 | 33 | temp_initial_value: 0.01 34 | temp_target_entropy: null # entropy_coef * action_dim 35 | temp_target_entropy_coef: -0.5 36 | 37 | gamma: ${gamma} 38 | n_step: ${n_step} 39 | -------------------------------------------------------------------------------- /scale_rl/buffers/__init__.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from typing import Tuple, Optional 3 | from scale_rl.buffers.base_buffer import BaseBuffer, Batch 4 | from scale_rl.buffers.numpy_buffer import NpyUniformBuffer, NpyPrioritizedBuffer 5 | 6 | 7 | def create_buffer( 8 | buffer_class_type: str, 9 | buffer_type: str, 10 | observation_space: gym.spaces.Space, 11 | action_space: gym.spaces.Space, 12 | n_step: int, 13 | gamma: float, 14 | max_length: int, 15 | min_length: int, 16 | sample_batch_size: int, 17 | **kwargs, 18 | ) -> BaseBuffer: 19 | 20 | if buffer_class_type == 'numpy': 21 | if buffer_type == 'uniform': 22 | buffer = NpyUniformBuffer( 23 | observation_space=observation_space, 24 | action_space=action_space, 25 | n_step=n_step, 26 | gamma=gamma, 27 | max_length=max_length, 28 | min_length=min_length, 29 | sample_batch_size=sample_batch_size, 30 | ) 31 | 32 | else: 33 | raise NotImplementedError 34 | 35 | elif buffer_class_type == 'jax': 36 | raise NotImplementedError 37 | 38 | return buffer 39 | -------------------------------------------------------------------------------- /scale_rl/agents/random_agent.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, TypeVar 2 | 3 | import gymnasium as gym 4 | import numpy as np 5 | 6 | from scale_rl.agents.base_agent import BaseAgent 7 | 8 | Config = TypeVar("Config") 9 | 10 | 11 | class RandomAgent(BaseAgent): 12 | def __init__( 13 | self, 14 | observation_space: gym.spaces.Space, 15 | action_space: gym.spaces.Space, 16 | cfg: Config, 17 | ): 18 | """ 19 | An agent that randomly selects actions without training. 20 | Useful for collecting baseline results and for debugging purposes. 21 | """ 22 | super(RandomAgent, self).__init__( 23 | observation_space, 24 | action_space, 25 | cfg, 26 | ) 27 | 28 | def sample_actions( 29 | self, 30 | interaction_step: int, 31 | prev_timestep: Dict[str, np.ndarray], 32 | training: bool, 33 | ) -> np.ndarray: 34 | num_envs = prev_timestep["next_observation"].shape[0] 35 | actions = [] 36 | for _ in range(num_envs): 37 | actions.append(self._action_space.sample()) 38 | 39 | actions = np.stack(actions).reshape(-1) 40 | return actions 41 | 42 | def update(self, update_step: int, batch: Dict[str, np.ndarray]) -> Dict: 43 | update_info = {} 44 | return update_info 45 | 46 | def save(self, path: str) -> None: 47 | pass 48 | 49 | def load(self, path: str) -> None: 50 | pass 51 | -------------------------------------------------------------------------------- /configs/offline_rl.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Common 3 | ################################################################################## 4 | 5 | project_name: 'SimbaV2' 6 | entity_name: 'draftrec' 7 | group_name: 'test' 8 | exp_name: 'test' 9 | seed: 0 10 | server: 'local' 11 | save_path: 'models/${group_name}/${exp_name}/${env.env_name}/${seed}' 12 | load_path: null 13 | 14 | ################################################################################## 15 | # Training 16 | ################################################################################## 17 | 18 | # gamma value is set with a heuristic from TD-MPCv2 19 | eff_episode_len: ${eval:'${env.max_episode_steps} / ${env.action_repeat}'} 20 | gamma: ${eval:'max(min((${eff_episode_len} / 5 - 1) / (${eff_episode_len} / 5), 0.995), 0.95)'} 21 | n_step: 1 22 | 23 | num_epochs: 100 24 | action_repeat: ${env.action_repeat} 25 | 26 | num_interaction_steps: null # update steps 27 | updates_per_interaction_step: 1 # fixed 28 | evaluation_per_interaction_step: 50_000 # evaluation frequency per interaction step. 29 | metrics_per_interaction_step: 50_000 # log metrics per interaction step. 30 | recording_per_interaction_step: -1 # recording is not supported for offline rl yet. 31 | logging_per_interaction_step: 10_000 # logging frequency per interaction step. 32 | save_checkpoint_per_interaction_step: null 33 | num_eval_episodes: 100 34 | num_record_episodes: 1 35 | 36 | defaults: 37 | - _self_ 38 | - agent: simbaV2_bc 39 | - buffer: numpy_uniform 40 | - env: d4rl 41 | -------------------------------------------------------------------------------- /scale_rl/agents/__init__.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from typing import TypeVar 3 | from omegaconf import OmegaConf 4 | from scale_rl.agents.base_agent import BaseAgent 5 | from scale_rl.agents.wrappers import ObservationNormalizer, RewardNormalizer 6 | 7 | Config = TypeVar('Config') 8 | 9 | 10 | def create_agent( 11 | observation_space: gym.spaces.Space, 12 | action_space: gym.spaces.Space, 13 | cfg: Config, 14 | ) -> BaseAgent: 15 | 16 | cfg = OmegaConf.to_container(cfg, throw_on_missing=True) 17 | agent_type = cfg.pop('agent_type') 18 | 19 | if agent_type == 'random': 20 | from scale_rl.agents.random_agent import RandomAgent 21 | agent = RandomAgent(observation_space, action_space, cfg) 22 | 23 | elif agent_type == 'simba': 24 | from scale_rl.agents.simba.simba_agent import SimbaAgent 25 | agent = SimbaAgent(observation_space, action_space, cfg) 26 | 27 | elif agent_type == 'simbaV2': 28 | from scale_rl.agents.simbaV2.simbaV2_agent import SimbaV2Agent 29 | agent = SimbaV2Agent(observation_space, action_space, cfg) 30 | 31 | else: 32 | raise NotImplementedError 33 | 34 | # observation and reward normalization wrappers 35 | if cfg['normalize_observation']: 36 | agent = ObservationNormalizer( 37 | agent, 38 | load_rms=cfg['load_observation_normalizer'] 39 | ) 40 | if cfg['normalize_reward']: 41 | agent = RewardNormalizer( 42 | agent, 43 | gamma=cfg['gamma'], 44 | g_max=cfg['normalized_g_max'], 45 | load_rms=cfg['load_reward_normalizer'] 46 | ) 47 | 48 | return agent -------------------------------------------------------------------------------- /scale_rl/common/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | EPS = 1e-8 4 | 5 | 6 | def cyclic_exponential_decay_scheduler( 7 | decay_period, initial_value, final_value, reverse=False 8 | ): 9 | if reverse: 10 | initial_value = 1 - initial_value 11 | final_value = 1 - final_value 12 | 13 | start = np.log(initial_value + EPS) 14 | end = np.log(final_value + EPS) 15 | 16 | def scheduler(step): 17 | cycle_length = decay_period 18 | cycle_step = step % cycle_length 19 | 20 | steps_left = decay_period - cycle_step 21 | bonus_frac = steps_left / decay_period 22 | bonus = np.clip(bonus_frac, 0.0, 1.0) 23 | new_value = bonus * (start - end) + end 24 | 25 | new_value = np.exp(new_value) - EPS 26 | if reverse: 27 | new_value = 1 - new_value 28 | return new_value 29 | 30 | return scheduler 31 | 32 | 33 | def linear_decay_scheduler(decay_period, initial_value, final_value): 34 | def scheduler(step): 35 | # Ensure step does not exceed decay_period 36 | step = min(step, decay_period) 37 | 38 | # Calculate the linear interpolation factor 39 | fraction = step / decay_period 40 | new_value = (1 - fraction) * initial_value + fraction * final_value 41 | 42 | return new_value 43 | 44 | return scheduler 45 | 46 | 47 | def constant_value_scheduler(value): 48 | """ 49 | Returns a scheduler function that always returns the same value. 50 | 51 | Args: 52 | value (float): The constant value to return. 53 | 54 | Returns: 55 | function: A scheduler function that always returns `value`. 56 | """ 57 | 58 | def scheduler(step): 59 | return value 60 | 61 | return scheduler 62 | -------------------------------------------------------------------------------- /configs/online_rl.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # Common 3 | ################################################################################## 4 | 5 | project_name: 'SimbaV2' 6 | entity_name: 'draftrec' 7 | group_name: 'test' 8 | exp_name: 'test' 9 | seed: 0 10 | server: 'local' 11 | save_path: 'models/${group_name}/${exp_name}/${env.env_name}/${seed}' 12 | load_path: null 13 | 14 | ################################################################################## 15 | # Training 16 | ################################################################################## 17 | 18 | # gamma value is set with a heuristic from TD-MPCv2 19 | eff_episode_len: ${eval:'${env.max_episode_steps} / ${env.action_repeat}'} 20 | gamma: ${eval:'max(min((${eff_episode_len} / 5 - 1) / (${eff_episode_len} / 5), 0.995), 0.95)'} 21 | n_step: 1 22 | 23 | num_train_envs: ${env.num_train_envs} 24 | num_env_steps: ${env.num_env_steps} 25 | action_repeat: ${env.action_repeat} 26 | 27 | num_interaction_steps: ${eval:'${num_env_steps} / (${num_train_envs} * ${action_repeat})'} 28 | updates_per_interaction_step: ${action_repeat} # number of updates per interaction step. 29 | evaluation_per_interaction_step: 50_000 # evaluation frequency per interaction step. 30 | metrics_per_interaction_step: 50_000 # log metrics per interaction step. 31 | recording_per_interaction_step: ${num_interaction_steps} # video recording frequency per interaction step. 32 | logging_per_interaction_step: 10_000 # logging frequency per interaction step. 33 | save_checkpoint_per_interaction_step: ${num_interaction_steps} 34 | save_buffer_per_interaction_step: ${num_interaction_steps} 35 | num_eval_episodes: 10 36 | num_record_episodes: 1 37 | 38 | defaults: 39 | - _self_ 40 | - agent: simbaV2 41 | - buffer: numpy_uniform 42 | - env: dmc -------------------------------------------------------------------------------- /scale_rl/envs/dmc.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from dm_control import suite 3 | from gymnasium import spaces 4 | from gymnasium.wrappers import FlattenObservation 5 | from shimmy import DmControlCompatibilityV0 as DmControltoGymnasium 6 | 7 | # 20 tasks 8 | DMC_EASY_MEDIUM = [ 9 | "acrobot-swingup", 10 | "ball_in_cup-catch", 11 | "cartpole-balance", 12 | "cartpole-balance_sparse", 13 | "cartpole-swingup", 14 | "cartpole-swingup_sparse", 15 | "cheetah-run", 16 | "finger-spin", 17 | "finger-turn_easy", 18 | "finger-turn_hard", 19 | "fish-swim", 20 | "hopper-hop", 21 | "hopper-stand", 22 | "pendulum-swingup", 23 | "quadruped-walk", 24 | "quadruped-run", 25 | "reacher-easy", 26 | "reacher-hard", 27 | "walker-stand", 28 | "walker-walk", 29 | "walker-run", 30 | ] 31 | 32 | # 8 tasks 33 | DMC_SPARSE = [ 34 | "cartpole-balance_sparse", 35 | "cartpole-swingup_sparse", 36 | "ball_in_cup-catch", 37 | "finger-spin", 38 | "finger-turn_easy", 39 | "finger-turn_hard", 40 | "reacher-easy", 41 | "reacher-hard", 42 | ] 43 | 44 | # 7 tasks 45 | DMC_HARD = [ 46 | "humanoid-stand", 47 | "humanoid-walk", 48 | "humanoid-run", 49 | "dog-stand", 50 | "dog-walk", 51 | "dog-run", 52 | "dog-trot", 53 | ] 54 | 55 | 56 | def make_dmc_env( 57 | env_name: str, 58 | seed: int, 59 | flatten: bool = True, 60 | ) -> gym.Env: 61 | domain_name, task_name = env_name.split("-") 62 | env = suite.load( 63 | domain_name=domain_name, 64 | task_name=task_name, 65 | task_kwargs={"random": seed}, 66 | ) 67 | env = DmControltoGymnasium(env, render_mode="rgb_array") 68 | if flatten and isinstance(env.observation_space, spaces.Dict): 69 | env = FlattenObservation(env) 70 | 71 | return env 72 | -------------------------------------------------------------------------------- /configs/agent/simbaV2_bc.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # SAC with Hyper-Simba architecture + Behavioral Cloning Loss 3 | ################################################################################## 4 | 5 | agent_type: 'simbaV2' 6 | 7 | seed: ${seed} 8 | normalize_observation: true 9 | normalize_reward: true 10 | normalized_g_max: 5.0 11 | 12 | load_only_param: true # do not load optimizer 13 | load_param_key: null # load designated key 14 | load_observation_normalizer: true 15 | load_reward_normalizer: true 16 | 17 | learning_rate_init: 1e-4 18 | learning_rate_end: 1e-5 19 | learning_rate_decay_rate: 0.5 20 | learning_rate_decay_step: null 21 | 22 | actor_num_blocks: 1 23 | actor_hidden_dim: 128 24 | actor_c_shift: 3.0 25 | actor_scaler_init: ${eval:'math.sqrt(2 / ${agent.actor_hidden_dim})'} 26 | actor_scaler_scale: ${eval:'math.sqrt(2 / ${agent.actor_hidden_dim})'} 27 | actor_alpha_init: ${eval:'1 / (${agent.actor_num_blocks} + 1)'} 28 | actor_alpha_scale: ${eval:'1 / math.sqrt(${agent.actor_hidden_dim})'} 29 | actor_bc_alpha: 0.1 # offline-rl 30 | 31 | critic_use_cdq: ${env.episodic} 32 | critic_num_blocks: 2 33 | critic_hidden_dim: 512 34 | critic_c_shift: 3.0 35 | critic_num_bins: 101 36 | critic_scaler_init: ${eval:'math.sqrt(2 / ${agent.critic_hidden_dim})'} 37 | critic_scaler_scale: ${eval:'math.sqrt(2 / ${agent.critic_hidden_dim})'} 38 | critic_min_v: ${eval:'-${agent.normalized_g_max}'} 39 | critic_max_v: ${eval:'${agent.normalized_g_max}'} 40 | critic_alpha_init: ${eval:'1 / (${agent.critic_num_blocks} + 1)'} 41 | critic_alpha_scale: ${eval:'1 / math.sqrt(${agent.critic_hidden_dim})'} 42 | 43 | target_tau: 0.005 44 | 45 | temp_initial_value: 0.01 46 | temp_target_entropy: null # entropy_coef * action_dim 47 | temp_target_entropy_coef: -0.5 48 | 49 | gamma: ${gamma} 50 | n_step: ${n_step} -------------------------------------------------------------------------------- /scale_rl/agents/wrappers/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is adapted from the Gymnasium API. 3 | ref: https://github.com/Farama-Foundation/Gymnasium/blob/main/gymnasium/wrappers/utils.py 4 | """ 5 | 6 | import numpy as np 7 | 8 | 9 | class RunningMeanStd: 10 | """Tracks the mean, variance and count of values.""" 11 | 12 | # https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm 13 | def __init__(self, epsilon=1e-4, shape=(), dtype=np.float32): 14 | """Tracks the mean, variance and count of values.""" 15 | self.mean = np.zeros(shape, dtype=dtype) 16 | self.var = np.ones(shape, dtype=dtype) 17 | self.count = epsilon 18 | 19 | def update(self, x): 20 | """Updates the mean, var and count from a batch of samples.""" 21 | batch_mean = np.mean(x, axis=0) 22 | batch_var = np.var(x, axis=0) 23 | batch_count = x.shape[0] 24 | self.update_from_moments(batch_mean, batch_var, batch_count) 25 | 26 | def update_from_moments(self, batch_mean, batch_var, batch_count): 27 | """Updates from batch mean, variance and count moments.""" 28 | self.mean, self.var, self.count = update_mean_var_count_from_moments( 29 | self.mean, self.var, self.count, batch_mean, batch_var, batch_count 30 | ) 31 | 32 | 33 | def update_mean_var_count_from_moments( 34 | mean, var, count, batch_mean, batch_var, batch_count 35 | ): 36 | """Updates the mean, var and count using the previous mean, var, count and batch values.""" 37 | delta = batch_mean - mean 38 | tot_count = count + batch_count 39 | ratio = batch_count / tot_count 40 | 41 | new_mean = mean + delta * ratio 42 | m_a = var * count 43 | m_b = batch_var * batch_count 44 | M2 = m_a + m_b + np.square(delta) * count * ratio 45 | new_var = M2 / tot_count 46 | 47 | return new_mean, new_var, tot_count 48 | -------------------------------------------------------------------------------- /configs/agent/simbaV2.yaml: -------------------------------------------------------------------------------- 1 | ################################################################################## 2 | # SAC with Hyper-Simba architecture 3 | ################################################################################## 4 | 5 | agent_type: 'simbaV2' 6 | 7 | seed: ${seed} 8 | normalize_observation: true 9 | normalize_reward: true 10 | normalized_g_max: 5.0 11 | 12 | load_only_param: true # do not load optimizer 13 | load_param_key: null # load designated key 14 | load_observation_normalizer: true 15 | load_reward_normalizer: true 16 | 17 | learning_rate_init: 1e-4 18 | learning_rate_end: 5e-5 19 | learning_rate_decay_rate: 1.0 20 | learning_rate_decay_step: ${eval:'int(${agent.learning_rate_decay_rate} * ${num_interaction_steps} * ${updates_per_interaction_step})'} 21 | 22 | actor_num_blocks: 1 23 | actor_hidden_dim: 128 24 | actor_c_shift: 3.0 25 | actor_scaler_init: ${eval:'math.sqrt(2 / ${agent.actor_hidden_dim})'} 26 | actor_scaler_scale: ${eval:'math.sqrt(2 / ${agent.actor_hidden_dim})'} 27 | actor_alpha_init: ${eval:'1 / (${agent.actor_num_blocks} + 1)'} 28 | actor_alpha_scale: ${eval:'1 / math.sqrt(${agent.actor_hidden_dim})'} 29 | actor_bc_alpha: 0.0 30 | 31 | critic_use_cdq: ${env.episodic} 32 | critic_num_blocks: 2 33 | critic_hidden_dim: 512 34 | critic_c_shift: 3.0 35 | critic_num_bins: 101 36 | critic_scaler_init: ${eval:'math.sqrt(2 / ${agent.critic_hidden_dim})'} 37 | critic_scaler_scale: ${eval:'math.sqrt(2 / ${agent.critic_hidden_dim})'} 38 | critic_min_v: ${eval:'-${agent.normalized_g_max}'} 39 | critic_max_v: ${eval:'${agent.normalized_g_max}'} 40 | critic_alpha_init: ${eval:'1 / (${agent.critic_num_blocks} + 1)'} 41 | critic_alpha_scale: ${eval:'1 / math.sqrt(${agent.critic_hidden_dim})'} 42 | 43 | target_tau: 0.005 44 | 45 | temp_initial_value: 0.01 46 | temp_target_entropy: null # entropy_coef * action_dim 47 | temp_target_entropy_coef: -0.5 48 | 49 | gamma: ${gamma} 50 | n_step: ${n_step} -------------------------------------------------------------------------------- /scale_rl/buffers/base_buffer.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Dict 3 | 4 | import gymnasium as gym 5 | import numpy as np 6 | 7 | Batch = Dict[str, np.ndarray] 8 | 9 | 10 | class BaseBuffer(ABC): 11 | def __init__( 12 | self, 13 | observation_space: gym.spaces.Space, 14 | action_space: gym.spaces.Space, 15 | n_step: int, 16 | gamma: float, 17 | max_length: int, 18 | min_length: int, 19 | sample_batch_size: int, 20 | ): 21 | """ 22 | A generic buffer class. 23 | 24 | args: 25 | observation_shape 26 | action_shapce 27 | max_length: maximum length of buffer (max number of experiences stored within the state). 28 | min_length: minimum number of experiences saved in the buffer state before we can sample. 29 | add_sequences: indiciator of whether we will be adding data in sequences to the buffer? 30 | sample_batch_size: batch size of data that is sampled from a single sampling call. 31 | """ 32 | 33 | self._observation_space = observation_space 34 | self._action_space = action_space 35 | self._max_length = max_length 36 | self._min_length = min_length 37 | self._n_step = n_step 38 | self._gamma = gamma 39 | self._sample_batch_size = sample_batch_size 40 | 41 | def __len__(self): 42 | pass 43 | 44 | @abstractmethod 45 | def reset(self) -> None: 46 | pass 47 | 48 | @abstractmethod 49 | def add(self, timestep: Dict[str, np.ndarray]) -> None: 50 | pass 51 | 52 | @abstractmethod 53 | def can_sample(self) -> bool: 54 | pass 55 | 56 | @abstractmethod 57 | def sample(self) -> Batch: 58 | pass 59 | 60 | @abstractmethod 61 | def save(self, path: str) -> None: 62 | pass 63 | 64 | @abstractmethod 65 | def get_observations(self) -> np.ndarray: 66 | pass 67 | -------------------------------------------------------------------------------- /scale_rl/envs/d4rl.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import d4rl 4 | import gymnasium as gym 5 | import numpy as np 6 | 7 | D4RL_MUJOCO = [ 8 | "hopper-medium-v2", 9 | "hopper-medium-replay-v2", 10 | "hopper-medium-expert-v2", 11 | "halfcheetah-medium-v2", 12 | "halfcheetah-medium-replay-v2", 13 | "halfcheetah-medium-expert-v2", 14 | "walker2d-medium-v2", 15 | "walker2d-medium-replay-v2", 16 | "walker2d-medium-expert-v2", 17 | ] 18 | 19 | 20 | def make_d4rl_env(env_name: str, seed: int) -> gym.Env: 21 | env = gym.make("GymV26Environment-v0", env_id=env_name) 22 | env.reset(seed=seed) 23 | return env 24 | 25 | 26 | def make_d4rl_dataset(env_name: str) -> list[dict[str, Any]]: 27 | env = make_d4rl_env(env_name, seed=0) 28 | 29 | # extract dataset 30 | dataset = env.env.env.gym_env.get_dataset() 31 | timesteps = [] 32 | total_steps = dataset["rewards"].shape[0] 33 | for i in range(total_steps - 1): 34 | obs = dataset["observations"][i] 35 | action = dataset["actions"][i] 36 | reward = dataset["rewards"][i] 37 | terminated = dataset["terminals"][i] 38 | truncated = dataset["timeouts"][i] 39 | if terminated: 40 | next_obs = np.zeros_like(obs) 41 | elif truncated: 42 | continue 43 | else: 44 | next_obs = dataset["observations"][i + 1] 45 | timestep = { 46 | "observation": np.expand_dims(obs, axis=0), 47 | "action": np.expand_dims(action, axis=0), 48 | "reward": np.array([reward]), 49 | "terminated": np.array([terminated]), 50 | "truncated": np.array([truncated]), 51 | "next_observation": np.expand_dims(next_obs, axis=0), 52 | } 53 | timesteps.append(timestep) 54 | 55 | return timesteps 56 | 57 | 58 | def get_d4rl_normalized_score(env_name: str, unnormalized_score: float) -> float: 59 | return 100 * d4rl.get_normalized_score(env_name, unnormalized_score) 60 | -------------------------------------------------------------------------------- /scale_rl/envs/myosuite.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | MYOSUITE_TASKS = [ 4 | "myo-reach", 5 | "myo-reach-hard", 6 | "myo-pose", 7 | "myo-pose-hard", 8 | "myo-obj-hold", 9 | "myo-obj-hold-hard", 10 | "myo-key-turn", 11 | "myo-key-turn-hard", 12 | "myo-pen-twirl", 13 | "myo-pen-twirl-hard", 14 | ] 15 | 16 | 17 | MYOSUITE_TASKS_DICT = { 18 | "myo-reach": "myoHandReachFixed-v0", 19 | "myo-reach-hard": "myoHandReachRandom-v0", 20 | "myo-pose": "myoHandPoseFixed-v0", 21 | "myo-pose-hard": "myoHandPoseRandom-v0", 22 | "myo-obj-hold": "myoHandObjHoldFixed-v0", 23 | "myo-obj-hold-hard": "myoHandObjHoldRandom-v0", 24 | "myo-key-turn": "myoHandKeyTurnFixed-v0", 25 | "myo-key-turn-hard": "myoHandKeyTurnRandom-v0", 26 | "myo-pen-twirl": "myoHandPenTwirlFixed-v0", 27 | "myo-pen-twirl-hard": "myoHandPenTwirlRandom-v0", 28 | } 29 | 30 | 31 | class MyosuiteGymnasiumVersionWrapper(gym.Wrapper): 32 | """ 33 | myosuite originally requires gymnasium==0.15 34 | however, we are currently using gymnasium==1.0.0a2, 35 | hence requiring some minor fix to the 36 | - fix a. 37 | - fix b. 38 | """ 39 | 40 | def __init__(self, env: gym.Env): 41 | super().__init__(env) 42 | self.unwrapped_env = env.unwrapped 43 | 44 | def step(self, action): 45 | obs, reward, terminated, truncated, info = self.env.step(action) 46 | info["success"] = info["solved"] 47 | return obs, reward, terminated, truncated, info 48 | 49 | def render( 50 | self, width: int = 192, height: int = 192, camera_id: str = "hand_side_inter" 51 | ): 52 | return self.unwrapped_env.sim.renderer.render_offscreen( 53 | width=width, 54 | height=height, 55 | camera_id=camera_id, 56 | ) 57 | 58 | 59 | def make_myosuite_env( 60 | env_name: str, 61 | seed: int, 62 | **kwargs, 63 | ) -> gym.Env: 64 | from myosuite.utils import gym as myo_gym 65 | 66 | env = myo_gym.make(MYOSUITE_TASKS_DICT[env_name]) 67 | env = MyosuiteGymnasiumVersionWrapper(env) 68 | 69 | return env 70 | -------------------------------------------------------------------------------- /scale_rl/agents/simba/simba_layer.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import flax.linen as nn 4 | import jax.numpy as jnp 5 | from tensorflow_probability.substrates import jax as tfp 6 | 7 | tfd = tfp.distributions 8 | tfb = tfp.bijectors 9 | 10 | 11 | class PreLNResidualBlock(nn.Module): 12 | hidden_dim: int 13 | expansion: int = 4 14 | 15 | def setup(self): 16 | self.pre_ln = nn.LayerNorm() 17 | self.w1 = nn.Dense( 18 | self.hidden_dim * self.expansion, kernel_init=nn.initializers.he_normal() 19 | ) 20 | self.w2 = nn.Dense(self.hidden_dim, kernel_init=nn.initializers.he_normal()) 21 | 22 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 23 | res = x 24 | x = self.pre_ln(x) 25 | x = self.w1(x) 26 | x = nn.relu(x) 27 | x = self.w2(x) 28 | return res + x 29 | 30 | 31 | class LinearCritic(nn.Module): 32 | def setup(self): 33 | self.w = nn.Dense(1, kernel_init=nn.initializers.orthogonal(1.0)) 34 | 35 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 36 | value = self.w(x).squeeze(-1) 37 | info = {} 38 | return value, info 39 | 40 | 41 | class NormalTanhPolicy(nn.Module): 42 | action_dim: int 43 | log_std_min: float = -10.0 44 | log_std_max: float = 2.0 45 | 46 | def setup(self): 47 | self.mean_w = nn.Dense( 48 | self.action_dim, kernel_init=nn.initializers.orthogonal(1.0) 49 | ) 50 | self.std_w = nn.Dense( 51 | self.action_dim, kernel_init=nn.initializers.orthogonal(1.0) 52 | ) 53 | 54 | def __call__( 55 | self, 56 | x: jnp.ndarray, 57 | temperature: float = 1.0, 58 | ) -> tfd.Distribution: 59 | mean = self.mean_w(x) 60 | log_std = self.std_w(x) 61 | 62 | # normalize log-stds for stability 63 | log_std = self.log_std_min + (self.log_std_max - self.log_std_min) * 0.5 * ( 64 | 1 + nn.tanh(log_std) 65 | ) 66 | 67 | # N(mu, exp(log_sigma)) 68 | dist = tfd.MultivariateNormalDiag( 69 | loc=mean, 70 | scale_diag=jnp.exp(log_std) * temperature, 71 | ) 72 | 73 | # tanh(N(mu, sigma)) 74 | dist = tfd.TransformedDistribution(distribution=dist, bijector=tfb.Tanh()) 75 | 76 | info = {} 77 | return dist, info 78 | -------------------------------------------------------------------------------- /docs/dataset/js/index.js: -------------------------------------------------------------------------------- 1 | window.HELP_IMPROVE_VIDEOJS = false; 2 | 3 | var INTERP_BASE = "./static/interpolation/stacked"; 4 | var NUM_INTERP_FRAMES = 240; 5 | 6 | var interp_images = []; 7 | function preloadInterpolationImages() { 8 | for (var i = 0; i < NUM_INTERP_FRAMES; i++) { 9 | var path = INTERP_BASE + '/' + String(i).padStart(6, '0') + '.jpg'; 10 | interp_images[i] = new Image(); 11 | interp_images[i].src = path; 12 | } 13 | } 14 | 15 | function setInterpolationImage(i) { 16 | var image = interp_images[i]; 17 | image.ondragstart = function() { return false; }; 18 | image.oncontextmenu = function() { return false; }; 19 | $('#interpolation-image-wrapper').empty().append(image); 20 | } 21 | 22 | 23 | $(document).ready(function() { 24 | // Check for click events on the navbar burger icon 25 | $(".navbar-burger").click(function() { 26 | // Toggle the "is-active" class on both the "navbar-burger" and the "navbar-menu" 27 | $(".navbar-burger").toggleClass("is-active"); 28 | $(".navbar-menu").toggleClass("is-active"); 29 | 30 | }); 31 | 32 | var options = { 33 | slidesToScroll: 1, 34 | slidesToShow: 3, 35 | loop: true, 36 | infinite: true, 37 | autoplay: false, 38 | autoplaySpeed: 3000, 39 | } 40 | 41 | // Initialize all div with carousel class 42 | var carousels = bulmaCarousel.attach('.carousel', options); 43 | 44 | // Loop on each carousel initialized 45 | for(var i = 0; i < carousels.length; i++) { 46 | // Add listener to event 47 | carousels[i].on('before:show', state => { 48 | console.log(state); 49 | }); 50 | } 51 | 52 | // Access to bulmaCarousel instance of an element 53 | var element = document.querySelector('#my-element'); 54 | if (element && element.bulmaCarousel) { 55 | // bulmaCarousel instance is available as element.bulmaCarousel 56 | element.bulmaCarousel.on('before-show', function(state) { 57 | console.log(state); 58 | }); 59 | } 60 | 61 | /*var player = document.getElementById('interpolation-video'); 62 | player.addEventListener('loadedmetadata', function() { 63 | $('#interpolation-slider').on('input', function(event) { 64 | console.log(this.value, player.duration); 65 | player.currentTime = player.duration / 100 * this.value; 66 | }) 67 | }, false);*/ 68 | preloadInterpolationImages(); 69 | 70 | $('#interpolation-slider').on('input', function(event) { 71 | setInterpolationImage(this.value); 72 | }); 73 | setInterpolationImage(0); 74 | $('#interpolation-slider').prop('max', NUM_INTERP_FRAMES - 1); 75 | 76 | bulmaSlider.attach(); 77 | 78 | }) 79 | -------------------------------------------------------------------------------- /scale_rl/agents/jax_utils/tree_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | import jax 4 | import jax.numpy as jnp 5 | 6 | 7 | def tree_norm(tree): 8 | return jnp.sqrt(sum((x**2).sum() for x in jax.tree_util.tree_leaves(tree))) 9 | 10 | 11 | def tree_map_until_match( 12 | f, tree, target_re, *rest, keep_structure=True, keep_values=False 13 | ): 14 | """ 15 | Similar to `jax.tree_util.tree_map_with_path`, but `is_leaf` is a regex condition. 16 | args: 17 | f: A function to map the discovered nodes (i.e., dict key matches `target_re`). 18 | Inputs to f will be (1) the discovered node and (2) the corresponding nodes in `*rest``. 19 | target_re: A regex string condition that triggers `f`. 20 | tree: A pytree to be searched by `target_re` and mapped by `f`. 21 | *rest: List of pytrees that are at least 'almost' identical structure to `tree`. 22 | 'Almost', since the substructure of matching nodes don't have to be identical. 23 | i.e., The tree structure of `tree` and `*rest` should be identical only up to the matching nodes. 24 | keep_structure: If false, the returned tree will only contain subtrees that lead to the matching nodes. 25 | keep_values: If false, unmatched leaves will become `None`. Assumes `keep_structure=True`. 26 | """ 27 | 28 | if not isinstance(tree, dict): 29 | return tree if keep_values else None 30 | 31 | ret_tree = {} 32 | for k, v in tree.items(): 33 | v_rest = [r[k] for r in rest] 34 | if re.fullmatch(target_re, k): 35 | ret_tree[k] = f(v, *v_rest) 36 | else: 37 | subtree = tree_map_until_match( 38 | f, 39 | v, 40 | target_re, 41 | *v_rest, 42 | keep_structure=keep_structure, 43 | keep_values=keep_values, 44 | ) 45 | if keep_structure or subtree: 46 | ret_tree[k] = subtree 47 | 48 | return ret_tree 49 | 50 | 51 | def tree_filter(f, tree, target_re="scaler"): 52 | if isinstance(tree, dict): 53 | # Keep only "target_re" keys in the dictionary 54 | filtered_tree = {} 55 | for k, v in tree.items(): 56 | if re.fullmatch(target_re, k): 57 | filtered_tree[k] = tree_filter(f, v, target_re="scaler") 58 | elif isinstance(v, dict): # Recursively check nested dictionaries 59 | filtered_value = tree_filter(f, v, target_re="scaler") 60 | if filtered_value: # Only keep non-empty dictionaries 61 | filtered_tree[k] = filtered_value 62 | return filtered_tree 63 | else: 64 | # If not a dictionary, return the tree as is (typically a leaf node) 65 | return f(tree) 66 | -------------------------------------------------------------------------------- /deps/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.2.2-cudnn8-devel-ubuntu22.04 2 | ENV DEBIAN_FRONTEND=noninteractive 3 | ENV CUDA_HOME=/usr/local/cuda-12.2 4 | ENV PATH=$CUDA_HOME/bin:$PATH 5 | ENV LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH 6 | 7 | # packages 8 | RUN apt-get -y update && \ 9 | apt-get install -y --no-install-recommends \ 10 | build-essential \ 11 | git \ 12 | rsync \ 13 | tree \ 14 | curl \ 15 | wget \ 16 | unzip \ 17 | htop \ 18 | tmux \ 19 | xvfb \ 20 | patchelf \ 21 | ca-certificates \ 22 | bash-completion \ 23 | libjpeg-dev \ 24 | libpng-dev \ 25 | ffmpeg \ 26 | cmake \ 27 | swig \ 28 | libssl-dev \ 29 | libcurl4-openssl-dev \ 30 | libopenmpi-dev \ 31 | python3-dev \ 32 | zlib1g-dev \ 33 | qtbase5-dev \ 34 | qtdeclarative5-dev \ 35 | libglib2.0-0 \ 36 | libglu1-mesa-dev \ 37 | libgl1-mesa-dev \ 38 | libvulkan1 \ 39 | libgl1-mesa-glx \ 40 | libosmesa6 \ 41 | libosmesa6-dev \ 42 | libglew-dev \ 43 | mesa-utils && \ 44 | apt-get clean && \ 45 | apt-get autoremove -y && \ 46 | rm -rf /var/lib/apt/lists/* && \ 47 | mkdir /root/.ssh 48 | 49 | # python 50 | RUN apt-get -y update && \ 51 | apt-get install -y software-properties-common && \ 52 | add-apt-repository ppa:deadsnakes/ppa && \ 53 | apt-get install -y python3.10 python3.10-distutils python3.10-venv 54 | 55 | ## kubernetes authorization 56 | # kubernetes needs a numeric user apparently 57 | # Ensure the user has write permissions 58 | RUN useradd --create-home \ 59 | --shell /bin/bash \ 60 | --base-dir /home \ 61 | --groups dialout,audio,video,plugdev \ 62 | --uid 1000 \ 63 | user 64 | USER root 65 | WORKDIR /home/user 66 | RUN chown -R user:user /home/user && \ 67 | chmod -R u+rwx /home/user 68 | USER 1000 69 | 70 | # Install packages 71 | COPY deps/requirements.txt /home/user/requirements.txt 72 | RUN python3.10 -m venv /home/user/venv && \ 73 | . /home/user/venv/bin/activate && \ 74 | pip install -r requirements.txt && \ 75 | pip install -U "jax[cuda12]==0.4.25" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 76 | ENV VIRTUAL_ENV=/home/user/venv 77 | ENV PATH="$VIRTUAL_ENV/bin:$PATH" 78 | ENV XLA_PYTHON_CLIENT_PREALLOCATE=false 79 | 80 | # install mujoco 2.1.0, humanoid-bench and myosuite 81 | ENV MUJOCO_GL egl 82 | ENV LD_LIBRARY_PATH /home/user/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH} 83 | RUN wget https://github.com/deepmind/mujoco/releases/download/2.1.0/mujoco210-linux-x86_64.tar.gz && \ 84 | tar -xzf mujoco210-linux-x86_64.tar.gz && \ 85 | rm mujoco210-linux-x86_64.tar.gz && \ 86 | mkdir /home/user/.mujoco && \ 87 | mv mujoco210 /home/user/.mujoco/mujoco210 && \ 88 | find /home/user/.mujoco -exec chown user:user {} \; && \ 89 | python -c "import mujoco_py" && \ 90 | git clone https://github.com/joonleesky/humanoid-bench /home/user/humanoid-bench && \ 91 | pip install -e /home/user/humanoid-bench && \ 92 | git clone --recursive https://github.com/joonleesky/myosuite /home/user/myosuite && \ 93 | pip install -e /home/user/myosuite -------------------------------------------------------------------------------- /scale_rl/common/logger.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | from omegaconf import OmegaConf 4 | 5 | import wandb 6 | 7 | 8 | class WandbTrainerLogger(object): 9 | def __init__(self, cfg: Dict): 10 | self.cfg = cfg 11 | dict_cfg = OmegaConf.to_container(cfg, throw_on_missing=True) 12 | 13 | wandb.init( 14 | project=cfg.project_name, 15 | entity=cfg.entity_name, 16 | group=cfg.group_name, 17 | config=dict_cfg, 18 | ) 19 | 20 | self.reset() 21 | 22 | def update_metric(self, **kwargs) -> None: 23 | for k, v in kwargs.items(): 24 | if isinstance(v, float) or isinstance(v, int): 25 | self.average_meter_dict.update(k, v) 26 | else: 27 | self.media_dict[k] = v 28 | 29 | def log_metric(self, step: int) -> Dict: 30 | log_data = {} 31 | log_data.update(self.average_meter_dict.averages()) 32 | log_data.update(self.media_dict) 33 | wandb.log(log_data, step=step) 34 | 35 | def reset(self) -> None: 36 | self.average_meter_dict = AverageMeterDict() 37 | self.media_dict = {} 38 | 39 | 40 | class AverageMeterDict(object): 41 | """ 42 | Manages a collection of AverageMeter instances, 43 | allowing for grouped tracking and averaging of multiple metrics. 44 | """ 45 | 46 | def __init__(self, meters=None): 47 | self.meters = meters if meters else {} 48 | 49 | def __getitem__(self, key): 50 | if key not in self.meters: 51 | meter = AverageMeter() 52 | meter.update(0) 53 | return meter 54 | return self.meters[key] 55 | 56 | def update(self, name, value, n=1) -> None: 57 | if name not in self.meters: 58 | self.meters[name] = AverageMeter() 59 | self.meters[name].update(value, n) 60 | 61 | def reset(self) -> None: 62 | for meter in self.meters.values(): 63 | meter.reset() 64 | 65 | def values(self, format_string="{}"): 66 | return { 67 | format_string.format(name): meter.val for name, meter in self.meters.items() 68 | } 69 | 70 | def averages(self, format_string="{}"): 71 | return { 72 | format_string.format(name): meter.avg for name, meter in self.meters.items() 73 | } 74 | 75 | 76 | class AverageMeter(object): 77 | """ 78 | Tracks and calculates the average and current values of a series of numbers. 79 | """ 80 | 81 | def __init__(self): 82 | self.val = 0 83 | self.avg = 0 84 | self.sum = 0 85 | self.count = 0 86 | 87 | def reset(self): 88 | self.val = 0 89 | self.avg = 0 90 | self.sum = 0 91 | self.count = 0 92 | 93 | def update(self, val, n=1): 94 | # TODO: description for using n 95 | self.val = val 96 | self.sum += val * n 97 | self.count += n 98 | self.avg = self.sum / self.count 99 | 100 | def __format__(self, format): 101 | return "{self.val:{format}} ({self.avg:{format}})".format( 102 | self=self, format=format 103 | ) 104 | -------------------------------------------------------------------------------- /docs/dataset/css/bulma-carousel.min.css: -------------------------------------------------------------------------------- 1 | @-webkit-keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}@keyframes spinAround{from{-webkit-transform:rotate(0);transform:rotate(0)}to{-webkit-transform:rotate(359deg);transform:rotate(359deg)}}.slider{position:relative;width:100%}.slider-container{display:flex;flex-wrap:nowrap;flex-direction:row;overflow:hidden;-webkit-transform:translate3d(0,0,0);transform:translate3d(0,0,0);min-height:100%}.slider-container.is-vertical{flex-direction:column}.slider-container .slider-item{flex:none}.slider-container .slider-item .image.is-covered img{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.slider-container .slider-item .video-container{height:0;padding-bottom:0;padding-top:56.25%;margin:0;position:relative}.slider-container .slider-item .video-container.is-1by1,.slider-container .slider-item .video-container.is-square{padding-top:100%}.slider-container .slider-item .video-container.is-4by3{padding-top:75%}.slider-container .slider-item .video-container.is-21by9{padding-top:42.857143%}.slider-container .slider-item .video-container embed,.slider-container .slider-item .video-container iframe,.slider-container .slider-item .video-container object{position:absolute;top:0;left:0;width:100%!important;height:100%!important}.slider-navigation-next,.slider-navigation-previous{display:flex;justify-content:center;align-items:center;position:absolute;width:42px;height:42px;background:#fff center center no-repeat;background-size:20px 20px;border:1px solid #fff;border-radius:25091983px;box-shadow:0 2px 5px #3232321a;top:50%;margin-top:-20px;left:0;cursor:pointer;transition:opacity .3s,-webkit-transform .3s;transition:transform .3s,opacity .3s;transition:transform .3s,opacity .3s,-webkit-transform .3s}.slider-navigation-next:hover,.slider-navigation-previous:hover{-webkit-transform:scale(1.2);transform:scale(1.2)}.slider-navigation-next.is-hidden,.slider-navigation-previous.is-hidden{display:none;opacity:0}.slider-navigation-next svg,.slider-navigation-previous svg{width:25%}.slider-navigation-next{left:auto;right:0;background:#fff center center no-repeat;background-size:20px 20px}.slider-pagination{display:none;justify-content:center;align-items:center;position:absolute;bottom:0;left:0;right:0;padding:0.4rem 1rem;text-align:center}.slider-pagination .slider-page{background:#648ef6;width:10px;height:10px;border-radius:25091983px;display:inline-block;margin:0 3px;box-shadow:0 2px 5px #3232321a;transition:-webkit-transform .3s;transition:transform .3s;transition:transform .3s,-webkit-transform .3s;cursor:pointer}.slider-pagination .slider-page.is-active,.slider-pagination .slider-page:hover{-webkit-transform:scale(1.4);transform:scale(1.4)}@media screen and (min-width:800px){.slider-pagination{display:flex}}.hero.has-carousel{position:relative}.hero.has-carousel+.hero-body,.hero.has-carousel+.hero-footer,.hero.has-carousel+.hero-head{z-index:10;overflow:hidden}.hero.has-carousel .hero-carousel{position:absolute;top:0;left:0;bottom:0;right:0;height:auto;border:none;margin:auto;padding:0;z-index:0}.hero.has-carousel .hero-carousel .slider{width:100%;max-width:100%;overflow:hidden;height:100%!important;max-height:100%;z-index:0}.hero.has-carousel .hero-carousel .slider .has-background{max-height:100%}.hero.has-carousel .hero-carousel .slider .has-background .is-background{-o-object-fit:cover;object-fit:cover;-o-object-position:center center;object-position:center center;height:100%;width:100%}.hero.has-carousel .hero-body{margin:0 3rem;z-index:10} -------------------------------------------------------------------------------- /scale_rl/agents/simba/simba_network.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | from tensorflow_probability.substrates import jax as tfp 4 | 5 | from scale_rl.agents.simba.simba_layer import ( 6 | LinearCritic, 7 | NormalTanhPolicy, 8 | PreLNResidualBlock, 9 | ) 10 | 11 | tfd = tfp.distributions 12 | tfb = tfp.bijectors 13 | 14 | 15 | class SimbaActor(nn.Module): 16 | num_blocks: int 17 | hidden_dim: int 18 | action_dim: int 19 | 20 | def setup(self): 21 | self.embedder = nn.Dense( 22 | self.hidden_dim, kernel_init=nn.initializers.orthogonal(1.0) 23 | ) 24 | self.encoder = nn.Sequential( 25 | [ 26 | *[PreLNResidualBlock(hidden_dim=self.hidden_dim) 27 | for _ in range(self.num_blocks)], 28 | nn.LayerNorm(), 29 | ] 30 | ) 31 | self.predictor = NormalTanhPolicy(self.action_dim) 32 | 33 | def __call__( 34 | self, 35 | observations: jnp.ndarray, 36 | temperature: float = 1.0, 37 | ) -> tfd.Distribution: 38 | x = observations 39 | y = self.embedder(x) 40 | z = self.encoder(y) 41 | dist, info = self.predictor(z, temperature) 42 | return dist, info 43 | 44 | 45 | class SimbaCritic(nn.Module): 46 | num_blocks: int 47 | hidden_dim: int 48 | 49 | def setup(self): 50 | self.embedder = nn.Dense( 51 | self.hidden_dim, kernel_init=nn.initializers.orthogonal(1.0) 52 | ) 53 | self.encoder = nn.Sequential( 54 | [ 55 | *[PreLNResidualBlock(hidden_dim=self.hidden_dim) 56 | for _ in range(self.num_blocks)], 57 | nn.LayerNorm(), 58 | ] 59 | ) 60 | self.predictor = LinearCritic() 61 | 62 | def __call__( 63 | self, 64 | observations: jnp.ndarray, 65 | actions: jnp.ndarray, 66 | ) -> jnp.ndarray: 67 | x = jnp.concatenate((observations, actions), axis=1) 68 | y = self.embedder(x) 69 | z = self.encoder(y) 70 | q, info = self.predictor(z) 71 | return q, info 72 | 73 | 74 | class SimbaDoubleCritic(nn.Module): 75 | """ 76 | Vectorized Double-Q for Clipped Double Q-learning. 77 | https://arxiv.org/pdf/1802.09477v3 78 | """ 79 | 80 | num_blocks: int 81 | hidden_dim: int 82 | 83 | num_qs: int = 2 84 | 85 | @nn.compact 86 | def __call__( 87 | self, 88 | observations: jnp.ndarray, 89 | actions: jnp.ndarray, 90 | ) -> jnp.ndarray: 91 | VmapCritic = nn.vmap( 92 | SimbaCritic, 93 | variable_axes={"params": 0}, 94 | split_rngs={"params": True}, 95 | in_axes=None, 96 | out_axes=0, 97 | axis_size=self.num_qs, 98 | ) 99 | 100 | qs, infos = VmapCritic( 101 | num_blocks=self.num_blocks, 102 | hidden_dim=self.hidden_dim, 103 | )(observations, actions) 104 | 105 | return qs, infos 106 | 107 | 108 | class SimbaTemperature(nn.Module): 109 | initial_value: float = 1.0 110 | 111 | @nn.compact 112 | def __call__(self) -> jnp.ndarray: 113 | log_temp = self.param( 114 | name="log_temp", 115 | init_fn=lambda key: jnp.full( 116 | shape=(), fill_value=jnp.log(self.initial_value) 117 | ), 118 | ) 119 | return jnp.exp(log_temp) 120 | -------------------------------------------------------------------------------- /scale_rl/agents/base_agent.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from typing import Any, Dict, TypeVar 3 | 4 | import gymnasium as gym 5 | import numpy as np 6 | 7 | Config = TypeVar("Config") 8 | 9 | 10 | class BaseAgent(ABC): 11 | def __init__( 12 | self, 13 | observation_space: gym.spaces.Space, 14 | action_space: gym.spaces.Space, 15 | cfg: Config, 16 | ): 17 | """ 18 | A generic agent class. 19 | """ 20 | self._observation_space = observation_space 21 | self._action_space = action_space 22 | self._cfg = cfg 23 | 24 | @abstractmethod 25 | def sample_actions( 26 | self, 27 | interaction_step: int, 28 | prev_timestep: Dict[str, np.ndarray], 29 | training: bool, 30 | ) -> np.ndarray: 31 | pass 32 | 33 | @abstractmethod 34 | def update(self, update_step: int, batch: Dict[str, np.ndarray]) -> Dict: 35 | pass 36 | 37 | @abstractmethod 38 | def save(self, path: str) -> None: 39 | pass 40 | 41 | @abstractmethod 42 | def load(self, path: str) -> None: 43 | pass 44 | 45 | # @abstractmethod 46 | def get_metrics( 47 | self, batch: Dict[str, np.ndarray], update_info: Dict[str, Any] 48 | ) -> Dict: 49 | pass 50 | 51 | 52 | class AgentWrapper(BaseAgent): 53 | """Wraps the agent to allow a modular transformation. 54 | 55 | This class is the base class for all wrappers for agent class. 56 | The subclass could override some methods to change the behavior of the original agent 57 | without touching the original code. 58 | 59 | Note: 60 | Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. 61 | """ 62 | 63 | def __init__(self, agent: BaseAgent): 64 | self.agent = agent 65 | 66 | # explicitly forward the methods defined in Agent to self.agent 67 | def sample_actions( 68 | self, 69 | interaction_step: int, 70 | prev_timestep: Dict[str, np.ndarray], 71 | training: bool, 72 | ) -> np.ndarray: 73 | return self.agent.sample_actions( 74 | interaction_step=interaction_step, 75 | prev_timestep=prev_timestep, 76 | training=training, 77 | ) 78 | 79 | def update(self, update_step: int, batch: Dict[str, np.ndarray]) -> Dict: 80 | return self.agent.update( 81 | update_step=update_step, 82 | batch=batch, 83 | ) 84 | 85 | def get_metrics( 86 | self, batch: Dict[str, np.ndarray], update_info: Dict[str, Any] 87 | ) -> Dict: 88 | return self.agent.get_metrics( 89 | batch=batch, 90 | update_info=update_info, 91 | ) 92 | 93 | def save(self, path: str) -> None: 94 | self.agent.save(path) 95 | 96 | def load(self, path: str) -> None: 97 | self.agent.load(path) 98 | 99 | def set_attr(self, name, values): 100 | return self.agent.set_attr(name, values) 101 | 102 | # implicitly forward all other methods and attributes to self.env 103 | def __getattr__(self, name): 104 | if name.startswith("_"): 105 | raise AttributeError(f"attempted to get missing private attribute '{name}'") 106 | """ 107 | logger.warn( 108 | f"env.{name} to get variables from other wrappers is deprecated and will be removed in v1.0, " 109 | f"to get this variable you can do `env.unwrapped.{name}` for environment variables." 110 | ) 111 | """ 112 | return getattr(self.agent, name) 113 | 114 | @property 115 | def unwrapped(self): 116 | return self.agent.unwrapped 117 | 118 | def __repr__(self): 119 | return f"<{self.__class__.__name__}, {self.agent}>" 120 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | # pytype static type analyzer 143 | .pytype/ 144 | 145 | # Cython debug symbols 146 | cython_debug/ 147 | 148 | # PyCharm 149 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 150 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 151 | # and can be added to the global gitignore or merged into this file. For a more nuclear 152 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 153 | #.idea/ 154 | 155 | # log 156 | wandb/** 157 | tmp/** 158 | 159 | # cache 160 | *.ipynb_checkpoints 161 | *.nfs* 162 | 163 | # saved checkpoints 164 | models/** -------------------------------------------------------------------------------- /scale_rl/envs/__init__.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | from typing import Tuple, Any 3 | from gymnasium.vector import SyncVectorEnv, AsyncVectorEnv, VectorEnv 4 | from gymnasium.wrappers import RescaleAction, TimeLimit 5 | 6 | from scale_rl.envs.dmc import make_dmc_env 7 | from scale_rl.envs.mujoco import make_mujoco_env 8 | from scale_rl.envs.humanoid_bench import make_humanoid_env 9 | from scale_rl.envs.myosuite import make_myosuite_env 10 | from scale_rl.envs.d4rl import make_d4rl_env, make_d4rl_dataset, get_d4rl_normalized_score 11 | from scale_rl.envs.wrappers import RepeatAction 12 | 13 | 14 | def create_envs( 15 | env_type: str, 16 | seed: int, 17 | env_name: str, 18 | num_train_envs: int, 19 | num_eval_envs: int, 20 | rescale_action: bool, 21 | action_repeat: int, 22 | max_episode_steps: int, 23 | **kwargs, 24 | )-> Tuple[VectorEnv, VectorEnv]: 25 | 26 | train_env = create_vec_env( 27 | env_type=env_type, 28 | env_name=env_name, 29 | seed=seed, 30 | num_envs=num_train_envs, 31 | action_repeat=action_repeat, 32 | rescale_action=rescale_action, 33 | max_episode_steps=max_episode_steps, 34 | ) 35 | eval_env = create_vec_env( 36 | env_type=env_type, 37 | env_name=env_name, 38 | seed=seed, 39 | num_envs=num_eval_envs, 40 | action_repeat=action_repeat, 41 | rescale_action=rescale_action, 42 | max_episode_steps=max_episode_steps, 43 | ) 44 | 45 | return train_env, eval_env 46 | 47 | 48 | def create_vec_env( 49 | env_type: str, 50 | env_name: str, 51 | num_envs: int, 52 | seed: int, 53 | rescale_action: bool = True, 54 | action_repeat: int = 1, 55 | max_episode_steps: int = 1000, 56 | ) -> VectorEnv: 57 | 58 | def make_one_env( 59 | env_type: str, 60 | env_name:str, 61 | seed:int, 62 | rescale_action:bool, 63 | action_repeat:int, 64 | max_episode_steps: int, 65 | **kwargs 66 | ) -> gym.Env: 67 | 68 | if env_type == 'dmc': 69 | env = make_dmc_env(env_name, seed, **kwargs) 70 | elif env_type == 'mujoco': 71 | env = make_mujoco_env(env_name, seed, **kwargs) 72 | elif env_type == 'humanoid_bench': 73 | env = make_humanoid_env(env_name, seed, **kwargs) 74 | elif env_type == 'myosuite': 75 | env = make_myosuite_env(env_name, seed, **kwargs) 76 | elif env_type == "d4rl": 77 | env = make_d4rl_env(env_name, seed, **kwargs) 78 | else: 79 | raise NotImplementedError 80 | 81 | if rescale_action: 82 | env = RescaleAction(env, -1.0, 1.0) 83 | 84 | # limit max_steps before action_repeat. 85 | env = TimeLimit(env, max_episode_steps) 86 | 87 | if action_repeat > 1: 88 | env = RepeatAction(env, action_repeat) 89 | 90 | env.observation_space.seed(seed) 91 | env.action_space.seed(seed) 92 | 93 | return env 94 | 95 | env_fns = [ 96 | ( 97 | lambda i=i: make_one_env( 98 | env_type=env_type, 99 | env_name=env_name, 100 | seed=seed + i, 101 | rescale_action=rescale_action, 102 | action_repeat=action_repeat, 103 | max_episode_steps=max_episode_steps, 104 | ) 105 | ) 106 | for i in range(num_envs) 107 | ] 108 | if len(env_fns) > 1: 109 | envs = AsyncVectorEnv(env_fns, autoreset_mode='SameStep') 110 | else: 111 | envs = SyncVectorEnv(env_fns, autoreset_mode='SameStep') 112 | 113 | return envs 114 | 115 | 116 | def create_dataset(env_type: str, env_name: str) -> list[dict[str, Any]]: 117 | if env_type == 'd4rl': 118 | dataset = make_d4rl_dataset(env_name) 119 | else: 120 | raise NotImplementedError 121 | return dataset 122 | 123 | 124 | def get_normalized_score(env_type: str, env_name: str, unnormalized_score: float) -> float: 125 | if env_type == "d4rl": 126 | score = get_d4rl_normalized_score(env_name, unnormalized_score) 127 | else: 128 | raise NotImplementedError 129 | return score 130 | -------------------------------------------------------------------------------- /scale_rl/evaluation.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import gymnasium as gym 4 | import numpy as np 5 | from gymnasium.vector import VectorEnv 6 | 7 | import wandb 8 | 9 | 10 | def evaluate( 11 | agent, 12 | env: VectorEnv, 13 | num_episodes: int, 14 | ) -> Dict[str, float]: 15 | n = env.num_envs 16 | 17 | assert num_episodes % n == 0, "num_episodes must be divisible by env.num_envs" 18 | num_eval_episodes_per_env = num_episodes // n 19 | 20 | total_returns = [] 21 | total_successes = [] 22 | total_lengths = [] 23 | 24 | for _ in range(num_eval_episodes_per_env): 25 | returns = np.zeros(n) 26 | lengths = np.zeros(n) 27 | successes = np.zeros(n) 28 | 29 | observations, infos = env.reset() 30 | 31 | prev_timestep = {"next_observation": observations} 32 | 33 | dones = np.zeros(n) 34 | while np.sum(dones) < n: 35 | actions = agent.sample_actions( 36 | interaction_step=0, 37 | prev_timestep=prev_timestep, 38 | training=False, 39 | ) 40 | next_observations, rewards, terminateds, truncateds, infos = env.step( 41 | actions 42 | ) 43 | 44 | prev_timestep = {"next_observation": next_observations} 45 | 46 | returns += rewards * (1 - dones) 47 | lengths += 1 - dones 48 | 49 | if "success" in infos: 50 | successes += infos["success"].astype("float") * (1 - dones) 51 | 52 | elif "final_info" in infos: 53 | final_successes = np.zeros(n) 54 | for idx in range(n): 55 | final_info = infos["final_info"] 56 | 57 | if "success" in final_info: 58 | try: 59 | final_successes[idx] = final_info["success"][idx].astype( 60 | "float" 61 | ) 62 | except: 63 | final_successes[idx] = np.array( 64 | final_info["success"][idx] 65 | ).astype("float") 66 | successes += final_successes * (1 - dones) 67 | 68 | else: 69 | pass 70 | 71 | # once an episode is done in a sub-environment, we assume it to be done. 72 | # also, we assume to be done whether it is terminated or truncated during evaluation. 73 | dones = np.maximum(dones, terminateds) 74 | dones = np.maximum(dones, truncateds) 75 | 76 | # proceed 77 | observations = next_observations 78 | 79 | for env_idx in range(n): 80 | total_returns.append(returns[env_idx]) 81 | total_lengths.append(lengths[env_idx]) 82 | total_successes.append(successes[env_idx].astype("bool").astype("float")) 83 | 84 | eval_info = { 85 | "avg_return": np.mean(total_returns), 86 | "avg_length": np.mean(total_lengths), 87 | "avg_success": np.mean(total_successes), 88 | } 89 | 90 | return eval_info 91 | 92 | 93 | def record_video( 94 | agent, 95 | env: VectorEnv, 96 | num_episodes: int, 97 | video_length: int = 100, 98 | ) -> Dict[str, float]: 99 | n = env.num_envs 100 | assert num_episodes % n == 0, "num_episodes must be divisible by env.num_envs" 101 | num_eval_episodes_per_env = num_episodes // n 102 | 103 | total_videos = [] 104 | 105 | for _ in range(num_eval_episodes_per_env): 106 | videos = [] 107 | 108 | observations, infos = env.reset() 109 | prev_timestep = {"next_observation": observations} 110 | images = env.call("render") 111 | dones = np.zeros(n) 112 | while np.sum(dones) < n: 113 | actions = agent.sample_actions( 114 | interaction_step=0, 115 | prev_timestep=prev_timestep, 116 | training=False, 117 | ) 118 | next_observations, rewards, terminateds, truncateds, infos = env.step( 119 | actions 120 | ) 121 | 122 | prev_timestep = {"next_observation": next_observations} 123 | 124 | # once an episode is done in a sub-environment, we assume it to be done. 125 | dones = np.maximum(dones, terminateds) 126 | dones = np.maximum(dones, truncateds) 127 | 128 | # proceed 129 | videos.append(images) 130 | images = env.call("render") 131 | observations = next_observations 132 | 133 | total_videos.append(np.stack(videos, axis=1)) # (n, t, c, h, w) 134 | 135 | total_videos = np.concatenate(total_videos, axis=0) # (b, t, h, w, c) 136 | total_videos = total_videos[:, :video_length] 137 | total_videos = total_videos.transpose(0, 1, 4, 2, 3) # (b, t, c, h, w) 138 | 139 | video_info = {"video": wandb.Video(total_videos, fps=10, format="gif")} 140 | 141 | return video_info 142 | -------------------------------------------------------------------------------- /scale_rl/buffers/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def fast_uniform_sample(max_size: int, num_samples: int): 5 | """ 6 | for speed comparison of uniform sampling, refer to analysis/benchmark_buffer_speed.ipynb 7 | """ 8 | interval = max_size // num_samples 9 | if max_size % num_samples == 0: 10 | return np.arange(0, max_size, interval) + np.random.randint( 11 | 0, interval, size=num_samples 12 | ) 13 | else: 14 | return np.arange(0, max_size, interval)[:-1] + np.random.randint( 15 | 0, interval, size=num_samples 16 | ) 17 | 18 | 19 | # Segment tree data structure where parent node values are sum/max of children node values 20 | class SegmentTree: 21 | def __init__(self, size): 22 | self._index = 0 23 | self._size = size 24 | self._full = False # Used to track actual capacity 25 | self._tree_start_idx = ( 26 | 2 ** (size - 1).bit_length() - 1 27 | ) # Put all used node leaves on last tree level 28 | self._sum_tree = np.zeros( 29 | (self._tree_start_idx + self._size,), dtype=np.float32 30 | ) 31 | self._max = ( 32 | 1 # Initial max value to return (1 = 1^ω), default priority is set to max 33 | ) 34 | 35 | # Updates nodes values from current tree 36 | def _update_nodes(self, indices): 37 | children_indices = indices * 2 + np.expand_dims([1, 2], axis=1) 38 | self._sum_tree[indices] = np.sum(self._sum_tree[children_indices], axis=0) 39 | 40 | # Propagates changes up tree given tree indices 41 | def _propagate(self, indices): 42 | parents = (indices - 1) // 2 43 | unique_parents = np.unique(parents) 44 | self._update_nodes(unique_parents) 45 | if parents[0] != 0: 46 | self._propagate(parents) 47 | 48 | # Propagates single value up tree given a tree index for efficiency 49 | def _propagate_index(self, index): 50 | parent = (index - 1) // 2 51 | left, right = 2 * parent + 1, 2 * parent + 2 52 | self._sum_tree[parent] = self._sum_tree[left] + self._sum_tree[right] 53 | if parent != 0: 54 | self._propagate_index(parent) 55 | 56 | # Updates values given tree indices 57 | def update(self, indices, values): 58 | self._sum_tree[indices] = values # Set new values 59 | self._propagate(indices) # Propagate values 60 | current_max_value = np.max(values) 61 | self._max = max(current_max_value, self._max) 62 | 63 | # Updates single value given a tree index for efficiency 64 | def _update_index(self, index, value): 65 | self._sum_tree[index] = value # Set new value 66 | self._propagate_index(index) # Propagate value 67 | self._max = max(value, self._max) 68 | 69 | def add(self, value): 70 | self._update_index(self._index + self._tree_start_idx, value) # Update tree 71 | self._index = (self._index + 1) % self._size # Update index 72 | self._full = self._full or self._index == 0 # Save when capacity reached 73 | self._max = max(value, self._max) 74 | 75 | # Searches for the location of values in sum tree 76 | def _retrieve(self, indices, values): 77 | children_indices = indices * 2 + np.expand_dims( 78 | [1, 2], axis=1 79 | ) # Make matrix of children indices 80 | # If indices correspond to leaf nodes, return them 81 | if children_indices[0, 0] >= self._sum_tree.shape[0]: 82 | return indices 83 | # If children indices correspond to leaf nodes, bound rare outliers in case total slightly overshoots 84 | elif children_indices[0, 0] >= self._tree_start_idx: 85 | children_indices = np.minimum(children_indices, self._sum_tree.shape[0] - 1) 86 | left_children_values = self._sum_tree[children_indices[0]] 87 | successor_choices = np.greater(values, left_children_values).astype( 88 | np.int32 89 | ) # Classify which values are in left or right branches 90 | successor_indices = children_indices[ 91 | successor_choices, np.arange(indices.size) 92 | ] # Use classification to index into the indices matrix 93 | successor_values = ( 94 | values - successor_choices * left_children_values 95 | ) # Subtract the left branch values when searching in the right branch 96 | return self._retrieve(successor_indices, successor_values) 97 | 98 | # Searches for values in sum tree and returns values, data indices and tree indices 99 | def find(self, values): 100 | indices = self._retrieve(np.zeros(values.shape, dtype=np.int32), values) 101 | data_index = indices - self._tree_start_idx 102 | return ( 103 | data_index, 104 | indices, 105 | self._sum_tree[indices], 106 | ) # Return values, data indices, tree indices 107 | 108 | @property 109 | def total(self): 110 | return self._sum_tree[0] 111 | 112 | @property 113 | def max(self): 114 | return self._max 115 | -------------------------------------------------------------------------------- /scale_rl/agents/simbaV2/simbaV2_network.py: -------------------------------------------------------------------------------- 1 | import flax.linen as nn 2 | import jax.numpy as jnp 3 | from tensorflow_probability.substrates import jax as tfp 4 | 5 | from scale_rl.agents.simbaV2.simbaV2_layer import ( 6 | HyperCategoricalValue, 7 | HyperEmbedder, 8 | HyperLERPBlock, 9 | HyperNormalTanhPolicy, 10 | ) 11 | 12 | tfd = tfp.distributions 13 | tfb = tfp.bijectors 14 | 15 | 16 | class SimbaV2Actor(nn.Module): 17 | num_blocks: int 18 | hidden_dim: int 19 | action_dim: int 20 | scaler_init: float 21 | scaler_scale: float 22 | alpha_init: float 23 | alpha_scale: float 24 | c_shift: float 25 | 26 | def setup(self): 27 | self.embedder = HyperEmbedder( 28 | hidden_dim=self.hidden_dim, 29 | scaler_init=self.scaler_init, 30 | scaler_scale=self.scaler_scale, 31 | c_shift=self.c_shift, 32 | ) 33 | self.encoder = nn.Sequential( 34 | [ 35 | HyperLERPBlock( 36 | hidden_dim=self.hidden_dim, 37 | scaler_init=self.scaler_init, 38 | scaler_scale=self.scaler_scale, 39 | alpha_init=self.alpha_init, 40 | alpha_scale=self.alpha_scale, 41 | ) 42 | for _ in range(self.num_blocks) 43 | ] 44 | ) 45 | self.predictor = HyperNormalTanhPolicy( 46 | hidden_dim=self.hidden_dim, 47 | action_dim=self.action_dim, 48 | scaler_init=1.0, 49 | scaler_scale=1.0, 50 | ) 51 | 52 | def __call__( 53 | self, 54 | observations: jnp.ndarray, 55 | temperature: float = 1.0, 56 | ) -> tfd.Distribution: 57 | x = observations 58 | y = self.embedder(x) 59 | z = self.encoder(y) 60 | dist, info = self.predictor(z, temperature) 61 | 62 | return dist, info 63 | 64 | 65 | class SimbaV2Critic(nn.Module): 66 | num_blocks: int 67 | hidden_dim: int 68 | scaler_init: float 69 | scaler_scale: float 70 | alpha_init: float 71 | alpha_scale: float 72 | c_shift: float 73 | num_bins: int 74 | min_v: float 75 | max_v: float 76 | 77 | def setup(self): 78 | self.embedder = HyperEmbedder( 79 | hidden_dim=self.hidden_dim, 80 | scaler_init=self.scaler_init, 81 | scaler_scale=self.scaler_scale, 82 | c_shift=self.c_shift, 83 | ) 84 | self.encoder = nn.Sequential( 85 | [ 86 | HyperLERPBlock( 87 | hidden_dim=self.hidden_dim, 88 | scaler_init=self.scaler_init, 89 | scaler_scale=self.scaler_scale, 90 | alpha_init=self.alpha_init, 91 | alpha_scale=self.alpha_scale, 92 | ) 93 | for _ in range(self.num_blocks) 94 | ] 95 | ) 96 | 97 | self.predictor = HyperCategoricalValue( 98 | hidden_dim=self.hidden_dim, 99 | num_bins=self.num_bins, 100 | min_v=self.min_v, 101 | max_v=self.max_v, 102 | scaler_init=1.0, 103 | scaler_scale=1.0, 104 | ) 105 | 106 | def __call__( 107 | self, 108 | observations: jnp.ndarray, 109 | actions: jnp.ndarray, 110 | ) -> jnp.ndarray: 111 | x = jnp.concatenate((observations, actions), axis=1) 112 | y = self.embedder(x) 113 | z = self.encoder(y) 114 | q, info = self.predictor(z) 115 | return q, info 116 | 117 | 118 | class SimbaV2DoubleCritic(nn.Module): 119 | """ 120 | Vectorized Double-Q for Clipped Double Q-learning. 121 | https://arxiv.org/pdf/1802.09477v3 122 | """ 123 | 124 | num_blocks: int 125 | hidden_dim: int 126 | scaler_init: float 127 | scaler_scale: float 128 | alpha_init: float 129 | alpha_scale: float 130 | c_shift: float 131 | num_bins: int 132 | min_v: float 133 | max_v: float 134 | 135 | num_qs: int = 2 136 | 137 | @nn.compact 138 | def __call__( 139 | self, 140 | observations: jnp.ndarray, 141 | actions: jnp.ndarray, 142 | ) -> jnp.ndarray: 143 | VmapCritic = nn.vmap( 144 | SimbaV2Critic, 145 | variable_axes={"params": 0}, 146 | split_rngs={"params": True}, 147 | in_axes=None, 148 | out_axes=0, 149 | axis_size=self.num_qs, 150 | ) 151 | 152 | qs, infos = VmapCritic( 153 | num_blocks=self.num_blocks, 154 | hidden_dim=self.hidden_dim, 155 | scaler_init=self.scaler_init, 156 | scaler_scale=self.scaler_scale, 157 | alpha_init=self.alpha_init, 158 | alpha_scale=self.alpha_scale, 159 | c_shift=self.c_shift, 160 | num_bins=self.num_bins, 161 | min_v=self.min_v, 162 | max_v=self.max_v, 163 | )(observations, actions) 164 | 165 | return qs, infos 166 | 167 | 168 | class SimbaV2Temperature(nn.Module): 169 | initial_value: float = 0.01 170 | 171 | @nn.compact 172 | def __call__(self) -> jnp.ndarray: 173 | log_temp = self.param( 174 | name="log_temp", 175 | init_fn=lambda key: jnp.full( 176 | shape=(), fill_value=jnp.log(self.initial_value) 177 | ), 178 | ) 179 | return jnp.exp(log_temp) 180 | -------------------------------------------------------------------------------- /scale_rl/envs/humanoid_bench.py: -------------------------------------------------------------------------------- 1 | import gymnasium as gym 2 | 3 | ################################# 4 | # 5 | # Original Humanoid 6 | # 7 | ################################# 8 | 9 | # 14 tasks 10 | HB_LOCOMOTION = [ 11 | "h1hand-walk-v0", 12 | "h1hand-stand-v0", 13 | "h1hand-run-v0", 14 | "h1hand-reach-v0", 15 | "h1hand-hurdle-v0", 16 | "h1hand-crawl-v0", 17 | "h1hand-maze-v0", 18 | "h1hand-sit_simple-v0", 19 | "h1hand-sit_hard-v0", 20 | "h1hand-balance_simple-v0", 21 | "h1hand-balance_hard-v0", 22 | "h1hand-stair-v0", 23 | "h1hand-slide-v0", 24 | "h1hand-pole-v0", 25 | ] 26 | 27 | # 17 tasks 28 | HB_MANIPULATION = [ 29 | "h1hand-push-v0", 30 | "h1hand-cabinet-v0", 31 | "h1strong-highbar_hard-v0", # Make hands stronger to be able to hang from the high bar 32 | "h1hand-door-v0", 33 | "h1hand-truck-v0", 34 | "h1hand-cube-v0", 35 | "h1hand-bookshelf_simple-v0", 36 | "h1hand-bookshelf_hard-v0", 37 | "h1hand-basketball-v0", 38 | "h1hand-window-v0", 39 | "h1hand-spoon-v0", 40 | "h1hand-kitchen-v0", 41 | "h1hand-package-v0", 42 | "h1hand-powerlift-v0", 43 | "h1hand-room-v0", 44 | "h1hand-insert_small-v0", 45 | "h1hand-insert_normal-v0", 46 | ] 47 | 48 | 49 | ################################# 50 | # 51 | # No Hand Humanoid 52 | # 53 | ################################# 54 | 55 | HB_LOCOMOTION_NOHAND = [ 56 | "h1-walk-v0", 57 | "h1-stand-v0", 58 | "h1-run-v0", 59 | "h1-reach-v0", 60 | "h1-hurdle-v0", 61 | "h1-crawl-v0", 62 | "h1-maze-v0", 63 | "h1-sit_simple-v0", 64 | "h1-sit_hard-v0", 65 | "h1-balance_simple-v0", 66 | "h1-balance_hard-v0", 67 | "h1-stair-v0", 68 | "h1-slide-v0", 69 | "h1-pole-v0", 70 | ] 71 | 72 | HB_LOCOMOTION_NOHAND_MINI = [ 73 | "h1-walk-v0", 74 | "h1-run-v0", 75 | "h1-sit_hard-v0", 76 | "h1-balance_simple-v0", 77 | "h1-stair-v0", 78 | ] 79 | 80 | ################################# 81 | # 82 | # Task Success scores 83 | # 84 | ################################# 85 | 86 | # 10 seeds, 10 eval envs per 1 seed 87 | HB_RANDOM_SCORE = { 88 | "h1-walk-v0": 2.377, 89 | "h1-stand-v0": 10.545, 90 | "h1-run-v0": 2.02, 91 | "h1-reach-v0": 260.302, 92 | "h1-hurdle-v0": 2.214, 93 | "h1-crawl-v0": 272.658, 94 | "h1-maze-v0": 106.441, 95 | "h1-sit_simple-v0": 9.393, 96 | "h1-sit_hard-v0": 2.448, 97 | "h1-balance_simple-v0": 9.391, 98 | "h1-balance_hard-v0": 9.044, 99 | "h1-stair-v0": 3.112, 100 | "h1-slide-v0": 3.191, 101 | "h1-pole-v0": 20.09, 102 | "h1hand-push-v0": -526.8, 103 | "h1hand-cabinet-v0": 37.733, 104 | "h1strong-highbar-hard-v0": 0.178, 105 | "h1hand-door-v0": 2.771, 106 | "h1hand-truck-v0": 562.419, 107 | "h1hand-cube-v0": 4.787, 108 | "h1hand-bookshelf_simple-v0": 16.777, 109 | "h1hand-bookshelf_hard-v0": 14.848, 110 | "h1hand-basketball-v0": 8.979, 111 | "h1hand-window-v0": 2.713, 112 | "h1hand-spoon-v0": 4.661, 113 | "h1hand-kitchen-v0": 0.0, 114 | "h1hand-package-v0": -10040.932, 115 | "h1hand-powerlift-v0": 17.638, 116 | "h1hand-room-v0": 3.018, 117 | "h1hand-insert_small-v0": 1.653, 118 | "h1hand-insert_normal-v0": 1.673, 119 | "h1hand-walk-v0": 2.505, 120 | "h1hand-stand-v0": 11.973, 121 | "h1hand-run-v0": 1.927, 122 | "h1hand-reach-v0": -50.024, 123 | "h1hand-hurdle-v0": 2.371, 124 | "h1hand-crawl-v0": 278.868, 125 | "h1hand-maze-v0": 106.233, 126 | "h1hand-sit_simple-v0": 10.768, 127 | "h1hand-sit_hard-v0": 2.477, 128 | "h1hand-balance_simple-v0": 10.17, 129 | "h1hand-balance_hard-v0": 10.032, 130 | "h1hand-stair-v0": 3.161, 131 | "h1hand-slide-v0": 3.142, 132 | "h1hand-pole-v0": 19.721, 133 | } 134 | 135 | HB_SUCCESS_SCORE = { 136 | "h1-walk-v0": 700.0, 137 | "h1-stand-v0": 800.0, 138 | "h1-run-v0": 700.0, 139 | "h1-reach-v0": 12000.0, 140 | "h1-hurdle-v0": 700.0, 141 | "h1-crawl-v0": 700.0, 142 | "h1-maze-v0": 1200.0, 143 | "h1-sit_simple-v0": 750.0, 144 | "h1-sit_hard-v0": 750.0, 145 | "h1-balance_simple-v0": 800.0, 146 | "h1-balance_hard-v0": 800.0, 147 | "h1-stair_v0": 700.0, 148 | "h1-slide_v0": 700.0, 149 | "h1-pole_v0": 700.0, 150 | "h1-push_v0": 700.0, 151 | "h1-cabinet_v0": 2500.0, 152 | "h1-highbar_v0": 750.0, 153 | "h1-door_v0": 600.0, 154 | "h1-truck_v0": 3000.0, 155 | "h1-cube_v0": 370.0, 156 | "h1-bookshelf_simple_v0": 2000.0, 157 | "h1-bookshelf_hard_v0": 2000.0, 158 | "h1-basketball_v0": 1200.0, 159 | "h1-window_v0": 650.0, 160 | "h1-spoon_v0": 650.0, 161 | "h1-kitchen_v0": 4.0, 162 | "h1-package_v0": 1500.0, 163 | "h1-powerlift_v0": 800.0, 164 | "h1-room_v0": 400.0, 165 | "h1-insert_small_v0": 350.0, 166 | "h1-insert_normal_v0": 350.0, 167 | } 168 | 169 | 170 | class HBGymnasiumVersionWrapper(gym.Wrapper): 171 | """ 172 | humanoid bench originally requires gymnasium==0.29.1 173 | however, we are currently using gymnasium==1.0.0a2, 174 | hence requiring some minor fix to the rendering function 175 | """ 176 | 177 | def __init__(self, env: gym.Env): 178 | super().__init__(env) 179 | self.task = env.unwrapped.task 180 | 181 | def render(self): 182 | return self.task._env.mujoco_renderer.render(self.task._env.render_mode) 183 | 184 | 185 | def make_humanoid_env( 186 | env_name: str, 187 | seed: int, 188 | ) -> gym.Env: 189 | import humanoid_bench 190 | 191 | additional_kwargs = {} 192 | if env_name == "h1hand-package-v0": 193 | additional_kwargs = {"policy_path": None} 194 | env = gym.make(env_name, **additional_kwargs) 195 | env = HBGymnasiumVersionWrapper(env) 196 | 197 | return env 198 | -------------------------------------------------------------------------------- /scale_rl/agents/simba/simba_update.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple 2 | 3 | import flax 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from scale_rl.agents.jax_utils.network import Network, PRNGKey 8 | from scale_rl.buffers import Batch 9 | 10 | 11 | def update_actor( 12 | key: PRNGKey, 13 | actor: Network, 14 | critic: Network, 15 | temperature: Network, 16 | batch: Batch, 17 | use_cdq: bool, 18 | bc_alpha: float, 19 | ) -> Tuple[Network, Dict[str, float]]: 20 | def actor_loss_fn( 21 | actor_params: flax.core.FrozenDict[str, Any], 22 | ) -> Tuple[jnp.ndarray, Dict[str, float]]: 23 | dist, _ = actor.apply( 24 | variables={"params": actor_params}, 25 | observations=batch["observation"], 26 | ) 27 | actions = dist.sample(seed=key) 28 | log_probs = dist.log_prob(actions) 29 | 30 | if use_cdq: 31 | # qs: (2, n) 32 | qs, q_infos = critic(observations=batch["observation"], actions=actions) 33 | q = jnp.minimum(qs[0], qs[1]) 34 | else: 35 | q, _ = critic(observations=batch["observation"], actions=actions) 36 | 37 | actor_loss = (log_probs * temperature() - q).mean() 38 | 39 | if bc_alpha > 0: 40 | # https://arxiv.org/abs/2306.02451 41 | q_abs = jax.lax.stop_gradient(jnp.abs(q).mean()) 42 | bc_loss = ((actions - batch["action"]) ** 2).mean() 43 | actor_loss = actor_loss + bc_alpha * q_abs * bc_loss 44 | 45 | actor_info = { 46 | "actor/loss": actor_loss, 47 | "actor/entropy": -log_probs.mean(), 48 | "actor/mean_action": jnp.mean(actions), 49 | } 50 | return actor_loss, actor_info 51 | 52 | actor, info = actor.apply_gradient(actor_loss_fn) 53 | 54 | return actor, info 55 | 56 | 57 | def update_critic( 58 | key: PRNGKey, 59 | actor: Network, 60 | critic: Network, 61 | target_critic: Network, 62 | temperature: Network, 63 | batch: Batch, 64 | use_cdq: bool, 65 | gamma: float, 66 | n_step: int, 67 | ) -> Tuple[Network, Dict[str, float]]: 68 | # compute the target q-value 69 | next_dist, _ = actor(observations=batch["next_observation"]) 70 | next_actions = next_dist.sample(seed=key) 71 | next_actor_log_probs = next_dist.log_prob(next_actions) 72 | next_actor_entropy = temperature() * next_actor_log_probs 73 | 74 | if use_cdq: 75 | # next_qs: (2, n) 76 | next_qs, next_q_infos = target_critic( 77 | observations=batch["next_observation"], actions=next_actions 78 | ) 79 | next_q = jnp.minimum(next_qs[0], next_qs[1]) 80 | else: 81 | next_q, next_q_info = target_critic( 82 | observations=batch["next_observation"], 83 | actions=next_actions, 84 | ) 85 | 86 | # compute the td-target, incorporating the n-step accumulated reward 87 | # https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/ 88 | target_q = batch["reward"] + (gamma**n_step) * (1 - batch["terminated"]) * ( 89 | next_q - next_actor_entropy 90 | ) 91 | 92 | def critic_loss_fn( 93 | critic_params: flax.core.FrozenDict[str, Any], 94 | ) -> Tuple[jnp.ndarray, Dict[str, float]]: 95 | # compute predicted q-value 96 | if use_cdq: 97 | pred_qs, pred_q_infos = critic.apply( 98 | variables={"params": critic_params}, 99 | observations=batch["observation"], 100 | actions=batch["action"], 101 | ) 102 | loss_1 = (pred_qs[0] - target_q) ** 2 103 | loss_2 = (pred_qs[1] - target_q) ** 2 104 | critic_loss = (loss_1 + loss_2).mean() 105 | else: 106 | pred_q, _ = critic.apply( 107 | variables={"params": critic_params}, 108 | observations=batch["observation"], 109 | actions=batch["action"], 110 | ) 111 | critic_loss = ((pred_q - target_q) ** 2).mean() 112 | 113 | critic_info = { 114 | "critic/loss": critic_loss, 115 | "critic/batch_rew_min": batch["reward"].min(), 116 | "critic/batch_rew_mean": batch["reward"].mean(), 117 | "critic/batch_rew_max": batch["reward"].max(), 118 | } 119 | 120 | return critic_loss, critic_info 121 | 122 | critic, info = critic.apply_gradient(critic_loss_fn) 123 | 124 | return critic, info 125 | 126 | 127 | def update_target_network( 128 | network: Network, 129 | target_network: Network, 130 | target_tau: bool, 131 | ) -> Tuple[Network, Dict[str, float]]: 132 | new_target_params = jax.tree_map( 133 | lambda p, tp: p * target_tau + tp * (1 - target_tau), 134 | network.params, 135 | target_network.params, 136 | ) 137 | target_network = target_network.replace(params=new_target_params) 138 | 139 | info = {} 140 | 141 | return target_network, info 142 | 143 | 144 | def update_temperature( 145 | temperature: Network, entropy: float, target_entropy: float 146 | ) -> Tuple[Network, Dict[str, float]]: 147 | def temperature_loss_fn( 148 | temperature_params: flax.core.FrozenDict[str, Any], 149 | ) -> Tuple[jnp.ndarray, Dict[str, float]]: 150 | temperature_value = temperature.apply({"params": temperature_params}) 151 | temperature_loss = temperature_value * (entropy - target_entropy).mean() 152 | temperature_info = { 153 | "temperature/value": temperature_value, 154 | "temperature/loss": temperature_loss, 155 | } 156 | 157 | return temperature_loss, temperature_info 158 | 159 | temperature, info = temperature.apply_gradient(temperature_loss_fn) 160 | 161 | return temperature, info 162 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SimbaV2 2 | 3 | > **DAVIAN Robotics, KAIST AI & Sony AI** 4 | > ICML 2025 (spotlight). 5 | 6 | ## Introduction 7 | 8 | SimbaV2 is a reinforcement learning architecture designed to stabilize training via hyperspherical normalization. By increasing model capacity and compute, SimbaV2 achieves state-of-the-art results on 57 continuous control tasks from MuJoCo, DMControl, MyoSuite, and Humanoid-bench. 9 | 10 |

11 | 12 |

13 | 14 | Hojoon Lee\*  15 | Youngdo Lee\*  16 | Takuma Seno  17 | Donghu Kim  18 | Peter Stone  19 | Jaegul Choo  20 | 21 | [[Website]](https://dojeon-ai.github.io/SimbaV2/) [[Paper]](https://arxiv.org/abs/2502.15280) [[Dataset]](https://dojeon-ai.github.io/SimbaV2/dataset/) 22 | 23 | ## Result 24 | 25 | We compare SimbaV2 to the original Simba by tracking: 26 | - (a) Average normalized return across tasks. 27 | - (b) Weighted sum of $\ell_2$-norms of all intermediate features in critics. 28 | - (c) Weighted sum of $\ell_2$-norms of all critic parameters. 29 | - (d) Weighted sum of $\ell_2$-norms of all gradients in critics. 30 | - (e) Effective learning rate (ELR) of the critics. 31 | 32 | SimbaV2 consistently maintains stable norms and ELR, while Simba shows divergent fluctuations. 33 | 34 | 35 |

36 | 37 |

38 | 39 | We scale model parameters by increasing critic width and scale compute via the update-to-data (UTD) ratio. We also explore resetting vs. non-resetting training: 40 | - DMC-Hard (7 tasks): $\texttt{dog}$ and $\texttt{humanoid}$ embodiments. 41 | - HBench-Hard (5 tasks): $\texttt{run}$, $\texttt{balance-simple}$, $\texttt{sit-hard}$, $\texttt{stair}$, $\texttt{walk}$. 42 | 43 | On these challenging subsets, SimbaV2 benefits from increasing model size and UTD, while Simba plateaus. Notably, SimbaV2 scales smoothly with UTD even without resets, and resetting can degrade its performance. 44 | 45 |

46 | 47 | 48 |

49 | 50 | SimbaV2 outperforms competing RL algorithms, with performance improving as compute increases. 51 | 52 |

53 | 54 |

55 | 56 | 57 | ## Getting strated 58 | 59 | We use Gymnasium 1.0 API interface which provides seamless integration with diverse RL environments. 60 | 61 | ### Docker 62 | 63 | We provide a `Dockerfile` for easy installation. You can build the docker image by running. 64 | 65 | ``` 66 | docker build . -t scale_rl . 67 | docker run --gpus all -v .:/home/user/scale_rl -it scale_rl /bin/bash 68 | ``` 69 | 70 | ### Pip/Conda 71 | 72 | If you prefer to install dependencies manually, start by installing dependencies via conda by following the guidelines. 73 | ``` 74 | # Use pip 75 | pip install -r deps/requirements.txt 76 | 77 | # Or use conda 78 | conda env create -f deps/environment.yaml 79 | ``` 80 | 81 | #### Jax for GPU 82 | ``` 83 | pip install -U "jax[cuda12]==0.4.25" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html 84 | # If you want to execute multiple runs with a single GPU, we recommend to set this variable. 85 | export XLA_PYTHON_CLIENT_PREALLOCATE=false 86 | ``` 87 | 88 | #### Mujoco 89 | Please see installation instruction at [MuJoCo](https://github.com/google-deepmind/mujoco). 90 | ``` 91 | # Additional environmental evariables for headless rendering 92 | export MUJOCO_GL="egl" 93 | export MUJOCO_EGL_DEVICE_ID="0" 94 | export MKL_SERVICE_FORCE_INTEL="0" 95 | ``` 96 | 97 | #### Humanoid Bench 98 | 99 | ``` 100 | git clone https://github.com/joonleesky/humanoid-bench 101 | cd humanoid-bench 102 | pip install -e . 103 | ``` 104 | 105 | #### Myosuite 106 | ``` 107 | git clone --recursive https://github.com/joonleesky/myosuite 108 | cd myosuite 109 | pip install -e . 110 | ``` 111 | 112 | 113 | ## Example usage 114 | 115 | We provide examples on how to train SAC agents with SimBa architecture. 116 | 117 | To run a single online RL experiment 118 | ``` 119 | python run_online.py 120 | ``` 121 | 122 | To run a single offline RL experiment 123 | ``` 124 | python run_offline.py 125 | ``` 126 | 127 | To benchmark the algorithm with all environments 128 | ``` 129 | python run_parallel.py \ 130 | --task all \ 131 | --device_ids \ 132 | --num_seeds \ 133 | --num_exp_per_device 134 | ``` 135 | 136 | 137 | ## Analysis 138 | 139 | Please refer to `/analysis` to visualize the experimental results provided in the paper. 140 | 141 | 142 | ## License 143 | This project is released under the [Apache 2.0 license](/LICENSE). 144 | 145 | ## Citation 146 | 147 | If you find our work useful, please consider citing our paper as follows: 148 | 149 | ``` 150 | @article{lee2025hyperspherical, 151 | title={Hyperspherical Normalization for Scalable Deep Reinforcement Learning}, 152 | author={Lee, Hojoon and Lee, Youngdo and Seno, Takuma and Kim, Donghu and Stone, Peter and Choo, Jaegul}, 153 | journal={arXiv preprint arXiv:2502.15280}, 154 | year={2025} 155 | } 156 | ``` 157 | -------------------------------------------------------------------------------- /scale_rl/agents/jax_utils/network.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial 3 | from typing import Any, Optional, Tuple 4 | 5 | import flax 6 | import flax.linen as nn 7 | import jax 8 | import jax.numpy as jnp 9 | import optax 10 | from flax.training import checkpoints 11 | 12 | PRNGKey = jnp.ndarray 13 | 14 | 15 | @flax.struct.dataclass 16 | class Network: 17 | network_def: nn.Module = flax.struct.field(pytree_node=False) 18 | params: flax.core.FrozenDict[str, Any] 19 | tx: Optional[optax.GradientTransformation] = flax.struct.field(pytree_node=False) 20 | opt_state: Optional[optax.OptState] = None 21 | update_step: int = 0 22 | """ 23 | dataclass decorator makes custom class to be passed safely to Jax. 24 | https://flax.readthedocs.io/en/latest/api_reference/flax.struct.html 25 | 26 | Network class wraps network & optimizer to easily optimize the network under the hood. 27 | 28 | args: 29 | network_def: flax.nn style of network definition. 30 | params: network parameters. 31 | tx: optimizer (e.g., optax.Adam). 32 | opt_state: current state of the optimizer (e.g., beta_1 in Adam). 33 | update_step: number of update step so far. 34 | """ 35 | 36 | @classmethod 37 | def create( 38 | cls, 39 | network_def: nn.Module, 40 | network_inputs: flax.core.FrozenDict[str, jnp.ndarray], 41 | tx: Optional[optax.GradientTransformation] = None, 42 | ) -> "Network": 43 | variables = network_def.init(**network_inputs) 44 | params = variables.pop("params") 45 | 46 | if tx is not None: 47 | opt_state = tx.init(params) 48 | else: 49 | opt_state = None 50 | 51 | network = cls( 52 | network_def=network_def, 53 | params=params, 54 | tx=tx, 55 | opt_state=opt_state, 56 | ) 57 | 58 | return network 59 | 60 | def __call__(self, *args, **kwargs): 61 | return self.network_def.apply({"params": self.params}, *args, **kwargs) 62 | 63 | def save(self, path: str, step: int = 0, keep: int = 1) -> None: 64 | """ 65 | Save parameters, optimizer state, and other metadata to the given path. 66 | """ 67 | ckpt = { 68 | "params": self.params, 69 | "opt_state": self.opt_state, 70 | "update_step": self.update_step, 71 | } 72 | checkpoints.save_checkpoint( 73 | ckpt_dir=path, 74 | target=ckpt, 75 | step=step, 76 | overwrite=True, 77 | keep=keep, 78 | ) 79 | 80 | def load(self, path: str, param_key: str = None, only_param: bool = False) -> None: 81 | """ 82 | Load parameters, optimizer state, and other metadata from the given path. 83 | args: 84 | path (str): The path to the checkpoint directory. 85 | param_key (str): If specified, only the subset of parameters is loaded. 86 | only_param (bool): If True, only the parameters are loaded. 87 | """ 88 | ckpt = checkpoints.restore_checkpoint(ckpt_dir=path, target=None) 89 | 90 | def _key_exists(d, key): 91 | """ 92 | Recursively check if key exists in dictionary d. 93 | """ 94 | if key in d: 95 | return True 96 | return any(_key_exists(v, key) for v in d.values() if isinstance(v, dict)) 97 | 98 | def _recursive_replace(source, target, key): 99 | """ 100 | Recursively replace the value of key in source from target. 101 | """ 102 | for k, v in source.items(): 103 | if k == key: 104 | source[k] = target[k] 105 | elif ( 106 | isinstance(v, dict) and k in source and isinstance(target[k], dict) 107 | ): 108 | _recursive_replace(source[k], target[k], key) 109 | return source 110 | 111 | if param_key: 112 | if not _key_exists(self.params, param_key): 113 | raise ValueError(f"The key '{param_key}' is missing") 114 | new_params = copy.deepcopy(self.params) 115 | new_params = _recursive_replace(new_params, ckpt["params"], param_key) 116 | else: 117 | new_params = ckpt["params"] 118 | 119 | if only_param: 120 | network = self.replace(params=new_params) 121 | else: 122 | # self.opt_state: named_tuple 123 | # ckpt['opt_state]: dictionary 124 | new_opt_state = jax.tree_util.tree_unflatten( 125 | jax.tree_util.tree_structure(self.opt_state), 126 | jax.tree_util.tree_leaves(ckpt["opt_state"]), 127 | ) 128 | 129 | network = self.replace( 130 | params=new_params, 131 | opt_state=new_opt_state, 132 | update_step=ckpt["update_step"], 133 | ) 134 | 135 | return network 136 | 137 | def apply(self, *args, **kwargs): 138 | return self.network_def.apply(*args, **kwargs) 139 | 140 | def apply_gradient(self, loss_fn, get_info=True) -> Tuple[Any, "Network"]: 141 | grad_fn = jax.grad(loss_fn, has_aux=True) 142 | grads, info = grad_fn(self.params) 143 | info["_grads"] = grads 144 | is_fin = True 145 | 146 | updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) 147 | new_params = optax.apply_updates(self.params, updates) 148 | 149 | network = self.replace( 150 | params=jax.tree_util.tree_map( 151 | partial(jnp.where, is_fin), new_params, self.params 152 | ), 153 | opt_state=jax.tree_util.tree_map( 154 | partial(jnp.where, is_fin), new_opt_state, self.opt_state 155 | ), 156 | update_step=self.update_step + 1, 157 | ) 158 | 159 | return network, info 160 | -------------------------------------------------------------------------------- /run_offline.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import math 4 | import random 5 | 6 | import hydra 7 | import numpy as np 8 | import omegaconf 9 | import tqdm 10 | from dotmap import DotMap 11 | 12 | from scale_rl.agents import create_agent 13 | from scale_rl.buffers import create_buffer 14 | from scale_rl.common import WandbTrainerLogger 15 | from scale_rl.envs import create_dataset, create_envs, get_normalized_score 16 | from scale_rl.evaluation import evaluate, record_video 17 | 18 | 19 | def run(args): 20 | ############################### 21 | # configs 22 | ############################### 23 | 24 | args = DotMap(args) 25 | config_path = args.config_path 26 | config_name = args.config_name 27 | overrides = args.overrides 28 | 29 | hydra.initialize(version_base=None, config_path=config_path) 30 | cfg = hydra.compose(config_name=config_name, overrides=overrides) 31 | 32 | def eval_resolver(s: str): 33 | return eval(s) 34 | 35 | omegaconf.OmegaConf.register_new_resolver("eval", eval_resolver) 36 | omegaconf.OmegaConf.resolve(cfg) 37 | 38 | np.random.seed(cfg.seed) 39 | random.seed(cfg.seed) 40 | 41 | ############################# 42 | # envs 43 | ############################# 44 | train_env, eval_env = create_envs(**cfg.env) 45 | observation_space = train_env.observation_space 46 | action_space = train_env.action_space 47 | 48 | dataset = create_dataset(cfg.env.env_type, cfg.env.env_name) 49 | 50 | ############################# 51 | # buffer 52 | ############################# 53 | cfg.buffer.max_length = len(dataset) 54 | buffer = create_buffer( 55 | observation_space=observation_space, action_space=action_space, **cfg.buffer 56 | ) 57 | buffer.reset() 58 | 59 | ############################# 60 | # fill buffer 61 | ############################# 62 | 63 | for i, timestep in tqdm.tqdm( 64 | list(enumerate(dataset)), desc="Filling buffer with dataset" 65 | ): 66 | buffer.add(timestep) 67 | 68 | ############################# 69 | # agent 70 | ############################# 71 | 72 | batch_size = cfg.buffer.sample_batch_size 73 | cfg.num_interaction_steps = int((len(dataset) / batch_size) * cfg.num_epochs) 74 | cfg.save_checkpoint_per_interaction_step = cfg.num_interaction_steps 75 | cfg.agent.learning_rate_decay_step = int( 76 | cfg.agent.learning_rate_decay_rate 77 | * cfg.num_interaction_steps 78 | * cfg.updates_per_interaction_step 79 | ) 80 | 81 | agent = create_agent( 82 | observation_space=observation_space, 83 | action_space=action_space, 84 | cfg=cfg.agent, 85 | ) 86 | 87 | # iterate over buffer to update normalizers 88 | num_batches = int(np.floor(len(dataset) / batch_size)) 89 | for batch_num in tqdm.tqdm(range(num_batches), desc="updating normalizers"): 90 | start_idx = batch_num * batch_size 91 | end_idx = start_idx + batch_size 92 | batch_indices = np.arange(start_idx, end_idx) 93 | batch = buffer.sample(sample_idxs=batch_indices) 94 | 95 | # update normalizers 96 | agent.sample_actions(i, copy.deepcopy(batch), training=True) 97 | 98 | ############################# 99 | # train offline 100 | ############################# 101 | 102 | logger = WandbTrainerLogger(cfg) 103 | 104 | # initial evaluation 105 | eval_info = evaluate(agent, eval_env, cfg.num_eval_episodes) 106 | eval_info["avg_normalized_return"] = get_normalized_score( 107 | cfg.env.env_type, cfg.env.env_name, eval_info["avg_return"] 108 | ) 109 | logger.update_metric(**eval_info) 110 | logger.log_metric(step=0) 111 | logger.reset() 112 | 113 | # start training 114 | update_step = 0 115 | for interaction_step in tqdm.tqdm( 116 | range(1, int(cfg.num_interaction_steps + 1)), smoothing=0.1 117 | ): 118 | # update network 119 | batch = buffer.sample() 120 | update_info = agent.update(update_step, batch) 121 | logger.update_metric(**update_info) 122 | update_step += 1 123 | 124 | # evaluation 125 | if interaction_step % cfg.evaluation_per_interaction_step == 0: 126 | eval_info = evaluate(agent, eval_env, cfg.num_eval_episodes) 127 | eval_info["avg_normalized_return"] = get_normalized_score( 128 | cfg.env.env_type, cfg.env.env_name, eval_info["avg_return"] 129 | ) 130 | logger.update_metric(**eval_info) 131 | 132 | # metrics 133 | if interaction_step % cfg.metrics_per_interaction_step == 0: 134 | batch = buffer.sample() 135 | metrics_info = agent.get_metrics(batch, update_info) 136 | if metrics_info: 137 | logger.update_metric(**metrics_info) 138 | 139 | # TODO Support video recording 140 | # # video recording 141 | # if offline_step % cfg.offline.recording_per_offline_step == 0: 142 | # video_info = record_video(agent, eval_env, cfg.num_record_episodes) 143 | # logger.update_metric(**video_info) 144 | 145 | # logging 146 | if interaction_step % cfg.logging_per_interaction_step == 0: 147 | logger.log_metric(step=interaction_step) 148 | logger.reset() 149 | 150 | # final evaluation 151 | eval_info = evaluate(agent, eval_env, cfg.num_eval_episodes) 152 | eval_info["avg_normalized_return"] = get_normalized_score( 153 | cfg.env.env_type, cfg.env.env_name, eval_info["avg_return"] 154 | ) 155 | logger.update_metric(**eval_info) 156 | logger.log_metric(step=interaction_step) 157 | logger.reset() 158 | 159 | train_env.close() 160 | eval_env.close() 161 | 162 | 163 | if __name__ == "__main__": 164 | parser = argparse.ArgumentParser(allow_abbrev=False) 165 | parser.add_argument("--config_path", type=str, default="./configs") 166 | parser.add_argument("--config_name", type=str, default="offline_rl") 167 | parser.add_argument("--overrides", action="append", default=[]) 168 | args = parser.parse_args() 169 | 170 | run(vars(args)) 171 | -------------------------------------------------------------------------------- /scale_rl/agents/simbaV2/simbaV2_layer.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import flax.linen as nn 4 | import jax 5 | import jax.numpy as jnp 6 | from tensorflow_probability.substrates import jax as tfp 7 | 8 | from scale_rl.agents.simbaV2.simbaV2_update import l2normalize 9 | 10 | tfd = tfp.distributions 11 | tfb = tfp.bijectors 12 | 13 | 14 | class Scaler(nn.Module): 15 | dim: int 16 | init: float = 1.0 17 | scale: float = 1.0 18 | 19 | def setup(self): 20 | self.scaler = self.param( 21 | "scaler", 22 | nn.initializers.constant(1.0 * self.scale), 23 | self.dim, 24 | ) 25 | self.forward_scaler = self.init / self.scale 26 | 27 | def __call__(self, x): 28 | return self.scaler * self.forward_scaler * x 29 | 30 | 31 | class HyperDense(nn.Module): 32 | hidden_dim: int 33 | 34 | def setup(self): 35 | self.w = nn.Dense( 36 | name="hyper_dense", 37 | features=self.hidden_dim, 38 | kernel_init=nn.initializers.orthogonal(scale=1.0, column_axis=0), 39 | use_bias=False, # important! 40 | ) 41 | 42 | def __call__(self, x): 43 | return self.w(x) 44 | 45 | 46 | class HyperMLP(nn.Module): 47 | hidden_dim: int 48 | out_dim: int 49 | scaler_init: float 50 | scaler_scale: float 51 | eps: float = 1e-8 52 | 53 | def setup(self): 54 | self.w1 = HyperDense(self.hidden_dim) 55 | self.scaler = Scaler(self.hidden_dim, self.scaler_init, self.scaler_scale) 56 | self.w2 = HyperDense(self.out_dim) 57 | 58 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 59 | x = self.w1(x) 60 | x = self.scaler(x) 61 | # `eps` is required to prevent zero vector. 62 | x = nn.relu(x) + self.eps 63 | x = self.w2(x) 64 | x = l2normalize(x, axis=-1) 65 | return x 66 | 67 | 68 | class HyperEmbedder(nn.Module): 69 | hidden_dim: int 70 | scaler_init: float 71 | scaler_scale: float 72 | c_shift: float 73 | 74 | def setup(self): 75 | self.w = HyperDense(self.hidden_dim) 76 | self.scaler = Scaler(self.hidden_dim, self.scaler_init, self.scaler_scale) 77 | 78 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 79 | new_axis = jnp.ones((x.shape[:-1] + (1,))) * self.c_shift 80 | x = jnp.concatenate([x, new_axis], axis=-1) 81 | x = l2normalize(x, axis=-1) 82 | x = self.w(x) 83 | x = self.scaler(x) 84 | x = l2normalize(x, axis=-1) 85 | 86 | return x 87 | 88 | 89 | class HyperLERPBlock(nn.Module): 90 | hidden_dim: int 91 | scaler_init: float 92 | scaler_scale: float 93 | alpha_init: float 94 | alpha_scale: float 95 | 96 | expansion: int = 4 97 | 98 | def setup(self): 99 | self.mlp = HyperMLP( 100 | hidden_dim=self.hidden_dim * self.expansion, 101 | out_dim=self.hidden_dim, 102 | scaler_init=self.scaler_init / math.sqrt(self.expansion), 103 | scaler_scale=self.scaler_scale / math.sqrt(self.expansion), 104 | ) 105 | self.alpha_scaler = Scaler( 106 | self.hidden_dim, 107 | init=self.alpha_init, 108 | scale=self.alpha_scale, 109 | ) 110 | 111 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 112 | residual = x 113 | x = self.mlp(x) 114 | x = residual + self.alpha_scaler(x - residual) 115 | x = l2normalize(x, axis=-1) 116 | 117 | return x 118 | 119 | 120 | class HyperNormalTanhPolicy(nn.Module): 121 | hidden_dim: int 122 | action_dim: int 123 | scaler_init: float 124 | scaler_scale: float 125 | log_std_min: float = -10.0 126 | log_std_max: float = 2.0 127 | 128 | def setup(self): 129 | self.mean_w1 = HyperDense(self.hidden_dim) 130 | self.mean_scaler = Scaler(self.hidden_dim, self.scaler_init, self.scaler_scale) 131 | self.mean_w2 = HyperDense(self.action_dim) 132 | self.mean_bias = self.param( 133 | "mean_bias", nn.initializers.zeros, (self.action_dim,) 134 | ) 135 | 136 | self.std_w1 = HyperDense(self.hidden_dim) 137 | self.std_scaler = Scaler(self.hidden_dim, self.scaler_init, self.scaler_scale) 138 | self.std_w2 = HyperDense(self.action_dim) 139 | self.std_bias = self.param( 140 | "std_bias", nn.initializers.zeros, (self.action_dim,) 141 | ) 142 | 143 | def __call__( 144 | self, 145 | x: jnp.ndarray, 146 | temperature: float = 1.0, 147 | ) -> tfd.Distribution: 148 | mean = self.mean_w1(x) 149 | mean = self.mean_scaler(mean) 150 | mean = self.mean_w2(mean) + self.mean_bias 151 | 152 | log_std = self.std_w1(x) 153 | log_std = self.std_scaler(log_std) 154 | log_std = self.std_w2(log_std) + self.std_bias 155 | 156 | # normalize log-stds for stability 157 | log_std = self.log_std_min + (self.log_std_max - self.log_std_min) * 0.5 * ( 158 | 1 + nn.tanh(log_std) 159 | ) 160 | 161 | # N(mu, exp(log_sigma)) 162 | dist = tfd.MultivariateNormalDiag( 163 | loc=mean, 164 | scale_diag=jnp.exp(log_std) * temperature, 165 | ) 166 | 167 | # tanh(N(mu, sigma)) 168 | dist = tfd.TransformedDistribution(distribution=dist, bijector=tfb.Tanh()) 169 | 170 | info = {} 171 | return dist, info 172 | 173 | 174 | class HyperCategoricalValue(nn.Module): 175 | hidden_dim: int 176 | num_bins: int 177 | min_v: float 178 | max_v: float 179 | scaler_init: float 180 | scaler_scale: float 181 | 182 | def setup(self): 183 | self.w1 = HyperDense(self.hidden_dim) 184 | self.scaler = Scaler(self.hidden_dim, self.scaler_init, self.scaler_scale) 185 | self.w2 = HyperDense(self.num_bins) 186 | self.bias = self.param("value_bias", nn.initializers.zeros, (self.num_bins,)) 187 | self.bin_values = jnp.linspace( 188 | start=self.min_v, stop=self.max_v, num=self.num_bins 189 | ).reshape(1, -1) 190 | 191 | def __call__( 192 | self, 193 | x: jnp.ndarray, 194 | ) -> jnp.ndarray: 195 | value = self.w1(x) 196 | value = self.scaler(value) 197 | value = self.w2(value) + self.bias 198 | 199 | # return log probability of bins 200 | log_prob = nn.log_softmax(value, axis=1) 201 | value = jnp.sum(jnp.exp(log_prob) * self.bin_values, axis=1) 202 | 203 | info = {"log_prob": log_prob} 204 | return value, info 205 | -------------------------------------------------------------------------------- /run_online.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import random 5 | import sys 6 | 7 | import hydra 8 | import numpy as np 9 | import omegaconf 10 | import tqdm 11 | from dotmap import DotMap 12 | 13 | from scale_rl.agents import create_agent 14 | from scale_rl.buffers import create_buffer 15 | from scale_rl.common import WandbTrainerLogger 16 | from scale_rl.envs import create_envs 17 | from scale_rl.evaluation import evaluate, record_video 18 | 19 | 20 | def run(args): 21 | ############################### 22 | # configs 23 | ############################### 24 | 25 | args = DotMap(args) 26 | config_path = args.config_path 27 | config_name = args.config_name 28 | overrides = args.overrides 29 | 30 | hydra.initialize(version_base=None, config_path=config_path) 31 | cfg = hydra.compose(config_name=config_name, overrides=overrides) 32 | 33 | def eval_resolver(s: str): 34 | return eval(s) 35 | 36 | omegaconf.OmegaConf.register_new_resolver("eval", eval_resolver) 37 | omegaconf.OmegaConf.resolve(cfg) 38 | 39 | np.random.seed(cfg.seed) 40 | random.seed(cfg.seed) 41 | 42 | ############################# 43 | # envs 44 | ############################# 45 | train_env, eval_env = create_envs(**cfg.env) 46 | 47 | observation_space = train_env.observation_space 48 | action_space = train_env.action_space 49 | 50 | ############################# 51 | # buffer 52 | ############################# 53 | buffer = create_buffer( 54 | observation_space=observation_space, action_space=action_space, **cfg.buffer 55 | ) 56 | buffer.reset() 57 | 58 | ############################# 59 | # agent 60 | ############################# 61 | 62 | # Since the network architecture is typically tied to the learning algorithm, 63 | # we opted not to fully modularize the network for the sake of readability. 64 | # Therefore, for each algorithm, the network is implemented within its respective directory. 65 | 66 | agent = create_agent( 67 | observation_space=observation_space, 68 | action_space=action_space, 69 | cfg=cfg.agent, 70 | ) 71 | 72 | ############################# 73 | # train 74 | ############################# 75 | 76 | logger = WandbTrainerLogger(cfg) 77 | 78 | # load model if given 79 | script_dir = os.path.dirname(os.path.abspath(sys.argv[0])) 80 | save_path = script_dir + "/" + cfg.save_path 81 | if cfg.load_path: 82 | load_path = script_dir + "/" + cfg.load_path 83 | agent.load(load_path) 84 | 85 | # initial evaluation 86 | eval_info = evaluate(agent, eval_env, cfg.num_eval_episodes) 87 | logger.update_metric(**eval_info) 88 | logger.log_metric(step=0) 89 | logger.reset() 90 | 91 | # start training 92 | update_step = 0 93 | update_counter = 0 94 | observations, env_infos = train_env.reset() 95 | timestep = None 96 | for interaction_step in tqdm.tqdm( 97 | range(1, int(cfg.num_interaction_steps + 1)), smoothing=0.1 98 | ): 99 | # collect data 100 | # While using random actions until buffer.can_sample(), 101 | # we feed data into agent to compute statistics within a wrapper. 102 | if timestep: 103 | actions = agent.sample_actions( 104 | interaction_step, prev_timestep=timestep, training=True 105 | ) 106 | if buffer.can_sample() is False: 107 | actions = train_env.action_space.sample() 108 | 109 | next_observations, rewards, terminateds, truncateds, env_infos = train_env.step( 110 | actions 111 | ) 112 | next_buffer_observations = next_observations.copy() 113 | for env_idx in range(cfg.num_train_envs): 114 | if terminateds[env_idx] or truncateds[env_idx]: 115 | next_buffer_observations[env_idx] = env_infos["final_obs"][env_idx] 116 | 117 | timestep = { 118 | "observation": observations, 119 | "action": actions, 120 | "reward": rewards, 121 | "terminated": terminateds, 122 | "truncated": truncateds, 123 | "next_observation": next_buffer_observations, 124 | } 125 | buffer.add(timestep) 126 | timestep["next_observation"] = next_observations 127 | observations = next_observations 128 | 129 | if buffer.can_sample(): 130 | # update network 131 | # updates_per_interaction_step can be below 1.0 132 | update_counter += cfg.updates_per_interaction_step 133 | while update_counter >= 1: 134 | batch = buffer.sample() 135 | update_info = agent.update(update_step, batch) 136 | logger.update_metric(**update_info) 137 | update_counter -= 1 138 | update_step += 1 139 | 140 | # evaluation 141 | if interaction_step % cfg.evaluation_per_interaction_step == 0: 142 | eval_info = evaluate(agent, eval_env, cfg.num_eval_episodes) 143 | logger.update_metric(**eval_info) 144 | 145 | # metrics 146 | if interaction_step % cfg.metrics_per_interaction_step == 0: 147 | batch = buffer.sample() 148 | metrics_info = agent.get_metrics(batch, update_info) 149 | if metrics_info: 150 | logger.update_metric(**metrics_info) 151 | 152 | # video recording 153 | if interaction_step % cfg.recording_per_interaction_step == 0: 154 | video_info = record_video(agent, eval_env, cfg.num_record_episodes) 155 | logger.update_metric(**video_info) 156 | 157 | # logging 158 | if interaction_step % cfg.logging_per_interaction_step == 0: 159 | # using env steps simplifies the comparison with the performance reported in the paper. 160 | env_step = interaction_step * cfg.action_repeat * cfg.num_train_envs 161 | logger.log_metric(step=env_step) 162 | logger.reset() 163 | 164 | # checkpointing 165 | if interaction_step % cfg.save_checkpoint_per_interaction_step == 0: 166 | agent.save(save_path) 167 | 168 | # save buffer 169 | if interaction_step % cfg.save_buffer_per_interaction_step == 0: 170 | buffer.save(save_path) 171 | 172 | train_env.close() 173 | eval_env.close() 174 | 175 | 176 | if __name__ == "__main__": 177 | parser = argparse.ArgumentParser(allow_abbrev=False) 178 | parser.add_argument("--config_path", type=str, default="./configs") 179 | parser.add_argument("--config_name", type=str, default="online_rl") 180 | parser.add_argument("--overrides", action="append", default=[]) 181 | args = parser.parse_args() 182 | 183 | run(vars(args)) 184 | -------------------------------------------------------------------------------- /run_parallel.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import multiprocessing as mp 4 | import os 5 | import time 6 | 7 | import numpy as np 8 | 9 | 10 | def run_with_device(server, device_id, config_path, config_name, overrides): 11 | os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) 12 | os.environ["MUJOCO_EGL_DEVICE_ID"] = str(device_id) 13 | os.environ["OMP_NUM_THREADS"] = "2" 14 | 15 | # Now import the main script 16 | if config_name == "online_rl": 17 | from run_online import run 18 | elif config_name == "offline_rl": 19 | from run_offline import run 20 | else: 21 | raise NotImplementedError 22 | 23 | args = { 24 | "config_path": config_path, 25 | "config_name": config_name, 26 | "overrides": overrides, 27 | } 28 | run(args) 29 | 30 | 31 | if __name__ == "__main__": 32 | parser = argparse.ArgumentParser(allow_abbrev=False) 33 | parser.add_argument("--config_path", type=str, default="./configs") 34 | parser.add_argument("--config_name", type=str, default="online_rl") 35 | parser.add_argument("--agent_config", type=str, default="simbaV2") 36 | parser.add_argument("--env_type", type=str, default="dmc_hard") 37 | parser.add_argument("--device_ids", default=[0], nargs="+") 38 | parser.add_argument("--num_seeds", type=int, default=1) 39 | parser.add_argument("--num_exp_per_device", type=int, default=1) 40 | parser.add_argument("--server", type=str, default="local") 41 | parser.add_argument("--group_name", type=str, default="test") 42 | parser.add_argument("--exp_name", type=str, default="test") 43 | parser.add_argument("--overrides", action="append", default=[]) 44 | 45 | args = vars(parser.parse_args()) 46 | seeds = (np.arange(args.pop("num_seeds")) * 1000).tolist() 47 | device_ids = args.pop("device_ids") 48 | os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, device_ids)) 49 | 50 | num_devices = len(device_ids) 51 | num_exp_per_device = args.pop("num_exp_per_device") 52 | pool_size = num_devices * num_exp_per_device 53 | 54 | # create configurations for child run 55 | experiments = [] 56 | config_path = args.pop("config_path") 57 | config_name = args.pop("config_name") 58 | server = args.pop("server") 59 | group_name = args.pop("group_name") 60 | exp_name = args.pop("exp_name") 61 | agent_config = args.pop("agent_config") 62 | 63 | # import library after CUDA_VISIBLE_DEVICES operation 64 | from scale_rl.envs.d4rl import D4RL_MUJOCO 65 | from scale_rl.envs.dmc import DMC_EASY_MEDIUM, DMC_HARD 66 | from scale_rl.envs.humanoid_bench import HB_LOCOMOTION_NOHAND 67 | from scale_rl.envs.mujoco import MUJOCO_ALL 68 | from scale_rl.envs.myosuite import MYOSUITE_TASKS 69 | 70 | env_type = args.pop("env_type") 71 | 72 | ################### 73 | # offline 74 | if env_type == "d4rl_mujoco": 75 | envs = D4RL_MUJOCO 76 | env_configs = ["d4rl"] * len(envs) 77 | 78 | ################### 79 | # online 80 | elif env_type == "mujoco": 81 | envs = MUJOCO_ALL 82 | env_configs = [env_type] * len(envs) 83 | 84 | elif env_type == "dmc_em": 85 | envs = DMC_EASY_MEDIUM 86 | env_configs = ["dmc"] * len(envs) 87 | 88 | elif env_type == "dmc_hard": 89 | envs = DMC_HARD 90 | env_configs = ["dmc"] * len(envs) 91 | 92 | elif env_type == "myosuite": 93 | envs = MYOSUITE_TASKS 94 | env_configs = [env_type] * len(envs) 95 | 96 | elif env_type == "hb_locomotion": 97 | envs = HB_LOCOMOTION_NOHAND 98 | env_configs = [env_type] * len(envs) 99 | 100 | elif env_type == "all": 101 | envs = ( 102 | MUJOCO_ALL 103 | + DMC_EASY_MEDIUM 104 | + DMC_HARD 105 | + MYOSUITE_TASKS 106 | + HB_LOCOMOTION_NOHAND 107 | ) 108 | env_configs = ( 109 | ["mujoco"] * len(MUJOCO_ALL) 110 | + ["dmc"] * len(DMC_EASY_MEDIUM) 111 | + ["dmc"] * len(DMC_HARD) 112 | + ["myosuite"] * len(MYOSUITE_TASKS) 113 | + ["hb_locomotion"] * len(HB_LOCOMOTION_NOHAND) 114 | ) 115 | 116 | else: 117 | raise NotImplementedError 118 | 119 | for seed in seeds: 120 | for idx, env_name in enumerate(envs): 121 | exp = copy.deepcopy(args) # copy overriding arguments 122 | exp["config_path"] = config_path 123 | exp["config_name"] = config_name 124 | 125 | exp["overrides"].append("agent=" + agent_config) 126 | exp["overrides"].append("env=" + env_configs[idx]) 127 | exp["overrides"].append("env.env_name=" + env_name) 128 | 129 | exp["overrides"].append("server=" + server) 130 | exp["overrides"].append("group_name=" + group_name) 131 | exp["overrides"].append("exp_name=" + exp_name) 132 | exp["overrides"].append("seed=" + str(seed)) 133 | 134 | experiments.append(exp) 135 | print(exp) 136 | 137 | # run parallel experiments 138 | # https://docs.python.org/3.5/library/multiprocessing.html#contexts-and-start-methods 139 | mp.set_start_method("spawn") 140 | available_gpus = device_ids 141 | process_dict = {gpu_id: [] for gpu_id in device_ids} 142 | 143 | for exp in experiments: 144 | wait = True 145 | # wait until there exists a finished process 146 | while wait: 147 | # Find all finished processes and register available GPU 148 | for gpu_id, processes in process_dict.items(): 149 | for process in processes: 150 | if not process.is_alive(): 151 | print(f"Process {process.pid} on GPU {gpu_id} finished.") 152 | processes.remove(process) 153 | if gpu_id not in available_gpus: 154 | available_gpus.append(gpu_id) 155 | 156 | for gpu_id, processes in process_dict.items(): 157 | if len(processes) < num_exp_per_device: 158 | wait = False 159 | gpu_id, processes = min( 160 | process_dict.items(), key=lambda x: len(x[1]) 161 | ) 162 | break 163 | 164 | time.sleep(10) 165 | 166 | # get running processes in the gpu 167 | processes = process_dict[gpu_id] 168 | exp["device_id"] = str(gpu_id) 169 | process = mp.Process( 170 | target=run_with_device, 171 | args=( 172 | server, 173 | exp["device_id"], 174 | exp["config_path"], 175 | exp["config_name"], 176 | exp["overrides"], 177 | ), 178 | ) 179 | process.start() 180 | processes.append(process) 181 | print(f"Process {process.pid} on GPU {gpu_id} started.") 182 | 183 | # check if the GPU has reached its maximum number of processes 184 | if len(processes) == num_exp_per_device: 185 | available_gpus.remove(gpu_id) 186 | -------------------------------------------------------------------------------- /docs/style.css: -------------------------------------------------------------------------------- 1 | body {background-color: #fdfdfd; color: rgb(54, 54, 54); margin: 0; font-size: 1em;} 2 | h1, h2, h3, h4, h5 {text-align: center;} 3 | h1 {margin-bottom: 0.5em; font-size: 2.2em;} 4 | h2 {margin-top: 1em; font-weight: bold !important;} 5 | h1, h2, h3, h4, h5, a, p, span, body {font-weight: normal; font-family: "Google Sans", sans-serif;} 6 | .header {background-color: #fdfdfd; width: 100%; padding-top: 48px; padding-bottom: 8px;} 7 | .header-menu {background-color: #efeff3; width: 100%; padding: 16px 0;} 8 | .header-menu-content {max-width: 960px; margin: auto;} 9 | .header-menu-item {display: inline-block; margin-left: 16px; margin-right: 16px; font-size: 1.2em;} 10 | .links {width: 100%; margin: auto; text-align: center; padding-top: 8px;} 11 | .links a {margin-left: 8px;} 12 | .content {max-width: 960px; margin: auto; margin-top: 48px; margin-bottom: 64px;} 13 | a, h2 {color: rgb(100, 142, 246); text-decoration: none;} 14 | a:hover {color: #fa6d6d;} 15 | .column-left {float: left; width: 50%;} 16 | .column-right {float: right; width: 50%; margin-top: 1em;} 17 | @media (max-width: 920px) { 18 | .column-left { float: none; width: 55%; margin: auto;} 19 | .column-right { float: none; width: 80%; margin: auto; margin-top: -4em;} 20 | } 21 | @media (max-width: 600px) { 22 | .column-left { float: none; width: 75%; margin: auto;} 23 | .column-right { float: none; width: 80%; margin: auto; margin-top: -4em;} 24 | } 25 | 26 | .scaling-column-left {float: left; width: 50%;} 27 | .scaling-column-right {float: right; width: 50%;} 28 | @media (max-width: 920px) { 29 | .scaling-column-left { float: none; width: 75%; margin: auto;} 30 | .scaling-column-right { float: none; width: 75%; margin: auto; margin-top: -4em;} 31 | } 32 | @media (max-width: 600px) { 33 | .scaling-column-left { float: none; width: 75%; margin: auto;} 34 | .scaling-column-right { float: none; width: 75%; margin: auto; margin-top: -4em;} 35 | } 36 | 37 | .nobreak {white-space: nowrap;} 38 | .hr {width: 100%; height: 1px; margin: 48px 0; background-color: #d6dbdf;} 39 | p {line-height: 1.4em; text-align: justify;} 40 | .abstract {max-width: 90%; margin: auto;} 41 | .citation {max-width: 95%; margin: auto; margin-bottom: 1em;} 42 | .math {font-family: "Computer Modern Sans", sans-serif; font-style: italic;} 43 | sub, sup {line-height: 0;} 44 | .figure {width: 100%; min-height: 120px; margin: 2em 0; background-repeat: no-repeat; background-position: center; background-size: contain;} 45 | .figure-caption {margin: auto; max-width: 95%; margin-top: 24px; margin-bottom: 24px;} 46 | .youtube-container {background-color: #000; margin-top: 32px;} 47 | .youtube {display: block; margin: auto; width: 960px; padding-top: 20px; padding-bottom: 20px;} 48 | .content-video {width: 100%; margin: 0; text-align: center;} 49 | .content-video-container {width: 100%; max-width: 860px; margin: auto} 50 | .content-video-frame {display: inline-block; width: 24.5%; text-align: center;} 51 | .content-video-frame.medium {width: 30%;} 52 | .content-video-frame.large {width: 36%;} 53 | .content-video-frame.huge {width: 76%;} 54 | .content-video-frame span {display: inline-block; margin-bottom: 12px; font-weight: bold; font-size: 1.2em;} 55 | .hidden-content {display: none; background-color: #f3f3f6; border: 0; border-radius: 12px; margin-top: -18px; margin-bottom: 32px; padding-top: 50px; padding-bottom: 32px;} 56 | .legend {display: inline-block; border: 1px solid #966565; padding: 8px; margin-top: 12px; text-align: center;} 57 | .legend-item {display: inline-block; margin-left: 6px; margin-right: 6px; font-size: 12px;} 58 | .legend-symbol {font-weight: bold; margin-right: 6px; font-size: 20px;} 59 | .page {display: inline-block; width: 84px; height: 108px; border: 1px solid #bbb; margin: 2px; background-repeat: no-repeat; background-position: center; background-size: contain;} 60 | table.authors {width: 100%; max-width: 700px; margin: auto; margin-bottom: 16px; text-align: center;} 61 | table.authors a {padding: 6px 0; display: inline-block; font-weight: normal; font-size: 1.3em;} 62 | table.authors .authors-affiliation {display: block; font-size: 1.2em;} 63 | table.models {width: 100%; max-width: 800px; margin: auto; font-size: 1em; border-collapse: collapse;} 64 | tr.models {background-color: #fff; border: 1px solid #eaeaea;} 65 | tr.models td {padding: 6px 0; text-align: center;} 66 | a.btn {display: inline-block; min-width: 70px; font-family: "Google Sans", sans-serif; background-color: rgb(47, 47, 47); color: white; padding: 8px 18px; font-size: 1.1em; font-weight: normal; border-radius: 32px;} 67 | a.btn:hover {background-color: rgb(54, 54, 54);} 68 | a.btn-disabled {background-color: rgb(174, 174, 174) !important; color: rgb(234, 234, 234) !important;} 69 | a.btn-blue {background-color: #648ef6; color: white;} 70 | a.btn-blue:hover {background-color: #fa6d6d;} 71 | .header-menu-item a.disabled {color: rgb(47, 47, 47) !important; pointer-events: none; text-decoration: underline;} 72 | a.link-disabled {color: #ccc; pointer-events: none;} 73 | .noselect {-webkit-touch-callout: none; -webkit-user-select: none; -khtml-user-select: none; -moz-user-select: none; -ms-user-select: none; user-select: none;} 74 | .bold {font-weight: bold;} 75 | .vbold {font-weight: bolder;} 76 | .italic {font-style: italic;} 77 | .underline {text-decoration: underline;} 78 | .red {color: #D62727;} 79 | .simbav2 {color: rgb(100, 142, 246)} 80 | .simba {color: rgb(233, 153, 19)} 81 | .tdmpc2 {color: rgb(146, 189, 110)} 82 | .bro {color: rgb(245, 150, 162)} 83 | .highlight {font-weight: bolder; font-style: italic; color: #D62727;} 84 | .default-color {color: rgb(54, 54, 54) !important;} 85 | .tldr-hr {width: 100%; height: 1px; margin: 36px 0; background-color: #d6dbdf;} 86 | .tldr-container {width: 100%; max-width: 960px; margin-left: auto; margin-right: auto; text-align: center; padding: 5px 5px; box-sizing: border-box;} 87 | .tldr-container h2 {font-size: 2em; margin: 0 0 10px 0;} 88 | .tldr-container .tldr-content {font-size: 1.3em; width: 100%; display: inline-block; text-align: center; margin: 0;} 89 | footer {background-color: #efeff3; width: 100%; margin-top: 32px; padding-top: 16px; padding-bottom: 16px; text-align: center;} 90 | 91 | .wrap-collabsible { 92 | margin-bottom: 1.2rem 0; 93 | } 94 | input[type='checkbox'] { 95 | display: none; 96 | } 97 | .lbl-toggle { 98 | color: rgb(100, 142, 246); 99 | display: block; 100 | /* font-weight: bold; */ 101 | font-size: 1.2rem; 102 | text-align: center; 103 | cursor: pointer; 104 | transition: all 0.35s ease-out; 105 | padding: 1rem; 106 | } 107 | .lbl-toggle:hover { 108 | color: #fa6d6d; 109 | } 110 | .lbl-toggle::before { 111 | content: ' '; 112 | display: inline-block; 113 | border-top: 5px solid transparent; 114 | border-bottom: 5px solid transparent; 115 | border-left: 5px solid currentColor; 116 | vertical-align: middle; 117 | margin-right: .7rem; 118 | transform: translateY(-2px); 119 | transition: transform .2s ease-out; 120 | } 121 | .toggle:checked + .lbl-toggle::before { 122 | transform: rotate(90deg) translateX(-3px); 123 | } 124 | .collapsible-online-content { 125 | max-height: 0px; 126 | overflow: hidden; 127 | transition: max-height 0.4s ease-in-out; 128 | } 129 | .collapsible-online-content .content-inner { 130 | margin-top: -4.0rem; 131 | } 132 | .toggle:checked + .lbl-toggle + .collapsible-online-content { 133 | max-height: 100vh; 134 | } 135 | .toggle:checked + .lbl-toggle { 136 | border-bottom-right-radius: 0; 137 | border-bottom-left-radius: 0; 138 | } 139 | -------------------------------------------------------------------------------- /scale_rl/agents/wrappers/normalization.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | from flax.training import checkpoints 5 | 6 | from scale_rl.agents.base_agent import AgentWrapper, BaseAgent 7 | from scale_rl.agents.wrappers.utils import RunningMeanStd 8 | 9 | 10 | class ObservationNormalizer(AgentWrapper): 11 | """ 12 | This wrapper will normalize observations s.t. each coordinate is centered with unit variance. 13 | 14 | Observation statistics is updated only on sample_actions with training==True 15 | """ 16 | 17 | def __init__(self, agent: BaseAgent, load_rms: bool = True, epsilon: float = 1e-8): 18 | AgentWrapper.__init__(self, agent) 19 | 20 | self.obs_rms = RunningMeanStd( 21 | shape=(1,) + self.agent._observation_space.shape[1:], 22 | dtype=np.float32, 23 | ) 24 | self.load_rms = load_rms 25 | self.epsilon = epsilon 26 | 27 | def _normalize(self, observations): 28 | return (observations - self.obs_rms.mean) / np.sqrt( 29 | self.obs_rms.var + self.epsilon 30 | ) 31 | 32 | def sample_actions( 33 | self, 34 | interaction_step: int, 35 | prev_timestep: Dict[str, np.ndarray], 36 | training: bool, 37 | ) -> np.ndarray: 38 | """ 39 | Defines the sample action function with normalized observation. 40 | """ 41 | observations = prev_timestep["next_observation"] 42 | if training: 43 | self.obs_rms.update(observations) 44 | prev_timestep["next_observation"] = self._normalize(observations) 45 | 46 | return self.agent.sample_actions( 47 | interaction_step=interaction_step, 48 | prev_timestep=prev_timestep, 49 | training=training, 50 | ) 51 | 52 | def update(self, update_step: int, batch: Dict[str, np.ndarray]): 53 | batch["observation"] = self._normalize(batch["observation"]) 54 | batch["next_observation"] = self._normalize(batch["next_observation"]) 55 | return self.agent.update( 56 | update_step=update_step, 57 | batch=batch, 58 | ) 59 | 60 | def save(self, path: str) -> None: 61 | """ 62 | Save both the wrapped agent and this wrapper's running statistics. 63 | """ 64 | # 1. Save the underlying agent’s checkpoint 65 | self.agent.save(path) 66 | 67 | # 2. Save the wrapper’s statistics in a separate file 68 | ckpt = { 69 | "obs_rms_mean": self.obs_rms.mean, 70 | "obs_rms_var": self.obs_rms.var, 71 | "obs_rms_count": self.obs_rms.count, 72 | } 73 | checkpoints.save_checkpoint( 74 | ckpt_dir=path + "/obs_norm", 75 | target=ckpt, 76 | step=0, 77 | overwrite=True, 78 | keep=1, 79 | ) 80 | 81 | def load(self, path: str): 82 | """ 83 | Load both the wrapped agent and the wrapper’s running statistics. 84 | """ 85 | # 1. Load the underlying agent 86 | self.agent.load(path) 87 | 88 | # 2. Load the wrapper’s statistics 89 | if self.load_rms: 90 | ckpt = checkpoints.restore_checkpoint( 91 | ckpt_dir=path + "/obs_norm", target=None 92 | ) 93 | self.obs_rms.mean = ckpt["obs_rms_mean"] 94 | self.obs_rms.var = ckpt["obs_rms_var"] 95 | self.obs_rms.count = ckpt["obs_rms_count"] 96 | 97 | 98 | class RewardNormalizer(AgentWrapper): 99 | """ 100 | This wrapper will scale rewards using the variance of a running estimate of the discounted returns. 101 | In policy gradient methods, the update rule often involves the term ∇log ⁡π(a|s)⋅G_t, where G_t is the return from time t. 102 | Scaling G_t to have unit variance can be an effective variance reduction technique. 103 | 104 | Return statistics is updated only on sample_actions with training == True 105 | """ 106 | 107 | def __init__( 108 | self, 109 | agent: BaseAgent, 110 | gamma: float, 111 | g_max: float = 10.0, 112 | load_rms: bool = True, 113 | epsilon: float = 1e-8, 114 | ): 115 | AgentWrapper.__init__(self, agent) 116 | self.G = 0.0 # running estimate of the discounted return 117 | self.G_rms = RunningMeanStd( 118 | shape=1, 119 | dtype=np.float32, 120 | ) 121 | self.G_r_max = 0.0 # running-max 122 | self.gamma = gamma 123 | self.g_max = g_max 124 | self.load_rms = load_rms 125 | self.epsilon = epsilon 126 | 127 | def _scale_reward(self, rewards): 128 | var_denominator = np.sqrt(self.G_rms.var + self.epsilon) 129 | min_required_denominator = self.G_r_max / self.g_max 130 | denominator = max(var_denominator, min_required_denominator) 131 | 132 | return rewards / denominator 133 | 134 | def sample_actions( 135 | self, 136 | interaction_step: int, 137 | prev_timestep: Dict[str, np.ndarray], 138 | training: bool, 139 | ) -> np.ndarray: 140 | """ 141 | Defines the sample action function with updating statistics. 142 | """ 143 | if training: 144 | reward = prev_timestep["reward"] 145 | terminated = prev_timestep["terminated"] 146 | truncated = prev_timestep["truncated"] 147 | done = np.logical_or(terminated, truncated) 148 | self.G = self.gamma * (1 - done) * self.G + reward 149 | self.G_rms.update(self.G) 150 | self.G_r_max = max(self.G_r_max, max(abs(self.G))) 151 | 152 | return self.agent.sample_actions( 153 | interaction_step=interaction_step, 154 | prev_timestep=prev_timestep, 155 | training=training, 156 | ) 157 | 158 | def update(self, update_step: int, batch: Dict[str, np.ndarray]): 159 | batch["reward"] = self._scale_reward(batch["reward"]) 160 | return self.agent.update( 161 | update_step=update_step, 162 | batch=batch, 163 | ) 164 | 165 | def save(self, path: str) -> None: 166 | """ 167 | Save both the wrapped agent and this wrapper's running statistics. 168 | """ 169 | # 1. Save the underlying agent’s checkpoint 170 | self.agent.save(path) 171 | 172 | # 2. Save the wrapper’s statistics in a separate file 173 | ckpt = { 174 | "G": self.G, 175 | "G_rms_mean": self.G_rms.mean, 176 | "G_rms_var": self.G_rms.var, 177 | "G_rms_count": self.G_rms.count, 178 | "G_r_max": self.G_r_max, 179 | } 180 | checkpoints.save_checkpoint( 181 | ckpt_dir=path + "/rew_norm", 182 | target=ckpt, 183 | step=0, 184 | overwrite=True, 185 | keep=1, 186 | ) 187 | 188 | def load(self, path: str): 189 | """ 190 | Load both the wrapped agent and the wrapper’s running statistics. 191 | """ 192 | # 1. Load the underlying agent 193 | self.agent.load(path) 194 | 195 | # 2. Load the wrapper’s statistics 196 | if self.load_rms: 197 | ckpt = checkpoints.restore_checkpoint( 198 | ckpt_dir=path + "/rew_norm", target=None 199 | ) 200 | self.G = ckpt["G"] 201 | self.G_rms.mean = ckpt["G_rms_mean"] 202 | self.G_rms.var = ckpt["G_rms_var"] 203 | self.G_rms.count = ckpt["G_rms_count"] 204 | self.G_r_max = ckpt["G_r_max"] 205 | -------------------------------------------------------------------------------- /docs/dataset/css/style.css: -------------------------------------------------------------------------------- 1 | body {background-color: #fdfdfd; color: rgb(54, 54, 54); margin: 0; font-size: 1em;} 2 | h2, h3, h4, h5 {text-align: center;} 3 | h2 {margin-top: 48px; font-weight: bold !important;} 4 | h2, h3, h4, h5, a, p, span, body {font-weight: normal; font-family: "Google Sans", sans-serif;} 5 | .header {background-color: #fdfdfd; width: 100%; padding-top: 48px; padding-bottom: 8px;} 6 | .header-menu {background-color: #efeff3; width: 100%; padding: 16px 0;} 7 | .header-menu-content {max-width: 960px; margin: auto;} 8 | .header-menu-item {display: inline-block; margin-left: 16px; margin-right: 16px; font-size: 1.2em;} 9 | 10 | .links {width: 120%; margin: auto; text-align: center; padding-top: 8px; margin-left: -10%;} 11 | .links a {margin-left: 8px;} 12 | .links br {display: none;} 13 | 14 | @media (max-width: 900px) { 15 | .links {width: 100%; margin: auto; text-align: center; padding-top: 8px;} 16 | .links br {display: block;} 17 | .links a {margin-top: 8px; } 18 | } 19 | 20 | 21 | .content {max-width: 960px; margin: auto; margin-top: 48px; margin-bottom: 64px;} 22 | a, h2 {color: rgb(100, 142, 246); text-decoration: none;} 23 | a:hover {color: #fa6d6d;} 24 | .nobreak {white-space: nowrap;} 25 | .hr {width: 100%; height: 1px; margin: 48px 0; background-color: #d6dbdf;} 26 | p {line-height: 1.4em; text-align: justify;} 27 | .math {font-family: "Computer Modern Sans", sans-serif; font-style: italic;} 28 | sub, sup {line-height: 0;} 29 | .figure {width: 100%; min-height: 120px; margin: 36px 0; background-repeat: no-repeat; background-position: center; background-size: contain;} 30 | .figure-caption {margin: auto; max-width: 95%; margin-top: 24px; margin-bottom: 24px;} 31 | .content-video {width: 100%; margin: 0; text-align: center;} 32 | .content-video-container {width: 100%; max-width: 860px; margin: auto} 33 | .content-video-frame {display: inline-block; width: 24.5%; text-align: center;} 34 | .content-video-frame.medium {width: 30%;} 35 | .content-video-frame.large {width: 36%;} 36 | .content-video-frame.huge {width: 76%;} 37 | .content-video-frame span {display: inline-block; margin-bottom: 12px; font-weight: bold; font-size: 1.2em;} 38 | .hidden-content {display: none; background-color: #f3f3f6; border: 0; border-radius: 12px; margin-top: -18px; margin-bottom: 32px; padding-top: 50px; padding-bottom: 32px;} 39 | .legend {display: inline-block; border: 1px solid #966565; padding: 8px; margin-top: 12px; text-align: center;} 40 | .legend-item {display: inline-block; margin-left: 6px; margin-right: 6px; font-size: 12px;} 41 | .legend-symbol {font-weight: bold; margin-right: 6px; font-size: 20px;} 42 | .page {display: inline-block; width: 84px; height: 108px; border: 1px solid #bbb; margin: 2px; background-repeat: no-repeat; background-position: center; background-size: contain;} 43 | table.authors {width: 100%; max-width: 700px; margin: auto; margin-bottom: 16px; text-align: center;} 44 | table.authors a {padding: 6px 0; display: inline-block; font-weight: normal; font-size: 1.3em;} 45 | table.authors .authors-affiliation {display: block; font-size: 1.2em;} 46 | table.models {width: 100%; max-width: 800px; margin: auto; font-size: 1em; border-collapse: collapse;} 47 | tr.models {background-color: #fff; border: 1px solid #eaeaea;} 48 | tr.models td {padding: 6px 0; text-align: center;} 49 | a.btn {display: inline-block; min-width: 1em; font-family: "Google Sans", sans-serif; background-color: rgb(47, 47, 47); color: white; padding: 8px 18px; font-size: 1em; font-weight: normal; border-radius: 1.2em;} 50 | a.btn:hover {background-color: rgb(54, 54, 54);} 51 | a.btn-disabled {background-color: rgb(174, 174, 174) !important; color: rgb(234, 234, 234) !important;} 52 | a.btn-blue {background-color: #648ef6; color: white;} 53 | a.btn-blue:hover {background-color: #fa6d6d;} 54 | .header-menu-item a.disabled {color: rgb(47, 47, 47) !important; pointer-events: none; text-decoration: underline;} 55 | a.link-disabled {color: #ccc; pointer-events: none;} 56 | .bibtexsection {padding: 4px 16px; font-family: "Courier", monospace; font-size: 15px; white-space: pre; background-color: #f4f4f4; text-align: left;} 57 | .noselect {-webkit-touch-callout: none; -webkit-user-select: none; -khtml-user-select: none; -moz-user-select: none; -ms-user-select: none; user-select: none;} 58 | .bold {font-weight: bold;} 59 | .vbold {font-weight: bolder;} 60 | .italic {font-style: italic;} 61 | .underline {text-decoration: underline;} 62 | .red {color: #D62727;} 63 | .simbav2 {color: rgb(100, 142, 246)} 64 | .simba {color: rgb(233, 153, 19)} 65 | .tdmpc2 {color: rgb(146, 189, 110)} 66 | .bro {color: rgb(245, 150, 162)} 67 | .highlight {font-weight: bolder; font-style: italic; color: #D62727;} 68 | .default-color {color: rgb(54, 54, 54) !important;} 69 | .tldr-hr {width: 100%; height: 1px; margin: 36px 0; background-color: #d6dbdf;} 70 | .tldr-container {width: 100%; max-width: 960px; margin-left: auto; margin-right: auto; text-align: center; padding: 5px 5px; box-sizing: border-box;} 71 | .tldr-container h2 {font-size: 2em; margin: 0 0 10px 0;} 72 | .tldr-container .tldr-content {font-size: 1.3em; width: 100%; display: inline-block; text-align: center; margin: 0;} 73 | footer {background-color: #efeff3; width: 100%; margin-top: 32px; padding-top: 16px; padding-bottom: 16px; text-align: center;} 74 | 75 | 76 | 77 | .wrap-collabsible { 78 | margin-bottom: 1.2rem 0; 79 | } 80 | input[type='checkbox'] { 81 | display: none; 82 | } 83 | .lbl-toggle { 84 | color: rgb(100, 142, 246); 85 | display: block; 86 | /* font-weight: bold; */ 87 | font-size: 1.2rem; 88 | text-align: center; 89 | cursor: pointer; 90 | transition: all 0.35s ease-out; 91 | padding: 1rem; 92 | } 93 | .lbl-toggle:hover { 94 | color: #fa6d6d; 95 | } 96 | .lbl-toggle::before { 97 | content: ' '; 98 | display: inline-block; 99 | border-top: 5px solid transparent; 100 | border-bottom: 5px solid transparent; 101 | border-left: 5px solid currentColor; 102 | vertical-align: middle; 103 | margin-right: .7rem; 104 | transform: translateY(-2px); 105 | transition: transform .2s ease-out; 106 | } 107 | .toggle:checked + .lbl-toggle::before { 108 | transform: rotate(90deg) translateX(-3px); 109 | } 110 | /* MuJoCo Table */ 111 | .collapsible-mujoco-content { 112 | max-height: 0px; 113 | overflow: hidden; 114 | transition: max-height 0.4s ease-in-out; 115 | } 116 | .collapsible-mujoco-content .content-inner { 117 | margin-top: 0.0rem; 118 | } 119 | .toggle:checked + .lbl-toggle + .collapsible-mujoco-content { 120 | max-height: 1000vh; 121 | } 122 | /* DMC Table */ 123 | .collapsible-dmc-content { 124 | max-height: 0px; 125 | overflow: hidden; 126 | transition: max-height 0.4s ease-in-out; 127 | } 128 | .collapsible-dmc-content .content-inner { 129 | margin-top: 0.0rem; 130 | } 131 | .toggle:checked + .lbl-toggle + .collapsible-dmc-content { 132 | max-height: 1000vh; 133 | } 134 | 135 | /* MyoSuite Table */ 136 | .collapsible-myosuite-content { 137 | max-height: 0px; 138 | overflow: hidden; 139 | transition: max-height 0.4s ease-in-out; 140 | } 141 | .collapsible-myosuite-content .content-inner { 142 | margin-top: 0.0rem; 143 | } 144 | .toggle:checked + .lbl-toggle + .collapsible-myosuite-content { 145 | max-height: 1000vh; 146 | } 147 | /* HBench Table */ 148 | .collapsible-hbench-content { 149 | max-height: 0px; 150 | overflow: hidden; 151 | transition: max-height 0.4s ease-in-out; 152 | } 153 | .collapsible-hbench-content .content-inner { 154 | margin-top: 0.0rem; 155 | } 156 | .toggle:checked + .lbl-toggle + .collapsible-hbench-content { 157 | max-height: 1000vh; 158 | } 159 | .toggle:checked + .lbl-toggle { 160 | border-bottom-right-radius: 0; 161 | border-bottom-left-radius: 0; 162 | } 163 | 164 | 165 | .results-carousel { 166 | overflow: hidden; 167 | width: 100%; 168 | } 169 | 170 | .results-carousel .item { 171 | margin: 5px; 172 | overflow: hidden; 173 | border: 0px solid #bbb; 174 | border-radius: 10px; 175 | padding: 0; 176 | font-size: 1em; 177 | } 178 | .results-carousel video { 179 | margin: 0; 180 | border-radius: 10px; 181 | } 182 | .results-carousel p { 183 | margin-top: 0.8em; 184 | text-align: center; 185 | } 186 | 187 | -------------------------------------------------------------------------------- /scale_rl/agents/simbaV2/simbaV2_update.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple 2 | 3 | import flax 4 | import jax 5 | import jax.numpy as jnp 6 | 7 | from scale_rl.agents.jax_utils.network import Network, PRNGKey 8 | from scale_rl.agents.jax_utils.tree_utils import tree_map_until_match 9 | from scale_rl.buffers import Batch 10 | 11 | EPS = 1e-8 12 | 13 | 14 | def l2normalize( 15 | x: jnp.ndarray, 16 | axis: int, 17 | ) -> jnp.ndarray: 18 | l2norm = jnp.linalg.norm(x, ord=2, axis=axis, keepdims=True) 19 | x = x / jnp.maximum(l2norm, EPS) 20 | 21 | return x 22 | 23 | 24 | def l2normalize_layer(tree): 25 | """ 26 | apply l2-normalization to the all leaf nodes 27 | """ 28 | if len(tree["kernel"].shape) == 2: 29 | axis = 0 30 | elif len(tree["kernel"].shape) == 3: 31 | axis = 1 32 | else: 33 | raise ValueError 34 | return jax.tree.map(f=lambda x: l2normalize(x, axis=axis), tree=tree) 35 | 36 | 37 | def l2normalize_network( 38 | network: Network, 39 | regex: str = "hyper_dense", 40 | ) -> Network: 41 | params = network.params 42 | new_params = tree_map_until_match( 43 | f=lambda x: l2normalize_layer(x), tree=params, target_re=regex, keep_values=True 44 | ) 45 | network = network.replace(params=new_params) 46 | return network 47 | 48 | 49 | def update_actor( 50 | key: PRNGKey, 51 | actor: Network, 52 | critic: Network, 53 | temperature: Network, 54 | batch: Batch, 55 | use_cdq: bool, 56 | bc_alpha: float, 57 | ) -> Tuple[Network, Dict[str, float]]: 58 | def actor_loss_fn( 59 | actor_params: flax.core.FrozenDict[str, Any], 60 | ) -> Tuple[jnp.ndarray, Dict[str, float]]: 61 | dist, _ = actor.apply( 62 | variables={"params": actor_params}, 63 | observations=batch["observation"], 64 | ) 65 | actions = dist.sample(seed=key) 66 | log_probs = dist.log_prob(actions) 67 | 68 | if use_cdq: 69 | # qs: (2, n) 70 | qs, q_infos = critic(observations=batch["observation"], actions=actions) 71 | q = jnp.minimum(qs[0], qs[1]) 72 | else: 73 | q, _ = critic(observations=batch["observation"], actions=actions) 74 | 75 | actor_loss = (log_probs * temperature() - q).mean() 76 | 77 | if bc_alpha > 0: 78 | # https://arxiv.org/abs/2306.02451 79 | q_abs = jax.lax.stop_gradient(jnp.abs(q).mean()) 80 | bc_loss = ((actions - batch["action"]) ** 2).mean() 81 | actor_loss = actor_loss + bc_alpha * q_abs * bc_loss 82 | 83 | actor_info = { 84 | "actor/loss": actor_loss, 85 | "actor/entropy": -log_probs.mean(), 86 | "actor/mean_action": jnp.mean(actions), 87 | } 88 | return actor_loss, actor_info 89 | 90 | actor, info = actor.apply_gradient(actor_loss_fn) 91 | actor = l2normalize_network(actor) 92 | 93 | return actor, info 94 | 95 | 96 | def categorical_td_loss( 97 | pred_log_probs: jnp.ndarray, # (n, num_bins) 98 | target_log_probs: jnp.ndarray, # (n, num_bins) 99 | reward: jnp.ndarray, # (n,) 100 | done: jnp.ndarray, # (n,) 101 | actor_entropy: jnp.ndarray, # (n,) 102 | gamma: float, 103 | num_bins: int, 104 | min_v: float, 105 | max_v: float, 106 | ) -> jnp.ndarray: 107 | reward = reward.reshape(-1, 1) 108 | done = done.reshape(-1, 1) 109 | actor_entropy = actor_entropy.reshape(-1, 1) 110 | 111 | # compute target value buckets 112 | # target_bin_values: (n, num_bins) 113 | bin_values = jnp.linspace(start=min_v, stop=max_v, num=num_bins).reshape(1, -1) 114 | target_bin_values = reward + gamma * (bin_values - actor_entropy) * (1.0 - done) 115 | target_bin_values = jnp.clip(target_bin_values, min_v, max_v) # (B, num_bins) 116 | 117 | # update indices 118 | b = (target_bin_values - min_v) / ((max_v - min_v) / (num_bins - 1)) 119 | l = jnp.floor(b) 120 | l_mask = jax.nn.one_hot(l.reshape(-1), num_bins).reshape((-1, num_bins, num_bins)) 121 | u = jnp.ceil(b) 122 | u_mask = jax.nn.one_hot(u.reshape(-1), num_bins).reshape((-1, num_bins, num_bins)) 123 | 124 | # target label 125 | _target_probs = jnp.exp(target_log_probs) 126 | m_l = (_target_probs * (u + (l == u).astype(jnp.float32) - b)).reshape( 127 | -1, num_bins, 1 128 | ) 129 | m_u = (_target_probs * (b - l)).reshape((-1, num_bins, 1)) 130 | target_probs = jax.lax.stop_gradient(jnp.sum(m_l * l_mask + m_u * u_mask, axis=1)) 131 | 132 | # cross entropy loss 133 | loss = -jnp.mean(jnp.sum(target_probs * pred_log_probs, axis=1)) 134 | 135 | return loss 136 | 137 | 138 | def update_critic( 139 | key: PRNGKey, 140 | actor: Network, 141 | critic: Network, 142 | target_critic: Network, 143 | temperature: Network, 144 | batch: Batch, 145 | use_cdq: bool, 146 | min_v: float, 147 | max_v: float, 148 | num_bins: int, 149 | gamma: float, 150 | n_step: int, 151 | ) -> Tuple[Network, Dict[str, float]]: 152 | # compute the target q-value 153 | next_dist, _ = actor(observations=batch["next_observation"]) 154 | next_actions = next_dist.sample(seed=key) 155 | next_actor_log_probs = next_dist.log_prob(next_actions) 156 | next_actor_entropy = temperature() * next_actor_log_probs 157 | 158 | if use_cdq: 159 | # next_qs: (2, n) 160 | # next_q_infos['log_prob]: (2, n, num_bins) 161 | # next_q_log_probs: (n, num_bins) 162 | next_qs, next_q_infos = target_critic( 163 | observations=batch["next_observation"], actions=next_actions 164 | ) 165 | min_indices = next_qs.argmin(axis=0) 166 | next_q_log_probs = jax.vmap( 167 | lambda log_prob, idx: log_prob[idx], in_axes=(1, 0) 168 | )(next_q_infos["log_prob"], min_indices) 169 | else: 170 | next_q, next_q_info = target_critic( 171 | observations=batch["next_observation"], 172 | actions=next_actions, 173 | ) 174 | next_q_log_probs = next_q_info["log_prob"] 175 | 176 | def critic_loss_fn( 177 | critic_params: flax.core.FrozenDict[str, Any], 178 | ) -> Tuple[jnp.ndarray, Dict[str, float]]: 179 | if use_cdq: 180 | # compute predicted q-value 181 | pred_qs, pred_q_infos = critic.apply( 182 | variables={"params": critic_params}, 183 | observations=batch["observation"], 184 | actions=batch["action"], 185 | ) 186 | loss_1 = categorical_td_loss( 187 | pred_log_probs=pred_q_infos["log_prob"][0], 188 | target_log_probs=next_q_log_probs, 189 | reward=batch["reward"], 190 | done=batch["terminated"], 191 | actor_entropy=next_actor_entropy, 192 | gamma=gamma**n_step, 193 | num_bins=num_bins, 194 | min_v=min_v, 195 | max_v=max_v, 196 | ) 197 | loss_2 = categorical_td_loss( 198 | pred_log_probs=pred_q_infos["log_prob"][1], 199 | target_log_probs=next_q_log_probs, 200 | reward=batch["reward"], 201 | done=batch["terminated"], 202 | actor_entropy=next_actor_entropy, 203 | gamma=gamma**n_step, 204 | num_bins=num_bins, 205 | min_v=min_v, 206 | max_v=max_v, 207 | ) 208 | critic_loss = (loss_1 + loss_2).mean() 209 | 210 | else: 211 | pred_q, pred_q_info = critic.apply( 212 | variables={"params": critic_params}, 213 | observations=batch["observation"], 214 | actions=batch["action"], 215 | ) 216 | loss = categorical_td_loss( 217 | pred_log_probs=pred_q_info["log_prob"], 218 | target_log_probs=next_q_log_probs, 219 | reward=batch["reward"], 220 | done=batch["terminated"], 221 | actor_entropy=next_actor_entropy, 222 | gamma=gamma**n_step, 223 | num_bins=num_bins, 224 | min_v=min_v, 225 | max_v=max_v, 226 | ) 227 | critic_loss = loss.mean() 228 | 229 | critic_info = { 230 | "critic/loss": critic_loss, 231 | "critic/batch_rew_min": batch["reward"].min(), 232 | "critic/batch_rew_mean": batch["reward"].mean(), 233 | "critic/batch_rew_max": batch["reward"].max(), 234 | } 235 | 236 | return critic_loss, critic_info 237 | 238 | critic, info = critic.apply_gradient(critic_loss_fn) 239 | critic = l2normalize_network(critic) 240 | 241 | return critic, info 242 | 243 | 244 | def update_target_network( 245 | network: Network, 246 | target_network: Network, 247 | target_tau: bool, 248 | ) -> Tuple[Network, Dict[str, float]]: 249 | new_target_params = jax.tree_map( 250 | lambda p, tp: p * target_tau + tp * (1 - target_tau), 251 | network.params, 252 | target_network.params, 253 | ) 254 | target_network = target_network.replace(params=new_target_params) 255 | 256 | info = {} 257 | 258 | return target_network, info 259 | 260 | 261 | def update_temperature( 262 | temperature: Network, entropy: float, target_entropy: float 263 | ) -> Tuple[Network, Dict[str, float]]: 264 | def temperature_loss_fn( 265 | temperature_params: flax.core.FrozenDict[str, Any], 266 | ) -> Tuple[jnp.ndarray, Dict[str, float]]: 267 | temperature_value = temperature.apply({"params": temperature_params}) 268 | temperature_loss = temperature_value * (entropy - target_entropy).mean() 269 | temperature_info = { 270 | "temperature/value": temperature_value, 271 | "temperature/loss": temperature_loss, 272 | } 273 | 274 | return temperature_loss, temperature_info 275 | 276 | temperature, info = temperature.apply_gradient(temperature_loss_fn) 277 | 278 | return temperature, info 279 | -------------------------------------------------------------------------------- /scale_rl/buffers/numpy_buffer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import deque 4 | 5 | import gymnasium as gym 6 | import numpy as np 7 | 8 | from scale_rl.buffers.base_buffer import BaseBuffer, Batch 9 | from scale_rl.buffers.utils import SegmentTree 10 | 11 | 12 | class NpyUniformBuffer(BaseBuffer): 13 | def __init__( 14 | self, 15 | observation_space: gym.spaces.Space, 16 | action_space: gym.spaces.Space, 17 | n_step: int, 18 | gamma: float, 19 | max_length: int, 20 | min_length: int, 21 | sample_batch_size: int, 22 | ): 23 | super(NpyUniformBuffer, self).__init__( 24 | observation_space, 25 | action_space, 26 | n_step, 27 | gamma, 28 | max_length, 29 | min_length, 30 | sample_batch_size, 31 | ) 32 | 33 | self._current_idx = 0 34 | 35 | def __len__(self): 36 | return self._num_in_buffer 37 | 38 | def reset(self) -> None: 39 | m = self._max_length 40 | 41 | # for pixel-based environments, we would prefer uint8 dtype. 42 | observation_shape = (self._observation_space.shape[-1],) 43 | observation_dtype = self._observation_space.dtype 44 | 45 | action_shape = (self._action_space.shape[-1],) 46 | action_dtype = self._action_space.dtype 47 | 48 | # for float64, we enforce it to be float32 49 | if observation_dtype == "float64": 50 | observation_dtype = np.float32 51 | 52 | if action_dtype == "float64": 53 | action_dtype = np.float32 54 | 55 | self._observations = np.empty((m,) + observation_shape, dtype=observation_dtype) 56 | self._actions = np.empty((m,) + action_shape, dtype=action_dtype) 57 | self._rewards = np.empty((m,), dtype=np.float32) 58 | self._terminateds = np.empty((m,), dtype=np.float32) 59 | self._truncateds = np.empty((m,), dtype=np.float32) 60 | self._next_observations = np.empty( 61 | (m,) + observation_shape, dtype=observation_dtype 62 | ) 63 | 64 | self._n_step_transitions = deque(maxlen=self._n_step) 65 | self._num_in_buffer = 0 66 | 67 | def _get_n_step_prev_timestep(self) -> Batch: 68 | """ 69 | This method processes a n_step_transitions to compute and update the 70 | n-step return, the done status, and the next observation. 71 | """ 72 | # pop n-step previous timestep 73 | n_step_prev_timestep = self._n_step_transitions[0] 74 | cur_timestep = self._n_step_transitions[-1] 75 | 76 | # copy (np.array(,) generates copy version of array) last timestep. 77 | n_step_reward = np.array(cur_timestep["reward"]) 78 | n_step_terminated = np.array(cur_timestep["terminated"]) 79 | n_step_truncated = np.array(cur_timestep["truncated"]) 80 | n_step_next_observation = np.array(cur_timestep["next_observation"]) 81 | 82 | for n_step_idx in reversed(range(self._n_step - 1)): 83 | transition = self._n_step_transitions[n_step_idx] 84 | reward = transition["reward"] # (n, ) 85 | terminated = transition["terminated"] # (n, ) 86 | truncated = transition["truncated"] # (n, ) 87 | next_observation = transition["next_observation"] # (n, *obs_shape) 88 | 89 | # compute n-step return 90 | done = (terminated.astype(bool) | truncated.astype(bool)).astype(np.float32) 91 | n_step_reward = reward + self._gamma * n_step_reward * (1 - done) 92 | 93 | # assign next observation starting from done 94 | done_mask = done.astype(bool) 95 | n_step_terminated[done_mask] = terminated[done_mask] 96 | n_step_truncated[done_mask] = truncated[done_mask] 97 | n_step_next_observation[done_mask] = next_observation[done_mask] 98 | 99 | n_step_prev_timestep["reward"] = n_step_reward 100 | n_step_prev_timestep["terminated"] = n_step_terminated 101 | n_step_prev_timestep["truncated"] = n_step_truncated 102 | n_step_prev_timestep["next_observation"] = n_step_next_observation 103 | 104 | return n_step_prev_timestep 105 | 106 | def add(self, timestep: Batch) -> None: 107 | # temporarily hold current timestep to the buffer 108 | self._n_step_transitions.append( 109 | {key: np.array(value) for key, value in timestep.items()} 110 | ) 111 | 112 | if len(self._n_step_transitions) >= self._n_step: 113 | n_step_prev_timestep = self._get_n_step_prev_timestep() 114 | 115 | add_batch_size = len(n_step_prev_timestep["observation"]) 116 | 117 | # add samples to the buffer 118 | add_idxs = np.arange(add_batch_size) + self._current_idx 119 | add_idxs = add_idxs % self._max_length 120 | 121 | self._observations[add_idxs] = n_step_prev_timestep["observation"] 122 | self._actions[add_idxs] = n_step_prev_timestep["action"] 123 | self._rewards[add_idxs] = n_step_prev_timestep["reward"] 124 | self._terminateds[add_idxs] = n_step_prev_timestep["terminated"] 125 | self._truncateds[add_idxs] = n_step_prev_timestep["truncated"] 126 | self._next_observations[add_idxs] = n_step_prev_timestep["next_observation"] 127 | 128 | self._num_in_buffer = min( 129 | self._num_in_buffer + add_batch_size, self._max_length 130 | ) 131 | self._current_idx = (self._current_idx + add_batch_size) % self._max_length 132 | 133 | def can_sample(self) -> bool: 134 | if self._num_in_buffer < self._min_length: 135 | return False 136 | else: 137 | return True 138 | 139 | def sample(self, sample_idxs=None) -> Batch: 140 | if sample_idxs is None: 141 | sample_idxs = np.random.randint( 142 | 0, self._num_in_buffer, size=self._sample_batch_size 143 | ) 144 | 145 | # copy the data for safeness 146 | batch = {} 147 | batch["observation"] = np.array(self._observations[sample_idxs]) 148 | batch["action"] = np.array(self._actions[sample_idxs]) 149 | batch["reward"] = np.array(self._rewards[sample_idxs]) 150 | batch["terminated"] = np.array(self._terminateds[sample_idxs]) 151 | batch["truncated"] = np.array(self._truncateds[sample_idxs]) 152 | batch["next_observation"] = np.array(self._next_observations[sample_idxs]) 153 | 154 | return batch 155 | 156 | def save(self, path: str) -> None: 157 | dataset = {} 158 | dataset["observation"] = self._observations[: self._num_in_buffer] 159 | dataset["action"] = self._actions[: self._num_in_buffer] 160 | dataset["reward"] = self._rewards[: self._num_in_buffer] 161 | dataset["terminated"] = self._terminateds[: self._num_in_buffer] 162 | dataset["truncated"] = self._truncateds[: self._num_in_buffer] 163 | dataset["next_observation"] = self._next_observations[: self._num_in_buffer] 164 | with open(os.path.join(path, "dataset.pickle"), "wb") as f: 165 | pickle.dump(dataset, f) 166 | 167 | def get_observations(self) -> np.ndarray: 168 | return self._observations[: self._num_in_buffer] 169 | 170 | 171 | class NpyPrioritizedBuffer(NpyUniformBuffer): 172 | def __init__( 173 | self, 174 | observation_space: gym.spaces.Space, 175 | action_space: gym.spaces.Space, 176 | n_step: int, 177 | gamma: float, 178 | max_length: int, 179 | min_length: int, 180 | add_batch_size: int, 181 | sample_batch_size: int, 182 | ): 183 | super(NpyPrioritizedBuffer, self).__init__( 184 | observation_space, 185 | action_space, 186 | n_step, 187 | gamma, 188 | max_length, 189 | min_length, 190 | add_batch_size, 191 | sample_batch_size, 192 | ) 193 | 194 | def reset(self) -> None: 195 | super().reset() 196 | self._priority_tree = SegmentTree(self._max_length) 197 | 198 | def add(self, timestep: Batch) -> None: 199 | super().add(timestep) 200 | 201 | # add samples to the priority tree 202 | # SegmentTree class is not vectorized so just added instance one-by-one. 203 | if len(self._n_step_transitions) == self._n_step: 204 | for _ in range(self._add_batch_size): 205 | self._priority_tree.add(value=self._priority_tree.max) 206 | 207 | def _sample_idx_from_priority_tree(self): 208 | p_total = self._priority_tree.total # sum of the priorities 209 | segment_length = p_total / self._sample_batch_size 210 | segment_starts = np.arange(self._sample_batch_size) * segment_length 211 | valid = False 212 | 213 | while not valid: 214 | # Uniformly sample from within all segments 215 | samples = ( 216 | np.random.uniform(0.0, segment_length, [self._sample_batch_size]) 217 | + segment_starts 218 | ) 219 | # Retrieve samples from tree with un-normalised probability 220 | buffer_idxs, tree_idxs, sample_probs = self._priority_tree.find(samples) 221 | if np.all(sample_probs != 0): 222 | valid = True # Note that conditions are valid but extra conservative around buffer index 0 223 | 224 | return buffer_idxs, tree_idxs, sample_probs 225 | 226 | def sample(self) -> Batch: 227 | sample_idxs, tree_idxs, sample_probs = self._sample_idx_from_priority_tree() 228 | 229 | batch = {} 230 | batch["observation"] = self._observations[sample_idxs] 231 | batch["action"] = self._actions[sample_idxs] 232 | batch["reward"] = self._rewards[sample_idxs] 233 | batch["terminated"] = self._terminateds[sample_idxs] 234 | batch["truncated"] = self._truncateds[sample_idxs] 235 | batch["next_observation"] = self._next_observations[sample_idxs] 236 | 237 | batch["tree_idxs"] = tree_idxs 238 | batch["sample_probs"] = sample_probs 239 | 240 | return batch 241 | 242 | def save(self, path: str) -> None: 243 | pass 244 | -------------------------------------------------------------------------------- /scale_rl/agents/simba/simba_agent.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from dataclasses import dataclass 3 | from typing import Dict, Tuple 4 | 5 | import gymnasium as gym 6 | import jax 7 | import jax.numpy as jnp 8 | import numpy as np 9 | import optax 10 | 11 | from scale_rl.agents.base_agent import BaseAgent 12 | from scale_rl.agents.jax_utils.network import Network, PRNGKey 13 | from scale_rl.agents.simba.simba_network import ( 14 | SimbaActor, 15 | SimbaCritic, 16 | SimbaDoubleCritic, 17 | SimbaTemperature, 18 | ) 19 | from scale_rl.agents.simba.simba_update import ( 20 | update_actor, 21 | update_critic, 22 | update_target_network, 23 | update_temperature, 24 | ) 25 | from scale_rl.buffers.base_buffer import Batch 26 | 27 | """ 28 | The @dataclass decorator must have `frozen=True` to ensure the instance is immutable, 29 | allowing it to be treated as a static variable in JAX. 30 | """ 31 | 32 | 33 | @dataclass(frozen=True) 34 | class SimbaConfig: 35 | seed: int 36 | normalize_observation: bool 37 | normalize_reward: bool 38 | normalized_g_max: float 39 | 40 | load_only_param: bool 41 | load_param_key: bool 42 | load_observation_normalizer: bool 43 | load_reward_normalizer: bool 44 | 45 | learning_rate_init: float 46 | learning_rate_end: float 47 | learning_rate_decay_rate: float 48 | learning_rate_decay_step: int 49 | weight_decay: float 50 | 51 | actor_num_blocks: int 52 | actor_hidden_dim: int 53 | actor_bc_alpha: float 54 | 55 | critic_num_blocks: int 56 | critic_hidden_dim: int 57 | critic_use_cdq: bool 58 | 59 | target_tau: float 60 | 61 | temp_initial_value: float 62 | temp_target_entropy: float 63 | temp_target_entropy_coef: float 64 | 65 | gamma: float 66 | n_step: int 67 | 68 | 69 | @functools.partial( 70 | jax.jit, 71 | static_argnames=( 72 | "observation_dim", 73 | "action_dim", 74 | "cfg", 75 | ), 76 | ) 77 | def _init_simba_networks( 78 | observation_dim: int, 79 | action_dim: int, 80 | cfg: SimbaConfig, 81 | ) -> Tuple[PRNGKey, Network, Network, Network, Network]: 82 | fake_observations = jnp.zeros((1, observation_dim)) 83 | fake_actions = jnp.zeros((1, action_dim)) 84 | 85 | rng = jax.random.PRNGKey(cfg.seed) 86 | rng, actor_key, critic_key, temp_key = jax.random.split(rng, 4) 87 | 88 | # When initializing the network in the flax.nn.Module class, rng_key should be passed as rngs. 89 | actor = Network.create( 90 | network_def=SimbaActor( 91 | num_blocks=cfg.actor_num_blocks, 92 | hidden_dim=cfg.actor_hidden_dim, 93 | action_dim=action_dim, 94 | ), 95 | network_inputs={"rngs": actor_key, "observations": fake_observations}, 96 | tx=optax.adamw( 97 | learning_rate=optax.linear_schedule( 98 | init_value=cfg.learning_rate_init, 99 | end_value=cfg.learning_rate_end, 100 | transition_steps=cfg.learning_rate_decay_step, 101 | ), 102 | weight_decay=cfg.weight_decay, 103 | ), 104 | ) 105 | 106 | if cfg.critic_use_cdq: 107 | critic_network_def = SimbaDoubleCritic( 108 | num_blocks=cfg.critic_num_blocks, 109 | hidden_dim=cfg.critic_hidden_dim, 110 | ) 111 | else: 112 | critic_network_def = SimbaCritic( 113 | num_blocks=cfg.critic_num_blocks, 114 | hidden_dim=cfg.critic_hidden_dim, 115 | ) 116 | 117 | critic = Network.create( 118 | network_def=critic_network_def, 119 | network_inputs={ 120 | "rngs": critic_key, 121 | "observations": fake_observations, 122 | "actions": fake_actions, 123 | }, 124 | tx=optax.adamw( 125 | learning_rate=optax.linear_schedule( 126 | init_value=cfg.learning_rate_init, 127 | end_value=cfg.learning_rate_end, 128 | transition_steps=cfg.learning_rate_decay_step, 129 | ), 130 | weight_decay=cfg.weight_decay, 131 | ), 132 | ) 133 | 134 | # we set target critic's parameters identical to critic by using same rng. 135 | target_network_def = critic_network_def 136 | target_critic = Network.create( 137 | network_def=target_network_def, 138 | network_inputs={ 139 | "rngs": critic_key, 140 | "observations": fake_observations, 141 | "actions": fake_actions, 142 | }, 143 | tx=None, 144 | ) 145 | 146 | temperature = Network.create( 147 | network_def=SimbaTemperature(cfg.temp_initial_value), 148 | network_inputs={ 149 | "rngs": temp_key, 150 | }, 151 | tx=optax.adamw( 152 | learning_rate=optax.linear_schedule( 153 | init_value=cfg.learning_rate_init, 154 | end_value=cfg.learning_rate_end, 155 | transition_steps=cfg.learning_rate_decay_step, 156 | ), 157 | weight_decay=cfg.weight_decay, 158 | ), 159 | ) 160 | 161 | return rng, actor, critic, target_critic, temperature 162 | 163 | 164 | @jax.jit 165 | def _sample_simba_actions( 166 | rng: PRNGKey, 167 | actor: Network, 168 | observations: jnp.ndarray, 169 | temperature: float = 1.0, 170 | ) -> Tuple[PRNGKey, jnp.ndarray]: 171 | rng, key = jax.random.split(rng) 172 | dist, _ = actor(observations=observations, temperature=temperature) 173 | actions = dist.sample(seed=key) 174 | 175 | return rng, actions 176 | 177 | 178 | @functools.partial( 179 | jax.jit, 180 | static_argnames=( 181 | "gamma", 182 | "n_step", 183 | "critic_use_cdq", 184 | "target_tau", 185 | "temp_target_entropy", 186 | "actor_bc_alpha", 187 | ), 188 | ) 189 | def _update_simba_networks( 190 | rng: PRNGKey, 191 | actor: Network, 192 | critic: Network, 193 | target_critic: Network, 194 | temperature: Network, 195 | batch: Batch, 196 | gamma: float, 197 | n_step: int, 198 | actor_bc_alpha: float, 199 | critic_use_cdq: bool, 200 | target_tau: float, 201 | temp_target_entropy: float, 202 | ) -> Tuple[PRNGKey, Network, Network, Network, Network, Dict[str, float]]: 203 | rng, actor_key, critic_key = jax.random.split(rng, 3) 204 | 205 | new_actor, actor_info = update_actor( 206 | key=actor_key, 207 | actor=actor, 208 | critic=critic, 209 | temperature=temperature, 210 | batch=batch, 211 | use_cdq=critic_use_cdq, 212 | bc_alpha=actor_bc_alpha, 213 | ) 214 | 215 | new_temperature, temperature_info = update_temperature( 216 | temperature=temperature, 217 | entropy=actor_info["actor/entropy"], 218 | target_entropy=temp_target_entropy, 219 | ) 220 | 221 | new_critic, critic_info = update_critic( 222 | key=critic_key, 223 | actor=new_actor, 224 | critic=critic, 225 | target_critic=target_critic, 226 | temperature=new_temperature, 227 | batch=batch, 228 | gamma=gamma, 229 | n_step=n_step, 230 | use_cdq=critic_use_cdq, 231 | ) 232 | 233 | new_target_critic, target_critic_info = update_target_network( 234 | network=new_critic, 235 | target_network=target_critic, 236 | target_tau=target_tau, 237 | ) 238 | 239 | info = { 240 | **actor_info, 241 | **critic_info, 242 | **target_critic_info, 243 | **temperature_info, 244 | } 245 | 246 | return (rng, new_actor, new_critic, new_target_critic, new_temperature, info) 247 | 248 | 249 | class SimbaAgent(BaseAgent): 250 | def __init__( 251 | self, 252 | observation_space: gym.spaces.Space, 253 | action_space: gym.spaces.Space, 254 | cfg: SimbaConfig, 255 | ): 256 | """ 257 | An agent that randomly selects actions without training. 258 | Useful for collecting baseline results and for debugging purposes. 259 | """ 260 | 261 | self._observation_dim = observation_space.shape[-1] 262 | self._action_dim = action_space.shape[-1] 263 | 264 | cfg["temp_target_entropy"] = cfg["temp_target_entropy_coef"] * self._action_dim 265 | 266 | super(SimbaAgent, self).__init__( 267 | observation_space, 268 | action_space, 269 | cfg, 270 | ) 271 | 272 | # map dictionary to dataclass 273 | self._cfg = SimbaConfig(**cfg) 274 | 275 | # initialize networks 276 | ( 277 | self._rng, 278 | self._actor, 279 | self._critic, 280 | self._target_critic, 281 | self._temperature, 282 | ) = _init_simba_networks(self._observation_dim, self._action_dim, self._cfg) 283 | 284 | def sample_actions( 285 | self, 286 | interaction_step: int, 287 | prev_timestep: Dict[str, np.ndarray], 288 | training: bool, 289 | ) -> np.ndarray: 290 | if training: 291 | temperature = 1.0 292 | else: 293 | temperature = 0.0 294 | 295 | # current timestep observation is "next" observations from the previous timestep 296 | observations = jnp.asarray(prev_timestep["next_observation"]) 297 | 298 | self._rng, actions = _sample_simba_actions( 299 | self._rng, self._actor, observations, temperature 300 | ) 301 | actions = np.array(actions) 302 | 303 | return actions 304 | 305 | def update(self, update_step: int, batch: Dict[str, np.ndarray]) -> Dict: 306 | for key, value in batch.items(): 307 | batch[key] = jnp.asarray(value) 308 | 309 | ( 310 | self._rng, 311 | self._actor, 312 | self._critic, 313 | self._target_critic, 314 | self._temperature, 315 | update_info, 316 | ) = _update_simba_networks( 317 | rng=self._rng, 318 | actor=self._actor, 319 | critic=self._critic, 320 | target_critic=self._target_critic, 321 | temperature=self._temperature, 322 | batch=batch, 323 | gamma=self._cfg.gamma, 324 | n_step=self._cfg.n_step, 325 | actor_bc_alpha=self._cfg.actor_bc_alpha, 326 | critic_use_cdq=self._cfg.critic_use_cdq, 327 | target_tau=self._cfg.target_tau, 328 | temp_target_entropy=self._cfg.temp_target_entropy, 329 | ) 330 | 331 | for key, value in update_info.items(): 332 | if isinstance(value, dict): 333 | continue 334 | update_info[key] = float(value) 335 | 336 | return update_info 337 | 338 | def save(self, path: str) -> None: 339 | self._actor.save(path + "/actor") 340 | self._critic.save(path + "/critic") 341 | self._target_critic.save(path + "/target_critic") 342 | self._temperature.save(path + "/temperature") 343 | 344 | def load(self, path: str) -> None: 345 | only_param = self._cfg.load_only_param 346 | param_key = self._cfg.load_param_key 347 | 348 | self._actor = self._actor.load(path + "/actor", param_key, only_param) 349 | self._critic = self._critic.load(path + "/critic", param_key, only_param) 350 | self._target_critic = self._target_critic.load( 351 | path + "/target_critic", param_key, only_param 352 | ) 353 | self._temperature = self._temperature.load( 354 | path + "/temperature", None, only_param 355 | ) 356 | --------------------------------------------------------------------------------