├── tasks ├── __init__.py ├── motion_completion.py ├── tracking.py ├── cmu_relabel.py └── dmcontrol.py ├── trajectory ├── __init__.py ├── tfds │ ├── __init__.py │ ├── tfds │ │ └── mocapact │ │ │ ├── __init__.py │ │ │ └── mocapact.py │ └── generate_tfds_dataset.py ├── models │ ├── __init__.py │ ├── ein.py │ ├── transformers.py │ ├── cnn.py │ └── transformer_prior.py ├── search │ ├── __init__.py │ ├── utils.py │ └── planning.py ├── datasets │ ├── __init__.py │ ├── preprocessing.py │ ├── d4rl.py │ └── sequence.py └── utils │ ├── symlog.py │ ├── __init__.py │ ├── timer.py │ ├── scaling_law.py │ ├── arrays.py │ ├── video.py │ ├── git_utils.py │ ├── config.py │ ├── dataset.py │ ├── progress.py │ ├── serialization.py │ ├── setup.py │ ├── sampler.py │ └── relabel_humanoid.py ├── requirements ├── requirements.in └── requirements.txt ├── CONTRIBUTING.md ├── README.md ├── CODE_OF_CONDUCT.md ├── plotting └── read_results.py ├── config └── vqvae.py └── scripts ├── humanoid_plan.py ├── train.py └── trainprior.py /tasks/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trajectory/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trajectory/tfds/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trajectory/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trajectory/search/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .planning import * 3 | -------------------------------------------------------------------------------- /trajectory/tfds/tfds/mocapact/__init__.py: -------------------------------------------------------------------------------- 1 | """mocapact dataset.""" 2 | 3 | from .mocapact import Mocapact 4 | -------------------------------------------------------------------------------- /trajectory/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .d4rl import load_environment 2 | from .sequence import * 3 | from .preprocessing import get_preprocess_fn 4 | from .mocapact import MocapactDataset, SequentialDataLoader 5 | -------------------------------------------------------------------------------- /trajectory/utils/symlog.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | 9 | def symlog(x): 10 | return torch.sign(x) * torch.log1p(torch.abs(x)) 11 | 12 | def symexp(x): 13 | return torch.sign(x) * (torch.expm1(torch.abs(x))) 14 | -------------------------------------------------------------------------------- /trajectory/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .setup import Parser, watch 2 | from .arrays import * 3 | from .serialization import * 4 | from .progress import Progress, Silent 5 | from .rendering import make_renderer 6 | from .config import Config 7 | from .training import VQTrainer, PriorTrainer 8 | from .sampler import BatchSampler, RandomSampler 9 | from .timer import Timer 10 | 11 | try: 12 | from . import iql 13 | except Exception as e: 14 | print("fail to load iql") 15 | -------------------------------------------------------------------------------- /trajectory/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import time 8 | 9 | class Timer: 10 | 11 | def __init__(self): 12 | self._start = time.time() 13 | 14 | def __call__(self, reset=True): 15 | now = time.time() 16 | diff = now - self._start 17 | if reset: 18 | self._start = now 19 | return diff -------------------------------------------------------------------------------- /trajectory/utils/scaling_law.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from scipy.optimize import curve_fit 9 | 10 | def fit_power_law(S, L): 11 | def power_law(S, alpha, beta, C): 12 | return C + (beta / S)**alpha 13 | 14 | params, covariance = curve_fit(power_law, S, L) 15 | alpha, beta, C = params 16 | 17 | def fitted_power_law(S): 18 | return power_law(S, alpha, beta, C) 19 | 20 | return fitted_power_law -------------------------------------------------------------------------------- /requirements/requirements.in: -------------------------------------------------------------------------------- 1 | -f https://download.pytorch.org/whl/cu116 2 | torch==1.13.1 3 | xformers==0.0.16 4 | rotary-embedding-torch 5 | absl-py 6 | opt-einsum 7 | flatbuffers 8 | gym==0.21.0 9 | mujoco-py==2.1.2.14 10 | numpy<1.24.0 11 | matplotlib 12 | typed-argument-parser 13 | d4rl @ git+https://github.com/JannerM/d4rl.git@c3dd04da02acbf4de6cbaa1141deb4f958f03ca9 14 | scikit-image 15 | scikit-video 16 | gitpython 17 | einops==0.4.1 18 | mjrl @ git+https://github.com/aravindr93/mjrl@3871d93763d3b49c4741e6daeaebbc605fe140dc 19 | wandb 20 | dm_control==1.0.2 21 | stable_baselines3 22 | 23 | absl-py 24 | mujoco==2.1.5 25 | dm-env 26 | dm-tree 27 | tensorflow==2.11.1 28 | # tensorflow_datasets 29 | rlds 30 | h5py 31 | chex 32 | pandas 33 | cython<3 34 | 35 | accelerate==0.18.0 36 | -------------------------------------------------------------------------------- /trajectory/utils/arrays.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | 10 | DTYPE = torch.float 11 | DEVICE = 'cuda:0' 12 | 13 | def to_np(x): 14 | if torch.is_tensor(x): 15 | x = x.detach().cpu().numpy() 16 | return x 17 | 18 | def to_torch(x, dtype=None, device=None): 19 | dtype = dtype or DTYPE 20 | device = device or DEVICE 21 | return torch.tensor(x, dtype=dtype, device=device) 22 | 23 | def to_device(*xs, device=DEVICE): 24 | return [x.to(device) for x in xs] 25 | 26 | def normalize(x): 27 | """ 28 | scales `x` to [0, 1] 29 | """ 30 | x = x - x.min() 31 | x = x / x.max() 32 | return x 33 | 34 | def to_img(x): 35 | normalized = normalize(x) 36 | array = to_np(normalized) 37 | array = np.transpose(array, (1,2,0)) 38 | return (array * 255).astype(np.uint8) 39 | 40 | def set_device(device): 41 | DEVICE = device 42 | if 'cuda' in device: 43 | torch.set_default_tensor_type(torch.cuda.FloatTensor) 44 | -------------------------------------------------------------------------------- /trajectory/utils/video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import numpy as np 9 | import skvideo.io 10 | 11 | def _make_dir(filename): 12 | folder = os.path.dirname(filename) 13 | if not os.path.exists(folder): 14 | os.makedirs(folder) 15 | 16 | def save_video(filename, video_frames, fps=30, video_format='mp4'): 17 | assert fps == int(fps), fps 18 | _make_dir(filename) 19 | 20 | skvideo.io.vwrite( 21 | filename, 22 | video_frames, 23 | inputdict={ 24 | '-r': str(int(fps)), 25 | }, 26 | outputdict={ 27 | '-f': video_format, 28 | #'-pix_fmt': 'yuv420p', # '-pix_fmt=yuv420p' needed for osx https://github.com/scikit-video/scikit-video/issues/74 29 | } 30 | ) 31 | 32 | def save_videos(filename, *video_frames, **kwargs): 33 | ## video_frame : [ N x H x W x C ] 34 | video_frames = np.concatenate(video_frames, axis=2) 35 | save_video(filename, video_frames, **kwargs) 36 | -------------------------------------------------------------------------------- /trajectory/utils/git_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import git 9 | import pdb 10 | 11 | PROJECT_PATH = os.path.dirname( 12 | os.path.realpath(os.path.join(__file__, '..', '..'))) 13 | 14 | def get_repo(path=PROJECT_PATH, search_parent_directories=True): 15 | repo = git.Repo( 16 | path, search_parent_directories=search_parent_directories) 17 | return repo 18 | 19 | def get_git_rev(*args, **kwargs): 20 | try: 21 | repo = get_repo(*args, **kwargs) 22 | if repo.head.is_detached: 23 | git_rev = repo.head.object.name_rev 24 | else: 25 | git_rev = repo.active_branch.commit.name_rev 26 | except: 27 | git_rev = None 28 | 29 | return git_rev 30 | 31 | def git_diff(*args, **kwargs): 32 | repo = get_repo(*args, **kwargs) 33 | diff = repo.git.diff() 34 | return diff 35 | 36 | def save_git_diff(savepath, *args, **kwargs): 37 | diff = git_diff(*args, **kwargs) 38 | with open(savepath, 'w') as f: 39 | f.write(diff) 40 | 41 | if __name__ == '__main__': 42 | 43 | git_rev = get_git_rev() 44 | print(git_rev) 45 | 46 | save_git_diff('diff_test.txt') -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to hgap 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to hgap, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /trajectory/models/ein.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | import pdb 11 | 12 | class EinLinear(nn.Module): 13 | 14 | def __init__(self, n_models, in_features, out_features, bias): 15 | super().__init__() 16 | self.n_models = n_models 17 | self.out_features = out_features 18 | self.in_features = in_features 19 | self.weight = nn.Parameter(torch.Tensor(n_models, out_features, in_features)) 20 | if bias: 21 | self.bias = nn.Parameter(torch.Tensor(n_models, out_features)) 22 | else: 23 | self.register_parameter('bias', None) 24 | self.reset_parameters() 25 | 26 | def reset_parameters(self): 27 | for i in range(self.n_models): 28 | nn.init.kaiming_uniform_(self.weight[i], a=math.sqrt(5)) 29 | if self.bias is not None: 30 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight[i]) 31 | bound = 1 / math.sqrt(fan_in) 32 | nn.init.uniform_(self.bias[i], -bound, bound) 33 | 34 | def forward(self, input): 35 | """ 36 | input : [ B x n_models x input_dim ] 37 | """ 38 | ## [ B x n_models x output_dim ] 39 | output = torch.einsum('eoi,bei->beo', self.weight, input) 40 | if self.bias is not None: 41 | raise RuntimeError() 42 | return output 43 | 44 | def extra_repr(self): 45 | return 'n_models={}, in_features={}, out_features={}, bias={}'.format( 46 | self.n_models, self.in_features, self.out_features, self.bias is not None 47 | ) 48 | -------------------------------------------------------------------------------- /trajectory/search/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import torch 9 | import pdb 10 | 11 | from ..utils.arrays import to_torch 12 | 13 | VALUE_PLACEHOLDER = 1e6 14 | 15 | def make_prefix(context, obs, transition_dim, device="cuda"): 16 | obs_discrete = to_torch(obs, dtype=torch.float32, device=device) 17 | pad_dims = to_torch(np.zeros(transition_dim - len(obs) - len(context)), dtype=torch.float32, device=device) 18 | if obs_discrete.ndim == 1: 19 | obs_discrete = obs_discrete.reshape(1, 1, -1) 20 | pad_dims = pad_dims.reshape(1, 1, -1) 21 | transition = torch.cat(context + [obs_discrete] + [pad_dims], axis=-1) 22 | prefix = transition 23 | return prefix 24 | 25 | def extract_actions(x, observation_dim, action_dim, t=None): 26 | actions = x[:, observation_dim:observation_dim+action_dim] 27 | if t is not None: 28 | return actions[t] 29 | else: 30 | return actions 31 | 32 | def extract_actions_continuous(x, observation_dim, action_dim, t=None): 33 | assert x.shape[0] == 1 34 | actions = x[0, :, observation_dim:observation_dim+action_dim] 35 | if t is not None: 36 | return actions[t] 37 | else: 38 | return actions 39 | 40 | def update_context(context, observation, action, reward, device): 41 | ''' 42 | context : list of transitions 43 | [ tensor( transition_dim ), ... ] 44 | ''' 45 | ## use a placeholder for value because input values are masked out by model 46 | rew_val = np.array([reward, VALUE_PLACEHOLDER]) 47 | transition = np.concatenate([observation, action, rew_val]) 48 | # context = [] 49 | 50 | transition_discrete = to_torch(transition, dtype=torch.float32, device=device) 51 | transition_discrete = transition_discrete.reshape(1, 1, -1) 52 | 53 | ## add new transition to context 54 | context.append(transition_discrete) 55 | return context 56 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # H-GAP: Humanoid Control with a Generalist Planner 2 | 3 | Implementation of [H-GAP: Humanoid Control with a Generalist Planner](https://ycxuyingchen.github.io/hgap/). 4 | 5 | 6 | ## Installation 7 | 1. Create a Python virtual environment with Python 3.9 via the method of your choice. For example with conda: 8 | ``` 9 | conda create -n hgap python=3.9 10 | ``` 11 | 12 | 2. Install the dependencies: 13 | ``` 14 | pip install -r requirements/requirements.txt 15 | ``` 16 | 17 | 3. Install MoCapAct following instruction at https://github.com/microsoft/MoCapAct 18 | 19 | ## Prepare MoCapAct datasets 20 | 21 | 1. Download MoCapAct datasets following instructions at https://github.com/microsoft/MoCapAct 22 | 23 | 2. Generate TFDS datasets for H-GAP training: 24 | 25 | ``` 26 | # This creates TFDS datasets from the original MoCapAct dataset (which is in HDF5 format). 27 | # Set mocapact_data_dir to be the path to the downloaded MoCapAct dataset, e.g. /home/usr/data/mocap 28 | # Note that there are two sizes of MoCapAct, i.e. small and large. Specify which one to build by setting size as small or large. 29 | 30 | python trajectory/datasets/generate_tfds_dataset.py --mocapact_data_dir $mocapact_data_dir --size small 31 | ``` 32 | 33 | ## Usage 34 | 35 | 1. Train VAE: 36 | ``` 37 | python scripts/train.py --dataset mocapact-large-compact --exp_name $vae_name --relabel_type none --n_epochs_ref 1200 38 | ``` 39 | 40 | 2. Train Prior Transformer: 41 | ``` 42 | python scripts/trainprior.py --dataset mocapact-large-compact --exp_name $prior_name --vae_name $vae_name --relabel_type none --n_epochs_ref 1200 43 | 44 | ``` 45 | 46 | 3. Plan: 47 | ``` 48 | python scripts/humanoid_plan.py --test_planner sample_with_prior --objective $relabel_type --temperature 2 --prob_weight 0 --nb_samples 64 --horizon 16 --dataset mocapact-large-compact --exp_name $plan_name --prior_name $prior_name --vae_name $vae_name --suffix $j --seed $j --task $relabel_type --top_p $top_p 49 | ``` 50 | 51 | ## License 52 | The majority of H-GAP is licensed under CC-BY-NC, however portions of the project are adapted from codes available under separate license terms: latentplan is licensed under the MIT license. 53 | -------------------------------------------------------------------------------- /trajectory/utils/config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import collections 9 | import pickle 10 | import torch 11 | 12 | class Config(collections.Mapping): 13 | 14 | def __init__(self, _class, verbose=True, savepath=None, **kwargs): 15 | self._class = _class 16 | self._dict = {} 17 | 18 | for key, val in kwargs.items(): 19 | self._dict[key] = val 20 | 21 | # only print and save on the main process 22 | try: 23 | rank = torch.distributed.get_rank() 24 | except: 25 | rank = 0 26 | if rank == 0: 27 | if verbose: 28 | print(self) 29 | 30 | if savepath is not None: 31 | savepath = os.path.join(*savepath) if type(savepath) is tuple else savepath 32 | pickle.dump(self, open(savepath, 'wb')) 33 | print(f'Saved config to: {savepath}\n') 34 | 35 | 36 | def __repr__(self): 37 | string = f'\nConfig: {self._class}\n' 38 | for key in sorted(self._dict.keys()): 39 | val = self._dict[key] 40 | string += f' {key}: {val}\n' 41 | return string 42 | 43 | def __iter__(self): 44 | return iter(self._dict) 45 | 46 | def __getitem__(self, item): 47 | return self._dict[item] 48 | 49 | def __len__(self): 50 | return len(self._dict) 51 | 52 | def __call__(self): 53 | return self.make() 54 | 55 | def __getattr__(self, attr): 56 | if attr == '_dict' and '_dict' not in vars(self): 57 | self._dict = {} 58 | try: 59 | return self._dict[attr] 60 | except KeyError: 61 | raise AttributeError(attr) 62 | 63 | def make(self, **kwargs): 64 | if 'GPT' in str(self._class) or 'VAE' in str(self._class) or 'Trainer' in str(self._class) or 'Prior' in str(self._class) or 'Critic' in str(self._class): 65 | return self._class(self) 66 | else: 67 | return self._class(**self._dict, **kwargs) 68 | -------------------------------------------------------------------------------- /trajectory/utils/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import trajectory.utils as utils 9 | import trajectory.datasets as datasets 10 | 11 | 12 | def create_dataset(args, world_size=1, rank=0, repeat=True): 13 | sequence_length = args.subsampled_sequence_length * args.step 14 | 15 | if "large" in args.dataset: 16 | if os.path.exists(os.path.expanduser('~/data/mocap/dataset/large')): 17 | file_path = os.path.expanduser('~/data/mocap/dataset/large') 18 | else: 19 | file_path = os.path.expanduser('~/data_local/mocap/dataset/large') 20 | else: 21 | if os.path.exists(os.path.expanduser('~/data/mocap/dataset/small')): 22 | file_path = os.path.expanduser('~/data/mocap/dataset/small') 23 | else: 24 | file_path = os.path.expanduser('~/data_local/mocap/dataset/small') 25 | 26 | file_names = args.dataset 27 | 28 | ignore_incomplete_episodes = args.ignore_incomplete_episodes 29 | 30 | dataset_config= utils.Config( 31 | datasets.SequentialDataLoader, 32 | batch_size=args.load_batch_size, 33 | savepath=(args.savepath, 'data_config.pkl'), 34 | fnames=file_names, 35 | normalize_obs=args.normalize, 36 | normalize_act=args.normalize, 37 | normalize_reward=args.normalize_reward, 38 | sequence_length=sequence_length, 39 | discount=args.discount, 40 | relabel_type=args.relabel_type, 41 | metrics_path=os.path.join(file_path, 'dataset_metrics.npz'), 42 | validation_episodes=args.validation_episodes, 43 | body_height_limit=args.body_height_limit, 44 | body_height_penalty=args.body_height_penalty, 45 | reward_clip=args.reward_clip, 46 | checkpoint_path=os.path.join(args.savepath, 'dataloader_checkpoint'), 47 | world_size=world_size, 48 | rank=rank, 49 | repeat=repeat, 50 | ignore_incomplete_episodes=ignore_incomplete_episodes, 51 | ) 52 | 53 | dataset = dataset_config() 54 | return dataset 55 | -------------------------------------------------------------------------------- /tasks/motion_completion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from pathlib import Path 9 | from gym import spaces 10 | from typing import Any, Dict, Optional, Tuple, Union 11 | from dm_control.locomotion.tasks.reference_pose import types 12 | 13 | from dm_control.locomotion.mocap import cmu_mocap_data 14 | from dm_control.locomotion.walkers import cmu_humanoid 15 | 16 | from mocapact.envs import dm_control_wrapper 17 | from mocapact.tasks import motion_completion 18 | 19 | class MotionCompletionGymEnv(dm_control_wrapper.DmControlWrapper): 20 | def __init__( 21 | self, 22 | dataset: types.ClipCollection, 23 | ref_steps: Tuple[int] = (0,), 24 | mocap_path: Optional[Union[str, Path]] = None, 25 | task_kwargs: Optional[Dict[str, Any]] = None, 26 | environment_kwargs: Optional[Dict[str, Any]] = None, 27 | include_clip_id: bool = False, 28 | 29 | # for rendering 30 | width: int = 640, 31 | height: int = 480, 32 | camera_id: int = 3 33 | ): 34 | if dataset is None: 35 | self._dataset = types.ClipCollection(ids=['CMU_002_01', 'CMU_009_01', 'CMU_010_04', 'CMU_013_11', 'CMU_014_06', 'CMU_041_02', 36 | 'CMU_046_01', 'CMU_075_01', 'CMU_083_18', 'CMU_105_53', 'CMU_143_41', 'CMU_049_07']) 37 | else: 38 | self._dataset = types.ClipCollection(ids=[dataset]) 39 | task_kwargs = task_kwargs or dict() 40 | task_kwargs['ref_path'] = mocap_path if mocap_path else cmu_mocap_data.get_path_for_cmu(version='2020') 41 | task_kwargs['dataset'] = self._dataset 42 | task_kwargs['ref_steps'] = ref_steps 43 | self._include_clip_id = include_clip_id 44 | super().__init__( 45 | motion_completion.MotionCompletion, 46 | task_kwargs=task_kwargs, 47 | environment_kwargs=environment_kwargs, 48 | act_noise=0., 49 | arena_size=(100., 100.), 50 | width=width, 51 | height=height, 52 | camera_id=camera_id 53 | ) 54 | 55 | def _get_walker(self): 56 | return cmu_humanoid.CMUHumanoidPositionControlledV2020 57 | 58 | def _create_observation_space(self) -> spaces.Dict: 59 | obs_spaces = dict() 60 | for k, v in self._env.observation_spec().items(): 61 | if v.dtype == np.float64 and np.prod(v.shape) > 0: 62 | obs_spaces[k] = spaces.Box( 63 | -np.infty, 64 | np.infty, 65 | shape=(np.prod(v.shape),), 66 | dtype=np.float32 67 | ) 68 | elif k == 'walker/clip_id' and self._include_clip_id: 69 | obs_spaces[k] = spaces.Discrete(len(self._dataset.ids)) 70 | return spaces.Dict(obs_spaces) -------------------------------------------------------------------------------- /trajectory/datasets/preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from trajectory.datasets.mocapact import CMU_HUMANOID_OBSERVABLES 9 | 10 | def kitchen_preprocess_fn(observations): 11 | ## keep first 30 dimensions of 60-dimension observations 12 | keep = observations[:, :30] 13 | remove = observations[:, 30:] 14 | assert (remove.max(0) == remove.min(0)).all(), 'removing important state information' 15 | return keep 16 | 17 | def ant_preprocess_fn(observations): 18 | qpos_dim = 13 ## root_x and root_y removed 19 | qvel_dim = 14 20 | cfrc_dim = 84 21 | assert observations.shape[1] == qpos_dim + qvel_dim + cfrc_dim 22 | keep = observations[:, :qpos_dim + qvel_dim] 23 | return keep 24 | 25 | def vmap(fn): 26 | 27 | def _fn(inputs): 28 | if isinstance(inputs, dict): 29 | return_1d = False 30 | else: 31 | if inputs.ndim == 1: 32 | inputs = inputs[None] 33 | return_1d = True 34 | else: 35 | return_1d = False 36 | 37 | outputs = fn(inputs) 38 | 39 | if return_1d: 40 | return outputs.squeeze(0) 41 | else: 42 | return outputs 43 | 44 | return _fn 45 | 46 | def preprocess_dataset(preprocess_fn): 47 | 48 | def _fn(dataset): 49 | for key in ['observations', 'next_observations']: 50 | dataset[key] = preprocess_fn(dataset[key]) 51 | return dataset 52 | 53 | return _fn 54 | 55 | def humanoid_preprocess_fn(obs_dict): 56 | obs_list = [] 57 | for key in CMU_HUMANOID_OBSERVABLES: 58 | obs_list.append(obs_dict[key]) 59 | obs = np.concatenate(obs_list) 60 | return obs 61 | 62 | def dmcontrol_preprocess_fn(obs_dict): 63 | obs_list = [] 64 | for key in obs_dict.keys(): 65 | obs_list.append(obs_dict[key]) 66 | obs = np.concatenate(obs_list) 67 | return obs 68 | 69 | preprocess_functions = { 70 | 'kitchen-complete-v0': vmap(kitchen_preprocess_fn), 71 | 'kitchen-mixed-v0': vmap(kitchen_preprocess_fn), 72 | 'kitchen-partial-v0': vmap(kitchen_preprocess_fn), 73 | 'ant-expert-v2': vmap(ant_preprocess_fn), 74 | 'ant-medium-expert-v2': vmap(ant_preprocess_fn), 75 | 'ant-medium-replay-v2': vmap(ant_preprocess_fn), 76 | 'ant-medium-v2': vmap(ant_preprocess_fn), 77 | 'ant-random-v2': vmap(ant_preprocess_fn), 78 | 'speed': vmap(humanoid_preprocess_fn), 79 | 'forward': vmap(humanoid_preprocess_fn), 80 | 'rotate_x': vmap(humanoid_preprocess_fn), 81 | 'rotate_y': vmap(humanoid_preprocess_fn), 82 | 'rotate_z': vmap(humanoid_preprocess_fn), 83 | 'x_vel': vmap(humanoid_preprocess_fn), 84 | 'y_vel': vmap(humanoid_preprocess_fn), 85 | 'jump': vmap(humanoid_preprocess_fn), 86 | 'shift_left': vmap(humanoid_preprocess_fn), 87 | 'backward': vmap(humanoid_preprocess_fn), 88 | 'z_vel': vmap(humanoid_preprocess_fn), 89 | 'negative_z_vel': vmap(humanoid_preprocess_fn), 90 | 'tracking': vmap(humanoid_preprocess_fn), 91 | } 92 | 93 | dataset_preprocess_functions = { 94 | k: preprocess_dataset(fn) for k, fn in preprocess_functions.items() 95 | } 96 | 97 | def get_preprocess_fn(env): 98 | return preprocess_functions.get(env, lambda x: x) -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /trajectory/models/transformers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import math 8 | import os 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import functional as F 13 | from xformers.components.attention import ScaledDotProduct 14 | import xformers.ops as xops 15 | 16 | from xformers.factory import xFormerEncoderBlock, xFormerEncoderConfig 17 | 18 | 19 | 20 | class Block(nn.Module): 21 | def __init__(self, n_embd, n_head, resid_pdrop, attn_pdrop, sequence_length, causal=True, rotary=False, 22 | kv_cache=False): 23 | super().__init__() 24 | """ 25 | if "ma_update" in config: 26 | sequence_length = int(config.block_size / config.transition_dim) 27 | else: 28 | sequence_length = int(config.max_sequence_length / config.latent_step)*config.code_per_step 29 | """ 30 | 31 | self.sequence_length = sequence_length 32 | 33 | # if triton is installed, use FusedMLP, otherwise use MLP 34 | 35 | block_config = { 36 | "dim_model": n_embd, 37 | "residual_norm_style": "post", # Optional, pre/post 38 | 39 | "multi_head_config": { 40 | "num_heads": n_head, 41 | "residual_dropout": resid_pdrop, 42 | "attention": { 43 | "name": "scaled_dot_product", 44 | "dropout": attn_pdrop, 45 | "seq_len": sequence_length, 46 | "num_rules": n_head, 47 | }, 48 | }, 49 | "feedforward_config": { 50 | "name": "MLP", 51 | "dropout": resid_pdrop, 52 | "activation": "gelu", 53 | "hidden_layer_multiplier": 4, 54 | }, 55 | } 56 | if rotary: 57 | block_config["multi_head_config"]["attention"]["rotary"] = True 58 | block_config["multi_head_config"]["attention"]["rotary_dim"] = n_embd 59 | config = xFormerEncoderConfig(**block_config) 60 | self.block = xFormerEncoderBlock(config) 61 | self.causal = causal 62 | self.kv_cache = kv_cache 63 | self.attention_mask = torch.ones((sequence_length, sequence_length), device="cuda") 64 | if causal: 65 | self.attention_mask = torch.triu(self.attention_mask,diagonal=1)*(-1e9) 66 | 67 | def forward(self, x, kv_cache=None): 68 | # pad input to sequence length 69 | if self.block.patch_emb is not None: 70 | x = self.block.patch_emb(x) 71 | 72 | if self.block.pose_encoding is not None: 73 | x = self.block.pose_encoding(x) 74 | 75 | if hasattr(self.block, "embedding_projector"): 76 | x = self.block.embedding_projector(x) 77 | 78 | # Handle the optional input masking, differs on Q, K, V 79 | if kv_cache is None: 80 | q, k, v = x, x, x 81 | else: 82 | assert x.size(1) == 1, "Only used for autoregressive decoding" 83 | if kv_cache.size(1) == 1: 84 | kv_cache = kv_cache.repeat(x.size(0), 1, 1) 85 | q, k, v = x, torch.cat([kv_cache, x], dim=1), torch.cat([kv_cache, x], dim=1) 86 | 87 | new_cache = k 88 | # Pre/Post norms and residual paths are already handled 89 | if kv_cache is None: 90 | x += self.block.wrap_att.sublayer.layer(q, k, v, att_mask=self.attention_mask) 91 | else: 92 | x += self.block.wrap_att.sublayer.layer(q, k, v, att_mask=None) 93 | 94 | x = self.block.wrap_att.norm(x) 95 | x = self.block.wrap_ff(inputs=[x]) 96 | 97 | # Optional simplicial embeddings 98 | if self.block.simplicial_embedding is not None: 99 | x = self.block.simplicial_embedding(x) 100 | if self.kv_cache: 101 | return x, new_cache 102 | else: 103 | return x -------------------------------------------------------------------------------- /plotting/read_results.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import glob 9 | import numpy as np 10 | import json 11 | from scipy.stats import bootstrap 12 | import pdb 13 | 14 | from collections import defaultdict 15 | import trajectory.utils as utils 16 | 17 | DATASETS = [ 18 | f'{env}-{buffer}' 19 | for env in ['hopper', 'walker2d', 'halfcheetah', 'ant'] 20 | for buffer in ['medium-expert-v2', 'medium-v2', 'medium-replay-v2'] 21 | ] 22 | 23 | LOGBASE = os.path.expanduser('~/logs') 24 | TRIAL = '*' 25 | EXP_NAME = 'plans/defaults/freq1_H1_beam50' 26 | 27 | def load_results(paths, humanoid=False): 28 | ''' 29 | paths : path to directory containing experiment trials 30 | ''' 31 | scores = [] 32 | infos = defaultdict(list) 33 | mean_infos = {} 34 | for i, path in enumerate(sorted(paths)): 35 | if humanoid: 36 | score, info = load_humanoid_result(path) 37 | else: 38 | score, info = load_result(path) 39 | if score is None: 40 | continue 41 | scores.append(score) 42 | for k, v in info.items(): 43 | infos[k].append(v) 44 | 45 | suffix = path.split('/')[-1] 46 | 47 | for k, v in infos.items(): 48 | mean_infos[k] = np.nanmean(v) 49 | 50 | res = bootstrap(data=[scores], statistic=np.mean, axis=0) 51 | bootstrap_error = res.standard_error 52 | conf_interval = res.confidence_interval 53 | mean_infos['bootstrap_error'] = bootstrap_error 54 | mean_infos['conf_lower'] = conf_interval.low 55 | mean_infos['conf_upper'] = conf_interval.high 56 | 57 | mean = np.mean(scores) 58 | err = np.std(scores) / np.sqrt(len(scores)) 59 | return mean, err, scores, mean_infos 60 | 61 | 62 | def load_humanoid_result(path): 63 | fullpath = os.path.join(path, 'rollout.json') 64 | suffix = path.split('/')[-1] 65 | 66 | if not os.path.exists(fullpath): 67 | return None, None 68 | 69 | results = json.load(open(fullpath, 'rb')) 70 | info = dict(returns=results["return"], 71 | discount_return=results["discount_return"], 72 | prediction_error=results["prediction_error"], 73 | value_mean=results["value_mean"], 74 | step=results["step"]) 75 | 76 | return results["return"], info 77 | 78 | def load_result(path): 79 | ''' 80 | path : path to experiment directory; expects `rollout.json` to be in directory 81 | ''' 82 | #path = os.path.join(path, "0") 83 | fullpath = os.path.join(path, 'rollout.json') 84 | suffix = path.split('/')[-1] 85 | 86 | if not os.path.exists(fullpath): 87 | return None, None 88 | 89 | results = json.load(open(fullpath, 'rb')) 90 | score = results['score'] 91 | info = dict(returns=results["return"], 92 | first_value=results["first_value"], 93 | first_search_value=results["first_search_value"], 94 | discount_return=results["discount_return"], 95 | prediction_error=results["prediction_error"], 96 | step=results["step"]) 97 | 98 | return score * 100, info 99 | 100 | ####################### 101 | ######## setup ######## 102 | ####################### 103 | 104 | class Parser(utils.Parser): 105 | dataset: str = None 106 | exp_name: str = None 107 | output: str = None 108 | test_planner: str = None 109 | wildcard_exp_name: bool = True 110 | 111 | if __name__ == '__main__': 112 | 113 | args = Parser().parse_args() 114 | 115 | write_to_file = args.output is not None 116 | 117 | if args.wildcard_exp_name: 118 | exp_name = args.exp_name+"*" 119 | else: 120 | exp_name = args.exp_name 121 | 122 | if write_to_file: 123 | f = open(args.output, "a") 124 | 125 | for dataset in ([args.dataset] if args.dataset else DATASETS): 126 | subdirs = glob.glob(os.path.join(LOGBASE, dataset)) 127 | 128 | for subdir in subdirs: 129 | reldir = subdir.split('/')[-1] 130 | if args.test_planner is not None: 131 | paths = glob.glob(os.path.join(subdir, exp_name, TRIAL, args.test_planner)) 132 | else: 133 | paths = glob.glob(os.path.join(subdir, exp_name, TRIAL)) 134 | 135 | if "mocapact" in args.dataset: 136 | mean, err, returns, infos = load_results(paths, humanoid=True) 137 | string_print=f'{args.exp_name} | {dataset.ljust(30)} | {len(returns)} returns | return {mean:.2f} +/- {err:.2f} | value mean {infos["value_mean"]:.2f}' 138 | else: 139 | mean, err, scores, infos = load_results(paths) 140 | string_print=f'{dataset.ljust(30)} | {len(scores)} scores | score {mean:.2f} +/- {err:.2f} | ' 141 | for k, v in infos.items(): 142 | string_print += f'{k} {v:.4f} | ' 143 | print(string_print) 144 | if write_to_file: 145 | f.write(string_print+'\n') 146 | 147 | if write_to_file: 148 | f.close() 149 | 150 | -------------------------------------------------------------------------------- /trajectory/tfds/generate_tfds_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Example for the new TFDS based dataset. 9 | 10 | This shows how to use the new `MocapActTFDSDataSource` for loading 11 | episodes from the MocapAct dataset. 12 | 13 | Right now, the new interface does not yet implement inference of the example 14 | specs like the old interface. However, this will be added in the near future. 15 | Otherwise, the TFDS interface should provide identical episode tf.data.Dataset 16 | like the old HDF5 data source. 17 | 18 | The new TFDS interface uses https://www.tensorflow.org/datasets for converting 19 | and processing the raw HDF5 into TFRecord which provides a more scalable pipeline 20 | and addresses some issues with imbalanced loading from the old HDF5 pipeline. 21 | 22 | You don't have to know TFDS in detail to use it. In particular, how to implement 23 | the dataset builder in trajectory/jax/tap/mocap/tfds. 24 | 25 | Having said that, it will be useful to know to use effectively use the new interface 26 | if you spend some time reading through a basic tutorial, which can be found at 27 | 28 | https://www.tensorflow.org/datasets/overview#load_a_dataset 29 | 30 | There are some additional details which would be useful to know when we want to 31 | support building more customized datasets. The extension points have been 32 | annotated with `NOTE`. 33 | 34 | """ 35 | import functools 36 | import os 37 | 38 | import rlds 39 | import tensorflow as tf 40 | import tensorflow_datasets as tfds 41 | from absl import app, flags 42 | 43 | from trajectory.tfds import mocap_utils 44 | 45 | # This is needed so that tfds can discover our mocapact builder. 46 | from trajectory.tfds.tfds import mocapact 47 | 48 | _MOCAPACT_DATA_DIR = flags.DEFINE_string( 49 | "mocapact_data_dir", 50 | default=os.environ.get("MOCAPACT_DATA_DIR", None), 51 | help="Path to the MocapAct download directory.", 52 | ) 53 | _MOCAPACT_DATA_SIZE = flags.DEFINE_string( 54 | "size", 55 | default="small", 56 | help="Size of the MocapAct dataset.", 57 | ) 58 | 59 | def main(_): 60 | mocap_data_dir = _MOCAPACT_DATA_DIR.value 61 | assert mocap_data_dir is not None 62 | # Use case 1. Manual way of preparing dataset 63 | # Prepare the dataset. 64 | data_size = _MOCAPACT_DATA_SIZE.value 65 | builder = tfds.builder(f"mocapact/{data_size}_cmu_observable") 66 | # NOTE: There's no downloading. 67 | # Instead, we specify the manual_dir that includes the original HDF5 dataset. 68 | # Overriding the manual_dir allows us to specify a custom location 69 | # for the original HDF5 files. 70 | # The manual_dir should point to the directory where the MocapAct dataset is 71 | # downloaded. For example, consider the download directory `./data` and 72 | # you have downloaded the small subset. Then the manual_dir would be `./data` 73 | # and the builder will look for HDF5 files located in `./data/dataset/small` 74 | # The later parts of the path can be customized with a BuilderConfig. 75 | # See mocap/tfds/mocapact/mocapact.py 76 | builder.download_and_prepare( 77 | download_config=tfds.download.DownloadConfig(manual_dir=mocap_data_dir) 78 | ) 79 | # Use case 2. Use MocapActTFDSSource to retrieve episode dataset 80 | # similar to the HDF5 data source 81 | # Specify "mocapact" as the name will build the dataset with the default config 82 | # (small_cmu_observable). 83 | # Which is to build against the small/ directory with reduced CMU observables. 84 | # An alternative way of specifying the name is `mocapact/small_cmu_observable`. 85 | data_source = mocap_utils.MocapActTFDSDataSource( 86 | name="mocapact", 87 | # Specify a split to retrieve. Only subsets of train is available 88 | # as there are no explicit train/val splits from MocapAct. 89 | split="train", 90 | # Same as the HDF5 data source. 91 | normalize_observations=True, 92 | normalize_actions=True, 93 | use_mean_actions=True, 94 | ) 95 | # Same as the HDF5 data source. 96 | episode_dataset = data_source.make_episode_dataset() 97 | print(episode_dataset.element_spec["steps"]) 98 | sequence_length = 25 99 | # Same way to convert to the sequence dataset. 100 | sequence_dataset: tf.data.Dataset = episode_dataset.interleave( 101 | lambda episode: rlds.transformations.batch( 102 | episode["steps"], size=sequence_length, shift=1, drop_remainder=False 103 | ), 104 | deterministic=True, 105 | num_parallel_calls=tf.data.AUTOTUNE, 106 | cycle_length=16, 107 | block_length=16, 108 | ) 109 | sequence_dataset = sequence_dataset.map( 110 | functools.partial(mocap_utils.pad_steps, max_len=sequence_length) 111 | ) 112 | print(sequence_dataset.element_spec) 113 | 114 | 115 | if __name__ == "__main__": 116 | tf.config.set_visible_devices([], "GPU") 117 | app.run(main) 118 | -------------------------------------------------------------------------------- /trajectory/utils/progress.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import time 8 | import math 9 | import re 10 | import pdb 11 | 12 | class Progress: 13 | 14 | def __init__(self, total, name = 'Progress', ncol=3, max_length=30, indent=8, line_width=100, speed_update_freq=100): 15 | self.total = total 16 | self.name = name 17 | self.ncol = ncol 18 | self.max_length = max_length 19 | self.indent = indent 20 | self.line_width = line_width 21 | self._speed_update_freq = speed_update_freq 22 | 23 | self._step = 0 24 | self._prev_line = '\033[F' 25 | self._clear_line = ' ' * self.line_width 26 | 27 | self._pbar_size = self.ncol * self.max_length 28 | self._complete_pbar = '#' * self._pbar_size 29 | self._incomplete_pbar = ' ' * self._pbar_size 30 | 31 | self.lines = [''] 32 | self.fraction = '{} / {}'.format(0, self.total) 33 | 34 | self.resume() 35 | 36 | 37 | def update(self, description, n=1): 38 | self._step += n 39 | if self._step % self._speed_update_freq == 0: 40 | self._time0 = time.time() 41 | self._step0 = self._step 42 | self.set_description(description) 43 | 44 | def resume(self): 45 | self._skip_lines = 1 46 | print('\n', end='') 47 | self._time0 = time.time() 48 | self._step0 = self._step 49 | 50 | def pause(self): 51 | self._clear() 52 | self._skip_lines = 1 53 | 54 | def set_description(self, params=[]): 55 | 56 | if type(params) == dict: 57 | params = sorted([ 58 | (key, val) 59 | for key, val in params.items() 60 | ]) 61 | 62 | ############ 63 | # Position # 64 | ############ 65 | self._clear() 66 | 67 | ########### 68 | # Percent # 69 | ########### 70 | percent, fraction = self._format_percent(self._step, self.total) 71 | self.fraction = fraction 72 | 73 | ######### 74 | # Speed # 75 | ######### 76 | speed = self._format_speed(self._step) 77 | 78 | ########## 79 | # Params # 80 | ########## 81 | num_params = len(params) 82 | nrow = math.ceil(num_params / self.ncol) 83 | params_split = self._chunk(params, self.ncol) 84 | params_string, lines = self._format(params_split) 85 | self.lines = lines 86 | 87 | 88 | description = '{} | {}{}'.format(percent, speed, params_string) 89 | print(description) 90 | self._skip_lines = nrow + 1 91 | 92 | def append_description(self, descr): 93 | self.lines.append(descr) 94 | 95 | def _clear(self): 96 | position = self._prev_line * self._skip_lines 97 | empty = '\n'.join([self._clear_line for _ in range(self._skip_lines)]) 98 | print(position, end='') 99 | print(empty) 100 | print(position, end='') 101 | 102 | def _format_percent(self, n, total): 103 | if total: 104 | percent = n / float(total) 105 | 106 | complete_entries = int(percent * self._pbar_size) 107 | incomplete_entries = self._pbar_size - complete_entries 108 | 109 | pbar = self._complete_pbar[:complete_entries] + self._incomplete_pbar[:incomplete_entries] 110 | fraction = '{} / {}'.format(n, total) 111 | string = '{} [{}] {:3d}%'.format(fraction, pbar, int(percent*100)) 112 | else: 113 | fraction = '{}'.format(n) 114 | string = '{} iterations'.format(n) 115 | return string, fraction 116 | 117 | def _format_speed(self, n): 118 | num_steps = n - self._step0 119 | t = time.time() - self._time0 120 | speed = num_steps / t 121 | string = '{:.1f} Hz'.format(speed) 122 | if num_steps > 0: 123 | self._speed = string 124 | return string 125 | 126 | def _chunk(self, l, n): 127 | return [l[i:i+n] for i in range(0, len(l), n)] 128 | 129 | def _format(self, chunks): 130 | lines = [self._format_chunk(chunk) for chunk in chunks] 131 | lines.insert(0,'') 132 | padding = '\n' + ' '*self.indent 133 | string = padding.join(lines) 134 | return string, lines 135 | 136 | def _format_chunk(self, chunk): 137 | line = ' | '.join([self._format_param(param) for param in chunk]) 138 | return line 139 | 140 | def _format_param(self, param, str_length=8): 141 | k, v = param 142 | k = k.rjust(str_length) 143 | if type(v) == float or hasattr(v, 'item'): 144 | string = '{}: {:12.4f}' 145 | else: 146 | string = '{}: {}' 147 | v = str(v).rjust(12) 148 | return string.format(k, v)[:self.max_length] 149 | 150 | def stamp(self): 151 | if self.lines != ['']: 152 | params = ' | '.join(self.lines) 153 | string = '[ {} ] {}{} | {}'.format(self.name, self.fraction, params, self._speed) 154 | string = re.sub(r'\s+', ' ', string) 155 | self._clear() 156 | print(string, end='\n') 157 | self._skip_lines = 1 158 | else: 159 | self._clear() 160 | self._skip_lines = 0 161 | 162 | def close(self): 163 | self.pause() 164 | 165 | class Silent: 166 | 167 | def __init__(self, *args, **kwargs): 168 | pass 169 | 170 | def __getattr__(self, attr): 171 | return lambda *args: None 172 | 173 | 174 | if __name__ == '__main__': 175 | silent = Silent() 176 | silent.update() 177 | silent.stamp() 178 | 179 | num_steps = 1000 180 | progress = Progress(num_steps) 181 | for i in range(num_steps): 182 | progress.update() 183 | params = [ 184 | ['A', '{:06d}'.format(i)], 185 | ['B', '{:06d}'.format(i)], 186 | ['C', '{:06d}'.format(i)], 187 | ['D', '{:06d}'.format(i)], 188 | ['E', '{:06d}'.format(i)], 189 | ['F', '{:06d}'.format(i)], 190 | ['G', '{:06d}'.format(i)], 191 | ['H', '{:06d}'.format(i)], 192 | ] 193 | progress.set_description(params) 194 | time.sleep(0.01) 195 | progress.close() 196 | -------------------------------------------------------------------------------- /tasks/tracking.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Wraps the MultiClipMocapTracking dm_env into a Gym environment. 9 | """ 10 | import numpy as np 11 | from gym import spaces 12 | from pathlib import Path 13 | from typing import Any, Dict, Optional, Tuple, Union 14 | 15 | from dm_control.locomotion.mocap import cmu_mocap_data 16 | from dm_control.locomotion.tasks.reference_pose import types 17 | from dm_control.locomotion.tasks.reference_pose import tracking 18 | from dm_control.locomotion.walkers import cmu_humanoid 19 | 20 | from mocapact.envs import dm_control_wrapper 21 | 22 | class MocapTrackingGymEnv(dm_control_wrapper.DmControlWrapper): 23 | """ 24 | Wraps the MultiClipMocapTracking into a Gym env. 25 | Adapted from https://github.com/microsoft/MoCapAct/blob/main/mocapact/envs/tracking.py 26 | """ 27 | 28 | def __init__( 29 | self, 30 | dataset: str = None, 31 | ref_steps: Tuple[int] = (0,), 32 | mocap_path: Optional[Union[str, Path]] = None, 33 | task_kwargs: Optional[Dict[str, Any]] = None, 34 | environment_kwargs: Optional[Dict[str, Any]] = None, 35 | act_noise: float = 0.01, 36 | enable_all_proprios: bool = False, 37 | enable_cameras: bool = False, 38 | include_clip_id: bool = False, 39 | display_ghost: bool = True, 40 | 41 | # for rendering 42 | width: int = 640, 43 | height: int = 480, 44 | camera_id: int = 3 45 | ): 46 | if dataset is None: 47 | self._dataset = types.ClipCollection(ids=['CMU_002_01', 'CMU_009_01', 'CMU_010_04', 'CMU_013_11', 'CMU_014_06', 'CMU_041_02', 48 | 'CMU_046_01', 'CMU_075_01', 'CMU_083_18', 'CMU_105_53', 'CMU_143_41', 'CMU_049_07']) 49 | else: 50 | self._dataset = types.ClipCollection(ids=[dataset]) 51 | self._enable_all_proprios = enable_all_proprios 52 | self._enable_cameras = enable_cameras 53 | self._include_clip_id = include_clip_id 54 | task_kwargs = task_kwargs or dict() 55 | task_kwargs['ref_path'] = mocap_path if mocap_path else cmu_mocap_data.get_path_for_cmu(version='2020') 56 | task_kwargs['dataset'] = self._dataset 57 | task_kwargs['ref_steps'] = ref_steps 58 | if display_ghost: 59 | task_kwargs['ghost_offset'] = np.array([1., 0., 0.]) 60 | super().__init__( 61 | tracking.MultiClipMocapTracking, 62 | task_kwargs=task_kwargs, 63 | environment_kwargs=environment_kwargs, 64 | act_noise=act_noise, 65 | width=width, 66 | height=height, 67 | camera_id=camera_id 68 | ) 69 | 70 | def _get_walker(self): 71 | return cmu_humanoid.CMUHumanoidPositionControlledV2020 72 | 73 | def _create_env( 74 | self, 75 | task_type, 76 | task_kwargs, 77 | environment_kwargs, 78 | act_noise=0., 79 | arena_size=(8., 8.) 80 | ): 81 | env = super()._create_env(task_type, task_kwargs, environment_kwargs, act_noise, arena_size) 82 | walker = env._task._walker 83 | # Remove the contacts. 84 | # for geom in walker.mjcf_model.find_all('geom'): 85 | # # alpha=0.999 ensures grey ghost reference. 86 | # # for alpha=1.0 there is no visible difference between real walker and 87 | # # ghost reference. 88 | # alpha = 0.999 89 | # if geom.rgba is not None and geom.rgba[3] < alpha: 90 | # alpha = geom.rgba[3] 91 | 92 | # geom.set_attributes( 93 | # # contype=0, 94 | # # conaffinity=0, 95 | # rgba=(0.5, 0.5, 0.5, alpha)) 96 | 97 | if self._enable_all_proprios: 98 | walker.observables.enable_all() 99 | walker.observables.prev_action.enabled = False # this observable is not implemented 100 | if not self._enable_cameras: 101 | # TODO: procedurally find the cameras 102 | walker.observables.egocentric_camera.enabled = False 103 | walker.observables.body_camera.enabled = False 104 | env.reset() 105 | return env 106 | 107 | def _create_observation_space(self) -> spaces.Dict: 108 | obs_spaces = dict() 109 | for k, v in self._env.observation_spec().items(): 110 | if v.dtype == np.float64 and np.prod(v.shape) > 0: 111 | obs_spaces[k] = spaces.Box( 112 | -np.infty, 113 | np.infty, 114 | shape=(np.prod(v.shape),), 115 | dtype=np.float32 116 | ) 117 | elif v.dtype == np.uint8: 118 | tmp = v.generate_value() 119 | obs_spaces[k] = spaces.Box( 120 | v.minimum.item(), 121 | v.maximum.item(), 122 | shape=tmp.shape, 123 | dtype=np.uint8 124 | ) 125 | elif k == 'walker/clip_id' and self._include_clip_id: 126 | obs_spaces[k] = spaces.Discrete(len(self._dataset.ids)) 127 | return spaces.Dict(obs_spaces) 128 | 129 | def step(self, action: np.ndarray) -> Tuple[Dict[str, np.ndarray], float, bool, Dict[str, Any]]: 130 | obs, reward, done, info = super().step(action) 131 | info['time_in_clip'] = obs['walker/time_in_clip'].item() 132 | info['start_time_in_clip'] = self._start_time_in_clip 133 | info['last_time_in_clip'] = self._last_time_in_clip 134 | return obs, reward, done, info 135 | 136 | def reset(self): 137 | time_step = self._env.reset() 138 | obs = self.get_observation(time_step) 139 | self._start_time_in_clip = obs['walker/time_in_clip'].item() 140 | self._last_time_in_clip = self._env.task._last_step / (len(self._env.task._clip_reference_features['joints'])-1) 141 | return obs -------------------------------------------------------------------------------- /trajectory/utils/serialization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import time 8 | import sys 9 | import os 10 | import glob 11 | import pickle 12 | import json 13 | import torch 14 | import pdb 15 | 16 | def mkdir(savepath, prune_fname=False): 17 | """ 18 | returns `True` iff `savepath` is created 19 | """ 20 | if prune_fname: 21 | savepath = os.path.dirname(savepath) 22 | if not os.path.exists(savepath): 23 | try: 24 | os.makedirs(savepath) 25 | except: 26 | print(f'[ utils/serialization ] Warning: did not make directory: {savepath}') 27 | return False 28 | return True 29 | else: 30 | return False 31 | 32 | def get_latest_epoch(loadpath, prior='', debug=False): 33 | states = glob.glob1(loadpath, prior+'state_*') 34 | states = [s for s in states if 'running' not in s] 35 | if debug: 36 | states = [s for s in states if 'debug' in s] 37 | debug_suffx = '_debug' 38 | else: 39 | states = [s for s in states if 'debug' not in s] 40 | debug_suffx = '' 41 | latest_epoch = -1 42 | for state in states: 43 | epoch = int(state.replace(debug_suffx, '').replace(prior+'state_', '').replace('.pt', '')) 44 | latest_epoch = max(epoch, latest_epoch) 45 | return latest_epoch 46 | 47 | def load_transformer_model(logger, *loadpath, epoch=None, device='cuda:0', debug=False, type='prior'): 48 | loadpath = os.path.join(*loadpath) 49 | config_path = os.path.join(loadpath, f'{type}_model_config.pkl') 50 | debug_suffix = '' if not debug else '_debug' 51 | 52 | if epoch == 'latest': 53 | epoch = get_latest_epoch(loadpath, f"{type}_", debug) 54 | 55 | logger.debug(f'[ utils/serialization ] Loading model epoch: {epoch}', main_process_only=True) 56 | state_path = os.path.join(loadpath, f'{type}_state_{epoch}{debug_suffix}.pt') 57 | 58 | config = pickle.load(open(config_path, 'rb')) 59 | map_location = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 60 | state = torch.load(state_path, map_location=map_location) 61 | 62 | model = config() 63 | model.to(device) 64 | model.load_state_dict(state, strict=True) 65 | 66 | logger.debug(f'\n[ utils/serialization ] Loaded config from {config_path}\n', main_process_only=True) 67 | logger.debug(config, main_process_only=True) 68 | return model, epoch 69 | 70 | def load_optimizer(logger, optimizer, *loadpath, epoch=None, device='cuda:0', debug=False, type='prior'): 71 | loadpath = os.path.join(*loadpath) 72 | prefix = f"{type}_" if type != "" else "" 73 | 74 | if epoch == 'latest': 75 | epoch = get_latest_epoch(loadpath, f"{prefix}optimizer_", debug) 76 | 77 | if epoch < 0: 78 | return optimizer, epoch 79 | 80 | logger.debug(f'[ utils/serialization ] Loading optimizer epoch: {epoch}', main_process_only=True) 81 | state_path = os.path.join(loadpath, f'{prefix}optimizer_state_{epoch}.pt') 82 | 83 | map_location = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 84 | state = torch.load(state_path, map_location=map_location) 85 | 86 | optimizer.load_state_dict(state) 87 | 88 | return optimizer, epoch 89 | 90 | def load_scaler(logger, scaler, *loadpath, epoch=None, device='cuda:0', debug=False, type='prior'): 91 | loadpath = os.path.join(*loadpath) 92 | prefix = f"{type}_" if type != "" else "" 93 | 94 | if epoch == 'latest': 95 | epoch = get_latest_epoch(loadpath, f"{prefix}scaler_", debug) 96 | 97 | if epoch < 0: 98 | return scaler, epoch 99 | 100 | logger.debug(f'[ utils/serialization ] Loading optimizer epoch: {epoch}', main_process_only=True) 101 | state_path = os.path.join(loadpath, f'{prefix}scaler_state_{epoch}.pt') 102 | 103 | map_location = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 104 | state = torch.load(state_path, map_location=map_location) 105 | 106 | scaler.load_state_dict(state) 107 | 108 | return scaler, epoch 109 | 110 | def load_model(logger, *loadpath, epoch=None, device='cuda:0', data_parallel=False, debug=False): 111 | loadpath = os.path.join(*loadpath) 112 | config_path = os.path.join(loadpath, 'model_config.pkl') 113 | debug_suffix = '' if not debug else '_debug' 114 | 115 | if epoch == 'latest': 116 | epoch = get_latest_epoch(loadpath, debug=debug) 117 | 118 | logger.debug(f'[ utils/serialization ] Loading model epoch: {epoch}', main_process_only=True) 119 | state_path = os.path.join(loadpath, f'state_{epoch}{debug_suffix}.pt') 120 | 121 | config = pickle.load(open(config_path, 'rb')) 122 | map_location = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 123 | state = torch.load(state_path, map_location=map_location) 124 | 125 | model = config() 126 | model.to(device) 127 | model.load_state_dict(state, strict=True) 128 | if data_parallel: 129 | # use all available cuda devices for data parallelization 130 | num_gpus = torch.cuda.device_count() 131 | model = torch.nn.DataParallel(model, device_ids=range(num_gpus)) 132 | 133 | logger.debug(f'\n[ utils/serialization ] Loaded config from {config_path}\n', main_process_only=True) 134 | logger.debug(config, main_process_only=True) 135 | 136 | return model, epoch 137 | 138 | def load_config(logger, *loadpath): 139 | loadpath = os.path.join(*loadpath) 140 | config = pickle.load(open(loadpath, 'rb')) 141 | logger.debug(f'[ utils/serialization ] Loaded config from {loadpath}', main_process_only=True) 142 | logger.debug(config, main_process_only=True) 143 | return config 144 | 145 | def load_from_config(logger, *loadpath, **kwargs): 146 | config = load_config(logger, *loadpath) 147 | return config.make(**kwargs) 148 | 149 | def load_args(*loadpath): 150 | from .setup import Parser 151 | loadpath = os.path.join(*loadpath) 152 | args_path = os.path.join(loadpath, 'args.json') 153 | args = Parser() 154 | args.load(args_path) 155 | return args 156 | -------------------------------------------------------------------------------- /tasks/cmu_relabel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """ 8 | Try to maximize speed on x-y plane, adopted from https://github.com/microsoft/MoCapAct/blob/main/mocapact/tasks/velocity_control.py 9 | """ 10 | import collections 11 | import numpy as np 12 | from dm_control import composer 13 | from dm_control.composer.observation import observable as dm_observable 14 | from dm_control.locomotion.tasks.reference_pose import tracking 15 | from trajectory.utils.relabel_humanoid import project_left, project_forward, project_height 16 | 17 | 18 | class CMURelabelTask(composer.Task): 19 | """ 20 | A task that requires the walker to track a randomly changing velocity. 21 | """ 22 | 23 | def __init__( 24 | self, 25 | walker, 26 | arena, 27 | max_speed=4.5, 28 | reward_margin=0.75, 29 | physics_timestep=tracking.DEFAULT_PHYSICS_TIMESTEP, 30 | control_timestep=0.03, 31 | contact_termination=True, 32 | relabel_type="speed" 33 | ): 34 | self._walker = walker 35 | self._arena = arena 36 | self._walker.create_root_joints(self._arena.attach(self._walker)) 37 | self._max_speed = max_speed 38 | self._reward_margin = reward_margin 39 | self._move_speed = 0. 40 | self._move_angle = 0. 41 | self._move_speed_counter = 0. 42 | self.relabel_type = relabel_type 43 | self.enabled_observables = [] 44 | self.enabled_observables += self._walker.observables.proprioception 45 | self.enabled_observables += self._walker.observables.kinematic_sensors 46 | self.enabled_observables += self._walker.observables.dynamic_sensors 47 | self.enabled_observables.append(self._walker.observables.sensors_touch) 48 | self.enabled_observables.append(self._walker.observables.egocentric_camera) 49 | for observable in self.enabled_observables: 50 | observable.enabled = True 51 | 52 | self.set_timesteps(physics_timestep=physics_timestep, control_timestep=control_timestep) 53 | self._contact_termination = contact_termination 54 | 55 | @property 56 | def root_entity(self): 57 | return self._arena 58 | 59 | def _is_disallowed_contact(self, contact): 60 | set1, set2 = self._walker_nonfoot_geomids, self._ground_geomids 61 | return ((contact.geom1 in set1 and contact.geom2 in set2) or 62 | (contact.geom1 in set2 and contact.geom2 in set1)) 63 | 64 | def should_terminate_episode(self, physics): 65 | del physics 66 | return self._failure_termination 67 | 68 | def get_discount(self, physics): 69 | del physics 70 | if self._failure_termination: 71 | return 0. 72 | else: 73 | return 1. 74 | 75 | def initialize_episode(self, physics, random_state): 76 | self._walker.reinitialize_pose(physics, random_state) 77 | 78 | self._failure_termination = False 79 | walker_foot_geoms = set(self._walker.ground_contact_geoms) 80 | walker_nonfoot_geoms = [ 81 | geom for geom in self._walker.mjcf_model.find_all('geom') 82 | if geom not in walker_foot_geoms 83 | ] 84 | self._walker_nonfoot_geomids = set(physics.bind(walker_nonfoot_geoms).element_id) 85 | self._ground_geomids = set(physics.bind(self._arena.ground_geoms).element_id) 86 | 87 | def get_reward(self, physics): 88 | if self.relabel_type == "speed": 89 | sensor_vel = self._walker.observables.sensors_velocimeter(physics) 90 | reward = np.linalg.norm(sensor_vel) 91 | elif self.relabel_type == "x_vel": 92 | reward = self._walker.observables.sensors_velocimeter(physics)[0] 93 | elif self.relabel_type == "y_vel": 94 | reward = self._walker.observables.sensors_velocimeter(physics)[1] 95 | elif self.relabel_type == "forward": 96 | reward = project_forward(self._walker.observables.sensors_velocimeter(physics), 97 | self._walker.observables.world_zaxis(physics)) 98 | elif self.relabel_type == "backward": 99 | reward = -project_forward(self._walker.observables.sensors_velocimeter(physics), 100 | self._walker.observables.world_zaxis(physics)) 101 | elif self.relabel_type == "shift_left": 102 | reward = project_left(self._walker.observables.sensors_velocimeter(physics), 103 | self._walker.observables.world_zaxis(physics)) 104 | elif self.relabel_type == "jump": 105 | reward = np.maximum(0, project_height(self._walker.observables.sensors_velocimeter(physics), 106 | self._walker.observables.world_zaxis(physics))) 107 | elif self.relabel_type == "z_vel": 108 | reward = self._walker.observables.sensors_velocimeter(physics)[2] 109 | elif self.relabel_type == "negative_z_vel": 110 | reward = -self._walker.observables.sensors_velocimeter(physics)[2] 111 | elif self.relabel_type == "rotate_x": 112 | reward = self._walker.observables.sensors_gyro(physics)[0] 113 | elif self.relabel_type == "rotate_y": 114 | reward = self._walker.observables.sensors_gyro(physics)[1] 115 | elif self.relabel_type == "rotate_z": 116 | reward = self._walker.observables.sensors_gyro(physics)[2] 117 | else: 118 | raise NotImplementedError() 119 | return reward 120 | 121 | def before_step(self, physics, action, random_state): 122 | self._walker.apply_action(physics, action, random_state) 123 | 124 | def after_step(self, physics, random_state): 125 | self._failure_termination = False 126 | if self._contact_termination: 127 | for contact in physics.data.contact: 128 | if self._is_disallowed_contact(contact): 129 | self._failure_termination = True 130 | break 131 | self._move_speed_counter += 1 132 | -------------------------------------------------------------------------------- /trajectory/utils/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import importlib 9 | import random 10 | import numpy as np 11 | import torch 12 | import pdb 13 | 14 | from tap import Tap 15 | from .serialization import mkdir 16 | from .arrays import set_device 17 | from .git_utils import ( 18 | get_git_rev, 19 | save_git_diff, 20 | ) 21 | 22 | def set_seed(seed): 23 | seed = int(seed) 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed_all(seed) 28 | 29 | def watch(args_to_watch): 30 | def _fn(args): 31 | exp_name = [] 32 | for key, label in args_to_watch: 33 | if not hasattr(args, key): 34 | continue 35 | val = getattr(args, key) 36 | exp_name.append(f'{label}{val}') 37 | exp_name = '_'.join(exp_name) 38 | exp_name = exp_name.replace('/_', '/') 39 | return exp_name 40 | return _fn 41 | 42 | def convert_to_type(string): 43 | if isinstance(string, bool): 44 | return string 45 | if isinstance(string, str) and string.lower() in ("true", "false", "True", "False"): 46 | return string.lower() == "true" 47 | try: 48 | return int(string) 49 | except ValueError: 50 | pass 51 | try: 52 | return float(string) 53 | except ValueError: 54 | pass 55 | return string 56 | 57 | class Parser(Tap): 58 | def save(self): 59 | fullpath = os.path.join(os.path.expanduser(self.savepath), 'args.json') 60 | print(f'[ utils/setup ] Saved args to {fullpath}') 61 | #super().save(fullpath, skip_unpicklable=True) 62 | 63 | def parse_args(self, experiment=None): 64 | args = super().parse_args(known_only=True) 65 | ## if not loading from a config script, skip the result of the setup 66 | if not hasattr(args, 'config'): return args 67 | args = self.read_config(args, experiment) 68 | self.add_extras(args) 69 | self.set_seed(args) 70 | self.get_commit(args) 71 | self.generate_exp_name(args) 72 | self.mkdir(args) 73 | self.save_diff(args) 74 | args.task_type = "locomotion" 75 | args.obs_shape = [-1] 76 | args.logbase = os.path.expanduser(args.logbase) 77 | args.savepath = os.path.expanduser(args.savepath) 78 | return args 79 | 80 | def read_config(self, args, experiment): 81 | ''' 82 | Load parameters from config file 83 | ''' 84 | dataset = args.dataset.replace('-', '_') 85 | print(f'[ utils/setup ] Reading config: {args.config}:{dataset}') 86 | module = importlib.import_module(args.config) 87 | params = getattr(module, 'base')[experiment] 88 | 89 | if hasattr(module, dataset) and experiment in getattr(module, dataset): 90 | print(f'[ utils/setup ] Using overrides | config: {args.config} | dataset: {dataset}') 91 | overrides = getattr(module, dataset)[experiment] 92 | params.update(overrides) 93 | else: 94 | print(f'[ utils/setup ] Not using overrides | config: {args.config} | dataset: {dataset}') 95 | 96 | for key, val in params.items(): 97 | setattr(args, key, val) 98 | 99 | return args 100 | 101 | def add_extras(self, args): 102 | ''' 103 | Override config parameters with command-line arguments 104 | ''' 105 | extras = args.extra_args 106 | if not len(extras): 107 | return 108 | 109 | print(f'[ utils/setup ] Found extras: {extras}') 110 | if len(extras) % 2 == 1: 111 | print(f'[ utils/setup ] Drop excessive extra arg: {extras[-1]}') 112 | extras = extras[:-1] 113 | assert len(extras) % 2 == 0, f'Found odd number ({len(extras)}) of extras: {extras}' 114 | for i in range(0, len(extras), 2): 115 | key = extras[i].replace('--', '') 116 | val = extras[i+1] 117 | # assert hasattr(args, key), f'[ utils/setup ] {key} not found in config: {args.config}' 118 | if not hasattr(args, key): 119 | print(f'[ utils/setup ] skipping {key} not found in config: {args.config}') 120 | continue 121 | old_val = getattr(args, key) 122 | old_type = type(old_val) 123 | print(f'[ utils/setup ] Overriding config | {key} : {old_val} --> {val}') 124 | if val == 'None': 125 | val = None 126 | elif val == 'latest': 127 | val = 'latest' 128 | elif old_type in [bool, type(None), list]: 129 | val = eval(val) 130 | else: 131 | val = val 132 | val = convert_to_type(val) 133 | setattr(args, key, val) 134 | 135 | def set_seed(self, args): 136 | if not 'seed' in dir(args): 137 | return 138 | set_seed(args.seed) 139 | 140 | def generate_exp_name(self, args): 141 | if not 'exp_name' in dir(args): 142 | return 143 | exp_name = getattr(args, 'exp_name') 144 | if callable(exp_name): 145 | exp_name_string = exp_name(args) 146 | print(f'[ utils/setup ] Setting exp_name to: {exp_name_string}') 147 | setattr(args, 'exp_name', exp_name_string) 148 | 149 | def mkdir(self, args): 150 | if 'logbase' in dir(args) and 'dataset' in dir(args) and 'exp_name' in dir(args): 151 | args.savepath = os.path.join(os.path.expanduser(args.logbase), args.dataset, args.exp_name) 152 | if 'suffix' in dir(args): 153 | args.savepath = os.path.join(os.path.expanduser(args.savepath), str(args.suffix)) 154 | if mkdir(args.savepath): 155 | print(f'[ utils/setup ] Made savepath: {args.savepath}') 156 | self.save() 157 | 158 | def get_commit(self, args): 159 | args.commit = get_git_rev() 160 | 161 | def save_diff(self, args): 162 | try: 163 | save_git_diff(os.path.join(os.path.expanduser(args.savepath), 'diff.txt')) 164 | except: 165 | print('[ utils/setup ] WARNING: did not save git diff') 166 | -------------------------------------------------------------------------------- /trajectory/utils/sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from torch.utils.data.sampler import Sampler 8 | from typing import Iterator, Iterable, Optional, Sequence, List, TypeVar, Generic, Sized, Union 9 | import torch 10 | 11 | class RandomSampler(Sampler[int]): 12 | r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset. 13 | If with replacement, then user can specify :attr:`num_samples` to draw. 14 | 15 | Args: 16 | data_source (Dataset): dataset to sample from 17 | replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False`` 18 | num_samples (int): number of samples to draw, default=`len(dataset)`. 19 | generator (Generator): Generator used in sampling. 20 | """ 21 | data_source: Sized 22 | replacement: bool 23 | 24 | def __init__(self, data_source: Sized, replacement: bool = False, 25 | num_samples: Optional[int] = None, generator=None, start_idx=None) -> None: 26 | self.data_source = data_source 27 | self.replacement = replacement 28 | self._num_samples = num_samples 29 | self.generator = generator 30 | self.start_idx = start_idx if start_idx is not None else 0 31 | 32 | if not isinstance(self.replacement, bool): 33 | raise TypeError("replacement should be a boolean value, but got " 34 | "replacement={}".format(self.replacement)) 35 | 36 | if not isinstance(self.num_samples, int) or self.num_samples <= 0: 37 | raise ValueError("num_samples should be a positive integer " 38 | "value, but got num_samples={}".format(self.num_samples)) 39 | 40 | @property 41 | def num_samples(self) -> int: 42 | # dataset size might change at runtime 43 | if self._num_samples is None: 44 | return len(self.data_source) 45 | return self._num_samples 46 | 47 | def __iter__(self) -> Iterator[int]: 48 | n = len(self.data_source) 49 | if self.generator is None: 50 | seed = int(torch.empty((), dtype=torch.int64).random_().item()) 51 | generator = torch.Generator() 52 | generator.manual_seed(seed) 53 | else: 54 | generator = self.generator 55 | 56 | if self.replacement: 57 | for _ in range(self.num_samples // 32): 58 | yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()[self.start_idx: ] 59 | yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()[self.start_idx: ] 60 | else: 61 | for _ in range(self.num_samples // n): 62 | yield from torch.randperm(n, generator=generator).tolist()[self.start_idx: ] 63 | yield from torch.randperm(n, generator=generator).tolist()[self.start_idx : self.num_samples % n] 64 | 65 | def __len__(self) -> int: 66 | return self.num_samples 67 | 68 | 69 | class BatchSampler(Sampler[List[int]]): 70 | r"""Wraps another sampler to yield a mini-batch of indices. 71 | 72 | Args: 73 | sampler (Sampler or Iterable): Base sampler. Can be any iterable object 74 | batch_size (int): Size of mini-batch. 75 | drop_last (bool): If ``True``, the sampler will drop the last batch if 76 | its size would be less than ``batch_size`` 77 | 78 | Example: 79 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False)) 80 | [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]] 81 | >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True)) 82 | [[0, 1, 2], [3, 4, 5], [6, 7, 8]] 83 | """ 84 | 85 | def __init__(self, data_source: Sized, batch_size: int, generator, drop_last: bool = False, start_idx: Optional[int] = None) -> None: 86 | # Since collections.abc.Iterable does not check for `__getitem__`, which 87 | # is one way for an object to be an iterable, we don't do an `isinstance` 88 | # check here. 89 | if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ 90 | batch_size <= 0: 91 | raise ValueError("batch_size should be a positive integer value, " 92 | "but got batch_size={}".format(batch_size)) 93 | if not isinstance(drop_last, bool): 94 | raise ValueError("drop_last should be a boolean value, but got " 95 | "drop_last={}".format(drop_last)) 96 | self.batch_size = batch_size 97 | self.drop_last = drop_last 98 | start_idx = start_idx if start_idx is not None else 0 99 | self.sampler = RandomSampler(data_source, generator=generator, start_idx = start_idx*batch_size) 100 | 101 | def __iter__(self) -> Iterator[List[int]]: 102 | # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 103 | if self.drop_last: 104 | sampler_iter = iter(self.sampler) 105 | while True: 106 | try: 107 | batch = [next(sampler_iter) for _ in range(self.batch_size)] 108 | yield batch 109 | except StopIteration: 110 | break 111 | else: 112 | batch = [0] * self.batch_size 113 | idx_in_batch = 0 114 | for idx in self.sampler: 115 | batch[idx_in_batch] = idx 116 | idx_in_batch += 1 117 | if idx_in_batch == self.batch_size: 118 | yield batch 119 | idx_in_batch = 0 120 | batch = [0] * self.batch_size 121 | if idx_in_batch > 0: 122 | yield batch[:idx_in_batch] 123 | 124 | def __len__(self) -> int: 125 | # Can only be called if self.sampler has __len__ implemented 126 | # We cannot enforce this condition, so we turn off typechecking for the 127 | # implementation below. 128 | # Somewhat related: see NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ] 129 | if self.drop_last: 130 | return len(self.sampler) // self.batch_size # type: ignore[arg-type] 131 | else: 132 | return (len(self.sampler) + self.batch_size - 1) // self.batch_size # type: ignore[arg-type] -------------------------------------------------------------------------------- /config/vqvae.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from trajectory.utils import watch 8 | 9 | #------------------------ base ------------------------# 10 | 11 | logbase = '~/logs/' 12 | gpt_expname = 'vae/vq' 13 | 14 | ## automatically make experiment names for planning 15 | ## by labelling folders with these args 16 | args_to_watch = [ 17 | ('prefix', ''), 18 | ('plan_freq', 'freq'), 19 | ('horizon', 'H'), 20 | ('beam_width', 'beam'), 21 | ] 22 | 23 | base = { 24 | 'train': { 25 | 'type': "prior", 26 | 'logbase': logbase, 27 | 'model': "VQTransformer", 28 | 'tag': "experiment", 29 | 'state_conditional': True, 30 | #'encoder_inputs': ["state", "action", "reward", "return", "terminal"], 31 | 'encoder_inputs': ["state", "action", "mask"], 32 | 'ae_type': "AttentionCNN", 33 | 'N': 100, 34 | 'discount': 0.995, 35 | 'iql_critic': False, 36 | 'downstream_task': "prior", 37 | 'n_layer': 4, 38 | 'n_head': 4, 39 | 'prior_layer': 4, 40 | 'prior_head': 4, 41 | 'value_layer': 4, 42 | 'value_head': 4, 43 | 44 | 'prior_learning_rate': 3e-4, 45 | "num_workers": 2, 46 | 'latent_step': 4, 47 | 'code_per_step': 8, 48 | 'causal_attention': True, 49 | 'causal_conv': True, 50 | 'n_embd': 128, 51 | 'prior_embd': 128, 52 | 'value_embd': 128, 53 | 'trajectory_embd': 1024, 54 | 'K': 512, 55 | 'blocks_per_layer': 2, 56 | 'load_batch_size': 128, 57 | 'train_batch_size': 128, 58 | 'learning_rate': 2e-4, 59 | 'lr_decay': True, 60 | 'seed': 42, 61 | 'device': 'cuda', 62 | 'n_epochs_ref': 1200, 63 | 'n_saves': 3, 64 | 'tau': 0.8, 65 | 66 | 'embd_pdrop': 0.0, 67 | 'resid_pdrop': 0.1, 68 | 'attn_pdrop': 0.1, 69 | 70 | 'step': 1, 71 | 'subsampled_sequence_length': 17, 72 | 'termination_penalty': -100, 73 | 'exp_name': gpt_expname, 74 | 75 | 'position_weight': 1, 76 | 'action_weight': 5, 77 | 'reward_weight': 0.0, 78 | 'value_weight': 0.0, 79 | 'prior_value_weight': 1.0, 80 | 81 | 'suffix': '', 82 | 83 | "normalize": True, 84 | "normalize_reward": False, 85 | "max_path_length": 1000, 86 | "bottleneck": "pooling", 87 | "masking": "none", 88 | "disable_goal": False, 89 | "residual": "absolute", 90 | "ma_update": True, 91 | "use_discriminator": False, 92 | "disc_start": 0.1, 93 | 94 | "position_embedding": "absolute", 95 | "keep_hdf5s_open": True, 96 | "n_tokens_target": 1e6, 97 | # debug 98 | "debug": False, 99 | 100 | "enable_fp16": True, 101 | "enable_prior_fp16": True, 102 | "bootstrap": True, 103 | "bootstrap_ignore_terminal": False, 104 | "twohot_value": False, 105 | "symlog": False, 106 | "datasource": "auto", 107 | "ignore_incomplete_episodes": True, 108 | 109 | "validation_episodes": 200, 110 | 111 | "data_parallel": False, 112 | "pretrained_model": "", 113 | 114 | 'relabel_type': "speed", 115 | 'reward_clip': "none", 116 | 'body_height_penalty': 0.0, 117 | 'body_height_limit': 0.8, 118 | 'value_policy_gradient_ratio': 0.5, 119 | 'value_ema_rate': 0.998, 120 | "prior_gradient_norm_clip": 1.0, 121 | 'initial_value_weight': 0.2, 122 | "cql_weight": 0.1, 123 | 124 | "prior_name": "", # only used for ciritc training 125 | "critic_name": "", # only used for prior finetuning 126 | }, 127 | 128 | 'plan': { 129 | "task": "", 130 | "vae_name": "", 131 | "prior_name": "", 132 | "critic_name": "", 133 | 'discrete': False, 134 | 'logbase': logbase, 135 | 'gpt_loadpath': gpt_expname, 136 | 'gpt_epoch': 'latest', 137 | 'device': 'cuda', 138 | 'renderer': 'Renderer', 139 | 'suffix': '0', 140 | 141 | 'plan_freq': 1, 142 | 'horizon': 16, 143 | 'temperature': 1.0, 144 | 145 | "rounds": 2, 146 | "nb_samples": 4096, 147 | "k": 16, 148 | 149 | 'beam_width': 64, 150 | 'n_expand': 4, 151 | 152 | 'prob_threshold': 1e-10, 153 | 'prob_weight': 5e3, 154 | 155 | 'advantage_weight': 1.0, 156 | 'top_p': 0.99, 157 | "discount": 1.0, 158 | 159 | 'vis_freq': 200, 160 | 'exp_name': watch(args_to_watch), 161 | 'verbose': True, 162 | 'uniform': False, 163 | 164 | # Planner 165 | 'test_planner': 'beam_prior', 166 | # debug 167 | "debug": False, 168 | # clip name subset of prior training data, e.g. ["CMU_001", "CMU_001"] 169 | "prior_data_subset": [], 170 | "objective": 'reward', 171 | }, 172 | } 173 | 174 | #------------------------ locomotion ------------------------# 175 | 176 | hammer_cloned_v0 = hammer_human_v0 = human_expert_v0 = relocate_cloned_v0 = relocate_human_v0 = relocate_expert_v0 = door_cloned_v0 = door_human_v0 = door_expert_v0 = { 177 | 'train': { 178 | "termination_penalty": None, 179 | "max_path_length": 200, 180 | 'n_epochs_ref': 10, 181 | 'subsampled_sequence_length': 25, 182 | 'n_layer': 3, 183 | 'n_embd': 64, 184 | }, 185 | 'plan': { 186 | 'horizon': 24, 187 | } 188 | } 189 | 190 | pen_cloned_v0 = pen_expert_v0 = pen_human_v0 = { 191 | 'train': { 192 | "termination_penalty": None, 193 | "max_path_length": 100, 194 | 'n_epochs_ref': 10, 195 | 'subsampled_sequence_length': 25, 196 | 'n_layer': 3, 197 | 'n_embd': 16, 198 | 'K': 32, 199 | 'prior_layer': 2, 200 | 'prior_embd': 64, 201 | 'prior_head': 2, 202 | }, 203 | 'plan': { 204 | 'prob_weight': 5e2, 205 | 'horizon': 24, 206 | } 207 | } 208 | 209 | antmaze_ultra_diverse_v0=antmaze_ultra_play_v0=antmaze_large_diverse_v0=antmaze_large_play_v0=antmaze_medium_diverse_v0=antmaze_medium_play_v0=antmaze_umaze_v0={ 210 | 'train':{ 211 | "disable_goal": False, 212 | "termination_penalty": None, 213 | "max_path_length": 1001, 214 | "normalize": False, 215 | "normalize_reward": False, 216 | 'lr_decay': False, 217 | 'K': 8192, 218 | "discount": 0.998, 219 | 'value_weight': 0.0001, 220 | 'subsampled_sequence_length': 16, 221 | }, 222 | 'plan': { 223 | 'iql_value': False, 224 | 'horizon': 15, 225 | 'vis_freq': 200, 226 | 'renderer': "AntMazeRenderer", 227 | 'beam_width': 2, 228 | 'n_expand': 4, 229 | } 230 | } 231 | 232 | mocapact= { 233 | 'train': { 234 | 'termination_penalty': None, 235 | } 236 | } 237 | -------------------------------------------------------------------------------- /trajectory/models/cnn.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torch import nn 9 | from einops.layers.torch import Rearrange 10 | 11 | class CausalConv1d(nn.Module): 12 | def __init__(self, in_channels, out_channels, kernel_size, dilation, stride): 13 | super(CausalConv1d, self).__init__() 14 | self.in_channels = in_channels 15 | self.out_channels = out_channels 16 | self.kernel_size = kernel_size 17 | self.padding = dilation*(kernel_size-1) 18 | self.stride = stride 19 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, padding=self.padding, dilation=dilation, stride=stride) 20 | 21 | def forward(self, x): 22 | """ 23 | shape of x: [total_seq, num_features, num_timesteps] 24 | """ 25 | 26 | x = self.conv(x) 27 | last_n = (2*self.padding-self.kernel_size)//self.stride + 1 28 | if last_n> 0: 29 | return x[:, :, :-last_n] 30 | else: 31 | return x 32 | 33 | 34 | class Conv1dBlock(nn.Module): 35 | ''' 36 | Conv1d --> GroupNorm --> Mish 37 | from https://github.com/jannerm/diffuser/blob/06b8e6a042e6a3312d50ed8048cba14afeab3085/diffuser/models/helpers.py#L46 38 | ''' 39 | def __init__(self, inp_channels, out_channels, kernel_size, stride, n_groups=8, causal=True): 40 | super().__init__() 41 | if causal: 42 | conv = CausalConv1d(inp_channels, out_channels, kernel_size, dilation=1, stride=stride) 43 | else: 44 | conv = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size//2, stride=stride) 45 | 46 | self.block = nn.Sequential( 47 | conv, 48 | Rearrange('batch channels horizon -> batch channels 1 horizon'), 49 | nn.GroupNorm(n_groups, out_channels), 50 | Rearrange('batch channels 1 horizon -> batch channels horizon'), 51 | nn.Mish(), 52 | ) 53 | 54 | def forward(self, x): 55 | return self.block(x) 56 | 57 | class CausalDeConv1d(nn.Module): 58 | def __init__(self, in_channels, out_channels, kernel_size, dilation, stride): 59 | super(CausalDeConv1d, self).__init__() 60 | self.in_channels = in_channels 61 | self.out_channels = out_channels 62 | self.kernel_size = kernel_size 63 | self.stride = stride 64 | # TODO: need to be double checked 65 | self.conv = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride) 66 | 67 | def forward(self, x): 68 | """ 69 | shape of x: [total_seq, num_features, num_timesteps] 70 | """ 71 | x = self.conv(x) 72 | last_n = self.kernel_size-self.stride 73 | if last_n> 0: 74 | return x[:, :, :-last_n] 75 | else: 76 | return x 77 | 78 | 79 | class DeConv1dBlock(nn.Module): 80 | ''' 81 | Conv1d --> GroupNorm --> Mish 82 | from https://github.com/jannerm/diffuser/blob/06b8e6a042e6a3312d50ed8048cba14afeab3085/diffuser/models/helpers.py#L46 83 | ''' 84 | 85 | def __init__(self, inp_channels, out_channels, kernel_size, stride, n_groups=8, causal=True): 86 | super().__init__() 87 | if causal: 88 | conv = CausalDeConv1d(inp_channels, out_channels, kernel_size, dilation=1, stride=stride), 89 | else: 90 | conv = nn.ConvTranspose1d(inp_channels, out_channels, kernel_size, padding=kernel_size//2, stride=stride) 91 | 92 | self.block = nn.Sequential( 93 | conv, 94 | Rearrange('batch channels horizon -> batch channels 1 horizon'), 95 | nn.GroupNorm(n_groups, out_channels), 96 | Rearrange('batch channels 1 horizon -> batch channels horizon'), 97 | nn.Mish(), 98 | ) 99 | 100 | def forward(self, x): 101 | return self.block(x) 102 | 103 | 104 | 105 | class ResidualTemporalBlock(nn.Module): 106 | def __init__(self, inp_channels, out_channels, kernel_size=3, pooling=1, stride=1, causal=True): 107 | super().__init__() 108 | 109 | second_kernel_size = kernel_size if stride==1 else stride 110 | 111 | self.blocks = nn.ModuleList([ 112 | Conv1dBlock(inp_channels, out_channels, kernel_size, 1, causal=causal), 113 | Conv1dBlock(out_channels, out_channels, second_kernel_size, stride=stride, causal=causal), 114 | ]) 115 | 116 | if out_channels == inp_channels and stride == 1: 117 | self.residual_conv = nn.Identity() 118 | else: 119 | self.residual_conv = nn.Conv1d(inp_channels, out_channels, kernel_size=1, stride=stride) 120 | 121 | if pooling==1: 122 | self.pooling = nn.Identity() 123 | else: 124 | self.pooling = nn.MaxPool1d(pooling, stride=pooling) 125 | 126 | def forward(self, input_dict): 127 | ''' 128 | x : [ batch_size x horizon x inp_channels ] 129 | returns: 130 | out : [ batch_size x horizon x out_channels ] 131 | ''' 132 | if isinstance(input_dict, dict): 133 | x = input_dict['x'] 134 | input_mask = input_dict['input_mask'] 135 | else: 136 | x = input_dict 137 | input_mask = None 138 | if input_mask is not None: 139 | x = x * input_mask.view(x.shape[0], x.shape[1], 1) 140 | x = torch.transpose(x, 1, 2) 141 | out = self.blocks[0](x) 142 | out = self.blocks[1](out) 143 | out = out + self.residual_conv(x) 144 | out = self.pooling(out) 145 | return torch.transpose(out, 1, 2) 146 | 147 | 148 | class ResidualTemporalDeConvBlock(nn.Module): 149 | def __init__(self, inp_channels, out_channels, kernel_size=3, pooling=1, stride=1, causal=True): 150 | super().__init__() 151 | second_kernel_size = kernel_size if stride==1 else stride 152 | 153 | self.blocks = nn.ModuleList([ 154 | DeConv1dBlock(inp_channels, out_channels, kernel_size, 1, causal=causal), 155 | DeConv1dBlock(out_channels, out_channels, second_kernel_size, stride=stride, causal=causal), 156 | ]) 157 | 158 | if out_channels == inp_channels and stride==1: 159 | self.residual_conv = nn.Identity() 160 | else: 161 | self.residual_conv = nn.ConvTranspose1d(inp_channels, out_channels, kernel_size=stride, stride=stride) 162 | if pooling==1: 163 | self.pooling = nn.Identity() 164 | else: 165 | self.pooling = nn.MaxPool1d(pooling, stride=pooling) 166 | 167 | def forward(self, input_dict): 168 | ''' 169 | x : [ batch_size x inp_channels x horizon ] 170 | returns: 171 | out : [ batch_size x out_channels x horizon ] 172 | ''' 173 | if isinstance(input_dict, dict): 174 | x = input_dict['x'] 175 | input_mask = input_dict['input_mask'] 176 | else: 177 | x = input_dict 178 | input_mask = None 179 | 180 | if input_mask is not None: 181 | x = x * input_mask.view(x.shape[0], x.shape[1], 1) 182 | x = torch.transpose(x, 1, 2) 183 | out = self.blocks[0](x) 184 | out = self.blocks[1](out) 185 | out = out + self.residual_conv(x) 186 | out = self.pooling(out) 187 | return torch.transpose(out, 1, 2) 188 | -------------------------------------------------------------------------------- /trajectory/datasets/d4rl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import numpy as np 9 | import gym 10 | import random 11 | from tasks.dmcontrol import CMUHumanoidGymWrapper, DMCWrapper 12 | from tasks.tracking import MocapTrackingGymEnv 13 | from tasks.motion_completion import MotionCompletionGymEnv 14 | 15 | from contextlib import ( 16 | contextmanager, 17 | redirect_stderr, 18 | redirect_stdout, 19 | ) 20 | 21 | @contextmanager 22 | def suppress_output(): 23 | """ 24 | A context manager that redirects stdout and stderr to devnull 25 | https://stackoverflow.com/a/52442331 26 | """ 27 | with open(os.devnull, 'w') as fnull: 28 | with redirect_stderr(fnull) as err, redirect_stdout(fnull) as out: 29 | yield (err, out) 30 | 31 | with suppress_output(): 32 | ## d4rl prints out a variety of warnings 33 | import d4rl 34 | 35 | def qlearning_dataset_with_timeouts(env, dataset=None, terminate_on_end=False, disable_goal=False, **kwargs): 36 | if dataset is None: 37 | dataset = env.get_dataset(**kwargs) 38 | 39 | N = dataset['rewards'].shape[0] 40 | obs_ = [] 41 | next_obs_ = [] 42 | action_ = [] 43 | reward_ = [] 44 | done_ = [] 45 | realdone_ = [] 46 | if "infos/goal" in dataset: 47 | if not disable_goal: 48 | dataset["observations"] = np.concatenate([dataset["observations"], dataset['infos/goal']], axis=1) 49 | else: 50 | dataset["observations"] = np.concatenate([dataset["observations"], np.zeros([dataset["observations"].shape[0], 2], dtype=np.float32)], 51 | axis=1) 52 | dataset["rewards"] = dataset["rewards"] - 1 53 | 54 | episode_step = 0 55 | for i in range(N-1): 56 | obs = dataset['observations'][i] 57 | new_obs = dataset['observations'][i+1] 58 | action = dataset['actions'][i] 59 | reward = dataset['rewards'][i] 60 | done_bool = bool(dataset['terminals'][i]) 61 | realdone_bool = bool(dataset['terminals'][i]) 62 | if "infos/goal" in dataset: 63 | final_timestep = True if (dataset['infos/goal'][i] != dataset['infos/goal'][i+1]).any() else False 64 | else: 65 | final_timestep = dataset['timeouts'][i] 66 | 67 | if i < N - 1: 68 | done_bool += final_timestep 69 | 70 | if (not terminate_on_end) and final_timestep: 71 | # Skip this transition and don't apply terminals on the last step of an episode 72 | episode_step = 0 73 | continue 74 | if done_bool or final_timestep: 75 | episode_step = 0 76 | 77 | obs_.append(obs) 78 | next_obs_.append(new_obs) 79 | action_.append(action) 80 | reward_.append(reward) 81 | done_.append(done_bool) 82 | realdone_.append(realdone_bool) 83 | episode_step += 1 84 | 85 | return { 86 | 'observations': np.array(obs_), 87 | 'actions': np.array(action_), 88 | 'next_observations': np.array(next_obs_), 89 | 'rewards': np.array(reward_)[:,None], 90 | 'terminals': np.array(done_)[:,None], 91 | 'realterminals': np.array(realdone_)[:,None], 92 | } 93 | 94 | 95 | def softmax(x): 96 | """Compute softmax values for each sets of scores in x.""" 97 | e_x = np.exp(x - np.max(x)) 98 | return e_x / e_x.sum(axis=0) 99 | 100 | 101 | def load_environment(name): 102 | from dm_control.locomotion.tasks import go_to_target, corridors 103 | from dm_control.locomotion.tasks.reference_pose import tracking 104 | if name == 'goto': 105 | CONTROL_TIMESTEP = 0.03 106 | #constructor = corridors.RunThroughCorridor 107 | constructor = go_to_target.GoToTarget 108 | env_ctor = CMUHumanoidGymWrapper.make_env_constructor(constructor) 109 | task_kwargs = dict( 110 | physics_timestep=tracking.DEFAULT_PHYSICS_TIMESTEP, 111 | control_timestep=CONTROL_TIMESTEP, 112 | moving_target=True, 113 | ) 114 | environment_kwargs = dict( 115 | time_limit=CONTROL_TIMESTEP * 1e7, 116 | random_state=np.random.randint(1e6) 117 | ) 118 | arena_size = (8., 8.) 119 | env = env_ctor( 120 | task_kwargs=task_kwargs, 121 | environment_kwargs=environment_kwargs, 122 | arena_size=arena_size 123 | ) 124 | elif name == 'corridor': 125 | CONTROL_TIMESTEP = 0.03 126 | constructor = corridors.RunThroughCorridor 127 | env_ctor = CMUHumanoidGymWrapper.make_env_constructor(constructor) 128 | task_kwargs = dict( 129 | physics_timestep=tracking.DEFAULT_PHYSICS_TIMESTEP, 130 | control_timestep=CONTROL_TIMESTEP, 131 | contact_termination=False, 132 | ) 133 | environment_kwargs = dict( 134 | time_limit=CONTROL_TIMESTEP * 400, 135 | random_state=np.random.randint(1e6), 136 | ) 137 | arena_size = (8., 8.) 138 | env = env_ctor( 139 | task_kwargs=task_kwargs, 140 | environment_kwargs=environment_kwargs, 141 | arena_size=arena_size 142 | ) 143 | elif name in ['speed', 'forward', 'rotate_x', 'rotate_y', 'rotate_z', 'x_vel', 'y_vel', 'z_vel', 'negative_z_vel', 'jump', 144 | 'backward', 'shift_left']: 145 | from tasks.cmu_relabel import CMURelabelTask 146 | CONTROL_TIMESTEP = 0.03 147 | constructor = CMURelabelTask 148 | env_ctor = CMUHumanoidGymWrapper.make_env_constructor(constructor) 149 | relabel_type = name 150 | contact_termination = False if relabel_type in ['rotate_x', 'rotate_z'] else True 151 | task_kwargs = dict( 152 | physics_timestep=tracking.DEFAULT_PHYSICS_TIMESTEP, 153 | control_timestep=CONTROL_TIMESTEP, 154 | relabel_type=relabel_type, 155 | contact_termination=contact_termination, 156 | ) 157 | 158 | environment_kwargs = dict( 159 | time_limit=CONTROL_TIMESTEP * 400, 160 | random_state=np.random.randint(1e6), 161 | ) 162 | arena_size = (8., 8.) 163 | env = env_ctor( 164 | task_kwargs=task_kwargs, 165 | environment_kwargs=environment_kwargs, 166 | arena_size=arena_size 167 | ) 168 | elif "tracking" in name: 169 | if ":" in name: 170 | name, dataset = name.split(":") 171 | env = MocapTrackingGymEnv(dataset=dataset) 172 | elif "motioncompletion" in name: 173 | # set up environment 174 | ghost_offset = 0. 175 | prompt_length = 0 176 | task_kwargs = dict( 177 | ghost_offset=np.array([ghost_offset, 0., 0.]), 178 | always_init_at_clip_start=False, 179 | termination_error_threshold=0.3, 180 | min_steps=10, 181 | max_steps=None, 182 | steps_before_color_change=prompt_length 183 | ) 184 | if ":" in name: 185 | name, dataset = name.split(":") 186 | env = MotionCompletionGymEnv(dataset=dataset, task_kwargs=task_kwargs) 187 | elif "." in name: 188 | domain_name, task_name = name.split(".") 189 | env = DMCWrapper(domain_name, task_name) 190 | else: 191 | with suppress_output(): 192 | wrapped_env = gym.make(name) 193 | 194 | env = wrapped_env.unwrapped 195 | env.max_episode_steps = wrapped_env._max_episode_steps 196 | 197 | env.name = name 198 | return env -------------------------------------------------------------------------------- /scripts/humanoid_plan.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import pdb 9 | from os.path import join 10 | from trajectory.utils.rendering import HumanoidRnederer 11 | import torch 12 | import os 13 | import numpy as np 14 | import imageio 15 | 16 | from accelerate.logging import get_logger 17 | 18 | os.environ["CUDA_HOME"] = os.environ["CONDA_PREFIX"] 19 | torch.backends.cuda.matmul.allow_tf32 = True 20 | 21 | import trajectory.utils as utils 22 | import trajectory.datasets as datasets 23 | from trajectory.search import ( 24 | sample_with_prior, 25 | beam_with_prior, 26 | make_prefix, 27 | extract_actions, 28 | ) 29 | 30 | class Parser(utils.Parser): 31 | dataset: str = 'mocapact' 32 | config: str = 'config.vqvae' 33 | 34 | 35 | def main(): 36 | ####################### 37 | ######## setup ######## 38 | ####################### 39 | 40 | args = Parser().parse_args('plan') 41 | 42 | if "vae_name" in args.__dict__ and args.vae_name != "": 43 | vae_name = args.vae_name 44 | else: 45 | vae_name = args.exp_name 46 | print(args.vae_name) 47 | 48 | logger = get_logger(__name__, log_level="DEBUG") 49 | 50 | ####################### 51 | ####### models ######## 52 | ####################### 53 | 54 | if "." in args.dataset: 55 | args.task = args.dataset 56 | if "task" not in args.__dict__ or args.task == "": 57 | args.task = Parser().parse_args('train').relabel_type 58 | 59 | env = datasets.load_environment(args.task) 60 | 61 | vae, gpt_epoch = utils.load_model(logger, args.logbase, args.dataset, vae_name, 62 | epoch=args.gpt_epoch, device=args.device) 63 | 64 | prior, _ = utils.load_transformer_model(logger, args.logbase, args.dataset, args.prior_name, 65 | epoch=args.gpt_epoch, device=args.device) 66 | 67 | if args.critic_name != "": 68 | critic, _ = utils.load_transformer_model(logger, args.logbase, args.dataset, args.critic_name, 69 | epoch=args.gpt_epoch, device=args.device, type='critic') 70 | 71 | vae.set_padding_vector(np.zeros(vae.transition_dim - 1)) 72 | ####################### 73 | ####### dataset ####### 74 | ####################### 75 | renderer = HumanoidRnederer(datasets.load_environment(args.task), observation_dim=vae.observation_dim) 76 | 77 | 78 | if args.critic_name != "": 79 | dataset = utils.load_from_config(logger, args.logbase, args.dataset, args.critic_name, 80 | 'data_config.pkl') 81 | else: 82 | dataset = utils.load_from_config(logger, args.logbase, args.dataset, args.prior_name, 83 | 'data_config.pkl') 84 | 85 | timer = utils.timer.Timer() 86 | 87 | discount = dataset.discount 88 | observation_dim = dataset.observation_dim 89 | action_dim = dataset.action_dim 90 | 91 | preprocess_fn = datasets.get_preprocess_fn(env.name) 92 | 93 | ####################### 94 | ###### main loop ###### 95 | ####################### 96 | REWARD_DIM = VALUE_DIM = 1 97 | transition_dim = observation_dim + action_dim + REWARD_DIM + VALUE_DIM 98 | time_step = env.dm_env.reset() 99 | observation = env.get_observation(time_step) 100 | total_reward = 0 101 | discount_return = 0 102 | frames = [] 103 | 104 | ## previous (tokenized) transitions for conditioning transformer 105 | context = [] 106 | values = [] 107 | 108 | T = 400 109 | vae.eval() 110 | for t in range(T): 111 | observation = preprocess_fn(observation) 112 | 113 | if dataset.normalize_obs: 114 | observation = dataset.normalize_observations(observation) 115 | 116 | if t % args.plan_freq == 0: 117 | ## concatenate previous transitions and current observations to input to model 118 | prefix = make_prefix(context, observation, transition_dim, device=args.device)[-1, -1, None, None] 119 | 120 | ## sample sequence from model beginning with `prefix` 121 | if args.test_planner == 'beam_with_prior': 122 | prior.eval() 123 | sequence = beam_with_prior(prior, vae, prefix, 124 | denormalize_rew=dataset.denormalize_reward, 125 | denormalize_val=dataset.denormalize_return, 126 | discount=args.discount, 127 | steps=args.horizon, 128 | optimize_target=args.objective, 129 | beam_width=args.beam_width, 130 | n_expand=args.n_expand) 131 | elif args.test_planner == 'beam_prior_perturb': 132 | prior.eval() 133 | sequence, value = beam_with_prior_perturb(prior, critic, vae, prefix, 134 | args.horizon, 135 | beam_width=args.beam_width, 136 | n_expand=args.n_expand, 137 | prob_threshold=args.prob_threshold, 138 | ood_weight=args.prob_weight, 139 | temperature=args.temperature, 140 | normalize_value=critic.normalize_returns, 141 | advantage_weight=args.advantage_weight, 142 | ) 143 | elif args.test_planner == "sample_with_prior": 144 | prior.eval() 145 | sequence = sample_with_prior(prior, vae, prefix, None, None, args.discount, args.horizon, nb_samples=args.nb_samples, rounds=1, 146 | likelihood_weight=args.prob_weight, prob_threshold=args.prob_threshold, top_p=args.top_p, 147 | temperature=args.temperature, objective=args.objective) 148 | else: 149 | raise NotImplementedError(f"Unknown planner type {args.test_planner}.") 150 | else: 151 | sequence = sequence[1:] 152 | 153 | ## [ horizon x transition_dim ] convert sampled tokens to continuous trajectory 154 | sequence_recon = sequence 155 | 156 | ## [ action_dim ] index into sampled trajectory to grab first action 157 | feature_dim = dataset.observation_dim 158 | action = extract_actions(sequence_recon, feature_dim, action_dim, t=0) 159 | if dataset.normalize_act: 160 | action = dataset.denormalize_actions(action) 161 | 162 | ## execute action in environment 163 | time_step = env.dm_env.step(action) 164 | next_observation = env.get_observation(time_step) 165 | reward = time_step.reward 166 | terminal = time_step.last() 167 | 168 | 169 | ## update return 170 | total_reward += reward 171 | discount_return += reward* discount**(t) 172 | 173 | img = env.render() 174 | frames.append(img) 175 | 176 | print( 177 | f'[ plan ] t: {t} / {T} | r: {reward:.2f} | R: {total_reward:.2f} | ' 178 | f'time: {timer():.4f} | {args.dataset} | {args.exp_name} | {args.suffix}\n' 179 | ) 180 | 181 | ## visualization 182 | if t % args.vis_freq == 0 or terminal or t == T-1: 183 | if not os.path.exists(args.savepath): 184 | os.makedirs(args.savepath) 185 | if terminal: break 186 | observation = next_observation 187 | 188 | imageio.mimsave(join(args.savepath, 'rollout.gif'), frames) 189 | ## save result as a json file 190 | json_path = join(args.savepath, 'rollout.json') 191 | json_data = {'step': t, 'return': float(total_reward), 'term': terminal, 'gpt_epoch': gpt_epoch, 'value_mean': float(np.mean(values)), 192 | 'first_value': np.nan, 'first_search_value': np.nan, 'discount_return': float(discount_return), 193 | 'prediction_error': np.nan} 194 | json.dump(json_data, open(json_path, 'w'), indent=2, sort_keys=True) 195 | 196 | if __name__ == '__main__': 197 | main() 198 | -------------------------------------------------------------------------------- /trajectory/utils/relabel_humanoid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | import torch 10 | 11 | OBSERVABLES_DIM = [('actuator_activation', 56), 12 | ('appendages_pos', 15), 13 | ('body_height', 1), 14 | ('end_effectors_pos', 12), 15 | ('joints_pos', 56), 16 | ('joints_vel', 56), 17 | ('sensors_accelerometer', 3), 18 | ('sensors_gyro', 3), 19 | ('sensors_torque', 6), 20 | ('sensors_touch', 10), 21 | ('sensors_velocimeter', 3), 22 | ('world_zaxis', 3)] 23 | 24 | def get_observable_range(observable_name): 25 | """ 26 | get the start and end index of the observable in the proprioceptive_obs 27 | """ 28 | start = 0 29 | for name, dim in OBSERVABLES_DIM: 30 | if name == observable_name: 31 | return start, start + dim 32 | start += dim 33 | raise ValueError(f"observable_name {observable_name} not found") 34 | 35 | def slice_observable(proprioceptive_obs, observable_name): 36 | """ 37 | get the observable from proprioceptive_obs 38 | """ 39 | start, end = get_observable_range(observable_name) 40 | return proprioceptive_obs[..., start:end] 41 | 42 | def get_angular_vel(proprioceptive_obs, direction): 43 | """ 44 | get the angular speed of the humanoid in from proprioceptive_obs. 45 | """ 46 | angular_vel = slice_observable(proprioceptive_obs, 'sensors_gyro') 47 | if direction == 'x': 48 | angular_vel = angular_vel[..., 0] 49 | elif direction == 'y': 50 | angular_vel = angular_vel[..., 1] 51 | elif direction == 'z': 52 | angular_vel = angular_vel[..., 2] 53 | 54 | return angular_vel 55 | 56 | def get_left_vel(proprioceptive_obs): 57 | """ 58 | get the left speed of the humanoid in from proprioceptive_obs. 59 | cancel the sub-velocity towards world z axis according to the ego-centric world z axis. 60 | """ 61 | ego_centric_vel = slice_observable(proprioceptive_obs, 'sensors_velocimeter') 62 | world_zaxis = slice_observable(proprioceptive_obs, 'world_zaxis') 63 | left_vel = project_left(ego_centric_vel, world_zaxis) 64 | return left_vel 65 | 66 | def project_left(ego_centric_vel, world_zaxis): 67 | """ 68 | project the velocity to the left direction 69 | """ 70 | world_z_projected_to_ego_xy = world_zaxis[..., :-1] 71 | if isinstance(ego_centric_vel, tf.Tensor): 72 | world_z_normalized_to_ego_xy = tf.math.l2_normalize(world_z_projected_to_ego_xy, axis=-1) 73 | elif isinstance(ego_centric_vel, np.ndarray): 74 | world_z_normalized_to_ego_xy = world_z_projected_to_ego_xy / np.linalg.norm(world_z_projected_to_ego_xy, axis=-1, keepdims=True) 75 | elif isinstance(ego_centric_vel, torch.Tensor): 76 | world_z_normalized_to_ego_xy = torch.nn.functional.normalize(world_z_projected_to_ego_xy, dim=-1) 77 | else: 78 | raise ValueError(f"ego_centric_vel type {type(ego_centric_vel)} not supported") 79 | left_vel = world_z_normalized_to_ego_xy[..., 1] * ego_centric_vel[..., 0] 80 | return left_vel 81 | 82 | def get_forward_vel(proprioceptive_obs): 83 | """ 84 | get the forward speed of the humanoid in from proprioceptive_obs. 85 | cancel the sub-velocity towards world z axis according to the ego-centric world z axis. 86 | """ 87 | ego_centric_vel = slice_observable(proprioceptive_obs, 'sensors_velocimeter') 88 | world_zaxis = slice_observable(proprioceptive_obs, 'world_zaxis') 89 | forward_vel = project_forward(ego_centric_vel, world_zaxis) 90 | return forward_vel 91 | 92 | def project_forward(ego_centric_vel, world_zaxis): 93 | """ 94 | project the velocity to the forward direction 95 | """ 96 | world_z_projected_to_ego_yz = world_zaxis[..., 1:] 97 | if isinstance(ego_centric_vel, tf.Tensor): 98 | world_z_normalized_to_ego_yz = tf.math.l2_normalize(world_z_projected_to_ego_yz, axis=-1) 99 | elif isinstance(ego_centric_vel, np.ndarray): 100 | world_z_normalized_to_ego_yz = world_z_projected_to_ego_yz / np.linalg.norm(world_z_projected_to_ego_yz, axis=-1, keepdims=True) 101 | elif isinstance(ego_centric_vel, torch.Tensor): 102 | world_z_normalized_to_ego_yz = torch.nn.functional.normalize(world_z_projected_to_ego_yz, dim=-1) 103 | else: 104 | raise ValueError(f"ego_centric_vel type {type(ego_centric_vel)} not supported") 105 | forward_vel = world_z_normalized_to_ego_yz[..., 0] * ego_centric_vel[..., 2] 106 | return forward_vel 107 | 108 | def get_height_vel(proprioceptive_obs): 109 | """ 110 | get the height speed of the humanoid in from proprioceptive_obs. 111 | """ 112 | ego_centric_vel = slice_observable(proprioceptive_obs, 'sensors_velocimeter') 113 | world_zaxis = slice_observable(proprioceptive_obs, 'world_zaxis') 114 | height_vel = project_height(ego_centric_vel, world_zaxis) 115 | return height_vel 116 | 117 | def project_height(ego_centric_vel, world_zaxis): 118 | """ 119 | project the velocity to the height direction 120 | """ 121 | if isinstance(ego_centric_vel, tf.Tensor): 122 | height_vel = tf.reduce_sum(ego_centric_vel * world_zaxis, axis=-1) 123 | elif isinstance(ego_centric_vel, np.ndarray): 124 | height_vel = np.sum(ego_centric_vel * world_zaxis, axis=-1) 125 | elif isinstance(ego_centric_vel, torch.Tensor): 126 | height_vel = torch.sum(ego_centric_vel * world_zaxis, dim=-1) 127 | else: 128 | raise ValueError("ego_centric_vel should be either tf.Tensor or np.ndarray") 129 | return height_vel 130 | 131 | 132 | def get_vel(proprioceptive_obs, direction): 133 | current_vel = slice_observable(proprioceptive_obs, 'sensors_velocimeter') 134 | if direction == 'x': 135 | vel = current_vel[..., 0] 136 | elif direction == 'y': 137 | vel = current_vel[..., 1] 138 | elif direction == 'z': 139 | vel = current_vel[..., 2] 140 | else: 141 | raise ValueError(f"direction {direction} not found") 142 | return vel 143 | 144 | 145 | def get_body_height(proprioceptive_obs): 146 | """ 147 | get the height of the humanoid in from proprioceptive_obs. 148 | """ 149 | return slice_observable(proprioceptive_obs, 'body_height')[..., 0] 150 | 151 | def get_speed(proprioceptive_obs): 152 | """ 153 | get the speed of the humanoid in from proprioceptive_obs. 154 | Note that the speed is a non-negative scalar and do not indicate direction (different from velocity) 155 | """ 156 | current_vel = slice_observable(proprioceptive_obs, 'sensors_velocimeter') 157 | if isinstance(current_vel, tf.Tensor): 158 | speed = tf.norm(current_vel, axis=-1) 159 | elif isinstance(current_vel, np.ndarray): 160 | speed = np.linalg.norm(current_vel, axis=-1) 161 | elif isinstance(current_vel, torch.Tensor): 162 | speed = torch.norm(current_vel, dim=-1) 163 | return speed 164 | 165 | 166 | def get_x_speed(proprioceptive_obs): 167 | """ 168 | get the speed of the humanoid in x direction from proprioceptive_obs. 169 | Note that the speed is a non-negative scalar and do not indicate direction (different from velocity) 170 | """ 171 | current_vel = proprioceptive_obs[..., -6, None] 172 | if isinstance(current_vel, tf.Tensor): 173 | speed = tf.norm(current_vel, axis=-1) 174 | elif isinstance(current_vel, np.ndarray): 175 | speed = np.linalg.norm(current_vel, axis=-1) 176 | return speed 177 | 178 | def get_x_negative(proprioceptive_obs, upper_bound=2.0): 179 | current_vel = proprioceptive_obs[..., -6] 180 | if isinstance(current_vel, tf.Tensor): 181 | reward = -tf.clip_by_value(current_vel, -upper_bound, upper_bound) 182 | elif isinstance(current_vel, np.ndarray): 183 | reward = -np.clip(current_vel, -upper_bound, upper_bound) 184 | return reward 185 | 186 | def get_target_similarity(obs, target): 187 | """ 188 | get the similarity between the current obs and the target with rbf kernel 189 | """ 190 | current_vel = slice_observable(obs, 'sensors_velocimeter') 191 | if isinstance(obs, tf.Tensor): 192 | return tf.exp(-tf.norm(current_vel - tf.convert_to_tensor(target), axis=-1)) 193 | elif isinstance(obs, np.ndarray): 194 | return np.exp(-np.linalg.norm(current_vel - target, axis=-1)) 195 | else: 196 | raise ValueError("obs should be either tf.Tensor or np.ndarray") -------------------------------------------------------------------------------- /trajectory/models/transformer_prior.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | import torch.nn as nn 9 | from trajectory.models.transformers import * 10 | from trajectory.models.ein import EinLinear 11 | 12 | class TransformerPrior(nn.Module): 13 | """ the full GPT language model, with a context size of block_size """ 14 | 15 | def __init__(self, config): 16 | super().__init__() 17 | if "data_parallel" in config: 18 | self.data_parallel = config.data_parallel 19 | else: 20 | self.data_parallel = False 21 | 22 | self.transition_dim = config.transition_dim 23 | self.observation_dim = config.observation_dim 24 | self.action_dim = config.action_dim 25 | # inputs to embed in additon to states 26 | self.n_embd = config.n_embd 27 | self.max_latents_length = config.max_sequence_length // config.latent_step * config.code_per_step 28 | 29 | # transformer for the policy 30 | self.tok_emb = nn.Embedding(config.K, config.n_embd) 31 | self.pos_emb = nn.Parameter(torch.zeros(1, self.max_latents_length, config.n_embd)) 32 | self.state_emb = nn.Linear(config.observation_dim, config.n_embd) 33 | self.drop = nn.Dropout(config.embd_pdrop) 34 | self.blocks = nn.ModuleList([Block(config.n_embd, config.n_head, config.resid_pdrop, 35 | config.attn_pdrop, sequence_length=self.max_latents_length, 36 | causal=True, rotary=(config.position_embedding == "rotary"), 37 | kv_cache=True) for _ in 38 | range(config.n_layer)]) 39 | 40 | # decoder head 41 | self.ln_f = nn.LayerNorm(config.n_embd) 42 | self.latent_step = config.latent_step 43 | self.code_per_step = config.code_per_step 44 | 45 | # output tensor is a concatenation of logits and q_value 46 | self.policy_head = nn.Linear(config.n_embd, config.K, bias=False) 47 | 48 | self.vocab_size = config.K 49 | self.embedding_dim = config.n_embd 50 | self.apply(self._init_weights) 51 | 52 | def get_block_size(self): 53 | return self.block_size 54 | 55 | def _init_weights(self, module): 56 | if isinstance(module, (nn.Linear, nn.Embedding)): 57 | module.weight.data.normal_(mean=0.0, std=0.02) 58 | if isinstance(module, nn.Linear) and module.bias is not None: 59 | module.bias.data.zero_() 60 | elif isinstance(module, nn.LayerNorm): 61 | module.bias.data.zero_() 62 | module.weight.data.fill_(1.0) 63 | 64 | def configure_optimizers(self, train_config): 65 | """ 66 | This long function is unfortunately doing something very simple and is being very defensive: 67 | We are separating out all parameters of the model into two buckets: those that will experience 68 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 69 | We are then returning the PyTorch optimizer object. 70 | """ 71 | 72 | # separate out all parameters to those that will and won't experience regularizing weight decay 73 | decay = set() 74 | no_decay = set() 75 | whitelist_weight_modules = (torch.nn.Linear, EinLinear) 76 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 77 | for mn, m in self.named_modules(): 78 | for pn, p in m.named_parameters(): 79 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 80 | 81 | if pn.endswith('bias'): 82 | # all biases will not be decayed 83 | no_decay.add(fpn) 84 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 85 | # weights of whitelist modules will be weight decayed 86 | decay.add(fpn) 87 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 88 | # weights of blacklist modules will NOT be weight decayed 89 | no_decay.add(fpn) 90 | elif pn.endswith('norm.weight'): 91 | no_decay.add(fpn) 92 | 93 | # special case the position embedding parameter in the root GPT module as not decayed 94 | no_decay.add('pos_emb') 95 | 96 | # validate that we considered every parameter 97 | param_dict = {pn: p for pn, p in self.named_parameters()} 98 | inter_params = decay & no_decay 99 | union_params = decay | no_decay 100 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) 101 | assert len( 102 | param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 103 | % (str(param_dict.keys() - union_params),) 104 | 105 | # create the pytorch optimizer object 106 | optim_groups = [ 107 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 108 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 109 | ] 110 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) 111 | return optimizer 112 | 113 | def forward(self, idx, state, targets=None, progress=0, mask=None, kv_cache=None): 114 | """ 115 | idx : [ B x T ] 116 | state: [ B ] 117 | """ 118 | 119 | state = state.to(device=self.pos_emb.device, dtype=torch.float32) 120 | ## [ B x T x embedding_dim ] 121 | 122 | if kv_cache is None: 123 | if not idx is None: 124 | b, t = idx.size() 125 | assert t <= self.max_latents_length, "Cannot forward, model block size is exhausted." 126 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 127 | token_embeddings = torch.cat( 128 | [torch.zeros(size=(b, 1, self.embedding_dim)).to(token_embeddings), token_embeddings], 129 | dim=1) 130 | else: 131 | b = state.size(0) 132 | t = 0 133 | token_embeddings = torch.zeros(size=(b, 1, self.embedding_dim)).to(state) 134 | else: 135 | b = idx.size(0) 136 | t = kv_cache[0].shape[1] 137 | token_embeddings = self.tok_emb(idx).to(state) 138 | 139 | ## [ 1 x T+1 x embedding_dim ] 140 | if kv_cache is None: 141 | position_embeddings = self.pos_emb[:, :t + 1, :] # each position maps to a (learnable) vector 142 | else: 143 | position_embeddings = self.pos_emb[:, kv_cache[0].shape[1], :] 144 | ## [ B x 1 x embedding_dim] 145 | state_embeddings = self.state_emb(state)[:, None] 146 | ## [ B x T+1 x embedding_dim ] 147 | x = self.drop(token_embeddings + position_embeddings + state_embeddings) 148 | 149 | new_cache = [] 150 | if kv_cache is None: 151 | kv_cache = [None] * len(self.blocks) 152 | 153 | for block, block_cache in zip(self.blocks, kv_cache): 154 | x, new_block_cache = block(x, kv_cache=block_cache) 155 | if not self.training: 156 | new_cache.append(new_block_cache) 157 | 158 | ## [ B x T+1 x embedding_dim ] 159 | x = self.ln_f(x) 160 | 161 | # if we are given some desired targets also calculate the loss 162 | if targets is not None: 163 | policy_logits = self.policy_head(x).reshape(b, t + 1, -1) 164 | # update return stats 165 | target_latent_codes = targets["codes"] 166 | policy_loss = F.cross_entropy(policy_logits.reshape(-1, self.vocab_size), target_latent_codes.reshape([-1]), 167 | reduction='none') 168 | 169 | logs = {"latent_policy_loss": policy_loss.detach(), 170 | } 171 | 172 | if "weights" in targets: 173 | weights_selected = targets["weights"].gather(-1, target_latent_codes[:, :, None]) 174 | policy_loss = torch.reshape(policy_loss, [b, t+1, -1]) * weights_selected 175 | 176 | if mask is not None: 177 | policy_loss = policy_loss.reshape(b, -1) * mask[:, :, 0] 178 | 179 | return policy_logits, policy_loss, new_cache, logs 180 | else: 181 | if kv_cache[0] is not None: 182 | policy_logits = self.policy_head(x).reshape(b, 1, -1) 183 | else: 184 | policy_logits = self.policy_head(x).reshape(b, t+1, -1) 185 | loss = torch.tensor(0.0, device=state.device) 186 | return policy_logits, loss, new_cache, {} 187 | -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | # 2 | # This file is autogenerated by pip-compile with Python 3.9 3 | # by the following command: 4 | # 5 | # pip-compile --output-file=requirements/requirements.txt requirements/requirements.in 6 | # 7 | --extra-index-url http://webservice 8 | --find-links https://download.pytorch.org/whl/cu116 9 | --trusted-host webservice 10 | 11 | absl-py==1.4.0 12 | # via 13 | # -r requirements/requirements.in 14 | # chex 15 | # dm-control 16 | # dm-env 17 | # labmaze 18 | # mujoco 19 | # rlds 20 | # tensorboard 21 | # tensorflow 22 | accelerate==0.18.0 23 | # via -r requirements/requirements.in 24 | appdirs==1.4.4 25 | # via wandb 26 | astunparse==1.6.3 27 | # via tensorflow 28 | cachetools==5.3.2 29 | # via google-auth 30 | certifi==2023.11.17 31 | # via 32 | # requests 33 | # sentry-sdk 34 | cffi==1.16.0 35 | # via mujoco-py 36 | charset-normalizer==3.3.2 37 | # via requests 38 | chex==0.1.7 39 | # via -r requirements/requirements.in 40 | click==8.1.7 41 | # via 42 | # d4rl 43 | # wandb 44 | cloudpickle==3.0.0 45 | # via 46 | # gym 47 | # gymnasium 48 | # stable-baselines3 49 | contourpy==1.2.0 50 | # via matplotlib 51 | cycler==0.12.1 52 | # via matplotlib 53 | cython==0.29.36 54 | # via 55 | # -r requirements/requirements.in 56 | # mujoco-py 57 | d4rl @ git+https://github.com/JannerM/d4rl.git@c3dd04da02acbf4de6cbaa1141deb4f958f03ca9 58 | # via -r requirements/requirements.in 59 | dm-control==1.0.2 60 | # via -r requirements/requirements.in 61 | dm-env==1.6 62 | # via 63 | # -r requirements/requirements.in 64 | # dm-control 65 | dm-tree==0.1.8 66 | # via 67 | # -r requirements/requirements.in 68 | # chex 69 | # dm-control 70 | # dm-env 71 | docker-pycreds==0.4.0 72 | # via wandb 73 | docstring-parser==0.15 74 | # via typed-argument-parser 75 | einops==0.4.1 76 | # via 77 | # -r requirements/requirements.in 78 | # rotary-embedding-torch 79 | farama-notifications==0.0.4 80 | # via gymnasium 81 | fasteners==0.19 82 | # via mujoco-py 83 | flatbuffers==23.5.26 84 | # via 85 | # -r requirements/requirements.in 86 | # tensorflow 87 | fonttools==4.45.1 88 | # via matplotlib 89 | future==0.18.3 90 | # via dm-control 91 | gast==0.4.0 92 | # via tensorflow 93 | gitdb==4.0.11 94 | # via gitpython 95 | gitpython==3.1.40 96 | # via 97 | # -r requirements/requirements.in 98 | # wandb 99 | glfw==2.6.3 100 | # via 101 | # dm-control 102 | # mujoco 103 | # mujoco-py 104 | google-auth==2.23.4 105 | # via 106 | # google-auth-oauthlib 107 | # tensorboard 108 | google-auth-oauthlib==0.4.6 109 | # via tensorboard 110 | google-pasta==0.2.0 111 | # via tensorflow 112 | grpcio==1.59.3 113 | # via 114 | # tensorboard 115 | # tensorflow 116 | gym==0.21.0 117 | # via 118 | # -r requirements/requirements.in 119 | # d4rl 120 | gymnasium==0.29.1 121 | # via stable-baselines3 122 | h5py==3.10.0 123 | # via 124 | # -r requirements/requirements.in 125 | # d4rl 126 | # tensorflow 127 | idna==3.6 128 | # via requests 129 | imageio==2.33.0 130 | # via 131 | # mujoco-py 132 | # scikit-image 133 | importlib-metadata==6.8.0 134 | # via 135 | # gymnasium 136 | # jax 137 | # markdown 138 | importlib-resources==6.1.1 139 | # via matplotlib 140 | jax==0.4.20 141 | # via chex 142 | jaxlib==0.4.20 143 | # via chex 144 | keras==2.11.0 145 | # via tensorflow 146 | kiwisolver==1.4.5 147 | # via matplotlib 148 | labmaze==1.0.6 149 | # via dm-control 150 | lazy-loader==0.3 151 | # via scikit-image 152 | libclang==16.0.6 153 | # via tensorflow 154 | lxml==4.9.3 155 | # via dm-control 156 | markdown==3.5.1 157 | # via tensorboard 158 | markupsafe==2.1.3 159 | # via werkzeug 160 | matplotlib==3.8.2 161 | # via 162 | # -r requirements/requirements.in 163 | # stable-baselines3 164 | mjrl @ git+https://github.com/aravindr93/mjrl@3871d93763d3b49c4741e6daeaebbc605fe140dc 165 | # via -r requirements/requirements.in 166 | ml-dtypes==0.2.0 167 | # via 168 | # jax 169 | # jaxlib 170 | mujoco==2.1.5 171 | # via 172 | # -r requirements/requirements.in 173 | # dm-control 174 | mujoco-py==2.1.2.14 175 | # via 176 | # -r requirements/requirements.in 177 | # d4rl 178 | mypy-extensions==1.0.0 179 | # via typing-inspect 180 | networkx==3.2.1 181 | # via scikit-image 182 | numpy==1.23.5 183 | # via 184 | # -r requirements/requirements.in 185 | # accelerate 186 | # chex 187 | # contourpy 188 | # d4rl 189 | # dm-control 190 | # dm-env 191 | # gym 192 | # gymnasium 193 | # h5py 194 | # imageio 195 | # jax 196 | # jaxlib 197 | # labmaze 198 | # matplotlib 199 | # ml-dtypes 200 | # mujoco 201 | # mujoco-py 202 | # opt-einsum 203 | # pandas 204 | # rlds 205 | # scikit-image 206 | # scikit-video 207 | # scipy 208 | # stable-baselines3 209 | # tensorboard 210 | # tensorflow 211 | # tifffile 212 | # xformers 213 | nvidia-cublas-cu11==11.10.3.66 214 | # via 215 | # nvidia-cudnn-cu11 216 | # torch 217 | nvidia-cuda-nvrtc-cu11==11.7.99 218 | # via torch 219 | nvidia-cuda-runtime-cu11==11.7.99 220 | # via torch 221 | nvidia-cudnn-cu11==8.5.0.96 222 | # via torch 223 | oauthlib==3.2.2 224 | # via requests-oauthlib 225 | opt-einsum==3.3.0 226 | # via 227 | # -r requirements/requirements.in 228 | # jax 229 | # tensorflow 230 | packaging==23.2 231 | # via 232 | # accelerate 233 | # matplotlib 234 | # scikit-image 235 | # tensorflow 236 | pandas==2.1.3 237 | # via 238 | # -r requirements/requirements.in 239 | # stable-baselines3 240 | pillow==10.1.0 241 | # via 242 | # imageio 243 | # matplotlib 244 | # scikit-image 245 | # scikit-video 246 | protobuf==3.19.6 247 | # via 248 | # dm-control 249 | # tensorboard 250 | # tensorflow 251 | # wandb 252 | psutil==5.9.6 253 | # via 254 | # accelerate 255 | # wandb 256 | pyasn1==0.5.1 257 | # via 258 | # pyasn1-modules 259 | # rsa 260 | pyasn1-modules==0.3.0 261 | # via google-auth 262 | pybullet==3.2.5 263 | # via d4rl 264 | pycparser==2.21 265 | # via cffi 266 | pyopengl==3.1.7 267 | # via 268 | # dm-control 269 | # mujoco 270 | pyparsing==2.4.7 271 | # via 272 | # dm-control 273 | # matplotlib 274 | pyre-extensions==0.0.23 275 | # via xformers 276 | python-dateutil==2.8.2 277 | # via 278 | # matplotlib 279 | # pandas 280 | pytz==2023.3.post1 281 | # via pandas 282 | pyyaml==6.0.1 283 | # via 284 | # accelerate 285 | # wandb 286 | requests==2.31.0 287 | # via 288 | # dm-control 289 | # requests-oauthlib 290 | # tensorboard 291 | # wandb 292 | requests-oauthlib==1.3.1 293 | # via google-auth-oauthlib 294 | rlds==0.1.8 295 | # via -r requirements/requirements.in 296 | rotary-embedding-torch==0.3.6 297 | # via -r requirements/requirements.in 298 | rsa==4.9 299 | # via google-auth 300 | scikit-image==0.22.0 301 | # via -r requirements/requirements.in 302 | scikit-video==1.1.11 303 | # via -r requirements/requirements.in 304 | scipy==1.11.4 305 | # via 306 | # dm-control 307 | # jax 308 | # jaxlib 309 | # scikit-image 310 | # scikit-video 311 | sentry-sdk==1.37.1 312 | # via wandb 313 | setproctitle==1.3.3 314 | # via wandb 315 | six==1.16.0 316 | # via 317 | # astunparse 318 | # docker-pycreds 319 | # google-pasta 320 | # python-dateutil 321 | # tensorflow 322 | smmap==5.0.1 323 | # via gitdb 324 | stable-baselines3==2.2.1 325 | # via -r requirements/requirements.in 326 | tensorboard==2.11.2 327 | # via tensorflow 328 | tensorboard-data-server==0.6.1 329 | # via tensorboard 330 | tensorboard-plugin-wit==1.8.1 331 | # via tensorboard 332 | tensorflow==2.11.1 333 | # via -r requirements/requirements.in 334 | tensorflow-estimator==2.11.0 335 | # via tensorflow 336 | tensorflow-io-gcs-filesystem==0.34.0 337 | # via tensorflow 338 | termcolor==2.3.0 339 | # via 340 | # d4rl 341 | # tensorflow 342 | tifffile==2023.9.26 343 | # via scikit-image 344 | toolz==0.12.0 345 | # via chex 346 | torch==1.13.1 347 | # via 348 | # -r requirements/requirements.in 349 | # accelerate 350 | # rotary-embedding-torch 351 | # stable-baselines3 352 | # xformers 353 | tqdm==4.66.1 354 | # via dm-control 355 | typed-argument-parser==1.9.0 356 | # via -r requirements/requirements.in 357 | typing-extensions==4.8.0 358 | # via 359 | # chex 360 | # gymnasium 361 | # pyre-extensions 362 | # tensorflow 363 | # torch 364 | # typing-inspect 365 | # wandb 366 | typing-inspect==0.9.0 367 | # via 368 | # pyre-extensions 369 | # typed-argument-parser 370 | tzdata==2023.3 371 | # via pandas 372 | urllib3==2.1.0 373 | # via 374 | # requests 375 | # sentry-sdk 376 | wandb==0.16.0 377 | # via -r requirements/requirements.in 378 | werkzeug==3.0.1 379 | # via tensorboard 380 | wheel==0.42.0 381 | # via 382 | # astunparse 383 | # nvidia-cublas-cu11 384 | # nvidia-cuda-runtime-cu11 385 | # tensorboard 386 | wrapt==1.14.1 387 | # via tensorflow 388 | xformers==0.0.16 389 | # via -r requirements/requirements.in 390 | zipp==3.17.0 391 | # via 392 | # importlib-metadata 393 | # importlib-resources 394 | 395 | # The following packages are considered to be unsafe in a requirements file: 396 | # setuptools 397 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import numpy as np 9 | import torch 10 | import json 11 | import pdb 12 | import pickle 13 | 14 | import trajectory.utils as utils 15 | from trajectory.utils.dataset import create_dataset 16 | from trajectory.models.vqvae import VQContinuousVAE 17 | from trajectory.utils.serialization import get_latest_epoch 18 | import wandb 19 | from accelerate.logging import get_logger 20 | 21 | torch.backends.cuda.matmul.allow_tf32 = True 22 | os.environ["CUDA_HOME"] = os.environ["CONDA_PREFIX"] 23 | os.environ[ 24 | "TORCH_DISTRIBUTED_DEBUG" 25 | ] = "DETAIL" 26 | 27 | class Parser(utils.Parser): 28 | dataset: str = 'halfcheetah-medium-expert-v2' 29 | config: str = 'config.vqvae' 30 | 31 | def save_model(args, trainer, model, save_epoch): 32 | ## save state, optimizer, scaler to disk 33 | tmp_model_statepath = os.path.join(args.savepath, f'tmp_state_{save_epoch}.pt') 34 | model_statepath = os.path.join(args.savepath, f'state_{save_epoch}.pt') 35 | model_state = trainer.accelerator.unwrap_model(model).state_dict() 36 | tmp_optimizer_statepath = os.path.join(args.savepath, f'tmp_optimizer_state_{save_epoch}.pt') 37 | optimizer_statepath = os.path.join(args.savepath, f'optimizer_state_{save_epoch}.pt') 38 | optimizer_state = trainer.accelerator.unwrap_model(trainer.optimizer).state_dict() 39 | 40 | trainer.accelerator.save(model_state, tmp_model_statepath) 41 | trainer.accelerator.save(optimizer_state, tmp_optimizer_statepath) 42 | 43 | 44 | ## rename saved tmp files to files 45 | ## save stats and dataloader 46 | if trainer.accelerator.is_main_process: 47 | os.rename(tmp_model_statepath, model_statepath) 48 | trainer.accelerator.print(f"Saved model to {model_statepath}\n") 49 | os.rename(tmp_optimizer_statepath, optimizer_statepath) 50 | trainer.accelerator.print(f"Saved optimizer to {optimizer_statepath}\n") 51 | statspath = os.path.join(args.savepath, 'stats.pkl') 52 | with open(statspath, 'wb') as f: 53 | pickle.dump(trainer.stats, f) 54 | trainer.accelerator.print(f"Saved stats: {trainer.stats}\n") 55 | trainer.loader.save() 56 | trainer.accelerator.print(f"Saved dataloder state\n") 57 | accelerator_statepath = os.path.join(args.savepath, f'accelerator_state') 58 | trainer.accelerator.save_state(output_dir=accelerator_statepath) 59 | trainer.accelerator.print(f"Saved accelerator state to {accelerator_statepath}\n") 60 | 61 | def main(): 62 | 63 | args = Parser().parse_args('train') 64 | args.n_layer = int(args.n_layer) 65 | 66 | logger = get_logger(__name__, log_level="DEBUG") 67 | 68 | ####################### 69 | ####### loading ####### 70 | ####################### 71 | 72 | args.logbase = os.path.expanduser(args.logbase) 73 | args.savepath = os.path.expanduser(args.savepath) 74 | if not os.path.exists(args.savepath): 75 | os.makedirs(args.savepath) 76 | 77 | ####################### 78 | ####### trainer ####### 79 | ####################### 80 | 81 | # use all available cuda devices for data parallelization 82 | num_gpus = torch.cuda.device_count() 83 | logger.debug(f"Using {num_gpus} gpus.\n", main_process_only=True) 84 | 85 | n_tokens_target = int(args.n_tokens_target) 86 | n_epochs_ref = int(args.n_epochs_ref) 87 | warmup_tokens = n_epochs_ref * n_tokens_target * 0.2 88 | final_tokens = n_epochs_ref * n_tokens_target 89 | 90 | trainer_config = utils.Config( 91 | utils.VQTrainer, 92 | savepath=(args.savepath, 'trainer_config.pkl'), 93 | # optimization parameters 94 | learning_rate=args.learning_rate, 95 | betas=(0.9, 0.95), 96 | grad_norm_clip=1.0, 97 | weight_decay=0.1, # only applied on matmul weights 98 | # learning rate decay: linear warmup followed by cosine decay to 10% of original 99 | lr_decay=False, 100 | warmup_tokens=warmup_tokens, 101 | kl_warmup_tokens=warmup_tokens*10, 102 | final_tokens=final_tokens, 103 | ## dataloader 104 | num_workers=int(args.num_workers), 105 | device=args.device, 106 | train_batch_size=int(args.train_batch_size), 107 | load_batch_size=int(args.load_batch_size), 108 | n_tokens_target=n_tokens_target, 109 | enable_fp16=args.enable_fp16, 110 | ) 111 | 112 | trainer = trainer_config() 113 | 114 | ############################ 115 | ######## DataLoader ######## 116 | ############################ 117 | 118 | dataset = create_dataset(args) 119 | 120 | obs_dim = dataset.observation_dim 121 | act_dim = dataset.action_dim 122 | transition_dim = obs_dim + act_dim + 3 123 | 124 | ####################### 125 | ######## model ######## 126 | ####################### 127 | 128 | block_size = args.subsampled_sequence_length * transition_dim # total number of dimensionalities for a maximum length sequence (T) 129 | 130 | logger.debug( 131 | f'Joined dim: {transition_dim} ' 132 | f'(observation: {obs_dim}, action: {act_dim}) | Block size: {block_size}', 133 | main_process_only=True) 134 | 135 | 136 | model_config = utils.Config( 137 | VQContinuousVAE, 138 | savepath=(args.savepath, 'model_config.pkl'), 139 | ## discretization 140 | vocab_size=args.N, block_size=block_size, 141 | K=args.K, 142 | code_per_step=args.code_per_step, 143 | ## architecture 144 | n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd * args.n_head, 145 | ## dimensions 146 | observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim, 147 | ## loss weighting 148 | action_weight=args.action_weight, reward_weight=args.reward_weight, value_weight=args.value_weight, 149 | position_weight=args.position_weight, 150 | trajectory_embd=args.trajectory_embd, 151 | model=args.model, 152 | latent_step=args.latent_step, 153 | ma_update=args.ma_update, 154 | residual=args.residual, 155 | obs_shape=args.obs_shape, 156 | ## dropout probabilities 157 | embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop, 158 | bottleneck=args.bottleneck, 159 | masking=args.masking, 160 | ae_type=args.ae_type, 161 | state_conditional=args.state_conditional, 162 | use_discriminator=args.use_discriminator, 163 | disc_start=args.disc_start, 164 | blocks_per_layer=args.blocks_per_layer, 165 | encoder_inputs=args.encoder_inputs, 166 | position_embedding=args.position_embedding, 167 | causal_attention=args.causal_attention, 168 | causal_conv=args.causal_conv, 169 | symlog=args.symlog, 170 | data_parallel=args.data_parallel, 171 | ) 172 | model = model_config() 173 | 174 | # initialize or load stats 175 | statspath = os.path.join(args.savepath, 'stats.pkl') 176 | if os.path.isfile(statspath): 177 | with open(statspath, 'rb') as f: 178 | stats = pickle.load(f) 179 | logger.debug(f"Loaded stats: {stats}\n", main_process_only=True) 180 | model, _ = utils.load_model(logger, args.logbase, args.dataset, args.exp_name, epoch="latest", device=args.device) 181 | else: 182 | stats = { 183 | "n_epochs": 0, 184 | "n_tokens": 0, # counter used for learning rate decay 185 | "last_save_n_tokens": 0, 186 | "last_logging_n_tokens": 0, 187 | "n_steps": 0, 188 | "n_logging": 0, 189 | } 190 | trainer.accelerator.print("Training from scratch...\n") 191 | 192 | model.set_padding_vector(np.zeros(model.transition_dim-1)) 193 | 194 | ####################### 195 | ###### main loop ###### 196 | ####################### 197 | 198 | ## scale number of epochs to keep number of updates constant 199 | n_saves = int(args.n_saves) 200 | save_freq = int(n_epochs_ref // n_saves) 201 | wandb_conf = {"entity": "transferplan", "group": args.exp_name, "reinit": True, "config": args, "tags": [args.exp_name, args.tag]} 202 | trainer.init_stats(stats) 203 | trainer.init_data_loader(dataset, stats) 204 | trainer.init_wandb(wandb_conf, name=args.exp_name) 205 | trainer.accelerator.print(f'model parameters {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000}M') 206 | 207 | # load accelerator state 208 | accelerator_statepath = os.path.join(args.savepath, f'accelerator_state.pt') 209 | if os.path.exists(accelerator_statepath): 210 | trainer.accelerator.load_state(accelerator_statepath) 211 | 212 | optimizer = trainer.get_optimizer(model) 213 | 214 | # load optimizer and scaler if needed 215 | optimizer, _ = utils.load_optimizer(logger, optimizer, args.logbase, args.dataset, args.exp_name, 216 | epoch="latest", device=args.device, type="") 217 | 218 | model, trainer.optimizer = trainer.accelerator.prepare(model, optimizer) 219 | 220 | start_ep = trainer.stats['n_epochs'] 221 | for epoch in range(start_ep, n_epochs_ref): 222 | trainer.accelerator.print(f'\nEpoch: {epoch} / {n_epochs_ref} | {args.dataset} | {args.exp_name}') 223 | 224 | trainer.train(model, dataset, save_freq=1e4, savepath=args.savepath) 225 | 226 | if epoch % 10 == 0: 227 | ## get greatest multiple of `save_freq` less than or equal to `save_epoch` 228 | save_epoch = (epoch + 1) // save_freq * save_freq 229 | save_model(args, trainer, model, save_epoch) 230 | 231 | # save the final trained model 232 | save_epoch = n_epochs_ref // save_freq * save_freq 233 | save_model(args, trainer, model, save_epoch) 234 | 235 | if __name__ == '__main__': 236 | main() 237 | -------------------------------------------------------------------------------- /scripts/trainprior.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import numpy as np 9 | import torch 10 | import pickle 11 | import gc 12 | from GPUtil import showUtilization as gpu_usage 13 | 14 | import trajectory.utils as utils 15 | from trajectory.utils.dataset import create_dataset 16 | from trajectory.models.transformer_prior import TransformerPrior 17 | from accelerate.logging import get_logger 18 | 19 | os.environ["CUDA_HOME"] = os.environ["CONDA_PREFIX"] 20 | os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL" 21 | torch.backends.cuda.matmul.allow_tf32 = True 22 | 23 | class Parser(utils.Parser): 24 | dataset: str = 'halfcheetah-medium-expert-v2' 25 | config: str = 'config.vqvae' 26 | 27 | def save_model(args, trainer, model, save_epoch): 28 | ## save state, optimizer to disk 29 | tmp_model_statepath = os.path.join(args.savepath, f'tmp_{args.type}_state_{save_epoch}.pt') 30 | model_statepath = os.path.join(args.savepath, f'{args.type}_state_{save_epoch}.pt') 31 | model_state = trainer.accelerator.unwrap_model(model).state_dict() 32 | tmp_optimizer_statepath = os.path.join(args.savepath, f'tmp_{args.type}_optimizer_state_{save_epoch}.pt') 33 | optimizer_statepath = os.path.join(args.savepath, f'{args.type}_optimizer_state_{save_epoch}.pt') 34 | optimizer_state = trainer.accelerator.unwrap_model(trainer.optimizer).state_dict() 35 | 36 | trainer.accelerator.save(model_state, tmp_model_statepath) 37 | trainer.accelerator.save(optimizer_state, tmp_optimizer_statepath) 38 | 39 | ## rename saved tmp files to files 40 | ## save stats and dataloader 41 | if trainer.accelerator.is_main_process: 42 | os.rename(tmp_model_statepath, model_statepath) 43 | trainer.accelerator.print(f"Saved model to {model_statepath}\n") 44 | os.rename(tmp_optimizer_statepath, optimizer_statepath) 45 | trainer.accelerator.print(f"Saved optimizer to {optimizer_statepath}\n") 46 | statspath = os.path.join(args.savepath, f'{args.type}_stats.pkl') 47 | with open(statspath, 'wb') as f: 48 | pickle.dump(trainer.stats, f) 49 | trainer.accelerator.print(f"Saved stats: {trainer.stats}\n") 50 | trainer.loader.save() 51 | trainer.accelerator.print(f"Saved dataloder state\n") 52 | accelerator_statepath = os.path.join(args.savepath, f'{args.type}_accelerator_state') 53 | trainer.accelerator.save_state(output_dir=accelerator_statepath) 54 | trainer.accelerator.print(f"Saved accelerator state to {accelerator_statepath}\n") 55 | 56 | def main(): 57 | gc.collect() 58 | torch.cuda.empty_cache() 59 | # print("Initial GPU Usage") 60 | # gpu_usage() 61 | ####################### 62 | ######## setup ######## 63 | ####################### 64 | logger = get_logger(__name__, log_level="DEBUG") 65 | 66 | args = Parser().parse_args('plan') 67 | if "vae_name" in args.__dict__ and args.vae_name != "": 68 | vae_name = args.vae_name 69 | else: 70 | vae_name = args.exp_name 71 | 72 | representation, _ = utils.load_model(logger, args.logbase, args.dataset, vae_name, epoch=args.gpt_epoch, device=args.device) 73 | args = Parser().parse_args('train') 74 | 75 | sequence_length = args.subsampled_sequence_length * args.step 76 | args.logbase = os.path.expanduser(args.logbase) 77 | args.savepath = os.path.expanduser(args.savepath) 78 | args.code_per_step = int(args.code_per_step) 79 | 80 | if not os.path.exists(args.savepath): 81 | os.makedirs(args.savepath) 82 | ## HACK: to avoid launch duplicate job 83 | if os.path.exists(os.path.join(args.savepath, 'critic_state_600.pt')): 84 | return 85 | 86 | dataset = create_dataset(args) 87 | 88 | obs_dim = dataset.observation_dim 89 | act_dim = dataset.action_dim 90 | 91 | transition_dim = obs_dim + act_dim + 3 92 | 93 | representation.set_padding_vector(np.zeros(representation.transition_dim - 1)) 94 | 95 | obs_dim = dataset.observation_dim 96 | 97 | model_config = utils.Config( 98 | TransformerPrior, 99 | savepath=(args.savepath, f'{args.type}_model_config.pkl'), 100 | ## discretization 101 | K=representation.K, max_sequence_length=args.subsampled_sequence_length, 102 | ## architecture 103 | observation_dim=obs_dim, action_dim=act_dim, transition_dim=transition_dim, 104 | n_layer=args.prior_layer, n_head=args.prior_head, n_embd=args.prior_embd * args.prior_head, 105 | value_layer=args.value_layer, value_head=args.value_head, value_embd=args.value_embd * args.value_head, 106 | ## loss weighting 107 | latent_step=args.latent_step, 108 | code_per_step=representation.code_per_step, 109 | ## dropout probabilities 110 | embd_pdrop=args.embd_pdrop, resid_pdrop=args.resid_pdrop, attn_pdrop=args.attn_pdrop, 111 | obs_shape=args.obs_shape, 112 | position_embedding=args.position_embedding, 113 | latent_steps=representation.latent_step, 114 | data_parallel=args.data_parallel, 115 | twohot_value=args.twohot_value, 116 | value_ema_rate=args.value_ema_rate, 117 | tau=args.tau, 118 | cql_weight=args.cql_weight, 119 | ) 120 | 121 | 122 | num_gpus = torch.cuda.device_count() 123 | logger.debug(f"Using {num_gpus} gpus.\n", main_process_only=True) 124 | 125 | ####################### 126 | ####### trainer ####### 127 | ####################### 128 | 129 | n_tokens_target = int(args.n_tokens_target) 130 | n_epochs_ref = int(args.n_epochs_ref) 131 | warmup_tokens = n_epochs_ref*n_tokens_target*0.2 132 | final_tokens = n_epochs_ref*n_tokens_target 133 | 134 | trainer_config = utils.Config( 135 | utils.PriorTrainer, 136 | savepath=(args.savepath, f'{args.type}trainer_config.pkl'), 137 | # optimization parameters 138 | train_batch_size=int(args.train_batch_size), 139 | load_batch_size=int(args.load_batch_size), 140 | learning_rate=args.prior_learning_rate, 141 | betas=(0.9, 0.95), 142 | grad_norm_clip=2.0 if "prior_gradient_norm_clip" not in args.__dict__ else args.prior_gradient_norm_clip, 143 | weight_decay=0.1, # only applied on matmul weights 144 | # learning rate decay: linear warmup followed by cosine decay to 10% of original 145 | lr_decay=args.lr_decay, 146 | warmup_tokens=warmup_tokens, 147 | kl_warmup_tokens=warmup_tokens*10, 148 | final_tokens=final_tokens, 149 | ## dataloader 150 | num_workers=args.num_workers, 151 | device=args.device, 152 | n_tokens_target=n_tokens_target, 153 | discount=args.discount, 154 | enable_fp16=args.enable_prior_fp16, 155 | bootstrap=args.bootstrap, 156 | bootstrap_ignore_terminal=args.bootstrap_ignore_terminal, 157 | type=args.type, 158 | ) 159 | 160 | trainer = trainer_config() 161 | 162 | 163 | # initialize or load stats 164 | statspath = os.path.join(args.savepath, f'{args.type}_stats.pkl') 165 | if os.path.isfile(statspath): 166 | with open(statspath, 'rb') as f: 167 | stats = pickle.load(f) 168 | logger.debug(f"Loaded stats: {stats}\n", main_process_only=True) 169 | # if trainer.accelerator.is_main_process: 170 | model, _ = utils.load_transformer_model(logger, args.logbase, args.dataset, 171 | args.prior_name if args.type=="prior_finetune" else args.exp_name, 172 | epoch="latest", device=args.device, type=args.type) 173 | dataset.restore() 174 | elif args.type == "prior_finetune": 175 | model, _ = utils.load_transformer_model(logger, args.logbase, args.dataset, 176 | args.prior_name if args.type=="prior_finetune" else args.exp_name, 177 | epoch="latest", device=args.device, type="prior") 178 | dataset.restore() 179 | stats = { 180 | "n_epochs": 0, 181 | "n_tokens": 0, # counter used for learning rate decay 182 | "last_save_n_tokens": 0, 183 | "last_logging_n_tokens": 0, 184 | "n_steps": 0, 185 | "ema_reconstruction": 0, 186 | "n_logging": 0, 187 | } 188 | else: 189 | model = model_config() 190 | stats = { 191 | "n_epochs": 0, 192 | "n_tokens": 0, # counter used for learning rate decay 193 | "last_save_n_tokens": 0, 194 | "last_logging_n_tokens": 0, 195 | "n_steps": 0, 196 | "ema_reconstruction": 0, 197 | "n_logging": 0, 198 | } 199 | if args.type == "critic" and args.prior_name != "": 200 | if trainer.accelerator.is_main_process: 201 | prior_model, _ = utils.load_transformer_model(logger, args.logbase, args.dataset, args.prior_name, 202 | epoch="latest", device=args.device, type="prior") 203 | else: 204 | prior_model = None 205 | 206 | if args.type == "prior_finetune": 207 | if trainer.accelerator.is_main_process: 208 | critic_model, _ = utils.load_transformer_model(logger, args.logbase, args.dataset, args.critic_name, 209 | epoch="latest", device=args.device, type="critic") 210 | critic_model.eval() 211 | else: 212 | critic_model = None 213 | 214 | # print("After model init GPU Usage") 215 | # gpu_usage() 216 | 217 | ####################### 218 | ###### main loop ###### 219 | ####################### 220 | 221 | ## scale number of epochs to keep number of updates constant 222 | n_saves = int(args.n_saves) 223 | save_freq = int(n_epochs_ref // n_saves) 224 | # wandb.init(project="latentPlanning", entity="transferplan", group=args.exp_name, reinit=True, config=args, tags=[args.exp_name, args.tag, "prior"]) 225 | wandb_conf = {"entity": "transferplan", "group": args.exp_name, "reinit": True, "config": args, "tags": [args.exp_name, args.tag, args.type]} 226 | trainer.init_stats(stats) 227 | trainer.init_data_loader(dataset) 228 | trainer.init_wandb(wandb_conf, name=args.exp_name) 229 | trainer.accelerator.print(f'model parameters {sum(p.numel() for p in model.parameters() if p.requires_grad) / 1000000}M') 230 | 231 | optimizer = trainer.get_optimizer(model) 232 | 233 | # load optimizer if needed 234 | if trainer.accelerator.is_main_process: 235 | optimizer, _ = utils.load_optimizer(logger, optimizer, args.logbase, args.dataset, args.exp_name, 236 | epoch="latest", device=args.device, type=args.type) 237 | 238 | torch.cuda.empty_cache() 239 | model, trainer.optimizer = trainer.accelerator.prepare(model, optimizer) 240 | # load accelerator state 241 | accelerator_statepath = os.path.join(args.savepath, f'{args.type}_accelerator_state.pt') 242 | if os.path.exists(accelerator_statepath): 243 | trainer.accelerator.load_state(accelerator_statepath) 244 | if prior_model: 245 | prior_model = trainer.accelerator.prepare(prior_model) 246 | if prior_model: 247 | critic_model = trainer.accelerator.prepare(critic_model) 248 | 249 | start_ep = trainer.stats['n_epochs'] 250 | for epoch in range(start_ep, n_epochs_ref): 251 | trainer.accelerator.print(f'\nEpoch: {epoch} / {n_epochs_ref} | {args.dataset} | {args.exp_name}') 252 | 253 | nan_loss = trainer.train(representation, model, dataset, type=args.type, prior_model=prior_model, critic_model=critic_model) 254 | if nan_loss: 255 | trainer.accelerator.print(f"Training aborted due to NaN losses!\n") 256 | return 257 | 258 | ## get greatest multiple of `save_freq` less than or equal to `save_epoch` 259 | if epoch % 10 == 0: 260 | save_epoch = (epoch + 1) // save_freq * save_freq 261 | save_model(args, trainer, model, save_epoch) 262 | 263 | # save the final trained model 264 | save_epoch = n_epochs_ref // save_freq * save_freq 265 | save_model(args, trainer, model, save_epoch) 266 | 267 | 268 | if __name__ == '__main__': 269 | main() 270 | -------------------------------------------------------------------------------- /trajectory/search/planning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from collections import defaultdict 8 | import torch 9 | from trajectory.utils.relabel_humanoid import get_speed, get_angular_vel, get_forward_vel, get_height_vel, get_left_vel 10 | 11 | 12 | REWARD_DIM = VALUE_DIM = 1 13 | 14 | import numpy as np 15 | 16 | 17 | def trajectory2rv(trajectories, objective): 18 | """ 19 | trajectories: [B x T x D] 20 | """ 21 | if objective == "reward": 22 | return trajectories[:, :, -3], trajectories[:, :, -2] 23 | elif objective == "speed": 24 | speed = get_speed(trajectories) 25 | return speed, speed 26 | elif objective == "shift_left": 27 | left_vel = get_left_vel(trajectories) 28 | return left_vel, left_vel 29 | elif objective == "forward": 30 | forward_vel = get_forward_vel(trajectories) 31 | return forward_vel, forward_vel 32 | elif objective == "backward": 33 | forward_vel = get_forward_vel(trajectories) 34 | return -forward_vel, -forward_vel 35 | elif objective == "jump": 36 | height_vel = get_height_vel(trajectories) 37 | jump_vel = torch.maximum(height_vel, torch.zeros_like(height_vel)) 38 | return jump_vel, jump_vel 39 | elif objective == "rotate_x": 40 | x_angular_vel = get_angular_vel(trajectories, "x") 41 | return x_angular_vel, x_angular_vel 42 | elif objective == "rotate_y": 43 | y_angular_vel = get_angular_vel(trajectories, "y") 44 | return y_angular_vel, y_angular_vel 45 | elif objective == "rotate_z": 46 | z_angular_vel = get_angular_vel(trajectories, "z") 47 | return z_angular_vel, z_angular_vel 48 | elif objective == "tracking": 49 | zeros = torch.zeros([trajectories.shape[0], trajectories.shape[1]]).to(trajectories) 50 | return zeros, zeros 51 | 52 | 53 | @torch.no_grad() 54 | def sample_with_prior(prior, model, x, denormalize_rew, denormalize_val, discount, steps, nb_samples=4096, rounds=8, 55 | likelihood_weight=5e2, prob_threshold=0.05, uniform=False, return_info=False, objective="reward", 56 | temperature=1.0, top_p=0.95): 57 | x = x.to(torch.float32) 58 | state = x[:, 0, :model.observation_dim] 59 | optimals = [] 60 | optimal_values = [] 61 | log_prob_list = [] 62 | entropy_list = [] 63 | info = defaultdict(list) 64 | for round in range(rounds): 65 | contex = None 66 | samples = None 67 | log_acc_probs = torch.zeros([1]).to(x) 68 | kv_cache = None 69 | for step in range(steps // model.latent_step): 70 | for internal_step in range(model.code_per_step): 71 | logits, _, kv_cache, _ = prior(samples.reshape([-1, 1]) if samples is not None else None, 72 | state, kv_cache=kv_cache) # [B x t x K] 73 | logits = logits / temperature # Add temperature control 74 | probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K] 75 | entropy_list.append(torch.mean(-torch.sum(probs * torch.log(probs + 1e-8), dim=-1))) 76 | log_probs = torch.log(probs) 77 | 78 | # Top-p sampling 79 | sorted_probs, indices = torch.sort(probs, descending=True, dim=-1) 80 | cumulative_probs = torch.cumsum(sorted_probs, dim=-1) 81 | mask = cumulative_probs < top_p 82 | mask[:, 0] = 1 # To make sure the most probable token is always included in the sampling set 83 | valid_probs = sorted_probs*mask 84 | if step == 0 and internal_step == 0: 85 | sampled_indices = torch.multinomial(valid_probs, num_samples=nb_samples // rounds, 86 | replacement=True) 87 | else: 88 | sampled_indices = torch.multinomial(valid_probs.reshape(nb_samples, -1), 89 | num_samples=1, replacement=True) 90 | samples = torch.gather(indices, 1, sampled_indices) 91 | 92 | samples_log_prob = torch.cat( 93 | [torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(log_probs, samples)]) # [B, M] 94 | 95 | log_prob_list.append(samples_log_prob.squeeze()) 96 | log_acc_probs = log_acc_probs + samples_log_prob.reshape([-1]) 97 | if contex is not None: 98 | contex = torch.cat([contex, samples.reshape([-1, 1])], dim=1) 99 | else: 100 | contex = samples.reshape([-1, step * model.code_per_step + internal_step + 1]) # [(B*M) x t] 101 | 102 | prediction_raw = model.decode_from_indices(contex, state) 103 | 104 | r, v = trajectory2rv(prediction_raw, objective=objective) 105 | 106 | discounts = torch.cumprod(torch.ones_like(r) * discount, dim=-1) 107 | values = torch.sum(r[:, :-1] * discounts[:, :-1], dim=-1) + v[:, -1] * discounts[:, -1] 108 | 109 | 110 | likelihood_bonus = likelihood_weight * torch.minimum(log_acc_probs, torch.log(torch.tensor(prob_threshold))) 111 | info["log_probs"].append(log_acc_probs.cpu().numpy()) 112 | info["log_prob_list"].append([prob.cpu().numpy().squeeze() for prob in log_prob_list]) 113 | info["returns"].append(values.cpu().numpy()) 114 | info["predictions"].append(prediction_raw.cpu().numpy()) 115 | info["objectives"].append(values.cpu().numpy() + likelihood_bonus.cpu().numpy()) 116 | info["latent_codes"].append(contex.cpu().numpy()) 117 | info["entropy"].append(torch.stack(entropy_list).cpu().numpy()) 118 | max_idx = (values + likelihood_bonus).argmax() 119 | optimal_value = values[max_idx] 120 | optimal = prediction_raw[max_idx] 121 | optimals.append(optimal) 122 | optimal_values.append(optimal_value.item()) 123 | 124 | for key, val in info.items(): 125 | info[key] = np.concatenate(val, axis=0) 126 | 127 | max_idx = np.array(optimal_values).argmax() 128 | optimal = optimals[max_idx] 129 | # for key in ["returns", "objectives"]: 130 | # val = info[key] 131 | # print(f"{key} {val},\n with mean: {np.mean(val)}, std {np.std(val)} \n") 132 | 133 | if return_info: 134 | return optimal.cpu().numpy(), info 135 | else: 136 | return optimal.cpu().numpy() 137 | 138 | 139 | @torch.no_grad() 140 | def beam_with_prior(prior, model, x, denormalize_rew, denormalize_val, discount, steps, 141 | beam_width, n_expand, prob_threshold=0.05, likelihood_weight=5e2, prob_acc="product", return_info=False, 142 | optimize_target="reward", temperature=1.0, top_p=0.99): 143 | contex = None 144 | state = x[:, 0, :prior.observation_dim] 145 | acc_probs = torch.zeros([1]).to(x) 146 | info = {} 147 | kv_cache = None 148 | for step in range(steps//model.latent_step): 149 | for internal_step in range(model.code_per_step): 150 | logits, _, new_kv_cache, _ = prior(contex, state) # [B x t x K] 151 | logits = logits / temperature # Add temperature control 152 | probs = torch.softmax(logits[:, -1, :], dim=-1) # [B x K] 153 | log_probs = torch.log(probs) 154 | nb_samples = beam_width * n_expand if step == 0 and internal_step == 0 else n_expand 155 | samples = torch.multinomial(probs, num_samples=nb_samples, replacement=True) # [B, M] 156 | samples_log_prob = torch.cat([torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(log_probs, samples)]) # [B, M] 157 | if prob_acc in ["product", "expect"]: 158 | acc_probs = acc_probs.repeat_interleave(nb_samples, 0) + samples_log_prob.reshape([-1]) 159 | elif prob_acc == "min": 160 | acc_probs = torch.minimum(acc_probs.repeat_interleave(nb_samples, 0), samples_log_prob.reshape([-1])) 161 | 162 | if not contex is None: 163 | contex = torch.cat([torch.repeat_interleave(contex, nb_samples, 0), samples.reshape([-1, 1])], 164 | dim=1) 165 | else: 166 | contex = samples.reshape([-1, step+1]) # [(B*M) x t] 167 | 168 | 169 | if internal_step==model.code_per_step-1: 170 | prediction_raw = model.decode_from_indices(contex, state) 171 | prediction = prediction_raw.reshape([-1, prediction_raw.shape[-1]]) 172 | 173 | r_t, V_t = trajectory2rv(prediction, optimize_target) 174 | if denormalize_rew is not None: 175 | r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1]) 176 | if denormalize_val is not None: 177 | V_t = denormalize_val(V_t).reshape([contex.shape[0], -1]) 178 | if return_info: 179 | info[(step + 1) * model.latent_step] = dict(predictions=prediction_raw.cpu(), returns=values.cpu(), 180 | latent_codes=contex.cpu(), log_probs=acc_probs.cpu(), 181 | objectives=values + likelihood_bonus, index=index.cpu()) 182 | discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1) 183 | values = torch.sum(r_t[:, :-1] * discounts[:, :-1], dim=-1) + V_t[:, -1] * discounts[:, -1] 184 | if prob_acc == "product": 185 | likelihood_bonus = likelihood_weight * torch.clamp(acc_probs, 0, np.log(prob_threshold) * ( 186 | steps // model.latent_step)) 187 | elif prob_acc == "min": 188 | likelihood_bonus = likelihood_weight * torch.clamp(acc_probs, 0, np.log(prob_threshold)) 189 | else: 190 | values = torch.zeros([contex.shape[0]]).to(x) 191 | likelihood_bonus = acc_probs 192 | 193 | nb_top = 1 if step == (steps//model.latent_step-1) and internal_step == (model.code_per_step-1) else beam_width 194 | if prob_acc == "expect": 195 | values_with_b, index = torch.topk(values*torch.exp(acc_probs), nb_top) 196 | else: 197 | values_with_b, index = torch.topk(values+likelihood_bonus, nb_top) 198 | 199 | contex = contex[index] 200 | acc_probs = acc_probs[index] 201 | 202 | optimal = prediction_raw[index[0]] 203 | if return_info: 204 | return optimal.cpu().numpy(), info 205 | else: 206 | return optimal.cpu().numpy() 207 | 208 | @torch.no_grad() 209 | def top_k(prior, model, x, denormalize_rew, denormalize_val, discount, steps, 210 | k, nb_samples, prob_threshold=0.05, likelihood_weight=5e2, return_info=False, optimize_target="reward"): 211 | x = x.to(torch.float32) 212 | contex = None 213 | state = x[:, 0, :prior.observation_dim] 214 | info = {} 215 | log_prob_list = [] 216 | log_acc_probs = torch.zeros([1]).to(x) 217 | for step in range(steps//model.latent_step): 218 | for internal_step in range(model.code_per_step): 219 | logits, _, values = prior(contex, state) # [B x t x M] 220 | logits = logits[:, -1, :] 221 | logits, candidate_latent_codes = torch.topk(logits, k, dim=-1) 222 | 223 | probs = torch.softmax(logits[:, :], dim=-1) # [B x K] 224 | log_probs = torch.log(probs) 225 | 226 | if step == 0 and internal_step == 0: 227 | samples_idx = torch.multinomial(probs, num_samples=nb_samples, replacement=True) # [B, K] 228 | else: 229 | samples_idx = torch.multinomial(probs, num_samples=1, replacement=True) # [B, K] 230 | latent_code = torch.cat([torch.index_select(c, 0, i) for c,i in zip(candidate_latent_codes, samples_idx)]) 231 | samples_log_prob = torch.cat( 232 | [torch.index_select(a, 0, i).unsqueeze(0) for a, i in zip(log_probs, samples_idx)]) # [B, K] 233 | log_prob_list.append(samples_log_prob.squeeze()) 234 | log_acc_probs = log_acc_probs + samples_log_prob.reshape([-1]) 235 | if not contex is None: 236 | contex = torch.cat([contex, latent_code.reshape([-1, 1])], dim=1) 237 | else: 238 | contex = latent_code.reshape([-1, step * model.code_per_step + internal_step + 1]) # [(B*M) x t] 239 | 240 | prediction_raw = model.decode_from_indices(contex, state) 241 | prediction = prediction_raw.reshape([-1, model.action_dim+model.observation_dim+3]) 242 | likelihood_bonus = likelihood_weight * torch.min(torch.stack(log_prob_list, dim=-1), 243 | torch.log(torch.tensor(prob_threshold))) 244 | likelihood_bonus = torch.sum(likelihood_bonus, dim=-1) 245 | 246 | r_t, V_t = prediction2rv(prediction, model, optimize_target) 247 | 248 | if denormalize_rew is not None: 249 | r_t = denormalize_rew(r_t).reshape([contex.shape[0], -1]) 250 | if denormalize_val is not None: 251 | V_t = denormalize_val(V_t).reshape([contex.shape[0], -1]) 252 | 253 | discounts = torch.cumprod(torch.ones_like(r_t) * discount, dim=-1) 254 | values = torch.sum(r_t[:,:-1] * discounts[:, :-1], dim=-1) + V_t[:,-1]*discounts[:, -1] 255 | #likelihood_bonus = likelihood_weight*torch.clamp(log_probs, -1e5, np.log(prob_threshold)*(steps//model.latent_step)) 256 | values_with_b, index = torch.topk(values+likelihood_bonus, 1) 257 | if return_info: 258 | info = dict(predictions=prediction_raw.cpu(), returns=values.cpu(), latent_codes=contex.cpu(), 259 | log_probs=torch.stack(log_prob_list).cpu().sum(), 260 | log_prob_list=[prob.cpu().numpy().squeeze() for prob in log_prob_list], 261 | objectives=values+likelihood_bonus, index=index.cpu()) 262 | 263 | optimal = prediction_raw[index[0]] 264 | if return_info: 265 | return optimal.cpu().numpy(), info 266 | else: 267 | return optimal.cpu().numpy() 268 | 269 | -------------------------------------------------------------------------------- /trajectory/datasets/sequence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import numpy as np 9 | import torch 10 | import tqdm 11 | import pdb 12 | 13 | from os.path import join 14 | import trajectory.utils as utils 15 | from trajectory.utils.arrays import to_torch 16 | 17 | from .d4rl import load_environment, qlearning_dataset_with_timeouts 18 | from .preprocessing import dataset_preprocess_functions 19 | 20 | def segment(observations, terminals, max_path_length): 21 | """ 22 | segment `observations` into trajectories according to `terminals` 23 | """ 24 | assert len(observations) == len(terminals) 25 | observation_dim = observations.shape[1] 26 | 27 | trajectories = [[]] 28 | for obs, term in zip(observations, terminals): 29 | trajectories[-1].append(obs) 30 | if term.squeeze(): 31 | trajectories.append([]) 32 | 33 | if len(trajectories[-1]) == 0: 34 | trajectories = trajectories[:-1] 35 | 36 | ## list of arrays because trajectories lengths will be different 37 | trajectories = [np.stack(traj, axis=0) for traj in trajectories] 38 | 39 | n_trajectories = len(trajectories) 40 | path_lengths = [len(traj) for traj in trajectories] 41 | 42 | ## pad trajectories to be of equal length 43 | trajectories_pad = np.zeros((n_trajectories, max_path_length, observation_dim), dtype=trajectories[0].dtype) 44 | early_termination = np.zeros((n_trajectories, max_path_length), dtype=bool) 45 | for i, traj in enumerate(trajectories): 46 | path_length = path_lengths[i] 47 | trajectories_pad[i,:path_length] = traj 48 | early_termination[i,path_length:] = 1 49 | 50 | return trajectories_pad, early_termination, path_lengths 51 | 52 | class SequenceDataset(torch.utils.data.Dataset): 53 | def __init__(self, env, sequence_length=250, step=10, discount=0.99, max_path_length=1000, 54 | penalty=None, device='cuda:0', normalize_raw=True, normalize_reward=True, train_portion=1.0, disable_goal=False, iql_critic=None, mix_rate=1.0, log_base="/tmp"): 55 | print(f'[ datasets/sequence ] Sequence length: {sequence_length} | Step: {step} | Max path length: {max_path_length}') 56 | env_name = env 57 | self.env = env = load_environment(env) if type(env) is str else env 58 | self.sequence_length = sequence_length 59 | self.step = step 60 | self.max_path_length = max_path_length 61 | self.device = device 62 | self.disable_goal = disable_goal 63 | 64 | print(f'[ datasets/sequence ] Loading...', end=' ', flush=True) 65 | if 'MineRL' in env.name: 66 | raise ValueError() 67 | dataset = qlearning_dataset_with_timeouts(env.unwrapped, terminate_on_end=True, disable_goal=disable_goal) 68 | # dataset = qlearning_dataset_with_timeouts(env, dataset=None, terminate_on_end=False) 69 | print('✓') 70 | 71 | preprocess_fn = dataset_preprocess_functions.get(env.name) 72 | if preprocess_fn: 73 | print(f'[ datasets/sequence ] Modifying environment') 74 | dataset = preprocess_fn(dataset) 75 | ## 76 | 77 | observations = dataset['observations'] 78 | actions = dataset['actions'].astype(np.float32) 79 | rewards = dataset['rewards'].astype(np.float32) 80 | terminals = dataset['terminals'] 81 | realterminals = dataset['realterminals'] 82 | 83 | self.observation_dim = observations.shape[1] 84 | self.action_dim = actions.shape[1] 85 | 86 | self.normalized_raw = normalize_raw 87 | self.normalize_reward = normalize_reward 88 | self.obs_mean, self.obs_std = observations.mean(axis=0, keepdims=True), observations.std(axis=0, keepdims=True) 89 | self.act_mean, self.act_std = actions.mean(axis=0, keepdims=True), actions.std(axis=0, keepdims=True) 90 | self.reward_mean, self.reward_std = rewards.mean(), rewards.std() 91 | 92 | if normalize_raw: 93 | observations = (observations-self.obs_mean) / self.obs_std 94 | actions = (actions-self.act_mean) / self.act_std 95 | 96 | self.observations_raw = observations 97 | self.actions_raw = actions 98 | self.joined_raw = np.concatenate([observations, actions], axis=-1, dtype=np.float32) 99 | self.rewards_raw = rewards 100 | self.terminals_raw = terminals 101 | 102 | ## terminal penalty 103 | if penalty is not None: 104 | terminal_mask = realterminals.squeeze() 105 | self.rewards_raw[terminal_mask] = penalty 106 | 107 | ## segment 108 | print(f'[ datasets/sequence ] Segmenting...', end=' ', flush=True) 109 | self.joined_segmented, self.termination_flags, self.path_lengths = segment(self.joined_raw, terminals, max_path_length) 110 | self.rewards_segmented, *_ = segment(self.rewards_raw, terminals, max_path_length) 111 | print('✓') 112 | 113 | self.discount = discount 114 | self.discounts = (discount ** np.arange(self.max_path_length))[:,None] 115 | 116 | ## [ n_paths x max_path_length x 1 ] 117 | self.values_segmented = np.zeros(self.rewards_segmented.shape, dtype=np.float32) 118 | self.validation_episodes = 0 119 | 120 | for t in range(max_path_length): 121 | ## [ n_paths x 1 ] 122 | V = (self.rewards_segmented[:,t+1:] * self.discounts[:-t-1]).sum(axis=1) 123 | self.values_segmented[:,t] = V 124 | 125 | ## add (r, V) to `joined` 126 | values_raw = self.values_segmented.squeeze(axis=-1).reshape(-1) 127 | values_mask = ~self.termination_flags.reshape(-1) 128 | self.values_raw = values_raw[values_mask, None] 129 | 130 | if normalize_raw and normalize_reward: 131 | self.value_mean, self.value_std = self.values_raw.mean(), self.values_raw.std() 132 | self.values_raw = (self.values_raw-self.value_mean) / self.value_std 133 | self.rewards_raw = (self.rewards_raw - self.reward_mean) / self.reward_std 134 | 135 | self.values_segmented = (self.values_segmented-self.value_mean) / self.value_std 136 | self.rewards_segmented = (self.rewards_segmented - self.reward_mean) / self.reward_std 137 | else: 138 | self.value_mean, self.value_std = np.array(0), np.array(1) 139 | 140 | self.joined_raw = np.concatenate([self.joined_raw, self.rewards_raw, self.values_raw], axis=-1) 141 | self.joined_segmented = np.concatenate([self.joined_segmented, self.rewards_segmented, self.values_segmented], axis=-1) 142 | 143 | self.train_portion = train_portion 144 | self.test_portion = 1.0 - train_portion 145 | ## get valid indices 146 | indices = [] 147 | test_indices = [] 148 | for path_ind, length in enumerate(self.path_lengths): 149 | end = length - 1 150 | split = int(end * self.train_portion) 151 | for i in range(end): 152 | if i < split: 153 | indices.append((path_ind, i, i+sequence_length)) 154 | else: 155 | test_indices.append((path_ind, i, i+sequence_length)) 156 | 157 | self.indices = np.array(indices) 158 | self.test_indices = np.array(test_indices) 159 | self.joined_dim = self.joined_raw.shape[1] 160 | 161 | ## pad trajectories 162 | n_trajectories, _, joined_dim = self.joined_segmented.shape 163 | self.joined_segmented = np.concatenate([ 164 | self.joined_segmented, 165 | np.zeros((n_trajectories, sequence_length-1, joined_dim), dtype=np.float32), 166 | ], axis=1) 167 | 168 | #self.joined_segmented_tensor = torch.tensor(self.joined_segmented, device=device) 169 | self.termination_flags = np.concatenate([ 170 | self.termination_flags, 171 | np.ones((n_trajectories, sequence_length-1), dtype=bool), 172 | ], axis=1) 173 | 174 | def denormalize(self, states, actions, rewards, values): 175 | states = states*self.obs_std + self.obs_mean 176 | actions = actions*self.act_std + self.act_mean 177 | rewards = rewards*self.reward_std + self.reward_mean 178 | values = values*self.value_std + self.value_mean 179 | return states, actions, rewards, values 180 | 181 | def normalize_joined_single(self, joined): 182 | joined_std = np.concatenate([self.obs_std[0], self.act_std[0], self.reward_std[None], self.value_std[None]]) 183 | joined_mean = np.concatenate([self.obs_mean[0], self.act_mean[0], self.reward_mean[None], self.value_mean[None]]) 184 | return (joined-joined_mean) / joined_std 185 | 186 | def denormalize_joined(self, joined): 187 | states = joined[:,:self.observation_dim] 188 | actions = joined[:,self.observation_dim:self.observation_dim+self.action_dim] 189 | rewards = joined[:,-3, None] 190 | values = joined[:,-2, None] 191 | results = self.denormalize(states, actions, rewards, values) 192 | return np.concatenate(results+(joined[:, -1, None],), axis=-1) 193 | 194 | 195 | def normalize_states(self, states): 196 | if torch.is_tensor(states): 197 | obs_std = torch.Tensor(self.obs_std).to(states.device) 198 | obs_mean = torch.Tensor(self.obs_mean).to(states.device) 199 | else: 200 | obs_std = np.squeeze(np.array(self.obs_std)) 201 | obs_mean = np.squeeze(np.array(self.obs_mean)) 202 | states = (states - obs_mean) / obs_std 203 | return states 204 | 205 | def denormalize_observations(self, observations): 206 | return self.denormalize_states(observations) 207 | 208 | 209 | def normalize_observations(self, observations): 210 | return self.normalize_states(observations) 211 | 212 | def denormalize_states(self, states): 213 | if torch.is_tensor(states): 214 | act_std = torch.Tensor(self.obs_std).to(states.device) 215 | act_mean = torch.Tensor(self.obs_mean).to(states.device) 216 | else: 217 | act_std = np.squeeze(np.array(self.obs_std)) 218 | act_mean = np.squeeze(np.array(self.obs_mean)) 219 | states = states * act_std + act_mean 220 | return states 221 | 222 | def denormalize_actions(self, actions): 223 | if torch.is_tensor(actions): 224 | act_std = torch.Tensor(self.act_std).to(actions.device) 225 | act_mean = torch.Tensor(self.act_mean).to(actions.device) 226 | else: 227 | act_std = np.squeeze(np.array(self.act_std)) 228 | act_mean = np.squeeze(np.array(self.act_mean)) 229 | actions = actions*act_std + act_mean 230 | return actions 231 | 232 | def normalize_actions(self, actions): 233 | if torch.is_tensor(actions): 234 | act_std = torch.Tensor(self.act_std).to(actions.device) 235 | act_mean = torch.Tensor(self.act_mean).to(actions.device) 236 | else: 237 | act_std = np.squeeze(np.array(self.act_std)) 238 | act_mean = np.squeeze(np.array(self.act_mean)) 239 | actions = (actions - act_mean) / act_std 240 | return actions 241 | 242 | def denormalize_rewards(self, rewards): 243 | if (not self.normalized_raw) or (not self.normalize_reward): 244 | return rewards 245 | if torch.is_tensor(rewards): 246 | reward_std = torch.Tensor([self.reward_std]).to(rewards.device) 247 | reward_mean = torch.Tensor([self.reward_mean]).to(rewards.device) 248 | else: 249 | reward_std = np.array([self.reward_std]) 250 | reward_mean = np.array([self.reward_mean]) 251 | rewards = rewards*reward_std + reward_mean 252 | return rewards 253 | 254 | def denormalize_values(self, values): 255 | if (not self.normalized_raw) or (not self.normalize_reward): 256 | return values 257 | if torch.is_tensor(values): 258 | value_std = torch.Tensor([self.value_std]).to(values.device) 259 | value_mean = torch.Tensor([self.value_mean]).to(values.device) 260 | else: 261 | value_std = np.array([self.value_std]) 262 | value_mean = np.array([self.value_mean]) 263 | values = values*value_std + value_mean 264 | return values 265 | 266 | def __len__(self): 267 | return len(self.indices) 268 | 269 | def __getitem__(self, idx): 270 | path_ind, start_ind, end_ind = self.indices[idx] 271 | 272 | joined = self.joined_segmented[path_ind, start_ind:end_ind:self.step] 273 | 274 | ## don't compute loss for parts of the prediction that extend 275 | ## beyond the max path length 276 | traj_inds = torch.arange(start_ind, end_ind, self.step) 277 | mask = torch.ones(joined.shape[:-1], dtype=torch.bool) 278 | mask[traj_inds > self.max_path_length - self.step] = 0 279 | terminal = 1-torch.cumprod(~torch.tensor(self.termination_flags[path_ind, start_ind:end_ind:self.step, None]), 280 | dim=1) 281 | 282 | terminal = terminal.float() 283 | # mask out the terminal state 284 | mask = mask[:, None].float()*(1-terminal) 285 | return joined, mask, terminal 286 | 287 | def get_test(self): 288 | Xs = [] 289 | Ys = [] 290 | masks = [] 291 | terminals = [] 292 | for path_ind, start_ind, end_ind in self.test_indices: 293 | joined = self.joined_segmented[path_ind, start_ind:end_ind:self.step] 294 | 295 | ## don't compute loss for parts of the prediction that extend 296 | ## beyond the max path length 297 | traj_inds = torch.arange(start_ind, end_ind, self.step) 298 | mask = torch.ones(joined.shape, dtype=torch.bool) 299 | mask[traj_inds > self.max_path_length - self.step] = 0 300 | terminal = 1 - torch.cumprod( 301 | ~torch.tensor(self.termination_flags[path_ind, start_ind:end_ind:self.step, None]), 302 | dim=1) 303 | 304 | ## flatten everything 305 | X = joined[:-1] 306 | Y = joined[1:] 307 | mask = mask[:-1] 308 | terminal = terminal[:-1] 309 | Xs.append(torch.tensor(X)) 310 | Ys.append(torch.tensor(Y)) 311 | masks.append(torch.tensor(mask)) 312 | terminals.append(torch.tensor(terminal)) 313 | return torch.stack(Xs), torch.stack(Ys), torch.stack(masks), torch.stack(terminals) 314 | 315 | def one_hot(a, num_classes): 316 | return np.squeeze(np.eye(num_classes, dtype=np.uint8)[a.reshape(-1)]) -------------------------------------------------------------------------------- /tasks/dmcontrol.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os.path as osp 8 | import numpy as np 9 | import tree 10 | import mujoco 11 | 12 | from typing import Any, Callable, Dict, Optional, Text, Tuple 13 | from dm_env import TimeStep 14 | from dm_control import composer 15 | from dm_control.locomotion.arenas import floors 16 | from dm_control.locomotion.mocap import cmu_mocap_data 17 | from dm_control.locomotion.mocap import loader 18 | from dm_control.locomotion.tasks.reference_pose import tracking 19 | from dm_control.locomotion.tasks.reference_pose import utils 20 | from dm_control.locomotion.walkers import cmu_humanoid 21 | from dm_control.locomotion.walkers import initializers 22 | from dm_control.suite.wrappers import action_noise 23 | from gym import core 24 | from gym import spaces 25 | from gym import core, spaces 26 | from dm_control import suite 27 | from dm_env import specs 28 | import numpy as np 29 | 30 | 31 | class StandInitializer(initializers.WalkerInitializer): 32 | def __init__(self, mode='random'): 33 | ref_path = cmu_mocap_data.get_path_for_cmu(version='2020') 34 | mocap_loader = loader.HDF5TrajectoryLoader(ref_path) 35 | self.mode = mode 36 | if mode=='fixed': 37 | clip_ids = ['CMU_040_12'] 38 | elif mode=='random': 39 | clip_ids = ['CMU_002_01', 'CMU_009_01', 'CMU_010_04', 'CMU_013_11', 'CMU_014_06', 'CMU_041_02', 40 | 'CMU_046_01', 'CMU_075_01', 'CMU_083_18', 'CMU_105_53', 'CMU_143_41', 'CMU_049_07',] 41 | else: 42 | raise NotImplementedError() 43 | 44 | self._stand_features = [] 45 | for clip_id in clip_ids: 46 | trajectory = mocap_loader.get_trajectory(clip_id) 47 | clip_reference_features = trajectory.as_dict() 48 | clip_reference_features = tracking._strip_reference_prefix(clip_reference_features, 'walker/') 49 | self._stand_features.append(tree.map_structure(lambda x: x, clip_reference_features)) 50 | 51 | def initialize_pose(self, physics, walker, random_state): 52 | clip_index = random_state.randint(0, len(self._stand_features)) 53 | if self.mode=='fixed': 54 | index = random_state.randint(0, 100) 55 | elif self.mode=='random': 56 | index = random_state.randint(0, 10) 57 | random_features = tree.map_structure(lambda x: x[index], self._stand_features[clip_index]) 58 | utils.set_walker_from_features(physics, walker, random_features) 59 | 60 | # Add gaussain noise to the current velocities. 61 | if self.mode=='random': 62 | velocity, angular_velocity = walker.get_velocity(physics) 63 | walker.set_velocity( 64 | physics, 65 | velocity=random_state.normal(0, 0.1, size=3) + velocity, 66 | angular_velocity=random_state.normal(0, 0.1, size=3)+angular_velocity) 67 | 68 | mujoco.mj_kinematics(physics.model.ptr, physics.data.ptr) 69 | 70 | 71 | def _spec_to_box(spec, dtype): 72 | def extract_min_max(s): 73 | assert s.dtype == np.float64 or s.dtype == np.float32 74 | dim = np.int(np.prod(s.shape)) 75 | if type(s) == specs.Array: 76 | bound = np.inf * np.ones(dim, dtype=np.float32) 77 | return -bound, bound 78 | elif type(s) == specs.BoundedArray: 79 | zeros = np.zeros(dim, dtype=np.float32) 80 | return s.minimum + zeros, s.maximum + zeros 81 | 82 | mins, maxs = [], [] 83 | for s in spec: 84 | mn, mx = extract_min_max(s) 85 | mins.append(mn) 86 | maxs.append(mx) 87 | low = np.concatenate(mins, axis=0).astype(dtype) 88 | high = np.concatenate(maxs, axis=0).astype(dtype) 89 | assert low.shape == high.shape 90 | return spaces.Box(low, high, dtype=dtype) 91 | 92 | 93 | def _flatten_obs(obs): 94 | obs_pieces = [] 95 | for v in obs.values(): 96 | flat = np.array([v]) if np.isscalar(v) else v.ravel() 97 | obs_pieces.append(flat) 98 | return np.concatenate(obs_pieces, axis=0) 99 | 100 | class DMCWrapper(core.Env): 101 | """ 102 | from https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/wrappers.py 103 | """ 104 | def __init__( 105 | self, 106 | domain_name, 107 | task_name, 108 | task_kwargs=None, 109 | from_pixels=False, 110 | height=84, 111 | width=84, 112 | camera_id=0, 113 | frame_skip=1, 114 | environment_kwargs=None, 115 | channels_first=True 116 | ): 117 | self._from_pixels = from_pixels 118 | self._height = height 119 | self._width = width 120 | self._camera_id = camera_id 121 | self._frame_skip = frame_skip 122 | self._channels_first = channels_first 123 | 124 | # create task 125 | self._env = suite.load( 126 | domain_name=domain_name, 127 | task_name=task_name, 128 | task_kwargs=task_kwargs, 129 | environment_kwargs=environment_kwargs 130 | ) 131 | 132 | # true and normalized action spaces 133 | self._true_action_space = _spec_to_box([self._env.action_spec()], np.float32) 134 | 135 | # create observation space 136 | if from_pixels: 137 | shape = [3, height, width] if channels_first else [height, width, 3] 138 | self._observation_space = spaces.Box( 139 | low=0, high=255, shape=shape, dtype=np.uint8 140 | ) 141 | else: 142 | self._observation_space = _spec_to_box( 143 | self._env.observation_spec().values(), 144 | np.float64 145 | ) 146 | 147 | self._state_space = _spec_to_box( 148 | self._env.observation_spec().values(), 149 | np.float64 150 | ) 151 | 152 | self.current_state = None 153 | 154 | def __getattr__(self, name): 155 | return getattr(self._env, name) 156 | 157 | @property 158 | def dm_env(self): 159 | return self._env 160 | 161 | @property 162 | def observation_space(self): 163 | return self._observation_space 164 | 165 | @property 166 | def state_space(self): 167 | return self._state_space 168 | 169 | @property 170 | def action_space(self): 171 | return self._true_action_space 172 | 173 | @property 174 | def reward_range(self): 175 | return 0, self._frame_skip 176 | 177 | def seed(self, seed): 178 | self._true_action_space.seed(seed) 179 | self._observation_space.seed(seed) 180 | 181 | def step(self, action): 182 | assert self._true_action_space.contains(action) 183 | reward = 0 184 | extra = {'internal_state': self._env.physics.get_state().copy()} 185 | 186 | for _ in range(self._frame_skip): 187 | time_step = self._env.step(action) 188 | reward += time_step.reward or 0 189 | done = time_step.last() 190 | if done: 191 | break 192 | obs = self.get_observation(time_step) 193 | self.current_state = _flatten_obs(time_step.observation) 194 | extra['discount'] = time_step.discount 195 | return obs, reward, done, extra 196 | 197 | def reset(self): 198 | time_step = self._env.reset() 199 | return time_step 200 | 201 | def get_observation(self, time_step: TimeStep) -> Dict[str, np.ndarray]: 202 | dm_obs = time_step.observation 203 | obs = _flatten_obs(dm_obs) 204 | return obs 205 | 206 | 207 | def render(self, mode='rgb_array', height=None, width=None, camera_id=0): 208 | assert mode == 'rgb_array', 'only support rgb_array mode, given %s' % mode 209 | height = height or self._height 210 | width = width or self._width 211 | camera_id = camera_id or self._camera_id 212 | return self._env.physics.render( 213 | height=height, width=width, camera_id=camera_id 214 | ) 215 | 216 | class CMUHumanoidGymWrapper(core.Env): 217 | """ 218 | Wraps the dm_control environment and task into a Gym env. The task assumes 219 | the presence of a CMU position-controlled humanoid. 220 | Adapted from: 221 | https://github.com/denisyarats/dmc2gym/blob/master/dmc2gym/wrappers.py 222 | """ 223 | 224 | metadata = {"render.modes": ["rgb_array"], "videos.frames_per_second": 30} 225 | 226 | def __init__( 227 | self, 228 | task_type: Callable[..., composer.Task], 229 | task_kwargs: Optional[Dict[str, Any]] = None, 230 | environment_kwargs: Optional[Dict[str, Any]] = None, 231 | act_noise: float = 0., 232 | arena_size: Tuple[float, float] = (8., 8.), 233 | 234 | # for rendering 235 | width: int = 640, 236 | height: int = 480, 237 | camera_id: int = 3 238 | ): 239 | """ 240 | task_kwargs: passed to the task constructor 241 | environment_kwargs: passed to composer.Environment constructor 242 | """ 243 | task_kwargs = task_kwargs or dict() 244 | environment_kwargs = environment_kwargs or dict() 245 | 246 | # create task 247 | self._env = self._create_env( 248 | task_type, 249 | task_kwargs, 250 | environment_kwargs, 251 | act_noise=act_noise, 252 | arena_size=arena_size 253 | ) 254 | self._original_rng_state = self._env.random_state.get_state() 255 | 256 | # Set observation and actions spaces 257 | self._observation_space = self._create_observation_space() 258 | action_spec = self._env.action_spec() 259 | dtype = np.float32 260 | self._action_space = spaces.Box( 261 | low=action_spec.minimum.astype(dtype), 262 | high=action_spec.maximum.astype(dtype), 263 | shape=action_spec.shape, 264 | dtype=dtype 265 | ) 266 | 267 | # set seed 268 | self.seed() 269 | 270 | self._height = height 271 | self._width = width 272 | self._camera_id = camera_id 273 | 274 | @staticmethod 275 | def make_env_constructor(task_type: Callable[..., composer.Task]): 276 | return lambda *args, **kwargs: CMUHumanoidGymWrapper(task_type, *args, **kwargs) 277 | 278 | def __getattr__(self, name: str) -> Any: 279 | return getattr(self._env, name) 280 | 281 | @property 282 | def dm_env(self) -> composer.Environment: 283 | return self._env 284 | 285 | @property 286 | def observation_space(self) -> spaces.Dict: 287 | return self._observation_space 288 | 289 | @property 290 | def action_space(self) -> spaces.Box: 291 | return self._action_space 292 | 293 | @property 294 | def np_random(self): 295 | return self._env.random_state 296 | 297 | def seed(self, seed: Optional[int] = None): 298 | if seed: 299 | srng = np.random.RandomState(seed=seed) 300 | self._env.random_state.set_state(srng.get_state()) 301 | else: 302 | self._env.random_state.set_state(self._original_rng_state) 303 | return self._env.random_state.get_state()[1] 304 | 305 | def _create_env( 306 | self, 307 | task_type, 308 | task_kwargs, 309 | environment_kwargs, 310 | act_noise=0., 311 | arena_size=(8., 8.) 312 | ) -> composer.Environment: 313 | walker = self._get_walker() 314 | arena = self._get_arena(arena_size) 315 | task = task_type( 316 | walker, 317 | arena, 318 | **task_kwargs 319 | ) 320 | env = composer.Environment( 321 | task=task, 322 | **environment_kwargs 323 | ) 324 | task.random = env.random_state # for action noise 325 | if act_noise > 0.: 326 | env = action_noise.Wrapper(env, scale=act_noise / 2) 327 | 328 | return env 329 | 330 | def _get_walker(self): 331 | directory = osp.dirname(osp.abspath(__file__)) 332 | initializer = StandInitializer() 333 | return cmu_humanoid.CMUHumanoidPositionControlledV2020(initializer=initializer) 334 | 335 | def _get_arena(self, arena_size): 336 | return floors.Floor(arena_size) 337 | 338 | def _create_observation_space(self) -> spaces.Dict: 339 | obs_spaces = dict() 340 | for k, v in self._env.observation_spec().items(): 341 | if v.dtype == np.float64 and np.prod(v.shape) > 0: 342 | if np.prod(v.shape) > 0: 343 | obs_spaces[k] = spaces.Box( 344 | -np.infty, 345 | np.infty, 346 | shape=(np.prod(v.shape),), 347 | dtype=np.float32 348 | ) 349 | elif v.dtype == np.uint8: 350 | tmp = v.generate_value() 351 | obs_spaces[k] = spaces.Box( 352 | v.minimum.item(), 353 | v.maximum.item(), 354 | shape=tmp.shape, 355 | dtype=np.uint8 356 | ) 357 | return spaces.Dict(obs_spaces) 358 | 359 | def get_observation(self, time_step: TimeStep) -> Dict[str, np.ndarray]: 360 | dm_obs = time_step.observation 361 | obs = dict() 362 | for k in self.observation_space.spaces: 363 | if self.observation_space[k].dtype == np.uint8: # image 364 | obs[k] = dm_obs[k].squeeze() 365 | else: 366 | obs[k] = dm_obs[k].ravel().astype(self.observation_space[k].dtype) 367 | return obs 368 | 369 | def step(self, action: np.ndarray) -> Tuple[Dict[str, np.ndarray], float, bool, Dict[str, Any]]: 370 | time_step = self._env.step(action) 371 | reward = time_step.reward or 0. 372 | done = time_step.last() 373 | obs = self.get_observation(time_step) 374 | info = dict( 375 | internal_state=self._env.physics.get_state().copy(), 376 | discount=time_step.discount 377 | ) 378 | return obs, reward, done, info 379 | 380 | def reset(self) -> Dict[str, np.ndarray]: 381 | time_step = self._env.reset() 382 | return self.get_observation(time_step) 383 | 384 | def render( 385 | self, 386 | mode: Text = 'rgb_array', 387 | height: Optional[int] = None, 388 | width: Optional[int] = None, 389 | camera_id: Optional[int] = None 390 | ) -> np.ndarray: 391 | assert mode == 'rgb_array', "This wrapper only supports rgb_array mode, given %s" % mode 392 | height = height or self._height 393 | width = width or self._width 394 | camera_id = camera_id or self._camera_id 395 | return self._env.physics.render(height=height, width=width, camera_id=camera_id) -------------------------------------------------------------------------------- /trajectory/tfds/tfds/mocapact/mocapact.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import List 8 | import dataclasses 9 | import os 10 | import pickle 11 | import tensorflow_datasets as tfds 12 | import tensorflow as tf 13 | import h5py 14 | import numpy as np 15 | 16 | from trajectory.tfds import mocap_utils 17 | 18 | _DESCRIPTION = """ 19 | MoCapAct: A Multi-Task Dataset for Simulated Humanoid Control 20 | """ 21 | _CITATION = """ 22 | @inproceedings{wagener2022mocapact, 23 | title={{MoCapAct: A Multi-Task Dataset for Simulated Humanoid Control}}, 24 | author={Wagener, Nolan and Kolobov, Andrey and Frujeri, Felipe Vieira and Loynd, Ricky and Cheng, Ching-An and Hausknecht, Matthew}, 25 | booktitle={Neural Information Processing Systems Datasets and Benchmarks Track}, 26 | year={2022} 27 | } 28 | """ 29 | 30 | _OBSERVABLE_SHAPES = { 31 | "walker/actuator_activation": 56, 32 | "walker/appendages_pos": 15, 33 | "walker/body_height": 1, 34 | "walker/end_effectors_pos": 12, 35 | "walker/gyro_anticlockwise_spin": 1, 36 | "walker/gyro_backward_roll": 1, 37 | "walker/gyro_control": 3, 38 | "walker/gyro_rightward_roll": 1, 39 | "walker/head_height": 1, 40 | "walker/joints_pos": 56, 41 | "walker/joints_vel": 56, 42 | "walker/joints_vel_control": 56, 43 | "walker/orientation": 9, 44 | "walker/position": 3, 45 | "walker/reference_appendages_pos": 75, 46 | "walker/reference_ego_bodies_quats": 620, 47 | "walker/reference_rel_bodies_pos_global": 465, 48 | "walker/reference_rel_bodies_pos_local": 465, 49 | "walker/reference_rel_bodies_quats": 620, 50 | "walker/reference_rel_joints": 280, 51 | "walker/reference_rel_root_pos_local": 15, 52 | "walker/reference_rel_root_quat": 20, 53 | "walker/sensors_accelerometer": 3, 54 | "walker/sensors_gyro": 3, 55 | "walker/sensors_torque": 6, 56 | "walker/sensors_touch": 10, 57 | "walker/sensors_velocimeter": 3, 58 | "walker/time_in_clip": 1, 59 | "walker/torso_xvel": 1, 60 | "walker/torso_yvel": 1, 61 | "walker/veloc_forward": 1, 62 | "walker/veloc_strafe": 1, 63 | "walker/veloc_up": 1, 64 | "walker/velocimeter_control": 3, 65 | "walker/world_zaxis": 3, 66 | } 67 | 68 | # TODO(yl): Shared the processing with the HDF5 datasource. 69 | def convert_to_rlds_format(episode): 70 | metadata = {k: v for k, v in episode.items() if k != "steps"} 71 | 72 | def _pad_last(nest): 73 | return tf.nest.map_structure( 74 | lambda x: np.concatenate([x, np.zeros_like(x[-1:])], axis=0), nest 75 | ) 76 | 77 | steps = { 78 | "observation": episode["steps"]["observation"], 79 | "action": _pad_last(episode["steps"]["action"]), 80 | "mean_action": _pad_last(episode["steps"]["mean_action"]), 81 | "reward": _pad_last(episode["steps"]["reward"]), 82 | "value": _pad_last(episode["steps"]["value"]), 83 | # Additional fields required by the RLDS dataset 84 | "is_first": np.concatenate( 85 | [ 86 | tf.ones(1, dtype=bool), 87 | tf.zeros_like(episode["steps"]["reward"], bool), 88 | ], 89 | axis=0, 90 | ), 91 | "is_last": np.concatenate( 92 | [ 93 | np.zeros_like(episode["steps"]["reward"], bool), 94 | np.ones(1, dtype=bool), 95 | ], 96 | axis=0, 97 | ), 98 | "is_terminal": np.concatenate( 99 | [ 100 | np.zeros_like(episode["steps"]["reward"], bool), 101 | np.expand_dims(episode["early_termination"], 0), 102 | ], 103 | axis=0, 104 | ), 105 | } 106 | 107 | return { 108 | "steps": steps, 109 | **metadata, 110 | } 111 | 112 | 113 | # TODO(yl): Shared the processing with the HDF5 datasource. 114 | def generate_episodes(path, observable_keys): 115 | reader = mocap_utils.MocapActTrajectoryReader(path) 116 | observable_indices = reader.observable_indices() 117 | # Using raw access to the HDF5 file is faster? 118 | # perhaps using the reader causes some GIL issues. 119 | h5_file = reader.h5_file 120 | n_rsi_rollouts = reader.n_rsi_rollouts() 121 | n_start_rollouts = reader.n_start_rollouts() 122 | snippet_names = reader.snippet_group_names() 123 | 124 | for snippet_name in snippet_names: 125 | rsi_metrics = _read_group(h5_file[f"{snippet_name}/rsi_metrics"]) 126 | start_metrics = _read_group(h5_file[f"{snippet_name}/start_metrics"]) 127 | early_terminations = h5_file[f"{snippet_name}/early_termination"][:] 128 | for episode_id in range(n_rsi_rollouts + n_start_rollouts): 129 | key = f"{snippet_name}/{episode_id}" 130 | # format: disable 131 | if episode_id < n_rsi_rollouts: 132 | stats = { 133 | "episode_return": rsi_metrics["episode_returns"][episode_id], 134 | "norm_episode_return": rsi_metrics["norm_episode_returns"][ 135 | episode_id 136 | ], 137 | "episode_length": rsi_metrics["episode_lengths"][episode_id], 138 | "norm_episode_length": rsi_metrics["norm_episode_lengths"][ 139 | episode_id 140 | ], 141 | "early_termination": early_terminations[episode_id], 142 | } 143 | else: 144 | i = episode_id - n_rsi_rollouts 145 | stats = { 146 | "episode_return": start_metrics["episode_returns"][i], 147 | "norm_episode_return": start_metrics["norm_episode_returns"][i], 148 | "episode_length": start_metrics["episode_lengths"][i], 149 | "norm_episode_length": start_metrics["norm_episode_lengths"][i], 150 | "early_termination": early_terminations[episode_id], 151 | } 152 | # format: enable 153 | mean_actions = h5_file[f"{key}/mean_actions"][:] 154 | actions = h5_file[f"{key}/actions"][:] 155 | flat_observations = h5_file[f"{key}/observations/proprioceptive"][:] 156 | observations = dict() 157 | for observable_name in observable_keys: 158 | idx = observable_indices[observable_name] 159 | observations[observable_name] = flat_observations[..., idx] 160 | rewards = h5_file[f"{key}/rewards"][:] 161 | values = h5_file[f"{key}/values"][:] 162 | yield key, { 163 | "steps": { 164 | "observation": observations, 165 | "action": actions, 166 | "mean_action": mean_actions, 167 | "reward": rewards, 168 | "value": values, 169 | }, 170 | "episode_id": key, 171 | **stats, 172 | } 173 | 174 | h5_file.close() 175 | 176 | 177 | def _float_feature(shape, dtype, encoding=tfds.features.Encoding.ZLIB): 178 | return tfds.features.Tensor(shape=shape, dtype=dtype, encoding=encoding) 179 | 180 | 181 | @dataclasses.dataclass 182 | class MocapactBuilderConfig(tfds.core.BuilderConfig): 183 | """Configuration of the dataset generation process.""" 184 | 185 | # Prefix in the download directory of MoCapAct 186 | # See https://github.com/microsoft/MoCapAct/blob/main/mocapact/download_dataset.py 187 | # Valid values are small/large 188 | prefix: str = "small" 189 | # Used for filtering observables used in the dataset (to reduce file size). 190 | observables: List[str] = mocap_utils.CMU_HUMANOID_OBSERVABLES 191 | 192 | 193 | def _read_metrics(observable_indices, path: str): 194 | metrics_npz = np.load(path, allow_pickle=True) 195 | # Split proprio normalization stats based on observable_keys 196 | obs_mean = {} 197 | obs_std = {} 198 | metrics = {} 199 | 200 | for observable_name, indices in observable_indices.items(): 201 | obs_mean[observable_name] = metrics_npz["proprio_mean"][..., indices] 202 | obs_std[observable_name] = np.sqrt(metrics_npz["proprio_var"][..., indices]) 203 | 204 | metrics["proprio_mean"] = obs_mean 205 | metrics["proprio_std"] = obs_std 206 | 207 | metrics["act_mean"] = metrics_npz["act_mean"] 208 | metrics["act_std"] = np.sqrt(metrics_npz["act_var"]) + 1e-4 209 | metrics["mean_act_mean"] = metrics_npz["mean_act_mean"] 210 | metrics["mean_act_std"] = np.sqrt(metrics_npz["mean_act_var"]) 211 | 212 | metrics["values"] = metrics_npz["values"].item() 213 | metrics["advantages"] = metrics_npz["advantages"].item() 214 | metrics["snippet_returns"] = metrics_npz["snippet_returns"].item() 215 | metrics_npz.close() 216 | return metrics 217 | 218 | 219 | class MocapactMetadata(tfds.core.Metadata, dict): 220 | """MocapAct metrics saved as metadata""" 221 | 222 | def save_metadata(self, data_dir): 223 | """Save the metadata.""" 224 | if "metrics" in self.keys(): 225 | metrics_path = os.path.join(data_dir, "dataset_metrics.npz") 226 | with open(metrics_path, 'wb') as f: 227 | pickle.dump(self["metrics"], f) 228 | 229 | def load_metadata(self, data_dir): 230 | """Restore the metadata.""" 231 | self.clear() 232 | metrics_path = os.path.join(data_dir, "dataset_metrics.npz") 233 | if os.path.exists(metrics_path): 234 | with open(metrics_path, 'rb') as f: 235 | metrics = pickle.load(f) 236 | self.update({"metrics": metrics}) 237 | 238 | 239 | class Mocapact(tfds.core.GeneratorBasedBuilder): 240 | """DatasetBuilder for mocapact dataset.""" 241 | 242 | VERSION = tfds.core.Version("1.0.0") 243 | RELEASE_NOTES = { 244 | "1.0.0": "Initial release.", 245 | } 246 | 247 | # NOTE: BUILDER_CONFIGS are used to specify the options 248 | # for building different processed datasets. 249 | # For now, only building the small dataset 250 | # with reduced CMU_HUMANOID observables is included. 251 | # To support additional configurations, append to the list by 252 | # 1. Change `prefix` to 'large' to build the large dataset. 253 | # 2. Change `observables` to all observables if you want to include more 254 | # observations. 255 | # 3. Give a unique name to the config so that we can use tfds.load 256 | # for the different configuration. 257 | BUILDER_CONFIGS = [ 258 | MocapactBuilderConfig( 259 | prefix="small", 260 | observables=mocap_utils.CMU_HUMANOID_OBSERVABLES, 261 | name="small_cmu_observable", 262 | ), 263 | MocapactBuilderConfig( 264 | prefix="large", 265 | observables=mocap_utils.CMU_HUMANOID_OBSERVABLES, 266 | name="large_cmu_observable", 267 | ), 268 | ] 269 | 270 | MANUAL_DOWNLOAD_INSTRUCTIONS = """ 271 | Download and put the MocapAct dataset in manual_dir 272 | """ 273 | 274 | def _info(self) -> tfds.core.DatasetInfo: 275 | """Returns the dataset metadata.""" 276 | # TODO(mocapact): Specifies the tfds.core.DatasetInfo object 277 | observable_keys = self.builder_config.observables 278 | 279 | # Compress observation and actions on disk. 280 | # There are probably no disk saving as these arrays are not 281 | # easily compressible. However, this will encode the tensors 282 | # in bytes which takes up less space compared to the default. 283 | observation_spec = tfds.features.FeaturesDict( 284 | { 285 | key: _float_feature(shape=(_OBSERVABLE_SHAPES[key],), dtype=tf.float32) 286 | for key in observable_keys 287 | } 288 | ) 289 | action_spec = tfds.features.Tensor( 290 | shape=(56,), dtype=tf.float32, encoding=tfds.features.Encoding.ZLIB 291 | ) 292 | steps_dict = { 293 | "observation": observation_spec, 294 | "action": action_spec, 295 | "mean_action": action_spec, 296 | "reward": tfds.features.Tensor(shape=(), dtype=tf.float32), 297 | "value": tfds.features.Tensor(shape=(), dtype=tf.float32), 298 | # Below are fields required by RLDS. 299 | "is_first": tfds.features.Tensor(shape=(), dtype=tf.bool), 300 | "is_last": tfds.features.Tensor(shape=(), dtype=tf.bool), 301 | "is_terminal": tfds.features.Tensor(shape=(), dtype=tf.bool), 302 | } 303 | return tfds.core.DatasetInfo( 304 | builder=self, 305 | description=_DESCRIPTION, 306 | features=tfds.features.FeaturesDict( 307 | { 308 | # These are the features of your dataset like images, labels ... 309 | "steps": tfds.features.Dataset(steps_dict), 310 | "episode_id": tfds.features.Tensor(shape=(), dtype=tf.string), 311 | "episode_return": tf.float32, 312 | "norm_episode_return": tf.float32, 313 | "episode_length": tf.int64, 314 | "norm_episode_length": tf.int64, 315 | "early_termination": tf.bool, 316 | } 317 | ), 318 | # If there's a common (input, target) tuple from the 319 | # features, specify them here. They'll be used if 320 | # `as_supervised=True` in `builder.as_dataset`. 321 | supervised_keys=None, # Set to `None` to disable 322 | homepage="https://microsoft.github.io/MoCapAct", 323 | citation=_CITATION, 324 | metadata=MocapactMetadata(), 325 | ) 326 | 327 | def _split_generators(self, dl_manager: tfds.download.DownloadManager): 328 | """Returns SplitGenerators.""" 329 | path = dl_manager.manual_dir 330 | metrics_path = path / "dataset_metrics.npz" 331 | # save metadata 332 | # TODO(yl): This is probably a bad way to save the metrics. 333 | # Figure out a better way to split the metrics. 334 | reference_filename = list(path.glob("*.hdf5"))[0] 335 | reader = mocap_utils.MocapActTrajectoryReader(reference_filename) 336 | observable_indices = reader.observable_indices() 337 | metrics = _read_metrics(observable_indices, metrics_path) 338 | self.info.metadata["metrics"] = metrics 339 | # TODO(mocapact): Returns the Dict[split names, Iterator[Key, Example]] 340 | return { 341 | "train": self._generate_examples(path), 342 | } 343 | 344 | def _generate_examples(self, path): 345 | """Yields examples.""" 346 | # TODO(mocapact): Yields (key, example) tuples from the dataset 347 | for path in path.glob("*.hdf5"): 348 | for key, episode in generate_episodes( 349 | path, self.builder_config.observables 350 | ): 351 | yield key, convert_to_rlds_format(episode) 352 | 353 | 354 | def _read_group(dataset_file): 355 | dataset_dict = {} 356 | for k in _get_dataset_keys(dataset_file): 357 | try: 358 | # first try loading as an array 359 | dataset_dict[k] = dataset_file[k][:] 360 | except ValueError: # try loading as a scalar 361 | dataset_dict[k] = dataset_file[k][()] 362 | return dataset_dict 363 | 364 | 365 | def _get_dataset_keys(h5file): 366 | """Gets the keys present in the D4RL dataset.""" 367 | keys = [] 368 | 369 | def visitor(name, item): 370 | if isinstance(item, h5py.Dataset): 371 | keys.append(name) 372 | 373 | h5file.visititems(visitor) 374 | return keys 375 | --------------------------------------------------------------------------------