├── .gitignore ├── metamotivo ├── misc │ ├── __init__.py │ └── zbuffer.py ├── buffers │ ├── __init__.py │ └── buffers.py ├── wrappers │ ├── __init__.py │ └── humenvbench.py ├── fb │ ├── __init__.py │ ├── huggingface.py │ ├── model.py │ └── agent.py ├── fb_cpr │ ├── __init__.py │ ├── huggingface.py │ ├── model.py │ └── agent.py ├── __init__.py └── nn_models.py ├── pyproject.toml ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── examples ├── README.md ├── fb_train_dmc.py └── fbcpr_train_humenv.py ├── tutorial_benchmark.ipynb ├── README.md ├── tutorial.ipynb └── LICENSE /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /metamotivo/misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metamotivo/buffers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metamotivo/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metamotivo/fb/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .model import FBModel 7 | from .model import Config as FBModelConfig 8 | from .agent import Config as FBAgentConfig 9 | from .agent import FBAgent 10 | -------------------------------------------------------------------------------- /metamotivo/fb_cpr/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from .model import FBcprModel 7 | from .model import Config as FBcprModelConfig 8 | from .agent import Config as FBcprAgentConfig 9 | from .agent import FBcprAgent 10 | -------------------------------------------------------------------------------- /metamotivo/fb/huggingface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from huggingface_hub import PyTorchModelHubMixin 7 | from .model import FBModel as BaseFBModel 8 | 9 | 10 | class FBModel( 11 | BaseFBModel, 12 | PyTorchModelHubMixin, 13 | library_name="metamotivo", 14 | tags=["facebook", "meta", "pytorch"], 15 | license="cc-by-nc-4.0", 16 | repo_url="https://github.com/facebookresearch/metamotivo", 17 | docs_url="https://metamotivo.metademolab.com/", 18 | ): ... 19 | -------------------------------------------------------------------------------- /metamotivo/fb_cpr/huggingface.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from huggingface_hub import PyTorchModelHubMixin 7 | from .model import FBcprModel as BaseFBcprModel 8 | 9 | class FBcprModel( 10 | BaseFBcprModel, 11 | PyTorchModelHubMixin, 12 | library_name="metamotivo", 13 | tags=["facebook", "meta", "pytorch"], 14 | license="cc-by-nc-4.0", 15 | repo_url="https://github.com/facebookresearch/metamotivo", 16 | docs_url="https://metamotivo.metademolab.com/", 17 | ): ... 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "metamotivo" 3 | dynamic = ["version"] 4 | description = "Inference and Training of FB-CPR" 5 | readme = "README.md" 6 | requires-python = ">=3.10" 7 | dependencies = [ 8 | "safetensors>=0.4.5", 9 | "torch>=2.3", 10 | "tyro>=0.9.0", 11 | "wandb>=0.19.0" 12 | ] 13 | 14 | [project.urls] 15 | Homepage = "https://github.com/facebookresearch/metamotivo" 16 | 17 | [project.optional-dependencies] 18 | humenv=[ 19 | "humenv[bench] @ git+https://github.com/facebookresearch/humenv.git", 20 | ] 21 | huggingface = [ 22 | "huggingface-hub[cli,torch]>=0.26.3", 23 | ] 24 | all = [ 25 | "huggingface-hub[cli,torch]>=0.26.3", 26 | "humenv[bench] @ git+https://github.com/facebookresearch/humenv.git", 27 | "tensordict>=0.6.0" 28 | ] 29 | 30 | [tool.setuptools.dynamic] 31 | version = {attr = "metamotivo.__version__"} 32 | 33 | [tool.setuptools.packages.find] 34 | include = ["metamotivo*"] 35 | 36 | [tool.ruff] 37 | line-length = 140 38 | -------------------------------------------------------------------------------- /metamotivo/misc/zbuffer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | import numpy as np 8 | from typing import Union 9 | 10 | 11 | class ZBuffer: 12 | def __init__(self, capacity: int, dim: int, device: Union[torch.device, str], dtype: torch.dtype = torch.float32): 13 | self._storage = torch.zeros((capacity, dim), device=device, dtype=dtype) 14 | self._idx = 0 15 | self._is_full = False 16 | self.capacity = capacity 17 | self.device = device 18 | 19 | def __len__(self) -> int: 20 | return self.capacity if self._is_full else self._idx 21 | 22 | def empty(self) -> bool: 23 | return self._idx == 0 and not self._is_full 24 | 25 | def add(self, data: torch.Tensor) -> None: 26 | if self._idx + data.shape[0] >= self.capacity: 27 | diff = self.capacity - self._idx 28 | self._storage[self._idx : self._idx + data.shape[0]] = data[:diff] 29 | self._storage[: data.shape[0] - diff] = data[diff:] 30 | self._is_full = True 31 | else: 32 | self._storage[self._idx : self._idx + data.shape[0]] = data 33 | self._idx = (self._idx + data.shape[0]) % self.capacity 34 | 35 | def sample(self, num, device=None) -> torch.Tensor: 36 | idx = np.random.randint(0, len(self), size=num) 37 | return self._storage[idx].clone().to(device if device is not None else self.device) 38 | -------------------------------------------------------------------------------- /metamotivo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import dataclasses 7 | from typing import Any, Dict 8 | from collections.abc import Mapping 9 | from pathlib import Path 10 | from typing import Any 11 | import torch 12 | import json 13 | import safetensors.torch 14 | 15 | 16 | def load_model(path: str, device: str | None, cls: Any): 17 | model_dir = Path(path) 18 | with (model_dir / "config.json").open() as f: 19 | loaded_config = json.load(f) 20 | if device is not None: 21 | loaded_config["device"] = device 22 | loaded_agent = cls(**loaded_config) 23 | safetensors.torch.load_model(loaded_agent, model_dir / "model.safetensors", device=device) 24 | # loaded_agent.load_state_dict( 25 | # torch.load(model_dir / "model.pt", weights_only=True, map_location=device) 26 | # ) 27 | return loaded_agent 28 | 29 | 30 | def dict_to_config(source: Mapping, target: Any): 31 | target_fields = {field.name for field in dataclasses.fields(target)} 32 | for field in target_fields: 33 | if field in source.keys() and dataclasses.is_dataclass(getattr(target, field)): 34 | dict_to_config(source[field], getattr(target, field)) 35 | elif field in source.keys(): 36 | setattr(target, field, source[field]) 37 | else: 38 | print(f"[WARNING] field {field} not found in source config") 39 | 40 | 41 | def config_from_dict(source: Dict, config_class: Any): 42 | target = config_class() 43 | dict_to_config(source, target) 44 | return target 45 | 46 | 47 | __version__ = "0.1.2" 48 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to Meta Motivo 2 | We want to make contributing to this project as easy and transparent as possible. 3 | 4 | ## Installing the library 5 | Install the library as suggested in the README. 6 | 7 | ## Formatting your code 8 | **Type annotation** 9 | 10 | Meta Motivo is not strongly-typed, i.e. we do not enforce type hints, neither do we check that the ones that are present are valid. We rely on type hints purely for documentary purposes. Although this might change in the future, there is currently no need for this to be enforced at the moment. 11 | 12 | **Formatting** 13 | 14 | Before your PR is ready, you'll probably want your code to be checked. This can be done easily by installing 15 | ``` 16 | ruff format 17 | ``` 18 | and 19 | ``` 20 | ruff check 21 | ``` 22 | from within the Meta Motivo cloned directory. 23 | 24 | ## Pull Requests 25 | We actively welcome your pull requests. 26 | 27 | 1. Fork the repo and create your branch from `main`. 28 | 2. If you've added code that should be tested, add tests. 29 | 3. If you've changed APIs, update the documentation. 30 | 4. Ruff format and check the code. 31 | 5. Ensure the test suite pass. 32 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 33 | 34 | When submitting a PR, we encourage you to link it to the related issue (if any) and add some tags to it. 35 | 36 | ## Contributor License Agreement ("CLA") 37 | In order to accept your pull request, we need you to submit a CLA. You only need 38 | to do this once to work on any of Facebook's open source projects. 39 | 40 | Complete your CLA here: 41 | 42 | ## Issues 43 | We use GitHub issues to track public bugs. Please ensure your description is 44 | clear and has sufficient instructions to be able to reproduce the issue. 45 | 46 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 47 | disclosure of security bugs. In those cases, please go through the process 48 | outlined on that page and do not file a public issue. 49 | 50 | ## License 51 | By contributing to Meta Motivo, you agree that your contributions will be licensed 52 | under the LICENSE file in the root directory of this source tree. 53 | -------------------------------------------------------------------------------- /metamotivo/fb_cpr/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import dataclasses 7 | from ..fb.model import FBModel 8 | from ..fb.model import Config as FBConfig 9 | from ..fb.model import ArchiConfig as FBArchiConfig 10 | from ..nn_models import build_forward, build_discriminator 11 | from .. import config_from_dict 12 | import torch 13 | import copy 14 | 15 | 16 | @dataclasses.dataclass 17 | class CriticArchiConfig: 18 | hidden_dim: int = 1024 19 | model: str = "simple" # {'simple', 'residual'} 20 | hidden_layers: int = 1 21 | embedding_layers: int = 2 22 | num_parallel: int = 2 23 | ensemble_mode: str = "batch" # {'batch', 'seq', 'vmap'} 24 | 25 | 26 | @dataclasses.dataclass 27 | class DiscriminatorArchiConfig: 28 | hidden_dim: int = 1024 29 | hidden_layers: int = 2 30 | 31 | 32 | @dataclasses.dataclass 33 | class ArchiConfig(FBArchiConfig): 34 | critic: CriticArchiConfig = dataclasses.field(default_factory=CriticArchiConfig) 35 | discriminator: DiscriminatorArchiConfig = dataclasses.field(default_factory=DiscriminatorArchiConfig) 36 | 37 | 38 | @dataclasses.dataclass 39 | class Config(FBConfig): 40 | archi: ArchiConfig = dataclasses.field(default_factory=ArchiConfig) 41 | 42 | 43 | class FBcprModel(FBModel): 44 | def __init__(self, **kwargs): 45 | super().__init__(**kwargs) 46 | self.cfg = config_from_dict(kwargs, Config) 47 | self._discriminator = build_discriminator(self.cfg.obs_dim, self.cfg.archi.z_dim, self.cfg.archi.discriminator) 48 | self._critic = build_forward(self.cfg.obs_dim, self.cfg.archi.z_dim, self.cfg.action_dim, self.cfg.archi.critic, output_dim=1) 49 | 50 | # make sure the model is in eval mode and never computes gradients 51 | self.train(False) 52 | self.requires_grad_(False) 53 | self.to(self.cfg.device) 54 | 55 | def _prepare_for_train(self) -> None: 56 | super()._prepare_for_train() 57 | self._target_critic = copy.deepcopy(self._critic) 58 | 59 | @torch.no_grad() 60 | def critic(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor): 61 | return self._critic(self._normalize(obs), z, action) 62 | 63 | @torch.no_grad() 64 | def discriminator(self, obs: torch.Tensor, z: torch.Tensor): 65 | return self._discriminator(self._normalize(obs), z) 66 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | We provide a few examples on how to use the Meta Motivo repository. 4 | 5 | ## FB: Offline training with ExoRL datasets 6 | 7 | [ExoRL](https://github.com/denisyarats/exorl) has been widely used to train offline algorithms. We provide the code for training FB on standard domains such as `walker`, `cheetah`, `quadruped` and `pointmass`. We use the standard tasks in `dm_control`, but you can easily update the script to run the full set of tasks defined in `ExoRL` or in the paper [Fast Imitation via Behavior Foundation Models](https://openreview.net/forum?id=qnWtw3l0jb). We will provide more details below. 8 | 9 | To use the provided script you can simply run from terminal 10 | 11 | ```bash 12 | python fb_train_dmc.py --domain_name walker --dataset_root 13 | ``` 14 | 15 | The standard folder structure of ExoRL is `/datasets/${DOMAIN}/${ALGO}/buffer` so we expect `dataset_root=/datasets`. Since the original creation of ExORL, mujoco has seen many updates. To rerun all the actions and collect a physics consistent data, you may optionally replay the trajectories. We refer to [https://github.com/facebookresearch/mtm/tree/main/research/exorl](https://github.com/facebookresearch/mtm/tree/main/research/exorl) for this. 16 | 17 | If you want to run auxiliary tasks and domains such as `walker_flip` or `pointmass` we suggest to download the files from [https://github.com/facebookresearch/offline_rl/tree/main/src/dmc_tasks](https://github.com/facebookresearch/offline_rl/tree/main/src/dmc_tasks) into `examples/dmc_tasks`. You can thus simply modify `fb_train_dmc.py` as follows: 18 | 19 | - add import 20 | ``` 21 | from dmc_tasks import dmc 22 | ``` 23 | - add new tasks 24 | ``` 25 | ALL_TASKS = { 26 | "walker": ["walk", "run", "stand", "flip", "spin"], 27 | "cheetah": ["walk", "run", "walk_backward", "run_backward"], 28 | "pointmass": ["reach_top_left", "reach_top_right", "reach_bottom_right", "reach_bottom_left", "loop", "square", "fast_slow"], 29 | "quadruped": ["jump", "walk", "run", "stand"], 30 | } 31 | ``` 32 | - use `dmc.make` for environment creation. For example, replace `suite.load(domain_name=self.cfg.domain_name,task_name=task,environment_kwargs={"flat_observation": True},)` with `dmc.make(f"{self.cfg.domain_name}_{task}")`. 33 | - This changes the way of getting the observation from `time_step.observation["observations"]` to simply `time_step.observation`. Update the file accordingly. 34 | 35 | 36 | ## FB-CPR: Online training with HumEnv 37 | 38 | We provide a complete code for training FB-CPR as described in the paper [Zero-Shot Whole-Body Humanoid Control via Behavioral Foundation Models](https://ai.meta.com/research/publications/zero-shot-whole-body-humanoid-control-via-behavioral-foundation-models/). 39 | 40 | **IMPORTANT!** We assume you have already preprocessed the AMASS motions as described [here](https://github.com/facebookresearch/humenv/tree/main/data_preparation). In addition, we assume you also downloaded the `test_train_split` sub-folder. 41 | 42 | The script is setup with the S configuration (i.e., paper configuration) and can be run by simply calling 43 | 44 | ```bash 45 | python fbcpr_train_humenv.py --compile --motions test_train_split/large1_small1_train_0.1.txt --motions_root --prioritization 46 | ``` 47 | 48 | There are several parameters that can be changed to do evaluation more modular, checkpoint the models, etc. We refer to the code for more details. 49 | 50 | If you would like to train our largest model (the one deployed in the [demo](https://metamotivo.metademolab.com/)), replace the following line 51 | 52 | ``` 53 | model, hidden_dim, hidden_layers = "simple", 1024, 2 54 | ``` 55 | 56 | with 57 | 58 | ``` 59 | model, hidden_dim, hidden_layers = "residual", 2048, 12 60 | ``` 61 | 62 | NOTE: we recommend that you use compile=True on a A100 GPU or better, as otherwise training can be very slow. 63 | -------------------------------------------------------------------------------- /metamotivo/fb/model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import math 7 | import dataclasses 8 | import numpy as np 9 | import torch 10 | from torch import nn 11 | import torch.nn.functional as F 12 | import copy 13 | from pathlib import Path 14 | from safetensors.torch import save_model as safetensors_save_model 15 | import json 16 | 17 | from ..nn_models import build_backward, build_forward, build_actor, eval_mode 18 | from .. import config_from_dict, load_model 19 | 20 | 21 | @dataclasses.dataclass 22 | class ActorArchiConfig: 23 | hidden_dim: int = 1024 24 | model: str = "simple" # {'simple', 'residual'} 25 | hidden_layers: int = 1 26 | embedding_layers: int = 2 27 | 28 | 29 | @dataclasses.dataclass 30 | class ForwardArchiConfig: 31 | hidden_dim: int = 1024 32 | model: str = "simple" # {'simple', 'residual'} 33 | hidden_layers: int = 1 34 | embedding_layers: int = 2 35 | num_parallel: int = 2 36 | ensemble_mode: str = "batch" # {'batch', 'seq', 'vmap'} 37 | 38 | 39 | @dataclasses.dataclass 40 | class BackwardArchiConfig: 41 | hidden_dim: int = 256 42 | hidden_layers: int = 2 43 | norm: bool = True 44 | 45 | 46 | @dataclasses.dataclass 47 | class ArchiConfig: 48 | z_dim: int = 100 49 | norm_z: bool = True 50 | f: ForwardArchiConfig = dataclasses.field(default_factory=ForwardArchiConfig) 51 | b: BackwardArchiConfig = dataclasses.field(default_factory=BackwardArchiConfig) 52 | actor: ActorArchiConfig = dataclasses.field(default_factory=ActorArchiConfig) 53 | 54 | 55 | @dataclasses.dataclass 56 | class Config: 57 | obs_dim: int = -1 58 | action_dim: int = -1 59 | device: str = "cpu" 60 | archi: ArchiConfig = dataclasses.field(default_factory=ArchiConfig) 61 | inference_batch_size: int = 500_000 62 | seq_length: int = 1 63 | actor_std: float = 0.2 64 | norm_obs: bool = True 65 | 66 | class FBModel(nn.Module): 67 | def __init__(self, **kwargs): 68 | super().__init__() 69 | self.cfg = config_from_dict(kwargs, Config) 70 | obs_dim, action_dim = self.cfg.obs_dim, self.cfg.action_dim 71 | arch = self.cfg.archi 72 | 73 | # create networks 74 | self._backward_map = build_backward(obs_dim, arch.z_dim, arch.b) 75 | self._forward_map = build_forward(obs_dim, arch.z_dim, action_dim, arch.f) 76 | self._actor = build_actor(obs_dim, arch.z_dim, action_dim, arch.actor) 77 | self._obs_normalizer = nn.BatchNorm1d(obs_dim, affine=False, momentum=0.01) if self.cfg.norm_obs else nn.Identity() 78 | 79 | # make sure the model is in eval mode and never computes gradients 80 | self.train(False) 81 | self.requires_grad_(False) 82 | self.to(self.cfg.device) 83 | 84 | def _prepare_for_train(self) -> None: 85 | # create TARGET networks 86 | self._target_backward_map = copy.deepcopy(self._backward_map) 87 | self._target_forward_map = copy.deepcopy(self._forward_map) 88 | 89 | def to(self, *args, **kwargs): 90 | device, _, _, _ = torch._C._nn._parse_to(*args, **kwargs) 91 | if device is not None: 92 | self.cfg.device = device.type # type: ignore 93 | return super().to(*args, **kwargs) 94 | 95 | @classmethod 96 | def load(cls, path: str, device: str | None = None): 97 | return load_model(path, device, cls=cls) 98 | 99 | def save(self, output_folder: str) -> None: 100 | output_folder = Path(output_folder) 101 | output_folder.mkdir(exist_ok=True) 102 | safetensors_save_model(self, output_folder / "model.safetensors") 103 | with (output_folder / "config.json").open("w+") as f: 104 | json.dump(dataclasses.asdict(self.cfg), f, indent=4) 105 | 106 | def _normalize(self, obs: torch.Tensor): 107 | with torch.no_grad(), eval_mode(self._obs_normalizer): 108 | return self._obs_normalizer(obs) 109 | 110 | @torch.no_grad() 111 | def backward_map(self, obs: torch.Tensor): 112 | return self._backward_map(self._normalize(obs)) 113 | 114 | @torch.no_grad() 115 | def forward_map(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor): 116 | return self._forward_map(self._normalize(obs), z, action) 117 | 118 | @torch.no_grad() 119 | def actor(self, obs: torch.Tensor, z: torch.Tensor, std: float): 120 | return self._actor(self._normalize(obs), z, std) 121 | 122 | def sample_z(self, size: int, device: str = "cpu") -> torch.Tensor: 123 | z = torch.randn((size, self.cfg.archi.z_dim), dtype=torch.float32, device=device) 124 | return self.project_z(z) 125 | 126 | def project_z(self, z): 127 | if self.cfg.archi.norm_z: 128 | z = math.sqrt(z.shape[-1]) * F.normalize(z, dim=-1) 129 | return z 130 | 131 | def act(self, obs: torch.Tensor, z: torch.Tensor, mean: bool = True) -> torch.Tensor: 132 | dist = self.actor(obs, z, self.cfg.actor_std) 133 | if mean: 134 | return dist.mean 135 | return dist.sample() 136 | 137 | def reward_inference(self, next_obs: torch.Tensor, reward: torch.Tensor, weight: torch.Tensor | None = None) -> torch.Tensor: 138 | num_batches = int(np.ceil(next_obs.shape[0] / self.cfg.inference_batch_size)) 139 | z = 0 140 | wr = reward if weight is None else reward * weight 141 | for i in range(num_batches): 142 | start_idx, end_idx = i * self.cfg.inference_batch_size, (i + 1) * self.cfg.inference_batch_size 143 | B = self.backward_map(next_obs[start_idx:end_idx].to(self.cfg.device)) 144 | z += torch.matmul(wr[start_idx:end_idx].to(self.cfg.device).T, B) 145 | return self.project_z(z) 146 | 147 | def reward_wr_inference(self, next_obs: torch.Tensor, reward: torch.Tensor) -> torch.Tensor: 148 | return self.reward_inference(next_obs, reward, F.softmax(10 * reward, dim=0)) 149 | 150 | def goal_inference(self, next_obs: torch.Tensor) -> torch.Tensor: 151 | z = self.backward_map(next_obs) 152 | return self.project_z(z) 153 | 154 | def tracking_inference(self, next_obs: torch.Tensor) -> torch.Tensor: 155 | z = self.backward_map(next_obs) 156 | for step in range(z.shape[0]): 157 | end_idx = min(step + self.cfg.seq_length, z.shape[0]) 158 | z[step] = z[step:end_idx].mean(dim=0) 159 | return self.project_z(z) 160 | -------------------------------------------------------------------------------- /tutorial_benchmark.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7e714762-bf94-474a-a4b0-154d8715c9a0", 6 | "metadata": {}, 7 | "source": [ 8 | "# Meta Motivo benchmarking using HumEnv\n", 9 | "\n", 10 | "This notebook shows how to evaluate a Meta Motivo model using the benchmark proposed in HumEnv. It assumes that motions for tracking and poses for goal reaching have been processed by following the [instructions](https://github.com/facebookresearch/humenv/tree/main/data_preparation) in HumEnv and are available in the folder `MOTIONS_BASE_PATH`." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "id": "89840e8c-ac6d-4fd1-880f-53e034b56f36", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from metamotivo.fb_cpr.huggingface import FBcprModel\n", 21 | "from metamotivo.wrappers.humenvbench import RewardWrapper, TrackingWrapper, GoalWrapper\n", 22 | "from metamotivo.buffers.buffers import DictBuffer\n", 23 | "from huggingface_hub import hf_hub_download\n", 24 | "import h5py\n", 25 | "import json\n", 26 | "import numpy as np\n", 27 | "from humenv import STANDARD_TASKS\n", 28 | "from humenv.bench import (\n", 29 | " RewardEvaluation,\n", 30 | " GoalEvaluation,\n", 31 | " TrackingEvaluation,\n", 32 | ")\n", 33 | "\n", 34 | "# paths where to find the output of HumEnv's data preparation scripts\n", 35 | "MOTIONS_BASE_PATH = \"humenv/data_preparation/humenv_amass\"\n", 36 | "MOTIONS_TRACKING = \"humenv/data_preparation/test_train_split/large1_small1_test_0.1.txt\"\n", 37 | "GOAL_POSES = \"humenv/data_preparation/goal_poses/goals.json\"\n", 38 | "\n", 39 | "# load the goal poses into a dictionary\n", 40 | "with open(GOAL_POSES, \"r\") as json_file:\n", 41 | " GOAL_DICT = json.load(json_file)\n", 42 | "GOAL_DICT = {k: np.array(v[\"observation\"]) for k,v in GOAL_DICT.items()}" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "id": "3ef8eeea-ee50-464d-9ecc-45d44fc7a71d", 48 | "metadata": {}, 49 | "source": [ 50 | "Load inference buffer." 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": null, 56 | "id": "7690ac63-ac66-474a-95a0-9b6ccb33f0da", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "buffer_path = hf_hub_download(\n", 61 | " repo_id=\"facebook/metamotivo-S-1\",\n", 62 | " filename=\"data/buffer_inference_500000.hdf5\",\n", 63 | " repo_type=\"model\",\n", 64 | " local_dir=\"metamotivo-S-1-datasets\",\n", 65 | " )\n", 66 | "hf = h5py.File(buffer_path, \"r\")\n", 67 | "data = {k: v[:] for k, v in hf.items()}\n", 68 | "buffer = DictBuffer(capacity=data[\"qpos\"].shape[0], device=\"cpu\")\n", 69 | "buffer.extend(data)" 70 | ] 71 | }, 72 | { 73 | "cell_type": "markdown", 74 | "id": "0d4fa4f5-74fd-45b1-b5b3-13d6368fb794", 75 | "metadata": {}, 76 | "source": [ 77 | "Load model and prepare it for inference." 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": null, 83 | "id": "e23b3e6c-eeb1-4fa8-bc77-f2ef04a3d495", 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "device = \"cpu\" # it is normally faster to evaluate on cpu since tracking is parallelized\n", 88 | "model = FBcprModel.from_pretrained(\"facebook/metamotivo-S-1\").to(device)\n", 89 | "model = RewardWrapper(\n", 90 | " model=model,\n", 91 | " inference_dataset=buffer,\n", 92 | " num_samples_per_inference=100_000,\n", 93 | " inference_function=\"reward_wr_inference\",\n", 94 | " max_workers=80,\n", 95 | " )\n", 96 | "model = GoalWrapper(model=model)\n", 97 | "model = TrackingWrapper(model=model)" 98 | ] 99 | }, 100 | { 101 | "cell_type": "markdown", 102 | "id": "279ca5af-54f2-4daa-9476-6c2ec00109a7", 103 | "metadata": {}, 104 | "source": [ 105 | "Humenv provides 3 evaluation protocols:\n", 106 | "- reward based,\n", 107 | "- goal based,\n", 108 | "- tracking" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "cc2f13f3-ee0e-4925-b87c-25e11a2f79f8", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [ 118 | "reward_eval = RewardEvaluation(\n", 119 | " tasks=STANDARD_TASKS, # all the 45 tasks used in the paper\n", 120 | " env_kwargs={\"state_init\": \"Fall\"},\n", 121 | " num_contexts=1,\n", 122 | " num_envs=50,\n", 123 | " num_episodes=100,\n", 124 | " )\n", 125 | "\n", 126 | "reward_metrics = reward_eval.run(agent=model)\n", 127 | "print(reward_metrics)\n", 128 | "\n", 129 | "r = np.array([m['reward'] for m in reward_metrics.values()])\n", 130 | "print(f\"reward averaged across {r.shape[0]} tasks: {r.mean()}\")" 131 | ] 132 | }, 133 | { 134 | "cell_type": "code", 135 | "execution_count": null, 136 | "id": "0e79882e-1e48-42b4-8a0d-4c1bdaf12148", 137 | "metadata": {}, 138 | "outputs": [], 139 | "source": [ 140 | "goal_eval = GoalEvaluation(\n", 141 | " goals=GOAL_DICT,\n", 142 | " env_kwargs={\"state_init\": \"Fall\"},\n", 143 | " num_contexts=1,\n", 144 | " num_envs=50,\n", 145 | " num_episodes=100,\n", 146 | ")\n", 147 | "\n", 148 | "goal_metrics = goal_eval.run(agent=model)\n", 149 | "print(goal_metrics)\n", 150 | "\n", 151 | "for k in ['success', 'proximity']:\n", 152 | " r = np.array([m[k] for m in goal_metrics.values()])\n", 153 | " print(f\"goal {k} averaged across {r.shape[0]} poses: {r.mean()}\")" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "c6ccae6b-510e-43c0-a92b-f0772c150ee5", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [ 163 | "tracking_eval = TrackingEvaluation(\n", 164 | " motions=MOTIONS_TRACKING,\n", 165 | " motion_base_path=MOTIONS_BASE_PATH,\n", 166 | " env_kwargs={\"state_init\": \"Default\"},\n", 167 | " num_envs=50,\n", 168 | ")\n", 169 | "\n", 170 | "tracking_metrics = tracking_eval.run(agent=model)\n", 171 | "print(tracking_metrics)\n", 172 | "\n", 173 | "for k in ['success_phc_linf', 'emd']:\n", 174 | " r = np.array([m[k] for m in tracking_metrics.values()])\n", 175 | " print(f\"tracking {k} averaged across {r.shape[0]} motions: {r.mean()}\")" 176 | ] 177 | } 178 | ], 179 | "metadata": { 180 | "kernelspec": { 181 | "display_name": "Python 3 (ipykernel)", 182 | "language": "python", 183 | "name": "python3" 184 | }, 185 | "language_info": { 186 | "codemirror_mode": { 187 | "name": "ipython", 188 | "version": 3 189 | }, 190 | "file_extension": ".py", 191 | "mimetype": "text/x-python", 192 | "name": "python", 193 | "nbconvert_exporter": "python", 194 | "pygments_lexer": "ipython3", 195 | "version": "3.12.2" 196 | } 197 | }, 198 | "nbformat": 4, 199 | "nbformat_minor": 5 200 | } 201 | -------------------------------------------------------------------------------- /metamotivo/buffers/buffers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import annotations 7 | 8 | import dataclasses 9 | import functools 10 | import numbers 11 | from collections import defaultdict 12 | from collections.abc import Mapping 13 | from typing import Any, Dict, List, Union 14 | 15 | import numpy as np 16 | import torch 17 | 18 | Device = Union[str, torch.device] 19 | 20 | 21 | @functools.singledispatch 22 | def _to_torch(value: Any, device: Device | None = None) -> Any: 23 | raise Exception(f"No known conversion for type ({type(value)}) to PyTorch registered. Report as issue on github.") 24 | 25 | 26 | @_to_torch.register(numbers.Number) 27 | @_to_torch.register(np.ndarray) 28 | def _np_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor: 29 | tensor = torch.tensor(value) 30 | if device: 31 | return tensor.to(device=device) 32 | return tensor 33 | 34 | 35 | @_to_torch.register(torch.Tensor) 36 | def _torch_to_torch(value: np.ndarray, device: Device | None = None) -> torch.Tensor: 37 | tensor = value.clone().detach() 38 | if device: 39 | return tensor.to(device=device) 40 | return tensor 41 | 42 | 43 | @dataclasses.dataclass 44 | class DictBuffer: 45 | capacity: int 46 | device: str = "cpu" 47 | 48 | def __post_init__(self) -> None: 49 | self.storage = None 50 | self._idx = 0 51 | self._is_full = False 52 | 53 | def __len__(self) -> int: 54 | return self.capacity if self._is_full else self._idx 55 | 56 | def empty(self) -> bool: 57 | return len(self) == 0 58 | 59 | @torch.no_grad 60 | def extend(self, data: Dict) -> None: 61 | if self.storage is None: 62 | self.storage = {} 63 | initialize_storage(data, self.storage, self.capacity, self.device) 64 | self._idx = 0 65 | self._is_full = False 66 | # let's store a key for easy inspection 67 | self._non_nested_key = [k for k, v in self.storage.items() if not isinstance(v, Mapping)][0] 68 | 69 | def add_new_data(data, storage, expected_dim: int): 70 | for k, v in data.items(): 71 | if isinstance(v, Mapping): 72 | # If the value is a dictionary, recursively call the function 73 | add_new_data(v, storage=storage[k], expected_dim=expected_dim) 74 | else: 75 | if v.shape[0] != expected_dim: 76 | raise ValueError("We expect all keys to have the same dimension") 77 | end = self._idx + v.shape[0] 78 | if end >= self.capacity: 79 | # Wrap data 80 | diff = self.capacity - self._idx 81 | storage[k][self._idx :] = _to_torch(v[:diff], device=self.device) 82 | storage[k][: v.shape[0] - diff] = _to_torch(v[diff:], device=self.device) 83 | self._is_full = True 84 | else: 85 | storage[k][self._idx : end] = _to_torch(v, device=self.device) 86 | 87 | data_dim = data[self._non_nested_key].shape[0] 88 | add_new_data(data, self.storage, expected_dim=data_dim) 89 | self._idx = (self._idx + data_dim) % self.capacity 90 | 91 | @torch.no_grad 92 | def sample(self, batch_size) -> Dict[str, torch.Tensor]: 93 | self.ind = torch.randint(0, len(self), (batch_size,)) 94 | return extract_values(self.storage, self.ind) 95 | 96 | def get_full_buffer(self) -> Dict: 97 | if self._is_full: 98 | return self.storage 99 | else: 100 | return extract_values(self.storage, torch.arange(0, len(self))) 101 | 102 | 103 | def extract_values(d: Dict, idxs: List | torch.Tensor | np.ndarray) -> Dict: 104 | result = {} 105 | for k, v in d.items(): 106 | if isinstance(v, Mapping): 107 | result[k] = extract_values(v, idxs) 108 | else: 109 | result[k] = v[idxs] 110 | return result 111 | 112 | 113 | @dataclasses.dataclass 114 | class TrajectoryBuffer: 115 | capacity: int 116 | device: str = "cpu" 117 | seq_length: int = 1 118 | output_key_t: list[str] = dataclasses.field(default_factory=lambda: ["observation"]) 119 | output_key_tp1: list[str] = dataclasses.field(default_factory=lambda: ["observation"]) 120 | 121 | def __post_init__(self) -> None: 122 | self._is_full = False 123 | self.storage = None 124 | self._idx = 0 125 | self.priorities = None 126 | 127 | def __len__(self) -> int: 128 | return self.capacity if self._is_full else self._idx 129 | 130 | def empty(self) -> bool: 131 | return len(self) == 0 132 | 133 | @torch.no_grad 134 | def extend(self, data: List[dict]) -> None: 135 | if self.storage is None: 136 | self.storage = [None for _ in range(self.capacity)] 137 | self._idx = 0 138 | self._is_full = False 139 | self.priorities = torch.ones(self.capacity, device=self.device, dtype=torch.float32) / self.capacity 140 | 141 | def add(new_data): 142 | storage = {} 143 | for k, v in new_data.items(): 144 | if isinstance(v, Mapping): 145 | storage[k] = add(v) 146 | else: 147 | storage[k] = _to_torch(v, device=self.device) 148 | if len(storage[k].shape) == 1: 149 | storage[k] = storage[k].reshape(-1, 1) 150 | return storage 151 | 152 | for episode in data: 153 | self.storage[self._idx] = add(new_data=episode) 154 | self._idx += 1 155 | if self._idx >= self.capacity: 156 | self._is_full = True 157 | self._idx = self._idx % self.capacity 158 | 159 | @torch.no_grad 160 | def sample(self, batch_size: int = 1) -> Dict[str, torch.Tensor]: 161 | if batch_size < self.seq_length: 162 | raise ValueError( 163 | f"The batch-size must be bigger than the sequence length, got batch_size={batch_size} and seq_length={self.seq_length}." 164 | ) 165 | 166 | if batch_size % self.seq_length != 0: 167 | raise ValueError( 168 | f"The batch-size must be divisible by the sequence length, got batch_size={batch_size} and seq_length={self.seq_length}." 169 | ) 170 | num_slices = batch_size // self.seq_length 171 | 172 | # self.ep_ind = torch.randint(0, len(self), (num_slices,)) 173 | self.ep_ind = torch.multinomial(self.priorities, num_slices, replacement=True) 174 | output = defaultdict(list) 175 | offset = 0 176 | if len(self.output_key_tp1) > 0: 177 | offset = 1 178 | output["next"] = defaultdict(list) 179 | for ep_idx in self.ep_ind: 180 | _ep = self.storage[ep_idx.item()] 181 | length = _ep[self.output_key_t[0]].shape[0] 182 | time_idx = torch.randint(0, length - self.seq_length - offset, (1,)) 183 | for k in self.output_key_t: 184 | output[k].append(_ep[k][time_idx : time_idx + self.seq_length]) 185 | for k in self.output_key_tp1: 186 | output["next"][k].append(_ep[k][time_idx + offset : time_idx + offset + self.seq_length]) 187 | 188 | return dict_cat(output) 189 | 190 | def update_priorities(self, priorities: torch.Tensor, idxs: torch.Tensor) -> None: 191 | self.priorities[idxs] = priorities 192 | self.priorities = self.priorities / torch.sum(self.priorities) 193 | 194 | 195 | def initialize_storage(data: Dict, storage: Dict, capacity: int, device: Device) -> None: 196 | def recursive_initialize(d, s): 197 | for k, v in d.items(): 198 | if isinstance(v, Mapping): 199 | s[k] = {} 200 | recursive_initialize(v, s[k]) 201 | else: 202 | s[k] = torch.zeros( 203 | (capacity, v.shape[1] if len(v.shape) == 2 else v.shape[0]), 204 | device=device, 205 | dtype=dtype_numpytotorch(v.dtype), 206 | ) 207 | 208 | recursive_initialize(data, storage) 209 | 210 | 211 | def dtype_numpytotorch(np_dtype: Any) -> torch.dtype: 212 | if isinstance(np_dtype, torch.dtype): 213 | return np_dtype 214 | if np_dtype == np.float16: 215 | return torch.float16 216 | elif np_dtype == np.float32: 217 | return torch.float32 218 | elif np_dtype == np.float64: 219 | return torch.float64 220 | elif np_dtype == np.int16: 221 | return torch.int16 222 | elif np_dtype == np.int32: 223 | return torch.int32 224 | elif np_dtype == np.int64: 225 | return torch.int64 226 | elif np_dtype == bool: 227 | return torch.bool 228 | elif np_dtype == np.uint8: 229 | return torch.uint8 230 | else: 231 | raise ValueError(f"Unknown type {np_dtype}") 232 | 233 | 234 | def dict_cat(d: Mapping) -> Dict[str, torch.Tensor]: 235 | res = {} 236 | for k, v in d.items(): 237 | if isinstance(v, Mapping): 238 | res[k] = dict_cat(v) 239 | else: 240 | res[k] = torch.cat(v, dim=0) 241 | return res 242 | -------------------------------------------------------------------------------- /metamotivo/wrappers/humenvbench.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import copy 7 | import torch 8 | from typing import Any 9 | import numpy as np 10 | import mujoco 11 | from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor 12 | import functools 13 | import dataclasses 14 | from humenv import make_humenv 15 | from humenv.rewards import RewardFunction 16 | 17 | 18 | def get_next(field: str, data: Any): 19 | if "next" in data and field in data["next"]: 20 | return data["next"][field] 21 | elif f"next_{field}" in data: 22 | return data[f"next_{field}"] 23 | else: 24 | raise ValueError(f"No next of {field} found in data.") 25 | 26 | 27 | @dataclasses.dataclass(kw_only=True) 28 | class BaseHumEnvBenchWrapper: 29 | model: Any 30 | numpy_output: bool = True 31 | _dtype: torch.dtype = dataclasses.field(default_factory=lambda: torch.float32) 32 | 33 | def act( 34 | self, 35 | obs: torch.Tensor | np.ndarray, 36 | z: torch.Tensor | np.ndarray, 37 | mean: bool = True, 38 | ) -> torch.Tensor: 39 | obs = to_torch(obs, device=self.device, dtype=self._dtype) 40 | z = to_torch(z, device=self.device, dtype=self._dtype) 41 | if self.numpy_output: 42 | return self.unwrapped_model.act(obs, z, mean).cpu().detach().numpy() 43 | return self.unwrapped_model.act(obs, z, mean) 44 | 45 | @property 46 | def device(self) -> Any: 47 | # this returns the base torch.nn.module 48 | return self.unwrapped_model.cfg.device 49 | 50 | @property 51 | def unwrapped_model(self): 52 | # this is used to call the base instance of model 53 | if hasattr(self.model, "unwrapped_model"): 54 | return self.model.unwrapped_model 55 | else: 56 | return self.model 57 | 58 | def __getattr__(self, name): 59 | # Delegate to the wrapped instance 60 | return getattr(self.model, name) 61 | 62 | def __deepcopy__(self, memo): 63 | return type(self)(model=copy.deepcopy(self.model, memo), numpy_output=self.numpy_output, _dtype=copy.deepcopy(self._dtype)) 64 | 65 | def __getstate__(self): 66 | # Return a dictionary containing the state of the object 67 | return { 68 | "model": self.model, 69 | "numpy_output": self.numpy_output, 70 | "_dtype": self._dtype, 71 | } 72 | 73 | def __setstate__(self, state): 74 | # Restore the state of the object from the given dictionary 75 | self.model = state["model"] 76 | self.numpy_output = state["numpy_output"] 77 | self._dtype = state["_dtype"] 78 | 79 | 80 | @dataclasses.dataclass(kw_only=True) 81 | class RewardWrapper(BaseHumEnvBenchWrapper): 82 | inference_dataset: Any 83 | num_samples_per_inference: int 84 | inference_function: str 85 | max_workers: int 86 | process_executor: bool = False 87 | process_context: str = "spawn" 88 | 89 | def reward_inference(self, task: str, **kwargs) -> torch.Tensor: 90 | env, _ = make_humenv(task=task, **kwargs) 91 | if self.num_samples_per_inference < len(self.inference_dataset): 92 | data = self.inference_dataset.sample(self.num_samples_per_inference) 93 | else: 94 | data = self.inference_dataset.get_full_buffer() 95 | qpos = get_next("qpos", data) 96 | qvel = get_next("qvel", data) 97 | action = data["action"] 98 | if isinstance(qpos, torch.Tensor): 99 | qpos = qpos.cpu().detach().numpy() 100 | qvel = qvel.cpu().detach().numpy() 101 | action = action.cpu().detach().numpy() 102 | rewards = relabel( 103 | env, 104 | qpos, 105 | qvel, 106 | action, 107 | env.unwrapped.task, 108 | max_workers=self.max_workers, 109 | process_executor=self.process_executor, 110 | ) 111 | env.close() 112 | 113 | td = { 114 | "reward": torch.tensor(rewards, dtype=torch.float32, device=self.device), 115 | } 116 | if "B" in data: 117 | td["B_vect"] = data["B"] 118 | else: 119 | td["next_obs"] = get_next("observation", data) 120 | inference_fn = getattr(self.model, self.inference_function, None) 121 | ctxs = inference_fn(**td).reshape(1, -1) 122 | return ctxs 123 | 124 | def __deepcopy__(self, memo): 125 | # Create a new instance of the same type as self 126 | return type(self)( 127 | model=copy.deepcopy(self.model, memo), 128 | numpy_output=self.numpy_output, 129 | _dtype=copy.deepcopy(self._dtype), 130 | inference_dataset=copy.deepcopy(self.inference_dataset), 131 | num_samples_per_inference=self.num_samples_per_inference, 132 | inference_function=self.inference_function, 133 | max_workers=self.max_workers, 134 | process_executor=self.process_executor, 135 | process_context=self.process_context, 136 | ) 137 | 138 | def __getstate__(self): 139 | # Return a dictionary containing the state of the object 140 | return { 141 | "model": self.model, 142 | "numpy_output": self.numpy_output, 143 | "_dtype": self._dtype, 144 | "inference_dataset": self.inference_dataset, 145 | "num_samples_per_inference": self.num_samples_per_inference, 146 | "inference_function": self.inference_function, 147 | "max_workers": self.max_workers, 148 | "process_executor": self.process_executor, 149 | "process_context": self.process_context, 150 | } 151 | 152 | def __setstate__(self, state): 153 | # Restore the state of the object from the given dictionary 154 | self.model = state["model"] 155 | self.numpy_output = state["numpy_output"] 156 | self._dtype = state["_dtype"] 157 | self.inference_dataset = state["inference_dataset"] 158 | self.num_samples_per_inference = state["num_samples_per_inference"] 159 | self.inference_function = state["inference_function"] 160 | self.max_workers = state["max_workers"] 161 | self.process_executor = state["process_executor"] 162 | self.process_context = state["process_context"] 163 | 164 | 165 | @dataclasses.dataclass(kw_only=True) 166 | class GoalWrapper(BaseHumEnvBenchWrapper): 167 | def goal_inference(self, goal_pose: torch.Tensor) -> torch.Tensor: 168 | next_obs = to_torch(goal_pose, device=self.device, dtype=self._dtype) 169 | ctx = self.unwrapped_model.goal_inference(next_obs=next_obs).reshape(1, -1) 170 | return ctx 171 | 172 | 173 | @dataclasses.dataclass(kw_only=True) 174 | class TrackingWrapper(BaseHumEnvBenchWrapper): 175 | def tracking_inference(self, next_obs: torch.Tensor | np.ndarray) -> torch.Tensor: 176 | next_obs = to_torch(next_obs, device=self.device, dtype=self._dtype) 177 | ctx = self.unwrapped_model.tracking_inference(next_obs=next_obs) 178 | return ctx 179 | 180 | 181 | def to_torch(x: np.ndarray | torch.Tensor, device: torch.device | str, dtype: torch.dtype): 182 | if len(x.shape) == 1: 183 | # adding batch dimension 184 | x = x[None, ...] 185 | if not isinstance(x, torch.Tensor): 186 | x = torch.tensor(x, device=device, dtype=dtype) 187 | else: 188 | x = x.to(dtype) 189 | return x 190 | 191 | 192 | def _relabel_worker( 193 | x, 194 | model: mujoco.MjModel, 195 | reward_fn: RewardFunction, 196 | ): 197 | qpos, qvel, action = x 198 | assert len(qpos.shape) > 1 199 | assert qvel.shape[0] == qpos.shape[0] 200 | assert qvel.shape[0] == action.shape[0] 201 | rewards = np.zeros((qpos.shape[0], 1)) 202 | for i in range(qpos.shape[0]): 203 | rewards[i] = reward_fn(model, qpos[i], qvel[i], action[i]) 204 | return rewards 205 | 206 | 207 | def relabel( 208 | env: Any, 209 | qpos: np.ndarray, 210 | qvel: np.ndarray, 211 | action: np.ndarray, 212 | reward_fn: RewardFunction, 213 | max_workers: int = 5, 214 | process_executor: bool = False, 215 | process_context: str = "spawn", 216 | ): 217 | chunk_size = int(np.ceil(qpos.shape[0] / max_workers)) 218 | args = [(qpos[i : i + chunk_size], qvel[i : i + chunk_size], action[i : i + chunk_size]) for i in range(0, qpos.shape[0], chunk_size)] 219 | if max_workers == 1: 220 | result = [_relabel_worker(args[0], model=env.unwrapped.model, reward_fn=reward_fn)] 221 | else: 222 | if process_executor: 223 | import multiprocessing 224 | 225 | with ProcessPoolExecutor( 226 | max_workers=max_workers, 227 | mp_context=multiprocessing.get_context(process_context), 228 | ) as exe: 229 | f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn) 230 | result = exe.map(f, args) 231 | else: 232 | with ThreadPoolExecutor(max_workers=max_workers) as exe: 233 | f = functools.partial(_relabel_worker, model=env.unwrapped.model, reward_fn=reward_fn) 234 | result = exe.map(f, args) 235 | 236 | tmp = [r for r in result] 237 | return np.concatenate(tmp) 238 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Meta Motivo 2 | **[Meta, FAIR](https://ai.facebook.com/research/)** 3 | 4 | 5 | # Overview 6 | This repository provides a PyTorch implementation and pre-trained models for Meta Motivo. For details see the paper [Zero-Shot Whole-Body Humanoid Control via Behavioral Foundation Models](https://metamotivo.metademolab.com/). 7 | 8 | ### Features 9 | 10 | - We provide [**6** pretrained FB-CPR](https://huggingface.co/collections/facebook/meta-motivo-6757761e8fd4a032466fd129) models for controlling the humanoid model defined in [HumEnv](https://github.com/facebookresearch/HumEnv/). 11 | - **Fully reproducible** scripts for evaluating the model in HumEnv. 12 | - **Fully reproducible** [FB-CPR training code in HumEnv](examples/fbcpr_train_humenv.py) for the full results in the paper, and [FB training code in DMC](examples/fb_train_dmc.py) for faster experimentation. 13 | 14 | # Installation 15 | 16 | The project is pip installable in your environment. 17 | 18 | ``` 19 | pip install "metamotivo[huggingface,humenv] @ git+https://github.com/facebookresearch/metamotivo.git" 20 | ``` 21 | 22 | It requires Python 3.10+. Optional dependencies include `humenv["bench"]` and `huggingface_hub` for testing/training and loading models from HuggingFace. 23 | 24 | 25 | # Pretrained models 26 | 27 | For reproducibility, we provide all the **5** models (**metamotivo-S-X**) we trained for producing the results in the [paper](https://openreview.net/forum?id=9sOR0nYLtz&referrer=%5BAuthor%20Console%5D(%2Fgroup%3Fid%3DICLR.cc%2F2025%2FConference%2FAuthors%23your-submissions)), where each model is trained using a different random seed. We also provide our largest and most performant model (**metamotivo-M-1**), which can also be interactively tested in our [demo](https://metamotivo.metademolab.com/). 28 | 29 | | Model | # of params | Download | 30 | | :--- | :---: | :---: | 31 | | metamotivo-S-1 | 24.5M | [link](https://huggingface.co/facebook/metamotivo-S-1) | 32 | | metamotivo-S-2 | 24.5M | [link](https://huggingface.co/facebook/metamotivo-S-2) | 33 | | metamotivo-S-3 | 24.5M | [link](https://huggingface.co/facebook/metamotivo-S-3) | 34 | | metamotivo-S-4 | 24.5M | [link](https://huggingface.co/facebook/metamotivo-S-4) | 35 | | metamotivo-S-5 | 24.5M | [link](https://huggingface.co/facebook/metamotivo-S-5) | 36 | | metamotivo-M-1 | 288M | [link](https://huggingface.co/facebook/metamotivo-M-1) | 37 | 38 | 39 | # Quick start 40 | 41 | Once the library is installed, you can easily create an FB-CPR agent and download a pre-trained model from the Hugging Face hub. Note that the model is an instance of `torch.nn.Module` and by default it is initialized in "inference" mode (no_grad and eval mode). 42 | 43 | We provide some simple code snippets to demonstrate how to use the model below. For more detailed examples, see our tutorials on [interacting with the model](https://github.com/facebookresearch/metamotivo/blob/main/tutorial.ipynb), [running an evaluation](https://github.com/facebookresearch/metamotivo/blob/main/tutorial_benchmark.ipynb), and [training from scratch](https://github.com/facebookresearch/metamotivo/tree/main/examples). 44 | 45 | ### Download the pre-trained models 46 | 47 | The following code snippet shows how to instantiate the model. 48 | 49 | ```python 50 | from metamotivo.fb_cpr.huggingface import FBcprModel 51 | 52 | model = FBcprModel.from_pretrained("facebook/metamotivo-S-1") 53 | ``` 54 | 55 | ### Download the buffers 56 | For each model we provide: 57 | - The training buffer (that can be used for inference or offline training) 58 | - A small reward inference buffer (that contains the minimum amount of information for doing reward inference) 59 | 60 | ```python 61 | from huggingface_hub import hf_hub_download 62 | import h5py 63 | 64 | local_dir = "metamotivo-S-1-datasets" 65 | dataset = "buffer_inference_500000.hdf5" # a smaller buffer that can be used for reward inference 66 | # dataset = "buffer.hdf5" # the full training buffer of the model 67 | buffer_path = hf_hub_download( 68 | repo_id="facebook/metamotivo-S-1", 69 | filename=f"data/{dataset}", 70 | repo_type="model", 71 | local_dir=local_dir, 72 | ) 73 | hf = h5py.File(buffer_path, "r") 74 | print(hf.keys()) 75 | 76 | # create a DictBuffer object that can be used for sampling 77 | data = {k: v[:] for k, v in hf.items()} 78 | buffer = DictBuffer(capacity=data["qpos"].shape[0], device="cpu") 79 | buffer.extend(data) 80 | ``` 81 | 82 | ### The FB-CPR model 83 | The FB-CPR model contains several networks: 84 | - forward net 85 | - backward net 86 | - critic net 87 | - discriminator net 88 | - actor net 89 | 90 | We provide functions for evaluating these networks 91 | 92 | ```python 93 | def backward_map(self, obs: torch.Tensor) -> torch.Tensor: ... 94 | def forward_map(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor) -> torch.Tensor: ... 95 | def actor(self, obs: torch.Tensor, z: torch.Tensor, std: float) -> torch.Tensor: ... 96 | def critic(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor) -> torch.Tensor: ... 97 | def discriminator(self, obs: torch.Tensor, z: torch.Tensor) -> torch.Tensor: ... 98 | ``` 99 | 100 | We also provide simple functions for prompting the model and obtaining a context vector `z` representing the task to execute. 101 | ```python 102 | #reward prompt (standard and weighted regression) 103 | def reward_inference(self, next_obs: torch.Tensor, reward: torch.Tensor, weight: torch.Tensor | None = None,) -> torch.Tensor: ... 104 | def reward_wr_inference(self, next_obs: torch.Tensor, reward: torch.Tensor) -> torch.Tensor: ... 105 | #goal prompt 106 | def goal_inference(self, next_obs: torch.Tensor) -> torch.Tensor: ... 107 | #tracking prompt 108 | def tracking_inference(self, next_obs: torch.Tensor) -> torch.Tensor: 109 | ``` 110 | Once we have a context vector `z` we can call the actor to get actions. We provide a function for acting in the environment with a standard interface. 111 | ```python 112 | def act(self, obs: torch.Tensor, z: torch.Tensor, mean: bool = True) -> torch.Tensor: 113 | ``` 114 | 115 | Note that these functions do not allow gradient computation and use eval mode since they are expected to be used for inference (`torch.no_grad()` and `model.eval()`). For training, you should directly access the class attributes. For training we also define target networks for the forward, backward and critic networks. 116 | 117 | 118 | ### Execute a policy 119 | 120 | This is the minimal example on how to execute a random policy 121 | 122 | ```python 123 | from humenv import make_humenv 124 | from gymnasium.wrappers import FlattenObservation, TransformObservation 125 | import torch 126 | from metamotivo.fb_cpr.huggingface import FBcprModel 127 | 128 | device = "cpu" 129 | env, _ = make_humenv( 130 | num_envs=1, 131 | wrappers=[ 132 | FlattenObservation, 133 | lambda env: TransformObservation( 134 | env, lambda obs: torch.tensor(obs.reshape(1, -1), dtype=torch.float32, device=device), env.observation_space # For gymnasium <1.0.0 remove the last argument: env.observation_space 135 | ), 136 | ], 137 | state_init="Default", 138 | ) 139 | 140 | model = FBcprModel.from_pretrained("facebook/metamotivo-S-1") 141 | model.to(device) 142 | z = model.sample_z(1) 143 | observation, _ = env.reset() 144 | for i in range(10): 145 | action = model.act(observation, z, mean=True) 146 | observation, reward, terminated, truncated, info = env.step(action.cpu().numpy().ravel()) 147 | ``` 148 | 149 | 150 | # Evaluation in HumEnv 151 | 152 | For reproducibility of the paper, we provide a way of evaluating the models using `HumEnv`. We provide wrappers that can be used to interface Meta Motivo with `humenv.bench` reward, goal and tracking evaluation. 153 | 154 | Here is an example of how to use the wrappers for reward evaluation: 155 | 156 | ```python 157 | from metamotivo.fb_cpr.huggingface import FBcprModel 158 | from metamotivo.wrappers.humenvbench import RewardWrapper 159 | import humenv.bench 160 | 161 | model = FBcprModel.from_pretrained("facebook/metamotivo-S-1") 162 | 163 | # this enable reward relabeling and context inference 164 | model = RewardWrapper( 165 | model=model, 166 | inference_dataset=buffer, # see above how to download and create a buffer 167 | num_samples_per_inference=100_000, 168 | inference_function="reward_wr_inference", 169 | max_workers=80, 170 | ) 171 | # create the evaluation from humenv 172 | reward_eval = humenv.bench.RewardEvaluation( 173 | tasks=["move-ego-0-0"], 174 | env_kwargs={ 175 | "state_init": "Default", 176 | }, 177 | num_contexts=1, 178 | num_envs=50, 179 | num_episodes=100 180 | ) 181 | scores = reward_eval.run(model) 182 | ``` 183 | 184 | You can do the same for the other evaluations provided in `humenv.bench`. Please refer to `tutorial_benchmark.ipynb` for a full evaluation loop. 185 | 186 | # Rendering a reward-based or tracking policy 187 | 188 | We show how to render an episode with a reward-based policy. 189 | 190 | ```python 191 | import os 192 | os.environ["OMP_NUM_THREADS"] = "1" 193 | from humenv import STANDARD_TASKS 194 | import mediapy as media 195 | 196 | task = STANDARD_TASKS[0] 197 | model = FBcprModel.from_pretrained("facebook/metamotivo-S-1", device="cpu") 198 | rew_model = RewardWrapper( 199 | model=model, 200 | inference_dataset=buffer, # see above how to download and create a buffer 201 | num_samples_per_inference=100_000, 202 | inference_function="reward_wr_inference", 203 | max_workers=40, 204 | process_executor=True, 205 | process_context="forkserver" 206 | ) 207 | z = rew_model.reward_inference(task) 208 | env, _ = make_humenv(num_envs=1, task=task, state_init="DefaultAndFall", wrappers=[gymnasium.wrappers.FlattenObservation]) 209 | done = False 210 | observation, info = env.reset() 211 | frames = [env.render()] 212 | while not done: 213 | obs = torch.tensor(observation.reshape(1,-1), dtype=torch.float32, device=rew_model.device) 214 | action = rew_model.act(obs=obs, z=z).ravel() 215 | observation, reward, terminated, truncated, info = env.step(action) 216 | frames.append(env.render()) 217 | done = bool(terminated or truncated) 218 | 219 | media.show_video(frames, fps=30) 220 | ``` 221 | 222 | It is also easy to render a policy for tracking a motion. 223 | 224 | ```python 225 | import os 226 | os.environ["OMP_NUM_THREADS"] = "1" 227 | from metamotivo.wrappers.humenvbench import TrackingWrapper 228 | from pathlib import Path 229 | from humenv.misc.motionlib import MotionBuffer 230 | 231 | model = FBcprModel.from_pretrained("facebook/metamotivo-S-1", device="cpu") 232 | track_model = TrackingWrapper(model=model) 233 | motion_buffer = MotionBuffer(files=ADD_THE_DESIRED_MOTION, base_path=ADD_YOUR_MOTION_ROOT, keys=["qpos", "qvel", "observation"]) 234 | ep_ = motion_buffer.get(motion_buffer.get_motion_ids()[0] 235 | ctx = track_model.tracking_inference(next_obs=ep_["observation"][1:]) 236 | observation, info = env.reset(options={"qpos": ep_["qpos"][0], "qvel": ep_["qvel"][0]}) 237 | done = False 238 | observation, info = env.reset() 239 | frames = [env.render()] 240 | for t in range(len(ctx)): 241 | obs = torch.tensor(observation.reshape(1,-1), dtype=torch.float32, device=track_model.device) 242 | action = track_model.act(obs=obs, z=ctx[t]).ravel() 243 | observation, reward, terminated, truncated, info = env.step(action) 244 | frames.append(env.render()) 245 | 246 | media.show_video(frames, fps=30) 247 | ``` 248 | 249 | # Citation 250 | ``` 251 | @article{tirinzoni2024metamotivo, 252 | title={Zero-shot Whole-Body Humanoid Control via Behavioral Foundation Models}, 253 | author={Tirinzoni, Andrea and Touati, Ahmed and Farebrother, Jesse and Guzek, Mateusz and Kanervisto, Anssi and Xu, Yingchen and Lazaric, Alessandro and Pirotta, Matteo}, 254 | } 255 | ``` 256 | 257 | # License 258 | 259 | Meta Motivo is licensed under the CC BY-NC 4.0 license. See [LICENSE](LICENSE) for details. 260 | -------------------------------------------------------------------------------- /examples/fb_train_dmc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import annotations 7 | import torch 8 | 9 | torch.set_float32_matmul_precision("high") 10 | 11 | import numpy as np 12 | import dataclasses 13 | from metamotivo.buffers.buffers import DictBuffer 14 | from metamotivo.fb import FBAgent, FBAgentConfig 15 | from metamotivo.nn_models import eval_mode 16 | from tqdm import tqdm 17 | import time 18 | from dm_control import suite 19 | import random 20 | from pathlib import Path 21 | import wandb 22 | import json 23 | from typing import List 24 | import mujoco 25 | import warnings 26 | import tyro 27 | 28 | ALL_TASKS = { 29 | "walker": [ 30 | "walk", 31 | "run", 32 | "stand", 33 | ], 34 | "cheetah": ["walk", "run"], 35 | "quadruped": ["walk", "run"], 36 | } 37 | 38 | 39 | def create_agent( 40 | domain_name="walker", 41 | task_name="walk", 42 | device="cpu", 43 | compile=False, 44 | cudagraphs=False, 45 | ) -> FBAgent: 46 | if domain_name not in ["walker", "pointmass", "cheetah", "quadruped"]: 47 | raise RuntimeError('FB configuration defined only for "walker", "pointmass", "cheetah", "quadruped"') 48 | env = suite.load( 49 | domain_name=domain_name, 50 | task_name=task_name, 51 | environment_kwargs={"flat_observation": True}, 52 | ) 53 | 54 | agent_config = FBAgentConfig() 55 | agent_config.model.obs_dim = env.observation_spec()["observations"].shape[0] 56 | agent_config.model.action_dim = env.action_spec().shape[0] 57 | agent_config.model.device = device 58 | agent_config.model.norm_obs = False 59 | agent_config.model.seq_length = 1 60 | agent_config.train.batch_size = 1024 61 | # archi 62 | if domain_name in ["walker", "pointmass"]: 63 | agent_config.model.archi.z_dim = 100 64 | else: 65 | agent_config.model.archi.z_dim = 50 66 | agent_config.model.archi.b.norm = True 67 | agent_config.model.archi.norm_z = True 68 | agent_config.model.archi.b.hidden_dim = 256 69 | agent_config.model.archi.f.hidden_dim = 1024 70 | agent_config.model.archi.actor.hidden_dim = 1024 71 | agent_config.model.archi.f.hidden_layers = 1 72 | agent_config.model.archi.actor.hidden_layers = 1 73 | agent_config.model.archi.b.hidden_layers = 2 74 | # optim 75 | if domain_name == "pointmass": 76 | agent_config.train.lr_f = 1e-4 77 | agent_config.train.lr_b = 1e-6 78 | agent_config.train.lr_actor = 1e-6 79 | else: 80 | agent_config.train.lr_f = 1e-4 81 | agent_config.train.lr_b = 1e-4 82 | agent_config.train.lr_actor = 1e-4 83 | agent_config.train.ortho_coef = 1 84 | agent_config.train.train_goal_ratio = 0.5 85 | agent_config.train.fb_pessimism_penalty = 0 86 | agent_config.train.actor_pessimism_penalty = 0.5 87 | 88 | if domain_name == "pointmass": 89 | agent_config.train.discount = 0.99 90 | else: 91 | agent_config.train.discount = 0.98 92 | agent_config.compile = compile 93 | agent_config.cudagraphs = cudagraphs 94 | 95 | return agent_config 96 | 97 | 98 | def load_data(dataset_path, expl_agent, domain_name, num_episodes=1): 99 | path = Path(dataset_path) / f"{domain_name}/{expl_agent}/buffer" 100 | print(f"Data path: {path}") 101 | storage = { 102 | "observation": [], 103 | "action": [], 104 | "physics": [], 105 | "next": {"observation": [], "terminated": [], "physics": []}, 106 | } 107 | files = list(path.glob("*.npz")) 108 | num_episodes = min(num_episodes, len(files)) 109 | for i in tqdm(range(num_episodes)): 110 | f = files[i] 111 | data = np.load(str(f)) 112 | storage["observation"].append(data["observation"][:-1].astype(np.float32)) 113 | storage["action"].append(data["action"][1:].astype(np.float32)) 114 | storage["next"]["observation"].append(data["observation"][1:].astype(np.float32)) 115 | storage["next"]["terminated"].append(np.array(1 - data["discount"][1:], dtype=np.bool)) 116 | storage["physics"].append(data["physics"][:-1]) 117 | storage["next"]["physics"].append(data["physics"][1:]) 118 | 119 | for k in storage: 120 | if k == "next": 121 | for k1 in storage[k]: 122 | storage[k][k1] = np.concat(storage[k][k1]) 123 | else: 124 | storage[k] = np.concat(storage[k]) 125 | return storage 126 | 127 | 128 | def set_seed_everywhere(seed): 129 | torch.manual_seed(seed) 130 | if torch.cuda.is_available(): 131 | torch.cuda.manual_seed_all(seed) 132 | np.random.seed(seed) 133 | random.seed(seed) 134 | 135 | 136 | @dataclasses.dataclass 137 | class TrainConfig: 138 | dataset_root: str 139 | seed: int = 0 140 | domain_name: str = "walker" 141 | task_name: str | None = None 142 | dataset_expl_agent: str = "rnd" 143 | num_train_steps: int = 3_000_000 144 | load_n_episodes: int = 5_000 145 | log_every_updates: int = 10_000 146 | work_dir: str | None = None 147 | 148 | checkpoint_every_steps: int = 1_000_000 149 | 150 | # eval 151 | num_eval_episodes: int = 10 152 | num_inference_samples: int = 50_000 153 | eval_every_steps: int = 100_000 154 | eval_tasks: List[str] | None = None 155 | 156 | # misc 157 | compile: bool = False 158 | cudagraphs: bool = False 159 | device: str = "cuda" 160 | 161 | # WANDB 162 | use_wandb: bool = False 163 | wandb_ename: str | None = None 164 | wandb_gname: str | None = None 165 | wandb_pname: str | None = "fb_train_dmc" 166 | wandb_name_prefix: str | None = None 167 | 168 | def __post_init__(self): 169 | if self.eval_tasks is None: 170 | self.eval_tasks = ALL_TASKS[self.domain_name] 171 | 172 | 173 | class Workspace: 174 | def __init__(self, cfg: TrainConfig, agent_cfg: FBAgentConfig) -> None: 175 | self.cfg = cfg 176 | self.agent_cfg = agent_cfg 177 | if self.cfg.work_dir is None: 178 | import string 179 | 180 | tmp_name = "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) 181 | self.work_dir = Path.cwd() / "tmp_fbcpr" / tmp_name 182 | self.cfg.work_dir = str(self.work_dir) 183 | else: 184 | self.work_dir = Path(self.cfg.work_dir) 185 | self.work_dir = Path(self.work_dir) 186 | self.work_dir.mkdir(exist_ok=True, parents=True) 187 | print(f"working dir: {self.work_dir}") 188 | 189 | self.agent = FBAgent(**dataclasses.asdict(self.agent_cfg)) 190 | set_seed_everywhere(self.cfg.seed) 191 | 192 | if self.cfg.use_wandb: 193 | exp_name = "fb" 194 | wandb_name = exp_name 195 | if self.cfg.wandb_name_prefix: 196 | wandb_name = f"{self.cfg.wandb_name_prefix}_{exp_name}" 197 | # fmt: off 198 | wandb_config = dataclasses.asdict(self.cfg) 199 | wandb.init(entity=self.cfg.wandb_ename, project=self.cfg.wandb_pname, 200 | group=self.cfg.agent.name if self.cfg.wandb_gname is None else self.cfg.wandb_gname, name=wandb_name, # mode="disabled", 201 | config=wandb_config) # type: ignore 202 | # fmt: on 203 | 204 | with (self.work_dir / "config.json").open("w") as f: 205 | json.dump(dataclasses.asdict(self.cfg), f, indent=4) 206 | 207 | def train(self): 208 | self.start_time = time.time() 209 | self.train_offline() 210 | 211 | def train_offline(self) -> None: 212 | self.replay_buffer = {} 213 | # LOAD DATA FROM EXORL 214 | data = load_data( 215 | self.cfg.dataset_root, 216 | self.cfg.dataset_expl_agent, 217 | self.cfg.domain_name, 218 | self.cfg.load_n_episodes, 219 | ) 220 | self.replay_buffer = {"train": DictBuffer(capacity=data["observation"].shape[0], device=self.agent.device)} 221 | self.replay_buffer["train"].extend(data) 222 | print(self.replay_buffer["train"]) 223 | del data 224 | 225 | total_metrics = None 226 | fps_start_time = time.time() 227 | for t in tqdm(range(0, int(self.cfg.num_train_steps))): 228 | if t % self.cfg.eval_every_steps == 0: 229 | self.eval(t) 230 | 231 | # torch.compiler.cudagraph_mark_step_begin() 232 | metrics = self.agent.update(self.replay_buffer, t) 233 | 234 | # we need to copy tensors returned by a cudagraph module 235 | if total_metrics is None: 236 | total_metrics = {k: metrics[k].clone() for k in metrics.keys()} 237 | else: 238 | total_metrics = {k: total_metrics[k] + metrics[k] for k in metrics.keys()} 239 | 240 | if t % self.cfg.log_every_updates == 0: 241 | m_dict = {} 242 | for k in sorted(list(total_metrics.keys())): 243 | tmp = total_metrics[k] / (1 if t == 0 else self.cfg.log_every_updates) 244 | m_dict[k] = np.round(tmp.mean().item(), 6) 245 | m_dict["duration"] = time.time() - self.start_time 246 | m_dict["FPS"] = (1 if t == 0 else self.cfg.log_every_updates) / (time.time() - fps_start_time) 247 | if self.cfg.use_wandb: 248 | wandb.log( 249 | {f"train/{k}": v for k, v in m_dict.items()}, 250 | step=t, 251 | ) 252 | print(m_dict) 253 | total_metrics = None 254 | fps_start_time = time.time() 255 | if t % self.cfg.checkpoint_every_steps == 0: 256 | self.agent.save(str(self.work_dir / "checkpoint")) 257 | self.agent.save(str(self.work_dir / "checkpoint")) 258 | return 259 | 260 | def eval(self, t): 261 | for task in self.cfg.eval_tasks: 262 | z = self.reward_inference(task).reshape(1, -1) 263 | eval_env = suite.load( 264 | domain_name=self.cfg.domain_name, 265 | task_name=task, 266 | environment_kwargs={"flat_observation": True}, 267 | ) 268 | num_ep = self.cfg.num_eval_episodes 269 | total_reward = np.zeros((num_ep,), dtype=np.float64) 270 | for ep in range(num_ep): 271 | time_step = eval_env.reset() 272 | while not time_step.last(): 273 | with torch.no_grad(), eval_mode(self.agent._model): 274 | obs = torch.tensor( 275 | time_step.observation["observations"].reshape(1, -1), 276 | device=self.agent.device, 277 | dtype=torch.float32, 278 | ) 279 | action = self.agent.act(obs=obs, z=z, mean=True).cpu().numpy() 280 | time_step = eval_env.step(action) 281 | total_reward[ep] += time_step.reward 282 | m_dict = { 283 | "reward": np.mean(total_reward), 284 | "reward#std": np.std(total_reward), 285 | } 286 | if self.cfg.use_wandb: 287 | wandb.log( 288 | {f"{task}/{k}": v for k, v in m_dict.items()}, 289 | step=t, 290 | ) 291 | m_dict["task"] = task 292 | print(m_dict) 293 | 294 | def reward_inference(self, task) -> torch.Tensor: 295 | env = suite.load( 296 | domain_name=self.cfg.domain_name, 297 | task_name=task, 298 | environment_kwargs={"flat_observation": True}, 299 | ) 300 | num_samples = self.cfg.num_inference_samples 301 | batch = self.replay_buffer["train"].sample(num_samples) 302 | rewards = [] 303 | for i in range(num_samples): 304 | with env._physics.reset_context(): 305 | env._physics.set_state(batch["next"]["physics"][i].cpu().numpy()) 306 | env._physics.set_control(batch["action"][i].cpu().detach().numpy()) 307 | mujoco.mj_forward(env._physics.model.ptr, env._physics.data.ptr) # pylint: disable=no-member 308 | mujoco.mj_fwdPosition(env._physics.model.ptr, env._physics.data.ptr) # pylint: disable=no-member 309 | mujoco.mj_sensorVel(env._physics.model.ptr, env._physics.data.ptr) # pylint: disable=no-member 310 | mujoco.mj_subtreeVel(env._physics.model.ptr, env._physics.data.ptr) # pylint: disable=no-member 311 | rewards.append(env._task.get_reward(env._physics)) 312 | rewards = np.array(rewards).reshape(-1, 1) 313 | z = self.agent._model.reward_inference( 314 | next_obs=batch["next"]["observation"], 315 | reward=torch.tensor(rewards, dtype=torch.float32, device=self.agent.device), 316 | ) 317 | return z 318 | 319 | 320 | if __name__ == "__main__": 321 | config = tyro.cli(TrainConfig) 322 | 323 | warnings.warn( 324 | "Since the original creation of ExORL, mujoco has seen many updates. To rerun all the actions and collect a physics consistent data, you may optionally use the update_data.py utility from MTM (https://github.com/facebookresearch/mtm/tree/main/research/exorl)." 325 | ) 326 | if config.task_name is None: 327 | if config.domain_name == "walker": 328 | config.task_name = "walk" 329 | elif config.domain_name == "cheetah": 330 | config.task_name = "run" 331 | elif config.domain_name == "pointmass": 332 | config.task_name = "reach_top_left" 333 | elif config.domain_name == "quadruped": 334 | config.task_name = "run" 335 | else: 336 | raise RuntimeError("Unsupported domain, you need to specify task_name") 337 | agent_config = create_agent( 338 | domain_name=config.domain_name, 339 | task_name=config.task_name, 340 | device=config.device, 341 | compile=config.compile, 342 | cudagraphs=config.cudagraphs, 343 | ) 344 | 345 | ws = Workspace(config, agent_cfg=agent_config) 346 | ws.train() 347 | -------------------------------------------------------------------------------- /metamotivo/fb/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import dataclasses 7 | import torch 8 | import torch.nn.functional as F 9 | from typing import Dict, Tuple 10 | 11 | from .model import FBModel, config_from_dict 12 | from .model import Config as FBModelConfig 13 | from ..nn_models import weight_init, _soft_update_params, eval_mode 14 | from ..misc.zbuffer import ZBuffer 15 | from pathlib import Path 16 | import json 17 | import safetensors 18 | 19 | 20 | @dataclasses.dataclass 21 | class TrainConfig: 22 | lr_f: float = 1e-4 23 | lr_b: float = 1e-4 24 | lr_actor: float = 1e-4 25 | weight_decay: float = 0.0 26 | clip_grad_norm: float = 0.0 27 | fb_target_tau: float = 0.01 28 | ortho_coef: float = 1.0 29 | train_goal_ratio: float = 0.5 30 | fb_pessimism_penalty: float = 0.0 31 | actor_pessimism_penalty: float = 0.5 32 | stddev_clip: float = 0.3 33 | q_loss_coef: float = 0.0 34 | batch_size: int = 1024 35 | discount: float | None = None 36 | use_mix_rollout: bool = False 37 | update_z_every_step: int = 150 38 | z_buffer_size: int = 10000 39 | 40 | 41 | @dataclasses.dataclass 42 | class Config: 43 | model: FBModelConfig = dataclasses.field(default_factory=FBModelConfig) 44 | train: TrainConfig = dataclasses.field(default_factory=TrainConfig) 45 | cudagraphs: bool = False 46 | compile: bool = False 47 | 48 | 49 | class FBAgent: 50 | def __init__(self, **kwargs): 51 | self.cfg = config_from_dict(kwargs, Config) 52 | self.cfg.train.fb_target_tau = float(min(max(self.cfg.train.fb_target_tau, 0), 1)) 53 | self._model = FBModel(**dataclasses.asdict(self.cfg.model)) 54 | self.setup_training() 55 | self.setup_compile() 56 | self._model.to(self.cfg.model.device) 57 | 58 | @property 59 | def device(self): 60 | return self._model.cfg.device 61 | 62 | def setup_training(self) -> None: 63 | self._model.train(True) 64 | self._model.requires_grad_(True) 65 | self._model.apply(weight_init) 66 | self._model._prepare_for_train() # ensure that target nets are initialized after applying the weights 67 | 68 | self.backward_optimizer = torch.optim.Adam( 69 | self._model._backward_map.parameters(), 70 | lr=self.cfg.train.lr_b, 71 | capturable=self.cfg.cudagraphs and not self.cfg.compile, 72 | weight_decay=self.cfg.train.weight_decay, 73 | ) 74 | self.forward_optimizer = torch.optim.Adam( 75 | self._model._forward_map.parameters(), 76 | lr=self.cfg.train.lr_f, 77 | capturable=self.cfg.cudagraphs and not self.cfg.compile, 78 | weight_decay=self.cfg.train.weight_decay, 79 | ) 80 | self.actor_optimizer = torch.optim.Adam( 81 | self._model._actor.parameters(), 82 | lr=self.cfg.train.lr_actor, 83 | capturable=self.cfg.cudagraphs and not self.cfg.compile, 84 | weight_decay=self.cfg.train.weight_decay, 85 | ) 86 | 87 | # prepare parameter list 88 | self._forward_map_paramlist = tuple(x for x in self._model._forward_map.parameters()) 89 | self._target_forward_map_paramlist = tuple(x for x in self._model._target_forward_map.parameters()) 90 | self._backward_map_paramlist = tuple(x for x in self._model._backward_map.parameters()) 91 | self._target_backward_map_paramlist = tuple(x for x in self._model._target_backward_map.parameters()) 92 | 93 | # precompute some useful variables 94 | self.off_diag = 1 - torch.eye(self.cfg.train.batch_size, self.cfg.train.batch_size, device=self.device) 95 | self.off_diag_sum = self.off_diag.sum() 96 | 97 | self.z_buffer = ZBuffer(self.cfg.train.z_buffer_size, self.cfg.model.archi.z_dim, self.cfg.model.device) 98 | 99 | def setup_compile(self): 100 | print(f"compile {self.cfg.compile}") 101 | if self.cfg.compile: 102 | mode = "reduce-overhead" if not self.cfg.cudagraphs else None 103 | print(f"compiling with mode '{mode}'") 104 | self.update_fb = torch.compile(self.update_fb, mode=mode) # use fullgraph=True to debug for graph breaks 105 | self.update_actor = torch.compile(self.update_actor, mode=mode) # use fullgraph=True to debug for graph breaks 106 | self.sample_mixed_z = torch.compile(self.sample_mixed_z, mode=mode, fullgraph=True) 107 | 108 | print(f"cudagraphs {self.cfg.cudagraphs}") 109 | if self.cfg.cudagraphs: 110 | from tensordict.nn import CudaGraphModule 111 | 112 | self.update_fb = CudaGraphModule(self.update_fb, warmup=5) 113 | self.update_actor = CudaGraphModule(self.update_actor, warmup=5) 114 | 115 | def act(self, obs: torch.Tensor, z: torch.Tensor, mean: bool = True) -> torch.Tensor: 116 | return self._model.act(obs, z, mean) 117 | 118 | @torch.no_grad() 119 | def sample_mixed_z(self, train_goal: torch.Tensor | None = None, *args, **kwargs): 120 | # samples a batch from the z distribution used to update the networks 121 | z = self._model.sample_z(self.cfg.train.batch_size, device=self.device) 122 | 123 | if train_goal is not None: 124 | perm = torch.randperm(self.cfg.train.batch_size, device=self.device) 125 | goals = self._model._backward_map(train_goal[perm]) 126 | goals = self._model.project_z(goals) 127 | mask = torch.rand((self.cfg.train.batch_size, 1), device=self.device) < self.cfg.train.train_goal_ratio 128 | z = torch.where(mask, goals, z) 129 | return z 130 | 131 | def update(self, replay_buffer, step: int) -> Dict[str, torch.Tensor]: 132 | batch = replay_buffer["train"].sample(self.cfg.train.batch_size) 133 | 134 | obs, action, next_obs, terminated = ( 135 | batch["observation"], 136 | batch["action"], 137 | batch["next"]["observation"], 138 | batch["next"]["terminated"], 139 | ) 140 | discount = self.cfg.train.discount * ~terminated 141 | 142 | self._model._obs_normalizer(obs) 143 | self._model._obs_normalizer(next_obs) 144 | with torch.no_grad(), eval_mode(self._model._obs_normalizer): 145 | obs, next_obs = self._model._obs_normalizer(obs), self._model._obs_normalizer(next_obs) 146 | 147 | torch.compiler.cudagraph_mark_step_begin() 148 | z = self.sample_mixed_z(train_goal=next_obs).clone() 149 | self.z_buffer.add(z) 150 | 151 | q_loss_coef = self.cfg.train.q_loss_coef if self.cfg.train.q_loss_coef > 0 else None 152 | clip_grad_norm = self.cfg.train.clip_grad_norm if self.cfg.train.clip_grad_norm > 0 else None 153 | 154 | torch.compiler.cudagraph_mark_step_begin() 155 | metrics = self.update_fb( 156 | obs=obs, 157 | action=action, 158 | discount=discount, 159 | next_obs=next_obs, 160 | goal=next_obs, 161 | z=z, 162 | q_loss_coef=q_loss_coef, 163 | clip_grad_norm=clip_grad_norm, 164 | ) 165 | metrics.update( 166 | self.update_actor( 167 | obs=obs, 168 | action=action, 169 | z=z, 170 | clip_grad_norm=clip_grad_norm, 171 | ) 172 | ) 173 | 174 | with torch.no_grad(): 175 | _soft_update_params(self._forward_map_paramlist, self._target_forward_map_paramlist, self.cfg.train.fb_target_tau) 176 | _soft_update_params(self._backward_map_paramlist, self._target_backward_map_paramlist, self.cfg.train.fb_target_tau) 177 | 178 | return metrics 179 | 180 | def update_fb( 181 | self, 182 | obs: torch.Tensor, 183 | action: torch.Tensor, 184 | discount: torch.Tensor, 185 | next_obs: torch.Tensor, 186 | goal: torch.Tensor, 187 | z: torch.Tensor, 188 | q_loss_coef: float | None, 189 | clip_grad_norm: float | None, 190 | ) -> Dict[str, torch.Tensor]: 191 | with torch.no_grad(): 192 | dist = self._model._actor(next_obs, z, self._model.cfg.actor_std) 193 | next_action = dist.sample(clip=self.cfg.train.stddev_clip) 194 | target_Fs = self._model._target_forward_map(next_obs, z, next_action) # num_parallel x batch x z_dim 195 | target_B = self._model._target_backward_map(goal) # batch x z_dim 196 | target_Ms = torch.matmul(target_Fs, target_B.T) # num_parallel x batch x batch 197 | _, _, target_M = self.get_targets_uncertainty(target_Ms, self.cfg.train.fb_pessimism_penalty) # batch x batch 198 | 199 | # compute FB loss 200 | Fs = self._model._forward_map(obs, z, action) # num_parallel x batch x z_dim 201 | B = self._model._backward_map(goal) # batch x z_dim 202 | Ms = torch.matmul(Fs, B.T) # num_parallel x batch x batch 203 | 204 | diff = Ms - discount * target_M # num_parallel x batch x batch 205 | fb_offdiag = 0.5 * (diff * self.off_diag).pow(2).sum() / self.off_diag_sum 206 | fb_diag = -torch.diagonal(diff, dim1=1, dim2=2).mean() * Ms.shape[0] 207 | fb_loss = fb_offdiag + fb_diag 208 | 209 | # compute orthonormality loss for backward embedding 210 | Cov = torch.matmul(B, B.T) 211 | orth_loss_diag = -Cov.diag().mean() 212 | orth_loss_offdiag = 0.5 * (Cov * self.off_diag).pow(2).sum() / self.off_diag_sum 213 | orth_loss = orth_loss_offdiag + orth_loss_diag 214 | fb_loss += self.cfg.train.ortho_coef * orth_loss 215 | 216 | q_loss = torch.zeros(1, device=z.device, dtype=z.dtype) 217 | if q_loss_coef is not None: 218 | with torch.no_grad(): 219 | next_Qs = (target_Fs * z).sum(dim=-1) # num_parallel x batch 220 | _, _, next_Q = self.get_targets_uncertainty(next_Qs, self.cfg.train.fb_pessimism_penalty) # batch 221 | cov = torch.matmul(B.T, B) / B.shape[0] # z_dim x z_dim 222 | inv_cov = torch.inverse(cov) # z_dim x z_dim 223 | implicit_reward = (torch.matmul(B, inv_cov) * z).sum(dim=-1) # batch 224 | target_Q = implicit_reward.detach() + discount.squeeze() * next_Q # batch 225 | expanded_targets = target_Q.expand(Fs.shape[0], -1) 226 | Qs = (Fs * z).sum(dim=-1) # num_parallel x batch 227 | q_loss = 0.5 * Fs.shape[0] * F.mse_loss(Qs, expanded_targets) 228 | fb_loss += q_loss_coef * q_loss 229 | 230 | # optimize FB 231 | self.forward_optimizer.zero_grad(set_to_none=True) 232 | self.backward_optimizer.zero_grad(set_to_none=True) 233 | fb_loss.backward() 234 | if clip_grad_norm is not None: 235 | torch.nn.utils.clip_grad_norm_(self._model._forward_map.parameters(), clip_grad_norm) 236 | torch.nn.utils.clip_grad_norm_(self._model._backward_map.parameters(), clip_grad_norm) 237 | self.forward_optimizer.step() 238 | self.backward_optimizer.step() 239 | 240 | with torch.no_grad(): 241 | output_metrics = { 242 | "target_M": target_M.mean(), 243 | "M1": Ms[0].mean(), 244 | "F1": Fs[0].mean(), 245 | "B": B.mean(), 246 | "B_norm": torch.norm(B, dim=-1).mean(), 247 | "z_norm": torch.norm(z, dim=-1).mean(), 248 | "fb_loss": fb_loss, 249 | "fb_diag": fb_diag, 250 | "fb_offdiag": fb_offdiag, 251 | "orth_loss": orth_loss, 252 | "orth_loss_diag": orth_loss_diag, 253 | "orth_loss_offdiag": orth_loss_offdiag, 254 | "q_loss": q_loss, 255 | } 256 | return output_metrics 257 | 258 | def update_actor( 259 | self, 260 | obs: torch.Tensor, 261 | action: torch.Tensor, 262 | z: torch.Tensor, 263 | clip_grad_norm: float | None, 264 | ) -> Dict[str, torch.Tensor]: 265 | return self.update_td3_actor(obs=obs, z=z, clip_grad_norm=clip_grad_norm) 266 | 267 | def update_td3_actor(self, obs: torch.Tensor, z: torch.Tensor, clip_grad_norm: float | None) -> Dict[str, torch.Tensor]: 268 | dist = self._model._actor(obs, z, self._model.cfg.actor_std) 269 | action = dist.sample(clip=self.cfg.train.stddev_clip) 270 | Fs = self._model._forward_map(obs, z, action) # num_parallel x batch x z_dim 271 | Qs = (Fs * z).sum(-1) # num_parallel x batch 272 | _, _, Q = self.get_targets_uncertainty(Qs, self.cfg.train.actor_pessimism_penalty) # batch 273 | actor_loss = -Q.mean() 274 | 275 | # optimize actor 276 | self.actor_optimizer.zero_grad(set_to_none=True) 277 | actor_loss.backward() 278 | if clip_grad_norm is not None: 279 | torch.nn.utils.clip_grad_norm_(self._model._actor.parameters(), clip_grad_norm) 280 | self.actor_optimizer.step() 281 | 282 | return {"actor_loss": actor_loss.detach(), "q": Q.mean().detach()} 283 | 284 | def get_targets_uncertainty( 285 | self, preds: torch.Tensor, pessimism_penalty: torch.Tensor | float 286 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 287 | dim = 0 288 | preds_mean = preds.mean(dim=dim) 289 | preds_uns = preds.unsqueeze(dim=dim) # 1 x n_parallel x ... 290 | preds_uns2 = preds.unsqueeze(dim=dim + 1) # n_parallel x 1 x ... 291 | preds_diffs = torch.abs(preds_uns - preds_uns2) # n_parallel x n_parallel x ... 292 | num_parallel_scaling = preds.shape[dim] ** 2 - preds.shape[dim] 293 | preds_unc = ( 294 | preds_diffs.sum( 295 | dim=(dim, dim + 1), 296 | ) 297 | / num_parallel_scaling 298 | ) 299 | return preds_mean, preds_unc, preds_mean - pessimism_penalty * preds_unc 300 | 301 | def maybe_update_rollout_context(self, z: torch.Tensor | None, step_count: torch.Tensor) -> torch.Tensor: 302 | # get mask for environmets where we need to change z 303 | if z is not None: 304 | mask_reset_z = step_count % self.cfg.train.update_z_every_step == 0 305 | if self.cfg.train.use_mix_rollout and not self.z_buffer.empty(): 306 | new_z = self.z_buffer.sample(z.shape[0], device=self.cfg.model.device) 307 | else: 308 | new_z = self._model.sample_z(z.shape[0], device=self.cfg.model.device) 309 | z = torch.where(mask_reset_z, new_z, z.to(self.cfg.model.device)) 310 | else: 311 | z = self._model.sample_z(step_count.shape[0], device=self.cfg.model.device) 312 | return z 313 | 314 | @classmethod 315 | def load(cls, path: str, device: str | None = None): 316 | path = Path(path) 317 | with (path / "config.json").open() as f: 318 | loaded_config = json.load(f) 319 | if device is not None: 320 | loaded_config["model"]["device"] = device 321 | agent = cls(**loaded_config) 322 | optimizers = torch.load(str(path / "optimizers.pth"), weights_only=True) 323 | agent.actor_optimizer.load_state_dict(optimizers["actor_optimizer"]) 324 | agent.backward_optimizer.load_state_dict(optimizers["backward_optimizer"]) 325 | agent.forward_optimizer.load_state_dict(optimizers["forward_optimizer"]) 326 | 327 | safetensors.torch.load_model(agent._model, path / "model/model.safetensors", device=device) 328 | return agent 329 | 330 | def save(self, output_folder: str) -> None: 331 | output_folder = Path(output_folder) 332 | output_folder.mkdir(exist_ok=True) 333 | with (output_folder / "config.json").open("w+") as f: 334 | json.dump(dataclasses.asdict(self.cfg), f, indent=4) 335 | # save optimizer 336 | torch.save( 337 | { 338 | "actor_optimizer": self.actor_optimizer.state_dict(), 339 | "backward_optimizer": self.backward_optimizer.state_dict(), 340 | "forward_optimizer": self.forward_optimizer.state_dict(), 341 | }, 342 | output_folder / "optimizers.pth", 343 | ) 344 | # save model 345 | model_folder = output_folder / "model" 346 | model_folder.mkdir(exist_ok=True) 347 | self._model.save(output_folder=str(model_folder)) 348 | -------------------------------------------------------------------------------- /metamotivo/fb_cpr/agent.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import dataclasses 7 | from typing import Dict 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import autograd 12 | 13 | from ..fb.agent import FBAgent 14 | from ..fb.agent import TrainConfig as FBTrainConfig 15 | from ..nn_models import _soft_update_params, eval_mode 16 | from .model import Config as FBcprModelConfig 17 | from .model import FBcprModel, config_from_dict 18 | 19 | 20 | @dataclasses.dataclass 21 | class TrainConfig(FBTrainConfig): 22 | lr_discriminator: float = 1e-4 23 | lr_critic: float = 1e-4 24 | critic_target_tau: float = 0.005 25 | critic_pessimism_penalty: float = 0.5 26 | reg_coeff: float = 1 27 | scale_reg: bool = True 28 | # the z distribution for rollouts (when agent.use_mix_rollout=1) and for the mini-batches used in the network updates is: 29 | # - a fraction of 'expert_asm_ratio' zs from expert trajectory encoding 30 | # - a fraction of 'train_goal_ratio' zs from goal encoding (goals sampled from the train buffer) 31 | # - the remaining fraction from the uniform distribution 32 | expert_asm_ratio: float = 0 33 | # a fraction of 'relabel_ratio' transitions in each mini-batch are relabeled with a z sampled from the above distribution 34 | relabel_ratio: float | None = 1 35 | grad_penalty_discriminator: float = 10.0 36 | weight_decay_discriminator: float = 0.0 37 | 38 | 39 | @dataclasses.dataclass 40 | class Config: 41 | model: FBcprModelConfig = dataclasses.field(default_factory=FBcprModelConfig) 42 | train: TrainConfig = dataclasses.field(default_factory=TrainConfig) 43 | cudagraphs: bool = False 44 | compile: bool = False 45 | 46 | 47 | class FBcprAgent(FBAgent): 48 | def __init__(self, **kwargs): 49 | # make sure batch size is a multiple of seq_length 50 | seq_length = kwargs["model"]["seq_length"] 51 | batch_size = kwargs["train"]["batch_size"] 52 | kwargs["train"]["batch_size"] = int(torch.ceil(torch.tensor([batch_size / seq_length])) * seq_length) 53 | del seq_length, batch_size 54 | 55 | self.cfg = config_from_dict(kwargs, Config) 56 | self._model = FBcprModel(**dataclasses.asdict(self.cfg.model)) 57 | self._model.to(self.cfg.model.device) 58 | self.setup_training() 59 | self.setup_compile() 60 | 61 | def setup_training(self) -> None: 62 | super().setup_training() 63 | 64 | # prepare parameter list 65 | self._critic_map_paramlist = tuple(x for x in self._model._critic.parameters()) 66 | self._target_critic_map_paramlist = tuple(x for x in self._model._target_critic.parameters()) 67 | 68 | self.critic_optimizer = torch.optim.Adam( 69 | self._model._critic.parameters(), 70 | lr=self.cfg.train.lr_critic, 71 | capturable=self.cfg.cudagraphs and not self.cfg.compile, 72 | weight_decay=self.cfg.train.weight_decay, 73 | ) 74 | self.discriminator_optimizer = torch.optim.Adam( 75 | self._model._discriminator.parameters(), 76 | lr=self.cfg.train.lr_discriminator, 77 | capturable=self.cfg.cudagraphs and not self.cfg.compile, 78 | weight_decay=self.cfg.train.weight_decay_discriminator, 79 | ) 80 | 81 | def setup_compile(self): 82 | super().setup_compile() 83 | if self.cfg.compile: 84 | mode = "reduce-overhead" if not self.cfg.cudagraphs else None 85 | self.update_critic = torch.compile(self.update_critic, mode=mode) 86 | self.update_discriminator = torch.compile(self.update_discriminator, mode=mode) 87 | self.encode_expert = torch.compile(self.encode_expert, mode=mode, fullgraph=True) 88 | 89 | if self.cfg.cudagraphs: 90 | from tensordict.nn import CudaGraphModule 91 | 92 | self.update_critic = CudaGraphModule(self.update_critic, warmup=5) 93 | self.update_discriminator = CudaGraphModule(self.update_discriminator, warmup=5) 94 | self.encode_expert = CudaGraphModule(self.encode_expert, warmup=5) 95 | 96 | @torch.no_grad() 97 | def sample_mixed_z(self, train_goal: torch.Tensor, expert_encodings: torch.Tensor, *args, **kwargs): 98 | z = self._model.sample_z(self.cfg.train.batch_size, device=self.device) 99 | p_goal = self.cfg.train.train_goal_ratio 100 | p_expert_asm = self.cfg.train.expert_asm_ratio 101 | prob = torch.tensor( 102 | [p_goal, p_expert_asm, 1 - p_goal - p_expert_asm], 103 | dtype=torch.float32, 104 | device=self.device, 105 | ) 106 | mix_idxs = torch.multinomial(prob, num_samples=self.cfg.train.batch_size, replacement=True).reshape(-1, 1) 107 | 108 | # zs obtained by encoding train goals 109 | perm = torch.randperm(self.cfg.train.batch_size, device=self.device) 110 | goals = self._model._backward_map(train_goal[perm]) 111 | goals = self._model.project_z(goals) 112 | z = torch.where(mix_idxs == 0, goals, z) 113 | 114 | # zs obtained by encoding expert trajectories 115 | perm = torch.randperm(self.cfg.train.batch_size, device=self.device) 116 | z = torch.where(mix_idxs == 1, expert_encodings[perm], z) 117 | 118 | return z 119 | 120 | @torch.no_grad() 121 | def encode_expert(self, next_obs: torch.Tensor): 122 | # encode expert trajectories through B 123 | B_expert = self._model._backward_map(next_obs).detach() # batch x d 124 | B_expert = B_expert.view( 125 | self.cfg.train.batch_size // self.cfg.model.seq_length, 126 | self.cfg.model.seq_length, 127 | B_expert.shape[-1], 128 | ) # N x L x d 129 | z_expert = B_expert.mean(dim=1) # N x d 130 | z_expert = self._model.project_z(z_expert) 131 | z_expert = torch.repeat_interleave(z_expert, self.cfg.model.seq_length, dim=0) # batch x d 132 | return z_expert 133 | 134 | def update(self, replay_buffer, step: int) -> Dict[str, torch.Tensor]: 135 | expert_batch = replay_buffer["expert_slicer"].sample(self.cfg.train.batch_size) 136 | train_batch = replay_buffer["train"].sample(self.cfg.train.batch_size) 137 | 138 | train_obs, train_action, train_next_obs = ( 139 | train_batch["observation"].to(self.device), 140 | train_batch["action"].to(self.device), 141 | train_batch["next"]["observation"].to(self.device), 142 | ) 143 | discount = self.cfg.train.discount * ~train_batch["next"]["terminated"].to(self.device) 144 | expert_obs, expert_next_obs = ( 145 | expert_batch["observation"].to(self.device), 146 | expert_batch["next"]["observation"].to(self.device), 147 | ) 148 | 149 | self._model._obs_normalizer(train_obs) 150 | self._model._obs_normalizer(train_next_obs) 151 | 152 | with torch.no_grad(), eval_mode(self._model._obs_normalizer): 153 | train_obs, train_next_obs = ( 154 | self._model._obs_normalizer(train_obs), 155 | self._model._obs_normalizer(train_next_obs), 156 | ) 157 | expert_obs, expert_next_obs = ( 158 | self._model._obs_normalizer(expert_obs), 159 | self._model._obs_normalizer(expert_next_obs), 160 | ) 161 | 162 | torch.compiler.cudagraph_mark_step_begin() 163 | expert_z = self.encode_expert(next_obs=expert_next_obs) 164 | train_z = train_batch["z"].to(self.device) 165 | 166 | # train the discriminator 167 | grad_penalty = self.cfg.train.grad_penalty_discriminator if self.cfg.train.grad_penalty_discriminator > 0 else None 168 | metrics = self.update_discriminator( 169 | expert_obs=expert_obs, 170 | expert_z=expert_z, 171 | train_obs=train_obs, 172 | train_z=train_z, 173 | grad_penalty=grad_penalty, 174 | ) 175 | 176 | z = self.sample_mixed_z(train_goal=train_next_obs, expert_encodings=expert_z).clone() 177 | self.z_buffer.add(z) 178 | 179 | if self.cfg.train.relabel_ratio is not None: 180 | mask = torch.rand((self.cfg.train.batch_size, 1), device=self.device) <= self.cfg.train.relabel_ratio 181 | train_z = torch.where(mask, z, train_z) 182 | 183 | q_loss_coef = self.cfg.train.q_loss_coef if self.cfg.train.q_loss_coef > 0 else None 184 | clip_grad_norm = self.cfg.train.clip_grad_norm if self.cfg.train.clip_grad_norm > 0 else None 185 | 186 | metrics.update( 187 | self.update_fb( 188 | obs=train_obs, 189 | action=train_action, 190 | discount=discount, 191 | next_obs=train_next_obs, 192 | goal=train_next_obs, 193 | z=train_z, 194 | q_loss_coef=q_loss_coef, 195 | clip_grad_norm=clip_grad_norm, 196 | ) 197 | ) 198 | metrics.update( 199 | self.update_critic( 200 | obs=train_obs, 201 | action=train_action, 202 | discount=discount, 203 | next_obs=train_next_obs, 204 | z=train_z, 205 | ) 206 | ) 207 | metrics.update( 208 | self.update_actor( 209 | obs=train_obs, 210 | action=train_action, 211 | z=train_z, 212 | clip_grad_norm=clip_grad_norm, 213 | ) 214 | ) 215 | 216 | with torch.no_grad(): 217 | _soft_update_params( 218 | self._forward_map_paramlist, 219 | self._target_forward_map_paramlist, 220 | self.cfg.train.fb_target_tau, 221 | ) 222 | _soft_update_params( 223 | self._backward_map_paramlist, 224 | self._target_backward_map_paramlist, 225 | self.cfg.train.fb_target_tau, 226 | ) 227 | _soft_update_params( 228 | self._critic_map_paramlist, 229 | self._target_critic_map_paramlist, 230 | self.cfg.train.critic_target_tau, 231 | ) 232 | 233 | return metrics 234 | 235 | @torch.compiler.disable 236 | def gradient_penalty_wgan( 237 | self, 238 | real_obs: torch.Tensor, 239 | real_z: torch.Tensor, 240 | fake_obs: torch.Tensor, 241 | fake_z: torch.Tensor, 242 | ) -> torch.Tensor: 243 | batch_size = real_obs.shape[0] 244 | alpha = torch.rand(batch_size, 1, device=real_obs.device) 245 | interpolates = torch.cat( 246 | [ 247 | (alpha * real_obs + (1 - alpha) * fake_obs).requires_grad_(True), 248 | (alpha * real_z + (1 - alpha) * fake_z).requires_grad_(True), 249 | ], 250 | dim=1, 251 | ) 252 | d_interpolates = self._model._discriminator.compute_logits( 253 | interpolates[:, 0 : real_obs.shape[1]], interpolates[:, real_obs.shape[1] :] 254 | ) 255 | gradients = autograd.grad( 256 | outputs=d_interpolates, 257 | inputs=interpolates, 258 | grad_outputs=torch.ones_like(d_interpolates), 259 | create_graph=True, 260 | retain_graph=True, 261 | only_inputs=True, 262 | )[0] 263 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() 264 | return gradient_penalty 265 | 266 | def update_discriminator( 267 | self, 268 | expert_obs: torch.Tensor, 269 | expert_z: torch.Tensor, 270 | train_obs: torch.Tensor, 271 | train_z: torch.Tensor, 272 | grad_penalty: float | None, 273 | ) -> Dict[str, torch.Tensor]: 274 | expert_logits = self._model._discriminator.compute_logits(obs=expert_obs, z=expert_z) 275 | unlabeled_logits = self._model._discriminator.compute_logits(obs=train_obs, z=train_z) 276 | # these are equivalent to binary cross entropy 277 | expert_loss = -torch.nn.functional.logsigmoid(expert_logits) 278 | unlabeled_loss = torch.nn.functional.softplus(unlabeled_logits) 279 | loss = torch.mean(expert_loss + unlabeled_loss) 280 | 281 | if grad_penalty is not None: 282 | wgan_gp = self.gradient_penalty_wgan(expert_obs, expert_z, train_obs, train_z) 283 | loss += grad_penalty * wgan_gp 284 | 285 | self.discriminator_optimizer.zero_grad(set_to_none=True) 286 | loss.backward() 287 | self.discriminator_optimizer.step() 288 | 289 | with torch.no_grad(): 290 | output_metrics = { 291 | "disc_loss": loss.detach(), 292 | "disc_expert_loss": expert_loss.detach().mean().detach(), 293 | "disc_train_loss": unlabeled_loss.detach().mean().detach(), 294 | } 295 | if grad_penalty is not None: 296 | output_metrics["disc_wgan_gp_loss"] = wgan_gp.detach() 297 | return output_metrics 298 | 299 | def update_critic( 300 | self, 301 | obs: torch.Tensor, 302 | action: torch.Tensor, 303 | discount: torch.Tensor, 304 | next_obs: torch.Tensor, 305 | z: torch.Tensor, 306 | ) -> Dict[str, torch.Tensor]: 307 | num_parallel = self.cfg.model.archi.critic.num_parallel 308 | # compute target critic 309 | with torch.no_grad(): 310 | reward = self._model._discriminator.compute_reward(obs=obs, z=z) 311 | dist = self._model._actor(next_obs, z, self._model.cfg.actor_std) 312 | next_action = dist.sample(clip=self.cfg.train.stddev_clip) 313 | next_Qs = self._model._target_critic(next_obs, z, next_action) # num_parallel x batch x 1 314 | Q_mean, Q_unc, next_V = self.get_targets_uncertainty(next_Qs, self.cfg.train.critic_pessimism_penalty) 315 | target_Q = reward + discount * next_V 316 | expanded_targets = target_Q.expand(num_parallel, -1, -1) 317 | 318 | # compute critic loss 319 | Qs = self._model._critic(obs, z, action) # num_parallel x batch x (1 or n_bins) 320 | critic_loss = 0.5 * num_parallel * F.mse_loss(Qs, expanded_targets) 321 | 322 | # optimize critic 323 | self.critic_optimizer.zero_grad(set_to_none=True) 324 | critic_loss.backward() 325 | self.critic_optimizer.step() 326 | 327 | with torch.no_grad(): 328 | output_metrics = { 329 | "target_Q": target_Q.mean().detach(), 330 | "Q1": Qs.mean().detach(), 331 | "mean_next_Q": Q_mean.mean().detach(), 332 | "unc_Q": Q_unc.mean().detach(), 333 | "critic_loss": critic_loss.mean().detach(), 334 | "mean_disc_reward": reward.mean().detach(), 335 | } 336 | return output_metrics 337 | 338 | def update_actor( 339 | self, 340 | obs: torch.Tensor, 341 | action: torch.Tensor, 342 | z: torch.Tensor, 343 | clip_grad_norm: float | None, 344 | ) -> Dict[str, torch.Tensor]: 345 | dist = self._model._actor(obs, z, self._model.cfg.actor_std) 346 | action = dist.sample(clip=self.cfg.train.stddev_clip) 347 | 348 | # compute discriminator reward loss 349 | Qs_discriminator = self._model._critic(obs, z, action) # num_parallel x batch x (1 or n_bins) 350 | _, _, Q_discriminator = self.get_targets_uncertainty(Qs_discriminator, self.cfg.train.actor_pessimism_penalty) # batch 351 | 352 | # compute fb reward loss 353 | Fs = self._model._forward_map(obs, z, action) # num_parallel x batch x z_dim 354 | Qs_fb = (Fs * z).sum(-1) # num_parallel x batch 355 | _, _, Q_fb = self.get_targets_uncertainty(Qs_fb, self.cfg.train.actor_pessimism_penalty) # batch 356 | 357 | weight = Q_fb.abs().mean().detach() if self.cfg.train.scale_reg else 1.0 358 | actor_loss = -Q_discriminator.mean() * self.cfg.train.reg_coeff * weight - Q_fb.mean() 359 | 360 | # optimize actor 361 | self.actor_optimizer.zero_grad(set_to_none=True) 362 | actor_loss.backward() 363 | if clip_grad_norm is not None: 364 | torch.nn.utils.clip_grad_norm_(self._model._actor.parameters(), clip_grad_norm) 365 | self.actor_optimizer.step() 366 | 367 | with torch.no_grad(): 368 | output_metrics = { 369 | "actor_loss": actor_loss.detach(), 370 | "Q_discriminator": Q_discriminator.mean().detach(), 371 | "Q_fb": Q_fb.mean().detach(), 372 | } 373 | return output_metrics 374 | -------------------------------------------------------------------------------- /tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "d0167d15-a6ff-4602-835d-c88851e83113", 6 | "metadata": {}, 7 | "source": [ 8 | "# Meta Motivo Tutorial\n", 9 | "This notebook provides a simple introduction on how to use the Meta Motivo api." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "id": "be0d8e4a-882c-467e-bce6-ef2f33b509e2", 15 | "metadata": {}, 16 | "source": [ 17 | "## All imports" 18 | ] 19 | }, 20 | { 21 | "cell_type": "code", 22 | "execution_count": null, 23 | "id": "03d3cd63-1d2e-4bda-b224-d2b2b73bf655", 24 | "metadata": {}, 25 | "outputs": [], 26 | "source": [ 27 | "from packaging.version import Version\n", 28 | "from metamotivo.fb_cpr.huggingface import FBcprModel\n", 29 | "from huggingface_hub import hf_hub_download\n", 30 | "from humenv import make_humenv\n", 31 | "import gymnasium\n", 32 | "from gymnasium.wrappers import FlattenObservation, TransformObservation\n", 33 | "from metamotivo.buffers.buffers import DictBuffer\n", 34 | "from humenv.env import make_from_name\n", 35 | "from humenv import rewards as humenv_rewards\n", 36 | "\n", 37 | "import torch\n", 38 | "import mediapy as media\n", 39 | "import math\n", 40 | "import h5py\n", 41 | "from pathlib import Path\n", 42 | "import numpy as np" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "id": "aa35d241-fa2d-4ad3-aaa9-9e2b4175e742", 48 | "metadata": {}, 49 | "source": [ 50 | "## Model download\n", 51 | "The first step is to download the model. We show how to use HuggingFace hub for that." 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "00f7b632-864d-4b05-848c-f7a22b662a12", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "model = FBcprModel.from_pretrained(\"facebook/metamotivo-S-1\")\n", 62 | "print(model)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "markdown", 67 | "id": "3609f0e1-2ff9-462b-9bea-15695a128da7", 68 | "metadata": {}, 69 | "source": [ 70 | "**Run a policy from Meta Motivo:**\n", 71 | "\n", 72 | "Now that we saw how to load a pre-trained Meta Motivo policy, we can prompt it and execute actions with it. \n", 73 | "\n", 74 | "The first step is to sample a context embedding `z` that needs to be passed to the policy." 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "2f4e85ca-1940-403b-adfd-126c244e39ed", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "device = \"cpu\"\n", 85 | "\n", 86 | "if Version(\"0.26\") <= Version(gymnasium.__version__) < Version(\"1.0\"):\n", 87 | " transform_obs_wrapper = lambda env: TransformObservation(\n", 88 | " env, lambda obs: torch.tensor(obs.reshape(1, -1), dtype=torch.float32, device=device)\n", 89 | " )\n", 90 | "else:\n", 91 | " transform_obs_wrapper = lambda env: TransformObservation(\n", 92 | " env, lambda obs: torch.tensor(obs.reshape(1, -1), dtype=torch.float32, device=device), env.observation_space\n", 93 | " )\n", 94 | "\n", 95 | "env, _ = make_humenv(\n", 96 | " num_envs=1,\n", 97 | " wrappers=[\n", 98 | " FlattenObservation,\n", 99 | " transform_obs_wrapper,\n", 100 | " ],\n", 101 | " state_init=\"Default\",\n", 102 | ")\n", 103 | "\n", 104 | "model.to(device)\n", 105 | "z = model.sample_z(1)\n", 106 | "print(f\"embedding size {z.shape}\")\n", 107 | "print(f\"z norm: {torch.norm(z)}\")\n", 108 | "print(f\"z norm / sqrt(d): {torch.norm(z) / math.sqrt(z.shape[-1])}\")\n", 109 | "observation, _ = env.reset()\n", 110 | "frames = [env.render()]\n", 111 | "for i in range(30):\n", 112 | " action = model.act(observation, z, mean=True)\n", 113 | " observation, reward, terminated, truncated, info = env.step(action.cpu().numpy().ravel())\n", 114 | " frames.append(env.render())\n", 115 | "\n", 116 | "media.show_video(frames, fps=30)" 117 | ] 118 | }, 119 | { 120 | "cell_type": "markdown", 121 | "id": "05338016-0890-4f9e-8741-73985dbc89b3", 122 | "metadata": {}, 123 | "source": [ 124 | "### Computing Q-functions\n", 125 | "\n", 126 | "FB-CPR provides a way of directly computing the action-value function of any policy embedding `z` on any task embedding `z_r`. Then, the Q function of a policy $z$ is given by\n", 127 | "\n", 128 | "$Q(s,a, z) = F(s,a,z) \\cdot z_r$\n", 129 | "\n", 130 | "The task embedding can be computed in the following way. Given a set of samples labeled with rewards $(s,a,s',r)$, the task embedding is given by: \n", 131 | "\n", 132 | "$z_r = \\mathrm{normalised}(\\sum_{i \\in \\mathrm{batch}} r_i B(s'_i))$." 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "id": "98e3b022-22d1-4602-ab84-b056d621b702", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "def Qfunction(state, action, z_reward, z_policy):\n", 143 | " F = model.forward_map(obs=state, z=z_policy.repeat(state.shape[0],1), action=action) # num_parallel x num_samples x z_dim\n", 144 | " Q = F @ z_reward.ravel()\n", 145 | " return Q.mean(axis=0)\n", 146 | "\n", 147 | "z_reward = model.sample_z(1)\n", 148 | "z_policy = model.sample_z(1)\n", 149 | "state = torch.rand((10, env.observation_space.shape[0]), device=model.cfg.device, dtype=torch.float32)\n", 150 | "action = torch.rand((10, env.action_space.shape[0]), device=model.cfg.device, dtype=torch.float32)*2 - 1\n", 151 | "Q = Qfunction(state, action, z_reward, z_policy)\n", 152 | "print(Q)" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "id": "5f168c33-a35f-4335-bc97-e6d4eeb5fce5", 158 | "metadata": {}, 159 | "source": [ 160 | "## Prompting the model\n", 161 | "\n", 162 | "We have seen that we can condition the model via the context variable `z`. We can control the task to execute via _prompting_ (or _policy inference_)." 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "id": "10d2b486-63c1-45f7-bfca-acd5308da1ec", 168 | "metadata": {}, 169 | "source": [ 170 | "### Reward prompts\n", 171 | "The first version of inference we investigate is the reward prompting, i.e., given a set of reward label samples we can infer in a zero-shot way the near-optimal policy for solving such task.\n", 172 | "\n", 173 | "First step, download the data for inference. We provide a buffer for inference of about 500k samples. This buffer has been generated by randomly subsampling the final replay buffer." 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "26312a39-fdb8-4843-a3e2-08f0dafc3db5", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "local_dir = \"metamotivo-S-1-datasets\"\n", 184 | "dataset = \"buffer_inference_500000.hdf5\"\n", 185 | "buffer_path = hf_hub_download(\n", 186 | " repo_id=\"facebook/metamotivo-S-1\",\n", 187 | " filename=f\"data/{dataset}\",\n", 188 | " repo_type=\"model\",\n", 189 | " local_dir=local_dir,\n", 190 | " )\n", 191 | "print(buffer_path)" 192 | ] 193 | }, 194 | { 195 | "cell_type": "markdown", 196 | "id": "ac150fc7-0719-4428-8ef1-eac48ddf0d9a", 197 | "metadata": {}, 198 | "source": [ 199 | "Now that we have download the h5 file for inference, we can conveniently loaded it in a buffer." 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": null, 205 | "id": "20c39961-15e6-4739-96fd-bb7aa47e60f9", 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "hf = h5py.File(buffer_path, \"r\")\n", 210 | "print(hf.keys())\n", 211 | "data = {}\n", 212 | "for k, v in hf.items():\n", 213 | " print(f\"{k:20s}: {v.shape}\")\n", 214 | " data[k] = v[:]\n", 215 | "buffer = DictBuffer(capacity=data[\"qpos\"].shape[0], device=\"cpu\")\n", 216 | "buffer.extend(data)\n", 217 | "del data" 218 | ] 219 | }, 220 | { 221 | "cell_type": "code", 222 | "execution_count": null, 223 | "id": "af64888e-c992-4c0b-aa7b-9421942ee605", 224 | "metadata": {}, 225 | "outputs": [], 226 | "source": [ 227 | "batch = buffer.sample(5)\n", 228 | "for k, v in batch.items():\n", 229 | " print(f\"{k:20s}: {v.shape}\")" 230 | ] 231 | }, 232 | { 233 | "cell_type": "markdown", 234 | "id": "1f1238a0-1ffc-4de9-a698-7853ea7fdd92", 235 | "metadata": {}, 236 | "source": [ 237 | "As you can see, the buffer does not provide a reward signal. We need to label this buffer with the desired reward function. We provide API for that but here we start looking into the basic steps:\n", 238 | "* Instantiate a reward function\n", 239 | "* Computing the reward from the batch data" 240 | ] 241 | }, 242 | { 243 | "cell_type": "code", 244 | "execution_count": null, 245 | "id": "dedac7ea-6365-41ce-b7cb-6b2f1ccbd4ed", 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "reward_fn = humenv_rewards.LocomotionReward(move_speed=2.0) # move ahead with speed 2\n", 250 | "# humenv provides also a name-base reward initialization. We could\n", 251 | "# get the same reward function in this way\n", 252 | "reward_fn = make_from_name(\"move-ego-0-2\") \n", 253 | "print(reward_fn)" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "id": "d95db165-5c5c-402c-9ddf-6cfed05f48a8", 259 | "metadata": {}, 260 | "source": [ 261 | "We can call the method `__call__` to obtain a reward value from the physics state. This function receives a mujoco model, qpos, qvel and the action. See the humenv tutorial for more information." 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "id": "caa8a3dd-159e-45a9-95b7-c52d6244b6f9", 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "N = 100_000\n", 272 | "batch = buffer.sample(N)\n", 273 | "rewards = []\n", 274 | "for i in range(N):\n", 275 | " rewards.append(\n", 276 | " reward_fn(\n", 277 | " env.unwrapped.model,\n", 278 | " qpos=batch[\"next_qpos\"][i],\n", 279 | " qvel=batch[\"next_qvel\"][i],\n", 280 | " ctrl=batch[\"action\"][i])\n", 281 | " )\n", 282 | "rewards = np.stack(rewards).reshape(-1,1)\n", 283 | "print(rewards.ravel())" 284 | ] 285 | }, 286 | { 287 | "cell_type": "markdown", 288 | "id": "f1ccc931-9be2-4a88-b66f-70d369648c2c", 289 | "metadata": {}, 290 | "source": [ 291 | "**Note** that the reward functions implemented in humenv are functions of next state and action which means we need to use `next_qpos` and `next_qvel` that are the physical state of the system at the next state.\n", 292 | "\n", 293 | "We provide a multi-thread version for faster relabeling, see `metamotivo.wrappers.humenvbench.relabel`." 294 | ] 295 | }, 296 | { 297 | "cell_type": "code", 298 | "execution_count": null, 299 | "id": "10ca4068-e529-43d4-a110-5f224d118d2e", 300 | "metadata": {}, 301 | "outputs": [], 302 | "source": [ 303 | "from metamotivo.wrappers.humenvbench import relabel\n", 304 | "rewards = relabel(\n", 305 | " env,\n", 306 | " qpos=batch[\"next_qpos\"],\n", 307 | " qvel=batch[\"next_qvel\"],\n", 308 | " action=batch[\"action\"],\n", 309 | " reward_fn=reward_fn, \n", 310 | " max_workers=8\n", 311 | ")\n", 312 | "print(rewards.ravel())" 313 | ] 314 | }, 315 | { 316 | "cell_type": "markdown", 317 | "id": "8327f700-746f-4d76-b0e2-4b66cd4b44de", 318 | "metadata": {}, 319 | "source": [ 320 | "We can now infer the context `z` for the selected task." 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "id": "5537121c-ce60-44e0-9fdf-5457bcad6c08", 327 | "metadata": {}, 328 | "outputs": [], 329 | "source": [ 330 | "z = model.reward_wr_inference(\n", 331 | " next_obs=batch[\"next_observation\"],\n", 332 | " reward=torch.tensor(rewards, device=model.cfg.device, dtype=torch.float32)\n", 333 | ")\n", 334 | "print(z.shape)\n", 335 | "\n", 336 | "observation, _ = env.reset()\n", 337 | "frames = [env.render()]\n", 338 | "for i in range(30):\n", 339 | " action = model.act(observation, z, mean=True)\n", 340 | " observation, reward, terminated, truncated, info = env.step(action.cpu().numpy().ravel())\n", 341 | " frames.append(env.render())\n", 342 | "\n", 343 | "media.show_video(frames, fps=30)" 344 | ] 345 | }, 346 | { 347 | "cell_type": "markdown", 348 | "id": "6bd2550c-06a5-4353-b3f5-4eeb56c30d70", 349 | "metadata": {}, 350 | "source": [ 351 | "Let's compute the **Q-function** for this policy." 352 | ] 353 | }, 354 | { 355 | "cell_type": "code", 356 | "execution_count": null, 357 | "id": "3719a3bf-8282-430a-84a6-c5067a4cfcd0", 358 | "metadata": {}, 359 | "outputs": [], 360 | "source": [ 361 | "z_reward = torch.sum(\n", 362 | " model.backward_map(obs=batch[\"next_observation\"]) * torch.tensor(rewards, dtype=torch.float32, device=model.cfg.device),\n", 363 | " dim=0\n", 364 | ")\n", 365 | "z_reward = model.project_z(z_reward)\n", 366 | "Q = Qfunction(batch[\"observation\"], batch[\"action\"], z_reward, z)\n", 367 | "print(Q)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "id": "42121e99-c920-4d69-a499-ec039e0b8e05", 373 | "metadata": {}, 374 | "source": [ 375 | "# Goal and Tracking prompts\n", 376 | "The model supports two other modalities, `goal` and `tracking`. These two modalities expose similar functions for context inference:\n", 377 | "- `def goal_inference(self, next_obs: torch.Tensor) -> torch.Tensor`\n", 378 | "- `def tracking_inference(self, next_obs: torch.Tensor) -> torch.Tensor`\n", 379 | " \n", 380 | "We show an example on how to perform goal inference." 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "id": "c3e1de2e-d792-446c-ba1f-7b07abf69cac", 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "goal_qpos = np.array([0.13769039,-0.20029453,0.42305034,0.21707786,0.94573617,0.23868944\n", 391 | ",0.03856998,-1.05566834,-0.12680767,0.11718296,1.89464102,-0.01371153\n", 392 | ",-0.07981451,-0.70497424,-0.0478,-0.05700732,-0.05363342,-0.0657329\n", 393 | ",0.08163511,-1.06263979,0.09788937,-0.22008936,1.85898192,0.08773695\n", 394 | ",0.06200327,-0.3802791,0.07829525,0.06707749,0.14137152,0.08834448\n", 395 | ",-0.07649805,0.78328658,0.12580912,-0.01076061,-0.35937259,-0.13176489\n", 396 | ",0.07497022,-0.2331914,-0.11682692,0.04782308,-0.13571422,0.22827948\n", 397 | ",-0.23456622,-0.12406075,-0.04466465,0.2311667,-0.12232673,-0.25614032\n", 398 | ",-0.36237662,0.11197906,-0.08259534,-0.634934,-0.30822742,-0.93798716\n", 399 | ",0.08848668,0.4083417,-0.30910404,0.40950143,0.30815359,0.03266103\n", 400 | ",1.03959336,-0.19865537,0.25149713,0.3277561,0.16943092,0.69125975\n", 401 | ",0.21721349,-0.30871948,0.88890484,-0.08884043,0.38474549,0.30884107\n", 402 | ",-0.40933304,0.30889523,-0.29562966,-0.6271498])\n", 403 | "env.unwrapped.set_physics(qpos=goal_qpos, qvel=np.zeros(75))\n", 404 | "goal_obs = torch.tensor(env.unwrapped.get_obs()[\"proprio\"].reshape(1,-1), device=model.cfg.device, dtype=torch.float32)\n", 405 | "print(\"goal pose\")\n", 406 | "media.show_image(env.render())" 407 | ] 408 | }, 409 | { 410 | "cell_type": "code", 411 | "execution_count": null, 412 | "id": "24755455-28f9-41ad-b7e6-26f5b10ae13e", 413 | "metadata": {}, 414 | "outputs": [], 415 | "source": [ 416 | "z = model.goal_inference(next_obs=goal_obs)\n", 417 | "\n", 418 | "\n", 419 | "observation, _ = env.reset()\n", 420 | "frames = [env.render()]\n", 421 | "for i in range(30):\n", 422 | " action = model.act(observation, z, mean=True)\n", 423 | " observation, reward, terminated, truncated, info = env.step(action.cpu().numpy().ravel())\n", 424 | " frames.append(env.render())\n", 425 | "\n", 426 | "media.show_video(frames, fps=30)" 427 | ] 428 | } 429 | ], 430 | "metadata": { 431 | "kernelspec": { 432 | "display_name": "Python 3 (ipykernel)", 433 | "language": "python", 434 | "name": "python3" 435 | }, 436 | "language_info": { 437 | "codemirror_mode": { 438 | "name": "ipython", 439 | "version": 3 440 | }, 441 | "file_extension": ".py", 442 | "mimetype": "text/x-python", 443 | "name": "python", 444 | "nbconvert_exporter": "python", 445 | "pygments_lexer": "ipython3", 446 | "version": "3.12.2" 447 | } 448 | }, 449 | "nbformat": 4, 450 | "nbformat_minor": 5 451 | } 452 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | 409 | -------------------------------------------------------------------------------- /examples/fbcpr_train_humenv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from __future__ import annotations 7 | 8 | import os 9 | 10 | os.environ["OMP_NUM_THREADS"] = "1" 11 | 12 | import torch 13 | 14 | torch.set_float32_matmul_precision("high") 15 | 16 | import collections 17 | import dataclasses 18 | import json 19 | import numbers 20 | import random 21 | import time 22 | from pathlib import Path 23 | from typing import List 24 | 25 | import gymnasium 26 | import humenv 27 | import numpy as np 28 | import tyro 29 | from gymnasium.wrappers import TimeAwareObservation 30 | from humenv import make_humenv 31 | from humenv.bench import ( 32 | RewardEvaluation, 33 | TrackingEvaluation, 34 | ) 35 | from humenv.misc.motionlib import canonicalize, load_episode_based_h5 36 | from packaging.version import Version 37 | from tqdm import tqdm 38 | 39 | import wandb 40 | from metamotivo.buffers.buffers import DictBuffer, TrajectoryBuffer 41 | from metamotivo.fb_cpr import FBcprAgent, FBcprAgentConfig 42 | from metamotivo.wrappers.humenvbench import RewardWrapper, TrackingWrapper 43 | 44 | if Version(humenv.__version__) < Version("0.1.2"): 45 | raise RuntimeError("This script requires humenv>=0.1.2") 46 | if Version(gymnasium.__version__) < Version("1.0"): 47 | raise RuntimeError("This script requires gymnasium>=1.0") 48 | 49 | 50 | def set_seed_everywhere(seed): 51 | torch.manual_seed(seed) 52 | if torch.cuda.is_available(): 53 | torch.cuda.manual_seed_all(seed) 54 | np.random.seed(seed) 55 | random.seed(seed) 56 | 57 | 58 | def load_expert_trajectories(motions: str | Path, motions_root: str | Path, device: str, sequence_length: int) -> TrajectoryBuffer: 59 | with open(motions, "r") as txtf: 60 | h5files = [el.strip().replace(" ", "") for el in txtf.readlines()] 61 | episodes = [] 62 | for h5 in tqdm(h5files, leave=False): 63 | h5 = canonicalize(h5, base_path=motions_root) 64 | _ep = load_episode_based_h5(h5, keys=None) 65 | for el in _ep: 66 | el["observation"] = el["observation"].astype(np.float32) 67 | del el["file_name"] 68 | episodes.extend(_ep) 69 | buffer = TrajectoryBuffer( 70 | capacity=len(episodes), 71 | seq_length=sequence_length, 72 | device=device, 73 | ) 74 | buffer.extend(episodes) 75 | return buffer 76 | 77 | 78 | @dataclasses.dataclass 79 | class TrainConfig: 80 | seed: int = 0 81 | motions: str = "" 82 | motions_root: str = "" 83 | buffer_size: int = 5_000_000 84 | online_parallel_envs: int = 50 85 | log_every_updates: int = 100_000 86 | work_dir: str | None = None 87 | num_env_steps: int = 30_000_000 88 | update_agent_every: int | None = None 89 | num_seed_steps: int | None = None 90 | num_agent_updates: int | None = None 91 | checkpoint_every_steps: int = 5_000_000 92 | prioritization: bool = False 93 | prioritization_min_val: float = 0.5 94 | prioritization_max_val: float = 5 95 | prioritization_scale: float = 2 96 | 97 | # WANDB 98 | use_wandb: bool = False 99 | wandb_ename: str | None = None 100 | wandb_gname: str | None = None 101 | wandb_pname: str | None = "fbcpr_humenv" 102 | 103 | # misc 104 | compile: bool = False 105 | cudagraphs: bool = False 106 | device: str = "cuda" 107 | buffer_device: str = "cpu" 108 | 109 | # eval 110 | evaluate: bool = False 111 | eval_every_steps: int = 1_000_000 112 | reward_eval_num_envs: int = 5 113 | reward_eval_num_eval_episodes: int = 10 114 | reward_eval_num_inference_samples: int = 50_000 115 | reward_eval_tasks: List[str] | None = None 116 | 117 | tracking_eval_num_envs: int = 60 118 | tracking_eval_motions: str | None = None 119 | tracking_eval_motions_root: str | None = None 120 | 121 | def __post_init__(self): 122 | if self.reward_eval_tasks is None: 123 | # this is just a subset of the tasks available in humenv 124 | self.reward_eval_tasks = [ 125 | "move-ego-0-0", 126 | "jump-2", 127 | "move-ego-0-2", 128 | "move-ego-90-2", 129 | "move-ego-180-2", 130 | "rotate-x-5-0.8", 131 | "rotate-y-5-0.8", 132 | "rotate-z-5-0.8" 133 | ] 134 | if self.update_agent_every is None: 135 | self.update_agent_every = 10 * self.online_parallel_envs 136 | if self.num_seed_steps is None: 137 | self.num_seed_steps = 1000 * self.online_parallel_envs 138 | if self.num_agent_updates is None: 139 | self.num_agent_updates = self.online_parallel_envs 140 | if self.prioritization: 141 | # NOTE: when using prioritization train and eval motions must match 142 | self.tracking_eval_motions = self.motions 143 | self.tracking_eval_motions_root = self.motions_root 144 | self.evaluate = True 145 | 146 | 147 | class Workspace: 148 | def __init__(self, cfg: TrainConfig, agent_cfg: FBcprAgentConfig) -> None: 149 | self.cfg = cfg 150 | self.agent_cfg = agent_cfg 151 | if self.cfg.work_dir is None: 152 | import string 153 | 154 | tmp_name = "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) 155 | self.work_dir = Path.cwd() / "tmp_fbcpr" / tmp_name 156 | self.cfg.work_dir = str(self.work_dir) 157 | else: 158 | self.work_dir = self.cfg.work_dir 159 | print(f"Workdir: {self.work_dir}") 160 | self.work_dir = Path(self.work_dir) 161 | self.work_dir.mkdir(exist_ok=True, parents=True) 162 | 163 | set_seed_everywhere(self.cfg.seed) 164 | self.agent = FBcprAgent(**dataclasses.asdict(agent_cfg)) 165 | 166 | if self.cfg.use_wandb: 167 | exp_name = "fbcpr" 168 | wandb_name = exp_name 169 | # fmt: off 170 | wandb_config = dataclasses.asdict(self.cfg) 171 | wandb.init(entity=self.cfg.wandb_ename, project=self.cfg.wandb_pname, 172 | group=self.cfg.wandb_gname, name=wandb_name, # mode="disabled", 173 | config=wandb_config) # type: ignore 174 | # fmt: on 175 | 176 | with (self.work_dir / "config.json").open("w") as f: 177 | json.dump(dataclasses.asdict(self.cfg), f, indent=4) 178 | 179 | self.manager = None 180 | 181 | def train(self): 182 | self.start_time = time.time() 183 | self.train_online() 184 | 185 | def train_online(self) -> None: 186 | print("Loading expert trajectories") 187 | expert_buffer = load_expert_trajectories(self.cfg.motions, self.cfg.motions_root, device=self.cfg.buffer_device, sequence_length=self.agent_cfg.model.seq_length) 188 | 189 | print("Creating the training environment") 190 | train_env, mp_info = make_humenv( 191 | num_envs=self.cfg.online_parallel_envs, 192 | # vectorization_mode="sync", 193 | wrappers=[ 194 | gymnasium.wrappers.FlattenObservation, 195 | lambda env: TimeAwareObservation(env, flatten=False), 196 | ], 197 | render_width=320, 198 | render_height=320, 199 | motions=self.cfg.motions, 200 | motion_base_path=self.cfg.motions_root, 201 | fall_prob=0.2, 202 | state_init="MoCapAndFall", 203 | ) 204 | 205 | print("Allocating buffers") 206 | replay_buffer = { 207 | "train": DictBuffer(capacity=self.cfg.buffer_size, device=self.cfg.buffer_device), 208 | "expert_slicer": expert_buffer, 209 | } 210 | 211 | print("Starting training") 212 | progb = tqdm(total=self.cfg.num_env_steps) 213 | td, info = train_env.reset() 214 | done = np.zeros(self.cfg.online_parallel_envs, dtype=np.bool) 215 | total_metrics, context = None, None 216 | start_time = time.time() 217 | fps_start_time = time.time() 218 | for t in range(0, self.cfg.num_env_steps, self.cfg.online_parallel_envs): 219 | if self.cfg.evaluate and t % self.cfg.eval_every_steps == 0: 220 | eval_metrics = self.eval(t, replay_buffer=replay_buffer) 221 | if self.cfg.prioritization: 222 | # priorities 223 | index_in_buffer = {} 224 | for i, ep in enumerate(replay_buffer["expert_slicer"].storage): 225 | index_in_buffer[ep["motion_id"][0].item()] = i 226 | motions_id, priorities, idxs = [], [], [] 227 | for _, metr in eval_metrics["tracking"].items(): 228 | motions_id.append(metr["motion_id"]) 229 | priorities.append(metr["emd"]) 230 | idxs.append(index_in_buffer[metr["motion_id"]]) 231 | priorities = ( 232 | torch.clamp( 233 | torch.tensor(priorities, dtype=torch.float32, device=self.agent.device), 234 | min=self.cfg.prioritization_min_val, 235 | max=self.cfg.prioritization_max_val, 236 | ) 237 | * self.cfg.prioritization_scale 238 | ) 239 | bins = torch.floor(priorities) 240 | for i in range(int(bins.min().item()), int(bins.max().item()) + 1): 241 | mask = bins == i 242 | n = mask.sum().item() 243 | if n > 0: 244 | priorities[mask] = 1 / n 245 | 246 | if mp_info is not None: 247 | mp_info["motion_buffer"].update_priorities(motions_id=motions_id, priorities=priorities.cpu().numpy()) 248 | else: 249 | train_env.unwrapped.motion_buffer.update_priorities(motions_id=motions_id, priorities=priorities.cpu().numpy()) 250 | replay_buffer["expert_slicer"].update_priorities( 251 | priorities=priorities.to(self.cfg.buffer_device), idxs=torch.tensor(np.array(idxs), device=self.cfg.buffer_device) 252 | ) 253 | 254 | with torch.no_grad(): 255 | obs = torch.tensor(td["obs"], dtype=torch.float32, device=self.agent.device) 256 | step_count = torch.tensor(td["time"], device=self.agent.device) 257 | context = self.agent.maybe_update_rollout_context(z=context, step_count=step_count) 258 | if t < self.cfg.num_seed_steps: 259 | action = train_env.action_space.sample().astype(np.float32) 260 | else: 261 | # this works in inference mode 262 | action = self.agent.act(obs=obs, z=context, mean=False).cpu().detach().numpy() 263 | new_td, reward, terminated, truncated, new_info = train_env.step(action) 264 | real_next_obs = new_td["obs"].astype(np.float32).copy() 265 | new_done = np.logical_or(terminated.ravel(), truncated.ravel()) 266 | 267 | if Version(gymnasium.__version__) >= Version("1.0"): 268 | # We add only transitions corresponding to environments that have not reset in the previous step. 269 | # For environments that have reset in the previous step, the new observation corresponds to the state after reset. 270 | indexes = ~done 271 | data = { 272 | "observation": obs[indexes], 273 | "action": action[indexes], 274 | "z": context[indexes], 275 | "step_count": step_count[indexes], 276 | "qpos": info["qpos"][indexes], 277 | "qvel": info["qvel"][indexes], 278 | "next": { 279 | "observation": real_next_obs[indexes], 280 | "terminated": terminated[indexes].reshape(-1, 1), 281 | "truncated": truncated[indexes].reshape(-1, 1), 282 | "reward": reward[indexes].reshape(-1, 1), 283 | "qpos": new_info["qpos"][indexes], 284 | "qvel": new_info["qvel"][indexes], 285 | }, 286 | } 287 | else: 288 | raise NotImplementedError("still some work to do for gymnasium < 1.0") 289 | replay_buffer["train"].extend(data) 290 | 291 | if len(replay_buffer["train"]) > 0 and t > self.cfg.num_seed_steps and t % self.cfg.update_agent_every == 0: 292 | for _ in range(self.cfg.num_agent_updates): 293 | metrics = self.agent.update(replay_buffer, t) 294 | if total_metrics is None: 295 | num_metrics_updates = 1 296 | total_metrics = {k: metrics[k].clone() for k in metrics.keys()} 297 | else: 298 | num_metrics_updates += 1 299 | total_metrics = {k: total_metrics[k] + metrics[k] for k in metrics.keys()} 300 | 301 | if t % self.cfg.log_every_updates == 0 and total_metrics is not None: 302 | m_dict = {} 303 | for k in sorted(list(total_metrics.keys())): 304 | tmp = total_metrics[k] / num_metrics_updates 305 | m_dict[k] = np.round(tmp.mean().item(), 6) 306 | m_dict["duration [minutes]"] = (time.time() - start_time) / 60 307 | m_dict["FPS"] = (1 if t == 0 else self.cfg.log_every_updates) / (time.time() - fps_start_time) 308 | if self.cfg.use_wandb: 309 | wandb.log( 310 | {f"train/{k}": v for k, v in m_dict.items()}, 311 | step=t, 312 | ) 313 | print(m_dict) 314 | total_metrics = None 315 | fps_start_time = time.time() 316 | 317 | if t % self.cfg.checkpoint_every_steps == 0: 318 | self.agent.save(str(self.work_dir / "checkpoint")) 319 | progb.update(self.cfg.online_parallel_envs) 320 | td = new_td 321 | done = new_done 322 | info = new_info 323 | self.agent.save(str(self.work_dir / "checkpoint")) 324 | if mp_info is not None: 325 | mp_info["manager"].shutdown() 326 | 327 | def eval(self, t, replay_buffer): 328 | print(f"Starting evaluation at time {t}") 329 | inference_function: str = "reward_wr_inference" 330 | 331 | self.agent._model.to("cpu") 332 | self.agent._model.train(False) 333 | 334 | # --------------------------------------------------------------- 335 | # Reward evaluation 336 | # --------------------------------------------------------------- 337 | eval_agent = RewardWrapper( 338 | model=self.agent._model, 339 | inference_dataset=replay_buffer["train"], 340 | num_samples_per_inference=self.cfg.reward_eval_num_inference_samples, 341 | inference_function=inference_function, 342 | max_workers=1, 343 | process_executor=False, 344 | ) 345 | reward_eval = RewardEvaluation( 346 | tasks=self.cfg.reward_eval_tasks, 347 | env_kwargs={"state_init": "Fall", "context": "spawn"}, 348 | num_contexts=1, 349 | num_envs=self.cfg.reward_eval_num_envs, 350 | num_episodes=self.cfg.reward_eval_num_eval_episodes, 351 | ) 352 | start_t = time.time() 353 | reward_metrics = {} 354 | if not replay_buffer["train"].empty(): 355 | print(f"Reward started at {time.ctime(start_t)}", flush=True) 356 | reward_metrics = reward_eval.run(agent=eval_agent) 357 | duration = time.time() - start_t 358 | print(f"Reward eval time: {duration}") 359 | if self.cfg.use_wandb: 360 | m_dict = {} 361 | avg_return = [] 362 | for task in reward_metrics.keys(): 363 | m_dict[f"{task}/return"] = np.mean(reward_metrics[task]["reward"]) 364 | m_dict[f"{task}/return#std"] = np.std(reward_metrics[task]["reward"]) 365 | avg_return.append(reward_metrics[task]["reward"]) 366 | m_dict["reward/return"] = np.mean(avg_return) 367 | m_dict["reward/return#std"] = np.std(avg_return) 368 | m_dict["reward/time"] = duration 369 | wandb.log( 370 | {f"eval/reward/{k}": v for k, v in m_dict.items()}, 371 | step=t, 372 | ) 373 | # --------------------------------------------------------------- 374 | # Tracking evaluation 375 | # --------------------------------------------------------------- 376 | eval_agent = TrackingWrapper(model=self.agent._model) 377 | tracking_eval = TrackingEvaluation( 378 | motions=self.cfg.tracking_eval_motions, 379 | motion_base_path=self.cfg.tracking_eval_motions_root, 380 | env_kwargs={ 381 | "state_init": "Default", 382 | }, 383 | num_envs=self.cfg.tracking_eval_num_envs, 384 | ) 385 | start_t = time.time() 386 | print(f"Tracking started at {time.ctime(start_t)}", flush=True) 387 | tracking_metrics = tracking_eval.run(agent=eval_agent) 388 | duration = time.time() - start_t 389 | print(f"Tracking eval time: {duration}") 390 | if self.cfg.use_wandb: 391 | aggregate, m_dict = collections.defaultdict(list), {} 392 | for _, metr in tracking_metrics.items(): 393 | for k, v in metr.items(): 394 | if isinstance(v, numbers.Number): 395 | aggregate[k].append(v) 396 | for k, v in aggregate.items(): 397 | m_dict[k] = np.mean(v) 398 | m_dict[f"{k}#std"] = np.std(v) 399 | m_dict["time"] = duration 400 | 401 | wandb.log( 402 | {f"eval/tracking/{k}": v for k, v in m_dict.items()}, 403 | step=t, 404 | ) 405 | # --------------------------------------------------------------- 406 | # this is important, move back the agent to cuda and 407 | # restart the training 408 | self.agent._model.to("cuda") 409 | self.agent._model.train() 410 | 411 | return {"reward": reward_metrics, "tracking": tracking_metrics} 412 | 413 | 414 | if __name__ == "__main__": 415 | config = tyro.cli(TrainConfig) 416 | 417 | env, _ = make_humenv( 418 | num_envs=1, 419 | vectorization_mode="sync", 420 | wrappers=[gymnasium.wrappers.FlattenObservation], 421 | render_width=320, 422 | render_height=320, 423 | ) 424 | 425 | agent_config = FBcprAgentConfig() 426 | agent_config.model.obs_dim = env.observation_space.shape[0] 427 | agent_config.model.action_dim = env.action_space.shape[0] 428 | agent_config.model.norm_obs = True 429 | agent_config.train.batch_size = 1024 430 | agent_config.train.use_mix_rollout = 1 431 | agent_config.train.update_z_every_step = 150 432 | agent_config.model.actor_std = 0.2 433 | agent_config.model.seq_length = 8 434 | # archi 435 | # the config of the model trained in the paper 436 | model, hidden_dim, hidden_layers = "simple", 1024, 2 437 | # uncomment the line below for the config of model deployed in the demo 438 | # WARNING: you need to use compile=True on a A100 GPU or better, as otherwise training can be very slow 439 | # model, hidden_dim, hidden_layers = "residual", 2048, 12 440 | agent_config.model.archi.z_dim = 256 441 | agent_config.model.archi.b.norm = 1 442 | agent_config.model.archi.norm_z = 1 443 | agent_config.model.archi.f.hidden_dim = hidden_dim 444 | agent_config.model.archi.b.hidden_dim = 256 445 | agent_config.model.archi.actor.hidden_dim = hidden_dim 446 | agent_config.model.archi.critic.hidden_dim = hidden_dim 447 | agent_config.model.archi.f.hidden_layers = hidden_layers 448 | agent_config.model.archi.b.hidden_layers = 1 449 | agent_config.model.archi.actor.hidden_layers = hidden_layers 450 | agent_config.model.archi.critic.hidden_layers = hidden_layers 451 | agent_config.model.archi.f.model = model 452 | agent_config.model.archi.actor.model = model 453 | agent_config.model.archi.critic.model = model 454 | # optim 455 | agent_config.train.lr_f = 1e-4 456 | agent_config.train.lr_b = 1e-5 457 | agent_config.train.lr_actor = 1e-4 458 | agent_config.train.lr_critic = 1e-4 459 | agent_config.train.ortho_coef = 100 460 | agent_config.train.train_goal_ratio = 0.2 461 | agent_config.train.expert_asm_ratio = 0.6 462 | agent_config.train.relabel_ratio = 0.8 463 | agent_config.train.reg_coeff = 0.01 464 | agent_config.train.q_loss_coef = 0.1 # or 0 465 | # discriminator cfg 466 | agent_config.train.grad_penalty_discriminator = 10 467 | agent_config.train.weight_decay_discriminator = 0 468 | agent_config.train.lr_discriminator = 1e-5 469 | agent_config.model.archi.discriminator.hidden_layers = 3 470 | agent_config.model.archi.discriminator.hidden_dim = 1024 471 | agent_config.model.device = config.device 472 | # misc 473 | agent_config.train.discount = 0.98 474 | agent_config.compile = config.compile 475 | agent_config.cudagraphs = config.cudagraphs 476 | env.close() 477 | 478 | ws = Workspace(config, agent_cfg=agent_config) 479 | ws.train() 480 | -------------------------------------------------------------------------------- /metamotivo/nn_models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # 3 | # This source code is licensed under the CC BY-NC 4.0 license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import torch 7 | from torch import nn 8 | from torch import distributions as pyd 9 | from torch.distributions.utils import _standard_normal 10 | import numpy as np 11 | import torch.nn.functional as F 12 | import numbers 13 | import math 14 | from typing import Any 15 | 16 | 17 | ########################## 18 | # Initialization utils 19 | ########################## 20 | 21 | # Initialization for parallel layers 22 | def parallel_orthogonal_(tensor, gain=1): 23 | if tensor.ndimension() == 2: 24 | tensor = nn.init.orthogonal_(tensor, gain=gain) 25 | return tensor 26 | if tensor.ndimension() < 3: 27 | raise ValueError("Only tensors with 3 or more dimensions are supported") 28 | n_parallel = tensor.size(0) 29 | rows = tensor.size(1) 30 | cols = tensor.numel() // n_parallel // rows 31 | flattened = tensor.new(n_parallel, rows, cols).normal_(0, 1) 32 | 33 | qs = [] 34 | for flat_tensor in torch.unbind(flattened, dim=0): 35 | if rows < cols: 36 | flat_tensor.t_() 37 | 38 | # Compute the qr factorization 39 | q, r = torch.linalg.qr(flat_tensor) 40 | # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf 41 | d = torch.diag(r, 0) 42 | ph = d.sign() 43 | q *= ph 44 | 45 | if rows < cols: 46 | q.t_() 47 | qs.append(q) 48 | 49 | qs = torch.stack(qs, dim=0) 50 | with torch.no_grad(): 51 | tensor.view_as(qs).copy_(qs) 52 | tensor.mul_(gain) 53 | return tensor 54 | 55 | def weight_init(m): 56 | if isinstance(m, nn.Linear): 57 | nn.init.orthogonal_(m.weight.data) 58 | if hasattr(m.bias, "data"): 59 | m.bias.data.fill_(0.0) 60 | elif isinstance(m, DenseParallel): 61 | gain = nn.init.calculate_gain("relu") 62 | parallel_orthogonal_(m.weight.data, gain) 63 | if hasattr(m.bias, "data"): 64 | m.bias.data.fill_(0.0) 65 | elif hasattr(m, "reset_parameters"): 66 | m.reset_parameters() 67 | 68 | 69 | ########################## 70 | # Update utils 71 | ########################## 72 | 73 | def _soft_update_params(net_params: Any, target_net_params: Any, tau: float): 74 | torch._foreach_mul_(target_net_params, 1 - tau) 75 | torch._foreach_add_(target_net_params, net_params, alpha=tau) 76 | 77 | def soft_update_params(net, target_net, tau) -> None: 78 | tau = float(min(max(tau, 0), 1)) 79 | net_params = tuple(x.data for x in net.parameters()) 80 | target_net_params = tuple(x.data for x in target_net.parameters()) 81 | _soft_update_params(net_params, target_net_params, tau) 82 | 83 | class eval_mode: 84 | def __init__(self, *models) -> None: 85 | self.models = models 86 | self.prev_states = [] 87 | 88 | def __enter__(self) -> None: 89 | self.prev_states = [] 90 | for model in self.models: 91 | self.prev_states.append(model.training) 92 | model.train(False) 93 | 94 | def __exit__(self, *args) -> None: 95 | for model, state in zip(self.models, self.prev_states): 96 | model.train(state) 97 | 98 | 99 | ########################## 100 | # Creation utils 101 | ########################## 102 | 103 | def build_backward(obs_dim, z_dim, cfg): 104 | return BackwardMap(obs_dim, z_dim, cfg.hidden_dim, cfg.hidden_layers, cfg.norm) 105 | 106 | def build_forward(obs_dim, z_dim, action_dim, cfg, output_dim=None): 107 | if cfg.ensemble_mode == "seq": 108 | return SequetialFMap(obs_dim, z_dim, action_dim, cfg) 109 | elif cfg.ensemble_mode == "vmap": 110 | raise NotImplementedError("vmap ensemble mode is currently not supported") 111 | 112 | assert cfg.ensemble_mode == "batch", "Invalid value for ensemble_mode. Use {'batch', 'seq', 'vmap'}" 113 | return _build_batch_forward(obs_dim, z_dim, action_dim, cfg, output_dim) 114 | 115 | def _build_batch_forward(obs_dim, z_dim, action_dim, cfg, output_dim=None, parallel=True): 116 | if cfg.model == "residual": 117 | forward_cls = ResidualForwardMap 118 | elif cfg.model == "simple": 119 | forward_cls = ForwardMap 120 | else: 121 | raise ValueError(f"Unsupported forward_map model {cfg.model}") 122 | num_parallel = cfg.num_parallel if parallel else 1 123 | return forward_cls(obs_dim, z_dim, action_dim, cfg.hidden_dim, cfg.hidden_layers, cfg.embedding_layers, num_parallel, output_dim) 124 | 125 | def build_actor(obs_dim, z_dim, action_dim, cfg): 126 | if cfg.model == "residual": 127 | actor_cls = ResidualActor 128 | elif cfg.model == "simple": 129 | actor_cls = Actor 130 | else: 131 | raise ValueError(f"Unsupported actor model {cfg.model}") 132 | return actor_cls(obs_dim, z_dim, action_dim, cfg.hidden_dim, cfg.hidden_layers, cfg.embedding_layers) 133 | 134 | def build_discriminator(obs_dim, z_dim, cfg): 135 | return Discriminator(obs_dim, z_dim, cfg.hidden_dim, cfg.hidden_layers) 136 | 137 | def linear(input_dim, output_dim, num_parallel=1): 138 | if num_parallel > 1: 139 | return DenseParallel(input_dim, output_dim, n_parallel=num_parallel) 140 | return nn.Linear(input_dim, output_dim) 141 | 142 | def layernorm(input_dim, num_parallel=1): 143 | if num_parallel > 1: 144 | return ParallelLayerNorm([input_dim], n_parallel=num_parallel) 145 | return nn.LayerNorm(input_dim) 146 | 147 | 148 | ########################## 149 | # Simple MLP models 150 | ########################## 151 | 152 | class BackwardMap(nn.Module): 153 | def __init__(self, goal_dim, z_dim, hidden_dim, hidden_layers: int = 2, norm=True) -> None: 154 | super().__init__() 155 | seq = [nn.Linear(goal_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh()] 156 | for _ in range(hidden_layers-1): 157 | seq += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()] 158 | seq += [nn.Linear(hidden_dim, z_dim)] 159 | if norm: 160 | seq += [Norm()] 161 | self.net = nn.Sequential(*seq) 162 | 163 | def forward(self, x): 164 | return self.net(x) 165 | 166 | 167 | def simple_embedding(input_dim, hidden_dim, hidden_layers, num_parallel=1): 168 | assert hidden_layers >= 2, "must have at least 2 embedding layers" 169 | seq = [linear(input_dim, hidden_dim, num_parallel), layernorm(hidden_dim, num_parallel), nn.Tanh()] 170 | for _ in range(hidden_layers - 2): 171 | seq += [linear(hidden_dim, hidden_dim, num_parallel), nn.ReLU()] 172 | seq += [linear(hidden_dim, hidden_dim // 2, num_parallel), nn.ReLU()] 173 | return nn.Sequential(*seq) 174 | 175 | 176 | class ForwardMap(nn.Module): 177 | def __init__(self, obs_dim, z_dim, action_dim, hidden_dim, hidden_layers: int = 1, 178 | embedding_layers: int = 2, num_parallel: int = 2, output_dim=None) -> None: 179 | super().__init__() 180 | self.z_dim = z_dim 181 | self.num_parallel = num_parallel 182 | self.hidden_dim = hidden_dim 183 | 184 | self.embed_z = simple_embedding(obs_dim + z_dim, hidden_dim, embedding_layers, num_parallel) 185 | self.embed_sa = simple_embedding(obs_dim + action_dim, hidden_dim, embedding_layers, num_parallel) 186 | 187 | seq = [] 188 | for _ in range(hidden_layers): 189 | seq += [linear(hidden_dim, hidden_dim, num_parallel), nn.ReLU()] 190 | seq += [linear(hidden_dim, output_dim if output_dim else z_dim, num_parallel)] 191 | self.Fs = nn.Sequential(*seq) 192 | 193 | def forward(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor): 194 | if self.num_parallel > 1: 195 | obs = obs.expand(self.num_parallel, -1, -1) 196 | z = z.expand(self.num_parallel, -1, -1) 197 | action = action.expand(self.num_parallel, -1, -1) 198 | z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # num_parallel x bs x h_dim // 2 199 | sa_embedding = self.embed_sa(torch.cat([obs, action], dim=-1)) # num_parallel x bs x h_dim // 2 200 | return self.Fs(torch.cat([sa_embedding, z_embedding], dim=-1)) 201 | 202 | 203 | class SequetialFMap(nn.Module): 204 | def __init__(self, obs_dim, z_dim, action_dim, cfg, output_dim=None): 205 | super().__init__() 206 | self.models = nn.ModuleList([_build_batch_forward(obs_dim, z_dim, action_dim, 207 | cfg, output_dim, parallel=False) for _ in range(cfg.num_parallel)]) 208 | 209 | def forward(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor) -> torch.Tensor: 210 | predictions = [model(obs, z, action) for model in self.models] 211 | return torch.stack(predictions) 212 | 213 | 214 | class Actor(nn.Module): 215 | def __init__(self, obs_dim, z_dim, action_dim, hidden_dim, hidden_layers: int = 1, 216 | embedding_layers: int = 2) -> None: 217 | super().__init__() 218 | 219 | self.embed_z = simple_embedding(obs_dim + z_dim, hidden_dim, embedding_layers) 220 | self.embed_s = simple_embedding(obs_dim, hidden_dim, embedding_layers) 221 | 222 | seq = [] 223 | for _ in range(hidden_layers): 224 | seq += [linear(hidden_dim, hidden_dim), nn.ReLU()] 225 | seq += [linear(hidden_dim, action_dim)] 226 | self.policy = nn.Sequential(*seq) 227 | 228 | def forward(self, obs, z, std): 229 | z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # bs x h_dim // 2 230 | s_embedding = self.embed_s(obs) # bs x h_dim // 2 231 | embedding = torch.cat([s_embedding, z_embedding], dim=-1) 232 | mu = torch.tanh(self.policy(embedding)) 233 | std = torch.ones_like(mu) * std 234 | dist = TruncatedNormal(mu, std) 235 | return dist 236 | 237 | 238 | class Discriminator(nn.Module): 239 | def __init__(self, obs_dim, z_dim, hidden_dim, hidden_layers) -> None: 240 | super().__init__() 241 | seq = [nn.Linear(obs_dim + z_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.Tanh()] 242 | for _ in range(hidden_layers-1): 243 | seq += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU()] 244 | seq += [nn.Linear(hidden_dim, 1)] 245 | self.trunk = nn.Sequential(*seq) 246 | 247 | def forward(self, obs: torch.Tensor, z: torch.Tensor) -> torch.Tensor: 248 | s = self.compute_logits(obs, z) 249 | return torch.sigmoid(s) 250 | 251 | def compute_logits(self, obs: torch.Tensor, z: torch.Tensor) -> torch.Tensor: 252 | x = torch.cat([z, obs], dim=1) 253 | logits = self.trunk(x) 254 | return logits 255 | 256 | def compute_reward(self, obs: torch.Tensor, z: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: 257 | s = self.forward(obs, z) 258 | s = torch.clamp(s, eps, 1 - eps) 259 | reward = s.log() - (1 - s).log() 260 | return reward 261 | 262 | 263 | ########################## 264 | # Residual models 265 | ########################## 266 | 267 | class ResidualBlock(nn.Module): 268 | def __init__(self, dim, num_parallel: int = 1): 269 | super().__init__() 270 | ln = layernorm(dim, num_parallel) 271 | lin = linear(dim, dim, num_parallel) 272 | self.mlp = nn.Sequential(ln, lin, nn.Mish()) 273 | 274 | def forward(self, x): 275 | return x + self.mlp(x) 276 | 277 | 278 | class Block(nn.Module): 279 | def __init__(self, input_dim, output_dim, activation, num_parallel: int = 1): 280 | super().__init__() 281 | ln = layernorm(input_dim, num_parallel) 282 | lin = linear(input_dim, output_dim, num_parallel) 283 | seq = [ln, lin] + ([nn.Mish()] if activation else []) 284 | self.mlp = nn.Sequential(*seq) 285 | 286 | def forward(self, x): 287 | return self.mlp(x) 288 | 289 | 290 | def residual_embedding(input_dim, hidden_dim, hidden_layers, num_parallel=1): 291 | assert hidden_layers >= 2, "must have at least 2 embedding layers" 292 | seq = [Block(input_dim, hidden_dim, True, num_parallel)] 293 | for _ in range(hidden_layers-2): 294 | seq += [ResidualBlock(hidden_dim, num_parallel)] 295 | seq += [Block(hidden_dim, hidden_dim // 2, True, num_parallel)] 296 | return nn.Sequential(*seq) 297 | 298 | 299 | class ResidualForwardMap(nn.Module): 300 | def __init__(self, obs_dim, z_dim, action_dim, hidden_dim, hidden_layers: int = 1, 301 | embedding_layers: int = 2, num_parallel: int = 2, output_dim=None) -> None: 302 | super().__init__() 303 | self.z_dim = z_dim 304 | self.num_parallel = num_parallel 305 | self.hidden_dim = hidden_dim 306 | 307 | self.embed_z = residual_embedding(obs_dim + z_dim, hidden_dim, embedding_layers, num_parallel) 308 | self.embed_sa = residual_embedding(obs_dim + action_dim, hidden_dim, embedding_layers, num_parallel) 309 | 310 | seq = [ResidualBlock(hidden_dim, num_parallel) for _ in range(hidden_layers)] 311 | seq += [Block(hidden_dim, output_dim if output_dim else z_dim, False, num_parallel)] 312 | self.Fs = nn.Sequential(*seq) 313 | 314 | def forward(self, obs: torch.Tensor, z: torch.Tensor, action: torch.Tensor): 315 | if self.num_parallel > 1: 316 | obs = obs.expand(self.num_parallel, -1, -1) 317 | z = z.expand(self.num_parallel, -1, -1) 318 | action = action.expand(self.num_parallel, -1, -1) 319 | z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # num_parallel x bs x h_dim // 2 320 | sa_embedding = self.embed_sa(torch.cat([obs, action], dim=-1)) # num_parallel x bs x h_dim // 2 321 | return self.Fs(torch.cat([sa_embedding, z_embedding], dim=-1)) 322 | 323 | 324 | class ResidualActor(nn.Module): 325 | def __init__(self, obs_dim, z_dim, action_dim, hidden_dim, hidden_layers: int = 1, 326 | embedding_layers: int = 2) -> None: 327 | super().__init__() 328 | 329 | self.embed_z = residual_embedding(obs_dim + z_dim, hidden_dim, embedding_layers) 330 | self.embed_s = residual_embedding(obs_dim, hidden_dim, embedding_layers) 331 | 332 | seq = [ResidualBlock(hidden_dim) for _ in range(hidden_layers)] + [Block(hidden_dim, action_dim, False)] 333 | self.policy = nn.Sequential(*seq) 334 | 335 | def forward(self, obs, z, std): 336 | z_embedding = self.embed_z(torch.cat([obs, z], dim=-1)) # bs x h_dim // 2 337 | s_embedding = self.embed_s(obs) # bs x h_dim // 2 338 | embedding = torch.cat([s_embedding, z_embedding], dim=-1) 339 | mu = torch.tanh(self.policy(embedding)) 340 | std = torch.ones_like(mu) * std 341 | dist = TruncatedNormal(mu, std) 342 | return dist 343 | 344 | 345 | ########################## 346 | # Helper modules 347 | ########################## 348 | 349 | class DenseParallel(nn.Module): 350 | def __init__( 351 | self, 352 | in_features: int, 353 | out_features: int, 354 | n_parallel: int, 355 | bias: bool = True, 356 | device=None, 357 | dtype=None, 358 | reset_params=True, 359 | ) -> None: 360 | factory_kwargs = {"device": device, "dtype": dtype} 361 | super(DenseParallel, self).__init__() 362 | self.in_features = in_features 363 | self.out_features = out_features 364 | self.n_parallel = n_parallel 365 | if n_parallel is None or (n_parallel == 1): 366 | self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs)) 367 | if bias: 368 | self.bias = nn.Parameter(torch.empty(out_features, **factory_kwargs)) 369 | else: 370 | self.register_parameter("bias", None) 371 | else: 372 | self.weight = nn.Parameter( 373 | torch.empty((n_parallel, in_features, out_features), **factory_kwargs) 374 | ) 375 | if bias: 376 | self.bias = nn.Parameter( 377 | torch.empty((n_parallel, 1, out_features), **factory_kwargs) 378 | ) 379 | else: 380 | self.register_parameter("bias", None) 381 | if self.bias is None: 382 | raise NotImplementedError 383 | if reset_params: 384 | self.reset_parameters() 385 | 386 | def load_module_list_weights(self, module_list) -> None: 387 | with torch.no_grad(): 388 | assert len(module_list) == self.n_parallel 389 | weight_list = [m.weight.T for m in module_list] 390 | target_weight = torch.stack(weight_list, dim=0) 391 | self.weight.data.copy_(target_weight.data) 392 | if self.bias: 393 | bias_list = [ln.bias.unsqueeze(0) for ln in module_list] 394 | target_bias = torch.stack(bias_list, dim=0) 395 | self.bias.data.copy_(target_bias.data) 396 | 397 | # TODO why do these layers have their own reset scheme? 398 | def reset_parameters(self) -> None: 399 | nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5)) 400 | if self.bias is not None: 401 | fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) 402 | bound = 1 / np.sqrt(fan_in) if fan_in > 0 else 0 403 | nn.init.uniform_(self.bias, -bound, bound) 404 | 405 | def forward(self, input): 406 | if self.n_parallel is None or (self.n_parallel == 1): 407 | return F.linear(input, self.weight, self.bias) 408 | else: 409 | return torch.baddbmm(self.bias, input, self.weight) 410 | 411 | def extra_repr(self) -> str: 412 | return "in_features={}, out_features={}, n_parallel={}, bias={}".format( 413 | self.in_features, self.out_features, self.n_parallel, self.bias is not None 414 | ) 415 | 416 | 417 | class ParallelLayerNorm(nn.Module): 418 | def __init__(self, normalized_shape, n_parallel, eps=1e-5, elementwise_affine=True, 419 | device=None, dtype=None) -> None: 420 | factory_kwargs = {'device': device, 'dtype': dtype} 421 | super(ParallelLayerNorm, self).__init__() 422 | if isinstance(normalized_shape, numbers.Integral): 423 | normalized_shape = [normalized_shape, ] 424 | assert len(normalized_shape) == 1 425 | self.n_parallel = n_parallel 426 | self.normalized_shape = list(normalized_shape) 427 | self.eps = eps 428 | self.elementwise_affine = elementwise_affine 429 | if self.elementwise_affine: 430 | if n_parallel is None or (n_parallel == 1): 431 | self.weight = nn.Parameter(torch.empty([*self.normalized_shape], **factory_kwargs)) 432 | self.bias = nn.Parameter(torch.empty([*self.normalized_shape], **factory_kwargs)) 433 | else: 434 | self.weight = nn.Parameter(torch.empty([n_parallel, 1, *self.normalized_shape], **factory_kwargs)) 435 | self.bias = nn.Parameter(torch.empty([n_parallel, 1, *self.normalized_shape], **factory_kwargs)) 436 | else: 437 | self.register_parameter('weight', None) 438 | self.register_parameter('bias', None) 439 | 440 | self.reset_parameters() 441 | 442 | def reset_parameters(self) -> None: 443 | if self.elementwise_affine: 444 | nn.init.ones_(self.weight) 445 | nn.init.zeros_(self.bias) 446 | 447 | def load_module_list_weights(self, module_list) -> None: 448 | with torch.no_grad(): 449 | assert len(module_list) == self.n_parallel 450 | if self.elementwise_affine: 451 | ln_weights = [ln.weight.unsqueeze(0) for ln in module_list] 452 | ln_biases = [ln.bias.unsqueeze(0) for ln in module_list] 453 | target_ln_weights = torch.stack(ln_weights, dim=0) 454 | target_ln_bias = torch.stack(ln_biases, dim=0) 455 | self.weight.data.copy_(target_ln_weights.data) 456 | self.bias.data.copy_(target_ln_bias.data) 457 | 458 | 459 | def forward(self, input): 460 | norm_input = F.layer_norm( 461 | input, self.normalized_shape, None, None, self.eps) 462 | if self.elementwise_affine: 463 | return (norm_input * self.weight) + self.bias 464 | else: 465 | return norm_input 466 | 467 | def extra_repr(self) -> str: 468 | return '{normalized_shape}, eps={eps}, ' \ 469 | 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) 470 | 471 | 472 | class TruncatedNormal(pyd.Normal): 473 | def __init__(self, loc, scale, low=-1.0, high=1.0, eps=1e-6) -> None: 474 | super().__init__(loc, scale, validate_args=False) 475 | self.low = low 476 | self.high = high 477 | self.eps = eps 478 | self.noise_upper_limit = high - self.loc 479 | self.noise_lower_limit = low - self.loc 480 | 481 | def _clamp(self, x) -> torch.Tensor: 482 | clamped_x = torch.clamp(x, self.low + self.eps, self.high - self.eps) 483 | x = x - x.detach() + clamped_x.detach() 484 | return x 485 | 486 | def sample(self, clip=None, sample_shape=torch.Size()) -> torch.Tensor: # type: ignore 487 | shape = self._extended_shape(sample_shape) 488 | eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) 489 | eps *= self.scale 490 | if clip is not None: 491 | eps = torch.clamp(eps, -clip, clip) 492 | x = self.loc + eps 493 | return self._clamp(x) 494 | 495 | 496 | class Norm(nn.Module): 497 | 498 | def __init__(self) -> None: 499 | super().__init__() 500 | 501 | def forward(self, x) -> torch.Tensor: 502 | return math.sqrt(x.shape[-1]) * F.normalize(x, dim=-1) 503 | --------------------------------------------------------------------------------