├── research ├── datasets │ ├── replay_buffer │ │ └── __init__.py │ ├── __init__.py │ ├── robomimic_dataset.py │ ├── wgcsl_dataset.py │ ├── d4rl_dataset.py │ └── rollout_buffer.py ├── __init__.py ├── processors │ ├── __init__.py │ ├── base.py │ ├── concatenate.py │ ├── image_augmentation.py │ └── normalization.py ├── algs │ ├── __init__.py │ ├── offline │ │ ├── bc.py │ │ ├── dp.py │ │ ├── iql.py │ │ └── idql.py │ └── online │ │ ├── td3.py │ │ ├── sac.py │ │ ├── dqn.py │ │ └── drqv2.py ├── networks │ ├── __init__.py │ ├── drqv2.py │ ├── base.py │ ├── transformer.py │ └── common.py ├── utils │ ├── schedules.py │ ├── logger.py │ └── evaluate.py └── envs │ ├── base.py │ ├── robomimic.py │ ├── __init__.py │ └── metaworld.py ├── setup.cfg ├── tools ├── cleanup.py ├── parse_sweep.py ├── run_slurm.py └── run_local.py ├── .pre-commit-config.yaml ├── LICENSE ├── environment_m1.yaml ├── environment_cpu.yaml ├── environment_gpu.yaml ├── configs └── examples │ ├── dqn.yaml │ ├── bc.yaml │ ├── drqv2.yaml │ ├── td3.yaml │ ├── gcsl.yaml │ ├── franka_reach_sac.yaml │ ├── ppo.yaml │ ├── sac.yaml │ ├── idql.yaml │ ├── iql.yaml │ ├── example.yaml │ └── diffusion_policy.yaml ├── pyproject.toml ├── scripts ├── train.py ├── compute_action_normalization.py ├── plot.py ├── evaluate.py └── create_dataset.py ├── setup_shell.sh ├── environment_polymetis.yaml └── .gitignore /research/datasets/replay_buffer/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [options] 2 | packages = find: 3 | -------------------------------------------------------------------------------- /research/__init__.py: -------------------------------------------------------------------------------- 1 | from . import algs, datasets, envs, networks, processors 2 | -------------------------------------------------------------------------------- /research/processors/__init__.py: -------------------------------------------------------------------------------- 1 | # Register Preprocessors here 2 | from .base import Compose 3 | from .concatenate import Concatenate 4 | from .image_augmentation import RandomCrop 5 | from .normalization import GaussianActionNormalizer, MinMaxActionNormalizer, RunningObservationNormalizer 6 | -------------------------------------------------------------------------------- /research/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Register dataset classes here 2 | from .d4rl_dataset import D4RLDataset 3 | from .replay_buffer.buffer import ReplayBuffer 4 | from .robomimic_dataset import RobomimicDataset 5 | from .rollout_buffer import RolloutBuffer 6 | from .wgcsl_dataset import WGCSLDataset 7 | -------------------------------------------------------------------------------- /research/algs/__init__.py: -------------------------------------------------------------------------------- 1 | # Register Algorithms here. 2 | 3 | from .offline.bc import BehaviorCloning 4 | from .offline.dp import DiffusionPolicy 5 | from .offline.idql import IDQL 6 | from .offline.iql import IQL 7 | from .online.dqn import DQN, DoubleDQN, SoftDoubleDQN, SoftDQN 8 | from .online.drqv2 import DRQV2 9 | from .online.ppo import PPO, AdaptiveKLPPO 10 | from .online.sac import SAC 11 | from .online.td3 import TD3 12 | -------------------------------------------------------------------------------- /research/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Register Network Classes here. 2 | from .base import ActorCriticPolicy, ActorCriticValuePolicy, ActorPolicy, ActorValuePolicy, MultiEncoder 3 | from .diffusion import ConditionalUnet1D, MLPResNet 4 | from .drqv2 import DrQv2Actor, DrQv2Critic, DrQv2Encoder, DrQv2Value 5 | from .mlp import ( 6 | ContinuousMLPActor, 7 | ContinuousMLPCritic, 8 | DiagonalGaussianMLPActor, 9 | DiscreteMLPCritic, 10 | GaussianMixtureMLPActor, 11 | MLPEncoder, 12 | MLPValue, 13 | ) 14 | from .resnet import RobomimicEncoder 15 | from .transformer import StateTransformerEncoder 16 | -------------------------------------------------------------------------------- /tools/cleanup.py: -------------------------------------------------------------------------------- 1 | # This script cleans up all the temporary files used by the research codebase. 2 | 3 | import os 4 | import shutil 5 | 6 | if __name__ == "__main__": 7 | base_path = "/tmp/" 8 | 9 | job_scripts_removed = 0 10 | replay_buffers_removed = 0 11 | sweeper_configs_removed = 0 12 | 13 | for name in os.listdir(base_path): 14 | path = os.path.join(base_path, name) 15 | try: 16 | if name.startswith("job_"): 17 | os.remove(path) 18 | job_scripts_removed += 1 19 | elif name.startswith("config_"): 20 | os.remove(path) 21 | sweeper_configs_removed += 1 22 | elif name.startswith("replay_buffer_"): 23 | shutil.rmtree(path) 24 | replay_buffers_removed += 1 25 | except OSError: 26 | continue 27 | 28 | print("Finished Cleanup.") 29 | print("Removed", job_scripts_removed, "job scripts.") 30 | print("Removed", sweeper_configs_removed, "sweeper configs.") 31 | print("Removed", replay_buffers_removed, "replay buffers.") 32 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | 4 | exclude: ".git" 5 | default_stages: 6 | - commit 7 | 8 | repos: 9 | 10 | - repo: https://github.com/hadialqattan/pycln 11 | rev: v2.1.5 12 | hooks: 13 | - id: pycln 14 | 15 | - repo: https://github.com/timothycrosley/isort 16 | rev: 5.12.0 17 | hooks: 18 | - id: isort 19 | 20 | - repo: https://github.com/psf/black 21 | rev: 23.3.0 22 | hooks: 23 | - id: black 24 | 25 | - repo: https://github.com/pre-commit/pre-commit-hooks 26 | rev: v4.4.0 27 | hooks: 28 | - id: check-ast 29 | - id: trailing-whitespace 30 | - id: end-of-file-fixer 31 | - id: check-yaml 32 | - id: check-toml 33 | - id: check-merge-conflict 34 | - id: check-case-conflict 35 | - id: check-added-large-files 36 | - id: debug-statements 37 | 38 | - repo: https://github.com/charliermarsh/ruff-pre-commit 39 | rev: v0.0.274 40 | hooks: 41 | - id: ruff 42 | args: [ --fix, --exit-non-zero-on-fix ] 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Joey Hejna 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /environment_m1.yaml: -------------------------------------------------------------------------------- 1 | name: research 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python 7 | 8 | # Development 9 | - conda-forge::pre-commit 10 | - Cython<3.0 11 | 12 | # pytorch 13 | - pytorch::pytorch 14 | - torchvision 15 | - torchtext 16 | - torchaudio 17 | 18 | # NP Family 19 | - numpy 20 | - scipy 21 | - scikit-image 22 | 23 | # IO 24 | - imageio 25 | - pillow 26 | - pyyaml 27 | - cloudpickle 28 | - h5py 29 | - absl-py 30 | - pyparsing 31 | 32 | # Plotting 33 | - tensorboard 34 | - pandas 35 | - matplotlib 36 | - seaborn 37 | 38 | # Other 39 | - pytest 40 | - tqdm 41 | - future 42 | 43 | # For Robosuite (But install robosuite via source) 44 | - numba 45 | 46 | - pip 47 | - pip: 48 | - gym==0.23.1 49 | - gym-robotics==0.1.0 50 | - mujoco-py<2.2,>=2.0 51 | # - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb 52 | # The D4RL dependency does not yet support python 3.11 via pip. Instead run this command outside of conda. 53 | # - git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl --ignore-requires-python 54 | -------------------------------------------------------------------------------- /environment_cpu.yaml: -------------------------------------------------------------------------------- 1 | name: research 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python 7 | 8 | # Development 9 | - conda-forge::pre-commit 10 | - Cython<3.0 11 | 12 | # pytorch 13 | - pytorch 14 | - cpuonly 15 | - torchvision 16 | - torchtext 17 | - torchaudio 18 | 19 | # NP Family 20 | - numpy 21 | - scipy 22 | - scikit-image 23 | 24 | # IO 25 | - imageio 26 | - pillow 27 | - pyyaml 28 | - cloudpickle 29 | - h5py 30 | - absl-py 31 | - pyparsing 32 | 33 | # Plotting 34 | - tensorboard 35 | - pandas 36 | - matplotlib 37 | - seaborn 38 | 39 | # Other 40 | - pytest 41 | - tqdm 42 | - future 43 | 44 | # For Robosuite (But install robosuite via source) 45 | - numba 46 | 47 | - pip 48 | - pip: 49 | - gym==0.23.1 50 | - gym-robotics==0.1.0 51 | - mujoco-py<2.2,>=2.0 52 | - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb 53 | # The D4RL dependency does not yet support python 3.11 via pip. Instead run this command outside of conda. 54 | # - git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl --ignore-requires-python 55 | -------------------------------------------------------------------------------- /environment_gpu.yaml: -------------------------------------------------------------------------------- 1 | name: research 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - python 8 | 9 | # Development 10 | - conda-forge::pre-commit 11 | - Cython<3.0 12 | 13 | # pytorch 14 | - pytorch 15 | - pytorch-cuda=11.7 16 | - torchvision 17 | - torchtext 18 | - torchaudio 19 | 20 | # NP Family 21 | - numpy 22 | - scipy 23 | - scikit-image 24 | 25 | # IO 26 | - imageio 27 | - pillow 28 | - pyyaml 29 | - cloudpickle 30 | - h5py 31 | - absl-py 32 | - pyparsing 33 | 34 | # Plotting 35 | - tensorboard 36 | - pandas 37 | - matplotlib 38 | - seaborn 39 | 40 | # Other 41 | - pytest 42 | - tqdm 43 | - future 44 | 45 | # For Robosuite (But install robosuite via source) 46 | - numba 47 | 48 | - pip 49 | - pip: 50 | - gym==0.23.1 51 | - gym-robotics==0.1.0 52 | - mujoco-py<2.2,>=2.0 53 | - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb 54 | # The D4RL dependency does not yet support python 3.11 via pip. Instead run this command outside of conda. 55 | # - git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl --ignore-requires-python 56 | -------------------------------------------------------------------------------- /research/utils/schedules.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file contains schedule functions that can be used as learning rate schedules 3 | 4 | All learning rate schedulers use the pytorch LambdaLR function and any additional kwargs. 5 | """ 6 | import math 7 | 8 | 9 | def linear_decay(total_steps: int, start_step: int = 1, offset: int = 0): 10 | def fn(step): 11 | return 1.0 - max(0, step + offset - start_step) / (total_steps - start_step) 12 | 13 | return fn 14 | 15 | 16 | def linear_warmup(total_steps: int, multiplier: float = 1.0): 17 | def fn(step): 18 | return multiplier * min(1.0, step / total_steps) 19 | 20 | return fn 21 | 22 | 23 | def cosine_with_linear_warmup(warmup_steps: int, total_steps: int, num_cycles: float = 0.5, min_lr_ratio=1e-1): 24 | def fn(step): 25 | step = min(step, total_steps) 26 | if step < warmup_steps: 27 | return float(step) / float(max(1, warmup_steps)) 28 | progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps)) 29 | out = max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) 30 | if out < min_lr_ratio: 31 | return min_lr_ratio 32 | else: 33 | return out 34 | 35 | return fn 36 | -------------------------------------------------------------------------------- /configs/examples/dqn.yaml: -------------------------------------------------------------------------------- 1 | # Example Config that uses almost all values 2 | 3 | alg: DoubleDQN 4 | alg_kwargs: 5 | tau: 1.0 6 | target_freq: 1000 7 | max_grad_norm: 10 8 | eps_start: 1.0 9 | eps_end: 0.05 10 | eps_frac: 0.1 11 | random_steps: 1000 12 | 13 | optim: Adam 14 | optim_kwargs: 15 | lr: 0.00025 16 | 17 | network: DiscreteMLPCritic 18 | network_kwargs: 19 | hidden_layers: [256, 256] 20 | act: ["import", "torch.nn", "ReLU"] 21 | 22 | env: CartPole-v1 23 | 24 | dataset: ReplayBuffer 25 | dataset_kwargs: 26 | sample_fn: sample_qlearning 27 | sample_kwargs: 28 | discount: 0.99 29 | nstep: 1 30 | batch_size: 64 31 | capacity: 500000 32 | fetch_every: 500 33 | 34 | processor: null 35 | 36 | trainer_kwargs: # Arguments given to Algorithm.train 37 | total_steps: 500000 # The total number of steps to train 38 | log_freq: 250 # How often to log values 39 | profile_freq: 100 40 | env_freq: 0.25 41 | eval_freq: 1000 # How often to run evals 42 | eval_fn: eval_policy 43 | eval_kwargs: 44 | num_ep: 10 # Number of enviornment episodes to run for evaluation 45 | loss_metric: reward # The validation metric that determines when to save the "best_checkpoint" 46 | train_dataloader_kwargs: 47 | num_workers: 0 # Number of dataloader workers. 48 | batch_size: null 49 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "research" 7 | version = "0.0.1" 8 | description = "Research experimentation package." 9 | readme = "README.md" 10 | requires-python = ">=3.9" 11 | authors = [ 12 | { name = "Joey Hejna", email = "jhejna@cs.stanford.edu" }, 13 | ] 14 | license = { file = "LICENSE" } 15 | 16 | classifiers = [ 17 | "Topic :: Research", 18 | "Private :: Do Not Upload" 19 | ] 20 | 21 | [dependencies] 22 | 23 | [project.urls] 24 | homepage = "https://github.com/jhejna/research-lightning" 25 | repository = "https://github.com/jhejna/research-lightning" 26 | documentation = "https://github.com/jhejna/research-lightning" 27 | 28 | [tool.black] 29 | line-length = 120 30 | target-version = ["py39", "py310", "py311"] 31 | preview = true 32 | 33 | [tool.ruff] 34 | line-length = 120 35 | target-version = "py39" 36 | select = ["A", "B", "C90", "E", "F", "I", "RUF", "W"] 37 | ignore = ["A002", "A003", "B027", "C901", "RUF012"] 38 | 39 | [tool.ruff.per-file-ignores] 40 | "__init__.py" = ["E402", "F401"] 41 | 42 | [tool.isort] 43 | profile = "black" 44 | line_length = 120 45 | skip = ["__init__.py"] 46 | filter_files = true 47 | py_version = "all" 48 | 49 | 50 | [tool.setuptools.packages.find] 51 | where = ["."] 52 | exclude = ["cache"] 53 | -------------------------------------------------------------------------------- /configs/examples/bc.yaml: -------------------------------------------------------------------------------- 1 | alg: BehaviorCloning 2 | alg_kwargs: 3 | # Configure offline steps. These aren't needed, but good to set. 4 | offline_steps: -1 5 | random_steps: 0 6 | 7 | optim: Adam 8 | optim_kwargs: 9 | lr: 0.0003 10 | 11 | network: ActorPolicy 12 | network_kwargs: 13 | actor_class: ContinuousMLPActor 14 | hidden_layers: [256, 256] 15 | ortho_init: True 16 | 17 | # Example config using Robomimic. 18 | eval_env: RobomimicEnv 19 | eval_env_kwargs: 20 | path: path/to/robomimic/dataset 21 | horizon: 500 22 | 23 | dataset: RobomimicDataset 24 | dataset_kwargs: 25 | path: path/to/robomimic/dataset 26 | sample_fn: sample 27 | sample_kwargs: 28 | batch_size: 256 29 | validation_dataset_kwargs: 30 | train: False 31 | 32 | schedule: null 33 | processor: Concatenate 34 | 35 | trainer_kwargs: # Arguments given to Algorithm.train 36 | total_steps: 1000000 # The total number of steps to train 37 | log_freq: 250 # How often to log values 38 | profile_freq: 100 39 | eval_freq: 10000 # How often to run evals 40 | eval_fn: eval_policy 41 | eval_kwargs: 42 | num_ep: 10 # Number of enviornment episodes to run for evaluation, or -1 if none should be run. 43 | loss_metric: reward # The validation metric that determines when to save the "best_checkpoint" 44 | max_validation_steps: 25 # Will run forever otherwise due to continuous replay buffer iter. 45 | train_dataloader_kwargs: 46 | num_workers: 0 # Number of dataloader workers. 47 | batch_size: null 48 | -------------------------------------------------------------------------------- /configs/examples/drqv2.yaml: -------------------------------------------------------------------------------- 1 | # NOTE: DrQv2 Implementation is UNTESTED!!!!!!!!!! 2 | 3 | alg: DRQV2 4 | alg_kwargs: 5 | tau: 0.01 6 | critic_freq: 1 7 | actor_freq: 1 8 | target_freq: 1 9 | noise_clip: 0.3 10 | std_schedule: [1.0, 0.1, 500000] 11 | init_steps: 4000 12 | 13 | optim: Adam 14 | optim_kwargs: 15 | lr: 0.0001 16 | 17 | network: ActorCriticPolicy 18 | network_kwargs: 19 | actor_class: DrQv2Actor 20 | actor_kwargs: 21 | feature_dim: 50 22 | hidden_layers: [1024, 1024] 23 | critic_class: DrQv2Critic 24 | critic_kwargs: 25 | feature_dim: 50 26 | hidden_layers: [1024, 1024] 27 | encoder_class: DrQv2Encoder 28 | 29 | env: CheetahRun-vision-v0 30 | 31 | dataset: ReplayBuffer 32 | dataset_kwargs: 33 | sample_fn: sample_qlearning 34 | sample_kwargs: 35 | discount: 0.99 36 | nstep: 3 37 | batch_size: 256 38 | capacity: 1000000 39 | fetch_every: 1000 40 | 41 | processor: RandomCrop # Add in the data augmentations! 42 | 43 | trainer_kwargs: # Arguments given to Algorithm.train 44 | total_steps: 3100000 # The total number of steps to train 45 | log_freq: 250 # How often to log values 46 | profile_freq: 200 47 | eval_freq: 10000 # How often to run evals 48 | eval_fn: eval_policy 49 | eval_kwargs: 50 | num_ep: 10 # Number of enviornment episodes to run for evaluation 51 | loss_metric: reward # The validation metric that determines when to save the "best_checkpoint" 52 | train_dataloader_kwargs: 53 | num_workers: 4 # Number of dataloader workers. 54 | -------------------------------------------------------------------------------- /configs/examples/td3.yaml: -------------------------------------------------------------------------------- 1 | # Example Config that uses almost all values 2 | 3 | alg: TD3 4 | alg_kwargs: 5 | tau: 0.005 6 | policy_noise: 0.1 7 | target_noise: 0.2 8 | noise_clip: 0.5 9 | critic_freq: 1 10 | actor_freq: 2 11 | target_freq: 2 12 | random_steps: 10000 13 | 14 | optim: Adam 15 | optim_kwargs: 16 | lr: 0.001 17 | 18 | network: ActorCriticPolicy 19 | network_kwargs: 20 | actor_class: ContinuousMLPActor 21 | actor_kwargs: 22 | hidden_layers: [256, 256] 23 | output_act: ["import", "torch.nn", "Tanh"] 24 | critic_class: ContinuousMLPCritic 25 | critic_kwargs: 26 | hidden_layers: [256, 256] 27 | ensemble_size: 2 28 | ortho_init: true 29 | 30 | env: CheetahRun-v0 31 | 32 | dataset: ReplayBuffer 33 | dataset_kwargs: 34 | sample_fn: sample_qlearning 35 | sample_kwargs: 36 | discount: 0.99 37 | nstep: 1 38 | batch_size: 1 39 | sample_by_timesteps: False 40 | capacity: 1000000 41 | fetch_every: 1000 42 | distributed: true 43 | 44 | processor: null 45 | 46 | trainer_kwargs: # Arguments given to Algorithm.train 47 | total_steps: 1000000 # The total number of steps to train 48 | log_freq: 250 # How often to log values 49 | profile_freq: 100 50 | eval_freq: 10000 # How often to run evals 51 | eval_fn: eval_policy 52 | eval_kwargs: 53 | num_ep: 10 # Number of enviornment episodes to run for evaluation, or -1 if none should be run. 54 | loss_metric: reward # The validation metric that determines when to save the "best_checkpoint" 55 | train_dataloader_kwargs: 56 | batch_size: 256 57 | num_workers: 4 # Number of dataloader workers. 58 | -------------------------------------------------------------------------------- /configs/examples/gcsl.yaml: -------------------------------------------------------------------------------- 1 | alg: BehaviorCloning 2 | alg_kwargs: 3 | # Configure offline steps. These aren't needed, but good to set. 4 | offline_steps: -1 5 | random_steps: 0 6 | 7 | optim: Adam 8 | optim_kwargs: 9 | lr: 0.0003 10 | 11 | network: ActorPolicy 12 | network_kwargs: 13 | actor_class: ContinuousMLPActor 14 | hidden_layers: [256, 256, 256] 15 | ortho_init: True 16 | 17 | # Example config using Robomimic. 18 | eval_env: FetchPickAndPlace-v1 19 | 20 | dataset: WGCSLDataset 21 | dataset_kwargs: 22 | # path: path/to/wgcsl/dataset 23 | path: [../datasets/offline_goal_conditioned_data/expert/FetchPick/buffer.pkl, ../datasets/offline_goal_conditioned_data/random/FetchPick/buffer.pkl] 24 | percents: [0.9, 0.1] 25 | sample_fn: sample_her 26 | sample_kwargs: 27 | sample_by_timesteps: False 28 | batch_size: 256 29 | relabel_fraction: 1.0 30 | strategy: future 31 | 32 | schedule: null 33 | processor: Concatenate # Concatenates all goal keys together 34 | 35 | trainer_kwargs: # Arguments given to Algorithm.train 36 | total_steps: 1000000 # The total number of steps to train 37 | log_freq: 250 # How often to log values 38 | profile_freq: 100 39 | eval_freq: 10000 # How often to run evals 40 | eval_fn: eval_policy 41 | eval_kwargs: 42 | num_ep: 10 # Number of enviornment episodes to run for evaluation, or -1 if none should be run. 43 | loss_metric: reward # The validation metric that determines when to save the "best_checkpoint" 44 | max_validation_steps: 25 # Will run forever otherwise due to continuous replay buffer iter. 45 | train_dataloader_kwargs: 46 | num_workers: 0 # Number of dataloader workers. 47 | batch_size: null 48 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import subprocess 4 | 5 | from research.utils.config import Config 6 | 7 | 8 | def try_wandb_setup(path, config): 9 | wandb_api_key = os.getenv("WANDB_API_KEY") 10 | if wandb_api_key is not None and wandb_api_key != "": 11 | try: 12 | import wandb 13 | except ImportError: 14 | return 15 | project_dir = os.path.dirname(os.path.dirname(__file__)) 16 | wandb.init( 17 | project=os.path.basename(project_dir), 18 | name=os.path.basename(path), 19 | config=config.flatten(separator="-"), 20 | dir=os.path.join(os.path.dirname(project_dir), "wandb"), 21 | ) 22 | 23 | 24 | if __name__ == "__main__": 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--config", "-c", type=str, default=None) 27 | parser.add_argument("--path", "-p", type=str, default=None) 28 | parser.add_argument("--device", "-d", type=str, default="auto") 29 | args = parser.parse_args() 30 | 31 | config = Config.load(args.config) 32 | os.makedirs(args.path, exist_ok=True) 33 | try_wandb_setup(args.path, config) 34 | config.save(args.path) # Save the config 35 | # save the git hash 36 | process = subprocess.Popen(["git", "rev-parse", "HEAD"], shell=False, stdout=subprocess.PIPE) 37 | git_head_hash = process.communicate()[0].strip() 38 | with open(os.path.join(args.path, "git_hash.txt"), "wb") as f: 39 | f.write(git_head_hash) 40 | # Parse the config file to resolve names. 41 | config = config.parse() 42 | # Get everything at once. 43 | trainer = config.get_trainer(device=args.device) 44 | # Train the model 45 | trainer.train(args.path) 46 | -------------------------------------------------------------------------------- /configs/examples/franka_reach_sac.yaml: -------------------------------------------------------------------------------- 1 | # Example Config that uses almost all values 2 | 3 | alg: SAC 4 | alg_kwargs: 5 | tau: 0.005 6 | init_temperature: 0.1 7 | critic_freq: 1 8 | actor_freq: 1 9 | target_freq: 2 10 | random_steps: 400 11 | 12 | optim: Adam 13 | optim_kwargs: 14 | lr: 0.0001 15 | 16 | network: ActorCriticPolicy 17 | network_kwargs: 18 | actor_class: DiagonalGaussianMLPActor 19 | actor_kwargs: 20 | hidden_layers: [1024, 1024] 21 | log_std_bounds: [-5, 2] 22 | critic_class: ContinuousMLPCritic 23 | critic_kwargs: 24 | hidden_layers: [1024, 1024] 25 | ensemble_size: 2 26 | ortho_init: true 27 | 28 | batch_size: null # Use serial replay buffer 29 | collate_fn: null # The collate function passed to the dataloader. None uses pytorch default. 30 | checkpoint: null # A checkpoint to initialize the network from. 31 | 32 | env: FrankaReach 33 | env_kwargs: 34 | control_hz: 10.0 35 | ip_address: 172.16.0.130 36 | controller: cartesian_delta 37 | 38 | dataset: ReplayBuffer 39 | dataset_kwargs: 40 | sample_fn: sample_qlearning 41 | sample_kwargs: 42 | discount: 0.99 43 | nstep: 1 44 | batch_size: 1024 45 | capacity: 1000000 46 | fetch_every: 1000 47 | distributed: False 48 | 49 | processor: Concatenate 50 | 51 | trainer_kwargs: # Arguments given to Algorithm.train 52 | total_steps: 1000000 # The total number of steps to train 53 | log_freq: 100 # How often to log values 54 | profile_freq: 100 55 | env_runner: AsyncEnv # This makes the environment run asynchronized! 56 | env_freq: 2 57 | eval_fn: null 58 | loss_metric: reward # The validation metric that determines when to save the "best_checkpoint" 59 | train_dataloader_kwargs: 60 | batch_size: null 61 | num_workers: 0 # Number of dataloader workers. 62 | -------------------------------------------------------------------------------- /configs/examples/ppo.yaml: -------------------------------------------------------------------------------- 1 | # Example Config that uses almost all values 2 | 3 | alg: PPO 4 | alg_kwargs: 5 | clip_range: 0.2 6 | clip_range_vf: 0.5 7 | num_epochs: 10 8 | normalize_advantage: True 9 | ent_coeff: 0.0 10 | vf_coeff: 0.5 11 | normalize_returns: True 12 | reward_clip: 10 13 | 14 | optim: Adam 15 | optim_kwargs: 16 | lr: 0.0003 17 | 18 | network: ActorValuePolicy 19 | network_kwargs: 20 | actor_class: DiagonalGaussianMLPActor 21 | actor_kwargs: 22 | hidden_layers: [64, 64] 23 | log_std_bounds: null 24 | act: ["import", "torch.nn", "Tanh"] 25 | state_dependent_log_std: False 26 | squash_normal: False 27 | log_std_tanh: False 28 | ortho_init: 1.41421356237 29 | output_gain: 0.01 30 | value_class: MLPValue 31 | value_kwargs: 32 | hidden_layers: [64, 64] 33 | act: ["import", "torch.nn", "Tanh"] 34 | ortho_init: 1.41421356237 35 | output_gain: 0.01 36 | 37 | env: HalfCheetah-v2 38 | 39 | dataset: RolloutBuffer 40 | dataset_kwargs: 41 | discount: 0.99 42 | gae_lambda: 0.95 43 | capacity: 2048 44 | batch_size: 64 45 | 46 | processor: RunningObservationNormalizer 47 | processor_kwargs: 48 | clip: 10 49 | explicit_update: True 50 | 51 | trainer_kwargs: # Arguments given to Algorithm.train 52 | total_steps: 1000000 # The total number of steps to train 53 | log_freq: 1000 # How often to log values 54 | eval_freq: 10000 # How often to run evals 55 | profile_freq: 100 56 | eval_fn: eval_policy 57 | eval_kwargs: 58 | num_ep: 10 # Number of enviornment episodes to run for evaluation, or -1 if none should be run. 59 | loss_metric: reward # The validation metric that determines when to save the "best_checkpoint" 60 | train_dataloader_kwargs: 61 | num_workers: 0 # Number of dataloader workers. 62 | batch_size: null 63 | -------------------------------------------------------------------------------- /configs/examples/sac.yaml: -------------------------------------------------------------------------------- 1 | # Example Config that uses almost all values 2 | 3 | alg: SAC 4 | alg_kwargs: 5 | tau: 0.005 6 | init_temperature: 0.1 7 | critic_freq: 1 8 | actor_freq: 1 9 | target_freq: 2 10 | random_steps: 5000 11 | 12 | optim: Adam 13 | optim_kwargs: 14 | lr: 0.0001 15 | 16 | network: ActorCriticPolicy 17 | network_kwargs: 18 | actor_class: DiagonalGaussianMLPActor 19 | actor_kwargs: 20 | hidden_layers: [1024, 1024] 21 | log_std_bounds: [-5, 2] 22 | critic_class: ContinuousMLPCritic 23 | critic_kwargs: 24 | hidden_layers: [1024, 1024] 25 | ensemble_size: 2 26 | ortho_init: true 27 | 28 | batch_size: null # Use serial replay buffer 29 | collate_fn: null # The collate function passed to the dataloader. None uses pytorch default. 30 | checkpoint: null # A checkpoint to initialize the network from. 31 | 32 | env: CheetahRun-v0 33 | 34 | dataset: ReplayBuffer 35 | dataset_kwargs: 36 | sample_fn: sample_qlearning 37 | sample_kwargs: 38 | discount: 0.99 39 | nstep: 1 40 | batch_size: 1024 41 | sample_by_timesteps: False 42 | capacity: 1000000 43 | fetch_every: 1000 44 | distributed: False 45 | 46 | processor: null 47 | 48 | trainer_kwargs: # Arguments given to Algorithm.train 49 | total_steps: 1000000 # The total number of steps to train 50 | log_freq: 250 # How often to log values 51 | profile_freq: 100 52 | env_runner: null # Set to AsyncEnv to run the environment run asynchronized! 53 | eval_freq: 10000 # How often to run evals 54 | eval_fn: eval_policy 55 | eval_kwargs: 56 | num_ep: 10 # Number of enviornment episodes to run for evaluation 57 | loss_metric: reward # The validation metric that determines when to save the "best_checkpoint" 58 | train_dataloader_kwargs: 59 | batch_size: null 60 | num_workers: 0 # Number of dataloader workers. 61 | -------------------------------------------------------------------------------- /setup_shell.sh: -------------------------------------------------------------------------------- 1 | # Make sure we have the conda environment set up. 2 | CONDA_PATH=~/miniconda3/bin/activate 3 | ENV_NAME=research 4 | REPO_PATH=path/to/your/repo 5 | USE_MUJOCO_PY=false # For using mujoco py 6 | WANDB_API_KEY="" # If you want to use wandb, set this to your API key. 7 | 8 | # Setup Conda 9 | source $CONDA_PATH 10 | conda activate $ENV_NAME 11 | cd $REPO_PATH 12 | unset DISPLAY # Make sure display is not set or it will prevent scripts from running in headless mode. 13 | 14 | if $WANDB_API_KEY; then 15 | export WANDB_API_KEY=$WANDB_API_KEY 16 | fi 17 | 18 | if $USE_MUJOCO_PY; then 19 | echo "Using mujoco_py" 20 | if [ -d "/usr/lib/nvidia" ]; then 21 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/lib/nvidia 22 | fi 23 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco210/bin 24 | export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:~/.mujoco/mujoco200/bin 25 | fi 26 | 27 | # First check if we have a GPU available 28 | if nvidia-smi | grep "CUDA Version"; then 29 | if [ -d "/usr/local/cuda-11.8" ]; then # This is the only GPU version supported by compile. 30 | export PATH=/usr/local/cuda-11.8/bin:$PATH 31 | elif [ -d "/usr/local/cuda-11.7" ]; then # This is the only GPU version supported by compile. 32 | export PATH=/usr/local/cuda-11.7/bin:$PATH 33 | elif [ -d "/usr/local/cuda" ]; then 34 | export PATH=/usr/local/cuda/bin:$PATH 35 | echo "Using default CUDA. Compatibility should be verified. torch.compile requires >= 11.7" 36 | else 37 | echo "Warning: Could not find a CUDA version but GPU was found." 38 | fi 39 | export MUJOCO_GL="egl" 40 | # Setup any GPU specific flags 41 | else 42 | echo "GPU was not found, assuming CPU setup." 43 | export MUJOCO_GL="osmesa" # glfw doesn't support headless rendering 44 | fi 45 | 46 | export D4RL_SUPPRESS_IMPORT_ERROR=1 47 | -------------------------------------------------------------------------------- /configs/examples/idql.yaml: -------------------------------------------------------------------------------- 1 | alg: IDQL 2 | alg_kwargs: 3 | tau: 0.005 4 | target_freq: 1 5 | expectile: 0.7 6 | beta: 0.333333333 7 | # Configure offline steps. These aren't needed, but good to set. 8 | offline_steps: -1 9 | random_steps: 0 10 | 11 | optim: Adam 12 | optim_kwargs: 13 | lr: 0.0003 14 | 15 | network: ActorCriticValuePolicy 16 | network_kwargs: 17 | actor_class: MLPResNet 18 | actor_kwargs: 19 | act: ["import", "torch.nn", "ReLU"] 20 | num_blocks: 2 21 | hidden_dim: 128 22 | critic_class: ContinuousMLPCritic 23 | critic_kwargs: 24 | ensemble_size: 2 25 | hidden_layers: [256, 256] 26 | ortho_init: True 27 | value_class: MLPValue 28 | value_kwargs: 29 | ensemble_size: 1 30 | hidden_layers: [256, 256] 31 | ortho_init: True 32 | 33 | eval_env: hopper-medium-replay-v2 34 | 35 | dataset: D4RLDataset 36 | dataset_kwargs: 37 | d4rl_path: ../datasets/d4rl/ 38 | name: hopper-medium-replay-v2 39 | distributed: False 40 | sample_fn: sample_qlearning 41 | sample_kwargs: 42 | batch_size: 256 43 | discount: 0.99 44 | normalize_reward: True 45 | reward_scale: 1000.0 # scale to 1000 like in IQL 46 | use_rtg: True 47 | use_timesteps: True 48 | action_eps: 0.00001 # necesary to prevent NaN in the dataset. 49 | 50 | schedule: 51 | actor: ["import", "torch.optim.lr_scheduler", "CosineAnnealingLR"] 52 | schedule_kwargs: 53 | actor: 54 | T_max: 1000000 55 | 56 | processor: null 57 | 58 | trainer_kwargs: # Arguments given to Algorithm.train 59 | total_steps: 1000000 # The total number of steps to train 60 | log_freq: 250 # How often to log values 61 | profile_freq: 100 62 | eval_freq: 10000 # How often to run evals 63 | eval_fn: eval_policy 64 | eval_kwargs: 65 | num_ep: 10 # Number of enviornment episodes to run for evaluation, or -1 if none should be run. 66 | loss_metric: reward # The validation metric that determines when to save the "best_checkpoint" 67 | train_dataloader_kwargs: 68 | num_workers: 0 # Number of dataloader workers. 69 | batch_size: null 70 | -------------------------------------------------------------------------------- /configs/examples/iql.yaml: -------------------------------------------------------------------------------- 1 | alg: IQL 2 | alg_kwargs: 3 | tau: 0.005 4 | target_freq: 1 5 | expectile: 0.7 6 | beta: 0.333333333 7 | # Configure offline steps. These aren't needed, but good to set. 8 | offline_steps: -1 9 | random_steps: 0 10 | 11 | optim: Adam 12 | optim_kwargs: 13 | lr: 0.0003 14 | 15 | network: ActorCriticValuePolicy 16 | network_kwargs: 17 | actor_class: DiagonalGaussianMLPActor 18 | actor_kwargs: 19 | log_std_bounds: [-5, 2] 20 | dropout: 0.0 # only actor gets dropout sometimes. 21 | output_act: ["import", "torch.nn", "Tanh"] 22 | state_dependent_log_std: False 23 | log_std_tanh: False 24 | squash_normal: False 25 | critic_class: ContinuousMLPCritic 26 | critic_kwargs: 27 | ensemble_size: 2 28 | value_class: MLPValue 29 | value_kwargs: 30 | ensemble_size: 1 31 | hidden_layers: [256, 256] 32 | ortho_init: True 33 | 34 | eval_env: hopper-medium-replay-v2 35 | 36 | dataset: D4RLDataset 37 | dataset_kwargs: 38 | d4rl_path: ../datasets/d4rl/ 39 | name: hopper-medium-replay-v2 40 | distributed: False 41 | sample_fn: sample_qlearning 42 | sample_kwargs: 43 | batch_size: 256 44 | discount: 0.99 45 | normalize_reward: True 46 | reward_scale: 1000.0 # scale to 1000 like in IQL 47 | use_rtg: True 48 | use_timesteps: True 49 | action_eps: 0.00001 # necesary to prevent NaN in the dataset. 50 | 51 | schedule: 52 | actor: ["import", "torch.optim.lr_scheduler", "CosineAnnealingLR"] 53 | schedule_kwargs: 54 | actor: 55 | T_max: 1000000 56 | 57 | processor: null 58 | 59 | trainer_kwargs: # Arguments given to Algorithm.train 60 | total_steps: 1000000 # The total number of steps to train 61 | log_freq: 250 # How often to log values 62 | profile_freq: 100 63 | eval_freq: 10000 # How often to run evals 64 | eval_fn: eval_policy 65 | eval_kwargs: 66 | num_ep: 10 # Number of enviornment episodes to run for evaluation, or -1 if none should be run. 67 | loss_metric: reward # The validation metric that determines when to save the "best_checkpoint" 68 | train_dataloader_kwargs: 69 | num_workers: 0 # Number of dataloader workers. 70 | batch_size: null 71 | -------------------------------------------------------------------------------- /environment_polymetis.yaml: -------------------------------------------------------------------------------- 1 | name: research 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - fair-robotics 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - python=3.9 # :( unfortunately locked at this by build deps. Faster to set here than search over versions. 10 | 11 | # Development 12 | - conda-forge::pre-commit 13 | 14 | # pytorch 15 | - pytorch 16 | - pytorch-cuda=11.8 # NOTE: this is for the workstation ONLY. NUC Should have monometis installed with cpu-only. 17 | - torchvision 18 | - torchtext 19 | - torchaudio 20 | 21 | # NP Family 22 | - numpy<1.24.0 # Must install lower version of numpy due to copy bug in python 3.9 :( 23 | - scipy 24 | - scikit-image 25 | 26 | # IO 27 | - imageio 28 | - pillow 29 | - pyyaml 30 | - cloudpickle 31 | - h5py 32 | - absl-py 33 | - pyparsing 34 | 35 | # Plotting 36 | - tensorboard 37 | - pandas 38 | - matplotlib 39 | - seaborn 40 | 41 | # Other 42 | - pytest 43 | - tqdm 44 | - future 45 | 46 | # For Robosuite (But install robosuite via source) 47 | - numba 48 | 49 | # Polymetis build dependencies 50 | - assimp=5.0.1=hdca8b6f_4 51 | - cmake 52 | - doxygen 53 | - eigen 54 | - grpc-cpp=1.41.1 55 | - hpp-fcl 56 | - libprotobuf=3.18.1 57 | - openmpi 58 | - pinocchio 59 | - urdfdom=2.3.3=hc9558a2_0 60 | - urdfdom_headers=1.0.5=hc9558a2_2 61 | - yaml-cpp 62 | 63 | # Polymetis run Dependencies 64 | - boost=1.72.0 65 | - boost-cpp=1.72.0 66 | - breathe 67 | - dash 68 | - grpcio=1.46.0 69 | - hydra-core 70 | - importlib-resources 71 | - myst-parser 72 | - protobuf 73 | - pymodbus 74 | - pyserial 75 | - spdlog=1.10.0=h924138e_0 76 | 77 | - pip 78 | - pip: 79 | - gym==0.23.1 80 | - mujoco-py<2.2,>=2.0 81 | - git+https://github.com/Farama-Foundation/Metaworld.git@04be337a12305e393c0caf0cbf5ec7755c7c8feb 82 | - git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl 83 | # NOTE: One of these commands (not sure which) installs protobuf OVER the conda version 84 | # This will break polymetis. After install run the following commands: 85 | # 1. pip uninstall protobuf 86 | # 2. conda install protobuf 87 | # This should fix the installation 88 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Development files 2 | propogate_from_main.sh 3 | demos/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | *$py.class 9 | 10 | # C extensions 11 | *.so 12 | 13 | # Distribution / packaging 14 | .Python 15 | build/ 16 | develop-eggs/ 17 | dist/ 18 | downloads/ 19 | eggs/ 20 | .eggs/ 21 | lib/ 22 | lib64/ 23 | parts/ 24 | sdist/ 25 | var/ 26 | wheels/ 27 | pip-wheel-metadata/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | 58 | # Translations 59 | *.mo 60 | *.pot 61 | 62 | # Django stuff: 63 | *.log 64 | local_settings.py 65 | db.sqlite3 66 | db.sqlite3-journal 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | -------------------------------------------------------------------------------- /configs/examples/example.yaml: -------------------------------------------------------------------------------- 1 | # Example Config that uses almost all values 2 | 3 | alg: AlgorithmName from algs/__init__.py 4 | alg_kwargs: 5 | kwarg_1: value 6 | # More keyword arguments for the algorithm... 7 | 8 | optim: OptimizerName from torch.optim 9 | optim_kwargs: 10 | lr: 0.001 11 | weight_decay: 0.05 12 | # More key word arguments for the optimizer... 13 | 14 | network: NetworkName from networks/__init__.py 15 | network_kwargs: 16 | hidden_layers: [256, 256] 17 | act: ["import", "torch.nn", "Tanh"] # A demonstration of how to import a function 18 | # More key word arguments for the network 19 | 20 | # If you are running supervised learning, the env is likely Empty and can just be used to specify input/output spaces. 21 | env: EnvironmentName from envs/__init__.py 22 | env_kwargs: 23 | kwarg_1: value 24 | # More key word arguments for the environment... 25 | 26 | dataset: DatasetName from datasets/__init__.py 27 | dataset_kwargs: 28 | 29 | validation_dataset_kwargs: 30 | # If you want a validation dataset, specify the kwargs here 31 | # If none are specified, there will be no validation dataset. 32 | 33 | processor: ProcessorName from processors/__init__.py or null 34 | # Note that unlike other configuration types, the processor is unnecesary. 35 | processor_kwargs: 36 | kwarg_1: value 37 | # More key word arguments for the processor 38 | 39 | schedule: linear_decay # Schedule function from utils/schedules.py can be null. 40 | schedule_kwargs: 41 | # if a scheduler is specified, specify its kwargs here. 42 | # total_steps is alwasy passed into it as the first argument. 43 | 44 | checkpoint: null # A checkpoint to initialize the network from. 45 | 46 | trainer_kwargs: # Arguments given to Algorithm.train 47 | total_steps: 10000 # The total number of steps to train 48 | log_freq: 25 # How often to log values 49 | profile_freq: 10 # How often to time different componetns 50 | eval_freq: 500 # How often to run evals 51 | max_validation_steps: 100 # Maximum number of steps from the validation dataset, if included 52 | loss_metric: loss # The validation metric that determines when to save the "best_checkpoint" 53 | eval_fn: eval_policy # evaluation function to run 54 | eval_kwargs: 55 | num_ep: 10 # Evaluation kwargs 56 | train_dataloader_kwargs: 57 | workers: 2 58 | batch_size: 64 59 | validation_dataloader_kwargs: 60 | workers: 0 61 | benchmark: False # whether or not to enable torch.cuddn.benchmark 62 | torch_compile: False # wether or not to use torch compile. Currently exhibits bugs -- waiting for real release. 63 | torch_compile_kwargs: 64 | mode: null # set torch compile key word args. 65 | 66 | seed: null # For manually setting the seed. 67 | -------------------------------------------------------------------------------- /research/envs/base.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import numpy as np 3 | 4 | 5 | def _get_space(low=None, high=None, shape=None, dtype=None): 6 | all_vars = [low, high, shape, dtype] 7 | if any([isinstance(v, dict) for v in all_vars]): 8 | all_keys = set() # get all the keys 9 | for v in all_vars: 10 | if isinstance(v, dict): 11 | all_keys.update(v.keys()) 12 | # Construct all the sets 13 | spaces = {} 14 | for k in all_keys: 15 | space_low = low.get(k, None) if isinstance(low, dict) else low 16 | space_high = high.get(k, None) if isinstance(high, dict) else high 17 | space_shape = shape.get(k, None) if isinstance(shape, dict) else shape 18 | space_type = dtype.get(k, None) if isinstance(dtype, dict) else dtype 19 | spaces[k] = _get_space(space_low, space_high, space_shape, space_type) 20 | # Construct the gym dict space 21 | return gym.spaces.Dict(**spaces) 22 | 23 | if shape is None and isinstance(high, int): 24 | assert low is None, "Tried to specify a discrete space with both high and low." 25 | return gym.spaces.Discrete(high) 26 | 27 | # Otherwise assume its a box. 28 | if low is None: 29 | low = -np.inf 30 | if high is None: 31 | high = np.inf 32 | if dtype is None: 33 | dtype = np.float32 34 | return gym.spaces.Box(low=low, high=high, shape=shape, dtype=dtype) 35 | 36 | 37 | class EmptyEnv(gym.Env): 38 | 39 | """ 40 | An empty holder for defining supervised learning problems 41 | It works by specifying the ranges and shapes. 42 | """ 43 | 44 | def __init__( 45 | self, 46 | observation_low=None, 47 | observation_high=None, 48 | observation_shape=None, 49 | observation_dtype=np.float32, 50 | observation_space=None, 51 | action_low=None, 52 | action_high=None, 53 | action_shape=None, 54 | action_dtype=np.float32, 55 | action_space=None, 56 | ): 57 | if observation_space is not None: 58 | self.observation_space = observation_space 59 | else: 60 | self.observation_space = _get_space(observation_low, observation_high, observation_shape, observation_dtype) 61 | if action_space is not None: 62 | self.action_space = action_space 63 | else: 64 | self.action_space = _get_space(action_low, action_high, action_shape, action_dtype) 65 | 66 | def step(self, action): 67 | raise NotImplementedError("Empty Env does not have step") 68 | 69 | def reset(self, **kwargs): 70 | raise NotImplementedError("Empty Env does not have reset") 71 | -------------------------------------------------------------------------------- /scripts/compute_action_normalization.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | 5 | import research 6 | from research.utils.config import Config 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--config", type=str, required=True, help="Path to the config.") 11 | parser.add_argument("--clip", type=float, default=4.0, help="std-dev clipping for min-max normalization") 12 | args = parser.parse_args() 13 | 14 | config = Config.load(args.config) 15 | config = config.parse() 16 | dataset_class = None if config["dataset"] is None else vars(research.datasets)[config["dataset"]] 17 | dataset_kwargs = config["dataset_kwargs"] 18 | assert issubclass( 19 | dataset_class, research.datasets.ReplayBuffer 20 | ), "Must use replay buffer for normalization computation" 21 | dataset_kwargs["distributed"] = False # Ensure that we load all of the data. 22 | observation_space, action_space = config.get_spaces() 23 | 24 | # Exclude all observations for faster loading 25 | exclude_keys = list(dataset_kwargs.get("exclude_keys", [])) 26 | exclude_keys.extend(["obs.*", "reward", "discount"]) # Cannot remove done! 27 | dataset_kwargs["exclude_keys"] = exclude_keys 28 | 29 | # Create the dataset, exclude everything but actions so we don't load it 30 | dataset = dataset_class(observation_space, action_space, **dataset_kwargs) 31 | 32 | # Loop through the dataset to get all the actions, ignoring the dummy ones. 33 | all_actions = [] 34 | storage = dataset._storage 35 | # NOTE: we add one to starts for the offset :) 36 | starts, ends = storage.starts + 1, storage.ends 37 | for start, end in zip(starts, ends): 38 | actions = storage["action"][start:end] 39 | all_actions.append(actions) 40 | 41 | all_actions = np.concatenate(all_actions, axis=0) 42 | 43 | # Compute all normalization possibilities. 44 | 45 | action_min, action_max = np.min(all_actions, axis=0), np.max(all_actions, axis=0) 46 | 47 | print("Low: ", action_min) 48 | print("High: ", action_max) 49 | 50 | action_mean, action_std = np.mean(all_actions, axis=0), np.std(all_actions, axis=0) 51 | 52 | print("Mean: ", action_mean) 53 | print("Std: ", action_std) 54 | 55 | # Compute the min / max after clipping. 56 | gaussian_normalized_actions = (all_actions - action_mean) / action_std 57 | # now clip everything 58 | gaussian_normalized_actions = np.clip(gaussian_normalized_actions, a_min=-args.clip, a_max=args.clip) 59 | # now re-normalize everthing 60 | clipped_low, clipped_high = np.min(gaussian_normalized_actions, axis=0), np.max(gaussian_normalized_actions, axis=0) 61 | 62 | print("Clipped Low: ", clipped_low) 63 | print("Clipped High: ", clipped_high) 64 | -------------------------------------------------------------------------------- /configs/examples/diffusion_policy.yaml: -------------------------------------------------------------------------------- 1 | alg: DiffusionPolicy 2 | alg_kwargs: 3 | # Configure offline steps. These aren't needed, but good to set. 4 | offline_steps: -1 5 | random_steps: 0 6 | noise_scheduler: ["import", "diffusers.schedulers.scheduling_ddim", "DDIMScheduler"] 7 | noise_scheduler_kwargs: 8 | num_train_timesteps: 100 9 | beta_start: 0.0001 10 | beta_end: 0.02 11 | beta_schedule: squaredcos_cap_v2 12 | clip_sample: True 13 | set_alpha_to_one: True 14 | steps_offset: 0 15 | prediction_type: epsilon 16 | num_inference_steps: 20 17 | horizon: 16 18 | 19 | optim: AdamW 20 | optim_kwargs: 21 | lr: 0.0001 22 | betas: 23 | - 0.95 24 | - 0.999 25 | eps: 1.0e-08 26 | weight_decay: 1.0e-06 27 | 28 | network: ActorPolicy 29 | network_kwargs: 30 | encoder_class: MultiEncoder 31 | encoder_kwargs: 32 | agentview_image_class: RobomimicEncoder 33 | agentview_image_kwargs: 34 | backbone: 18 35 | feature_dim: 64 36 | use_group_norm: True 37 | num_kp: 64 38 | robot0_eye_in_hand_image_class: RobomimicEncoder 39 | robot0_eye_in_hand_image_kwargs: 40 | backbone: 18 41 | feature_dim: 64 42 | use_group_norm: True 43 | num_kp: 64 44 | robot0_eef_pos_class: ["import", "torch.nn", "Identity"] 45 | robot0_eef_quat_class: ["import", "torch.nn", "Identity"] 46 | robot0_eef_vel_lin_class: ["import", "torch.nn", "Identity"] 47 | robot0_eef_vel_ang_class: ["import", "torch.nn", "Identity"] 48 | robot0_gripper_qpos_class: ["import", "torch.nn", "Identity"] 49 | 50 | actor_class: ConditionalUnet1D 51 | actor_kwargs: 52 | diffusion_step_embed_dim: 128 53 | down_dims: [256, 512, 1024] # This is [512, 1024, 2048] 54 | kernel_size: 5 55 | n_groups: 8 56 | 57 | # Example config using Robomimic. 58 | eval_env: RobomimicEnv 59 | eval_env_kwargs: 60 | path: path/to/robomimic/can/ph/image.hdf5 61 | horizon: 400 62 | channels_first: True 63 | 64 | dataset: RobomimicDataset 65 | dataset_kwargs: 66 | path: path/to/robomimic/can/ph/image.hdf5 67 | sample_fn: sample 68 | sample_kwargs: 69 | batch_size: 64 70 | seq: 16 71 | pad: 8 72 | seq_keys: ["action"] 73 | distributed: True 74 | validation_dataset_kwargs: 75 | train: False 76 | 77 | schedule: cosine_with_linear_warmup 78 | schedule_kwargs: 79 | warmup_steps: 2000 80 | total_steps: 500000 81 | 82 | processor: RandomCrop 83 | 84 | trainer_kwargs: 85 | total_steps: 500000 86 | log_freq: 250 # How often to log values 87 | profile_freq: 250 88 | eval_freq: 10000 # How often to run evals 89 | eval_fn: null 90 | loss_metric: loss # The validation metric that determines when to save the "best_checkpoint" 91 | max_validation_steps: 20 # Will run forever otherwise due to continuous replay buffer iter. 92 | train_dataloader_kwargs: 93 | num_workers: 4 # Number of dataloader workers. 94 | batch_size: null 95 | collate_fn: null 96 | -------------------------------------------------------------------------------- /scripts/plot.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from matplotlib import pyplot as plt 5 | 6 | from research.utils import plotter 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument("--output", "-o", type=str, default="plot.png", help="Path of output plot") 11 | parser.add_argument("--path", "-p", nargs="+", type=str, required=True, help="Paths of runs to plot") 12 | parser.add_argument( 13 | "--legend", "-l", nargs="+", type=str, required=False, help="Names of each run to display in the legend" 14 | ) 15 | parser.add_argument("--title", "-t", type=str, required=False, help="Plot title") 16 | parser.add_argument("--window", "-w", type=int, default=1, help="Moving window averaging parameter.") 17 | parser.add_argument("--x", "-x", type=str, default="step", help="X value to plot") 18 | parser.add_argument("--max-x", "-m", type=int, default=None, help="Max x value to plot") 19 | parser.add_argument("--xlabel", "-xl", type=str, default=None, help="X label to display on the plot") 20 | parser.add_argument("--y", "-y", type=str, nargs="+", default=["eval/loss"], help="Y value(s) to plot") 21 | parser.add_argument("--ylabel", "-yl", type=str, default=None, help="Y label to display on the plot") 22 | parser.add_argument("--fig-size", "-f", nargs=2, type=int, default=(3, 2)) 23 | args = parser.parse_args() 24 | 25 | paths = args.path 26 | 27 | if len(paths) == 1 and paths[0].endswith(".yaml"): 28 | # We are creating a plot via config 29 | plotter.plot_from_config(paths[0]) 30 | plt.savefig(args.output, dpi=300) # Increase DPI for higher res. 31 | else: 32 | # Check to see if we should auto-expand the path. 33 | # Do this only if the number of paths specified is one and each sub-path is a directory 34 | if len(paths) == 1 and all([os.path.isdir(os.path.join(paths[0], d)) for d in os.listdir(paths[0])]): 35 | paths = sorted([os.path.join(paths[0], d) for d in os.listdir(paths[0])]) 36 | # Now create the labels 37 | labels = args.legend 38 | if labels is None: 39 | labels = [os.path.basename(path[:-1] if path.endswith("/") else path) for path in paths] 40 | # Sort the paths alphabetically by the labels 41 | paths, labels = zip(*sorted(zip(paths, labels), key=lambda x: x[0])) # Alphabetically sort by filename 42 | 43 | plotter.create_plot( 44 | paths, 45 | labels, 46 | title=args.title, 47 | xlabel=args.xlabel, 48 | ylabel=args.ylabel, 49 | x_key=args.x, 50 | y_keys=args.y, 51 | window_size=args.window, 52 | max_x_value=args.max_x, 53 | ) 54 | 55 | # Save the plot 56 | print("[research] Saving plot to", args.output) 57 | plt.gcf().set_size_inches(*args.fig_size) 58 | plt.tight_layout(pad=0) 59 | plt.savefig(args.output, dpi=300) # Increase DPI for higher res. 60 | -------------------------------------------------------------------------------- /research/datasets/robomimic_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Optional 3 | 4 | import h5py 5 | import numpy as np 6 | import torch 7 | 8 | from research.utils import utils 9 | 10 | from .replay_buffer.buffer import ReplayBuffer 11 | 12 | 13 | class RobomimicDataset(ReplayBuffer): 14 | """ 15 | Simple Class that writes the data from the RoboMimicDatasets into a ReplayBuffer 16 | """ 17 | 18 | def __init__( 19 | self, observation_space, action_space, *args, action_eps: Optional[float] = 1e-5, train=True, **kwargs 20 | ): 21 | self.action_eps = action_eps 22 | self.train = train 23 | self.channels_first_keys = [] 24 | for k in observation_space.keys(): 25 | if "image" in k and observation_space[k].shape[0] == 3: 26 | self.channels_first_keys.append(k) 27 | super().__init__(observation_space, action_space, *args, **kwargs) 28 | 29 | def _data_generator(self): 30 | # Compute the worker info 31 | worker_info = torch.utils.data.get_worker_info() 32 | num_workers = 1 if worker_info is None else worker_info.num_workers 33 | worker_id = 0 if worker_info is None else worker_info.id 34 | 35 | f = h5py.File(self.path, "r") 36 | 37 | if self.train: 38 | # Extract the training demonstrations 39 | demos = [elem.decode("utf-8") for elem in np.array(f["mask/train"][:])] 40 | else: 41 | # Extract the validation 42 | demos = [elem.decode("utf-8") for elem in np.array(f["mask/valid"][:])] 43 | 44 | # Assign demos to each worker 45 | demos = sorted(demos) # Deterministic ordering 46 | demos = demos[worker_id::num_workers] 47 | # Shuffle the data ordering 48 | random.shuffle(demos) 49 | 50 | for _i, demo in enumerate(demos): 51 | # Get obs from the start to the end. 52 | obs = utils.get_from_batch(f["data"][demo]["obs"], 0, len(f["data"][demo]["dones"])) 53 | last_obs = utils.unsqueeze(utils.get_from_batch(f["data"][demo]["next_obs"], -1), 0) 54 | obs = utils.concatenate(obs, last_obs) 55 | obs = utils.remove_float64(obs) 56 | 57 | # Flip images if needed 58 | for k in self.channels_first_keys: 59 | obs[k] = np.transpose(obs[k], (0, 3, 1, 2)) 60 | 61 | dummy_action = np.expand_dims(self.dummy_action, axis=0) 62 | action = np.concatenate((dummy_action, f["data"][demo]["actions"]), axis=0) 63 | action = utils.remove_float64(action) 64 | 65 | if self.action_eps is not None: 66 | lim = 1 - self.action_eps 67 | action = np.clip(action, -lim, lim) 68 | 69 | reward = np.concatenate(([0], f["data"][demo]["rewards"]), axis=0) 70 | reward = utils.remove_float64(reward) 71 | 72 | done = np.concatenate(([0], f["data"][demo]["dones"]), axis=0).astype(np.bool_) 73 | done[-1] = True 74 | 75 | discount = (1 - done).astype(np.float32) 76 | 77 | obs_len = obs[next(iter(obs.keys()))].shape[0] 78 | assert all([len(obs[k]) == obs_len for k in obs.keys()]) 79 | assert obs_len == len(action) == len(reward) == len(done) == len(discount) 80 | 81 | yield dict(obs=obs, action=action, reward=reward, done=done, discount=discount) 82 | 83 | f.close() # Close the file handler. 84 | -------------------------------------------------------------------------------- /research/algs/offline/bc.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | 7 | from research.utils import utils 8 | 9 | from ..off_policy_algorithm import OffPolicyAlgorithm 10 | 11 | IGNORE_INDEX = -100 12 | 13 | 14 | class BehaviorCloning(OffPolicyAlgorithm): 15 | """ 16 | BC Implementation. 17 | Uses MSE loss for continuous, and CE for discrete 18 | Supports arbitrary obs -> action networks or ActorPolicy ModuleContainers. 19 | """ 20 | 21 | def __init__(self, *args, grad_norm_clip: Optional[float] = None, **kwargs) -> None: 22 | super().__init__(*args, **kwargs) 23 | self.grad_norm_clip = grad_norm_clip 24 | 25 | def setup_optimizers(self) -> None: 26 | # create optim groups. Any parameters that is 2D or higher will be weight decayed, otherwise no. 27 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 28 | groups = utils.create_optim_groups(self.network.parameters(), self.optim_kwargs) 29 | self.optim["network"] = self.optim_class(groups) 30 | 31 | def _compute_loss(self, batch: Dict): 32 | dist = self.network(batch["obs"]) 33 | 34 | if isinstance(dist, torch.distributions.Distribution): 35 | loss = -dist.log_prob(batch["action"]) # NLL Loss 36 | elif torch.is_tensor(dist) and isinstance(self.processor.action_space, gym.spaces.Box): 37 | loss = torch.nn.functional.mse_loss(dist, batch["action"], reduction="none") # MSE Loss 38 | elif torch.is_tensor(dist) and isinstance(self.processor.action_space, gym.spaces.Discrete): 39 | loss = torch.nn.functional.cross_entropy(dist, batch["action"], ignore_index=IGNORE_INDEX, reduction="none") 40 | else: 41 | raise ValueError("Invalid Policy output") 42 | 43 | # Aggregate the losses 44 | if "mask" in batch: 45 | assert batch["mask"].shape == loss.shape 46 | mask = (1 - batch["mask"]).float() 47 | loss = mask * loss 48 | size = mask.sum() # how many elements we train on. 49 | else: 50 | size = loss.numel() 51 | 52 | loss = loss.sum() / size 53 | return loss 54 | 55 | def train_step(self, batch: Dict, step: int, total_steps: int) -> Dict: 56 | """ 57 | Overriding the Algorithm BaseClass Method train_step. 58 | Returns a dictionary of training metrics. 59 | """ 60 | self.optim["network"].zero_grad(set_to_none=True) 61 | loss = self._compute_loss(batch) 62 | loss.backward() 63 | if self.grad_norm_clip is not None: 64 | torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.grad_norm_clip) 65 | self.optim["network"].step() 66 | metrics = dict(loss=loss.item()) 67 | return metrics 68 | 69 | def validation_step(self, batch: Any) -> Dict: 70 | """ 71 | Overriding the Algorithm BaseClass Method validation_step. 72 | Returns a dictionary of validation metrics. 73 | """ 74 | with torch.no_grad(): 75 | loss = self._compute_loss(batch) 76 | 77 | return dict(loss=loss.item()) 78 | 79 | def _get_train_action(self, obs: Any, step: int, total_steps: int) -> np.ndarray: 80 | batch = dict(obs=obs) 81 | with torch.no_grad(): 82 | action = self.predict(batch, is_batched=False, sample=True) 83 | return action 84 | -------------------------------------------------------------------------------- /research/datasets/wgcsl_dataset.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import pickle 4 | from typing import List, Optional, Tuple, Union 5 | 6 | import gym 7 | import numpy as np 8 | import torch 9 | 10 | from research.utils.utils import remove_float64 11 | 12 | from .replay_buffer.buffer import ReplayBuffer 13 | 14 | 15 | class WGCSLDataset(ReplayBuffer): 16 | """ 17 | Simple Class that writes the data from the WGCSL buffers into a HindsightReplayBuffer 18 | """ 19 | 20 | def __init__( 21 | self, 22 | observation_space: gym.Space, 23 | action_space: gym.Space, 24 | path: Union[str, Tuple[str]] = (), 25 | percents: Optional[List[float]] = None, 26 | train: bool = True, 27 | terminal_threshold: Optional[float] = None, 28 | **kwargs, 29 | ): 30 | assert path is not None 31 | if isinstance(path, str): 32 | path = [path] 33 | percents = [1.0] * len(path) if percents is None else percents 34 | self.percents = percents 35 | self.train = train 36 | self.terminal_threshold = terminal_threshold 37 | super().__init__(observation_space, action_space, path=path, **kwargs) 38 | assert not self.distributed, "WGCSL datasets do not support distributed training." 39 | 40 | def _data_generator(self): 41 | for path, percent in zip(self.path, self.percents): 42 | with open(path, "rb") as f: 43 | data = pickle.load(f) 44 | num_ep = data["ag"].shape[0] 45 | # Add the episodes 46 | ep_idxs = range(int(num_ep * percent)) if self.train else range(num_ep - int(num_ep * percent), num_ep) 47 | for i in ep_idxs: 48 | # We need to make sure we appropriately handle the dummy transition 49 | obs = dict(achieved_goal=data["ag"][i].copy()) 50 | if "o" in data: 51 | obs["observation"] = data["o"][i].copy() 52 | if "g" in data: 53 | goal = data["g"][i] 54 | obs["desired_goal"] = np.concatenate((goal[:1], goal), axis=0) 55 | obs = remove_float64(obs) 56 | dummy_action = np.expand_dims(self.dummy_action, axis=0) 57 | action = np.concatenate((dummy_action, data["u"][i]), axis=0) 58 | action = remove_float64(action) 59 | 60 | # If we have a terminal threshold compute and store the horizon 61 | if self.terminal_threshold is not None: 62 | goal_distance = np.linalg.norm(obs["desired_goal"] - obs["achieved_goal"], axis=-1) 63 | done = (goal_distance < self.terminal_threshold).astype(np.bool_) 64 | else: 65 | done = np.zeros(action.shape[0], dtype=np.bool_) 66 | done[-1] = True # Add the episode delineation 67 | discount = np.ones(action.shape[0]) # Gets recomputed with HER 68 | reward = np.zeros(action.shape[0]) # Gets recomputed with HER 69 | assert len(obs["achieved_goal"]) == len(action) == len(reward) == len(done) == len(discount) 70 | yield dict(obs=obs, action=action, reward=reward, done=done, discount=discount) 71 | 72 | # Explicitly delete the data objects to save memory 73 | del data 74 | del obs 75 | del action 76 | del reward 77 | del done 78 | del discount 79 | gc.collect() 80 | -------------------------------------------------------------------------------- /tools/parse_sweep.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import collections 3 | import os 4 | 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from research.utils.plotter import LOG_FILE_NAME, moving_avg 9 | 10 | 11 | def get_score(path, y_key, window=1, use_max=False): 12 | # Get all of the seeds 13 | if LOG_FILE_NAME in os.listdir(path): 14 | paths = [path] 15 | else: 16 | assert all(["seed-" in p for p in os.listdir(path)]) 17 | paths = [os.path.join(path, seed) for seed in os.listdir(path)] 18 | values = collections.defaultdict(list) 19 | for p in paths: 20 | df = pd.read_csv(os.path.join(p, LOG_FILE_NAME)) 21 | if y_key not in df: 22 | print("[tools] Error: key", y_key, "not in", p) 23 | x, y = moving_avg(df["step"].to_numpy(), df[y_key].to_numpy(), window) 24 | for i in range(len(x)): 25 | values[x[i]].append(y[i]) 26 | # Compute the final values by averaging 27 | values = {k: np.mean(v) for k, v in values.items()} 28 | if use_max: 29 | return max(values.values()) 30 | else: 31 | return min(values.values()) 32 | 33 | 34 | def get_params(path): 35 | # Separate out the hyperparamters 36 | name = os.path.basename(path) 37 | parts = name.split("_") 38 | params = {} 39 | for part in parts: 40 | split_part = part.split("-") 41 | name, value = split_part[0], "-".join(split_part[1:]) 42 | params[name] = value 43 | return params 44 | 45 | 46 | def get_paths(path): 47 | files = os.listdir(path) 48 | if LOG_FILE_NAME in files or any(["seed" in f for f in files]): 49 | return [path] 50 | else: 51 | return sum([get_paths(os.path.join(path, f)) for f in files], start=[]) 52 | 53 | 54 | if __name__ == "__main__": 55 | # Do something 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument("--path", type=str, help="Path to the run to handle.") 58 | parser.add_argument("--y-key", type=str, default="step", help="name of string to parse") 59 | parser.add_argument("--window", type=int, default=1, help="averaging window") 60 | parser.add_argument("--use-max", action="store_true", help="if we should return the max for each metric") 61 | parser.add_argument("--include", nargs="+", type=str, default=[], help="Include filters") 62 | parser.add_argument("--exclude", nargs="+", type=str, default=[], help="Exclude filters") 63 | args = parser.parse_args() 64 | 65 | paths = get_paths(args.path) 66 | # Run the path filters 67 | for include in args.include: 68 | paths = [path for path in paths if include in path] 69 | for exclude in args.exclude: 70 | paths = [path for path in paths if exclude not in path] 71 | params_list = [get_params(path) for path in paths] 72 | scores = [get_score(path, args.y_key, window=args.window, use_max=args.use_max) for path in paths] 73 | 74 | # Get all hyperparameter configurations 75 | hyperparameters = collections.defaultdict(set) 76 | for params in params_list: 77 | for name, value in params.items(): 78 | hyperparameters[name].add(value) 79 | 80 | # For each hyperparameter, construct a report for its values averaged over scores 81 | for param, values in hyperparameters.items(): 82 | print("[Sweep Parser] Ablating parameter", param) 83 | for value in sorted(values): 84 | avg_score = np.mean( 85 | [scores[i] for i in range(len(scores)) if param in params_list[i] and params_list[i][param] == value] 86 | ) 87 | print(value, ":", avg_score) 88 | -------------------------------------------------------------------------------- /scripts/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | from research.utils.config import Config 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument("--checkpoint", type=str, required=True, help="Path to the model") 9 | parser.add_argument("--path", type=str, default=None, required=False, help="Path to save the gif") 10 | parser.add_argument("--device", "-d", type=str, default="auto") 11 | parser.add_argument("--num-ep", type=int, default=1, help="Number of episodes") 12 | parser.add_argument("--num-gifs", type=int, default=0, help="Number of gifs to save.") 13 | parser.add_argument("--every-n-frames", type=int, default=2, help="Save every n frames to the gif.") 14 | parser.add_argument("--width", type=int, default=160, help="Width of image") 15 | parser.add_argument("--height", type=int, default=120, help="Height of image") 16 | parser.add_argument("--strict", action="store_true", default=False, help="Strict") 17 | parser.add_argument( 18 | "--terminate-on-success", action="store_true", default=False, help="Terminate gif on success condition." 19 | ) 20 | parser.add_argument( 21 | "--override", 22 | metavar="KEY=VALUE", 23 | nargs="+", 24 | default=[], 25 | help="Set kv pairs used as args for the entry point script.", 26 | ) 27 | parser.add_argument("--max-len", type=int, default=1000, help="maximum length of an episode.") 28 | args = parser.parse_args() 29 | 30 | assert args.checkpoint.endswith(".pt"), "Must provide a model checkpoint" 31 | config = Config.load(os.path.dirname(args.checkpoint)) 32 | config["checkpoint"] = None # Set checkpoint to None, we don't actually need to load it. 33 | 34 | if args.path is None: 35 | args.path = os.path.dirname(args.checkpoint) 36 | 37 | # Overrides 38 | print("Overrides:") 39 | for override in args.override: 40 | print(override) 41 | 42 | # Overrides 43 | for override in args.override: 44 | items = override.split("=") 45 | key, value = items[0].strip(), "=".join(items[1:]) 46 | # Progress down the config path (seperated by '.') until we reach the final value to override. 47 | config_path = key.split(".") 48 | config_dict = config 49 | while len(config_path) > 1: 50 | config_dict = config_dict[config_path[0]] 51 | config_path.pop(0) 52 | config_dict[config_path[0]] = value 53 | 54 | if len(args.override) > 0: 55 | print(config) 56 | 57 | # Make sure we don't use subprocess evaluation 58 | config["trainer_kwargs"]["eval_env_runner"] = None 59 | 60 | # Over-write the parameters in the eval_kwargs 61 | assert config["trainer_kwargs"]["eval_fn"] == "eval_policy", "Evaluate only works with eval_policy for now." 62 | config["trainer_kwargs"]["eval_kwargs"]["num_ep"] = args.num_ep 63 | config["trainer_kwargs"]["eval_kwargs"]["num_gifs"] = args.num_gifs 64 | config["trainer_kwargs"]["eval_kwargs"]["width"] = args.width 65 | config["trainer_kwargs"]["eval_kwargs"]["height"] = args.height 66 | config["trainer_kwargs"]["eval_kwargs"]["every_n_frames"] = args.every_n_frames 67 | config["trainer_kwargs"]["eval_kwargs"]["terminate_on_success"] = args.terminate_on_success 68 | config = config.parse() 69 | model = config.get_model(device=args.device) 70 | metadata = model.load(args.checkpoint) 71 | trainer = config.get_trainer(model=model) 72 | # Run the evaluation loop 73 | os.makedirs(args.path, exist_ok=True) 74 | metrics = trainer.evaluate(args.path, metadata["step"]) 75 | 76 | print("[research] Eval policy finished:") 77 | for k, v in metrics.items(): 78 | print(k, v) 79 | -------------------------------------------------------------------------------- /research/processors/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Processors are designed as ways of manipulating entire batches of tensors at once to prepare them for the network. 3 | Examples are as follows: 4 | 1. Normalization 5 | 2. Image Augmentations applied on the entire batch at once. 6 | """ 7 | from typing import Any, Dict, List, Optional, Tuple 8 | 9 | import gym 10 | import torch 11 | 12 | import research 13 | 14 | 15 | class Processor(torch.nn.Module): 16 | """ 17 | This is the base processor class. All processors should inherit from it. 18 | """ 19 | 20 | def __init__(self, observation_space: gym.Space, action_space: gym.Space): 21 | super().__init__() 22 | self.training = True 23 | self._observation_space = observation_space 24 | self._action_space = action_space 25 | 26 | def unprocess(self, batch: Any) -> Any: 27 | raise NotImplementedError 28 | 29 | @property 30 | def supports_gpu(self): 31 | return True 32 | 33 | @property 34 | def observation_space(self): 35 | """ 36 | Outputs the observation space for the network 37 | Can be overrided if processor changes this space. 38 | """ 39 | return self._observation_space 40 | 41 | @property 42 | def action_space(self): 43 | """ 44 | Outputs the action space for the network 45 | Can be overrided if processor changes this space. 46 | """ 47 | return self._action_space 48 | 49 | 50 | class Identity(Processor): 51 | """ 52 | This processor just performs the identity operation 53 | """ 54 | 55 | def forward(self, batch: Any) -> Any: 56 | return batch 57 | 58 | def unprocess(self, batch: Any) -> Any: 59 | return batch 60 | 61 | 62 | class Compose(Processor): 63 | """ 64 | This Processor Composes multiple processors 65 | """ 66 | 67 | def __init__( 68 | self, 69 | observation_space: gym.Space, 70 | action_space: gym.Space, 71 | processors: List[Tuple[str, Optional[Dict]]] = (("Identity", None),), 72 | ): 73 | super().__init__(observation_space, action_space) 74 | created_processors = [] 75 | current_observation_space, current_action_space = observation_space, action_space 76 | for processor_class, processor_kwargs in processors: 77 | processor_class = vars(research.processors)[processor_class] 78 | processor_kwargs = {} if processor_kwargs is None else processor_kwargs 79 | processor = processor_class(current_observation_space, current_action_space, **processor_kwargs) 80 | created_processors.append(processor) 81 | current_observation_space, current_action_space = processor.observation_space, processor.action_space 82 | self.processors = torch.nn.ModuleList(created_processors) 83 | 84 | @property 85 | def observation_space(self): 86 | # Return the space of the last processor 87 | return self.processors[-1].observation_space 88 | 89 | @property 90 | def action_space(self): 91 | # Return the space of the last processor 92 | return self.processors[-1].action_space 93 | 94 | @property 95 | def supports_gpu(self): 96 | return all([processor.supports_gpu for processor in self.processors]) 97 | 98 | def forward(self, batch: Any) -> Any: 99 | for processor in self.processors: 100 | batch = processor(batch) 101 | return batch 102 | 103 | def unprocess(self, batch: Any) -> Any: 104 | for processor in reversed(self.processors): 105 | batch = processor.unprocess(batch) 106 | return batch 107 | -------------------------------------------------------------------------------- /research/envs/robomimic.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import gym 4 | import h5py 5 | import numpy as np 6 | from robomimic.utils import env_utils, file_utils 7 | 8 | 9 | class RobomimicEnv(gym.Env): 10 | def __init__( 11 | self, 12 | path: str, 13 | keys: Optional[List] = None, 14 | terminate_early: bool = True, 15 | horizon: Optional[int] = None, 16 | channels_first: bool = True, 17 | ): 18 | # Get the observation space 19 | f = h5py.File(path, "r") 20 | demo_id = list(f["data"].keys())[0] 21 | demo = f["data/{}".format(demo_id)] 22 | if keys is None: 23 | self.keys = list(demo["obs"].keys()) 24 | else: 25 | self.keys = keys 26 | spaces = {} 27 | self.image_keys = [] 28 | use_image_obs = False 29 | for k in self.keys: 30 | if "image" in k: 31 | use_image_obs = True 32 | obs_modality = demo["obs/{}".format(k)] 33 | if obs_modality.dtype == np.uint8: 34 | low, high = 0, 255 35 | self.image_keys.append(k) 36 | if channels_first: 37 | spaces[k] = gym.spaces.Box( 38 | low=low, 39 | high=high, 40 | shape=(obs_modality.shape[-1], obs_modality.shape[-3], obs_modality.shape[-2]), 41 | dtype=np.uint8, 42 | ) 43 | else: 44 | # just add normally 45 | spaces[k] = gym.spaces.Box(low=low, high=high, shape=obs_modality.shape[1:], dtype=np.uint8) 46 | elif obs_modality.dtype == np.float32 or obs_modality.dtype == np.float64: 47 | low, high = -np.inf, np.inf 48 | dtype = np.float32 if obs_modality.dtype == np.float64 else obs_modality.dtype 49 | spaces[k] = gym.spaces.Box(low=low, high=high, shape=obs_modality.shape[1:], dtype=dtype) 50 | else: 51 | raise ValueError("Unsupported dtype in Robomimic Env.") 52 | 53 | self.observation_space = gym.spaces.Dict(spaces) 54 | self.channels_first = channels_first 55 | f.close() 56 | 57 | # Create the environment. 58 | env_meta = file_utils.get_env_metadata_from_dataset(dataset_path=path) 59 | self.env = env_utils.create_env_from_metadata( 60 | env_meta=env_meta, 61 | env_name=env_meta["env_name"], 62 | render=False, 63 | render_offscreen=False, 64 | use_image_obs=use_image_obs, 65 | ).env 66 | self.env.ignore_done = False 67 | if horizon is not None: 68 | self.env.horizon = horizon 69 | self.env._max_episode_steps = self.env.horizon 70 | self.terminate_early = terminate_early 71 | 72 | # Get the action space 73 | low, high = self.env.action_spec 74 | self.action_space = gym.spaces.Box(low, high) 75 | 76 | def _format_obs(self, obs): 77 | if "object-state" in obs: 78 | # Need to duplicate because of robomimic bug. 79 | obs["object"] = obs["object-state"] 80 | obs = {k: obs[k] for k in self.keys} 81 | if self.channels_first: 82 | for k in self.image_keys: 83 | obs[k] = np.transpose(obs[k], (2, 0, 1)) 84 | return obs 85 | 86 | def step(self, action: np.ndarray): 87 | obs, reward, done, info = self.env.step(action) 88 | if self.terminate_early and self.env._check_success(): 89 | done = True 90 | return self._format_obs(obs), reward, done, info 91 | 92 | def reset(self, *args, **kwargs): 93 | return self._format_obs(self.env.reset(*args, **kwargs)) 94 | -------------------------------------------------------------------------------- /research/envs/__init__.py: -------------------------------------------------------------------------------- 1 | # If we want to register environments in gym. 2 | # These will be loaded when we import the research package. 3 | from gym.envs import register 4 | 5 | from .base import EmptyEnv 6 | 7 | # Add things you explicitly want exported here. 8 | # Otherwise, all imports are deleted. 9 | __all__ = ["EmptyEnv"] 10 | 11 | try: 12 | import gym_robotics 13 | except ImportError: 14 | print("[research] skipping gym robotics, package not found.") 15 | 16 | try: 17 | import d4rl 18 | except ImportError: 19 | print("[research] skipping d4rl, package not found.") 20 | 21 | try: 22 | # Register environment classes here 23 | # Register the DM Control environments. 24 | from dm_control import suite 25 | 26 | # Custom DM Control domains can be registered as follows: 27 | # from . import 28 | # assert hasattr(, 'SUITE') 29 | # suite._DOMAINS[''] = 30 | 31 | # Register all of the DM control tasks 32 | for domain_name, task_name in suite._get_tasks(tag=None): 33 | # Import state domains 34 | ID = f"{domain_name.capitalize()}{task_name.capitalize()}-v0" 35 | register( 36 | id=ID, 37 | entry_point="research.envs.dm_control:DMControlEnv", 38 | kwargs={ 39 | "domain_name": domain_name, 40 | "task_name": task_name, 41 | "action_minimum": -1.0, 42 | "action_maximum": 1.0, 43 | "action_repeat": 1, 44 | "from_pixels": False, 45 | "flatten": True, 46 | "stack": 1, 47 | }, 48 | ) 49 | 50 | # Import vision domains as specified in DRQ-v2 51 | ID = f"{domain_name.capitalize()}{task_name.capitalize()}-vision-v0" 52 | camera_id = dict(quadruped=2).get(domain_name, 0) 53 | register( 54 | id=ID, 55 | entry_point="research.envs.dm_control:DMControlEnv", 56 | kwargs={ 57 | "domain_name": domain_name, 58 | "task_name": task_name, 59 | "action_repeat": 2, 60 | "action_minimum": -1.0, 61 | "action_maximum": 1.0, 62 | "from_pixels": True, 63 | "height": 84, 64 | "width": 84, 65 | "camera_id": camera_id, 66 | "flatten": False, 67 | "stack": 3, 68 | }, 69 | ) 70 | 71 | # Cleanup extra imports 72 | del suite 73 | except ImportError: 74 | print("[research] Skipping dm_control, package not found.") 75 | 76 | try: 77 | from metaworld.envs.mujoco.env_dict import ALL_V2_ENVIRONMENTS 78 | 79 | # Add the meta world test environments. 80 | # For each one, register the different tasks. 81 | 82 | for env_name, _env_cls in ALL_V2_ENVIRONMENTS.items(): 83 | ID = f"mw_{env_name}" 84 | register(id=ID, entry_point="research.envs.metaworld:MetaWorldSawyerEnv", kwargs={"env_name": env_name}) 85 | id_parts = ID.split("-") 86 | id_parts[-1] = "image-" + id_parts[-1] 87 | ID = "-".join(id_parts) 88 | register(id=ID, entry_point="research.envs.metaworld:get_mw_image_env", kwargs={"env_name": env_name}) 89 | except ImportError: 90 | print("[research] Skipping metaworld, package not found.") 91 | 92 | try: 93 | from .robomimic import RobomimicEnv 94 | except ImportError: 95 | print("[research] Skipping robomimic, package not found") 96 | 97 | try: 98 | from .franka import FrankaEnv, FrankaReach 99 | except ImportError: 100 | print("[research] Skipping polymetis, package not found.") 101 | 102 | del register 103 | -------------------------------------------------------------------------------- /research/envs/metaworld.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple wrapper for registering metaworld enviornments 3 | properly with gym. 4 | """ 5 | import gym 6 | import numpy as np 7 | 8 | 9 | class MetaWorldSawyerEnv(gym.Env): 10 | def __init__(self, env_name, seed=True): 11 | from metaworld.envs.mujoco.env_dict import ALL_V2_ENVIRONMENTS 12 | 13 | self._env = ALL_V2_ENVIRONMENTS[env_name]() 14 | self._env._freeze_rand_vec = False 15 | self._env._set_task_called = True 16 | self._seed = seed 17 | if self._seed: 18 | self._env.seed(0) # Seed it at zero for now. 19 | 20 | self.observation_space = self._env.observation_space 21 | self.action_space = self._env.action_space 22 | self._max_episode_steps = self._env.max_path_length 23 | 24 | def seed(self, seed=None): 25 | super().seed(seed=seed) 26 | if self._seed: 27 | self._env.seed(0) 28 | 29 | def evaluate_state(self, state, action): 30 | return self._env.evaluate_state(state, action) 31 | 32 | def step(self, action): 33 | self._episode_steps += 1 34 | obs, reward, done, info = self._env.step(action) 35 | if self._episode_steps == self._max_episode_steps: 36 | done = True 37 | info["discount"] = 1.0 # Ensure infinite boostrap. 38 | # Add the underlying state to the info 39 | state = self._env.sim.get_state() 40 | info["state"] = np.concatenate((state.qpos, state.qvel), axis=0) 41 | return obs.astype(np.float32), reward, done, info 42 | 43 | def set_state(self, state): 44 | qpos, qvel = state[: self._env.model.nq], state[self._env.model.nq :] 45 | self._env.set_state(qpos, qvel) 46 | 47 | def reset(self, **kwargs): 48 | self._episode_steps = 0 49 | return self._env.reset(**kwargs).astype(np.float32) 50 | 51 | def render(self, mode="rgb_array", camera_name="corner2", width=640, height=480): 52 | assert mode == "rgb_array", "Only RGB array is supported" 53 | # stack multiple views 54 | for ctx in self._env.sim.render_contexts: 55 | ctx.opengl_context.make_context_current() 56 | return self._env.render(offscreen=True, camera_name=camera_name, resolution=(width, height)) 57 | 58 | def __getattr__(self, name): 59 | return getattr(self._env, name) 60 | 61 | 62 | class MetaWorldSawyerImageWrapper(gym.Wrapper): 63 | def __init__(self, env, width=84, height=84, camera="corner2", show_goal=False): 64 | assert isinstance( 65 | env.unwrapped, MetaWorldSawyerEnv 66 | ), "MetaWorld Wrapper must be used with a MetaWorldSawyerEnv class" 67 | super().__init__(env) 68 | self._width = width 69 | self._height = height 70 | self._camera = camera 71 | self._show_goal = show_goal 72 | shape = (3, self._height, self._width) 73 | self.observation_space = gym.spaces.Box(low=0, high=255, shape=shape, dtype=np.uint8) 74 | 75 | def _get_image(self): 76 | if not self._show_goal: 77 | try: 78 | self.env.unwrapped._set_pos_site("goal", np.inf * self.env.unwrapped._target_pos) 79 | except ValueError: 80 | pass # If we don't have the goal site, just continue. 81 | img = self.env.render(mode="rgb_array", camera_name=self._camera, width=self._width, height=self._height) 82 | return img.transpose(2, 0, 1) 83 | 84 | def get_state_obs(self): 85 | return self.env.unwrapped._get_obs() 86 | 87 | def step(self, action): 88 | state_obs, reward, done, info = self.env.step(action) 89 | # Throw away the state-based observation. 90 | info["state"] = state_obs 91 | return self._get_image().copy(), reward, done, info 92 | 93 | def reset(self): 94 | # Zoom in camera corner2 to make it better for control 95 | # I found this view to work well across a lot of the tasks. 96 | camera_name = "corner2" 97 | # Original XYZ is 1.3 -0.2 1.1 98 | index = self.model.camera_name2id(camera_name) 99 | self.model.cam_fovy[index] = 20.0 # FOV 100 | self.model.cam_pos[index][0] = 1.5 # X 101 | self.model.cam_pos[index][1] = -0.35 # Y 102 | self.model.cam_pos[index][2] = 1.1 # Z 103 | 104 | self.env.reset() 105 | return self._get_image().copy() # Return the image observation 106 | 107 | 108 | def get_mw_image_env(env_name): 109 | env = MetaWorldSawyerEnv(env_name) 110 | return MetaWorldSawyerImageWrapper(env) 111 | -------------------------------------------------------------------------------- /research/utils/logger.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import os 3 | from abc import ABC, abstractmethod 4 | from collections.abc import Iterable 5 | from typing import Any 6 | 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | try: 10 | import wandb 11 | except ModuleNotFoundError: 12 | pass 13 | 14 | 15 | class Writer(ABC): 16 | def __init__(self, path, on_eval=False): 17 | self.path = path 18 | self.on_eval = on_eval 19 | self.values = {} 20 | 21 | def record(self, key: str, value: Any) -> None: 22 | self.values[key] = value 23 | 24 | def dump(self, step: int, eval: bool = False) -> None: 25 | if not self.on_eval or eval: 26 | self._dump(step) 27 | 28 | @abstractmethod 29 | def _dump(self, step: int, eval: bool = False) -> None: 30 | return NotImplementedError 31 | 32 | @abstractmethod 33 | def close(self) -> None: 34 | raise NotImplementedError 35 | 36 | 37 | class TensorBoardWriter(Writer): 38 | def __init__(self, path, on_eval=False): 39 | super().__init__(path, on_eval=on_eval) 40 | self.writer = SummaryWriter(self.path) 41 | 42 | def _dump(self, step): 43 | for k in self.values.keys(): 44 | self.writer.add_scalar(k, self.values[k], step) 45 | self.writer.flush() 46 | self.values.clear() 47 | 48 | def close(self): 49 | self.writer.close() 50 | 51 | 52 | class CSVWriter(Writer): 53 | def __init__(self, path, on_eval=True): 54 | super().__init__(path, on_eval=on_eval) 55 | self._csv_path = os.path.join(self.path, "log.csv") 56 | self._csv_file_handler = None 57 | self.csv_logger = None 58 | self.num_keys = 0 59 | 60 | # If we are continuing to train, make sure that we know how many keys to expect. 61 | if os.path.exists(self._csv_path): 62 | with open(self._csv_path, "r") as f: 63 | reader = csv.DictReader(f) 64 | fieldnames = reader.fieldnames.copy() 65 | num_keys = len(fieldnames) 66 | if num_keys > self.num_keys: 67 | self.num_keys = num_keys 68 | # Create a new CSV handler with the fieldnames set. 69 | self.csv_file_handler = open(self._csv_path, "a") 70 | self.csv_logger = csv.DictWriter(self.csv_file_handler, fieldnames=list(fieldnames)) 71 | 72 | def _reset_csv_handler(self): 73 | if self._csv_file_handler is not None: 74 | self._csv_file_handler.close() # Close our fds 75 | self.csv_file_handler = open(self._csv_path, "w") # Write a new one 76 | self.csv_logger = csv.DictWriter(self.csv_file_handler, fieldnames=list(self.values.keys())) 77 | self.csv_logger.writeheader() 78 | 79 | def _dump(self, step): 80 | # Record the step 81 | self.values["step"] = step 82 | if len(self.values) < self.num_keys: 83 | # We haven't gotten all keys yet, return without doing anything. 84 | return 85 | if len(self.values) > self.num_keys: 86 | # Got a new key, so re-create the writer 87 | self.num_keys = len(self.values) 88 | # We encountered a new key. We need to recreate the file handler and overwrite old data 89 | self._reset_csv_handler() 90 | 91 | # We should now have all the keys 92 | self.csv_logger.writerow(self.values) 93 | self.csv_file_handler.flush() 94 | # Note: Don't reset the CSV because the file handler doesn't support it. 95 | 96 | def close(self): 97 | self.csv_file_handler.close() 98 | 99 | 100 | class WandBWriter(Writer): 101 | def __init__(self, path: str, on_eval: bool = True): 102 | super().__init__(path, on_eval=on_eval) 103 | # No extra init steps, just mark eval as True 104 | 105 | def _dump(self, step: int) -> None: 106 | wandb.log(self.values, step=step) 107 | self.values.clear() # reset the values 108 | 109 | def close(self) -> None: 110 | wandb.finish() 111 | 112 | 113 | class Logger(object): 114 | def __init__(self, path: str, writers: Iterable[str] = ("tb", "csv")): 115 | self.writers = [] 116 | for writer in writers: 117 | self.writers.append({"tb": TensorBoardWriter, "csv": CSVWriter, "wandb": WandBWriter}[writer](path)) 118 | 119 | def record(self, key: str, value: Any) -> None: 120 | for writer in self.writers: 121 | writer.record(key, value) 122 | 123 | def dump(self, step: int, eval: bool = False) -> None: 124 | for writer in self.writers: 125 | writer.dump(step, eval=eval) 126 | 127 | def close(self) -> None: 128 | for writer in self.writers: 129 | writer.close() 130 | -------------------------------------------------------------------------------- /research/algs/offline/dp.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | 3 | import diffusers 4 | import numpy as np 5 | import torch 6 | 7 | from research.networks.base import ActorPolicy 8 | from research.utils import utils 9 | 10 | from ..off_policy_algorithm import OffPolicyAlgorithm 11 | 12 | 13 | class DiffusionPolicy(OffPolicyAlgorithm): 14 | """ 15 | BC Implementation. 16 | Uses MSE loss for continuous, and CE for discrete 17 | Supports arbitrary obs -> action networks or ActorPolicy ModuleContainers. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | *args, 23 | noise_scheduler=diffusers.schedulers.DDIMScheduler, 24 | noise_scheduler_kwargs: Optional[Dict] = None, 25 | num_inference_steps: Optional[int] = 10, 26 | horizon: int = 16, 27 | **kwargs, 28 | ) -> None: 29 | super().__init__(*args, **kwargs) 30 | assert isinstance(self.network, ActorPolicy), "Must use an ActorPolicy with DiffusionPolicy" 31 | noise_scheduler_kwargs = {} if noise_scheduler_kwargs is None else noise_scheduler_kwargs 32 | self.noise_scheduler = noise_scheduler(**noise_scheduler_kwargs) 33 | if num_inference_steps is None: 34 | self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps 35 | else: 36 | self.num_inference_steps = num_inference_steps 37 | self.horizon = horizon 38 | 39 | def setup_optimizers(self) -> None: 40 | """ 41 | Decay support added explicitly. Maybe move this to base implementation? 42 | """ 43 | # create optim groups. Any parameters that is 2D or higher will be weight decayed, otherwise no. 44 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 45 | groups = utils.create_optim_groups(self.network.parameters(), self.optim_kwargs) 46 | self.optim["network"] = self.optim_class(groups) 47 | 48 | def _compute_loss(self, batch: Dict) -> torch.Tensor: 49 | obs = self.network.encoder(batch["obs"]) 50 | B, T = batch["action"].shape[:2] 51 | assert T == self.horizon, "Received unexpected temporal dimension." 52 | noise = torch.randn_like(batch["action"]) 53 | timesteps = torch.randint( 54 | low=0, high=self.noise_scheduler.config.num_train_timesteps, size=(B,), device=self.device 55 | ).long() 56 | noisy_actions = self.noise_scheduler.add_noise(batch["action"], noise, timesteps) 57 | 58 | noise_pred = self.network.actor(noisy_actions, timesteps, cond=obs) 59 | loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none").sum(dim=-1) # Sum over action Dim 60 | if "mask" in batch: 61 | mask = (~batch["mask"]).float() 62 | loss = loss * mask 63 | size = mask.sum() 64 | else: 65 | size = loss.numel() 66 | loss = loss.sum() / size 67 | return loss 68 | 69 | def train_step(self, batch: Dict, step: int, total_steps: int) -> Dict: 70 | """ 71 | Overriding the Algorithm BaseClass Method train_step. 72 | Returns a dictionary of training metrics. 73 | """ 74 | self.optim["network"].zero_grad(set_to_none=True) 75 | loss = self._compute_loss(batch) 76 | # Update the networks. These are done in a stack to support different grad options for the encoder. 77 | loss.backward() 78 | self.optim["network"].step() 79 | return dict(loss=loss.item()) 80 | 81 | def validation_step(self, batch: Any) -> Dict: 82 | """ 83 | Overriding the Algorithm BaseClass Method validation_step. 84 | Returns a dictionary of validation metrics. 85 | """ 86 | with torch.no_grad(): 87 | loss = self._compute_loss(batch) 88 | return dict(loss=loss.item()) 89 | 90 | def _predict(self, batch: Dict) -> torch.Tensor: 91 | B = batch["obs"].shape[0] 92 | noisy_actions = torch.randn(B, self.horizon, self.processor.action_space.shape[0], device=self.device) 93 | with torch.no_grad(): 94 | obs = self.network.encoder(batch["obs"]) 95 | self.noise_scheduler.set_timesteps(self.num_inference_steps) 96 | for timestep in self.noise_scheduler.timesteps: 97 | noise_pred = self.network.actor(noisy_actions, timestep.unsqueeze(0).to(self.device), cond=obs) 98 | noisy_actions = self.noise_scheduler.step( 99 | model_output=noise_pred, timestep=timestep, sample=noisy_actions 100 | ).prev_sample 101 | return noisy_actions 102 | 103 | def _get_train_action(self, obs: Any, step: int, total_steps: int) -> np.ndarray: 104 | batch = dict(obs=obs) 105 | with torch.no_grad(): 106 | action = self.predict(batch, is_batched=False, sample=True) 107 | return action 108 | -------------------------------------------------------------------------------- /research/processors/concatenate.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | 7 | from .base import Processor 8 | 9 | 10 | class Concatenate(Processor): 11 | def __init__( 12 | self, 13 | observation_space: gym.Space, 14 | action_space: gym.Space, 15 | concat_obs: bool = True, 16 | concat_action: bool = True, 17 | obs_dim: int = -1, 18 | action_dim: int = -1, 19 | ) -> None: 20 | super().__init__(observation_space, action_space) 21 | self.concat_action = concat_action and isinstance(action_space, gym.spaces.Dict) 22 | self.action_dim = action_dim 23 | self.forward_action_dim = action_dim if action_dim < 0 else action_dim + 1 24 | if self.concat_action: 25 | self.action_order = list(action_space.keys()) 26 | self.concat_obs = concat_obs and isinstance(observation_space, gym.spaces.Dict) 27 | self.obs_dim = obs_dim 28 | self.forward_obs_dim = obs_dim if obs_dim < 0 else obs_dim + 1 29 | if self.concat_obs: 30 | self.obs_order = list(observation_space.keys()) 31 | 32 | @property 33 | def observation_space(self): 34 | if self.concat_obs: 35 | # Concatenate the spaces on the last dim 36 | low = np.concatenate([space.low for space in self._observation_space.values()], axis=self.obs_dim) 37 | high = np.concatenate([space.high for space in self._observation_space.values()], axis=self.obs_dim) 38 | return gym.spaces.Box(low=low, high=high, dtype=np.float32) # force float32 conversion 39 | else: 40 | return self._observation_space 41 | 42 | @property 43 | def action_space(self): 44 | if self.concat_action: 45 | # Concatenate the spaces on the last dim 46 | low = np.concatenate([space.low for space in self._action_space.values()], axis=self.action_dim) 47 | high = np.concatenate([space.high for space in self._action_space.values()], axis=self.action_dim) 48 | return gym.spaces.Box(low=low, high=high, dtype=np.float32) # force float32 conversion 49 | else: 50 | return self._action_space 51 | 52 | def forward(self, batch: Dict) -> Dict: 53 | batch = {k: v for k, v in batch.items()} # Perform a shallow copy of the batch 54 | if self.concat_action and "action" in batch: 55 | batch["action"] = torch.cat( 56 | [batch["action"][act_key] for act_key in self.action_order], dim=self.forward_action_dim 57 | ) 58 | for k in ("obs", "next_obs", "init_obs"): 59 | if self.concat_obs and k in batch: 60 | batch[k] = torch.cat([batch[k][obs_key] for obs_key in self.obs_order], dim=self.forward_obs_dim) 61 | return batch 62 | 63 | 64 | class SelectProcessor(Processor): 65 | def __init__( 66 | self, 67 | observation_space: gym.Space, 68 | action_space: gym.Space, 69 | obs_include: Optional[List[str]] = None, 70 | obs_exclude: Optional[List[str]] = None, 71 | action_include: Optional[List[str]] = None, 72 | action_exclude: Optional[List[str]] = None, 73 | ): 74 | super().__init__(observation_space, action_space) 75 | assert not (action_include is not None and action_exclude is not None) 76 | assert not (obs_include is not None and obs_exclude is not None) 77 | 78 | if action_include is not None: 79 | self.action_keys = [k for k in action_space.keys() if k in action_include] 80 | elif action_exclude is not None: 81 | self.action_keys = [k for k in action_space.keys() if k not in action_exclude] 82 | else: 83 | self.action_keys = None 84 | if self.action_keys is not None: 85 | self._action_space = gym.spaces.Dict({k: v for k, v in self._action_space.items() if k in self.action_keys}) 86 | 87 | if obs_include is not None: 88 | self.obs_keys = [k for k in observation_space.keys() if k in obs_include] 89 | elif obs_exclude is not None: 90 | self.obs_keys = [k for k in observation_space.keys() if k not in obs_exclude] 91 | else: 92 | self.obs_keys = None 93 | if self.obs_keys is not None: 94 | self._observation_space = gym.spaces.Dict( 95 | {k: v for k, v in self._observation_space.items() if k in self.obs_keys} 96 | ) 97 | 98 | def forward(self, batch: Dict) -> Dict: 99 | if "action" in batch and self.action_keys is not None: 100 | batch["action"] = {k: batch["action"][k] for k in self.action_keys} 101 | for k in ("obs", "next_obs", "init_obs"): 102 | if k in batch and self.obs_keys is not None: 103 | batch[k] = {obs_key: batch[k][obs_key] for obs_key in self.obs_keys} 104 | return batch 105 | -------------------------------------------------------------------------------- /tools/run_slurm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | import subprocess 5 | import tempfile 6 | from typing import TextIO 7 | 8 | import utils 9 | 10 | SLURM_LOG_DEFAULT = os.path.join(utils.STORAGE_ROOT, "slurm_logs") 11 | 12 | SLURM_ARGS = { 13 | "partition": {"type": str, "required": True}, 14 | "time": {"type": str, "default": "48:00:00"}, 15 | "nodes": {"type": int, "default": 1}, 16 | "ntasks-per-node": {"type": int, "default": 1}, 17 | "cpus": {"type": int, "required": True}, 18 | "gpus": {"type": str, "required": False, "default": None}, 19 | "mem": {"type": str, "required": True}, 20 | "output": {"type": str, "default": SLURM_LOG_DEFAULT}, 21 | "error": {"type": str, "default": SLURM_LOG_DEFAULT}, 22 | "job-name": {"type": str, "required": True}, 23 | "exclude": {"type": str, "required": False, "default": None}, 24 | "nodelist": {"type": str, "required": False, "default": None}, 25 | "account": {"type": str, "required": False, "default": None}, 26 | } 27 | 28 | SLURM_NAME_OVERRIDES = {"gpus": "gres", "cpus": "cpus-per-task"} 29 | 30 | 31 | def write_slurm_header(f: TextIO, args: argparse.Namespace) -> None: 32 | # Make a copy of the args to prevent corruption 33 | args = copy.deepcopy(args) 34 | # Modify everything in the name space to later write it all at once 35 | for key in SLURM_ARGS.keys(): 36 | assert key.replace("-", "_") in args, "Key " + key + " not found." 37 | 38 | if not os.path.isdir(args.output): 39 | os.makedirs(args.output) 40 | if not os.path.isdir(args.error): 41 | os.makedirs(args.error) 42 | 43 | args.output = os.path.join(args.output, args.job_name + "_%A.out") 44 | args.error = os.path.join(args.error, args.job_name + "_%A.err") 45 | args.gpus = "gpu:" + str(args.gpus) if args.gpus is not None else args.gpus 46 | 47 | NL = "\n" 48 | f.write("#!/bin/bash" + NL) 49 | f.write(NL) 50 | for arg_name in SLURM_ARGS.keys(): 51 | arg_value = vars(args)[arg_name.replace("-", "_")] 52 | if arg_name in SLURM_NAME_OVERRIDES: 53 | arg_name = SLURM_NAME_OVERRIDES[arg_name] 54 | if arg_value is not None: 55 | f.write("#SBATCH --" + arg_name + "=" + str(arg_value) + NL) 56 | 57 | f.write(NL) 58 | f.write('echo "SLURM_JOBID = "$SLURM_JOBID' + NL) 59 | f.write('echo "SLURM_JOB_NODELIST = "$SLURM_JOB_NODELIST' + NL) 60 | f.write('echo "SLURM_JOB_NODELIST = "$SLURM_JOB_NODELIST' + NL) 61 | f.write('echo "SLURM_NNODES = "$SLURM_NNODES' + NL) 62 | f.write('echo "SLURMTMPDIR = "$SLURMTMPDIR' + NL) 63 | f.write('echo "working directory = "$SLURM_SUBMIT_DIR' + NL) 64 | f.write(NL) 65 | f.write(". " + utils.ENV_SETUP_SCRIPT) 66 | f.write(NL) 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = utils.get_parser() 71 | # Add Slurm Arguments 72 | for k, v in SLURM_ARGS.items(): 73 | parser.add_argument("--" + k, **v) 74 | parser.add_argument( 75 | "--remainder", 76 | default="split", 77 | choices=["split", "new"], 78 | help="Whether or not to spread out jobs that don't divide evently, or place them in a new job", 79 | ) 80 | 81 | args = parser.parse_args() 82 | scripts = utils.get_scripts(args) 83 | 84 | # Call python subprocess to launch the slurm jobs. 85 | num_slurm_calls = len(scripts) // args.scripts_per_job 86 | remainder_scripts = len(scripts) - num_slurm_calls * args.scripts_per_job 87 | scripts_per_call = [args.scripts_per_job for _ in range(num_slurm_calls)] 88 | if args.remainder == "split": 89 | for i in range(remainder_scripts): 90 | scripts_per_call[i] += 1 # Add the remainder jobs to spread them out as evenly as possible. 91 | elif args.remainder == "new": 92 | scripts_per_call.append(remainder_scripts) 93 | else: 94 | raise ValueError("Invalid job remainder specification.") 95 | assert sum(scripts_per_call) == len(scripts) 96 | script_index = 0 97 | procs = [] 98 | for num_scripts in scripts_per_call: 99 | current_scripts = scripts[script_index : script_index + num_scripts] 100 | script_index += num_scripts 101 | 102 | _, slurm_file = tempfile.mkstemp(text=True, prefix="job_", suffix=".sh") 103 | print("Launching job with slurm configuration:", slurm_file) 104 | 105 | with open(slurm_file, "w+") as f: 106 | write_slurm_header(f, args) 107 | # Now that we have written the header we can launch the jobs. 108 | for entry_point, script_args in current_scripts: 109 | command_str = ["python", entry_point] 110 | for arg_name, arg_value in script_args.items(): 111 | command_str.append("--" + arg_name) 112 | command_str.append(str(arg_value)) 113 | if len(current_scripts) != 1: 114 | command_str.append("&") 115 | command_str = " ".join(command_str) + "\n" 116 | f.write(command_str) 117 | if len(current_scripts) != 1: 118 | f.write("wait") 119 | 120 | # Now launch the job 121 | proc = subprocess.Popen(["sbatch", slurm_file]) 122 | procs.append(proc) 123 | 124 | exit_codes = [p.wait() for p in procs] 125 | -------------------------------------------------------------------------------- /research/algs/online/td3.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Any, Dict, Type 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from research.networks.base import ActorCriticPolicy 8 | 9 | from ..off_policy_algorithm import OffPolicyAlgorithm 10 | 11 | 12 | class TD3(OffPolicyAlgorithm): 13 | def __init__( 14 | self, 15 | *args, 16 | tau=0.005, 17 | policy_noise=0.1, 18 | target_noise=0.2, 19 | noise_clip=0.5, 20 | critic_freq=1, 21 | actor_freq=2, 22 | target_freq=2, 23 | average_actor_q=True, 24 | bc_coeff=0.0, 25 | **kwargs, 26 | ): 27 | super().__init__(*args, **kwargs) 28 | assert isinstance(self.network, ActorCriticPolicy) 29 | # Save extra parameters 30 | self.tau = tau 31 | self.policy_noise = policy_noise 32 | self.target_noise = target_noise 33 | self.noise_clip = noise_clip 34 | self.critic_freq = critic_freq 35 | self.actor_freq = actor_freq 36 | self.target_freq = target_freq 37 | self.average_actor_q = average_actor_q 38 | self.bc_coeff = bc_coeff 39 | 40 | def setup_network(self, network_class: Type[torch.nn.Module], network_kwargs: Dict) -> None: 41 | self.network = network_class( 42 | self.processor.observation_space, self.processor.action_space, **network_kwargs 43 | ).to(self.device) 44 | self.target_network = network_class( 45 | self.processor.observation_space, self.processor.action_space, **network_kwargs 46 | ).to(self.device) 47 | self.target_network.load_state_dict(self.network.state_dict()) 48 | for param in self.target_network.parameters(): 49 | param.requires_grad = False 50 | 51 | def setup_optimizers(self) -> None: 52 | # Default optimizer initialization 53 | self.optim["actor"] = self.optim_class(self.network.actor.parameters(), **self.optim_kwargs) 54 | # Update the encoder with the critic. 55 | critic_params = itertools.chain(self.network.critic.parameters(), self.network.encoder.parameters()) 56 | self.optim["critic"] = self.optim_class(critic_params, **self.optim_kwargs) 57 | 58 | def _get_train_action(self, obs: Any, step: int, total_steps: int) -> np.ndarray: 59 | batch = dict(obs=obs) 60 | with torch.no_grad(): 61 | action = self.predict(batch, is_batched=False) 62 | action += self.policy_noise * np.random.randn(action.shape[0]) 63 | return action 64 | 65 | def _update_critic(self, batch: Dict) -> Dict: 66 | with torch.no_grad(): 67 | noise = (torch.randn_like(batch["action"]) * self.target_noise).clamp(-self.noise_clip, self.noise_clip) 68 | next_action = self.target_network.actor(batch["next_obs"]) 69 | noisy_next_action = (next_action + noise).clamp(*self.action_range) 70 | target_q = self.target_network.critic(batch["next_obs"], noisy_next_action) 71 | target_q = torch.min(target_q, dim=0)[0] 72 | target_q = batch["reward"] + batch["discount"] * target_q 73 | 74 | qs = self.network.critic(batch["obs"], batch["action"]) 75 | q_loss = torch.nn.functional.mse_loss(qs, target_q.expand(qs.shape[0], -1), reduction="none").mean() 76 | 77 | self.optim["critic"].zero_grad(set_to_none=True) 78 | q_loss.backward() 79 | self.optim["critic"].step() 80 | 81 | return dict(q_loss=q_loss.item(), target_q=target_q.mean().item()) 82 | 83 | def _update_actor(self, batch: Dict) -> Dict: 84 | obs = batch["obs"].detach() # Detach the encoder so it isn't updated. 85 | action = self.network.actor(obs) 86 | qs = self.network.critic(obs, action) 87 | if self.average_actor_q: 88 | q = qs.mean(dim=0) # average the qs over the ensemble 89 | else: 90 | q = qs[0] # Take only the first Q function 91 | actor_loss = -q.mean() 92 | 93 | if self.bc_coeff > 0.0: 94 | bc_loss = torch.nn.functional.mse_loss(action, batch["action"]) 95 | actor_loss = actor_loss + self.bc_coeff * bc_loss 96 | 97 | self.optim["actor"].zero_grad(set_to_none=True) 98 | actor_loss.backward() 99 | self.optim["actor"].step() 100 | 101 | return dict(actor_loss=actor_loss.item()) 102 | 103 | def train_step(self, batch: Dict, step: int, total_steps: int) -> Dict: 104 | all_metrics = {} 105 | 106 | if "obs" not in batch or step < self.random_steps: 107 | return all_metrics 108 | 109 | batch["obs"] = self.network.encoder(batch["obs"]) 110 | with torch.no_grad(): 111 | batch["next_obs"] = self.target_network.encoder(batch["next_obs"]) 112 | 113 | if step % self.critic_freq == 0: 114 | metrics = self._update_critic(batch) 115 | all_metrics.update(metrics) 116 | 117 | if step % self.actor_freq == 0: 118 | metrics = self._update_actor(batch) 119 | all_metrics.update(metrics) 120 | 121 | if step % self.target_freq == 0: 122 | with torch.no_grad(): 123 | for param, target_param in zip(self.network.parameters(), self.target_network.parameters()): 124 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 125 | 126 | return all_metrics 127 | -------------------------------------------------------------------------------- /research/utils/evaluate.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import os 3 | from typing import Any, Dict, List, Optional 4 | 5 | import gym 6 | import imageio 7 | import numpy as np 8 | import torch 9 | 10 | from . import utils 11 | 12 | MAX_METRICS = {"success", "is_success", "completions"} 13 | LAST_METRICS = {"goal_distance"} 14 | MEAN_METRICS = {"discount"} 15 | EXCLUDE_METRICS = {"TimeLimit.truncated"} 16 | 17 | 18 | class EvalMetricTracker(object): 19 | """ 20 | A simple class to make keeping track of eval metrics easy. 21 | Usage: 22 | Call reset before each episode starts 23 | Call step after each environment step 24 | call export to get the final metrics 25 | """ 26 | 27 | def __init__(self): 28 | self.metrics = collections.defaultdict(list) 29 | self.ep_length = 0 30 | self.ep_reward = 0 31 | self.ep_metrics = collections.defaultdict(list) 32 | 33 | def reset(self) -> None: 34 | if self.ep_length > 0: 35 | # Add the episode to overall metrics 36 | self.metrics["reward"].append(self.ep_reward) 37 | self.metrics["length"].append(self.ep_length) 38 | for k, v in self.ep_metrics.items(): 39 | if k in MAX_METRICS: 40 | self.metrics[k].append(np.max(v)) 41 | elif k in LAST_METRICS: # Append the last value 42 | self.metrics[k].append(v[-1]) 43 | elif k in MEAN_METRICS: 44 | self.metrics[k].append(np.mean(v)) 45 | else: 46 | self.metrics[k].append(np.sum(v)) 47 | 48 | self.ep_length = 0 49 | self.ep_reward = 0 50 | self.ep_metrics = collections.defaultdict(list) 51 | 52 | def step(self, reward: float, info: Dict) -> None: 53 | self.ep_length += 1 54 | self.ep_reward += reward 55 | for k, v in info.items(): 56 | if (isinstance(v, float) or np.isscalar(v)) and k not in EXCLUDE_METRICS: 57 | self.ep_metrics[k].append(v) 58 | 59 | def add(self, k: str, v: Any): 60 | self.metrics[k].append(v) 61 | 62 | def export(self) -> Dict: 63 | if self.ep_length > 0: 64 | # We have one remaining episode to log, make sure to get it. 65 | self.reset() 66 | metrics = {k: np.mean(v) for k, v in self.metrics.items()} 67 | metrics["reward_std"] = np.std(self.metrics["reward"]) 68 | return metrics 69 | 70 | 71 | def eval_multiple(env, model, path: str, step: int, eval_fns: List[str], eval_kwargs: List[Dict]): 72 | all_metrics = dict() 73 | for eval_fn, eval_kwarg in zip(eval_fns, eval_kwargs): 74 | metrics = locals()[eval_fn](env, model, path, step, **eval_kwarg) 75 | all_metrics.update(metrics) 76 | return all_metrics 77 | 78 | 79 | def eval_policy( 80 | env: gym.Env, 81 | model, 82 | path: str, 83 | step: int, 84 | num_ep: int = 10, 85 | num_gifs: int = 0, 86 | width=200, 87 | height=200, 88 | every_n_frames: int = 2, 89 | terminate_on_success=False, 90 | history_length: int = 0, 91 | predict_kwargs: Optional[Dict] = None, 92 | ) -> Dict: 93 | metric_tracker = EvalMetricTracker() 94 | predict_kwargs = {} if predict_kwargs is None else predict_kwargs 95 | assert num_gifs <= num_ep, "Cannot save more gifs than eval ep." 96 | 97 | for i in range(num_ep): 98 | # Reset Metrics 99 | done = False 100 | ep_length, ep_reward = 0, 0 101 | frames = [] 102 | save_gif = i < num_gifs 103 | render_kwargs = dict(mode="rgb_array", width=width, height=height) if save_gif else dict() 104 | obs = env.reset() 105 | if history_length > 0: 106 | obs = utils.unsqueeze(obs, 0) 107 | if save_gif: 108 | frames.append(env.render(**render_kwargs)) 109 | metric_tracker.reset() 110 | while not done: 111 | batch = dict(obs=obs) 112 | if hasattr(env, "_max_episode_steps"): 113 | batch["horizon"] = env._max_episode_steps - ep_length 114 | with torch.no_grad(): 115 | action = model.predict(batch, **predict_kwargs) 116 | if history_length > 0: 117 | action = action[-1] 118 | next_obs, reward, done, info = env.step(action) 119 | ep_reward += reward 120 | metric_tracker.step(reward, info) 121 | ep_length += 1 122 | if save_gif and ep_length % every_n_frames == 0: 123 | frames.append(env.render(**render_kwargs)) 124 | if terminate_on_success and (info.get("success", False) or info.get("is_success", False)): 125 | done = True 126 | # Update the observation if we have history 127 | if history_length > 0: 128 | obs = utils.concatenate(obs, utils.unsqueeze(next_obs, 0), dim=0) 129 | if ep_length + 1 > history_length: 130 | # Drop the last observation from the sequence 131 | obs = utils.get_from_batch(obs, start=1, end=history_length + 1) 132 | else: 133 | obs = next_obs 134 | 135 | if hasattr(env, "get_normalized_score"): 136 | metric_tracker.add("score", env.get_normalized_score(ep_reward)) 137 | 138 | if save_gif: 139 | gif_name = "vis-{}_ep-{}.gif".format(step, i) 140 | imageio.mimsave(os.path.join(path, gif_name), frames) 141 | 142 | return metric_tracker.export() 143 | -------------------------------------------------------------------------------- /research/networks/drqv2.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | 8 | from .mlp import MLP, EnsembleMLP 9 | 10 | 11 | def drqv2_weight_init(m: nn.Module) -> None: 12 | if isinstance(m, nn.Linear): 13 | nn.init.orthogonal_(m.weight.data) 14 | if hasattr(m.bias, "data"): 15 | m.bias.data.fill_(0.0) 16 | elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 17 | gain = nn.init.calculate_gain("relu") 18 | nn.init.orthogonal_(m.weight.data, gain) 19 | if hasattr(m.bias, "data"): 20 | m.bias.data.fill_(0.0) 21 | 22 | 23 | class DrQv2Encoder(nn.Module): 24 | def __init__(self, observation_space: gym.Space, action_space: gym.Space) -> None: 25 | super().__init__() 26 | if len(observation_space.shape) == 4: 27 | s, c, h, w = observation_space.shape 28 | channels = s * c 29 | elif len(observation_space.shape) == 3: 30 | c, h, w = observation_space.shape 31 | channels = c 32 | else: 33 | raise ValueError("Invalid observation space for DRQV2 Image encoder.") 34 | self.convnet = nn.Sequential( 35 | nn.Conv2d(channels, 32, 3, stride=2), 36 | nn.ReLU(), 37 | nn.Conv2d(32, 32, 3, stride=1), 38 | nn.ReLU(), 39 | nn.Conv2d(32, 32, 3, stride=1), 40 | nn.ReLU(), 41 | nn.Conv2d(32, 32, 3, stride=1), 42 | nn.ReLU(), 43 | nn.Flatten(), 44 | ) 45 | self.reset_parameters() 46 | 47 | with torch.no_grad(): 48 | sample = torch.as_tensor(observation_space.sample()[None]) / 255.0 - 0.5 49 | self.repr_dim = self.convnet(sample).shape[1] 50 | 51 | def reset_parameters(self): 52 | self.apply(drqv2_weight_init) 53 | 54 | @property 55 | def output_space(self) -> gym.Space: 56 | return gym.spaces.Box(shape=(self.repr_dim,), low=-np.inf, high=np.inf, dtype=np.float32) 57 | 58 | def forward(self, obs: torch.Tensor) -> torch.Tensor: 59 | if len(obs.shape) == 5: 60 | b, s, c, h, w = obs.shape 61 | obs = obs.view(b, s * c, h, w) 62 | obs = obs / 255.0 - 0.5 63 | h = self.convnet(obs) 64 | return h 65 | 66 | 67 | class DrQv2Critic(nn.Module): 68 | def __init__( 69 | self, 70 | observation_space: gym.Space, 71 | action_space: gym.Space, 72 | feature_dim: int = 50, 73 | hidden_layers: Iterable[int] = (1024, 1024), 74 | ensemble_size: int = 2, 75 | **kwargs, 76 | ): 77 | super().__init__() 78 | self.trunk = nn.Sequential( 79 | nn.Linear(observation_space.shape[0], feature_dim), nn.LayerNorm(feature_dim), nn.Tanh() 80 | ) 81 | self.ensemble_size = ensemble_size 82 | input_dim = feature_dim + action_space.shape[0] 83 | if self.ensemble_size > 1: 84 | self.mlp = EnsembleMLP(input_dim, 1, ensemble_size=ensemble_size, hidden_layers=hidden_layers, **kwargs) 85 | else: 86 | self.mlp = MLP(input_dim, 1, hidden_layers=hidden_layers, **kwargs) 87 | self.reset_parameters() 88 | 89 | def reset_parameters(self): 90 | self.apply(drqv2_weight_init) 91 | 92 | def forward(self, obs, action): 93 | x = self.trunk(obs) 94 | x = torch.cat((x, action), dim=-1) 95 | q = self.mlp(x).squeeze(-1) # Remove the last dim 96 | if self.ensemble_size == 1: 97 | q = q.unsqueeze(0) # add in the ensemble dim 98 | return q 99 | 100 | 101 | class DrQv2Value(nn.Module): 102 | def __init__( 103 | self, 104 | observation_space: gym.Space, 105 | action_space: gym.Space, 106 | feature_dim: int = 50, 107 | hidden_layers: Iterable[int] = (1024, 1024), 108 | ensemble_size: int = 1, 109 | **kwargs, 110 | ): 111 | super().__init__() 112 | self.trunk = nn.Sequential( 113 | nn.Linear(observation_space.shape[0], feature_dim), nn.LayerNorm(feature_dim), nn.Tanh() 114 | ) 115 | self.ensemble_size = ensemble_size 116 | if self.ensemble_size > 1: 117 | self.mlp = EnsembleMLP(feature_dim, 1, ensemble_size=ensemble_size, hidden_layers=hidden_layers, **kwargs) 118 | else: 119 | self.mlp = MLP(feature_dim, 1, hidden_layers=hidden_layers, **kwargs) 120 | self.reset_parameters() 121 | 122 | def reset_parameters(self): 123 | self.apply(drqv2_weight_init) 124 | 125 | def forward(self, obs): 126 | v = self.trunk(obs) 127 | v = self.mlp(v).squeeze(-1) # Remove the last dim 128 | if self.ensemble_size == 1: 129 | v = v.unsqueeze(0) # add in the ensemble dim 130 | return v 131 | 132 | 133 | class DrQv2Actor(nn.Module): 134 | def __init__( 135 | self, 136 | observation_space: gym.Space, 137 | action_space: gym.Space, 138 | feature_dim: int = 50, 139 | hidden_layers: Iterable[int] = (1024, 1024), 140 | **kwargs, 141 | ): 142 | super().__init__() 143 | self.trunk = nn.Sequential( 144 | nn.Linear(observation_space.shape[0], feature_dim), nn.LayerNorm(feature_dim), nn.Tanh() 145 | ) 146 | self.mlp = MLP(feature_dim, action_space.shape[0], hidden_layers=hidden_layers, **kwargs) 147 | self.reset_parameters() 148 | 149 | def reset_parameters(self): 150 | self.apply(drqv2_weight_init) 151 | 152 | def forward(self, obs: torch.Tensor) -> torch.Tensor: 153 | x = self.trunk(obs) 154 | return self.mlp(x) 155 | -------------------------------------------------------------------------------- /tools/run_local.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | 5 | import utils 6 | 7 | """ 8 | A script for launching jobs on a local machine. This is designed to mimic the design as the SLURM launch script. 9 | """ 10 | 11 | if __name__ == "__main__": 12 | parser = utils.get_parser() 13 | parser.add_argument("--cpus", "-c", type=int, default=None, help="Number of CPUs per job instance") 14 | parser.add_argument("--gpus", "-g", type=int, default=None, help="Number of GPUs per job instance") 15 | parser.add_argument("--valid-gpus", type=int, nargs="+", default=None, help="Specifies which GPUS to use.") 16 | parser.add_argument("--valid-cpus", type=str, nargs="+", default=None) 17 | parser.add_argument( 18 | "--use-taskset", action="store_true", default=False, help="Whether or not to CPU load balance with taskset" 19 | ) 20 | 21 | # Add Taskset and GPU arguments 22 | args = parser.parse_args() 23 | assert isinstance(args.valid_gpus, list) or args.valid_gpus is None, "Valid GPUs must be a list of ints or None." 24 | assert isinstance(args.valid_cpus, list) or args.valid_cpus is None, "Valid CPUs must be a list" 25 | if args.gpus is not None: 26 | assert ( 27 | isinstance(args.valid_gpus, list) and len(args.valid_gpus) >= args.gpus 28 | ), "If GPU, must provide valid gpus >= num gpus" 29 | 30 | scripts = utils.get_scripts(args) 31 | 32 | if args.valid_cpus is None: 33 | args.valid_cpus = ["0-" + str(os.cpu_count())] 34 | 35 | cpu_list = [] # Populate CPU list with a list of all valid CPU cores [1,2,3,4,5,6, ...] etc. 36 | for cpu_item in args.valid_cpus: 37 | if isinstance(cpu_item, str) and "-" in cpu_item: 38 | # We have a CPU range 39 | cpu_min, cpu_max = cpu_item.split("-") 40 | cpu_min, cpu_max = int(cpu_min), int(cpu_max) 41 | cpu_list.extend(list(range(cpu_min, cpu_max))) 42 | else: 43 | cpu_list.append(int(cpu_item)) 44 | assert ( 45 | len(cpu_list) >= args.cpus or args.cpus is None 46 | ), "Must have more valid CPUs than cpus per script, otherwise nothing can launch" 47 | 48 | gpu_list = [] if args.valid_gpus is None else args.valid_gpus 49 | 50 | job_list = [] 51 | 52 | try: 53 | while len(scripts) > 0: 54 | # Check on existing proceses 55 | finished_jobs = [] 56 | for i, (processes, job_cpus, job_gpus) in enumerate(finished_jobs): 57 | if all([process.poll() is not None for process in processes]): 58 | cpu_list.extend(job_cpus) 59 | gpu_list.extend(job_gpus) 60 | finished_jobs.append(i) 61 | for i in reversed(finished_jobs): 62 | del job_list[i] 63 | 64 | # Next, check to see if we can launch a process 65 | have_sufficient_cpus = args.cpus is None or len(cpu_list) >= args.cpus 66 | have_sufficient_gpus = args.gpus is None or len(gpu_list) >= args.gpus 67 | if have_sufficient_cpus and have_sufficient_cpus: 68 | # we have the resources to launch a job, so launch it 69 | job_cpus = cpu_list[: args.cpus] if args.cpus is not None else [] 70 | job_gpus = gpu_list[: args.gpus] if args.gpus is not None else [] 71 | job_scripts = scripts[: args.scripts_per_job] 72 | processes = [] 73 | for entry_point, script_args in job_scripts: 74 | command_list = [] 75 | if args.use_taskset: 76 | command_list.extend(["taskset", "-c", ",".join(job_cpus)]) 77 | command_list.extend(["python", entry_point]) 78 | for arg_name, arg_value in script_args.items(): 79 | command_list.append("--" + arg_name) 80 | command_list.append(str(arg_value)) 81 | if job_gpus is not None: 82 | env = os.environ 83 | env["CUDA_VISIBLE_DEVICES"] = ",".join(job_gpus) 84 | else: 85 | env = None 86 | 87 | print("[Local Sweeper] launching script on gpu:", job_gpus, "and cpus:", job_cpus) 88 | proc = subprocess.Popen(command_list, env=env) 89 | processes.append(proc) 90 | 91 | # After all the jobs have launched, updated the set of available resources and remaining scripts 92 | cpu_list = cpu_list[len(job_cpus) :] 93 | gpu_list = gpu_list[len(job_gpus) :] 94 | scripts = scripts[len(job_scripts) :] 95 | # Append to the set of currently running jobs 96 | job_list.append((processes, job_cpus, job_gpus)) 97 | 98 | else: 99 | # If we were unable to launch a job, sleep for a while. 100 | time.sleep(10) 101 | 102 | # We have launched all the scripts, now wait for the remaining ones to complete. 103 | all_processes = [] 104 | for processes, _, _ in job_list: 105 | all_processes.extend(processes) 106 | exit_codes = [p.wait() for p in all_processes] 107 | print("[Local Sweeper] Completed.") 108 | 109 | except KeyboardInterrupt: 110 | # If we detect a keyboard interrupt, manually send a kill signal to all subprocesses. 111 | all_processes = [] 112 | for processes, _, _ in job_list: 113 | all_processes.extend(processes) 114 | 115 | for p in processes: 116 | try: 117 | p.terminate() 118 | except OSError: 119 | pass 120 | p.wait() 121 | -------------------------------------------------------------------------------- /research/networks/base.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | 7 | import research 8 | 9 | """ 10 | There are two special network functions used by research lightning 11 | 1. output_space - this is used to give the observation_space to different networks in a container group 12 | 2. compile - this is used when torch.compile is called. 13 | """ 14 | 15 | 16 | def reset(module): 17 | if hasattr(module, "reset_parameters"): 18 | module.reset_parameters() 19 | 20 | 21 | class ModuleContainer(torch.nn.Module): 22 | CONTAINERS = [] 23 | 24 | def __init__(self, observation_space: gym.Space, action_space: gym.Space, **kwargs) -> None: 25 | super().__init__() 26 | # save the classes and containers 27 | base_kwargs = {k: v for k, v in kwargs.items() if not k.endswith("_class") and not k.endswith("_kwargs")} 28 | 29 | output_space = observation_space 30 | for container in self.CONTAINERS: 31 | module_class = kwargs.get(container + "_class", torch.nn.Identity) 32 | module_class = vars(research.networks)[module_class] if isinstance(module_class, str) else module_class 33 | if module_class is torch.nn.Identity: 34 | module_kwargs = dict() 35 | else: 36 | module_kwargs = base_kwargs.copy() 37 | module_kwargs.update(kwargs.get(container + "_kwargs", dict())) 38 | # Create the module, and attach it to self 39 | module = module_class(output_space, action_space, **module_kwargs) 40 | setattr(self, container, module) 41 | 42 | # Set a reset function 43 | setattr(self, "reset_" + container, partial(self._reset, container)) 44 | 45 | if hasattr(getattr(self, container), "output_space"): 46 | # update the output space 47 | output_space = getattr(self, container).output_space 48 | 49 | # Done creating all sub-modules. 50 | 51 | @classmethod 52 | def create_subset(cls, containers): 53 | assert all([container in cls.CONTAINERS for container in containers]) 54 | name = "".join([container.capitalize() for container in containers]) + "Subset" 55 | return type(name, (ModuleContainer,), {"CONTAINERS": containers}) 56 | 57 | def _reset(self, container: str) -> None: 58 | module = getattr(self, container) 59 | with torch.no_grad(): 60 | module.apply(reset) 61 | 62 | def compile(self, **kwargs): 63 | for container in self.CONTAINERS: 64 | attr = getattr(self, container) 65 | if type(attr).forward == torch.nn.Module.forward: 66 | assert hasattr(attr, "compile"), ( 67 | "container " + container + " is nn.Module without forward() but didn't define `compile`." 68 | ) 69 | attr.compile(**kwargs) 70 | else: 71 | setattr(self, container, torch.compile(attr, **kwargs)) 72 | 73 | def forward(self, x): 74 | # Use all of the modules in order 75 | for container in self.CONTAINERS: 76 | x = getattr(self, container)(x) 77 | return x 78 | 79 | 80 | class MultiEncoder(torch.nn.Module): 81 | def __init__(self, observation_space: gym.Space, action_space: gym.Space, **kwargs): 82 | super().__init__() 83 | assert isinstance(observation_space, gym.spaces.Dict) 84 | base_kwargs = {k: v for k, v in kwargs.items() if not k.endswith("_class") and not k.endswith("_kwargs")} 85 | # parse unique modalities from args that are passed with "class" 86 | self.obs_keys = sorted([k[: -len("_class")] for k in kwargs if k.endswith("_class")]) 87 | assert all([k in set(observation_space.keys()) for k in self.obs_keys]) 88 | 89 | modules = dict() 90 | for k in self.obs_keys: 91 | # Build the modules 92 | module_class = kwargs[k + "_class"] 93 | module_class = vars(research.networks)[module_class] if isinstance(module_class, str) else module_class 94 | module_kwargs = base_kwargs.copy() 95 | module_kwargs.update(kwargs.get(k + "_kwargs", dict())) 96 | module = module_class(observation_space[k], action_space, **module_kwargs) 97 | modules[k] = module 98 | 99 | # register all the modules 100 | self.encoders = torch.nn.ModuleDict(modules) 101 | 102 | # compute the output space 103 | output_dim = 0 104 | for k in self.obs_keys: 105 | module = self.encoders[k] 106 | if hasattr(module, "output_space"): 107 | output_shape = module.output_space.shape 108 | else: 109 | assert isinstance(module, torch.nn.Identity) 110 | output_shape = observation_space[k].shape 111 | assert len(output_shape) == 1 112 | output_dim += output_shape[0] 113 | 114 | self.output_dim = output_dim 115 | 116 | @property 117 | def output_space(self) -> gym.Space: 118 | return gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.output_dim,), dtype=np.float32) 119 | 120 | def forward(self, obs): 121 | return torch.cat([self.encoders[k](obs[k]) for k in self.obs_keys], dim=-1) 122 | 123 | 124 | class ActorCriticPolicy(ModuleContainer): 125 | CONTAINERS = ["encoder", "actor", "critic"] 126 | 127 | 128 | class ActorCriticValuePolicy(ModuleContainer): 129 | CONTAINERS = ["encoder", "actor", "critic", "value"] 130 | 131 | 132 | class ActorValuePolicy(ModuleContainer): 133 | CONTAINERS = ["encoder", "actor", "value"] 134 | 135 | 136 | class ActorPolicy(ModuleContainer): 137 | CONTAINERS = ["encoder", "actor"] 138 | -------------------------------------------------------------------------------- /research/datasets/d4rl_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import d4rl 4 | import gym 5 | import numpy as np 6 | 7 | from .replay_buffer.buffer import ReplayBuffer 8 | 9 | 10 | class D4RLDataset(ReplayBuffer): 11 | """ 12 | This class is designed to be able to produce the same dataset configs used in the IQL paper. 13 | See https://github.com/ikostrikov/implicit_q_learning 14 | """ 15 | 16 | def __init__( 17 | self, 18 | observation_space: gym.Space, 19 | action_space: gym.Space, 20 | name: str, 21 | d4rl_path: Optional[str] = None, # where to save D4RL files. 22 | use_rtg: bool = False, 23 | use_timesteps: bool = False, 24 | normalize_reward: bool = False, 25 | reward_scale: float = 1.0, 26 | reward_shift: float = 0.0, 27 | action_eps: float = 0.00001, 28 | **kwargs, 29 | ) -> None: 30 | self.env_name = name 31 | self.reward_scale = reward_scale 32 | self.reward_shift = reward_shift 33 | self.normalize_reward = normalize_reward 34 | self.action_eps = action_eps 35 | self.use_rtg = use_rtg 36 | self.use_timesteps = use_timesteps 37 | if d4rl_path is not None: 38 | d4rl.set_dataset_path(d4rl_path) 39 | super().__init__(observation_space, action_space, **kwargs) 40 | 41 | def _data_generator(self): 42 | env = gym.make(self.env_name) 43 | dataset = env.get_dataset() 44 | 45 | def get_done(i, ep_step=None): 46 | nonlocal dataset, env 47 | done = False 48 | if "ant" not in self.env_name: 49 | # use terminals 50 | done = done or dataset["terminals"][i] 51 | if "timeouts" in dataset: 52 | done = done or dataset["timeouts"][i] 53 | elif ep_step is not None: 54 | done = done or (episode_step == env._max_episode_steps - 1) 55 | return done 56 | 57 | # Compute dataset normalization as in https://github.com/ikostrikov/implicit_q_learning 58 | if self.normalize_reward: 59 | ep_rewards, ep_lengths = [], [] 60 | ep_reward, ep_length = 0, 0 61 | for i in range(len(dataset["observations"])): 62 | ep_reward += dataset["rewards"][i] 63 | done = get_done(i, ep_length) 64 | ep_length += 1 65 | if done: 66 | ep_rewards.append(ep_reward) 67 | ep_lengths.append(ep_length) 68 | ep_reward, ep_length = 0, 0 69 | min_reward, max_reward = min(ep_rewards), max(ep_rewards) 70 | max_length = max(ep_lengths) 71 | print("[research] Normalized D4RL range:", min_reward, max_reward) 72 | self.reward_scale *= max_length / (max_reward - min_reward) 73 | 74 | # Lots of this code was borrowed from https://github.com/rail-berkeley/d4rl/blob/master/d4rl/__init__.py 75 | obs_ = [] 76 | action_ = [self.dummy_action] 77 | reward_ = [0.0] 78 | done_ = [False] 79 | discount_ = [1.0] 80 | 81 | episode_step = 0 82 | for i in range(dataset["rewards"].shape[0]): 83 | obs = dataset["observations"][i].astype(np.float32) 84 | action = dataset["actions"][i].astype(np.float32) 85 | reward = dataset["rewards"][i].astype(np.float32) 86 | terminal = bool(dataset["terminals"][i]) 87 | done = get_done(i, episode_step) 88 | 89 | obs_.append(obs) 90 | action_.append(action) 91 | reward_.append(reward) 92 | discount_.append(1 - float(terminal)) 93 | done_.append(done) 94 | 95 | episode_step += 1 96 | 97 | if done: 98 | if "next_observations" in dataset: 99 | obs_.append(dataset["next_observations"][i].astype(np.float32)) 100 | else: 101 | # We need to do something to pad to the full length. 102 | # The default solution is to get rid of this transtion 103 | # but we need a transition with the terminal flag for our replay buffer 104 | # implementation to work. 105 | # Since we always end up masking this out anyways, it shouldn't matter and we can just repeat 106 | obs_.append(dataset["observations"][i].astype(np.float32)) 107 | 108 | obs_ = np.array(obs_) 109 | action_ = np.array(action_) 110 | if self.action_eps > 0.0: 111 | action_ = np.clip(action_, -1.0 + self.action_eps, 1.0 - self.action_eps) 112 | reward_ = np.array(reward_).astype(np.float32) * self.reward_scale + self.reward_shift 113 | discount_ = np.array(discount_).astype(np.float32) 114 | done_ = np.array(done_, dtype=np.bool_) 115 | 116 | data = dict(obs=obs_, action=action_, reward=reward_, done=done_, discount=discount_) 117 | 118 | # Support Decision Transformer. 119 | if self.use_rtg: 120 | # Compute reward to go 121 | discount = self.sample_fn.keywords.get("discount", 0.99) 122 | rtg = np.zeros_like(reward_, dtype=np.float32) 123 | rtg[-1] = reward_[-1] 124 | for t in reversed(range(reward_.shape[0] - 1)): 125 | rtg[t] = reward_[t] + discount * rtg[t + 1] 126 | data["rtg"] = rtg 127 | 128 | if self.use_timesteps: 129 | # WARNING: Might be an error in this because of the dummy transition 130 | data["timestep"] = np.arange(len(reward_), dtype=np.int64) 131 | 132 | yield data 133 | 134 | # reset the episode trackers 135 | episode_step = 0 136 | obs_ = [] 137 | action_ = [self.dummy_action] 138 | reward_ = [0.0] 139 | done_ = [False] 140 | discount_ = [1.0] 141 | 142 | # Finally clean up the environment 143 | del dataset 144 | del env 145 | -------------------------------------------------------------------------------- /research/datasets/rollout_buffer.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Any, Dict, Optional 3 | 4 | import gym 5 | import numpy as np 6 | import torch 7 | 8 | from research.utils.utils import np_dataset_alloc 9 | 10 | 11 | class RolloutBuffer(torch.utils.data.IterableDataset): 12 | """ """ 13 | 14 | def __init__( 15 | self, 16 | observation_space: gym.Space, 17 | action_space: gym.Space, 18 | discount: float = 0.99, 19 | batch_size: Optional[int] = None, 20 | gae_lambda: float = 0.95, 21 | capacity: int = 2048, 22 | ): 23 | # Observation and action space values 24 | self.observation_space = observation_space 25 | self.action_space = action_space 26 | 27 | # Queuing values 28 | self.discount = discount 29 | self.gae_lambda = gae_lambda 30 | self.batch_size = 1 if batch_size is None else batch_size 31 | self._capacity = capacity + 2 # Add one for the first timestep and one for the last timestep 32 | self._last_batch = True 33 | self._idx = 0 34 | 35 | @property 36 | def is_full(self) -> bool: 37 | return self._idx >= self._capacity 38 | 39 | @property 40 | def last_batch(self) -> bool: 41 | return self._last_batch 42 | 43 | def setup(self) -> None: 44 | # Setup the required rollout buffers 45 | self._obs_buffer = np_dataset_alloc(self.observation_space, self._capacity) 46 | self._action_buffer = np_dataset_alloc(self.action_space, self._capacity) 47 | self._reward_buffer = np_dataset_alloc(0.0, self._capacity) 48 | self._done_buffer = np_dataset_alloc(False, self._capacity) 49 | self._info_buffers = dict() 50 | self._idx = 0 51 | 52 | def __del__(self): 53 | pass 54 | 55 | def add( 56 | self, obs: Any, action: Optional[Any] = None, reward: Optional[Any] = None, done: Optional[Any] = None, **kwargs 57 | ) -> None: 58 | assert (action is None) == (reward is None) == (done is None) 59 | if action is None: 60 | # TODO: figure out if we should have the discount factor here. 61 | action = self.action_space.sample() 62 | reward = 0.0 63 | done = False 64 | 65 | assert self._idx < self._capacity, "Called add on a full buffer" 66 | 67 | def add_to_buffer_helper(buffer, value): 68 | if isinstance(buffer, dict): 69 | for k, v in buffer.items(): 70 | add_to_buffer_helper(v, value[k]) 71 | elif isinstance(buffer, np.ndarray): 72 | buffer[self._idx] = value 73 | else: 74 | raise ValueError("Attempted buffer ran out of space!") 75 | 76 | add_to_buffer_helper(self._obs_buffer, obs.copy()) 77 | add_to_buffer_helper(self._action_buffer, action.copy()) 78 | add_to_buffer_helper(self._reward_buffer, reward) 79 | add_to_buffer_helper(self._done_buffer, done) 80 | 81 | for k, v in kwargs.items(): 82 | if k not in self._info_buffers: 83 | self._info_buffers[k] = np_dataset_alloc(v, self._capacity) 84 | add_to_buffer_helper(self._info_buffers[k], v.copy()) 85 | 86 | self._idx += 1 # increase the index 87 | 88 | def prepare_buffer(self) -> None: 89 | assert "value" in self._info_buffers, "Attempted to use Rollout Buffer but values were not added." 90 | self._advantage_buffer = np_dataset_alloc(0.0, self._capacity) 91 | 92 | last_gae_lam = 0 93 | for step in reversed(range(1, self._capacity - 1)): # Stay within the valid range 94 | next_non_terminal = ( 95 | 1.0 - self._done_buffer[step] 96 | ) # Get done from the current step. Maybe should be step + 1? But i think not. 97 | next_values = self._info_buffers["value"][step + 1] 98 | 99 | delta = ( 100 | self._reward_buffer[step] 101 | + self.discount * next_values * next_non_terminal 102 | - self._info_buffers["value"][step] 103 | ) 104 | last_gae_lam = delta + self.discount * self.gae_lambda * next_non_terminal * last_gae_lam 105 | self._advantage_buffer[step] = last_gae_lam 106 | 107 | # TD(lambda) estimator, see Github PR #375 or "Telescoping in TD(lambda)" 108 | # in David Silver Lecture 4: https://www.youtube.com/watch?v=PnHCvfgC_ZA 109 | self._return_buffer = self._advantage_buffer + self._info_buffers["value"] 110 | 111 | def _get(self, idxs: np.ndarray) -> Dict: 112 | if idxs.shape[0] == 1: 113 | idxs = idxs[0] 114 | 115 | obs_idxs = idxs - 1 116 | obs = ( 117 | {k: v[obs_idxs] for k, v in self._obs_buffer} 118 | if isinstance(self._obs_buffer, dict) 119 | else self._obs_buffer[obs_idxs] 120 | ) 121 | action = ( 122 | {k: v[idxs] for k, v in self._action_buffer} 123 | if isinstance(self._action_buffer, dict) 124 | else self._action_buffer[idxs] 125 | ) 126 | returns = self._return_buffer[idxs] 127 | advantage = self._advantage_buffer[idxs] 128 | 129 | batch = dict(obs=obs, action=action, returns=returns, advantage=advantage) 130 | for k, v in self._info_buffers.items(): 131 | batch[k] = v[idxs] 132 | return batch 133 | 134 | def __iter__(self): 135 | worker_info = torch.utils.data.get_worker_info() 136 | assert worker_info is None, "Rollout Buffer does not support worker parallelism." 137 | # Return Empty Batches if we are not full 138 | if not self.is_full: 139 | self._last_batch = True 140 | yield dict() 141 | return 142 | 143 | self.prepare_buffer() 144 | self._last_batch = False 145 | idxs = np.random.permutation(self._capacity - 2) + 1 # Add one offset for initial observation 146 | num_batches = math.ceil(len(idxs) / self.batch_size) 147 | for i in range(num_batches - 1): # Do up to the last 148 | batch_idxs = idxs[i * self.batch_size : (i + 1) * self.batch_size] 149 | yield self._get(batch_idxs) 150 | self._last_batch = True # Flag last batch 151 | last_idxs = idxs[(num_batches - 1) * self.batch_size :] 152 | yield self._get(last_idxs) 153 | -------------------------------------------------------------------------------- /research/algs/online/sac.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Any, Dict, Type 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from research.networks.base import ActorCriticPolicy 8 | 9 | from ..off_policy_algorithm import OffPolicyAlgorithm 10 | 11 | 12 | class SAC(OffPolicyAlgorithm): 13 | def __init__( 14 | self, 15 | *args, 16 | tau: float = 0.005, 17 | init_temperature: float = 0.1, 18 | critic_freq: int = 1, 19 | actor_freq: int = 1, 20 | target_freq: int = 2, 21 | bc_coeff=0.0, 22 | **kwargs, 23 | ): 24 | # Save values needed for network setup. 25 | self.init_temperature = init_temperature 26 | super().__init__(*args, **kwargs) 27 | assert isinstance(self.network, ActorCriticPolicy) 28 | 29 | # Save extra parameters 30 | self.tau = tau 31 | self.critic_freq = critic_freq 32 | self.actor_freq = actor_freq 33 | self.target_freq = target_freq 34 | self.bc_coeff = bc_coeff 35 | self.target_entropy = -np.prod(self.processor.action_space.low.shape) 36 | 37 | @property 38 | def alpha(self) -> torch.Tensor: 39 | return self.log_alpha.exp() 40 | 41 | def setup_network(self, network_class: Type[torch.nn.Module], network_kwargs: Dict) -> None: 42 | # Setup network and target network 43 | self.network = network_class( 44 | self.processor.observation_space, self.processor.action_space, **network_kwargs 45 | ).to(self.device) 46 | self.target_network = network_class( 47 | self.processor.observation_space, self.processor.action_space, **network_kwargs 48 | ).to(self.device) 49 | self.target_network.load_state_dict(self.network.state_dict()) 50 | for param in self.target_network.parameters(): 51 | param.requires_grad = False 52 | 53 | # Setup the log alpha 54 | log_alpha = torch.tensor(np.log(self.init_temperature), dtype=torch.float).to(self.device) 55 | self.log_alpha = torch.nn.Parameter(log_alpha, requires_grad=True) 56 | 57 | def setup_optimizers(self) -> None: 58 | # Default optimizer initialization 59 | self.optim["actor"] = self.optim_class(self.network.actor.parameters(), **self.optim_kwargs) 60 | # Update the encoder with the critic. 61 | critic_params = itertools.chain(self.network.critic.parameters(), self.network.encoder.parameters()) 62 | self.optim["critic"] = self.optim_class(critic_params, **self.optim_kwargs) 63 | self.optim["log_alpha"] = self.optim_class([self.log_alpha], **self.optim_kwargs) 64 | 65 | def _update_critic(self, batch: Dict) -> Dict: 66 | with torch.no_grad(): 67 | dist = self.network.actor(batch["next_obs"]) 68 | next_action = dist.rsample() 69 | log_prob = dist.log_prob(next_action) 70 | target_qs = self.target_network.critic(batch["next_obs"], next_action) 71 | target_v = torch.min(target_qs, dim=0)[0] - self.alpha.detach() * log_prob 72 | target_q = batch["reward"] + batch["discount"] * target_v 73 | 74 | qs = self.network.critic(batch["obs"], batch["action"]) 75 | q_loss = torch.nn.functional.mse_loss(qs, target_q.expand(qs.shape[0], -1), reduction="none").mean() 76 | 77 | self.optim["critic"].zero_grad(set_to_none=True) 78 | q_loss.backward() 79 | self.optim["critic"].step() 80 | 81 | return dict(q_loss=q_loss.item(), target_q=target_q.mean().item()) 82 | 83 | def _update_actor_and_alpha(self, batch: Dict) -> Dict: 84 | obs = batch["obs"].detach() # Detach the encoder so it isn't updated. 85 | dist = self.network.actor(obs) 86 | action = dist.rsample() 87 | log_prob = dist.log_prob(action) 88 | qs = self.network.critic(obs, action) 89 | q = torch.min(qs, dim=0)[0] 90 | 91 | actor_loss = (self.alpha.detach() * log_prob - q).mean() 92 | if self.bc_coeff > 0.0: 93 | bc_loss = -dist.log_prob(batch["action"]).mean() # Simple NLL loss. 94 | actor_loss = actor_loss + self.bc_coeff * bc_loss 95 | 96 | self.optim["actor"].zero_grad(set_to_none=True) 97 | actor_loss.backward() 98 | self.optim["actor"].step() 99 | entropy = -log_prob.mean() 100 | 101 | # Update the learned temperature 102 | self.optim["log_alpha"].zero_grad(set_to_none=True) 103 | alpha_loss = (self.alpha * (-log_prob - self.target_entropy).detach()).mean() 104 | alpha_loss.backward() 105 | self.optim["log_alpha"].step() 106 | 107 | return dict( 108 | actor_loss=actor_loss.item(), 109 | entropy=entropy.item(), 110 | alpha_loss=alpha_loss.item(), 111 | alpha=self.alpha.detach().item(), 112 | ) 113 | 114 | def train_step(self, batch: Dict, step: int, total_steps: int) -> Dict: 115 | all_metrics = {} 116 | 117 | if "obs" not in batch or step < self.random_steps: 118 | return all_metrics 119 | 120 | batch["obs"] = self.network.encoder(batch["obs"]) 121 | with torch.no_grad(): 122 | batch["next_obs"] = self.target_network.encoder(batch["next_obs"]) 123 | 124 | if step % self.critic_freq == 0: 125 | metrics = self._update_critic(batch) 126 | all_metrics.update(metrics) 127 | 128 | if step % self.actor_freq == 0: 129 | metrics = self._update_actor_and_alpha(batch) 130 | all_metrics.update(metrics) 131 | 132 | if step % self.target_freq == 0: 133 | # Only update the critic and encoder for speed. Ignore the actor. 134 | with torch.no_grad(): 135 | for param, target_param in zip( 136 | self.network.encoder.parameters(), self.target_network.encoder.parameters() 137 | ): 138 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 139 | for param, target_param in zip( 140 | self.network.critic.parameters(), self.target_network.critic.parameters() 141 | ): 142 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 143 | 144 | return all_metrics 145 | 146 | def _get_train_action(self, obs: Any, step: int, total_steps: int) -> np.ndarray: 147 | batch = dict(obs=obs) 148 | with torch.no_grad(): 149 | action = self.predict(batch, is_batched=False, sample=True) 150 | return action 151 | -------------------------------------------------------------------------------- /research/algs/online/dqn.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Any, Dict, Type 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from ..off_policy_algorithm import OffPolicyAlgorithm 8 | 9 | 10 | class DQN(OffPolicyAlgorithm): 11 | def __init__( 12 | self, 13 | *args, 14 | target_freq: int = 1000, 15 | tau: float = 1.0, 16 | max_grad_norm: float = 10, 17 | eps_start: float = 1.0, 18 | eps_end: float = 0.05, 19 | eps_frac: float = 0.1, 20 | loss: str = "huber", 21 | **kwargs, 22 | ): 23 | super().__init__(*args, **kwargs) 24 | # Save extra parameters 25 | self.tau = tau 26 | self.target_freq = target_freq 27 | self.max_grad_norm = max_grad_norm 28 | self.eps_start = eps_start 29 | self.eps_end = eps_end 30 | self.eps_frac = eps_frac 31 | self.loss = self._get_loss(loss) 32 | 33 | def _get_loss(self, loss: str): 34 | if loss == "mse": 35 | return torch.nn.MSELoss() 36 | elif loss == "huber": 37 | return torch.nn.SmoothL1Loss() 38 | else: 39 | raise ValueError("Invalid loss specification") 40 | 41 | def setup_network(self, network_class: Type[torch.nn.Module], network_kwargs: Dict) -> None: 42 | self.network = network_class( 43 | self.processor.observation_space, self.processor.action_space, **network_kwargs 44 | ).to(self.device) 45 | self.target_network = network_class( 46 | self.processor.observation_space, self.processor.action_space, **network_kwargs 47 | ).to(self.device) 48 | self.target_network.load_state_dict(self.network.state_dict()) 49 | for param in self.target_network.parameters(): 50 | param.requires_grad = False 51 | 52 | def _get_train_action(self, obs: Any, step: int, total_steps: int) -> np.ndarray: 53 | if self.eps_frac > 0: 54 | frac = min(1.0, step / (total_steps * self.eps_frac)) 55 | eps = (1 - frac) * self.eps_start + frac * self.eps_end 56 | else: 57 | eps = 0.0 58 | 59 | if random.random() < eps: 60 | action = self.action_space.sample() 61 | else: 62 | with torch.no_grad(): 63 | action = self.predict(dict(obs=obs), sample=False) 64 | return action 65 | 66 | def _compute_value(self, batch: Any) -> torch.Tensor: 67 | next_q = self.target_network(batch["next_obs"]) 68 | next_v, _ = next_q.max(dim=-1) 69 | return next_v 70 | 71 | def train_step(self, batch: Any, step: int, total_steps: int) -> Dict: 72 | all_metrics = {} 73 | 74 | if step < self.random_steps or "obs" not in batch: 75 | return all_metrics 76 | 77 | # Update the agent 78 | with torch.no_grad(): 79 | next_v = self._compute_value(batch) 80 | target_q = batch["reward"] + batch["discount"] * next_v 81 | 82 | q = self.network(batch["obs"]) 83 | q = torch.gather(q, dim=-1, index=batch["action"].long().unsqueeze(-1)).squeeze(-1) 84 | loss = self.loss(q, target_q) 85 | 86 | self.optim["network"].zero_grad(set_to_none=True) 87 | loss.backward() 88 | torch.nn.utils.clip_grad_norm_(self.network.parameters(), self.max_grad_norm) 89 | self.optim["network"].step() 90 | 91 | all_metrics["q_loss"] = loss.item() 92 | all_metrics["target_q"] = target_q.mean().item() 93 | 94 | if step % self.target_freq == 0: 95 | with torch.no_grad(): 96 | for param, target_param in zip(self.network.parameters(), self.target_network.parameters()): 97 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 98 | 99 | return all_metrics 100 | 101 | def _validation_step(self, batch: Any): 102 | raise NotImplementedError("RL Algorithm does not have a validation dataset.") 103 | 104 | 105 | class DoubleDQN(DQN): 106 | def _compute_value(self, batch: Any) -> torch.Tensor: 107 | next_a = self.network(batch["next_obs"]).argmax(dim=-1) 108 | next_q = self.target_network(batch["next_obs"]) 109 | next_v = torch.gather(next_q, dim=-1, index=next_a.unsqueeze(-1)).squeeze(-1) 110 | return next_v 111 | 112 | 113 | class SoftDQN(DQN): 114 | def __init__(self, *args, exploration_alpha=0.01, target_alpha=0.1, **kwargs): 115 | super().__init__(*args, **kwargs) 116 | self.exploration_alpha = exploration_alpha 117 | self.target_alpha = target_alpha 118 | 119 | def _get_train_action(self, obs: Any, step: int, total_steps: int) -> np.ndarray: 120 | if self.eps_frac > 0: 121 | frac = min(1.0, step / (total_steps * self.eps_frac)) 122 | eps = (1 - frac) * self.eps_start + frac * self.eps_end 123 | else: 124 | eps = 0.0 125 | 126 | if random.random() < eps: 127 | action = self.action_space.sample() 128 | else: 129 | with torch.no_grad(): 130 | action = self.predict(dict(obs=obs), sample=True, temperature=self.exploration_alpha) 131 | return action 132 | 133 | def _compute_value(self, batch: Any) -> torch.Tensor: 134 | next_q = self.target_network(batch["next_obs"]) 135 | next_v = self.target_alpha * torch.logsumexp(next_q / self.target_alpha, dim=-1) 136 | return next_v 137 | 138 | 139 | class SoftDoubleDQN(DQN): 140 | def __init__(self, *args, exploration_alpha=0.01, target_alpha=0.1, **kwargs): 141 | super().__init__(*args, **kwargs) 142 | self.exploration_alpha = exploration_alpha 143 | self.target_alpha = target_alpha 144 | 145 | def _get_train_action(self, obs: Any, step: int, total_steps: int) -> np.ndarray: 146 | if self.eps_frac > 0: 147 | frac = min(1.0, step / (total_steps * self.eps_frac)) 148 | eps = (1 - frac) * self.eps_start + frac * self.eps_end 149 | else: 150 | eps = 0.0 151 | 152 | if random.random() < eps: 153 | action = self.action_space.sample() 154 | else: 155 | with torch.no_grad(): 156 | action = self.predict(dict(obs=obs), sample=True, temperature=self.exploration_alpha) 157 | return action 158 | 159 | def _compute_value(self, batch: Any) -> torch.Tensor: 160 | log_pi = torch.nn.functional.log_softmax(self.network(batch["next_obs"]), dim=-1) 161 | next_q = self.target_network(batch["next_obs"]) 162 | next_v = self.target_alpha * torch.logsumexp(next_q / self.target_alpha + log_pi, dim=-1) 163 | return next_v 164 | -------------------------------------------------------------------------------- /scripts/create_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import os 4 | import time 5 | 6 | import gym 7 | import numpy as np 8 | 9 | from research.datasets import ReplayBuffer 10 | from research.utils.config import Config 11 | from research.utils.evaluate import EvalMetricTracker 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--path", type=str, required=True) 16 | parser.add_argument("--num-ep", type=int, default=np.inf) 17 | parser.add_argument("--num-steps", type=int, default=np.inf) 18 | parser.add_argument( 19 | "--shard", action="store_true", default=False, help="Whether or not to shard the dataset into episodes." 20 | ) 21 | parser.add_argument("--noise", type=float, default=0.0, help="Gaussian noise std.") 22 | parser.add_argument("--random-percent", type=float, default=0.0, help="percent of dataset to be purely random.") 23 | parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint to load") 24 | parser.add_argument("--device", type=str, default="auto") 25 | parser.add_argument( 26 | "--override", 27 | metavar="KEY=VALUE", 28 | nargs="+", 29 | default=[], 30 | help="Set kv pairs used as args for the entry point script.", 31 | ) 32 | args = parser.parse_args() 33 | 34 | assert not (args.num_steps == np.inf and args.num_ep == np.inf), "Must set one of num-steps and num-ep" 35 | assert not (args.num_steps != np.inf and args.num_ep != np.inf), "Cannot set both num-steps and num-ep" 36 | assert args.random_percent <= 1.0 and args.random_percent >= 0.0, "Invalid random-percent" 37 | 38 | if os.path.exists(args.path): 39 | print("[research] Warning: saving dataset to an existing directory.") 40 | os.makedirs(args.path, exist_ok=True) 41 | 42 | # Load the config 43 | config = Config.load(os.path.dirname(args.checkpoint) if args.checkpoint.endswith(".pt") else args.checkpoint) 44 | config["checkpoint"] = None # Set checkpoint to None 45 | 46 | # Overrides 47 | print("Overrides:") 48 | for override in args.override: 49 | print(override) 50 | 51 | for override in args.override: 52 | items = override.split("=") 53 | key, value = items[0].strip(), "=".join(items[1:]) 54 | # Progress down the config path (seperated by '.') until we reach the final value to override. 55 | config_path = key.split(".") 56 | config_dict = config 57 | while len(config_path) > 1: 58 | config_dict = config_dict[config_path[0]] 59 | config_path.pop(0) 60 | config_dict[config_path[0]] = value 61 | 62 | # Parse the config 63 | config = config.parse() 64 | 65 | # Get the environment 66 | env = config.get_train_env_fn()() 67 | if env is None: 68 | env = config.get_eval_env_fn()() 69 | 70 | if args.random_percent < 1.0: 71 | assert args.checkpoint.endswith(".pt"), "Did not specify checkpoint file." 72 | model = config.get_model( 73 | observation_space=env.observation_space, action_space=env.action_space, device=args.device 74 | ) 75 | metadata = model.load(args.checkpoint) 76 | else: 77 | model = None 78 | 79 | # Calculate the replay buffer capacity 80 | if args.num_ep < np.inf and not args.shard: 81 | try: 82 | max_ep_steps = env._max_episode_steps 83 | except AttributeError: 84 | max_ep_steps = env.unwrapped._max_episode_steps 85 | capacity = (max_ep_steps + 2) * args.num_ep 86 | elif not args.shard: 87 | capacity = args.num_steps 88 | else: 89 | capacity = 2 90 | 91 | # Init the replay buffer. 92 | replay_buffer = ReplayBuffer( 93 | env.observation_space, env.action_space, capacity=capacity, cleanup=not args.shard, distributed=args.shard 94 | ) 95 | 96 | print(replay_buffer._storage["done"].dtype) 97 | 98 | # Track data collection 99 | num_steps = 0 100 | num_ep = 0 101 | finished_data_collection = False 102 | # Episode metrics 103 | metric_tracker = EvalMetricTracker() 104 | start_time = time.time() 105 | 106 | while not finished_data_collection: 107 | # Determine if we should use random actions or not. 108 | progress = num_ep / args.num_ep if args.num_ep != np.inf else num_steps / args.num_steps 109 | use_random_actions = progress < args.random_percent 110 | 111 | # Collect an episode 112 | done = False 113 | ep_length = 0 114 | obs = env.reset() 115 | metric_tracker.reset() 116 | replay_buffer.add(obs=obs) 117 | while not done: 118 | if use_random_actions: 119 | action = env.action_space.sample() 120 | else: 121 | action = model.predict(dict(obs=obs)) 122 | if args.noise > 0: 123 | assert isinstance(env.action_space, gym.spaces.Box) 124 | action = action + args.noise * np.random.randn(*action.shape) 125 | # Step the environment with the predicted action 126 | env_action = np.clip(action, env.action_space.low, env.action_space.high) 127 | 128 | obs, reward, done, info = env.step(action) 129 | metric_tracker.step(reward, info) 130 | ep_length += 1 131 | 132 | # Determine the discount factor. 133 | if "discount" in info: 134 | discount = info["discount"] 135 | elif hasattr(env, "_max_episode_steps") and ep_length == env._max_episode_steps: 136 | discount = 1.0 137 | else: 138 | discount = 1 - float(done) 139 | 140 | # Store the consequences. 141 | replay_buffer.add(obs=obs, action=action, reward=reward, done=done, discount=discount) 142 | num_steps += 1 143 | 144 | num_ep += 1 145 | # Determine if we should stop data collection 146 | finished_data_collection = num_steps >= args.num_steps or num_ep >= args.num_ep 147 | 148 | end_time = time.time() 149 | print("Finished", num_ep, "episodes in", num_steps, "steps.") 150 | print("It took", (end_time - start_time) / num_steps, "seconds per step") 151 | 152 | replay_buffer.save(args.path) 153 | 154 | # Write the metrics 155 | metrics = metric_tracker.export() 156 | dt = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") 157 | print("Metrics:") 158 | print(metrics) 159 | with open(os.path.join(args.path, "metrics.txt"), "a") as f: 160 | f.write("Collected data: " + str(dt) + "\n") 161 | for k, v in metrics.items(): 162 | f.write(k + ": " + str(v) + "\n") 163 | -------------------------------------------------------------------------------- /research/algs/offline/iql.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Any, Dict, Optional, Type 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from research.networks.base import ActorCriticValuePolicy 8 | 9 | from ..off_policy_algorithm import OffPolicyAlgorithm 10 | 11 | 12 | def iql_loss(pred, target, expectile=0.5): 13 | err = target - pred 14 | weight = torch.abs(expectile - (err < 0).float()) 15 | return weight * torch.square(err) 16 | 17 | 18 | class IQL(OffPolicyAlgorithm): 19 | def __init__( 20 | self, 21 | *args, 22 | tau: float = 0.005, 23 | target_freq: int = 1, 24 | expectile: Optional[float] = None, 25 | beta: float = 1, 26 | clip_score: float = 100.0, 27 | **kwargs, 28 | ) -> None: 29 | super().__init__(*args, **kwargs) 30 | self.tau = tau 31 | self.target_freq = target_freq 32 | self.expectile = expectile 33 | self.beta = beta 34 | self.clip_score = clip_score 35 | assert isinstance(self.network, ActorCriticValuePolicy) 36 | 37 | def setup_network(self, network_class: Type[torch.nn.Module], network_kwargs: Dict) -> None: 38 | self.network = network_class( 39 | self.processor.observation_space, self.processor.action_space, **network_kwargs 40 | ).to(self.device) 41 | self.target_network = network_class( 42 | self.processor.observation_space, self.processor.action_space, **network_kwargs 43 | ).to(self.device) 44 | self.target_network.load_state_dict(self.network.state_dict()) 45 | for param in self.target_network.parameters(): 46 | param.requires_grad = False 47 | 48 | def setup_optimizers(self) -> None: 49 | actor_params = itertools.chain(self.network.actor.parameters(), self.network.encoder.parameters()) 50 | self.optim["actor"] = self.optim_class(actor_params, **self.optim_kwargs) 51 | self.optim["critic"] = self.optim_class(self.network.critic.parameters(), **self.optim_kwargs) 52 | self.optim["value"] = self.optim_class(self.network.value.parameters(), **self.optim_kwargs) 53 | 54 | def train_step(self, batch: Dict, step: int, total_steps: int) -> Dict: 55 | # We use the online encoder for everything in this IQL implementation 56 | # That is because we need to use the current obs for the target critic and online value. 57 | # This is done by default in DrQv2. 58 | with torch.no_grad(): 59 | batch["next_obs"] = self.network.encoder(batch["next_obs"]) 60 | batch["obs"] = self.network.encoder(batch["obs"]) 61 | 62 | # First compute the value loss 63 | with torch.no_grad(): 64 | target_q = self.target_network.critic(batch["obs"], batch["action"]) 65 | target_q = torch.min(target_q, dim=0)[0] 66 | vs = self.network.value(batch["obs"].detach()) # Always detach for value learning 67 | v_loss = iql_loss(vs, target_q.expand(vs.shape[0], -1), self.expectile).mean() 68 | 69 | self.optim["value"].zero_grad(set_to_none=True) 70 | v_loss.backward() 71 | self.optim["value"].step() 72 | 73 | # Next, compute the critic loss 74 | with torch.no_grad(): 75 | next_vs = self.network.value(batch["next_obs"]) 76 | next_v = torch.min(next_vs, dim=0)[0] 77 | target = batch["reward"] + batch["discount"] * next_v 78 | qs = self.network.critic(batch["obs"].detach(), batch["action"]) 79 | q_loss = torch.nn.functional.mse_loss(qs, target.expand(qs.shape[0], -1), reduction="none").mean() 80 | 81 | self.optim["critic"].zero_grad(set_to_none=True) 82 | q_loss.backward() 83 | self.optim["critic"].step() 84 | 85 | # Next, update the actor. We detach and use the old value, v for computational efficiency 86 | # though the JAX IQL recomputes it, while Pytorch IQL versions do not. 87 | with torch.no_grad(): 88 | adv = target_q - torch.min(vs, dim=0)[0] 89 | exp_adv = torch.exp(adv / self.beta) 90 | if self.clip_score is not None: 91 | exp_adv = torch.clamp(exp_adv, max=self.clip_score) 92 | 93 | dist = self.network.actor(batch["obs"]) 94 | if isinstance(dist, torch.distributions.Distribution): 95 | bc_loss = -dist.log_prob(batch["action"]) 96 | elif torch.is_tensor(dist): 97 | assert dist.shape == batch["action"].shape 98 | bc_loss = torch.nn.functional.mse_loss(dist, batch["action"], reduction="none").sum(dim=-1) 99 | else: 100 | raise ValueError("Invalid policy output provided") 101 | actor_loss = (exp_adv * bc_loss).mean() 102 | 103 | # Update the networks. These are done in a stack to support different grad options for the encoder. 104 | self.optim["actor"].zero_grad(set_to_none=True) 105 | actor_loss.backward() 106 | self.optim["actor"].step() 107 | 108 | if step % self.target_freq == 0: 109 | with torch.no_grad(): 110 | # Only run on the critic and encoder, those are the only weights we update. 111 | for param, target_param in zip( 112 | self.network.critic.parameters(), self.target_network.critic.parameters() 113 | ): 114 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 115 | for param, target_param in zip( 116 | self.network.encoder.parameters(), self.target_network.encoder.parameters() 117 | ): 118 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 119 | 120 | return dict( 121 | q_loss=q_loss.item(), 122 | v_loss=v_loss.item(), 123 | actor_loss=actor_loss.item(), 124 | v=vs.mean().item(), 125 | q=qs.mean().item(), 126 | advantage=adv.mean().item(), 127 | ) 128 | 129 | def _predict(self, batch: Dict, sample: bool = False, noise: float = 0.0) -> torch.Tensor: 130 | with torch.no_grad(): 131 | z = self.network.encoder(batch["obs"]) 132 | dist = self.network.actor(z) 133 | if isinstance(dist, torch.distributions.Distribution): 134 | action = dist.sample() if sample else dist.base_dist.loc 135 | elif torch.is_tensor(dist): 136 | action = dist 137 | else: 138 | raise ValueError("Invalid policy output") 139 | if noise > 0.0: 140 | action = action + noise * torch.randn_like(action) 141 | action = action.clamp(*self.action_range) 142 | return action 143 | 144 | def _get_train_action(self, obs: Any, step: int, total_steps: int) -> np.ndarray: 145 | batch = dict(obs=obs) 146 | with torch.no_grad(): 147 | action = self.predict(batch, is_batched=False, sample=True) 148 | return action 149 | -------------------------------------------------------------------------------- /research/algs/online/drqv2.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Any, Dict, Tuple, Type, Union 3 | 4 | import gym 5 | import numpy as np 6 | import torch 7 | 8 | from research.networks.base import ActorCriticPolicy 9 | 10 | from ..off_policy_algorithm import OffPolicyAlgorithm 11 | 12 | """ 13 | Note: this implementation is untested! 14 | """ 15 | 16 | 17 | class TruncatedNormal(torch.distributions.Normal): 18 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6): 19 | super().__init__(loc, scale, validate_args=False) 20 | self.low = low 21 | self.high = high 22 | self.eps = eps 23 | 24 | def _clamp(self, x): 25 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 26 | x = x - x.detach() + clamped_x.detach() 27 | return x 28 | 29 | def sample(self, clip=None, sample_shape=None): 30 | shape = self._extended_shape(torch.Size() if sample_shape is None else sample_shape) 31 | eps = torch.distributions.utils._standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) 32 | eps *= self.scale 33 | if clip is not None: 34 | eps = torch.clamp(eps, -clip, clip) 35 | x = self.loc + eps 36 | return self._clamp(x) 37 | 38 | 39 | class DRQV2(OffPolicyAlgorithm): 40 | """ 41 | NOTE: DrQv2 implementation is untested and not verified yet. 42 | Please do not use this implementation for baseline comparisons. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | env: gym.Env, 48 | network_class: Type[torch.nn.Module], 49 | dataset_class: Union[Type[torch.utils.data.IterableDataset], Type[torch.utils.data.Dataset]], 50 | tau: float = 0.005, 51 | critic_freq: int = 1, 52 | actor_freq: int = 1, 53 | target_freq: int = 1, 54 | init_steps: int = 1000, 55 | std_schedule: Tuple[float, float, int] = (1.0, 0.1, 500000), 56 | noise_clip: float = 0.3, 57 | **kwargs, 58 | ): 59 | super().__init__(env, network_class, dataset_class, **kwargs) 60 | assert isinstance(self.network, ActorCriticPolicy) 61 | # Save extra parameters 62 | self.tau = tau 63 | self.critic_freq = critic_freq 64 | self.actor_freq = actor_freq 65 | self.target_freq = target_freq 66 | self.std_schedule = std_schedule 67 | self.init_steps = init_steps 68 | self.noise_clip = noise_clip 69 | 70 | def setup_network(self, network_class: Type[torch.nn.Module], network_kwargs: Dict) -> None: 71 | self.network = network_class( 72 | self.processor.observation_space, self.processor.action_space, **network_kwargs 73 | ).to(self.device) 74 | self.target_network = network_class( 75 | self.processor.observation_space, self.processor.action_space, **network_kwargs 76 | ).to(self.device) 77 | self.target_network.load_state_dict(self.network.state_dict()) 78 | for param in self.target_network.parameters(): 79 | param.requires_grad = False 80 | 81 | def setup_optimizers(self) -> None: 82 | # Default optimizer initialization 83 | self.optim["actor"] = self.optim_class(self.network.actor.parameters(), **self.optim_kwargs) 84 | # Update the encoder with the critic. 85 | critic_params = itertools.chain(self.network.critic.parameters(), self.network.encoder.parameters()) 86 | self.optim["critic"] = self.optim_class(critic_params, **self.optim_kwargs) 87 | 88 | def _get_std(self, step: int): 89 | init, final, duration = self.std_schedule 90 | mix = np.clip(step / duration, 0.0, 1.0) 91 | std = (1.0 - mix) * init + mix * final 92 | return std 93 | 94 | def _update_critic(self, batch: Dict, step: int) -> Dict: 95 | with torch.no_grad(): 96 | mu = self.network.actor(batch["next_obs"]) 97 | std = self._get_std(step) * torch.ones_like(mu) 98 | next_action = TruncatedNormal(mu, std).sample(clip=self.noise_clip) 99 | target_qs = self.target_network.critic(batch["next_obs"], next_action) 100 | target_v = torch.min(target_qs, dim=0)[0] 101 | target_q = batch["reward"] + batch["discount"] * target_v 102 | 103 | qs = self.network.critic(batch["obs"], batch["action"]) 104 | q_loss = ( 105 | torch.nn.functional.mse_loss(qs, target_q.expand(qs.shape[0], -1), reduction="none").mean(dim=-1).sum() 106 | ) # averages over the ensemble. No for loop! 107 | 108 | self.optim["critic"].zero_grad(set_to_none=True) 109 | q_loss.backward() 110 | self.optim["critic"].step() 111 | 112 | return dict(q_loss=q_loss.item(), target_q=target_q.mean().item()) 113 | 114 | def _update_actor(self, batch: Dict, step: int) -> Dict: 115 | obs = batch["obs"].detach() # Detach the encoder so it isn't updated. 116 | mu = self.network.actor(obs) 117 | std = self._get_std(step) * torch.ones_like(mu) 118 | dist = TruncatedNormal(mu, std) 119 | action = dist.sample(clip=self.noise_clip) 120 | log_prob = dist.log_prob(action) 121 | 122 | q1, q2 = self.network.critic(obs, action) 123 | q = torch.min(q1, q2) 124 | actor_loss = -q.mean() 125 | 126 | self.optim["actor"].zero_grad(set_to_none=True) 127 | actor_loss.backward() 128 | self.optim["actor"].step() 129 | 130 | return dict(actor_loss=actor_loss.item(), log_prob=log_prob.mean().item()) 131 | 132 | def _get_train_action(self, obs: Any, step: int, total_steps: int) -> np.ndarray: 133 | batch = dict(obs=obs) 134 | with torch.no_grad(): 135 | action = self.predict(batch, noise=self._get_std(step), noise_clip=None) 136 | action = np.clip(action, self.processor.action_space.low + 1e-6, self.processor.action_space.high - 1e-6) 137 | return action 138 | 139 | def train_step(self, batch: Dict, step: int, total_steps: int) -> Dict: 140 | all_metrics = {} 141 | 142 | if "obs" not in batch or step < self.random_steps: 143 | return all_metrics 144 | 145 | batch["obs"] = self.network.encoder(batch["obs"]) 146 | with torch.no_grad(): 147 | batch["next_obs"] = self.network.encoder(batch["next_obs"]) 148 | 149 | if step % self.critic_freq == 0: 150 | metrics = self._update_critic(batch, step) 151 | all_metrics.update(metrics) 152 | 153 | if step % self.actor_freq == 0: 154 | metrics = self._update_actor(batch, step) 155 | all_metrics.update(metrics) 156 | 157 | if step % self.target_freq == 0: 158 | # Only update the critic for speed. Ignore the actor. 159 | with torch.no_grad(): 160 | for param, target_param in zip( 161 | self.network.critic.parameters(), self.target_network.critic.parameters() 162 | ): 163 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 164 | 165 | return all_metrics 166 | -------------------------------------------------------------------------------- /research/networks/transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Taken from nanoGPT by Andrej Karpathy 3 | https://github.com/karpathy/nanoGPT/blob/master/model.py 4 | """ 5 | import math 6 | 7 | import gym 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | 13 | 14 | def transformer_weight_init(module: nn.Module): 15 | if isinstance(module, nn.Linear): 16 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 17 | if module.bias is not None: 18 | torch.nn.init.zeros_(module.bias) 19 | elif isinstance(module, nn.Embedding): 20 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 21 | 22 | 23 | class LayerNorm(nn.Module): 24 | """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" 25 | 26 | def __init__(self, n_embd, bias, eps=1e-5): 27 | super().__init__() 28 | self.weight = nn.Parameter(torch.ones(n_embd)) 29 | self.bias = nn.Parameter(torch.zeros(n_embd)) if bias else None 30 | self.eps = eps 31 | 32 | def forward(self, input): 33 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, self.eps) 34 | 35 | 36 | class MLP(nn.Module): 37 | def __init__(self, n_embd=128, dropout=0.1, dense_multiplier=4, bias=True): 38 | super().__init__() 39 | self.c_fc = nn.Linear(n_embd, int(dense_multiplier * n_embd), bias=bias) 40 | self.gelu = nn.GELU() 41 | self.c_proj = nn.Linear(4 * n_embd, n_embd, bias=bias) 42 | self.dropout = nn.Dropout(dropout) 43 | 44 | def forward(self, x): 45 | x = self.c_fc(x) 46 | x = self.gelu(x) 47 | x = self.c_proj(x) 48 | x = self.dropout(x) 49 | return x 50 | 51 | 52 | class SelfAttention(nn.Module): 53 | def __init__( 54 | self, n_embd: int = 128, n_head: int = 4, dropout: float = 0.1, bias: bool = True, causal: bool = True 55 | ): 56 | super().__init__() 57 | assert n_embd % n_head == 0 58 | # key, query, value projections for all heads, but in a batch 59 | self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias=bias) 60 | # output projection 61 | self.c_proj = nn.Linear(n_embd, n_embd, bias=bias) 62 | # regularization 63 | self.attn_dropout = nn.Dropout(dropout) 64 | self.resid_dropout = nn.Dropout(dropout) 65 | self.n_head = n_head 66 | self.n_embd = n_embd 67 | self.dropout = dropout 68 | # Causal 69 | self.causal = causal 70 | 71 | def forward(self, x, attn_mask=None): 72 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 73 | 74 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 75 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 76 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 77 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 78 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 79 | 80 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 81 | if attn_mask is not None: 82 | assert attn_mask.dtype == torch.bool 83 | attn_mask = attn_mask.unsqueeze(1) # Expand for attention heads. 84 | y = torch.nn.functional.scaled_dot_product_attention( 85 | q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0, is_causal=self.causal 86 | ) 87 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 88 | # output projection 89 | y = self.resid_dropout(self.c_proj(y)) 90 | return y 91 | 92 | 93 | class TransformerBlock(nn.Module): 94 | def __init__( 95 | self, 96 | n_embd: int = 128, 97 | n_head: int = 4, 98 | dropout: float = 0.1, 99 | dense_multiplier: int = 4, 100 | bias: bool = False, 101 | causal: bool = True, 102 | eps: float = 1e-5, 103 | ): 104 | super().__init__() 105 | self.ln_1 = LayerNorm(n_embd, bias=bias, eps=eps) 106 | self.attn = SelfAttention(n_embd=n_embd, n_head=n_head, dropout=dropout, bias=bias, causal=causal) 107 | self.ln_2 = LayerNorm(n_embd, bias=bias) 108 | self.mlp = MLP(n_embd=n_embd, dropout=dropout, dense_multiplier=dense_multiplier, bias=bias) 109 | 110 | def forward(self, x, attn_mask=None): 111 | x = x + self.attn(self.ln_1(x), attn_mask=attn_mask) 112 | x = x + self.mlp(self.ln_2(x)) 113 | return x 114 | 115 | 116 | class TransformerEncoder(nn.Module): 117 | def __init__( 118 | self, 119 | n_embd: int = 128, 120 | n_head: int = 4, 121 | n_layer: int = 2, 122 | dropout: float = 0.1, 123 | dense_multiplier: int = 4, 124 | bias: bool = False, 125 | causal: bool = True, 126 | eps: float = 1e-5, 127 | block_size: int = 128, 128 | ): 129 | super().__init__() 130 | self.n_embd = n_embd 131 | self.block_size = block_size 132 | self.pos_embedding = nn.Embedding(block_size, n_embd) 133 | self.dropout = nn.Dropout(dropout) 134 | self.blocks = nn.ModuleList( 135 | [ 136 | TransformerBlock( 137 | n_embd=n_embd, 138 | n_head=n_head, 139 | dropout=dropout, 140 | dense_multiplier=dense_multiplier, 141 | bias=bias, 142 | causal=causal, 143 | eps=eps, 144 | ) 145 | for _ in range(n_layer) 146 | ] 147 | ) 148 | self.layer_norm = LayerNorm(n_embd, bias=bias, eps=eps) 149 | 150 | self.apply(transformer_weight_init) 151 | # apply special scaled init to the residual projections, per GPT-2 paper 152 | for pn, p in self.named_parameters(): 153 | if pn.endswith("c_proj.weight"): 154 | torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * n_layer)) 155 | 156 | def forward(self, x, attn_mask=None): 157 | assert len(x.shape) == 3 158 | assert x.shape[1] <= self.block_size, "Insufficient block size." 159 | pos_idxs = torch.arange(0, x.shape[1], device=x.device, dtype=torch.long) 160 | x = x + self.pos_embedding(pos_idxs) 161 | x = self.dropout(x) 162 | for block in self.blocks: 163 | x = block(x, attn_mask=attn_mask) 164 | x = self.layer_norm(x) 165 | return x 166 | 167 | 168 | class StateTransformerEncoder(nn.Module): 169 | def __init__(self, observation_space: gym.Space, action_space: gym.Space, n_embd=128, bias=True, **kwargs): 170 | super().__init__() 171 | assert isinstance(observation_space, gym.spaces.Box) and len(observation_space.shape) == 1 172 | self.n_embd = n_embd 173 | self.transformer = TransformerEncoder(n_embd=n_embd, bias=bias, **kwargs) 174 | self.token_ln = LayerNorm(self.n_embd, bias=bias) 175 | self.obs_embedding = nn.Linear(observation_space.shape[0], n_embd) 176 | nn.init.normal_(self.obs_embedding.weight, mean=0.0, std=0.02) 177 | 178 | @property 179 | def output_space(self): 180 | return gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.n_embd,), dtype=np.float32) 181 | 182 | def forward(self, obs, mask=None): 183 | assert len(obs.shape) == 3 184 | return self.transformer(self.token_ln(self.obs_embedding(obs)), attn_mask=mask) 185 | -------------------------------------------------------------------------------- /research/processors/image_augmentation.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from typing import Dict, Optional, Tuple, Union 3 | 4 | import gym 5 | import numpy as np 6 | import torch 7 | import torchvision 8 | from torch.nn import functional as F 9 | 10 | from .base import Processor 11 | 12 | 13 | def is_image_space(space): 14 | shape = space.shape 15 | is_image_space = (len(shape) == 3 or len(shape) == 4) and space.dtype == np.uint8 16 | return is_image_space 17 | 18 | 19 | def modify_space_hw(space, h, w, dtype=None): 20 | if isinstance(space, gym.spaces.Box) and is_image_space(space): 21 | shape = list(space.shape) 22 | shape[-2] = h 23 | shape[-1] = w 24 | dtype = space.dtype if dtype is None else dtype 25 | return gym.spaces.Box(low=0, high=255, shape=shape, dtype=dtype) 26 | elif isinstance(space, gym.spaces.Dict): 27 | return gym.spaces.Dict({k: modify_space_hw(v, h, w, dtype=dtype) for k, v in space.items()}) 28 | else: 29 | return space 30 | 31 | 32 | class RandomCrop(Processor): 33 | def __init__( 34 | self, 35 | observation_space: gym.Space, 36 | action_space: gym.Space, 37 | size: Optional[Tuple[int, int]] = None, 38 | pad: Union[int, Tuple[int, int]] = 4, 39 | consistent: bool = True, 40 | ) -> None: 41 | super().__init__(observation_space, action_space) 42 | self.consistent = consistent 43 | 44 | # Get the image keys and sequence lengths 45 | if isinstance(observation_space, gym.spaces.Box): 46 | assert is_image_space(observation_space) 47 | self.is_sequence = len(observation_space.shape) == 4 48 | self.in_h, self.in_w = observation_space.shape[-2], observation_space.shape[-1] 49 | self.image_keys = None 50 | elif isinstance(observation_space, gym.spaces.Dict): 51 | image_keys = [] 52 | sequence = [] 53 | hs, ws = [], [] 54 | for k, v in observation_space.items(): 55 | if is_image_space(v): 56 | image_keys.append(k) 57 | if len(v.shape) == 4: 58 | sequence.append(v.shape[0]) # Append the sequence dim 59 | else: 60 | sequence.append(0) 61 | ws.append(v.shape[-1]) 62 | hs.append(v.shape[-2]) 63 | assert all(sequence) or (not any(sequence)), "All image keys must be sequence or not" 64 | assert all([h == hs[0] for h in hs]) 65 | assert all([w == ws[0] for w in ws]) 66 | self.in_h, self.in_w = hs[0], ws[0] 67 | self.is_sequence = sequence[0] 68 | self.image_keys = image_keys 69 | else: 70 | raise ValueError("Invalid observation space specified") 71 | 72 | # Save output sizes 73 | if size is None: 74 | self.out_h, self.out_w = self.in_h, self.in_w 75 | else: 76 | self.out_h, self.out_w = size 77 | assert self.out_h <= self.in_h and self.out_w <= self.in_w 78 | 79 | self.pad = (pad, pad) if isinstance(pad, int) else pad 80 | self.padding = [self.pad[0], self.pad[0], self.pad[1], self.pad[1]] 81 | self.do_pad = self.pad[0] > 0 or self.pad[1] > 0 82 | 83 | # Save intermediate sizes 84 | self.middle_h, self.middle_w = self.in_h + 2 * self.pad[0], self.in_w + 2 * self.pad[1] 85 | 86 | self.is_square = self.in_h == self.in_w 87 | if self.is_square: 88 | assert self.out_h == self.out_w, "Must use square output on square images for acceleration" 89 | assert self.pad[0] == self.pad[1], "Must use uniform pad with square images" 90 | 91 | eps_h = 1.0 / (self.middle_h) 92 | eps_w = 1.0 / (self.middle_w) 93 | 94 | grid_h = torch.linspace(-1.0 + eps_h, 1.0 - eps_h, self.middle_h, dtype=torch.float)[: self.out_h] 95 | grid_h = grid_h.unsqueeze(1).repeat(1, self.out_w) 96 | grid_w = torch.linspace(-1.0 + eps_w, 1.0 - eps_w, self.middle_w, dtype=torch.float)[: self.out_w] 97 | grid_w = grid_w.unsqueeze(0).repeat(self.out_h, 1) 98 | base_grid = torch.stack((grid_w, grid_h), dim=-1).unsqueeze(0) # Shape (1, out_h, out_w, 2) 99 | 100 | self.register_buffer("base_grid", base_grid, persistent=False) # Do note save the grid in state_dict 101 | 102 | # Now set the eval op 103 | if self.out_h == self.in_h and self.out_w == self.in_w: 104 | self.eval_op = None 105 | else: 106 | self.eval_op = functools.partial( 107 | torchvision.transforms.functional.center_crop, output_size=(self.out_h, self.out_w) 108 | ) 109 | 110 | @property 111 | def observation_space(self): 112 | return modify_space_hw(self._observation_space, self.out_h, self.out_w, dtype=np.float32) 113 | 114 | def _aug(self, x: torch.Tensor) -> torch.Tensor: 115 | size = x.size() 116 | assert len(size) == 4, "_aug supports images of shape (b, c, h, w)" 117 | b = size[0] 118 | # Determine if we should pad 119 | if self.do_pad: 120 | x = F.pad(x, self.padding, "replicate") 121 | 122 | if self.is_square: 123 | # offsets are computed in the pad and subsample size 124 | offsets = ( 125 | torch.randint(0, self.middle_h - self.out_h + 1, size=(b, 1, 1, 2), device=x.device, dtype=torch.float) 126 | * 2.0 127 | / (self.middle_h) 128 | ) 129 | else: 130 | # We need to compute individual h and w offsets. 131 | h_offsets = ( 132 | torch.randint(0, self.middle_h - self.out_h + 1, size=(b, 1, 1), device=x.device, dtype=torch.float) 133 | * 2.0 134 | / (self.middle_h) 135 | ) 136 | w_offsets = ( 137 | torch.randint(0, self.middle_w - self.out_w + 1, size=(b, 1, 1), device=x.device, dtype=torch.float) 138 | * 2.0 139 | / (self.middle_w) 140 | ) 141 | offsets = torch.stack((w_offsets, h_offsets), dim=-1) 142 | 143 | grid = self.base_grid + offsets 144 | 145 | return F.grid_sample(x, grid, padding_mode="zeros", align_corners=False) 146 | 147 | def forward(self, batch: Dict) -> Dict: 148 | op = self._aug 149 | if not self.training: 150 | if self.eval_op is None: 151 | return batch 152 | else: 153 | op = self.eval_op 154 | 155 | # Images are assumed to be of shape (B, S, C, H, W) or (B, C, H, W) if there is no sequence dimension 156 | images = [] 157 | split = [] 158 | for k in ("obs", "next_obs", "init_obs"): 159 | if k in batch: 160 | if self.image_keys is None: 161 | images.append(batch[k]) 162 | split.append(batch[k].shape[1]) 163 | else: 164 | images.extend([batch[k][img_key] for img_key in self.image_keys]) 165 | split.extend([batch[k][img_key].shape[1] for img_key in self.image_keys]) 166 | 167 | is_sequence = self.is_sequence or len(images[0].shape) > 4 # See if we have a sequence dimension 168 | with torch.no_grad(): 169 | images = torch.cat(images, dim=1 if self.consistent else 0) # This is either the seq dim or channel dim. 170 | if is_sequence: 171 | n, s, c, h, w = images.size() 172 | images = images.view(n, s * c, h, w) # Apply same augmentations across sequence. 173 | images = op(images.float()) # Apply the same augmentation to each data pt. 174 | if is_sequence: 175 | images = images.view(n, s, c, h, w) 176 | # Split according to the dimension 1 splits 177 | images = torch.split(images, split, dim=1 if self.consistent else 0) 178 | 179 | # Iterate over everything in the same order and overwrite in the batch 180 | i = 0 181 | for k in ("obs", "next_obs", "init_obs"): 182 | if k in batch: 183 | if self.image_keys is None: 184 | batch[k] = images[i] 185 | i += 1 186 | else: 187 | for img_key in self.image_keys: 188 | batch[k][img_key] = images[i] 189 | i += 1 190 | assert i == len(images), "Did not write batch all augmented images." 191 | return batch 192 | -------------------------------------------------------------------------------- /research/algs/offline/idql.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Any, Dict, Optional, Type 3 | 4 | import diffusers 5 | import numpy as np 6 | import torch 7 | 8 | from research.networks.base import ActorCriticValuePolicy 9 | from research.utils import utils 10 | 11 | from ..off_policy_algorithm import OffPolicyAlgorithm 12 | 13 | 14 | def iql_loss(pred, target, expectile=0.5): 15 | err = target - pred 16 | weight = torch.abs(expectile - (err < 0).float()) 17 | return weight * torch.square(err) 18 | 19 | 20 | class IDQL(OffPolicyAlgorithm): 21 | def __init__( 22 | self, 23 | *args, 24 | tau: float = 0.005, 25 | target_freq: int = 1, 26 | expectile: Optional[float] = None, 27 | beta: float = 1, 28 | noise_scheduler=diffusers.schedulers.DDIMScheduler, 29 | noise_scheduler_kwargs: Optional[Dict] = None, 30 | num_inference_steps: Optional[int] = 10, 31 | num_samples: int = 64, 32 | **kwargs, 33 | ) -> None: 34 | super().__init__(*args, **kwargs) 35 | self.tau = tau 36 | self.target_freq = target_freq 37 | self.expectile = expectile 38 | self.beta = beta 39 | self.num_samples = num_samples 40 | assert isinstance(self.network, ActorCriticValuePolicy) 41 | noise_scheduler_kwargs = {} if noise_scheduler_kwargs is None else noise_scheduler_kwargs 42 | self.noise_scheduler = noise_scheduler(**noise_scheduler_kwargs) 43 | if num_inference_steps is None: 44 | self.num_inference_steps = self.noise_scheduler.config.num_train_timesteps 45 | else: 46 | self.num_inference_steps = num_inference_steps 47 | 48 | def setup_network(self, network_class: Type[torch.nn.Module], network_kwargs: Dict) -> None: 49 | self.network = network_class( 50 | self.processor.observation_space, self.processor.action_space, **network_kwargs 51 | ).to(self.device) 52 | 53 | self.target_network = self.network.create_subset(["encoder", "critic"])( 54 | self.processor.observation_space, self.processor.action_space, **network_kwargs 55 | ) 56 | # Delete the unneeded things from the target network. 57 | del self.target_network.encoder 58 | self.target_network = self.target_network.to(self.device) 59 | 60 | # Set up the target network. 61 | self.target_network.critic.load_state_dict(self.network.critic.state_dict()) 62 | for param in self.target_network.parameters(): 63 | param.requires_grad = False 64 | 65 | def setup_optimizers(self) -> None: 66 | actor_params = itertools.chain(self.network.actor.parameters(), self.network.encoder.parameters()) 67 | actor_groups = utils.create_optim_groups(actor_params, self.optim_kwargs) 68 | # NOTE: Optim class only affects the Actor. 69 | self.optim["actor"] = self.optim_class(actor_groups) 70 | 71 | # Remove weight decay from critics. 72 | value_optim_kwargs = self.optim_kwargs.copy() 73 | if "weight_decay" in value_optim_kwargs: 74 | del value_optim_kwargs["weight_decay"] 75 | self.optim["critic"] = torch.optim.Adam(self.network.critic.parameters(), **value_optim_kwargs) 76 | self.optim["value"] = torch.optim.Adam(self.network.value.parameters(), **value_optim_kwargs) 77 | 78 | def train_step(self, batch: Dict, step: int, total_steps: int) -> Dict: 79 | # We use the online encoder for everything in this IQL implementation 80 | # That is because we need to use the current obs for the target critic and online value. 81 | # This is done by default in DrQv2. 82 | batch["obs"] = self.network.encoder(batch["obs"]) 83 | 84 | with torch.no_grad(): 85 | batch["next_obs"] = self.network.encoder(batch["next_obs"]) 86 | 87 | # First compute the value loss 88 | with torch.no_grad(): 89 | target_q = self.target_network.critic(batch["obs"], batch["action"]) 90 | target_q = torch.min(target_q, dim=0)[0] 91 | vs = self.network.value(batch["obs"].detach()) # Always detach for value learning 92 | v_loss = iql_loss(vs, target_q.expand(vs.shape[0], -1), self.expectile).mean() 93 | 94 | self.optim["value"].zero_grad(set_to_none=True) 95 | v_loss.backward() 96 | self.optim["value"].step() 97 | 98 | # Next, compute the critic loss 99 | with torch.no_grad(): 100 | next_vs = self.network.value(batch["next_obs"]) 101 | next_v = torch.min(next_vs, dim=0)[0] 102 | target = batch["reward"] + batch["discount"] * next_v 103 | qs = self.network.critic(batch["obs"].detach(), batch["action"]) 104 | q_loss = torch.nn.functional.mse_loss(qs, target.expand(qs.shape[0], -1), reduction="none").mean() 105 | 106 | self.optim["critic"].zero_grad(set_to_none=True) 107 | q_loss.backward() 108 | self.optim["critic"].step() 109 | 110 | # Update the actor, just with BC. We will sample from it later using re-weighting. 111 | B = batch["action"].shape[0] 112 | noise = torch.randn_like(batch["action"]) 113 | timesteps = torch.randint( 114 | low=0, high=self.noise_scheduler.config.num_train_timesteps, size=(B,), device=self.device 115 | ).long() 116 | noisy_actions = self.noise_scheduler.add_noise(batch["action"], noise, timesteps) 117 | 118 | noise_pred = self.network.actor(noisy_actions, timesteps, cond=batch["obs"]) 119 | actor_loss = torch.nn.functional.mse_loss(noise_pred, noise, reduction="none").sum( 120 | dim=-1 121 | ) # Sum over action Dim 122 | 123 | if "mask" in batch: 124 | mask = (~batch["mask"]).float() 125 | actor_loss = actor_loss * mask 126 | size = mask.sum() 127 | else: 128 | size = actor_loss.numel() 129 | actor_loss = actor_loss.sum() / size 130 | 131 | # Update the networks. These are done in a stack to support different grad options for the encoder. 132 | self.optim["actor"].zero_grad(set_to_none=True) 133 | actor_loss.backward() 134 | self.optim["actor"].step() 135 | 136 | if step % self.target_freq == 0: 137 | with torch.no_grad(): 138 | # Only run on the critic and encoder, those are the only weights we update. 139 | for param, target_param in zip( 140 | self.network.critic.parameters(), self.target_network.critic.parameters() 141 | ): 142 | target_param.data.copy_(self.tau * param.data + (1 - self.tau) * target_param.data) 143 | 144 | return dict( 145 | q_loss=q_loss.item(), 146 | v_loss=v_loss.item(), 147 | actor_loss=actor_loss.item(), 148 | v=vs.mean().item(), 149 | q=qs.mean().item(), 150 | ) 151 | 152 | def _predict(self, batch: Dict): 153 | with torch.no_grad(): 154 | obs = self.network.encoder(batch["obs"]) 155 | B, D = obs.shape 156 | obs = obs.unsqueeze(1).expand(B, self.num_samples, D) 157 | 158 | noisy_actions = torch.randn(B, self.num_samples, self.processor.action_space.shape[0], device=self.device) 159 | self.noise_scheduler.set_timesteps(self.num_inference_steps) 160 | for timestep in self.noise_scheduler.timesteps: 161 | noise_pred = self.network.actor( 162 | noisy_actions, timestep.to(self.device).expand(B, self.num_samples), cond=obs 163 | ) 164 | noisy_actions = self.noise_scheduler.step( 165 | model_output=noise_pred, timestep=timestep, sample=noisy_actions 166 | ).prev_sample 167 | 168 | # Now we have finished generating the actions, now we need to figure out their weights 169 | v = self.network.value(obs).mean(dim=0) 170 | q = torch.min(self.target_network.critic(obs, noisy_actions), dim=0)[0] 171 | adv = q - v # Shape (B, self.num_samples) 172 | expectile_weights = torch.where(adv > 0, self.expectile, 1 - self.expectile) 173 | sample_idx = torch.multinomial(expectile_weights / expectile_weights.sum(), 1) # (B, 1) 174 | actions = noisy_actions[torch.arange(B), sample_idx.squeeze(-1)] 175 | 176 | return actions 177 | 178 | def _get_train_action(self, obs: Any, step: int, total_steps: int) -> np.ndarray: 179 | batch = dict(obs=obs) 180 | with torch.no_grad(): 181 | action = self.predict(batch, is_batched=False) 182 | return action[0] # return the first one. 183 | -------------------------------------------------------------------------------- /research/networks/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Type, Union 3 | 4 | import torch 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | 9 | class MLP(nn.Module): 10 | def __init__( 11 | self, 12 | input_dim: int, 13 | output_dim: int, 14 | hidden_layers: List[int] = (256, 256), 15 | act: nn.Module = nn.ReLU, 16 | dropout: float = 0.0, 17 | normalization: Optional[Type[nn.Module]] = None, 18 | output_act: Optional[Type[nn.Module]] = None, 19 | ): 20 | super().__init__() 21 | net = [] 22 | last_dim = input_dim 23 | for dim in hidden_layers: 24 | net.append(nn.Linear(last_dim, dim)) 25 | if dropout > 0.0: 26 | net.append(nn.Dropout(dropout)) 27 | if normalization is not None: 28 | net.append(normalization(dim)) 29 | net.append(act()) 30 | last_dim = dim 31 | net.append(nn.Linear(last_dim, output_dim)) 32 | if output_act is not None: 33 | net.append(output_act()) 34 | self.net = nn.Sequential(*net) 35 | self._has_output_act = False if output_act is None else True 36 | 37 | @property 38 | def last_layer(self) -> nn.Module: 39 | if self._has_output_act: 40 | return self.net[-2] 41 | else: 42 | return self.net[-1] 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | return self.net(x) 46 | 47 | 48 | class LinearEnsemble(nn.Module): 49 | def __init__( 50 | self, 51 | in_features: int, 52 | out_features: int, 53 | ensemble_size: int = 3, 54 | bias: bool = True, 55 | device: Optional[Union[str, torch.device]] = None, 56 | dtype: Optional[torch.dtype] = None, 57 | ): 58 | """ 59 | An Ensemble linear layer. 60 | For inputs of shape (B, H) will return (E, B, H) where E is the ensemble size 61 | See https://github.com/pytorch/pytorch/issues/54147 62 | """ 63 | factory_kwargs = {"device": device, "dtype": dtype} 64 | super().__init__() 65 | self.in_features = in_features 66 | self.out_features = out_features 67 | self.ensemble_size = ensemble_size 68 | self.weight = nn.Parameter(torch.empty((ensemble_size, in_features, out_features), **factory_kwargs)) 69 | if bias: 70 | self.bias = nn.Parameter(torch.empty((ensemble_size, 1, out_features), **factory_kwargs)) 71 | else: 72 | self.register_parameter("bias", None) 73 | self.reset_parameters() 74 | 75 | def reset_parameters(self) -> None: 76 | # The default torch init for Linear is a complete mess 77 | # https://github.com/pytorch/pytorch/issues/57109 78 | # If we use the same init, we will end up scaling incorrectly 79 | # 1. Compute the fan in of the 2D tensor = dim 1 of 2D matrix (0 index) 80 | # 2. Comptue the gain with param=math.sqrt(5.0) 81 | # This returns math.sqrt(2.0 / 6.0) = sqrt(1/3) 82 | # 3. Compute std = gain / math.sqrt(fan) = sqrt(1/3) / sqrt(in). 83 | # 4. Compute bound as math.sqrt(3.0) * std = 1 / in di 84 | std = 1.0 / math.sqrt(self.in_features) 85 | nn.init.uniform_(self.weight, -std, std) 86 | if self.bias is not None: 87 | nn.init.uniform_(self.bias, -std, std) 88 | 89 | def forward(self, x: torch.Tensor) -> torch.Tensor: 90 | if len(x.shape) == 2: 91 | x = x.repeat(self.ensemble_size, 1, 1) 92 | elif len(x.shape) > 3: 93 | raise ValueError("LinearEnsemble layer does not support inputs with more than 3 dimensions.") 94 | return torch.baddbmm(self.bias, x, self.weight) 95 | 96 | def extra_repr(self) -> str: 97 | return "ensemble_size={}, in_features={}, out_features={}, bias={}".format( 98 | self.ensemble_size, self.in_features, self.out_features, self.bias is not None 99 | ) 100 | 101 | 102 | class LayerNormEnsemble(nn.Module): 103 | """ 104 | This is a re-implementation of the Pytorch nn.LayerNorm module with suport for the Ensemble dim. 105 | We need this custom class since we need to normalize over normalize dims, but have multiple weight/bais 106 | parameters for the ensemble. 107 | 108 | """ 109 | 110 | __constants__ = ["normalized_shape", "eps", "elementwise_affine"] 111 | normalized_shape: Tuple[int, ...] 112 | eps: float 113 | elementwise_affine: bool 114 | 115 | def __init__( 116 | self, 117 | normalized_shape: int, 118 | ensemble_size: int = 3, 119 | eps: float = 1e-5, 120 | elementwise_affine: bool = True, 121 | device=None, 122 | dtype=None, 123 | ) -> None: 124 | factory_kwargs = {"device": device, "dtype": dtype} 125 | super().__init__() 126 | assert isinstance(normalized_shape, int), "Currently EnsembleLayerNorm only supports final dim int shapes." 127 | self.normalized_shape = (normalized_shape,) 128 | self.eps = eps 129 | self.elementwise_affine = elementwise_affine 130 | self.ensemble_size = ensemble_size 131 | if self.elementwise_affine: 132 | self.weight = nn.Parameter(torch.empty((self.ensemble_size, 1, *self.normalized_shape), **factory_kwargs)) 133 | self.bias = nn.Parameter(torch.empty((self.ensemble_size, 1, *self.normalized_shape), **factory_kwargs)) 134 | else: 135 | self.register_parameter("weight", None) 136 | self.register_parameter("bias", None) 137 | 138 | self.reset_parameters() 139 | 140 | def reset_parameters(self) -> None: 141 | if self.elementwise_affine: 142 | nn.init.ones_(self.weight) 143 | nn.init.zeros_(self.bias) 144 | 145 | def forward(self, x: torch.Tensor) -> torch.Tensor: 146 | if len(x.shape) == 2: 147 | x = x.repeat(self.ensemble_size, 1, 1) 148 | elif len(x.shape) > 3: 149 | raise ValueError("LayerNormEnsemble layer does not support inputs with more than 3 dimensions.") 150 | x = F.layer_norm(x, self.normalized_shape, None, None, self.eps) # (E, B, *normalized shape) 151 | if self.elementwise_affine: 152 | x = x * self.weight + self.bias 153 | return x 154 | 155 | def extra_repr(self) -> str: 156 | return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__) 157 | 158 | 159 | class EnsembleMLP(nn.Module): 160 | def __init__( 161 | self, 162 | input_dim: int, 163 | output_dim: int, 164 | ensemble_size: int = 3, 165 | hidden_layers: List[int] = (256, 256), 166 | act: nn.Module = nn.ReLU, 167 | dropout: float = 0.0, 168 | normalization: Optional[Type[nn.Module]] = None, 169 | output_act: Optional[Type[nn.Module]] = None, 170 | ): 171 | """ 172 | An ensemble MLP 173 | Returns values of shape (E, B, H) from input (B, H) 174 | All extra dimensions are moved to batch 175 | """ 176 | super().__init__() 177 | # Change the normalization type to work over ensembles 178 | assert normalization is None or normalization is LayerNormEnsemble, "Ensemble only support EnsembleLayerNorm" 179 | net = [] 180 | last_dim = input_dim 181 | for dim in hidden_layers: 182 | net.append(LinearEnsemble(last_dim, dim, ensemble_size=ensemble_size)) 183 | if dropout > 0.0: 184 | net.append(nn.Dropout(dropout)) 185 | if normalization is not None: 186 | net.append(normalization(dim, ensemble_size=ensemble_size)) 187 | net.append(act()) 188 | last_dim = dim 189 | net.append(LinearEnsemble(last_dim, output_dim, ensemble_size=ensemble_size)) 190 | if output_act is not None: 191 | net.append(output_act()) 192 | self.net = nn.Sequential(*net) 193 | self.input_dim, self.output_dim = input_dim, output_dim 194 | self.ensemble_size = ensemble_size 195 | self._has_output_act = False if output_act is None else True 196 | 197 | def forward(self, x: torch.Tensor) -> torch.Tensor: 198 | # The input to this network is assumed to be (....., input_dim) 199 | assert x.shape[-1] == self.input_dim 200 | batch_dims = x.size()[:-1] 201 | if len(batch_dims) > 1: 202 | x = x.view(-1, self.input_dim) 203 | x = self.net(x) 204 | output_shape = (self.ensemble_size, *batch_dims, self.output_dim) 205 | x = x.view(*output_shape) 206 | else: 207 | x = self.net(x) 208 | return x 209 | 210 | @property 211 | def last_layer(self) -> torch.Tensor: 212 | if self._has_output_act: 213 | return self.net[-2] 214 | else: 215 | return self.net[-1] 216 | -------------------------------------------------------------------------------- /research/processors/normalization.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Union 2 | 3 | import gym 4 | import numpy as np 5 | import torch 6 | 7 | from research.utils import utils 8 | 9 | from .base import Processor 10 | 11 | 12 | class RunningMeanStd(torch.nn.Module): 13 | def __init__(self, shape, epsilon: float = 1e-6): 14 | super().__init__() 15 | self.shape = shape 16 | self._mean = torch.nn.Parameter(torch.zeros(shape, dtype=torch.float), requires_grad=False) 17 | self._var = torch.nn.Parameter(torch.ones(shape, dtype=torch.float), requires_grad=False) 18 | self._count = torch.nn.Parameter(torch.tensor(epsilon, dtype=torch.float), requires_grad=False) 19 | 20 | def update(self, x: Union[float, np.ndarray, torch.Tensor]) -> None: 21 | if isinstance(x, float): 22 | x = torch.tensor(x) 23 | elif isinstance(x, np.ndarray): 24 | x = torch.from_numpy(x) 25 | elif isinstance(x, torch.Tensor): 26 | pass 27 | else: 28 | raise ValueError("Invalid type provided") 29 | # If the data is unbatched, unsqueeze it 30 | if len(x.shape) == len(self.shape): 31 | x = x.unsqueeze(0) 32 | 33 | mean = torch.mean(x, dim=0) 34 | var = torch.var(x, dim=0, unbiased=False) 35 | count = x.shape[0] 36 | 37 | delta = mean - self._mean 38 | total_count = self._count + count 39 | 40 | new_mean = self._mean + delta * count / total_count 41 | 42 | m_a = self._var * self._count 43 | m_b = var * count 44 | m_2 = m_a + m_b + torch.square(delta) * self._count * count / total_count 45 | new_var = m_2 / total_count 46 | # Update member variables 47 | self._count.copy_(total_count) 48 | self._mean.copy_(new_mean) 49 | self._var.copy_(new_var) 50 | 51 | @property 52 | def mean(self): 53 | return self._mean.data # Make sure we return the data, not the parameter! 54 | 55 | @property 56 | def var(self): 57 | return self._var.data # Make sure we return the data, not the parameter! 58 | 59 | @property 60 | def std(self): 61 | return torch.sqrt(self._var) 62 | 63 | 64 | class RunningObservationNormalizer(Processor): 65 | """ 66 | A running observation normalizer. 67 | Note that there are quite a few speed optimizations that could be performed: 68 | 1. We could cache computation of the variance etc. so it doesn't run everytime. 69 | 2. We could permanently store torch tensors so we don't recompute them and sync to GPU. 70 | """ 71 | 72 | def __init__( 73 | self, 74 | observation_space: gym.Space, 75 | action_space: gym.Space, 76 | epsilon: float = 1e-7, 77 | clip: float = 10, 78 | explicit_update: bool = False, 79 | paired_keys: Optional[List[str]] = None, 80 | ) -> None: 81 | super().__init__(observation_space, action_space) 82 | self.paired_keys = set() if paired_keys is None else set(paired_keys) 83 | if isinstance(observation_space, gym.spaces.Dict): 84 | assert all([isinstance(space, gym.spaces.Box) for space in observation_space.values()]) 85 | self.rms = { 86 | k: RunningMeanStd(space.shape, epsilon=epsilon) 87 | for k, space in observation_space.items() 88 | if k not in self.paired_keys 89 | } 90 | if len(self.paired_keys) > 0: 91 | self.rms["paired"] = RunningMeanStd(observation_space[paired_keys[0]].shape, epsilon=epsilon) 92 | elif isinstance(observation_space, gym.spaces.Box): 93 | self.rms = RunningMeanStd(observation_space.shape, epsilon=epsilon) 94 | else: 95 | raise ValueError("Invalid space type provided.") 96 | self._updated_stats = True 97 | self.clip = clip 98 | self.explicit_update = explicit_update 99 | 100 | @property 101 | def supports_gpu(self): 102 | return False 103 | 104 | def _get_key(self, k): 105 | if k in self.paired_keys: 106 | return "paired" 107 | else: 108 | return k 109 | 110 | def update(self, obs: Union[torch.Tensor, Dict]) -> None: 111 | if isinstance(obs, dict): 112 | for k in obs.keys(): 113 | self.rms[self._get_key(k)].update(obs[k]) 114 | else: 115 | self.rms.update(obs) 116 | self._updated_stats = True 117 | 118 | def normalize(self, obs: Union[torch.Tensor, Dict]) -> Union[torch.Tensor, Dict]: 119 | if self._updated_stats: 120 | # Grab the states from the RMS trackers 121 | self._mean = {k: self.rms[k].mean for k in self.rms.keys()} if isinstance(obs, dict) else self.rms.mean 122 | self._std = {k: self.rms[k].std for k in self.rms.keys()} if isinstance(obs, dict) else self.rms.std 123 | device = utils.get_device(obs) 124 | if device is not None: 125 | self._mean = utils.to_device(self._mean, device) 126 | self._std = utils.to_device(self._std, device) 127 | self._updated_stats = False 128 | # Normalize the observation 129 | if isinstance(obs, dict): 130 | obs = {k: (obs[k] - self._mean[self._get_key(k)]) / self._std[self._get_key(k)] for k in obs.keys()} 131 | if self.clip is not None: 132 | for k in obs.keys(): 133 | obs[k] = torch.clamp(obs[k], -self.clip, self.clip) 134 | return obs 135 | elif isinstance(obs, torch.Tensor): 136 | obs = (obs - self._mean) / self._std 137 | return obs if self.clip is None else torch.clamp(obs, -self.clip, self.clip) 138 | else: 139 | raise ValueError("Invalid Input provided") 140 | 141 | def forward(self, batch: Dict) -> Dict: 142 | # Check if we should update the statistics 143 | if not self.explicit_update and self.training and "obs" in batch: 144 | self.update(batch["obs"]) 145 | # Normalize 146 | for k in ("obs", "next_obs", "init_obs"): 147 | if k in batch: 148 | batch[k] = self.normalize(batch[k]) 149 | return batch 150 | 151 | 152 | class GaussianActionNormalizer(Processor): 153 | def __init__( 154 | self, 155 | observation_space: gym.Space, 156 | action_space: gym.Space, 157 | mean: List[float], 158 | std: List[float], 159 | clip: Optional[float] = None, 160 | ): 161 | super().__init__(observation_space, action_space) 162 | assert isinstance(action_space, gym.spaces.Box), "Must use box action space." 163 | self.mean = np.array(mean, dtype=np.float32) 164 | self.std = np.array(std, dtype=np.float32) 165 | assert self.low.shape == action_space.low.shape 166 | assert self.high.shape == action_space.high.shape 167 | self.clip = clip 168 | 169 | @property 170 | def action_space(self): 171 | if self.clip is None: 172 | return gym.spaces.Box( 173 | low=(self._action_space.low - self.mean) / self.std, 174 | high=(self._action_space.high - self.mean) / self.std, 175 | ) 176 | else: 177 | return gym.spaces.Box( 178 | low=-self.clip * np.ones_like(self._action_space.low), 179 | high=self.clip * np.ones_like(self._action_space.high), 180 | ) 181 | 182 | def forward(self, batch: Dict): 183 | # Process the action to be the correct space 184 | action = (batch["action"] - self.mean) / self.std 185 | if self.clip is None: 186 | action = torch.clamp(action, min=-self.clip, max=self.clip) 187 | batch["action"] = action 188 | return batch 189 | 190 | def unprocess(self, batch: Dict) -> Dict: 191 | # Replace the action to be the correct 192 | batch["action"] = batch["action"] * self.std + self.mean 193 | return batch 194 | 195 | 196 | class MinMaxActionNormalizer(Processor): 197 | def __init__( 198 | self, 199 | observation_space: gym.Space, 200 | action_space: gym.Space, 201 | low: List[float], 202 | high: List[float], 203 | output_low: float = -1, 204 | output_high: float = 1, 205 | ): 206 | super().__init__(observation_space, action_space) 207 | assert isinstance(action_space, gym.spaces.Box), "Must use box action space." 208 | self.low = np.array(low, dtype=np.float32) 209 | self.high = np.array(high, dtype=np.float32) 210 | assert self.low.shape == action_space.low.shape 211 | assert self.high.shape == action_space.high.shape 212 | self.output_high = output_high 213 | self.output_low = output_low 214 | 215 | @property 216 | def action_space(self): 217 | return gym.spaces.Box( 218 | low=self.output_low, high=self.output_high, shape=self._action_space.shape, dtype=np.float32 219 | ) 220 | 221 | def forward(self, batch: Dict): 222 | # Process the batch to be the correct shape 223 | action = batch["action"] 224 | action = (action - self.low) / (self.high - self.low) # normalize to 0 to 1 225 | action = action * (self.output_high - self.output_low) + self.output_low 226 | batch["action"] = action 227 | return batch 228 | 229 | def unprocess(self, batch: Dict) -> Dict: 230 | # Replace the action to be the correct 231 | action = batch["action"] 232 | action = (action - self.output_low) / (self.output_high - self.output_low) 233 | action = action * (self.high - self.low) + self.low 234 | batch["action"] = action 235 | return batch 236 | --------------------------------------------------------------------------------