├── maze_envs ├── __init__.py ├── assets │ └── point_mass.xml ├── base.py └── maze.py ├── conda_env.yml ├── agent ├── __init__.py ├── critic.py ├── actor.py └── sac.py ├── config ├── agent │ └── sac.yaml └── imitate.yaml ├── video.py ├── CONTRIBUTING.md ├── README.md ├── utils.py ├── CODE_OF_CONDUCT.md ├── replay_buffer.py ├── logger.py ├── imitate.py └── LICENSE /maze_envs/__init__.py: -------------------------------------------------------------------------------- 1 | from .maze import MazeEnd_PointMass 2 | -------------------------------------------------------------------------------- /conda_env.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | name: gwil 4 | channels: 5 | - defaults 6 | dependencies: 7 | - python=3.6 8 | - pytorch 9 | - cudatoolkit=9.2 10 | - absl-py 11 | - pyparsing 12 | - pip: 13 | - termcolor 14 | - git+git://github.com/deepmind/dm_control.git 15 | - git+git://github.com/denisyarats/dmc2gym.git 16 | - tb-nightly 17 | - imageio 18 | - imageio-ffmpeg 19 | - hydra-core=1.1.0 20 | - POT 21 | - scipy 22 | -------------------------------------------------------------------------------- /agent/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import abc 4 | 5 | 6 | class Agent(object): 7 | def reset(self): 8 | """For state-full agents this function performs reseting at the beginning of each episode.""" 9 | pass 10 | 11 | @abc.abstractmethod 12 | def train(self, training=True): 13 | """Sets the agent in either training or evaluation mode.""" 14 | 15 | @abc.abstractmethod 16 | def update(self, replay_buffer, logger, step): 17 | """Main function of the agent that performs learning.""" 18 | 19 | @abc.abstractmethod 20 | def act(self, obs, sample=False): 21 | """Issues an action given an observation.""" 22 | -------------------------------------------------------------------------------- /config/agent/sac.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | agent: 4 | _target_: agent.sac.SACAgent 5 | obs_dim: ??? # to be specified later 6 | action_dim: ??? # to be specified later 7 | action_range: ??? # to be specified later 8 | device: ${device} 9 | critic: ${double_q_critic} 10 | actor: ${diag_gaussian_actor} 11 | discount: 0.99 12 | init_temperature: 0.1 13 | alpha_lr: 1e-4 14 | alpha_betas: [0.9, 0.999] 15 | actor_lr: 1e-4 16 | actor_betas: [0.9, 0.999] 17 | actor_update_frequency: 1 18 | critic_lr: 1e-4 19 | critic_betas: [0.9, 0.999] 20 | critic_tau: 0.005 21 | critic_target_update_frequency: 2 22 | batch_size: 1024 23 | learnable_temperature: true 24 | 25 | double_q_critic: 26 | _target_: agent.critic.DoubleQCritic 27 | obs_dim: ${agent.obs_dim} 28 | action_dim: ${agent.action_dim} 29 | hidden_dim: 1024 30 | hidden_depth: 2 31 | 32 | diag_gaussian_actor: 33 | _target_: agent.actor.DiagGaussianActor 34 | obs_dim: ${agent.obs_dim} 35 | action_dim: ${agent.action_dim} 36 | hidden_depth: 2 37 | hidden_dim: 1024 38 | log_std_bounds: [-5, 2] 39 | -------------------------------------------------------------------------------- /video.py: -------------------------------------------------------------------------------- 1 | ### Code for Gromov-Wasserstein Imitation Learning, Arnaud Fickinger, 2022 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import imageio 5 | import os 6 | import numpy as np 7 | import sys 8 | 9 | import utils 10 | 11 | class VideoRecorder(object): 12 | def __init__(self, root_dir, height=256, width=256, camera_id=0, fps=30): 13 | self.save_dir = utils.make_dir(root_dir, 'video') if root_dir else None 14 | self.height = height 15 | self.width = width 16 | self.camera_id = camera_id 17 | self.fps = fps 18 | self.frames = [] 19 | 20 | def init(self, enabled=True): 21 | self.frames = [] 22 | self.enabled = self.save_dir is not None and enabled 23 | 24 | def record(self, env): 25 | if self.enabled: 26 | frame = env.render(mode='rgb_array', 27 | height=self.height, 28 | width=self.width, 29 | camera_id=self.camera_id) 30 | self.frames.append(frame) 31 | 32 | def save(self, file_name): 33 | if self.enabled: 34 | path = os.path.join(self.save_dir, file_name) 35 | imageio.mimsave(path, self.frames, fps=self.fps) 36 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to gwil 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to svg, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /agent/critic.py: -------------------------------------------------------------------------------- 1 | ### Code for Gromov-Wasserstein Imitation Learning, Arnaud Fickinger, 2022 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | import utils 10 | 11 | 12 | class DoubleQCritic(nn.Module): 13 | """Critic network, employes double Q-learning.""" 14 | def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth): 15 | super().__init__() 16 | 17 | self.Q1 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth) 18 | self.Q2 = utils.mlp(obs_dim + action_dim, hidden_dim, 1, hidden_depth) 19 | 20 | self.outputs = dict() 21 | self.apply(utils.weight_init) 22 | 23 | def forward(self, obs, action): 24 | assert obs.size(0) == action.size(0) 25 | 26 | obs_action = torch.cat([obs, action], dim=-1) 27 | q1 = self.Q1(obs_action) 28 | q2 = self.Q2(obs_action) 29 | 30 | self.outputs['q1'] = q1 31 | self.outputs['q2'] = q2 32 | 33 | return q1, q2 34 | 35 | def log(self, logger, step): 36 | for k, v in self.outputs.items(): 37 | logger.log_histogram(f'train_critic/{k}_hist', v, step) 38 | 39 | assert len(self.Q1) == len(self.Q2) 40 | for i, (m1, m2) in enumerate(zip(self.Q1, self.Q2)): 41 | assert type(m1) == type(m2) 42 | if type(m1) is nn.Linear: 43 | logger.log_param(f'train_critic/q1_fc{i}', m1, step) 44 | logger.log_param(f'train_critic/q2_fc{i}', m2, step) 45 | -------------------------------------------------------------------------------- /maze_envs/assets/point_mass.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /config/imitate.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - agent: sac 3 | - override hydra/launcher: submitit_slurm 4 | 5 | env_expert: pendulum_swingup 6 | env_agent: cartpole_swingup 7 | 8 | metric_expert: euclidean 9 | metric_agent: euclidean 10 | 11 | state_expert: state 12 | state_agent: state 13 | 14 | maze_id_agent: 0 15 | maze_id_expert: 0 16 | time_limit: 1000 17 | 18 | demonstration_name: '' 19 | 20 | nb_channels_expert: 3 21 | size_observation_expert: 84 22 | 23 | nb_channels_agent: 3 24 | size_observation_agent: 84 25 | 26 | experiment: test_exp 27 | 28 | num_train_steps: 1e6 29 | replay_buffer_capacity: ${num_train_steps} 30 | 31 | num_train_steps_expert: 1e6 32 | 33 | gw_include_actions_expert: true 34 | gw_include_actions_agent: true 35 | gw_entropic: true 36 | gw_epsilon: 5e-4 37 | gw_max_iter: 1000 38 | gw_tol: 1e-9 39 | gw_normalize: false 40 | gw_normalize_batch: false 41 | 42 | cutoff: 1e-5 43 | 44 | ot_cost: gw 45 | 46 | normalize_agent_with_expert: false 47 | 48 | sinkhorn_reg: 5e-3 49 | 50 | num_seed_steps: 5000 51 | 52 | eval_frequency: 50000 53 | num_eval_episodes: 10 54 | num_eval_episodes_expert: 10 55 | 56 | verbose: false 57 | 58 | device: cuda 59 | 60 | dmc: true 61 | gym: false 62 | 63 | weight_external_reward: 1 64 | weight_gw_reward: 1 65 | 66 | pretrained_agent: '' 67 | 68 | include_external_reward: false 69 | 70 | ultra_sparse: false 71 | 72 | # logger 73 | log_frequency: 10000 74 | log_save_tb: false 75 | 76 | # video recorder 77 | save_video: true 78 | 79 | timeout_min: 2000 80 | 81 | seed: 1 82 | 83 | comment: '' 84 | partition: learnlab 85 | gpus_per_node: 1 86 | cpus_per_task: 10 87 | 88 | expert_model: '' 89 | 90 | project_name: gwil 91 | 92 | # hydra configuration 93 | hydra: 94 | run: 95 | dir: ./exp/local/${now:%Y.%m.%d.%H%M%S}/${experiment}_${now:%H%M%S} 96 | sweep: 97 | dir: ./exp/${now:%Y.%m.%d.%H%M%S}/${now:%H%M%S}_${experiment} 98 | subdir: ${hydra.job.num} 99 | launcher: 100 | max_num_timeout: 100000 101 | timeout_min: ${timeout_min} 102 | partition: ${partition} 103 | comment: ${comment} 104 | mem_gb: 64 105 | gpus_per_node: ${gpus_per_node} 106 | cpus_per_task: ${cpus_per_task} 107 | -------------------------------------------------------------------------------- /agent/actor.py: -------------------------------------------------------------------------------- 1 | ### Code for Gromov-Wasserstein Imitation Learning, Arnaud Fickinger, 2022 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import torch 6 | import math 7 | from torch import nn 8 | import torch.nn.functional as F 9 | from torch import distributions as pyd 10 | 11 | import utils 12 | 13 | 14 | class TanhTransform(pyd.transforms.Transform): 15 | domain = pyd.constraints.real 16 | codomain = pyd.constraints.interval(-1.0, 1.0) 17 | bijective = True 18 | sign = +1 19 | 20 | def __init__(self, cache_size=1): 21 | super().__init__(cache_size=cache_size) 22 | 23 | @staticmethod 24 | def atanh(x): 25 | return 0.5 * (x.log1p() - (-x).log1p()) 26 | 27 | def __eq__(self, other): 28 | return isinstance(other, TanhTransform) 29 | 30 | def _call(self, x): 31 | return x.tanh() 32 | 33 | def _inverse(self, y): 34 | # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. 35 | # one should use `cache_size=1` instead 36 | return self.atanh(y) 37 | 38 | def log_abs_det_jacobian(self, x, y): 39 | # We use a formula that is more numerically stable, see details in the following link 40 | # https://github.com/tensorflow/probability/commit/ef6bb176e0ebd1cf6e25c6b5cecdd2428c22963f#diff-e120f70e92e6741bca649f04fcd907b7 41 | return 2. * (math.log(2.) - x - F.softplus(-2. * x)) 42 | 43 | 44 | class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): 45 | def __init__(self, loc, scale): 46 | self.loc = loc 47 | self.scale = scale 48 | 49 | self.base_dist = pyd.Normal(loc, scale) 50 | transforms = [TanhTransform()] 51 | super().__init__(self.base_dist, transforms) 52 | 53 | @property 54 | def mean(self): 55 | mu = self.loc 56 | for tr in self.transforms: 57 | mu = tr(mu) 58 | return mu 59 | 60 | 61 | class DiagGaussianActor(nn.Module): 62 | """torch.distributions implementation of an diagonal Gaussian policy.""" 63 | def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth, 64 | log_std_bounds): 65 | super().__init__() 66 | 67 | self.log_std_bounds = log_std_bounds 68 | self.trunk = utils.mlp(obs_dim, hidden_dim, 2 * action_dim, 69 | hidden_depth) 70 | 71 | self.outputs = dict() 72 | self.apply(utils.weight_init) 73 | 74 | def forward(self, obs): 75 | mu, log_std = self.trunk(obs).chunk(2, dim=-1) 76 | 77 | # constrain log_std inside [log_std_min, log_std_max] 78 | log_std = torch.tanh(log_std) 79 | log_std_min, log_std_max = self.log_std_bounds 80 | log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) 81 | 82 | std = log_std.exp() 83 | 84 | self.outputs['mu'] = mu 85 | self.outputs['std'] = std 86 | 87 | dist = SquashedNormal(mu, std) 88 | return dist 89 | 90 | def log(self, logger, step): 91 | for k, v in self.outputs.items(): 92 | logger.log_histogram(f'train_actor/{k}_hist', v, step) 93 | 94 | for i, m in enumerate(self.trunk): 95 | if type(m) == nn.Linear: 96 | logger.log_param(f'train_actor/fc{i}', m, step) 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gromov-Wasserstein Cross Domain Imitation Learning 2 | 3 | This is the official PyTorch implementation of the ICLR 2022 paper [Cross-Domain Imitation Learning via Optimal Transport](https://arxiv.org/abs/2110.03684). 4 | 5 | If you use this code in your research project please cite us as: 6 | ``` 7 | @inproceedings{fickinger2022gromov, 8 | title={Cross-Domain Imitation Learning via Optimal Transport}, 9 | author={Fickinger, Arnaud and Cohen, Samuel and Russell, Stuart and Amos, Brandon}, 10 | booktitle={10th International Conference on Learning Representations, ICLR}, 11 | year={2022} 12 | } 13 | ``` 14 | 15 | ## Requirements 16 | We assume you have access to a gpu that can run CUDA 9.2. Then, the simplest way to install all required dependencies is to create an anaconda environment and activate it: 17 | ``` 18 | conda env create -f conda_env.yml 19 | source activate gwil 20 | ``` 21 | 22 | ## Instructions 23 | 24 | Expert demonstrations are available [here](https://drive.google.com/file/d/1xE882IuQkXUuaeXHInYaP9eqvhQm48Et/view?usp=sharing). Copy the directory exp at the root of this repo. 25 | 26 | ### Training the expert policies 27 | 28 | Only needed if new expert demonstrations are needed. The parameter num_train_steps is set such that the policy obtained is approximately optimal in the environment as observed in [this repo's result plots](https://github.com/denisyarats/pytorch_sac). 29 | 30 | ``` 31 | python train.py env=pendulum_swingup num_train_steps=1e6 experiment=expert 32 | python train.py env=cartpole_swingup num_train_steps=5e5 experiment=expert 33 | python train.py env=cheetah_run num_train_steps=2e6 experiment=expert 34 | python train.py env=walker_walk num_train_steps=2e6 experiment=expert 35 | ``` 36 | 37 | ### Saving the expert demonstrations 38 | 39 | Only needed if new expert demonstrations are needed. The parameter num_train_steps is set to be the same as when training the expert policy. 40 | 41 | ``` 42 | python save_expert_demonstration.py env=pendulum_swingup num_train_steps=1e6 experiment=expert_demonstration 43 | python save_expert_demonstration.py env=cartpole_swingup num_train_steps=5e5 experiment=expert_demonstration 44 | python save_expert_demonstration.py env=cheetah_run num_train_steps=2e6 experiment=expert_demonstration 45 | python save_expert_demonstration.py env=walker_walk num_train_steps=2e6 experiment=expert_demonstration 46 | ``` 47 | 48 | ### Training the imitation policies 49 | 50 | The parameter num_train_steps is set to be the same as when training the expert policy in the agent environment. 51 | ``` 52 | python imitate.py env_expert=pendulum_swingup env_agent=cartpole_swingup num_train_steps=1e6 experiment=imitation_normalize gw_entropic=false gw_normalize=true 53 | python imitate.py env_expert=cheetah_run env_agent=walker_walk num_train_steps=2e6 experiment=imitation_normalize gw_entropic=false gw_normalize=true 54 | python imitate.py env_expert=MazeEnd_PointMass env_agent=MazeEnd_PointMass maze_id_expert=0 maze_id_agent=2 num_train_steps=1e6 experiment=imitation_normalize gw_entropic=false gw_normalize=true dmc=false 55 | 56 | ``` 57 | 58 | ## Credits 59 | 60 | The code is based on the SAC Pytorch implementation available [here](https://github.com/denisyarats/pytorch_sac) 61 | 62 | # Licensing 63 | This repository is licensed under the 64 | [CC BY-NC 4.0 License](https://creativecommons.org/licenses/by-nc/4.0/). 65 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | ### Code for Gromov-Wasserstein Imitation Learning, Arnaud Fickinger, 2022 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn 7 | import os 8 | import random 9 | 10 | 11 | class eval_mode(object): 12 | def __init__(self, *models): 13 | self.models = models 14 | 15 | def __enter__(self): 16 | self.prev_states = [] 17 | for model in self.models: 18 | self.prev_states.append(model.training) 19 | model.train(False) 20 | 21 | def __exit__(self, *args): 22 | for model, state in zip(self.models, self.prev_states): 23 | model.train(state) 24 | return False 25 | 26 | 27 | class train_mode(object): 28 | def __init__(self, *models): 29 | self.models = models 30 | 31 | def __enter__(self): 32 | self.prev_states = [] 33 | for model in self.models: 34 | self.prev_states.append(model.training) 35 | model.train(True) 36 | 37 | def __exit__(self, *args): 38 | for model, state in zip(self.models, self.prev_states): 39 | model.train(state) 40 | return False 41 | 42 | 43 | def soft_update_params(net, target_net, tau): 44 | for param, target_param in zip(net.parameters(), target_net.parameters()): 45 | target_param.data.copy_(tau * param.data + 46 | (1 - tau) * target_param.data) 47 | 48 | 49 | def set_seed_everywhere(seed): 50 | torch.manual_seed(seed) 51 | if torch.cuda.is_available(): 52 | torch.cuda.manual_seed_all(seed) 53 | np.random.seed(seed) 54 | random.seed(seed) 55 | 56 | 57 | def make_dir(*path_parts): 58 | dir_path = os.path.join(*path_parts) 59 | try: 60 | os.mkdir(dir_path) 61 | except OSError: 62 | pass 63 | return dir_path 64 | 65 | 66 | def weight_init(m): 67 | """Custom weight init for Conv2D and Linear layers.""" 68 | if isinstance(m, nn.Linear): 69 | nn.init.orthogonal_(m.weight.data) 70 | if hasattr(m.bias, 'data'): 71 | m.bias.data.fill_(0.0) 72 | 73 | 74 | class MLP(nn.Module): 75 | def __init__(self, 76 | input_dim, 77 | hidden_dim, 78 | output_dim, 79 | hidden_depth, 80 | output_mod=None): 81 | super().__init__() 82 | self.trunk = mlp(input_dim, hidden_dim, output_dim, hidden_depth, 83 | output_mod) 84 | self.apply(weight_init) 85 | 86 | def forward(self, x): 87 | return self.trunk(x) 88 | 89 | 90 | def mlp(input_dim, hidden_dim, output_dim, hidden_depth, output_mod=None): 91 | if hidden_depth == 0: 92 | mods = [nn.Linear(input_dim, output_dim)] 93 | else: 94 | mods = [nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True)] 95 | for i in range(hidden_depth - 1): 96 | mods += [nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True)] 97 | mods.append(nn.Linear(hidden_dim, output_dim)) 98 | if output_mod is not None: 99 | mods.append(output_mod) 100 | trunk = nn.Sequential(*mods) 101 | return trunk 102 | 103 | 104 | def to_np(t): 105 | if t is None: 106 | return None 107 | elif t.nelement() == 0: 108 | return np.array([]) 109 | else: 110 | return t.cpu().detach().numpy() 111 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /agent/sac.py: -------------------------------------------------------------------------------- 1 | ### Code for Gromov-Wasserstein Imitation Learning, Arnaud Fickinger, 2022 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import math 9 | import copy 10 | 11 | from agent import Agent 12 | import utils 13 | 14 | import hydra 15 | 16 | 17 | class SACAgent(Agent): 18 | """SAC algorithm.""" 19 | def __init__(self, obs_dim, action_dim, action_range, device, critic, 20 | actor, discount, init_temperature, alpha_lr, alpha_betas, 21 | actor_lr, actor_betas, actor_update_frequency, critic_lr, 22 | critic_betas, critic_tau, critic_target_update_frequency, 23 | batch_size, learnable_temperature): 24 | super().__init__() 25 | 26 | self.action_range = action_range 27 | self.device = torch.device(device) 28 | self.discount = discount 29 | self.critic_tau = critic_tau 30 | self.actor_update_frequency = actor_update_frequency 31 | self.critic_target_update_frequency = critic_target_update_frequency 32 | self.batch_size = batch_size 33 | self.learnable_temperature = learnable_temperature 34 | 35 | self.actor = actor.to(self.device) 36 | self.critic = critic.to(self.device) 37 | self.critic_target = copy.deepcopy(self.critic) 38 | self.critic_target.load_state_dict(self.critic.state_dict()) 39 | 40 | self.log_alpha = torch.tensor(np.log(init_temperature)).to(self.device) 41 | self.log_alpha.requires_grad = True 42 | # set target entropy to -|A| 43 | self.target_entropy = -action_dim 44 | 45 | # optimizers 46 | self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), 47 | lr=actor_lr, 48 | betas=actor_betas) 49 | 50 | self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), 51 | lr=critic_lr, 52 | betas=critic_betas) 53 | 54 | self.log_alpha_optimizer = torch.optim.Adam([self.log_alpha], 55 | lr=alpha_lr, 56 | betas=alpha_betas) 57 | 58 | self.train() 59 | self.critic_target.train() 60 | 61 | def train(self, training=True): 62 | self.training = training 63 | self.actor.train(training) 64 | self.critic.train(training) 65 | 66 | @property 67 | def alpha(self): 68 | return self.log_alpha.exp() 69 | 70 | def act(self, obs, sample=False): 71 | obs = torch.FloatTensor(obs).to(self.device) 72 | obs = obs.unsqueeze(0) 73 | dist = self.actor(obs) 74 | action = dist.sample() if sample else dist.mean 75 | action = action.clamp(*self.action_range) 76 | assert action.ndim == 2 and action.shape[0] == 1 77 | return utils.to_np(action[0]) 78 | 79 | def update_critic(self, obs, action, reward, next_obs, not_done, logger, 80 | step): 81 | dist = self.actor(next_obs) 82 | next_action = dist.rsample() 83 | log_prob = dist.log_prob(next_action).sum(-1, keepdim=True) 84 | target_Q1, target_Q2 = self.critic_target(next_obs, next_action) 85 | target_V = torch.min(target_Q1, 86 | target_Q2) - self.alpha.detach() * log_prob 87 | target_Q = reward + (not_done * self.discount * target_V) 88 | target_Q = target_Q.detach() 89 | 90 | # get current Q estimates 91 | current_Q1, current_Q2 = self.critic(obs, action) 92 | critic_loss = F.mse_loss(current_Q1, target_Q) + F.mse_loss( 93 | current_Q2, target_Q) 94 | logger.log('train_critic/loss', critic_loss, step) 95 | 96 | # Optimize the critic 97 | self.critic_optimizer.zero_grad() 98 | critic_loss.backward() 99 | self.critic_optimizer.step() 100 | 101 | self.critic.log(logger, step) 102 | 103 | def update_actor_and_alpha(self, obs, logger, step): 104 | dist = self.actor(obs) 105 | action = dist.rsample() 106 | log_prob = dist.log_prob(action).sum(-1, keepdim=True) 107 | actor_Q1, actor_Q2 = self.critic(obs, action) 108 | 109 | actor_Q = torch.min(actor_Q1, actor_Q2) 110 | actor_loss = (self.alpha.detach() * log_prob - actor_Q).mean() 111 | 112 | logger.log('train_actor/loss', actor_loss, step) 113 | logger.log('train_actor/target_entropy', self.target_entropy, step) 114 | logger.log('train_actor/entropy', -log_prob.mean(), step) 115 | 116 | # optimize the actor 117 | self.actor_optimizer.zero_grad() 118 | actor_loss.backward() 119 | self.actor_optimizer.step() 120 | 121 | self.actor.log(logger, step) 122 | 123 | if self.learnable_temperature: 124 | self.log_alpha_optimizer.zero_grad() 125 | alpha_loss = (self.alpha * 126 | (-log_prob - self.target_entropy).detach()).mean() 127 | logger.log('train_alpha/loss', alpha_loss, step) 128 | logger.log('train_alpha/value', self.alpha, step) 129 | alpha_loss.backward() 130 | self.log_alpha_optimizer.step() 131 | 132 | def update(self, replay_buffer, logger, step, gw=False, normalize_reward=False, normalize_reward_batch=False, include_external_reward=False, weight_external_reward=1, weight_gw_reward=1): 133 | obs, action, reward, next_obs, not_done, not_done_no_max = replay_buffer.sample( 134 | self.batch_size, gw=gw, normalize_reward=normalize_reward, normalize_reward_batch=normalize_reward_batch, include_external_reward=include_external_reward, weight_external_reward=weight_external_reward, weight_gw_reward=weight_gw_reward) 135 | 136 | logger.log('train/batch_reward', reward.mean(), step) 137 | 138 | self.update_critic(obs, action, reward, next_obs, not_done_no_max, 139 | logger, step) 140 | 141 | if step % self.actor_update_frequency == 0: 142 | self.update_actor_and_alpha(obs, logger, step) 143 | 144 | if step % self.critic_target_update_frequency == 0: 145 | utils.soft_update_params(self.critic, self.critic_target, 146 | self.critic_tau) 147 | -------------------------------------------------------------------------------- /replay_buffer.py: -------------------------------------------------------------------------------- 1 | ### Code for Gromov-Wasserstein Imitation Learning, Arnaud Fickinger, 2022 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import numpy as np 5 | import torch 6 | import ot 7 | import scipy as sp 8 | 9 | class ReplayBuffer(object): 10 | """Buffer to store environment transitions.""" 11 | def __init__(self, obs_shape, action_shape, capacity, device, cfg): 12 | self.capacity = capacity 13 | self.device = device 14 | 15 | # the proprioceptive obs is stored as float32, pixels obs as uint8 16 | obs_dtype = np.float32 if len(obs_shape) == 1 else np.uint8 17 | 18 | self.obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) 19 | self.next_obses = np.empty((capacity, *obs_shape), dtype=obs_dtype) 20 | self.actions = np.empty((capacity, *action_shape), dtype=np.float32) 21 | self.rewards = np.empty((capacity, 1), dtype=np.float32) 22 | self.gw_rewards = np.empty((capacity, 1), dtype=np.float32) 23 | self.not_dones = np.empty((capacity, 1), dtype=np.float32) 24 | self.not_dones_no_max = np.empty((capacity, 1), dtype=np.float32) 25 | 26 | self.idx = 0 27 | self.last_save = 0 28 | self.full = False 29 | 30 | self.idx_gw = 0 31 | self.full_gw = False 32 | 33 | self.cfg = cfg 34 | 35 | def __len__(self): 36 | return self.capacity if self.full else self.idx 37 | 38 | def add(self, obs, action, reward, next_obs, done, done_no_max): 39 | # import pdb;pdb.set_trace() 40 | np.copyto(self.obses[self.idx], obs) 41 | np.copyto(self.actions[self.idx], action) 42 | np.copyto(self.rewards[self.idx], reward) 43 | np.copyto(self.next_obses[self.idx], next_obs) 44 | np.copyto(self.not_dones[self.idx], not done) 45 | np.copyto(self.not_dones_no_max[self.idx], not done_no_max) 46 | 47 | self.idx = (self.idx + 1) % self.capacity 48 | self.full = self.full or self.idx == 0 49 | 50 | def process_trajectory(self, traj_expert, metric_expert = 'euclidean', metric_agent = 'euclidean', sinkhorn_reg=5e-3, normalize_agent_with_expert=False, include_actions=True, entropic=True): 51 | assert not (self.idx == 0 and not self.full) 52 | if self.idx == 0: 53 | traj_agent = self.obses[self.idx_gw:] 54 | else: 55 | traj_agent = self.obses[self.idx_gw:self.idx] 56 | 57 | if normalize_agent_with_expert: 58 | traj_agent = (traj_agent - traj_expert.mean())/(traj_expert.std()) 59 | 60 | if include_actions: 61 | if self.idx == 0: 62 | actions_trajectory = self.actions[self.idx_gw:] 63 | else: 64 | actions_trajectory = self.actions[self.idx_gw:self.idx] 65 | traj_agent = np.concatenate((traj_agent,actions_trajectory), axis=1) 66 | 67 | gw_rewards = self.compute_gw_reward(traj_expert, traj_agent, metric_expert, metric_agent, 68 | entropic, sinkhorn_reg=sinkhorn_reg) 69 | 70 | if self.idx == 0: 71 | self.gw_rewards[self.idx_gw:] = np.expand_dims(gw_rewards, axis=1) 72 | normalized_reward = ((self.gw_rewards[:self.idx] - self.gw_rewards[:self.idx].mean())/(1e-5+self.gw_rewards[:self.idx].std()))[self.idx_gw:].sum() 73 | 74 | else: 75 | self.gw_rewards[self.idx_gw:self.idx] = np.expand_dims(gw_rewards, axis=1) 76 | normalized_reward = ((self.gw_rewards[:self.idx] - self.gw_rewards[:self.idx].mean())/(1e-5+self.gw_rewards[:self.idx].std()))[self.idx_gw:self.idx].sum() 77 | 78 | self.idx_gw = self.idx 79 | 80 | return gw_rewards.sum(), normalized_reward 81 | 82 | def compute_gw_reward(self, traj_expert, traj_agent, metric_expert = 'euclidean', metric_agent = 'euclidean', entropic=True, sinkhorn_reg=5e-3, return_coupling = False): 83 | distances_expert = sp.spatial.distance.cdist(traj_expert, traj_expert, metric=metric_expert) 84 | 85 | distances_agent = sp.spatial.distance.cdist(traj_agent, traj_agent, metric=metric_agent) 86 | 87 | distances_expert/=distances_expert.max() 88 | distances_agent/=distances_agent.max() 89 | 90 | distribution_expert = ot.unif(len(traj_expert)) 91 | distribution_agent = ot.unif(len(traj_agent)) 92 | 93 | if entropic: 94 | optimal_coupling = ot.gromov.entropic_gromov_wasserstein( 95 | distances_expert, distances_agent, distribution_expert, distribution_agent, 'square_loss', epsilon=sinkhorn_reg, max_iter=1000, tol=1e-9) 96 | else: 97 | optimal_coupling= ot.gromov.gromov_wasserstein(distances_expert, distances_agent, distribution_expert, distribution_agent, 'square_loss') 98 | 99 | 100 | constC, hExpert, hAgent = ot.gromov.init_matrix(distances_expert, distances_agent, distribution_expert, distribution_agent, loss_fun='square_loss') 101 | 102 | tens = ot.gromov.tensor_product(constC, hExpert, hAgent, optimal_coupling) 103 | 104 | rewards = -(tens*optimal_coupling).sum(axis=0) 105 | 106 | if return_coupling: 107 | return rewards, optimal_coupling 108 | 109 | return rewards 110 | 111 | def sample(self, batch_size, gw=False, normalize_reward=False,normalize_reward_batch=False, include_external_reward=False, weight_external_reward=1, weight_gw_reward=1): 112 | 113 | if gw: 114 | end_idxs = self.capacity if self.full_gw else self.idx_gw 115 | else: 116 | end_idxs = self.capacity if self.full else self.idx 117 | 118 | idxs = np.random.randint(0, 119 | end_idxs, 120 | size=batch_size) 121 | 122 | obses = torch.as_tensor(self.obses[idxs], device=self.device).float() 123 | actions = torch.as_tensor(self.actions[idxs], device=self.device) 124 | if gw: 125 | if normalize_reward_batch: 126 | rewards = torch.as_tensor((self.gw_rewards[idxs] - self.gw_rewards[idxs].mean())/(1e-5+self.gw_rewards[idxs].std()), device=self.device) 127 | elif normalize_reward: 128 | gw_rewards_normalized = (self.gw_rewards[:end_idxs] - self.gw_rewards[:end_idxs].mean())/(1e-5+self.gw_rewards[:end_idxs].std()) 129 | rewards = torch.as_tensor(gw_rewards_normalized[idxs], device=self.device) 130 | else: 131 | rewards = torch.as_tensor(self.gw_rewards[idxs], device=self.device) 132 | 133 | else: 134 | rewards = torch.as_tensor(self.rewards[idxs], device=self.device) 135 | 136 | if include_external_reward: 137 | assert gw 138 | rewards=weight_gw_reward*rewards+weight_external_reward*torch.as_tensor(self.rewards[idxs], device=self.device) 139 | 140 | next_obses = torch.as_tensor(self.next_obses[idxs], 141 | device=self.device).float() 142 | not_dones = torch.as_tensor(self.not_dones[idxs], device=self.device) 143 | not_dones_no_max = torch.as_tensor(self.not_dones_no_max[idxs], 144 | device=self.device) 145 | 146 | return obses, actions, rewards, next_obses, not_dones, not_dones_no_max 147 | -------------------------------------------------------------------------------- /logger.py: -------------------------------------------------------------------------------- 1 | ### Code for Gromov-Wasserstein Imitation Learning, Arnaud Fickinger, 2022 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | from torch.utils.tensorboard import SummaryWriter 5 | from collections import defaultdict 6 | import json 7 | import os 8 | import csv 9 | import shutil 10 | import torch 11 | import numpy as np 12 | from termcolor import colored 13 | 14 | COMMON_TRAIN_FORMAT = [ 15 | ('episode', 'E', 'int'), 16 | ('step', 'S', 'int'), 17 | ('episode_reward', 'R', 'float'), 18 | ('episode_gw_reward', 'GWR', 'float'), 19 | ('normalized_episode_gw_reward', 'NGWR', 'float'), 20 | ('duration', 'D', 'time') 21 | ] 22 | 23 | COMMON_EVAL_FORMAT = [ 24 | ('episode', 'E', 'int'), 25 | ('step', 'S', 'int'), 26 | ('episode_reward', 'R', 'float') 27 | ] 28 | 29 | 30 | AGENT_TRAIN_FORMAT = { 31 | 'sac': [ 32 | ('batch_reward', 'BR', 'float'), 33 | ('actor_loss', 'ALOSS', 'float'), 34 | ('critic_loss', 'CLOSS', 'float'), 35 | ('alpha_loss', 'TLOSS', 'float'), 36 | ('alpha_value', 'TVAL', 'float'), 37 | ('actor_entropy', 'AENT', 'float') 38 | ] 39 | } 40 | 41 | 42 | class AverageMeter(object): 43 | def __init__(self): 44 | self._sum = 0 45 | self._count = 0 46 | 47 | def update(self, value, n=1): 48 | self._sum += value 49 | self._count += n 50 | 51 | def value(self): 52 | return self._sum / max(1, self._count) 53 | 54 | 55 | class MetersGroup(object): 56 | def __init__(self, file_name, formating): 57 | self._csv_file_name = self._prepare_file(file_name, 'csv') 58 | self._formating = formating 59 | self._meters = defaultdict(AverageMeter) 60 | self._csv_file = open(self._csv_file_name, 'w') 61 | self._csv_writer = None 62 | 63 | def _prepare_file(self, prefix, suffix): 64 | file_name = f'{prefix}.{suffix}' 65 | if os.path.exists(file_name): 66 | os.remove(file_name) 67 | return file_name 68 | 69 | def log(self, key, value, n=1): 70 | self._meters[key].update(value, n) 71 | 72 | def _prime_meters(self): 73 | data = dict() 74 | for key, meter in self._meters.items(): 75 | if key.startswith('train'): 76 | key = key[len('train') + 1:] 77 | else: 78 | key = key[len('eval') + 1:] 79 | key = key.replace('/', '_') 80 | data[key] = meter.value() 81 | return data 82 | 83 | def _dump_to_csv(self, data): 84 | if self._csv_writer is None: 85 | self._csv_writer = csv.DictWriter(self._csv_file, 86 | fieldnames=sorted(data.keys()), 87 | restval=0.0) 88 | self._csv_writer.writeheader() 89 | self._csv_writer.writerow(data) 90 | self._csv_file.flush() 91 | 92 | def _format(self, key, value, ty): 93 | if ty == 'int': 94 | value = int(value) 95 | return f'{key}: {value}' 96 | elif ty == 'float': 97 | return f'{key}: {value:.04f}' 98 | elif ty == 'time': 99 | return f'{key}: {value:04.1f} s' 100 | else: 101 | raise f'invalid format type: {ty}' 102 | 103 | def _dump_to_console(self, data, prefix): 104 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 105 | pieces = [f'| {prefix: <14}'] 106 | for key, disp_key, ty in self._formating: 107 | value = data.get(key, 0) 108 | pieces.append(self._format(disp_key, value, ty)) 109 | print(' | '.join(pieces)) 110 | 111 | def dump(self, step, prefix, save=True): 112 | if len(self._meters) == 0: 113 | return 114 | if save: 115 | data = self._prime_meters() 116 | data['step'] = step 117 | self._dump_to_csv(data) 118 | self._dump_to_console(data, prefix) 119 | self._meters.clear() 120 | 121 | 122 | class Logger(object): 123 | def __init__(self, 124 | log_dir, 125 | save_tb=False, 126 | log_frequency=10000, 127 | agent='sac'): 128 | self._log_dir = log_dir 129 | self._log_frequency = log_frequency 130 | if save_tb: 131 | tb_dir = os.path.join(log_dir, 'tb') 132 | if os.path.exists(tb_dir): 133 | try: 134 | shutil.rmtree(tb_dir) 135 | except: 136 | print("logger.py warning: Unable to remove tb directory") 137 | pass 138 | self._sw = SummaryWriter(tb_dir) 139 | else: 140 | self._sw = None 141 | # each agent has specific output format for training 142 | assert agent in AGENT_TRAIN_FORMAT 143 | train_format = COMMON_TRAIN_FORMAT + AGENT_TRAIN_FORMAT[agent] 144 | self._train_mg = MetersGroup(os.path.join(log_dir, 'train'), 145 | formating=train_format) 146 | self._eval_mg = MetersGroup(os.path.join(log_dir, 'eval'), 147 | formating=COMMON_EVAL_FORMAT) 148 | 149 | def _should_log(self, step, log_frequency): 150 | log_frequency = log_frequency or self._log_frequency 151 | return step % log_frequency == 0 152 | 153 | def _try_sw_log(self, key, value, step): 154 | if self._sw is not None: 155 | self._sw.add_scalar(key, value, step) 156 | 157 | def _try_sw_log_video(self, key, frames, step): 158 | if self._sw is not None: 159 | frames = torch.from_numpy(np.array(frames)) 160 | frames = frames.unsqueeze(0) 161 | self._sw.add_video(key, frames, step, fps=30) 162 | 163 | def _try_sw_log_histogram(self, key, histogram, step): 164 | if self._sw is not None: 165 | self._sw.add_histogram(key, histogram, step) 166 | 167 | def log(self, key, value, step, n=1, log_frequency=1): 168 | if not self._should_log(step, log_frequency): 169 | return 170 | assert key.startswith('train') or key.startswith('eval') 171 | if type(value) == torch.Tensor: 172 | value = value.item() 173 | self._try_sw_log(key, value / n, step) 174 | mg = self._train_mg if key.startswith('train') else self._eval_mg 175 | mg.log(key, value, n) 176 | 177 | def log_param(self, key, param, step, log_frequency=None): 178 | if not self._should_log(step, log_frequency): 179 | return 180 | self.log_histogram(key + '_w', param.weight.data, step) 181 | if hasattr(param.weight, 'grad') and param.weight.grad is not None: 182 | self.log_histogram(key + '_w_g', param.weight.grad.data, step) 183 | if hasattr(param, 'bias') and hasattr(param.bias, 'data'): 184 | self.log_histogram(key + '_b', param.bias.data, step) 185 | if hasattr(param.bias, 'grad') and param.bias.grad is not None: 186 | self.log_histogram(key + '_b_g', param.bias.grad.data, step) 187 | 188 | def log_video(self, key, frames, step, log_frequency=None): 189 | if not self._should_log(step, log_frequency): 190 | return 191 | assert key.startswith('train') or key.startswith('eval') 192 | self._try_sw_log_video(key, frames, step) 193 | 194 | def log_histogram(self, key, histogram, step, log_frequency=None): 195 | if not self._should_log(step, log_frequency): 196 | return 197 | assert key.startswith('train') or key.startswith('eval') 198 | self._try_sw_log_histogram(key, histogram, step) 199 | 200 | def dump(self, step, save=True, ty=None): 201 | if ty is None: 202 | self._train_mg.dump(step, 'train', save) 203 | self._eval_mg.dump(step, 'eval', save) 204 | elif ty == 'eval': 205 | self._eval_mg.dump(step, 'eval', save) 206 | elif ty == 'train': 207 | self._train_mg.dump(step, 'train', save) 208 | else: 209 | raise f'invalid log type: {ty}' 210 | -------------------------------------------------------------------------------- /maze_envs/base.py: -------------------------------------------------------------------------------- 1 | ### Code for Gromov-Wasserstein Imitation Learning, Arnaud Fickinger, 2022 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | """ 5 | Convention for Environments 6 | Name: _ 7 | Methods: 8 | - the usual gym methods 9 | - optional compute_reward for goal envs 10 | - task_obs(obs): returns agent agnostic componets 11 | - agent_obs(obs): returns agent specific components 12 | - skill_obs(obs): returns agent skill space components 13 | - goal(obs): returns current goal position 14 | 15 | The state method should be implemented as follows 16 | 1. Agent specific information 17 | 2. extra agent information (must be relative to the agent skill spc. components.) 18 | 3. agent skill space components 19 | 4. Task Info 20 | The below indiciates how these values should be set. The numbers reference the above. 21 | len(1 to 2) = AGENT_DIM 22 | len(3) = SKILL_DIM 23 | len(2 to 4) = TASK_DIM 24 | 25 | Defaults: 26 | agent_obs(obs): return obs[:AGENT_DIM] 27 | skill_obs(obs): return obs[AGENT_DIM:AGENT_DIM+SKILL_DIM] 28 | task_obs(obs): return obs[-TASK_DIM:] 29 | 30 | For GoalEnvs, they will handle the defining the goal space as part of the state. 31 | """ 32 | 33 | import os 34 | from collections import OrderedDict 35 | import gym 36 | import mujoco_py 37 | import numpy as np 38 | from gym import spaces 39 | from gym.utils import seeding 40 | 41 | def convert_observation_to_space(observation): 42 | if isinstance(observation, dict): 43 | space = spaces.Dict(OrderedDict([ 44 | (key, convert_observation_to_space(value)) 45 | for key, value in observation.items() 46 | ])) 47 | elif isinstance(observation, np.ndarray): 48 | low = np.full(observation.shape, -float('inf'), dtype=np.float32) 49 | high = np.full(observation.shape, float('inf'), dtype=np.float32) 50 | space = spaces.Box(low, high, dtype=observation.dtype) 51 | else: 52 | raise NotImplementedError(type(observation), observation) 53 | 54 | return space 55 | 56 | class Env(gym.Env): 57 | 58 | ASSET = None 59 | AGENT_DIM = None 60 | TASK_DIM = None 61 | SKILL_DIM = None 62 | FRAME_SKIP = None 63 | NSUBSTEPS = 1 64 | 65 | def __init__(self, model_path=None, frame_skip=None): 66 | if model_path is None: 67 | model_path = self.ASSET 68 | if frame_skip is None: 69 | frame_skip = self.FRAME_SKIP 70 | if model_path.startswith("/"): 71 | fullpath = model_path 72 | else: 73 | fullpath = os.path.join(os.path.dirname(__file__), "assets", model_path) 74 | if not os.path.exists(fullpath): 75 | raise IOError("File %s does not exist" % fullpath) 76 | self.frame_skip = frame_skip 77 | self.model = mujoco_py.load_model_from_path(fullpath) 78 | self.sim = mujoco_py.MjSim(self.model, nsubsteps=self.NSUBSTEPS) 79 | self.data = self.sim.data 80 | self.viewer = None 81 | self._viewers = {} 82 | 83 | self.metadata = { 84 | 'render.modes': ['human', 'rgb_array', 'depth_array'], 85 | 'video.frames_per_second': int(np.round(1.0 / self.dt)) 86 | } 87 | 88 | self.init_qpos = self.sim.data.qpos.ravel().copy() 89 | self.init_qvel = self.sim.data.qvel.ravel().copy() 90 | 91 | bounds = self.model.actuator_ctrlrange.copy().astype(np.float32) 92 | low, high = bounds.T 93 | self.action_space = spaces.Box(low=low, high=high, dtype=np.float32) 94 | 95 | action = self.action_space.sample() 96 | observation, _reward, done, _info = self.step(action) 97 | assert not done 98 | 99 | # Set the observation space 100 | self.observation_space = convert_observation_to_space(observation) 101 | 102 | self.seed() 103 | 104 | def seed(self, seed=None): 105 | self.np_random, seed = seeding.np_random(seed) 106 | return [seed] 107 | 108 | # methods to override (in addition to those required by gym.Env): 109 | # ---------------------------- 110 | 111 | def viewer_setup(self): 112 | """ 113 | This method is called when the viewer is initialized. 114 | Optionally implement this method, if you need to tinker with camera position 115 | and so forth. 116 | """ 117 | pass 118 | 119 | def agent_obs(self, obs): 120 | return obs[:self.AGENT_DIM] 121 | 122 | def skill_obs(self, obs): 123 | return obs[self.AGENT_DIM:self.AGENT_DIM + self.SKILL_DIM] 124 | 125 | def task_obs(self, obs): 126 | return obs[-self.TASK_DIM:] 127 | 128 | def display_skill(self, skill): 129 | self.model.body_pos[-1][:self.SKILL_DIM] = skill 130 | 131 | # Utils Methods 132 | # ----------------------------- 133 | 134 | def set_state(self, qpos, qvel): 135 | assert qpos.shape == (self.model.nq,) and qvel.shape == (self.model.nv,) 136 | old_state = self.sim.get_state() 137 | new_state = mujoco_py.MjSimState(old_state.time, qpos, qvel, 138 | old_state.act, old_state.udd_state) 139 | self.sim.set_state(new_state) 140 | self.sim.forward() 141 | 142 | @property 143 | def dt(self): 144 | return self.model.opt.timestep * self.frame_skip 145 | 146 | def do_simulation(self, ctrl, n_frames): 147 | self.sim.data.ctrl[:] = ctrl 148 | for _ in range(n_frames): 149 | self.sim.step() 150 | 151 | def render(self, 152 | mode='human', 153 | width=400, 154 | height=400, 155 | camera_id=None, 156 | camera_name=None): 157 | if mode == 'rgb_array': 158 | # import pdb;pdb.set_trace() 159 | if camera_id is not None and camera_name is not None: 160 | raise ValueError("Both `camera_id` and `camera_name` cannot be" 161 | " specified at the same time.") 162 | 163 | no_camera_specified = camera_name is None and camera_id is None 164 | if no_camera_specified: 165 | camera_name = 'track' 166 | 167 | if camera_id is None and camera_name in self.model._camera_name2id: 168 | camera_id = self.model.camera_name2id(camera_name) 169 | # import pdb; pdb.set_trace() 170 | 171 | # self._get_viewer(mode) 172 | 173 | self._get_viewer(mode).render(width, height, camera_id=camera_id) 174 | # window size used for old mujoco-py: 175 | data = self._get_viewer(mode).read_pixels(width, height, depth=False) 176 | # original image is upside-down, so flip it 177 | return data[::-1, :, :] 178 | elif mode == 'depth_array': 179 | self._get_viewer(mode).render(width, height) 180 | # window size used for old mujoco-py: 181 | # Extract depth part of the read_pixels() tuple 182 | data = self._get_viewer(mode).read_pixels(width, height, depth=True)[1] 183 | # original image is upside-down, so flip it 184 | return data[::-1, :] 185 | elif mode == 'human': 186 | self._get_viewer(mode).render() 187 | 188 | def close(self): 189 | if self.viewer is not None: 190 | # self.viewer.finish() 191 | self.viewer = None 192 | self._viewers = {} 193 | 194 | def _get_viewer(self, mode): 195 | # import pdb; 196 | # pdb.set_trace() 197 | self.viewer = self._viewers.get(mode) 198 | 199 | if self.viewer is None: 200 | if mode == 'human': 201 | self.viewer = mujoco_py.MjViewer(self.sim) 202 | elif mode == 'rgb_array' or mode == 'depth_array': 203 | self.viewer = mujoco_py.MjRenderContextOffscreen(self.sim, -1) 204 | 205 | self.viewer_setup() 206 | # import pdb; pdb.set_trace() 207 | self.viewer.cam.distance = 50. 208 | # self.viewer.cam.azimuth = 132. 209 | self.viewer.cam.elevation = -90. 210 | self._viewers[mode] = self.viewer 211 | return self.viewer 212 | 213 | def get_body_com(self, body_name): 214 | return self.data.get_body_xpos(body_name) 215 | 216 | def get_site_com(self, site_name): 217 | return self.data.get_site_xpos(site_name) 218 | 219 | class GoalEnv(gym.GoalEnv, Env): 220 | 221 | def agent_obs(self, obs): 222 | return obs['observation'][:self.AGENT_DIM] 223 | 224 | def skill_obs(self, obs): 225 | return obs['observation'][self.AGENT_DIM:self.AGENT_DIM + self.SKILL_DIM] 226 | 227 | def task_obs(self, obs): 228 | return obs['observation'][-self.TASK_DIM:] -------------------------------------------------------------------------------- /imitate.py: -------------------------------------------------------------------------------- 1 | ### Code for Gromov-Wasserstein Imitation Learning, Arnaud Fickinger, 2022 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | # !/usr/bin/env python3 5 | 6 | import numpy as np 7 | import torch 8 | import os 9 | import time 10 | import pickle as pkl 11 | from video import VideoRecorder 12 | from logger import Logger 13 | from replay_buffer import ReplayBuffer 14 | import utils 15 | import dmc2gym 16 | import hydra 17 | # import wrappers 18 | 19 | 20 | def make_maze(cfg, env_id, maze_id): 21 | from gym.envs.mujoco import mujoco_env 22 | from gym.wrappers import TimeLimit 23 | from maze_envs import MazeEnd_PointMass 24 | if env_id == 'MazeEnd_PointMass': 25 | env = MazeEnd_PointMass(maze_id=maze_id) 26 | else: 27 | assert False 28 | if cfg.time_limit > 0: 29 | env = TimeLimit(env, cfg.time_limit) 30 | return env 31 | 32 | 33 | def make_env(cfg, env_id, maze_id=0): 34 | """Helper function to create dm_control environment""" 35 | if cfg.dmc: 36 | if env_id == 'ball_in_cup_catch': 37 | domain_name = 'ball_in_cup' 38 | task_name = 'catch' 39 | else: 40 | domain_name = env_id.split('_')[0] 41 | task_name = '_'.join(env_id.split('_')[1:]) 42 | env = dmc2gym.make(domain_name=domain_name, 43 | task_name=task_name, 44 | seed=cfg.seed, 45 | visualize_reward=False) 46 | elif 'Maze' in env_id: 47 | env = make_maze(cfg, env_id, maze_id) 48 | random_rgb = np.array([0., 0., 0., 0.]) 49 | env.sim.model.geom_rgba[0, :] = random_rgb 50 | else: 51 | assert False 52 | if cfg.ultra_sparse: 53 | env = wrappers.SparseRewardCartpole(env) 54 | env.seed(cfg.seed) 55 | assert env.action_space.low.min() >= -1 56 | assert env.action_space.high.max() <= 1 57 | return env 58 | 59 | 60 | class Workspace(object): 61 | def __init__(self, cfg): 62 | self.work_dir = os.getcwd() 63 | print(f'workspace: {self.work_dir}') 64 | 65 | self.cfg = cfg 66 | 67 | agent_name = cfg.agent._target_.split('.')[1] 68 | self.logger = Logger(self.work_dir, 69 | save_tb=cfg.log_save_tb, 70 | log_frequency=cfg.log_frequency, 71 | agent=agent_name) 72 | 73 | utils.set_seed_everywhere(cfg.seed) 74 | self.device = torch.device(cfg.device) 75 | 76 | self.video_recorder = VideoRecorder( 77 | self.work_dir if cfg.save_video else None) 78 | self.step = 0 79 | 80 | # load the expert demonstration 81 | with open(f'{cfg.demonstration_name}', 'rb') as handle: 82 | dict_demonstration = pkl.load(handle) 83 | traj_expert = dict_demonstration['obs'] 84 | 85 | if cfg.gw_include_actions_expert: 86 | traj_expert = np.concatenate((traj_expert, dict_demonstration['action']), axis=1) 87 | self.traj_expert = traj_expert 88 | 89 | self.env = make_env(cfg, cfg.env_agent, cfg.maze_id_agent) 90 | 91 | cfg.agent.obs_dim = self.env.observation_space.shape[0] 92 | cfg.agent.action_dim = self.env.action_space.shape[0] 93 | cfg.agent.action_range = [ 94 | float(self.env.action_space.low.min()), 95 | float(self.env.action_space.high.max()) 96 | ] 97 | self.agent = hydra.utils.instantiate(cfg.agent) 98 | 99 | if cfg.pretrained_agent != '': 100 | self.agent.actor.load_state_dict( 101 | torch.load(f'{cfg.pretrained_agent}')) 102 | 103 | self.replay_buffer = ReplayBuffer(self.env.observation_space.shape, 104 | self.env.action_space.shape, 105 | int(cfg.replay_buffer_capacity), 106 | self.device, cfg) 107 | 108 | def evaluate(self): 109 | average_episode_reward = 0 110 | for episode in range(self.cfg.num_eval_episodes): 111 | obs = self.env.reset() 112 | self.agent.reset() 113 | self.video_recorder.init(enabled=(episode == 0)) 114 | done = False 115 | episode_reward = 0 116 | while not done: 117 | with utils.eval_mode(self.agent): 118 | action = self.agent.act(obs, sample=False) 119 | obs, reward, done, _ = self.env.step(action) 120 | self.video_recorder.record(self.env) 121 | episode_reward += reward 122 | 123 | average_episode_reward += episode_reward 124 | self.video_recorder.save(f'{self.step}.mp4') 125 | average_episode_reward /= self.cfg.num_eval_episodes 126 | self.logger.log('eval/episode_reward', average_episode_reward, 127 | self.step) 128 | self.logger.dump(self.step) 129 | 130 | def run(self): 131 | episode, episode_reward, done = 0, 0, True 132 | to_evaluate = False 133 | 134 | if 'Maze' in self.cfg.env_agent or self.cfg.ultra_sparse: 135 | episode_dense_reward = 0 136 | start_time = time.time() 137 | while self.step < self.cfg.num_train_steps: 138 | if done: 139 | if self.step > 0: 140 | duration = time.time() - start_time 141 | self.logger.log('train/duration', 142 | duration, self.step) 143 | self.logger.dump( 144 | self.step, save=(self.step > self.cfg.num_seed_steps)) 145 | 146 | # evaluate agent periodically 147 | if to_evaluate: 148 | self.logger.log('eval/episode', episode, self.step) 149 | self.evaluate() 150 | to_evaluate = False 151 | 152 | self.logger.log('train/episode_reward', episode_reward, 153 | self.step) 154 | 155 | obs = self.env.reset() 156 | 157 | self.agent.reset() 158 | done = False 159 | episode_reward = 0 160 | if 'Maze' in self.cfg.env_agent or self.cfg.ultra_sparse: 161 | episode_dense_reward = 0 162 | episode_step = 0 163 | episode += 1 164 | 165 | self.logger.log('train/episode', episode, self.step) 166 | 167 | # sample action for data collection 168 | if self.step < self.cfg.num_seed_steps: 169 | action = self.env.action_space.sample() 170 | else: 171 | with utils.eval_mode(self.agent): 172 | action = self.agent.act(obs, sample=True) 173 | 174 | # run training update 175 | if self.step >= self.cfg.num_seed_steps: 176 | self.agent.update(self.replay_buffer, self.logger, self.step, gw=True, 177 | normalize_reward=self.cfg.gw_normalize, 178 | normalize_reward_batch=self.cfg.gw_normalize_batch, 179 | include_external_reward=self.cfg.include_external_reward, 180 | weight_external_reward=self.cfg.weight_external_reward, 181 | weight_gw_reward=self.cfg.weight_gw_reward) 182 | 183 | next_obs, reward, done, info = self.env.step(action) 184 | 185 | if 'Maze' in self.cfg.env_agent or self.cfg.ultra_sparse: 186 | episode_dense_reward += info['dense_reward'] 187 | 188 | # allow infinite bootstrap 189 | done = float(done) 190 | done_no_max = 0 if episode_step + 1 == self.env._max_episode_steps else done 191 | episode_reward += reward 192 | 193 | self.replay_buffer.add(obs, action, reward, next_obs, done, done_no_max) 194 | 195 | if done: 196 | episode_gw_reward, normalized_episode_gw_reward = self.replay_buffer.process_trajectory( 197 | self.traj_expert, 198 | metric_expert=self.cfg.metric_expert, metric_agent=self.cfg.metric_agent, 199 | include_actions=self.cfg.gw_include_actions_agent, entropic=self.cfg.gw_entropic, 200 | normalize_agent_with_expert=self.cfg.normalize_agent_with_expert, 201 | sinkhorn_reg=self.cfg.sinkhorn_reg) 202 | self.logger.log('train/episode_gw_reward', episode_gw_reward, 203 | self.step) 204 | self.logger.log('train/normalized_episode_gw_reward', normalized_episode_gw_reward, 205 | self.step) 206 | obs = next_obs 207 | episode_step += 1 208 | 209 | self.step += 1 210 | if self.cfg.eval_frequency > 0 and self.step % self.cfg.eval_frequency == 0: 211 | to_evaluate = True 212 | self.evaluate_sample = self.step 213 | 214 | 215 | @hydra.main(config_path='config', config_name='imitate.yaml') 216 | def main(cfg): 217 | workspace = Workspace(cfg) 218 | workspace.run() 219 | 220 | 221 | if __name__ == '__main__': 222 | main() 223 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /maze_envs/maze.py: -------------------------------------------------------------------------------- 1 | ### Code for Gromov-Wasserstein Imitation Learning, Arnaud Fickinger, 2022 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | 4 | import os 5 | import numpy as np 6 | import tempfile 7 | import gym 8 | import xml.etree.ElementTree as ET 9 | from .base import Env 10 | import math 11 | 12 | def construct_maze(maze_id=0, length=1): 13 | # define the maze to use 14 | if maze_id == 0: 15 | if length != 1: 16 | raise NotImplementedError("Maze_id 0 only has length 1!") 17 | # structure = [ 18 | # [1, 1, 1, 1, 1], 19 | # [1, 'r', 0, 0, 1], 20 | # [1, 1, 1, 0, 1], 21 | # [1, 1, 1, 0, 1], 22 | # [1, 'g', 0, 0, 1], 23 | # [1, 1, 1, 1, 1], 24 | # ] 25 | structure = [ 26 | [1, 1, 1, 1, 1], 27 | [1, 'r', 0, 0, 1], 28 | [1, 1, 1, 0, 1], 29 | [1, 'g', 0, 0, 1], 30 | [1, 1, 1, 1, 1], 31 | ] 32 | # dense_reward_direction = [ 33 | # ["rr", "rr", "rr", "dd", "dd"], 34 | # ["rr", "rr", "rr", "dd", "dd"], 35 | # ["rr", "rr", "rr", "dd", "dd"], 36 | # ["ll", "ll", "ll", "dd", "dd"], 37 | # ["ll", "ll", "ll", "dl", "dl"], 38 | # ["ll", "ll", "ll", "dl", "dl"], 39 | # ] 40 | 41 | dense_reward_direction = [ 42 | [1, 1, 1, 1, 1], 43 | [1, 'r', 'r', 'd', 1], 44 | [1, 1, 1, 'd', 1], 45 | [1, 'l', 'l', 'l', 1], 46 | [1, 1, 1, 1, 1], 47 | ] 48 | # ##3: 49 | # dense_reward_direction = [ 50 | # ["r", "r", "r", "d", "d"], 51 | # ["r", "r", "r", "d", "d"], 52 | # 53 | # [1, 1, 1, "d", "d"], 54 | # ["l", "l", "l", "l", "l"], 55 | # ["l", "l", "l", "l", "l"], 56 | # ] 57 | 58 | # dense_reward = [ 59 | # [0, 1, 2, 3, 3], 60 | # [0, 1, 2, 3, 3], 61 | # [0, 1, 2, 4, 4], 62 | # [8, 8, 7, 5, 5], 63 | # [8, 8, 7, 6, 6], 64 | # [8, 8, 7, 6, 6], 65 | # ] 66 | dense_reward = [ 67 | [0, 0, 0, 0, 0], 68 | [0, 1, 2, 3, 0], 69 | [0, 0, 0, 4, 0], 70 | [0, 7, 6, 5, 0], 71 | [0, 0, 0, 0, 0], 72 | ] 73 | return structure, dense_reward, dense_reward_direction 74 | elif maze_id == 1: 75 | if length != 1: 76 | raise NotImplementedError("Maze_id 0 only has length 1!") 77 | structure = [ 78 | [1, 1, 1, 1, 1], 79 | [1, 'r', 0, 'g', 1], 80 | [1, 1, 1, 1, 1] 81 | ] 82 | # dense_reward_direction = [ 83 | # ['rr', 'rr', 'rr', 'rr', 'rr'], 84 | # ['rr', 'rr', 'rr', 'rr', 'rr'], 85 | # ['rr', 'rr', 'rr', 'rr', 'rr'] 86 | # ] 87 | dense_reward_direction = [ 88 | [1, 1, 1, 1, 1], 89 | [1, 'r', 'r', 'r', 1], 90 | [1, 1, 1, 1, 1] 91 | ] 92 | dense_reward = [ 93 | [0, 0, 0, 0, 0], 94 | [0, 1, 2, 3, 0], 95 | [0, 0, 0, 0, 0] 96 | ] 97 | return structure, dense_reward, dense_reward_direction 98 | if maze_id == 2: 99 | if length != 1: 100 | raise NotImplementedError("Maze_id 0 only has length 1!") 101 | structure = [ 102 | [1, 1, 1, 1, 1], 103 | [1, 'g', 0, 0, 1], 104 | [1, 1, 1, 0, 1], 105 | [1, 'r', 0, 0, 1], 106 | [1, 1, 1, 1, 1], 107 | ] 108 | dense_reward_direction = [ 109 | [1, 1, 1, 1, 1], 110 | [1, 'l', 'l', 'l', 1], 111 | [1, 1, 1, 'u', 1], 112 | [1, 'r', 'r', 'u', 1], 113 | [1, 1, 1, 1, 1], 114 | ] 115 | dense_reward = [ 116 | [0, 0, 0, 0, 0], 117 | [0, 7, 6, 5, 0], 118 | [0, 0, 0, 4, 0], 119 | [0, 1, 2, 3, 0], 120 | [0, 0, 0, 0, 0], 121 | ] 122 | return structure, dense_reward, dense_reward_direction 123 | elif maze_id == 3: 124 | if length != 1: 125 | raise NotImplementedError("Maze_id 0 only has length 1!") 126 | structure = [ 127 | [1, 1, 1, 1, 1], 128 | [1, 'g', 0, 'r', 1], 129 | [1, 1, 1, 1, 1] 130 | ] 131 | dense_reward_direction = [ 132 | [1, 1, 1, 1, 1], 133 | [1, 'l', 'l', 'l', 1], 134 | [1, 1, 1, 1, 1] 135 | ] 136 | dense_reward = [ 137 | [0, 0, 0, 0, 0], 138 | [0, 3, 2, 1, 0], 139 | [0, 0, 0, 0, 0] 140 | ] 141 | return structure, dense_reward, dense_reward_direction 142 | else: 143 | raise NotImplementedError("The provided MazeId is not recognized") 144 | 145 | class Maze(Env): 146 | VISUALIZE = True 147 | SCALING = 8.0 148 | # MAZE_ID = 0 149 | DIST_REWARD = 0 150 | SPARSE_REWARD = 1000.0 151 | RANDOM_GOALS = False 152 | HEIGHT = 2 153 | 154 | # Fixed constants for agents 155 | SKILL_DIM = 2 # X, Y 156 | TASK_DIM = 4 # agent position, goal position. 157 | 158 | def __init__(self, model_path=None, maze_id=0): 159 | self.MAZE_ID = maze_id 160 | if model_path is None: 161 | model_path = os.path.join(os.path.dirname(__file__), "assets", self.ASSET) 162 | 163 | # Initialize the maze and its parameters 164 | self.STRUCTURE, self.DENSE_REWARD, self.DENSE_REWARD_DIRECTION = construct_maze(maze_id=self.MAZE_ID, length=1) 165 | self.interm_goals = {} 166 | self.vel_rew = {} 167 | 168 | torso_x, torso_y = self.get_agent_start() 169 | self._init_torso_x = torso_x 170 | self._init_torso_y = torso_y 171 | 172 | for i in range(len(self.DENSE_REWARD)): 173 | for j in range(len(self.DENSE_REWARD[0])): 174 | minx = j * self.SCALING - self.SCALING * 0.5 - self._init_torso_x 175 | maxx = j * self.SCALING + self.SCALING * 0.5 - self._init_torso_x 176 | miny = i * self.SCALING - self.SCALING * 0.5 - self._init_torso_y 177 | maxy = i * self.SCALING + self.SCALING * 0.5 - self._init_torso_y 178 | self.interm_goals[(minx, maxx, miny, maxy)] = self.DENSE_REWARD[i][j] 179 | self.vel_rew[(minx, maxx, miny, maxy)] = self.DENSE_REWARD_DIRECTION[i][j] 180 | 181 | tree = ET.parse(model_path) 182 | worldbody = tree.find(".//worldbody") 183 | for i in range(len(self.STRUCTURE)): 184 | for j in range(len(self.STRUCTURE[0])): 185 | if self.STRUCTURE[i][j] == 'g': 186 | self.current_goal_pos = (i,j) 187 | if (isinstance(self.STRUCTURE[i][j], int) or isinstance(self.STRUCTURE[i][j], float)) \ 188 | and self.STRUCTURE[i][j] > 0: 189 | height = float(self.STRUCTURE[i][j]) 190 | ET.SubElement( 191 | worldbody, "geom", 192 | name="block_%d_%d" % (i, j), 193 | pos="%f %f %f" % (j * self.SCALING - torso_x, 194 | i * self.SCALING - torso_y, 195 | self.HEIGHT / 2 * height), 196 | size="%f %f %f" % (0.5 * self.SCALING, 197 | 0.5 * self.SCALING, 198 | self.HEIGHT / 2 * height), 199 | type="box", 200 | material="", 201 | contype="1", 202 | conaffinity="1", 203 | rgba="%f %f 0.3 1" % (height * 0.3, height * 0.3) 204 | ) 205 | 206 | if self.VISUALIZE: 207 | world_body = tree.find(".//worldbody") 208 | waypoint_elem = ET.Element('body') 209 | waypoint_elem.set("name", "waypoint") 210 | waypoint_elem.set("pos", "0 0 " + str(self.SCALING/10)) 211 | waypoint_geom = ET.SubElement(waypoint_elem, "geom") 212 | waypoint_geom.set("conaffinity", "0") 213 | waypoint_geom.set("contype", "0") 214 | waypoint_geom.set("name", "waypoint") 215 | waypoint_geom.set("pos", "0 0 0") 216 | waypoint_geom.set("rgba", "0.2 0.9 0.2 0.8") 217 | waypoint_geom.set("size", str(self.SCALING/10)) 218 | waypoint_geom.set("type", "sphere") 219 | world_body.insert(-1, waypoint_elem) 220 | xml_path = model_path 221 | 222 | _, xml_path = tempfile.mkstemp(text=True, suffix='.xml') 223 | tree.write(xml_path) 224 | 225 | # Get the list of possible segments of the maze to be the goal. 226 | self.possible_goal_positions = list() 227 | for i in range(len(self.STRUCTURE)): 228 | for j in range(len(self.STRUCTURE[0])): 229 | if self.STRUCTURE[i][j] == 0 or self.STRUCTURE[i][j] == 'g': 230 | self.possible_goal_positions.append((i,j)) 231 | self.goal_range = self.get_goal_range() 232 | self.center_goal = np.array([(self.goal_range[0] + self.goal_range[1]) / 2, 233 | (self.goal_range[2] + self.goal_range[3]) / 2]) 234 | 235 | super(Maze, self).__init__(model_path=xml_path) 236 | 237 | def sample_goal_pos(self): 238 | if not self.RANDOM_GOALS: 239 | return 240 | cur_x, cur_y = self.current_goal_pos 241 | self.STRUCTURE[cur_x][cur_y] = 0 242 | new_x, new_y = self.possible_goal_positions[self.np_random.randint(low=0, high=len(self.possible_goal_positions))] 243 | self.STRUCTURE[new_x][new_y] = 'g' 244 | self.current_goal_pos = (new_x, new_y) 245 | self.goal_range = self.get_goal_range() 246 | self.center_goal = np.array([(self.goal_range[0] + self.goal_range[1]) / 2, 247 | (self.goal_range[2] + self.goal_range[3]) / 2]) 248 | 249 | def get_agent_start(self): 250 | for i in range(len(self.STRUCTURE)): 251 | for j in range(len(self.STRUCTURE[0])): 252 | if self.STRUCTURE[i][j] == 'r': 253 | return j * self.SCALING, i * self.SCALING 254 | assert False 255 | 256 | def get_goal_range(self): 257 | for i in range(len(self.STRUCTURE)): 258 | for j in range(len(self.STRUCTURE[0])): 259 | if self.STRUCTURE[i][j] == 'g': 260 | minx = j * self.SCALING - self.SCALING * 0.5 - self._init_torso_x 261 | maxx = j * self.SCALING + self.SCALING * 0.5 - self._init_torso_x 262 | miny = i * self.SCALING - self.SCALING * 0.5 - self._init_torso_y 263 | maxy = i * self.SCALING + self.SCALING * 0.5 - self._init_torso_y 264 | return minx, maxx, miny, maxy 265 | 266 | 267 | def xy_to_discrete(self, x, y): 268 | x = x + self._init_torso_x + self.SCALING * 0.5 269 | y = y + self._init_torso_y + self.SCALING * 0.5 270 | i = int(x//self.SCALING) 271 | j = int(y//self.SCALING) 272 | return i,j 273 | 274 | 275 | 276 | # def get_dense_reward2(self): 277 | # # we need to remove int_torso to get mujoco positions 278 | # x, y = self.get_body_com("torso")[:2] 279 | # i, j = self.xy_to_discrete(x,y) 280 | # # i = max(i,0) 281 | # # j = max(i,0) 282 | # # i = min(i, len(self.DENSE_REWARD_DIRECTION)-1) 283 | # # j = min(j, len(self.DENSE_REWARD_DIRECTION[0])-1) 284 | # # return self.DENSE_REWARD[i][j]*10 285 | # try: 286 | # direction = self.DENSE_REWARD_DIRECTION[i][j] 287 | # except: 288 | # print("wrong coord should not happen") 289 | # return -100 290 | # if direction=='i': 291 | # 292 | # return -100 293 | # # miny = i * self.SCALING - self.SCALING * 0.5 - self._init_torso_y 294 | # # maxy = i * self.SCALING + self.SCALING * 0.5 - self._init_torso_y 295 | # # center = (miny+maxy)/2 296 | # # if y + self._init_torso_y + self.SCALING * 0.5>center: 297 | # # i = i+1 298 | # # else: 299 | # # i = i-1 300 | # # direction = self.DENSE_REWARD_DIRECTION[i][j] 301 | # if direction=='j': 302 | # return -100 303 | # # minx = j * self.SCALING - self.SCALING * 0.5 - self._init_torso_x 304 | # # maxx = j * self.SCALING + self.SCALING * 0.5 - self._init_torso_x 305 | # # center = (minx+maxx)/2 306 | # # if x + self._init_torso_x + self.SCALING * 0.5>center: 307 | # # j=j+1 308 | # # else: 309 | # # j=j-1 310 | # # direction = self.DENSE_REWARD_DIRECTION[i][j] 311 | # base_reward = self.DENSE_REWARD[i][j] 312 | # r_pos = 0 313 | # r_vel = 0 314 | # for dir in direction: 315 | # if dir=='r': 316 | # r_vel+= self.sim.data.qvel.flat[:][0] 317 | # minx = j * self.SCALING - self.SCALING * 0.5 - self._init_torso_x 318 | # maxx = j * self.SCALING + self.SCALING * 0.5 - self._init_torso_x 319 | # perturb = (x-minx)/(maxx-minx) 320 | # perturb = max(perturb,0) 321 | # perturb = min(perturb,1) 322 | # r_pos += (base_reward+(perturb))*10 323 | # elif dir == 'u': 324 | # r_vel-=self.sim.data.qvel.flat[:][1] 325 | # miny = i * self.SCALING - self.SCALING * 0.5 - self._init_torso_y 326 | # maxy = i * self.SCALING + self.SCALING * 0.5 - self._init_torso_y 327 | # perturb = (maxy-y)/(maxy-miny) 328 | # perturb = max(perturb, 0) 329 | # perturb = min(perturb, 1) 330 | # r_pos += (base_reward+(perturb))*10 331 | # elif dir == 'l': 332 | # r_vel -= self.sim.data.qvel.flat[:][0] 333 | # minx = j * self.SCALING - self.SCALING * 0.5 - self._init_torso_x 334 | # maxx = j * self.SCALING + self.SCALING * 0.5 - self._init_torso_x 335 | # perturb = (maxx - x) / (maxx - minx) 336 | # perturb = max(perturb, 0) 337 | # perturb = min(perturb, 1) 338 | # r_pos += (base_reward+(perturb))*10 339 | # elif dir == 'd': 340 | # r_vel+= self.sim.data.qvel.flat[:][1] 341 | # miny = i * self.SCALING - self.SCALING * 0.5 - self._init_torso_y 342 | # maxy = i * self.SCALING + self.SCALING * 0.5 - self._init_torso_y 343 | # perturb = (y - miny) / (maxy - miny) 344 | # perturb = max(perturb, 0) 345 | # perturb = min(perturb, 1) 346 | # r_pos += (base_reward+(perturb))*10 347 | # return r_pos 348 | 349 | def _get_obs(self): 350 | return NotImplemented 351 | 352 | def step(self, action): 353 | # import pdb; pdb.set_trace() 354 | self.do_simulation(action, self.frame_skip) 355 | obs = self._get_obs() 356 | # Compute the reward 357 | minx, maxx, miny, maxy = self.goal_range 358 | x, y = self.get_body_com("torso")[:2] 359 | reward = 0 360 | if minx <= x <= maxx and miny <= y <= maxy: 361 | reward += self.SPARSE_REWARD 362 | done = True 363 | else: 364 | done = False 365 | # if self.DIST_REWARD > 0: 366 | # # adds L2 reward 367 | # reward += -self.DIST_REWARD * np.linalg.norm(self.skill_obs(obs)[:2] - self.center_goal) 368 | 369 | dense_reward = self.get_dense_reward()+reward-1 370 | 371 | obs_dic = self.get_obs_dic() 372 | 373 | return obs, reward, done, {'is_success' : done, 'dense_reward':dense_reward, 'dmc_obs': obs_dic} 374 | 375 | def reset(self): 376 | return NotImplemented 377 | 378 | class MazeEnd_PointMass(Maze): 379 | ASSET = 'point_mass.xml' 380 | AGENT_DIM = 2 381 | FRAME_SKIP = 3 382 | 383 | def __init__(self, maze_id=0): 384 | super(MazeEnd_PointMass, self).__init__(maze_id=maze_id) 385 | 386 | def _get_obs(self): 387 | return np.concatenate([ 388 | self.sim.data.qvel.flat[:], 389 | self.get_body_com("torso")[:2], 390 | self.center_goal, 391 | ]) 392 | 393 | def get_obs_dic(self): 394 | # import pdb; pdb.set_trace() 395 | return { 396 | 'velocity': self.sim.data.qvel.flat[:], 397 | 'position': self.get_body_com("torso")[:2], 398 | 'position_torso': self.get_body_com("torso")[:2] 399 | } 400 | 401 | def get_dense_reward(self): 402 | # we need to remove int_torso to get mujoco positions 403 | rew_pos = 0 404 | rew_vel = 0 405 | x, y = self.get_body_com("torso")[:2] 406 | for key in self.interm_goals: 407 | minx, maxx, miny, maxy = key 408 | if minx <= x <= maxx and miny <= y <= maxy: 409 | rew_pos+=self.interm_goals[key] 410 | if self.vel_rew[key]=='r': 411 | rew_vel +=self.sim.data.qvel.flat[:][0] 412 | elif self.vel_rew[key] == 'u': 413 | rew_vel -=self.sim.data.qvel.flat[:][1] 414 | elif self.vel_rew[key] == 'l': 415 | rew_vel -= self.sim.data.qvel.flat[:][0] 416 | elif self.vel_rew[key] == 'd': 417 | rew_vel += self.sim.data.qvel.flat[:][1] 418 | break 419 | # print(f"rew_pos: {rew_pos}, rew_vel: {rew_vel}") 420 | return rew_pos+rew_vel 421 | # i, j = self.xy_to_discrete(x,y) 422 | # # i = max(i,0) 423 | # # j = max(i,0) 424 | # # i = min(i, len(self.DENSE_REWARD_DIRECTION)-1) 425 | # # j = min(j, len(self.DENSE_REWARD_DIRECTION[0])-1) 426 | # # return self.DENSE_REWARD[i][j]*10 427 | # try: 428 | # direction = self.DENSE_REWARD_DIRECTION[i][j] 429 | # except: 430 | # # print("wrong coord should not happen") 431 | # return -100 432 | # if direction==1: 433 | # 434 | # return -100 435 | # # miny = i * self.SCALING - self.SCALING * 0.5 - self._init_torso_y 436 | # # maxy = i * self.SCALING + self.SCALING * 0.5 - self._init_torso_y 437 | # # center = (miny+maxy)/2 438 | # # if y + self._init_torso_y + self.SCALING * 0.5>center: 439 | # # i = i+1 440 | # # else: 441 | # # i = i-1 442 | # # direction = self.DENSE_REWARD_DIRECTION[i][j] 443 | # # if direction=='j': 444 | # # return -100 445 | # # minx = j * self.SCALING - self.SCALING * 0.5 - self._init_torso_x 446 | # # maxx = j * self.SCALING + self.SCALING * 0.5 - self._init_torso_x 447 | # # center = (minx+maxx)/2 448 | # # if x + self._init_torso_x + self.SCALING * 0.5>center: 449 | # # j=j+1 450 | # # else: 451 | # # j=j-1 452 | # # direction = self.DENSE_REWARD_DIRECTION[i][j] 453 | # base_reward = self.DENSE_REWARD[i][j] 454 | # r_pos = 0 455 | # r_vel = 0 456 | # for dir in direction: 457 | # if dir=='r': 458 | # r_vel+= self.sim.data.qvel.flat[:][0] 459 | # minx = j * self.SCALING - self.SCALING * 0.5 - self._init_torso_x 460 | # maxx = j * self.SCALING + self.SCALING * 0.5 - self._init_torso_x 461 | # perturb = (x-minx)/(maxx-minx) 462 | # perturb = max(perturb,0) 463 | # perturb = min(perturb,1) 464 | # r_pos += (base_reward+(perturb))*10 465 | # elif dir == 'u': 466 | # r_vel-=self.sim.data.qvel.flat[:][1] 467 | # miny = i * self.SCALING - self.SCALING * 0.5 - self._init_torso_y 468 | # maxy = i * self.SCALING + self.SCALING * 0.5 - self._init_torso_y 469 | # perturb = (maxy-y)/(maxy-miny) 470 | # perturb = max(perturb, 0) 471 | # perturb = min(perturb, 1) 472 | # r_pos += (base_reward+(perturb))*10 473 | # elif dir == 'l': 474 | # r_vel -= self.sim.data.qvel.flat[:][0] 475 | # minx = j * self.SCALING - self.SCALING * 0.5 - self._init_torso_x 476 | # maxx = j * self.SCALING + self.SCALING * 0.5 - self._init_torso_x 477 | # perturb = (maxx - x) / (maxx - minx) 478 | # perturb = max(perturb, 0) 479 | # perturb = min(perturb, 1) 480 | # r_pos += (base_reward+(perturb))*10 481 | # elif dir == 'd': 482 | # r_vel+= self.sim.data.qvel.flat[:][1] 483 | # miny = i * self.SCALING - self.SCALING * 0.5 - self._init_torso_y 484 | # maxy = i * self.SCALING + self.SCALING * 0.5 - self._init_torso_y 485 | # perturb = (y - miny) / (maxy - miny) 486 | # perturb = max(perturb, 0) 487 | # perturb = min(perturb, 1) 488 | # r_pos += (base_reward+(perturb))*10 489 | # return r_pos 490 | 491 | def reset(self): 492 | self.sim.reset() 493 | self.sample_goal_pos() 494 | if self.VISUALIZE: 495 | self.model.body_pos[-2][:2] = self.center_goal 496 | qpos = self.init_qpos + self.np_random.uniform(low=-self.SCALING/10.0, high=self.SCALING/10.0, size=self.model.nq) 497 | qvel = self.init_qvel + self.np_random.uniform(low=-0.2, high=0.2, size=self.model.nv) 498 | self.set_state(qpos, qvel) 499 | return self._get_obs() 500 | --------------------------------------------------------------------------------