├── .gitignore ├── LICENSE ├── README.md ├── compute_goal_embedding.py ├── environment.yaml ├── graphirl ├── __init__.py ├── bbox_util.py ├── common.py ├── data_aug.py ├── dataset.py ├── evaluators │ ├── __init__.py │ ├── base.py │ ├── cycle_consistency.py │ ├── emb_visualizer.py │ ├── kendalls_tau.py │ ├── manager.py │ ├── nn_visualizer.py │ ├── reconstruction_visualizer.py │ └── reward_visualizer.py ├── factory.py ├── file_utils.py ├── frame_samplers.py ├── losses.py ├── models.py ├── tensorizers.py ├── trainers │ ├── __init__.py │ ├── base.py │ ├── classification.py │ ├── lifs.py │ ├── tcc.py │ └── tcn.py ├── transforms.py ├── types.py └── video_samplers.py ├── media ├── preview.gif └── summary.png ├── scripts ├── pegbox_helper.sh ├── push_helper.sh ├── reach_helper.sh ├── run_pegbox.sh ├── run_push.sh └── run_reach.sh └── src ├── algorithms ├── drq.py ├── drq_multiview.py ├── drqv2.py ├── factory.py ├── modules.py ├── modules_3d.py ├── multiview.py ├── rot_utils.py ├── sac.py ├── sacv2.py ├── sacv2_3d.py ├── svea_multiview.py └── sveav2.py ├── arguments.py ├── augmentations.py ├── color_jitter.py ├── env ├── __pycache__ │ └── wrappers.cpython-37.pyc ├── robot │ ├── __pycache__ │ │ ├── base.cpython-37.pyc │ │ ├── gym_utils.cpython-37.pyc │ │ ├── push.cpython-37.pyc │ │ └── registration.cpython-37.pyc │ ├── assets │ │ └── robot │ │ │ ├── fetch │ │ │ ├── base_link.STL │ │ │ ├── base_link_collision.stl │ │ │ ├── bellows_link_collision.stl │ │ │ ├── elbow_flex_link_collision.stl │ │ │ ├── estop_link.stl │ │ │ ├── forearm_roll_link_collision.stl │ │ │ ├── gripper_link.stl │ │ │ ├── head_pan_link_collision.stl │ │ │ ├── head_tilt_link_collision.stl │ │ │ ├── l_wheel_link_collision.stl │ │ │ ├── laser_link.stl │ │ │ ├── left_finger.STL │ │ │ ├── left_inner_knuckle.STL │ │ │ ├── left_outer_knuckle.STL │ │ │ ├── link1.STL │ │ │ ├── link2.STL │ │ │ ├── link3.STL │ │ │ ├── link4.STL │ │ │ ├── link5.STL │ │ │ ├── link6.STL │ │ │ ├── link7.STL │ │ │ ├── link_base copy.STL │ │ │ ├── link_base.STL │ │ │ ├── r_wheel_link_collision.stl │ │ │ ├── right_finger.STL │ │ │ ├── right_inner_knuckle.STL │ │ │ ├── right_outer_knuckle.STL │ │ │ ├── shoulder_lift_link_collision.stl │ │ │ ├── shoulder_pan_link_collision.stl │ │ │ ├── torso_fixed_link.stl │ │ │ ├── torso_lift_link_collision.stl │ │ │ ├── upperarm_roll_link_collision.stl │ │ │ ├── wrist_flex_link_collision.stl │ │ │ └── wrist_roll_link_collision.stl │ │ │ ├── golf.xml │ │ │ ├── golfbot.xml │ │ │ ├── hammer.xml │ │ │ ├── hammer_all.xml │ │ │ ├── lift.xml │ │ │ ├── mesh │ │ │ ├── arm │ │ │ │ ├── Base_Link.stl │ │ │ │ ├── Bracelet_Link.stl │ │ │ │ ├── ForeArm_Link.stl │ │ │ │ ├── HalfArm1_Link.stl │ │ │ │ ├── HalfArm2_Link.stl │ │ │ │ ├── Shoulder_Link.stl │ │ │ │ ├── SphericalWrist1_Link.stl │ │ │ │ └── SphericalWrist2_Link.stl │ │ │ ├── robotiq │ │ │ │ ├── inner_finger_coarse.stl │ │ │ │ ├── inner_knuckle_coarse.stl │ │ │ │ ├── kinova_robotiq_coupler.stl │ │ │ │ ├── outer_finger_coarse.stl │ │ │ │ ├── outer_knuckle_coarse.stl │ │ │ │ ├── robotiq_85_base_link.stl │ │ │ │ ├── robotiq_85_base_link_coarse.stl │ │ │ │ ├── robotiq_85_finger_link.stl │ │ │ │ ├── robotiq_85_finger_tip_link.stl │ │ │ │ ├── robotiq_85_inner_knuckle_link.stl │ │ │ │ └── robotiq_85_knuckle_link.stl │ │ │ ├── robotiq_85_gripper │ │ │ │ ├── robotiq_arg2f_85_base_link.stl │ │ │ │ ├── robotiq_arg2f_85_base_link_vis.dae │ │ │ │ ├── robotiq_arg2f_85_inner_finger.dae │ │ │ │ ├── robotiq_arg2f_85_inner_finger.stl │ │ │ │ ├── robotiq_arg2f_85_inner_finger_vis.dae │ │ │ │ ├── robotiq_arg2f_85_inner_finger_vis.stl │ │ │ │ ├── robotiq_arg2f_85_inner_knuckle.dae │ │ │ │ ├── robotiq_arg2f_85_inner_knuckle.stl │ │ │ │ ├── robotiq_arg2f_85_inner_knuckle_vis.dae │ │ │ │ ├── robotiq_arg2f_85_inner_knuckle_vis.stl │ │ │ │ ├── robotiq_arg2f_85_outer_finger.dae │ │ │ │ ├── robotiq_arg2f_85_outer_finger.stl │ │ │ │ ├── robotiq_arg2f_85_outer_finger_vis.dae │ │ │ │ ├── robotiq_arg2f_85_outer_finger_vis.stl │ │ │ │ ├── robotiq_arg2f_85_outer_knuckle.dae │ │ │ │ ├── robotiq_arg2f_85_outer_knuckle.stl │ │ │ │ ├── robotiq_arg2f_85_outer_knuckle_vis.dae │ │ │ │ ├── robotiq_arg2f_85_outer_knuckle_vis.stl │ │ │ │ ├── robotiq_arg2f_85_pad_vis.dae │ │ │ │ ├── robotiq_arg2f_85_pad_vis.stl │ │ │ │ ├── robotiq_arg2f_base_link.stl │ │ │ │ └── robotiq_gripper_coupling_vis.stl │ │ │ ├── xarm │ │ │ │ ├── link1.STL │ │ │ │ ├── link2.STL │ │ │ │ ├── link3.STL │ │ │ │ ├── link4.STL │ │ │ │ ├── link5.STL │ │ │ │ ├── link6.STL │ │ │ │ ├── link7.STL │ │ │ │ └── link_base.STL │ │ │ └── xarm_gripper │ │ │ │ ├── base_link.STL │ │ │ │ ├── left_finger.STL │ │ │ │ ├── left_inner_knuckle.STL │ │ │ │ ├── left_outer_knuckle.STL │ │ │ │ ├── right_finger.STL │ │ │ │ ├── right_inner_knuckle.STL │ │ │ │ └── right_outer_knuckle.STL │ │ │ ├── peg_in_box.xml │ │ │ ├── pick_place.xml │ │ │ ├── push.xml │ │ │ ├── reach.xml │ │ │ ├── robot_gen3.xml │ │ │ ├── robot_xarm.xml │ │ │ ├── robot_xarm_hammer.xml │ │ │ ├── shared.xml │ │ │ ├── shelf_placing_classic.xml │ │ │ ├── shelf_placing_far.xml │ │ │ ├── shelf_placing_near.xml │ │ │ └── texture │ │ │ ├── block.png │ │ │ ├── block_hidden.png │ │ │ ├── carpet-black.png │ │ │ ├── concrete.png │ │ │ ├── light_wood.png │ │ │ ├── sponge.png │ │ │ └── table1.png │ ├── base.py │ ├── golf.py │ ├── gym_utils.py │ ├── hammer.py │ ├── hammer_all.py │ ├── lift.py │ ├── peg_in_box.py │ ├── pick_place.py │ ├── push.py │ ├── reach.py │ ├── registration.py │ ├── reward_utils.py │ ├── shelf_go_and_back.py │ └── shelf_placing.py └── wrappers.py ├── graphirl_wrapper.py ├── logger.py ├── state.py ├── train.py ├── train_policy.py ├── train_reward.py ├── utils.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | src/base_configs 2 | src/configs 3 | src/torchkit 4 | src/x-magical 5 | src/x-magical-diverse 6 | src/graphirl 7 | src/sac 8 | logs/ 9 | datasets/ 10 | utils/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Sateesh Kumar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Graph Inverse Reinforcement Learning 2 | 3 | Original PyTorch implementation of **GraphIRL** from 4 | 5 | [Graph Inverse Reinforcement Learning from Diverse Videos](https://arxiv.org/abs/2207.14299) by 6 | 7 | [Sateesh Kumar](https://sateeshkumar21.github.io/), [Jonathan Zamora](https://jonzamora.dev/)\*, [Nicklas Hansen](https://nicklashansen.github.io/)\*, [Rishabh Jangir](https://jangirrishabh.github.io/), [Xiaolong Wang](https://xiaolonw.github.io/) 8 | 9 | Conference on Robot Learning (CoRL), 2022 **(Oral, Top 6.5%)** 10 | 11 |

12 |

13 | [Paper][Website][Video][Dataset] 14 |

15 | 16 | 17 | ## Method 18 | 19 | **GraphIRL** is a self-supervised method for learning a visually invariant reward function directly from a set of diverse third-person video demonstrations via a graph abstraction. Our framework builds an object-centric graph abstraction from video demonstrations and then learns an embedding space that captures task progression by exploiting the temporal cue in the videos. This embedding space is then used to construct a domain invariant and embodiment invariant reward function which can be used to train any standard reinforcement learning algorithm. 20 | 21 |

22 | 23 |

24 | 25 | 26 | ## Citation 27 | 28 | If you use our method or code in your research, please consider citing the paper as follows: 29 | 30 | ``` 31 | @article{kumar2022inverse, 32 | title={Graph Inverse Reinforcement Learning from Diverse Videos}, 33 | author={Kumar, Sateesh and Zamora, Jonathan and Hansen, Nicklas and Jangir, Rishabh and Wang, Xiaolong}, 34 | journal={arXiv preprint arXiv:2207.14299}, 35 | year={2022} 36 | } 37 | ``` 38 | 39 | ## Instructions 40 | 41 | We assume you have installed [MuJoCo](http://www.mujoco.org) on your machine. You can install dependencies using `conda`: 42 | 43 | ``` 44 | conda env create -f environment.yaml 45 | conda activate graphirl 46 | ``` 47 | 48 | You can use trained models to extract the reward and train a policy by running: 49 | 50 | ``` 51 | bash scripts/run_${Task_Name} /path/to/trained_model 52 | ``` 53 | 54 | ## Dataset 55 | 56 | The dataset for GraphIRL can be found in this [Google Drive Folder](https://drive.google.com/drive/folders/1mNJmnyzIoCudRcTdRVrN3WAiuWIM8355?usp=share_link). We have also released the trained reward models for Reach, Push and Peg in Box [here](https://drive.google.com/drive/folders/1O69YtAmq7hqU6kmutqGr-I0vBo7bu8f3?usp=sharing). 57 | We include a [script](scripts/download_data.sh) for downloading all data with the `gdown` package, though feel free to directly download the files from our drive folder if you run into issues with the script. 58 | 59 | ## License & Acknowledgements 60 | 61 | GraphIRL is licensed under the MIT license. [MuJoCo](https://github.com/deepmind/mujoco) is licensed under the Apache 2.0 license. We thank the [XIRL](https://x-irl.github.io/) authors for open-sourcing their codebase to the community, our work is built on top of their engineering efforts. 62 | -------------------------------------------------------------------------------- /compute_goal_embedding.py: -------------------------------------------------------------------------------- 1 | """Compute and store the mean goal embedding using a trained model.""" 2 | 3 | import os 4 | import pickle 5 | import typing 6 | 7 | from absl import app 8 | from absl import flags 9 | import numpy as np 10 | import torch 11 | from torchkit import checkpoint 12 | from graphirl import common 13 | import tqdm 14 | 15 | FLAGS = flags.FLAGS 16 | 17 | flags.DEFINE_string("experiment_path", None, "Path to model checkpoint.") 18 | flags.DEFINE_boolean("restore_checkpoint", True, "Restore model checkpoint.") 19 | 20 | flags.mark_flag_as_required("experiment_path") 21 | 22 | ModelType = torch.nn.Module 23 | DataLoaderType = typing.Dict[str, torch.utils.data.DataLoader] 24 | 25 | 26 | def embed( 27 | model, 28 | downstream_loader, 29 | device, 30 | ): 31 | """Embed the stored trajectories and compute mean goal embedding.""" 32 | goal_embs = [] 33 | init_embs = [] 34 | for class_name, class_loader in downstream_loader.items(): 35 | print(f"\tEmbedding class: {class_name}...") 36 | for batch_idx, batch in enumerate(class_loader): 37 | 38 | #if batch_idx % 100 == 0: 39 | 40 | out = model.infer(batch["frames"].to(device)) 41 | emb = out.numpy().embs 42 | goal_embs.append(emb[-1, :]) 43 | init_embs.append(emb[0, :]) 44 | 45 | goal_emb = np.mean(np.stack(goal_embs, axis=0), axis=0, keepdims=True) 46 | dist_to_goal = np.linalg.norm( 47 | np.stack(init_embs, axis=0) - goal_emb, axis=-1).mean() 48 | distance_scale = 1.0 / dist_to_goal 49 | return goal_emb, distance_scale 50 | 51 | 52 | def setup(device): 53 | """Load the latest embedder checkpoint and dataloaders.""" 54 | config = common.load_config_from_dir(FLAGS.experiment_path) 55 | model = common.get_model(config) 56 | downstream_loaders = common.get_downstream_dataloaders(config, False)["train"] 57 | checkpoint_dir = os.path.join(FLAGS.experiment_path, "checkpoints") 58 | if FLAGS.restore_checkpoint: 59 | checkpoint_manager = checkpoint.CheckpointManager( 60 | checkpoint.Checkpoint(model=model), 61 | checkpoint_dir, 62 | device, 63 | ) 64 | global_step = checkpoint_manager.restore_or_initialize() 65 | print(f"Restored model from checkpoint {global_step}.") 66 | else: 67 | print("Skipping checkpoint restore.") 68 | return model, downstream_loaders 69 | 70 | 71 | def main(_): 72 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 73 | model, downstream_loader = setup(device) 74 | model.to(device).eval() 75 | goal_emb, distance_scale = embed(model, downstream_loader, device) 76 | with open(os.path.join(FLAGS.experiment_path, "goal_emb.pkl"), "wb") as fp: 77 | pickle.dump(goal_emb, fp) 78 | with open(os.path.join(FLAGS.experiment_path, "distance_scale.pkl"), "wb") as fp: 79 | pickle.dump(distance_scale, fp) 80 | 81 | 82 | if __name__ == "__main__": 83 | app.run(main) 84 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: graphirl 2 | channels: 3 | - defaults 4 | - pytorch 5 | dependencies: 6 | - python=3.8 7 | - pytorch=1.7.1 8 | - torchvision=0.8.2 9 | - cudatoolkit=11.0 10 | - pip 11 | - pip: 12 | - albumentations==0.5.2 13 | - gym==0.17.* 14 | - pymunk==5.6.0 15 | - ml-collections==0.1.0 16 | - scikit-learn==0.24.1 17 | - tensorflow-cpu==2.6.0 18 | - imageio 19 | - imageio-ffmpeg 20 | - tqdm 21 | - absl-py 22 | - numpy 23 | - scipy 24 | - pandas 25 | - gdown==4.4.0 26 | - protobuf~=3.19.0 27 | - wandb 28 | - natsort 29 | - pillow==6.2.0 30 | - termcolor 31 | - opencv-python 32 | - xmltodict 33 | - scikit-image 34 | - seaborn 35 | - mujoco-py 36 | - kornia 37 | - einops 38 | -------------------------------------------------------------------------------- /graphirl/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | -------------------------------------------------------------------------------- /graphirl/common.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Functionality common to pretraining and evaluation.""" 17 | 18 | import os 19 | from typing import Dict 20 | 21 | from ml_collections import ConfigDict 22 | import torch 23 | from graphirl import factory 24 | import yaml 25 | 26 | DataLoadersDict = Dict[str, torch.utils.data.DataLoader] 27 | 28 | 29 | def get_pretraining_dataloaders( 30 | config, 31 | debug = False, 32 | ): 33 | """Construct a train/valid pair of pretraining dataloaders. 34 | 35 | Args: 36 | config: ConfigDict object with config parameters. 37 | debug: When set to True, the following happens: 1. Data augmentation is 38 | disabled regardless of config values. 2. Sequential sampling of videos is 39 | turned on. 3. The number of dataloader workers is set to 0. 40 | 41 | Returns: 42 | A dict of train/valid pretraining dataloaders. 43 | """ 44 | 45 | def _loader(split): 46 | dataset = factory.dataset_from_config(config, False, split, debug) 47 | batch_sampler = factory.video_sampler_from_config( 48 | config, dataset.dir_tree, downstream=False, sequential=debug) 49 | return torch.utils.data.DataLoader( 50 | dataset, 51 | collate_fn=dataset.collate_fn, 52 | batch_sampler=batch_sampler, 53 | num_workers=4 if torch.cuda.is_available() and not debug else 0, 54 | pin_memory=torch.cuda.is_available() and not debug, 55 | ) 56 | 57 | return { 58 | "train": _loader("train"), 59 | "valid": _loader("valid"), 60 | } 61 | 62 | 63 | def get_downstream_dataloaders( 64 | config, 65 | debug, 66 | ): 67 | """Construct a train/valid pair of downstream dataloaders. 68 | 69 | Args: 70 | config: ConfigDict object with config parameters. 71 | debug: When set to True, the following happens: 1. Data augmentation is 72 | disabled regardless of config values. 2. Sequential sampling of videos is 73 | turned on. 3. The number of dataloader workers is set to 0. 74 | 75 | Returns: 76 | A dict of train/valid downstream dataloaders 77 | """ 78 | 79 | def _loader(split): 80 | datasets = factory.dataset_from_config(config, True, split, debug) 81 | loaders = {} 82 | for action_class, dataset in datasets.items(): 83 | batch_sampler = factory.video_sampler_from_config( 84 | config, dataset.dir_tree, downstream=True, sequential=debug) 85 | loaders[action_class] = torch.utils.data.DataLoader( 86 | dataset, 87 | collate_fn=dataset.collate_fn, 88 | batch_sampler=batch_sampler, 89 | num_workers=4 if torch.cuda.is_available() and not debug else 0, 90 | pin_memory=torch.cuda.is_available() and not debug, 91 | ) 92 | return loaders 93 | 94 | return { 95 | "train": _loader("train"), 96 | "valid": _loader("valid"), 97 | } 98 | 99 | 100 | def get_factories( 101 | config, 102 | device, 103 | debug = False, 104 | ): 105 | """Feed config to factories and return objects.""" 106 | pretrain_loaders = get_pretraining_dataloaders(config, debug) 107 | downstream_loaders = get_downstream_dataloaders(config, debug) 108 | model = factory.model_from_config(config) 109 | optimizer = factory.optim_from_config(config, model) 110 | trainer = factory.trainer_from_config(config, model, optimizer, device) 111 | eval_manager = factory.evaluator_from_config(config) 112 | return ( 113 | model, 114 | optimizer, 115 | pretrain_loaders, 116 | downstream_loaders, 117 | trainer, 118 | eval_manager, 119 | ) 120 | 121 | 122 | def get_model(config): 123 | """Construct a model from a config.""" 124 | return factory.model_from_config(config) 125 | 126 | 127 | def load_config_from_dir(exp_dir): 128 | """Load experiment config.""" 129 | try: 130 | with open(os.path.join(exp_dir, "config.yaml"), "r") as fp: 131 | cfg = yaml.load(fp, Loader=yaml.FullLoader) 132 | return ConfigDict(cfg) 133 | except FileNotFoundError as e: 134 | raise e 135 | -------------------------------------------------------------------------------- /graphirl/evaluators/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Evaluators.""" 17 | 18 | from .cycle_consistency import ThreeWayCycleConsistency 19 | from .cycle_consistency import TwoWayCycleConsistency 20 | from .emb_visualizer import EmbeddingVisualizer 21 | from .kendalls_tau import KendallsTau 22 | from .manager import EvalManager 23 | from .nn_visualizer import NearestNeighbourVisualizer 24 | from .reconstruction_visualizer import ReconstructionVisualizer 25 | from .reward_visualizer import RewardVisualizer 26 | 27 | __all__ = [ 28 | "EvalManager", 29 | "KendallsTau", 30 | "TwoWayCycleConsistency", 31 | "ThreeWayCycleConsistency", 32 | "NearestNeighbourVisualizer", 33 | "RewardVisualizer", 34 | "EmbeddingVisualizer", 35 | "ReconstructionVisualizer", 36 | ] 37 | -------------------------------------------------------------------------------- /graphirl/evaluators/base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Base evaluator.""" 17 | 18 | import abc 19 | from typing import List, Optional, Union 20 | 21 | import dataclasses 22 | import numpy as np 23 | from torchkit import Logger 24 | from xirl.models import SelfSupervisedOutput 25 | 26 | 27 | @dataclasses.dataclass 28 | class EvaluatorOutput: 29 | """The output of an evaluator.""" 30 | 31 | # An evaluator does not necessarily generate all fields below. For example, 32 | # some evaluators like Kendalls Tau return a scalar and image metric, while 33 | # TwoWayCycleConsistency only generates a scalar metric. 34 | scalar: Optional[Union[float, List[float]]] = None 35 | image: Optional[Union[np.ndarray, List[np.ndarray]]] = None 36 | video: Optional[Union[np.ndarray, List[np.ndarray]]] = None 37 | 38 | @staticmethod 39 | def _assert_same_attrs(list_out): 40 | """Ensures a list of this class instance have the same attributes.""" 41 | 42 | def _not_none(o): 43 | return [getattr(o, a) is not None for a in ["scalar", "image", "video"]] 44 | 45 | expected = _not_none(list_out[0]) 46 | for o in list_out[1:]: 47 | actual = _not_none(o) 48 | assert np.array_equal(expected, actual) 49 | 50 | @staticmethod 51 | def merge(list_out): 52 | """Merge a list of this class instance into one.""" 53 | # We need to make sure that all elements of the list have the same 54 | # non-empty (i.e. != None) attributes. 55 | EvaluatorOutput._assert_same_attrs(list_out) 56 | # At this point, we're confident that we only need to check the 57 | # attributes of the first member of the list to guarantee the same 58 | # availability for *all* other members of the list. 59 | scalars = None 60 | if list_out[0].scalar is not None: 61 | scalars = [o.scalar for o in list_out] 62 | images = None 63 | if list_out[0].image is not None: 64 | images = [o.image for o in list_out] 65 | videos = None 66 | if list_out[0].video is not None: 67 | videos = [o.video for o in list_out] 68 | return EvaluatorOutput(scalars, images, videos) 69 | 70 | def log(self, logger, global_step, name, 71 | prefix): 72 | """Log the attributes to tensorboard.""" 73 | if self.scalar is not None: 74 | if isinstance(self.scalar, list): 75 | self.scalar = np.mean(self.scalar) 76 | logger.log_scalar(self.scalar, global_step, name, prefix) 77 | if self.image is not None: 78 | if isinstance(self.image, list): 79 | for i, image in enumerate(self.image): 80 | logger.log_image(image, global_step, name + f"_{i}", prefix) 81 | else: 82 | logger.log_image(self.image, global_step, name, prefix) 83 | if self.video is not None: 84 | if isinstance(self.video, list): 85 | for i, video in enumerate(self.video): 86 | logger.log_video(video, global_step, name + f"_{i}", prefix) 87 | else: 88 | logger.log_video(self.video, global_step, name, prefix) 89 | 90 | 91 | class Evaluator(abc.ABC): 92 | """Base class for evaluating a self-supervised model on downstream tasks. 93 | 94 | Subclasses must implement the `_evaluate` method. 95 | """ 96 | 97 | def __init__(self, inter_class): 98 | self.inter_class = inter_class 99 | 100 | @abc.abstractmethod 101 | def evaluate(self, outs): 102 | """Evaluate the downstream task in embedding space. 103 | 104 | Args: 105 | outs: A list of outputs generated by the model on the downstream dataset. 106 | :meta public: 107 | """ 108 | pass 109 | -------------------------------------------------------------------------------- /graphirl/evaluators/cycle_consistency.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Cycle-consistency evaluator.""" 17 | 18 | import itertools 19 | from typing import List 20 | 21 | from .base import Evaluator 22 | from .base import EvaluatorOutput 23 | import numpy as np 24 | from scipy.spatial.distance import cdist 25 | from xirl.models import SelfSupervisedOutput 26 | 27 | 28 | class _CycleConsistency(Evaluator): 29 | """Base class for cycle consistency evaluation.""" 30 | 31 | def __init__(self, n_way, stride, distance): 32 | """Constructor. 33 | 34 | Args: 35 | n_way: The number of cycle-consistency ways. 36 | stride: Controls how many frames are skipped in each video sequence. For 37 | example, if the embedding vector of the first video is (100, 128), a 38 | stride of 5 reduces it to (20, 128). 39 | distance: The distance metric to use when calculating nearest-neighbours. 40 | 41 | Raises: 42 | ValueError: If the distance metric is invalid or the 43 | mode is invalid. 44 | """ 45 | super().__init__(inter_class=False) 46 | 47 | assert n_way in [2, 3], "n_way must be 2 or 3." 48 | assert isinstance(stride, int), "stride must be an integer." 49 | if distance not in ["sqeuclidean", "cosine"]: 50 | raise ValueError( 51 | "{} is not a supported distance metric.".format(distance)) 52 | 53 | self.n_way = n_way 54 | self.stride = stride 55 | self.distance = distance 56 | 57 | def _evaluate_two_way(self, embs): 58 | """Two-way cycle consistency.""" 59 | num_embs = len(embs) 60 | total_combinations = num_embs * (num_embs - 1) 61 | ccs = np.zeros((total_combinations)) 62 | idx = 0 63 | for i in range(num_embs): 64 | query_emb = embs[i][::self.stride] 65 | ground_truth = np.arange(len(embs[i]))[::self.stride] 66 | for j in range(num_embs): 67 | if i == j: 68 | continue 69 | candidate_emb = embs[j][::self.stride] 70 | dists = cdist(query_emb, candidate_emb, self.distance) 71 | nns = np.argmin(dists[:, np.argmin(dists, axis=1)], axis=0) 72 | ccs[idx] = np.mean(np.abs(nns - ground_truth) <= 1) 73 | idx += 1 74 | ccs = ccs[~np.isnan(ccs)] 75 | return EvaluatorOutput(scalar=np.mean(ccs)) 76 | 77 | def _evaluate_three_way(self, embs): 78 | """Three-way cycle consistency.""" 79 | num_embs = len(embs) 80 | cycles = np.stack(list(itertools.permutations(np.arange(num_embs), 3))) 81 | total_combinations = len(cycles) 82 | ccs = np.zeros((total_combinations)) 83 | for c_idx, cycle in enumerate(cycles): 84 | # Forward consistency check. Each cycle will be a length 3 85 | # permutation, e.g. U - V - W. We compute nearest neighbours across 86 | # consecutive pairs in the cycle and loop back to the first cycle 87 | # index to obtain: U - V - W - U. 88 | query_emb = None 89 | for i in range(len(cycle)): 90 | if query_emb is None: 91 | query_emb = embs[cycle[i]][::self.stride] 92 | candidate_emb = embs[cycle[(i + 1) % len(cycle)]][::self.stride] 93 | dists = cdist(query_emb, candidate_emb, self.distance) 94 | nns_forward = np.argmin(dists, axis=1) 95 | query_emb = candidate_emb[nns_forward] 96 | ground_truth_forward = np.arange(len(embs[cycle[0]]))[::self.stride] 97 | cc_forward = np.abs(nns_forward - ground_truth_forward) <= 1 98 | # Backward consistency check. A backward check is equivalent to 99 | # reversing the middle pair V - W and performing a forward check, 100 | # e.g. U - W - V - U. 101 | cycle[1:] = cycle[1:][::-1] 102 | query_emb = None 103 | for i in range(len(cycle)): 104 | if query_emb is None: 105 | query_emb = embs[cycle[i]][::self.stride] 106 | candidate_emb = embs[cycle[(i + 1) % len(cycle)]][::self.stride] 107 | dists = cdist(query_emb, candidate_emb, self.distance) 108 | nns_backward = np.argmin(dists, axis=1) 109 | query_emb = candidate_emb[nns_backward] 110 | ground_truth_backward = np.arange(len(embs[cycle[0]]))[::self.stride] 111 | cc_backward = np.abs(nns_backward - ground_truth_backward) <= 1 112 | # Require consistency both ways. 113 | cc = np.logical_and(cc_forward, cc_backward) 114 | ccs[c_idx] = np.mean(cc) 115 | ccs = ccs[~np.isnan(ccs)] 116 | return EvaluatorOutput(scalar=np.mean(ccs)) 117 | 118 | def evaluate(self, outs): 119 | embs = [o.embs for o in outs] 120 | if self.n_way == 2: 121 | return self._evaluate_two_way(embs) 122 | return self._evaluate_three_way(embs) 123 | 124 | 125 | class TwoWayCycleConsistency(_CycleConsistency): 126 | """2-way cycle consistency evaluator [1]. 127 | 128 | References: 129 | [1]: https://arxiv.org/abs/1805.11592 130 | """ 131 | 132 | def __init__(self, stride, distance): 133 | super().__init__(2, stride, distance) 134 | 135 | 136 | class ThreeWayCycleConsistency(_CycleConsistency): 137 | """2-way cycle consistency evaluator [1]. 138 | 139 | References: 140 | [1]: https://arxiv.org/abs/1805.11592 141 | """ 142 | 143 | def __init__(self, stride, distance): 144 | super().__init__(3, stride, distance) 145 | -------------------------------------------------------------------------------- /graphirl/evaluators/emb_visualizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """2D embedding visualizer.""" 17 | 18 | from typing import List 19 | 20 | from .base import Evaluator 21 | from .base import EvaluatorOutput 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | from sklearn.decomposition import PCA 25 | from xirl.models import SelfSupervisedOutput 26 | 27 | 28 | class EmbeddingVisualizer(Evaluator): 29 | """Visualize PCA of the embeddings.""" 30 | 31 | def __init__(self, num_seqs): 32 | """Constructor. 33 | 34 | Args: 35 | num_seqs: How many embedding sequences to visualize. 36 | 37 | Raises: 38 | ValueError: If the distance metric is invalid. 39 | """ 40 | super().__init__(inter_class=True) 41 | 42 | self.num_seqs = num_seqs 43 | 44 | def _gen_emb_plot(self, embs): 45 | """Create a pyplot plot and save to buffer.""" 46 | fig = plt.figure() 47 | for emb in embs: 48 | plt.scatter(emb[:, 0], emb[:, 1]) 49 | fig.canvas.draw() 50 | img_arr = np.array(fig.canvas.renderer.buffer_rgba())[:, :, :3] 51 | plt.close() 52 | return img_arr 53 | 54 | def evaluate(self, outs): 55 | embs = [o.embs for o in outs] 56 | 57 | # Randomly sample the embedding sequences we'd like to plot. 58 | seq_idxs = np.random.choice( 59 | np.arange(len(embs)), size=self.num_seqs, replace=False) 60 | seq_embs = [embs[idx] for idx in seq_idxs] 61 | 62 | # Subsample embedding sequences to make them the same length. 63 | seq_lens = [s.shape[0] for s in seq_embs] 64 | min_len = np.min(seq_lens) 65 | same_length_embs = [] 66 | for emb in seq_embs: 67 | emb_len = len(emb) 68 | stride = emb_len / min_len 69 | idxs = np.arange(0.0, emb_len, stride).round().astype(int) 70 | idxs = np.clip(idxs, a_min=0, a_max=emb_len - 1) 71 | idxs = idxs[:min_len] 72 | same_length_embs.append(emb[idxs]) 73 | 74 | # Flatten embeddings to perform PCA. 75 | same_length_embs = np.stack(same_length_embs) 76 | num_seqs, seq_len, emb_dim = same_length_embs.shape 77 | embs_flat = same_length_embs.reshape(-1, emb_dim) 78 | embs_2d = PCA(n_components=2, random_state=0).fit_transform(embs_flat) 79 | embs_2d = embs_2d.reshape(num_seqs, seq_len, 2) 80 | 81 | image = self._gen_emb_plot(embs_2d) 82 | return EvaluatorOutput(image=image) 83 | -------------------------------------------------------------------------------- /graphirl/evaluators/kendalls_tau.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Kendall rank correlation coefficient evaluator.""" 17 | 18 | from typing import List 19 | 20 | from .base import Evaluator 21 | from .base import EvaluatorOutput 22 | import numpy as np 23 | from scipy.spatial.distance import cdist 24 | from scipy.stats import kendalltau 25 | from xirl.models import SelfSupervisedOutput 26 | 27 | 28 | def softmax(dists, temp = 1.0): 29 | dists_ = np.array(dists - np.max(dists)) 30 | exp = np.exp(dists_ / temp) 31 | return exp / np.sum(exp) 32 | 33 | 34 | class KendallsTau(Evaluator): 35 | """Kendall rank correlation coefficient [1]. 36 | 37 | References: 38 | [1]: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient 39 | """ 40 | 41 | def __init__(self, stride, distance): 42 | """Constructor. 43 | 44 | Args: 45 | stride: Controls how many frames are skipped in each video sequence. For 46 | example, if the embedding vector of the first video is (100, 128), a 47 | stride of 5 reduces it to (20, 128). 48 | distance: The distance metric to use when calculating nearest-neighbours. 49 | 50 | Raises: 51 | ValueError: If the distance metric is invalid. 52 | """ 53 | super().__init__(inter_class=False) 54 | 55 | assert isinstance(stride, int), "stride must be an integer." 56 | if distance not in ["sqeuclidean", "cosine"]: 57 | raise ValueError( 58 | "{} is not a supported distance metric.".format(distance)) 59 | 60 | self.stride = stride 61 | self.distance = distance 62 | 63 | def evaluate(self, outs): 64 | """Get pairwise nearest-neighbours then compute KT.""" 65 | embs = [o.embs for o in outs] 66 | num_embs = len(embs) 67 | total_combinations = num_embs * (num_embs - 1) 68 | taus = np.zeros((total_combinations)) 69 | idx = 0 70 | img = None 71 | for i in range(num_embs): 72 | query_emb = embs[i][::self.stride] 73 | for j in range(num_embs): 74 | if i == j: 75 | continue 76 | candidate_emb = embs[j][::self.stride] 77 | dists = cdist(query_emb, candidate_emb, self.distance) 78 | if i == 0 and j == 1: 79 | sim_matrix = [] 80 | for k in range(len(query_emb)): 81 | sim_matrix.append(softmax(-dists[k])) 82 | img = np.array(sim_matrix, dtype=np.float32)[Ellipsis, None] 83 | nns = np.argmin(dists, axis=1) 84 | taus[idx] = kendalltau(np.arange(len(nns)), nns).correlation 85 | idx += 1 86 | taus = taus[~np.isnan(taus)] 87 | if taus.size == 0: 88 | tau = 0.0 89 | else: 90 | tau = np.mean(taus) 91 | return EvaluatorOutput(scalar=tau, image=img) 92 | -------------------------------------------------------------------------------- /graphirl/evaluators/manager.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """An evaluator manager.""" 17 | 18 | from typing import Dict, List, Mapping, Optional 19 | 20 | from .base import Evaluator 21 | from .base import EvaluatorOutput 22 | import torch 23 | import torch.nn as nn 24 | from xirl.models import SelfSupervisedOutput 25 | 26 | 27 | class EvalManager: 28 | """Manage a bunch of downstream task evaluators and aggregate their results. 29 | 30 | Specifically, the manager embeds the downstream dataset *once*, and shares 31 | the embeddings across all evaluators for more efficient evaluation. 32 | """ 33 | 34 | def __init__(self, evaluators): 35 | """Constructor. 36 | 37 | Args: 38 | evaluators: A mapping from evaluator name to Evaluator instance. 39 | """ 40 | self._evaluators = evaluators 41 | 42 | @staticmethod 43 | @torch.no_grad() 44 | def embed( 45 | model, 46 | downstream_loader, 47 | device, 48 | eval_iters, 49 | ): 50 | """Run the model on the downstream data and generate embeddings.""" 51 | loader_to_output = {} 52 | for action_name, valid_loader in downstream_loader.items(): 53 | outs = [] 54 | for batch_idx, batch in enumerate(valid_loader): 55 | if eval_iters is not None and batch_idx >= eval_iters: 56 | break 57 | outs.append(model.infer(batch["frames"].to(device)).numpy()) 58 | loader_to_output[action_name] = outs 59 | return loader_to_output 60 | 61 | @torch.no_grad() 62 | def evaluate( 63 | self, 64 | model, 65 | downstream_loader, 66 | device, 67 | eval_iters=None, 68 | ): 69 | """Evaluate the model on the validation data. 70 | 71 | Args: 72 | model: The self-supervised model that will embed the frames in the 73 | downstream loader. 74 | downstream_loader: A downstream dataloader. Has a batch size of 1 and 75 | loads all frames of the video. 76 | device: The compute device. 77 | eval_iters: The number of time to call `next()` on the downstream 78 | iterator. Set to None to evaluate on the entire iterator. 79 | 80 | Returns: 81 | A dict mapping from evaluator name to EvaluatorOutput. 82 | """ 83 | model.eval() 84 | print("Embedding downstream dataset...") 85 | downstream_outputs = EvalManager.embed(model, downstream_loader, device, 86 | eval_iters) 87 | eval_to_metric = {} 88 | for evaluator_name, evaluator in self._evaluators.items(): 89 | print(f"\tRunning {evaluator_name} evaluator...") 90 | if evaluator.inter_class: 91 | # Merge all downstream classes into a single list and do one 92 | # eval computation. 93 | outs = [ 94 | o for out in downstream_outputs.values() for o in out # pylint: disable=g-complex-comprehension 95 | ] 96 | metric = evaluator.evaluate(outs) 97 | else: 98 | # Loop and evaluate over downstream classes separately, then 99 | # merge into a single EvaluatorOutput whose fields are lists. 100 | metrics = [] 101 | for outs in downstream_outputs.values(): 102 | metrics.append(evaluator.evaluate(outs)) 103 | metric = EvaluatorOutput.merge(metrics) 104 | eval_to_metric[evaluator_name] = metric 105 | return eval_to_metric 106 | -------------------------------------------------------------------------------- /graphirl/evaluators/nn_visualizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Nearest-neighbor evaluator.""" 17 | 18 | from typing import List 19 | 20 | from .base import Evaluator 21 | from .base import EvaluatorOutput 22 | import numpy as np 23 | from scipy.spatial.distance import cdist 24 | from xirl.models import SelfSupervisedOutput 25 | 26 | 27 | class NearestNeighbourVisualizer(Evaluator): 28 | """Nearest-neighbour frame visualizer.""" 29 | 30 | def __init__( 31 | self, 32 | distance, 33 | num_videos, 34 | num_ctx_frames, 35 | ): 36 | """Constructor. 37 | 38 | Args: 39 | distance: The distance metric to use when calculating nearest-neighbours. 40 | num_videos: The number of video sequences to display. 41 | num_ctx_frames: The number of context frames stacked together for each 42 | individual video frame. 43 | 44 | Raises: 45 | ValueError: If the distance metric is invalid. 46 | """ 47 | super().__init__(inter_class=True) 48 | 49 | if distance not in ["sqeuclidean", "cosine"]: 50 | raise ValueError( 51 | "{} is not a supported distance metric.".format(distance)) 52 | 53 | self.distance = distance 54 | self.num_videos = num_videos 55 | self.num_ctx_frames = num_ctx_frames 56 | 57 | def evaluate(self, outs): 58 | """Sample source and target sequences and plot nn frames.""" 59 | 60 | def _reshape(frame): 61 | s, h, w, c = frame.shape 62 | seq_len = s // self.num_ctx_frames 63 | return frame.reshape(seq_len, self.num_ctx_frames, h, w, c) 64 | 65 | embs = [o.embs for o in outs] 66 | frames = [o.frames for o in outs] 67 | 68 | # Randomly sample the video sequences we'd like to plot. 69 | seq_idxs = np.random.choice( 70 | np.arange(len(embs)), size=self.num_videos, replace=False) 71 | 72 | # Perform nearest-neighbor lookup in embedding space and retrieve the 73 | # frames associated with those embeddings. 74 | cand_frames = [_reshape(frames[seq_idxs[0]])[:, -1]] 75 | for cand_idx in seq_idxs[1:]: 76 | dists = cdist(embs[seq_idxs[0]], embs[cand_idx], self.distance) 77 | nn_ids = np.argmin(dists, axis=1) 78 | c_frames = _reshape(frames[cand_idx]) 79 | nn_frames = [c_frames[idx, -1] for idx in nn_ids] 80 | cand_frames.append(np.stack(nn_frames)) 81 | 82 | video = np.stack(cand_frames) 83 | return EvaluatorOutput(video=video) 84 | -------------------------------------------------------------------------------- /graphirl/evaluators/reconstruction_visualizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Frame reconstruction visualizer.""" 17 | 18 | from typing import List 19 | 20 | from .base import Evaluator 21 | from .base import EvaluatorOutput 22 | import numpy as np 23 | import torch 24 | import torch.nn.functional as F 25 | from torchvision.utils import make_grid 26 | from xirl.models import SelfSupervisedReconOutput 27 | 28 | 29 | class ReconstructionVisualizer(Evaluator): 30 | """Frame reconstruction visualizer.""" 31 | 32 | def __init__(self, num_frames, num_ctx_frames): 33 | """Constructor. 34 | 35 | Args: 36 | num_frames: The number of reconstructed frames in a sequence to display. 37 | num_ctx_frames: The number of context frames stacked together for each 38 | individual video frame. 39 | """ 40 | super().__init__(inter_class=False) 41 | 42 | self.num_frames = num_frames 43 | self.num_ctx_frames = num_ctx_frames 44 | 45 | def evaluate(self, outs): 46 | """Plot a frame along with its reconstruction.""" 47 | 48 | def _remove_ctx_frames(frame): 49 | s, h, w, c = frame.shape 50 | seq_len = s // self.num_ctx_frames 51 | frame = frame.reshape(seq_len, self.num_ctx_frames, h, w, c) 52 | return frame[:, -1] 53 | 54 | frames = [o.frames for o in outs] 55 | recons = [o.reconstruction for o in outs] 56 | 57 | r_idx = np.random.randint(0, len(frames)) 58 | frame = _remove_ctx_frames(frames[r_idx]) 59 | recon = _remove_ctx_frames(recons[r_idx]) 60 | 61 | # Select which frames we want to plot from the sequence. 62 | frame_idxs = np.random.choice( 63 | np.arange(frame.shape[0]), size=self.num_frames, replace=False) 64 | frame = frame[frame_idxs] 65 | recon = recon[frame_idxs] 66 | 67 | # Downsample the frame. 68 | _, _, sh, _ = recon.shape 69 | _, _, h, _ = frame.shape 70 | scale_factor = sh / h 71 | frame_ds = F.interpolate( 72 | torch.from_numpy(frame).permute(0, 3, 1, 2), 73 | mode='bilinear', 74 | scale_factor=scale_factor, 75 | recompute_scale_factor=False, 76 | align_corners=True).permute(0, 2, 3, 1).numpy() 77 | 78 | # Clip reconstruction between 0 and 1. 79 | recon = np.clip(recon, 0.0, 1.0) 80 | 81 | imgs = np.concatenate([frame_ds, recon], axis=0) 82 | img = make_grid(torch.from_numpy(imgs).permute(0, 3, 1, 2), nrow=2) 83 | 84 | return EvaluatorOutput(image=img.permute(1, 2, 0).numpy()) 85 | -------------------------------------------------------------------------------- /graphirl/evaluators/reward_visualizer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Reward visualizer.""" 17 | 18 | from typing import List 19 | 20 | from .base import Evaluator 21 | from .base import EvaluatorOutput 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | from scipy.spatial.distance import cdist 25 | from xirl.models import SelfSupervisedOutput 26 | 27 | 28 | class RewardVisualizer(Evaluator): 29 | """Distance to goal state visualizer.""" 30 | 31 | def __init__(self, distance, num_plots): 32 | """Constructor. 33 | 34 | Args: 35 | distance: The distance metric to use when calculating nearest-neighbours. 36 | num_plots: The number of reward plots to display. 37 | 38 | Raises: 39 | ValueError: If the distance metric is invalid. 40 | """ 41 | super().__init__(inter_class=False) 42 | 43 | if distance not in ["sqeuclidean", "cosine"]: 44 | raise ValueError( 45 | "{} is not a supported distance metric.".format(distance)) 46 | 47 | self.distance = distance 48 | self.num_plots = num_plots 49 | 50 | def _gen_reward_plot(self, rewards): 51 | """Create a pyplot plot and save to buffer.""" 52 | fig, axes = plt.subplots(1, len(rewards), figsize=(6.4 * len(rewards), 4.8)) 53 | if len(rewards) == 1: 54 | axes = [axes] 55 | for i, rew in enumerate(rewards): 56 | axes[i].plot(rew) 57 | fig.text(0.5, 0.04, "Timestep", ha="center") 58 | fig.text(0.04, 0.5, "Reward", va="center", rotation="vertical") 59 | fig.canvas.draw() 60 | img_arr = np.array(fig.canvas.renderer.buffer_rgba())[:, :, :3] 61 | plt.close() 62 | return img_arr 63 | 64 | def _compute_goal_emb(self, embs): 65 | """Compute the mean of all last frame embeddings.""" 66 | goal_emb = [emb[-1, :] for emb in embs] 67 | goal_emb = np.stack(goal_emb, axis=0) 68 | goal_emb = np.mean(goal_emb, axis=0, keepdims=True) 69 | return goal_emb 70 | 71 | def evaluate(self, outs): 72 | embs = [o.embs for o in outs] 73 | goal_emb = self._compute_goal_emb(embs) 74 | 75 | # Make sure we sample only as many as are available. 76 | num_plots = min(len(embs), self.num_plots) 77 | rand_idxs = np.random.choice( 78 | np.arange(len(embs)), size=num_plots, replace=False) 79 | 80 | # Compute rewards as distances to the goal embedding. 81 | rewards = [] 82 | for idx in rand_idxs: 83 | emb = embs[idx] 84 | dists = cdist(emb, goal_emb, self.distance) 85 | rewards.append(-dists) 86 | 87 | image = self._gen_reward_plot(rewards) 88 | return EvaluatorOutput(image=image) 89 | -------------------------------------------------------------------------------- /graphirl/tensorizers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Tensorizers convert a packet of video data into a packet of video tensors.""" 17 | 18 | import abc 19 | from typing import Any, Dict, Union 20 | 21 | import numpy as np 22 | import torch 23 | import torchvision.transforms.functional as TF 24 | 25 | from graphirl.types import SequenceType 26 | 27 | DataArrayPacket = Dict[SequenceType, Union[np.ndarray, str, int]] 28 | DataTensorPacket = Dict[SequenceType, Union[torch.Tensor, str]] 29 | 30 | 31 | class Tensorizer(abc.ABC): 32 | """Base tensorizer class. 33 | 34 | Custom tensorizers must subclass this class. 35 | """ 36 | 37 | @abc.abstractmethod 38 | def __call__(self, x): 39 | pass 40 | 41 | 42 | class IdentityTensorizer(Tensorizer): 43 | """Outputs the input as is.""" 44 | 45 | def __call__(self, x): 46 | return x 47 | 48 | 49 | class LongTensorizer(Tensorizer): 50 | """Converts the input to a LongTensor.""" 51 | 52 | def __call__(self, x): 53 | return torch.from_numpy(np.asarray(x)).long() 54 | 55 | class FloatTensorizer(Tensorizer): 56 | 57 | def __call__(self, x): 58 | return torch.from_numpy(np.asarray(x)).float() 59 | 60 | class FramesTensorizer(Tensorizer): 61 | """Converts a sequence of video frames to a batched FloatTensor.""" 62 | 63 | def __call__(self, x): 64 | assert x.ndim == 4, "Input must be a 4D sequence of frames." 65 | frames = [] 66 | for frame in x: 67 | frames.append(TF.to_tensor(frame)) 68 | return torch.stack(frames, dim=0) 69 | 70 | 71 | class ToTensor: 72 | """Convert video data to video tensors.""" 73 | 74 | MAP = { 75 | SequenceType.FRAMES: FramesTensorizer, 76 | SequenceType.FRAME_IDXS: LongTensorizer, 77 | SequenceType.VIDEO_NAME: IdentityTensorizer, 78 | SequenceType.VIDEO_LEN: LongTensorizer, 79 | } 80 | 81 | def __call__(self, data): 82 | """Iterate and transform the data values. 83 | 84 | Args: 85 | data: A dictionary containing key, value pairs where the key is an enum 86 | member of `SequenceType` and the value is either an int, a string or an 87 | ndarray respecting the key type. 88 | 89 | Raises: 90 | ValueError: If the input is not a dictionary or one of its keys is 91 | not a supported sequence type. 92 | 93 | Returns: 94 | The dictionary with the values tensorized. 95 | """ 96 | tensors = {} 97 | for key, np_arr in data.items(): 98 | tensors[key] = ToTensor.MAP[key]()(np_arr) 99 | return tensors 100 | 101 | 102 | 103 | class ToTensor_bbox: 104 | """Convert video data to bbox tensors.""" 105 | 106 | MAP = { 107 | SequenceType.FRAMES: FloatTensorizer, 108 | SequenceType.FRAME_IDXS: LongTensorizer, 109 | SequenceType.VIDEO_NAME: IdentityTensorizer, 110 | SequenceType.VIDEO_LEN: LongTensorizer, 111 | } 112 | 113 | def __call__(self, data): 114 | """Iterate and transform the data values. 115 | 116 | Args: 117 | data: A dictionary containing key, value pairs where the key is an enum 118 | member of `SequenceType` and the value is either an int, a string or an 119 | ndarray respecting the key type. 120 | 121 | Raises: 122 | ValueError: If the input is not a dictionary or one of its keys is 123 | not a supported sequence type. 124 | 125 | Returns: 126 | The dictionary with the values tensorized. 127 | """ 128 | tensors = {} 129 | for key, np_arr in data.items(): 130 | tensors[key] = ToTensor_bbox.MAP[key]()(np_arr) 131 | return tensors 132 | -------------------------------------------------------------------------------- /graphirl/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Trainers.""" 17 | 18 | from .base import Trainer 19 | from .classification import GoalFrameClassifierTrainer 20 | from .lifs import LIFSTrainer 21 | from .tcc import TCCTrainer 22 | from .tcn import TCNTrainer 23 | 24 | __all__ = [ 25 | "Trainer", 26 | "TCCTrainer", 27 | "TCNTrainer", 28 | "LIFSTrainer", 29 | "GoalFrameClassifierTrainer", 30 | ] 31 | -------------------------------------------------------------------------------- /graphirl/trainers/base.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Base class for defining training algorithms.""" 17 | 18 | import abc 19 | from typing import Dict, List, Optional, Union 20 | from ml_collections.config_dict.config_dict import ConfigDict 21 | 22 | import torch 23 | import torch.nn as nn 24 | 25 | from graphirl.models import SelfSupervisedOutput, SelfSupervisedOutputBbox 26 | 27 | BatchType = Dict[str, Union[torch.Tensor, List[str]]] 28 | 29 | 30 | class Trainer(abc.ABC): 31 | """Base trainer abstraction. 32 | 33 | Subclasses should override `compute_loss`. 34 | """ 35 | 36 | def __init__( 37 | self, 38 | model, 39 | optimizer, 40 | device, 41 | config, 42 | ): 43 | """Constructor. 44 | 45 | Args: 46 | model: The model to train. 47 | optimizer: The optimizer to use. 48 | device: The computing device. 49 | config: The config dict. 50 | """ 51 | self._model = model 52 | self._optimizer = optimizer 53 | self._device = device 54 | self._config = config 55 | 56 | self._model.to(self._device).train() 57 | 58 | @abc.abstractmethod 59 | def compute_loss( 60 | self, 61 | embs, 62 | batch, 63 | ): 64 | """Compute the loss on a single batch. 65 | 66 | Args: 67 | embs: The output of the embedding network. 68 | batch: The output of a VideoDataset dataloader. 69 | 70 | Returns: 71 | A tensor corresponding to the value of the loss function evaluated 72 | on the given batch. 73 | """ 74 | pass 75 | 76 | def compute_auxiliary_loss( 77 | self, 78 | out, # pylint: disable=unused-argument 79 | batch, # pylint: disable=unused-argument 80 | ): 81 | """Compute an auxiliary loss on a single batch. 82 | 83 | Args: 84 | out: The output of the self-supervised model. 85 | batch: The output of a VideoDataset dataloader. 86 | 87 | Returns: 88 | A tensor corresponding to the value of the auxiliary loss function 89 | evaluated on the given batch. 90 | """ 91 | return 0.0 92 | 93 | def train_one_iter(self, batch): 94 | """Single forward + backward pass of the model. 95 | 96 | Args: 97 | batch: The output of a VideoDataset dataloader. 98 | 99 | Returns: 100 | A dict of loss values. 101 | """ 102 | self._model.train() 103 | 104 | self._optimizer.zero_grad() 105 | 106 | # Forward pass to compute embeddings. 107 | frames = batch["frames"].to(self._device) 108 | out = self._model(frames) 109 | 110 | # Compute losses. 111 | loss = self.compute_loss(out.embs, batch) 112 | aux_loss = self.compute_auxiliary_loss(out, batch) 113 | total_loss = loss + aux_loss 114 | 115 | # Backwards pass + optimization step. 116 | total_loss.backward() 117 | self._optimizer.step() 118 | 119 | return { 120 | "train/base_loss": loss, 121 | "train/auxiliary_loss": aux_loss, 122 | "train/total_loss": total_loss, 123 | } 124 | 125 | @torch.no_grad() 126 | def eval_num_iters( 127 | self, 128 | valid_loader, 129 | eval_iters = None, 130 | ): 131 | """Compute the loss with the model in `eval()` mode. 132 | 133 | Args: 134 | valid_loader: The validation data loader. 135 | eval_iters: The number of time to call `next()` on the data iterator. Set 136 | to None to evaluate on the whole validation set. 137 | 138 | Returns: 139 | A dict of validation losses. 140 | """ 141 | self._model.eval() 142 | 143 | val_base_loss = 0.0 144 | val_aux_loss = 0.0 145 | it_ = 0 146 | for batch_idx, batch in enumerate(valid_loader): 147 | if eval_iters is not None and batch_idx >= eval_iters: 148 | break 149 | 150 | frames = batch["frames"].to(self._device) 151 | out = self._model(frames) 152 | 153 | val_base_loss += self.compute_loss(out.embs, batch) 154 | val_aux_loss += self.compute_auxiliary_loss(out, batch) 155 | it_ += 1 156 | val_base_loss /= it_ 157 | val_aux_loss /= it_ 158 | 159 | return { 160 | "valid/base_loss": val_base_loss, 161 | "valid/auxiliary_loss": val_aux_loss, 162 | "valid/total_loss": val_base_loss + val_aux_loss, 163 | } 164 | -------------------------------------------------------------------------------- /graphirl/trainers/classification.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Goal classifier trainer.""" 17 | 18 | from typing import Dict, List, Union 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | from graphirl.trainers.base import Trainer 23 | 24 | BatchType = Dict[str, Union[torch.Tensor, List[str]]] 25 | 26 | 27 | class GoalFrameClassifierTrainer(Trainer): 28 | """A trainer that learns to classifiy whether an image is a goal frame. 29 | 30 | This should be used in conjunction with the LastFrameAndRandomFrames frame 31 | sampler which ensures the batch of frame sequences consists of first 32 | one goal frame, then by N - 1 random other frames. 33 | """ 34 | 35 | def compute_loss( 36 | self, 37 | embs, 38 | batch, 39 | ): 40 | del batch 41 | 42 | batch_size, num_cc_frames, _ = embs.shape 43 | 44 | # Create the labels tensor. 45 | row_tensor = torch.FloatTensor([1] + [0] * (num_cc_frames - 1)) 46 | label_tensor = row_tensor.unsqueeze(0).repeat(batch_size, 1) 47 | label_tensor = label_tensor.to(self._device) 48 | 49 | return F.binary_cross_entropy_with_logits( 50 | embs.view(batch_size * num_cc_frames), 51 | label_tensor.view(batch_size * num_cc_frames), 52 | ) 53 | -------------------------------------------------------------------------------- /graphirl/trainers/lifs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """LIFS trainer.""" 17 | 18 | from typing import Dict, List, Union 19 | 20 | from ml_collections import ConfigDict 21 | import torch 22 | import torch.nn as nn 23 | import torch.nn.functional as F 24 | from graphirl.models import SelfSupervisedReconOutput 25 | from graphirl.trainers.base import Trainer 26 | 27 | BatchType = Dict[str, Union[torch.Tensor, List[str]]] 28 | 29 | 30 | class LIFSTrainer(Trainer): 31 | """A trainer that implements LIFS from [1]. 32 | 33 | This should be used in conjunction with the VariableStridedSampler frame 34 | sampler, which assumes rough alignment between pairs of sequences and hence 35 | a time index can be used to correspond frames across sequences. 36 | 37 | Note that the authors of [1] do not implement a negative term in the 38 | contrastive loss. It is just a similarity (l2) loss with an autoencoding 39 | loss to prevent the embeddings from collapsing to trivial constants. 40 | 41 | References: 42 | [1]: https://arxiv.org/abs/1703.02949 43 | """ 44 | 45 | def __init__( 46 | self, 47 | model, 48 | optimizer, 49 | device, 50 | config, 51 | ): 52 | super().__init__(model, optimizer, device, config) 53 | 54 | self.temperature = config.LOSS.LIFS.TEMPERATURE 55 | 56 | def compute_auxiliary_loss( 57 | self, 58 | out, 59 | batch, 60 | ): 61 | reconstruction = out.reconstruction 62 | frames = batch["frames"].to(self._device) 63 | b, t, _, _, _ = reconstruction.shape 64 | reconstruction = reconstruction.view((b * t, *reconstruction.shape[2:])) 65 | frames = frames.view((b * t, *frames.shape[2:])) 66 | _, _, sh, _ = reconstruction.shape 67 | _, _, h, _ = frames.shape 68 | scale_factor = sh / h 69 | frames_ds = F.interpolate( 70 | frames, 71 | mode="bilinear", 72 | scale_factor=scale_factor, 73 | recompute_scale_factor=False, 74 | align_corners=True, 75 | ) 76 | return F.mse_loss(reconstruction, frames_ds) 77 | 78 | def compute_loss( 79 | self, 80 | embs, 81 | batch, 82 | ): 83 | del batch 84 | 85 | batch_size, num_cc_frames, num_dims = embs.shape 86 | 87 | # Compute pairwise squared L2 distances between embeddings. 88 | embs_flat = embs.view(-1, num_dims) 89 | distances = torch.cdist(embs_flat, embs_flat).pow(2) 90 | distances = distances / self.temperature 91 | 92 | # Each row in a batch corresponds to a frame sequence. Since this 93 | # baseline assumes rough alignment between sequences, we want columns, 94 | # i.e. frames in each row that belong to the same index to be close 95 | # together in embedding space. 96 | labels = torch.arange(num_cc_frames).unsqueeze(0).repeat(batch_size, 1) 97 | labels = labels.to(self._device) 98 | mask = labels.flatten()[:, None] == labels.flatten()[None, :] 99 | return (distances * mask.float()).sum(dim=-1).mean() 100 | -------------------------------------------------------------------------------- /graphirl/trainers/tcc.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """TCC trainer.""" 17 | 18 | from typing import Dict, List, Union 19 | 20 | from ml_collections import ConfigDict 21 | import torch 22 | from graphirl.losses import compute_tcc_loss 23 | from graphirl.trainers.base import Trainer 24 | 25 | BatchType = Dict[str, Union[torch.Tensor, List[str]]] 26 | 27 | 28 | class TCCTrainer(Trainer): 29 | """A trainer for Temporal Cycle Consistency Learning [1]. 30 | 31 | References: 32 | [1]: arxiv.org/abs/1904.07846 33 | """ 34 | 35 | def __init__( 36 | self, 37 | model, 38 | optimizer, 39 | device, 40 | config, 41 | ): 42 | super().__init__(model, optimizer, device, config) 43 | 44 | self.normalize_embeddings = config.MODEL.NORMALIZE_EMBEDDINGS 45 | self.stochastic_matching = config.LOSS.TCC.STOCHASTIC_MATCHING 46 | self.loss_type = config.LOSS.TCC.LOSS_TYPE 47 | self.similarity_type = config.LOSS.TCC.SIMILARITY_TYPE 48 | self.cycle_length = config.LOSS.TCC.CYCLE_LENGTH 49 | self.temperature = config.LOSS.TCC.SOFTMAX_TEMPERATURE 50 | self.label_smoothing = config.LOSS.TCC.LABEL_SMOOTHING 51 | self.variance_lambda = config.LOSS.TCC.VARIANCE_LAMBDA 52 | self.huber_delta = config.LOSS.TCC.HUBER_DELTA 53 | self.normalize_indices = config.LOSS.TCC.NORMALIZE_INDICES 54 | 55 | def compute_loss( 56 | self, 57 | embs, 58 | batch, 59 | ): 60 | steps = batch["frame_idxs"].to(self._device) 61 | seq_lens = batch["video_len"].to(self._device) 62 | 63 | # Dynamically determine the number of cycles if using stochastic 64 | # matching. 65 | batch_size, num_cc_frames = embs.shape[:2] 66 | num_cycles = int(batch_size * num_cc_frames) 67 | 68 | return compute_tcc_loss( 69 | embs=embs, 70 | idxs=steps, 71 | seq_lens=seq_lens, 72 | stochastic_matching=self.stochastic_matching, 73 | normalize_embeddings=self.normalize_embeddings, 74 | loss_type=self.loss_type, 75 | similarity_type=self.similarity_type, 76 | num_cycles=num_cycles, 77 | cycle_length=self.cycle_length, 78 | temperature=self.temperature, 79 | label_smoothing=self.label_smoothing, 80 | variance_lambda=self.variance_lambda, 81 | huber_delta=self.huber_delta, 82 | normalize_indices=self.normalize_indices, 83 | ) 84 | -------------------------------------------------------------------------------- /graphirl/trainers/tcn.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """TCN trainer.""" 17 | 18 | from typing import Dict, List, Union 19 | 20 | import numpy as np 21 | import torch 22 | from graphirl.trainers.base import Trainer 23 | 24 | BatchType = Dict[str, Union[torch.Tensor, List[str]]] 25 | 26 | 27 | class TCNTrainer(Trainer): 28 | """A trainer that implements a single-view Time Contrastive Network [1]. 29 | 30 | This should be used in conjunction with the WindowSampler frame sampler. 31 | 32 | References: 33 | [1]: https://arxiv.org/abs/1704.06888 34 | """ 35 | 36 | def __init__( 37 | self, 38 | model, 39 | optimizer, 40 | device, 41 | config, 42 | ): 43 | super().__init__(model, optimizer, device, config) 44 | 45 | self.temperature = config.LOSS.TCN.TEMPERATURE 46 | self.num_pairs = config.LOSS.TCN.NUM_PAIRS 47 | self.pos_radius = config.LOSS.TCN.POS_RADIUS 48 | self.neg_radius = config.LOSS.TCN.NEG_RADIUS 49 | 50 | def compute_loss( 51 | self, 52 | embs, 53 | batch, 54 | ): 55 | del batch 56 | 57 | batch_size, num_cc_frames, _ = embs.shape 58 | 59 | # A positive is centered at the current index and can extend pos_radius 60 | # forward or backward. 61 | # A negative can be sampled anywhere outside a negative radius centered 62 | # around the current index. 63 | batch_pos = [] 64 | batch_neg = [] 65 | idxs = np.arange(num_cc_frames) 66 | for i in range(batch_size): 67 | pos_delta = np.random.choice( 68 | [-self.pos_radius, self.pos_radius], 69 | size=(num_cc_frames, self.num_pairs), 70 | ) 71 | pos_idxs = np.clip(idxs[:, None] + pos_delta, 0, num_cc_frames - 1) 72 | batch_pos.append(torch.LongTensor(pos_idxs)) 73 | 74 | negatives = [] 75 | for idx in idxs: 76 | allowed = (idxs > (idx + self.neg_radius)) | ( 77 | idxs < (idx - self.neg_radius)) 78 | neg_idxs = np.random.choice(idxs[allowed], size=self.num_pairs) 79 | negatives.append(neg_idxs) 80 | batch_neg.append(torch.LongTensor(np.vstack(negatives))) 81 | 82 | pos_losses = 0.0 83 | neg_losses = 0.0 84 | for i, (positives, negatives) in enumerate(zip(batch_pos, batch_neg)): 85 | row_idx = torch.arange(num_cc_frames).unsqueeze(1) 86 | 87 | # Compute pairwise squared L2 distances between the embeddings in a 88 | # sequence. 89 | emb_seq = embs[i] 90 | distances = torch.cdist(emb_seq, emb_seq).pow(2) 91 | distances = distances / self.temperature 92 | 93 | # For every embedding in the sequence, we need to minimize its 94 | # distance to every positive we sampled for it. 95 | pos_loss = distances[row_idx, positives] 96 | pos_losses += pos_loss.sum() 97 | 98 | # And for negatives, we need to ensure they are at least a distance 99 | # M apart. We use the squared hinge loss to express this constraint. 100 | neg_margin = 1 - distances[row_idx, negatives] 101 | neg_loss = torch.clamp(neg_margin, min=0).pow(2) 102 | neg_losses += neg_loss.sum() 103 | 104 | total_loss = (pos_losses + neg_losses) / (batch_size * num_cc_frames) 105 | return total_loss 106 | -------------------------------------------------------------------------------- /graphirl/transforms.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Transformations for video data.""" 17 | 18 | import enum 19 | from typing import Callable, Dict, Mapping, Sequence, Tuple 20 | import warnings 21 | 22 | import albumentations as alb 23 | import numpy as np 24 | import torch 25 | from graphirl.types import SequenceType 26 | 27 | 28 | @enum.unique 29 | class PretrainedMeans(enum.Enum): 30 | """Pretrained mean normalization values.""" 31 | 32 | IMAGENET = (0.485, 0.456, 0.406) 33 | 34 | 35 | @enum.unique 36 | class PretrainedStds(enum.Enum): 37 | """Pretrained std deviation normalization values.""" 38 | 39 | IMAGENET = (0.229, 0.224, 0.225) 40 | 41 | 42 | class UnNormalize: 43 | """Unnormalize a batch of images that have been normalized. 44 | 45 | Speficially, re-multiply by the standard deviation and shift by the mean. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | mean, 51 | std, 52 | ): 53 | """Constructor. 54 | 55 | Args: 56 | mean: The color channel means. 57 | std: The color channel standard deviation. 58 | """ 59 | if np.asarray(mean).shape: 60 | self.mean = torch.tensor(mean)[Ellipsis, :, None, None] 61 | if np.asarray(std).shape: 62 | self.std = torch.tensor(std)[Ellipsis, :, None, None] 63 | 64 | def __call__(self, tensor): 65 | return (tensor * self.std) + self.mean 66 | 67 | 68 | def augment_video( 69 | frames, 70 | pipeline, # pylint: disable=g-bare-generic 71 | ): 72 | """Apply the same augmentation pipeline to all frames in a video. 73 | 74 | Args: 75 | frames: A numpy array of shape (T, H, W, 3), where T is the number of frames 76 | in the video. 77 | pipeline (list): A list containing albumentation augmentations. 78 | 79 | Returns: 80 | The augmented frames of shape (T, H, W, 3). 81 | 82 | Raises: 83 | ValueError: If the input video doesn't have the correct shape. 84 | """ 85 | if frames.ndim != 4: 86 | raise ValueError("Input video must be a 4D sequence of frames.") 87 | 88 | transform = alb.ReplayCompose(pipeline, p=1.0) 89 | 90 | # Apply a transformation to the first frame and record the parameters 91 | # that were sampled in a replay, then use the parameters stored in the 92 | # replay to apply an identical transform to the remaining frames in the 93 | # sequence. 94 | with warnings.catch_warnings(): 95 | # This supresses albumentations' warning related to ReplayCompose. 96 | warnings.simplefilter("ignore") 97 | 98 | replay, frames_aug = None, [] 99 | for frame in frames: 100 | if replay is None: 101 | aug = transform(image=frame) 102 | replay = aug.pop("replay") 103 | else: 104 | aug = transform.replay(replay, image=frame) 105 | frames_aug.append(aug["image"]) 106 | 107 | return np.stack(frames_aug, axis=0) 108 | 109 | 110 | class VideoAugmentor: 111 | """Data augmentation for videos. 112 | 113 | Augmentor consistently augments data across the time dimension (i.e. dim 0). 114 | In other words, the same transformation is applied to every single frame in 115 | a video sequence. 116 | 117 | Currently, only image frames, i.e. SequenceType.FRAMES in a video can be 118 | augmented. 119 | """ 120 | 121 | MAP = { 122 | SequenceType.FRAMES: augment_video, 123 | } 124 | 125 | def __init__( 126 | self, 127 | params, # pylint: disable=g-bare-generic 128 | ): 129 | """Constructor. 130 | 131 | Args: 132 | params: 133 | 134 | Raises: 135 | ValueError: If params contains an unsupported data augmentation. 136 | """ 137 | for key in params.keys(): 138 | if key not in SequenceType: 139 | raise ValueError(f"{key} is not a supported SequenceType.") 140 | self._params = params 141 | 142 | def __call__( 143 | self, 144 | data, 145 | ): 146 | """Iterate and transform the data values. 147 | 148 | Currently, data augmentation is only applied to video frames, i.e. the 149 | value of the data dict associated with the SequenceType.IMAGE key. 150 | 151 | Args: 152 | data: A dict mapping from sequence type to sequence value. 153 | 154 | Returns: 155 | A an augmented dict. 156 | """ 157 | for key, transforms in self._params.items(): 158 | data[key] = VideoAugmentor.MAP[key](data[key], transforms) 159 | return data 160 | -------------------------------------------------------------------------------- /graphirl/types.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Types shared across modules.""" 17 | 18 | import enum 19 | 20 | 21 | @enum.unique 22 | class SequenceType(enum.Enum): 23 | """Sequence data types we know how to preprocess. 24 | 25 | If you need to preprocess additional video data, you must add it here. 26 | """ 27 | 28 | FRAMES = "frames" 29 | FRAME_IDXS = "frame_idxs" 30 | VIDEO_NAME = "video_name" 31 | VIDEO_LEN = "video_len" 32 | 33 | def __str__(self): # pylint: disable=invalid-str-returned 34 | return self.value 35 | 36 | 37 | @enum.unique 38 | class ImageTransformationType(enum.Enum): 39 | """Transformations we know how to run on images. 40 | 41 | If you want to add additional augmentations, you must add them here. 42 | """ 43 | 44 | RANDOM_RESIZED_CROP = "random_resized_crop" 45 | CENTER_CROP = "center_crop" 46 | GLOBAL_RESIZE = "global_resize" 47 | VERTICAL_FLIP = "vertical_flip" 48 | HORIZONTAL_FLIP = "horizontal_flip" 49 | COLOR_JITTER = "color_jitter" 50 | ROTATE = "rotate" 51 | DROPOUT = "dropout" 52 | NORMALIZE = "normalize" 53 | UNNORMALIZE = "unnormalize" 54 | 55 | def __str__(self): # pylint: disable=invalid-str-returned 56 | return self.value 57 | -------------------------------------------------------------------------------- /graphirl/video_samplers.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2021 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """Video samplers for mini-batch creation.""" 17 | 18 | import abc 19 | from typing import Dict, Iterator, List, Tuple 20 | 21 | import numpy as np 22 | import torch 23 | from torch.utils.data import Sampler 24 | 25 | ClassIdxVideoIdx = Tuple[int, int] 26 | DirTreeIndices = List[List[ClassIdxVideoIdx]] 27 | VideoBatchIter = Iterator[List[ClassIdxVideoIdx]] 28 | 29 | 30 | class VideoBatchSampler(abc.ABC, Sampler): 31 | """Base class for all video samplers.""" 32 | 33 | def __init__( 34 | self, 35 | dir_tree, 36 | batch_size, 37 | sequential = False, 38 | ): 39 | """Constructor. 40 | 41 | Args: 42 | dir_tree: The directory tree of a `datasets.VideoDataset`. 43 | batch_size: The number of videos in a batch. 44 | sequential: Set to `True` to disable any shuffling or randomness. 45 | """ 46 | assert isinstance(batch_size, int) 47 | 48 | self._batch_size = batch_size 49 | self._dir_tree = dir_tree 50 | self._sequential = sequential 51 | 52 | @abc.abstractmethod 53 | def _generate_indices(self): 54 | """Generate batch chunks containing (class idx, video_idx) tuples.""" 55 | pass 56 | 57 | def __iter__(self): 58 | idxs = self._generate_indices() 59 | if self._sequential: 60 | return iter(idxs) 61 | return iter(idxs[i] for i in torch.randperm(len(idxs))) 62 | 63 | def __len__(self): 64 | num_vids = 0 65 | for vids in self._dir_tree.values(): 66 | num_vids += len(vids) 67 | return num_vids // self.batch_size 68 | 69 | @property 70 | def batch_size(self): 71 | return self._batch_size 72 | 73 | @property 74 | def dir_tree(self): 75 | return self._dir_tree 76 | 77 | 78 | class RandomBatchSampler(VideoBatchSampler): 79 | """Randomly samples videos from different classes into the same batch. 80 | 81 | Note the `sequential` arg is disabled here. 82 | """ 83 | 84 | def _generate_indices(self): 85 | # Generate a list of video indices for every class. 86 | all_idxs = [] 87 | for k, v in enumerate(self._dir_tree.values()): 88 | seq = list(range(len(v))) 89 | all_idxs.extend([(k, s) for s in seq]) 90 | # Shuffle the indices. 91 | all_idxs = [all_idxs[i] for i in torch.randperm(len(all_idxs))] 92 | # If we have less total videos than the batch size, we pad with clones 93 | # until we reach a length of batch_size. 94 | if len(all_idxs) < self._batch_size: 95 | while len(all_idxs) < self._batch_size: 96 | all_idxs.append(all_idxs[np.random.randint(0, len(all_idxs))]) 97 | # Split the list of indices into chunks of len `batch_size`. 98 | idxs = [] 99 | end = self._batch_size * (len(all_idxs) // self._batch_size) 100 | for i in range(0, end, self._batch_size): 101 | batch_idxs = all_idxs[i:i + self._batch_size] 102 | idxs.append(batch_idxs) 103 | return idxs 104 | 105 | 106 | class SameClassBatchSampler(VideoBatchSampler): 107 | """Ensures all videos in a batch belong to the same class.""" 108 | 109 | def _generate_indices(self): 110 | idxs = [] 111 | for k, v in enumerate(self._dir_tree.values()): 112 | # Generate a list of indices for every video in the class. 113 | len_v = len(v) 114 | seq = list(range(len_v)) 115 | if not self._sequential: 116 | seq = [seq[i] for i in torch.randperm(len(seq))] 117 | # Split the list of indices into chunks of len `batch_size`, 118 | # ensuring we drop the last chunk if it is not of adequate length. 119 | batch_idxs = [] 120 | end = self._batch_size * (len_v // self._batch_size) 121 | for i in range(0, end, self._batch_size): 122 | xs = seq[i:i + self._batch_size] 123 | # Add the class index to the video index. 124 | xs = [(k, x) for x in xs] 125 | batch_idxs.append(xs) 126 | idxs.extend(batch_idxs) 127 | return idxs 128 | 129 | 130 | class SameClassBatchSamplerDownstream(SameClassBatchSampler): 131 | """A same class batch sampler with a batch size of 1. 132 | 133 | This batch sampler is used for downstream datasets. Since such datasets 134 | typically load a variable number of frames per video, we are forced to use 135 | a batch size of 1. 136 | """ 137 | 138 | def __init__( 139 | self, 140 | dir_tree, 141 | sequential = False, 142 | ): 143 | super().__init__(dir_tree, batch_size=1, sequential=sequential) 144 | -------------------------------------------------------------------------------- /media/preview.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/media/preview.gif -------------------------------------------------------------------------------- /media/summary.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/media/summary.png -------------------------------------------------------------------------------- /scripts/pegbox_helper.sh: -------------------------------------------------------------------------------- 1 | export MUJOCO_GL=egl 2 | python3 src/train.py \ 3 | --algorithm ${12} \ 4 | --domain_name robot \ 5 | --task_name $1 \ 6 | --episode_length 50 \ 7 | --exp_suffix ${11} \ 8 | --eval_mode none \ 9 | --save_video \ 10 | --eval_freq 10k \ 11 | --train_steps 1500k \ 12 | --buffer_size 500k \ 13 | --save_freq 50k \ 14 | --frame_stack 1 \ 15 | --log_dir /data/sateesh/logs_pegbox_graphirl \ 16 | --eval_episodes 50 \ 17 | --image_size 84 \ 18 | --render_image_size 84 \ 19 | --seed $4 \ 20 | --cameras $2 \ 21 | --camera_dropout $3 \ 22 | --action_space $5 \ 23 | --attention $7 \ 24 | --num_head_layers $6 \ 25 | --concat $8 \ 26 | --observation_type ${13} \ 27 | --context1 $9 \ 28 | --context2 ${10} \ 29 | --wandb_project pegbox_graphirl \ 30 | --save_video \ 31 | --pretrained_path {15} \ 32 | --reward_wrapper gil_pegbox \ 33 | --apply_wrapper \ 34 | --wandb -------------------------------------------------------------------------------- /scripts/push_helper.sh: -------------------------------------------------------------------------------- 1 | export MUJOCO_GL=egl 2 | python3 src/train.py \ 3 | --algorithm ${12} \ 4 | --domain_name robot \ 5 | --task_name $1 \ 6 | --episode_length 50 \ 7 | --exp_suffix ${11} \ 8 | --eval_mode none \ 9 | --save_video \ 10 | --eval_freq 10k \ 11 | --train_steps 800k \ 12 | --buffer_size 500k \ 13 | --save_freq 50k \ 14 | --frame_stack 1 \ 15 | --log_dir push_graphirl \ 16 | --eval_episodes 50 \ 17 | --image_size 84 \ 18 | --render_image_size 84 \ 19 | --seed $4 \ 20 | --cameras $2 \ 21 | --camera_dropout $3 \ 22 | --action_space $5 \ 23 | --attention $7 \ 24 | --num_head_layers $6 \ 25 | --concat $8 \ 26 | --observation_type ${13} \ 27 | --context1 $9 \ 28 | --context2 ${10} \ 29 | --wandb_project push_graphirl \ 30 | --save_video \ 31 | --pretrained_path ${15} \ 32 | --reward_wrapper gil \ 33 | --apply_wrapper \ 34 | --wandb 35 | -------------------------------------------------------------------------------- /scripts/reach_helper.sh: -------------------------------------------------------------------------------- 1 | export MUJOCO_GL=egl 2 | python3 src/train.py \ 3 | --algorithm ${12} \ 4 | --domain_name robot \ 5 | --task_name $1 \ 6 | --episode_length 50 \ 7 | --exp_suffix ${11} \ 8 | --eval_mode none \ 9 | --save_video \ 10 | --eval_freq 10k \ 11 | --train_steps 300k \ 12 | --save_freq 50k \ 13 | --frame_stack 1 \ 14 | --log_dir /data/sateesh/reach_graphirl_logs \ 15 | --eval_episodes 50 \ 16 | --image_size 84 \ 17 | --render_image_size 224 \ 18 | --seed $4 \ 19 | --cameras $2 \ 20 | --camera_dropout $3 \ 21 | --action_space $5 \ 22 | --attention $7 \ 23 | --num_head_layers $6 \ 24 | --concat $8 \ 25 | --observation_type ${13} \ 26 | --context1 $9 \ 27 | --context2 ${10} \ 28 | --wandb_project reach_graphirl \ 29 | --save_video \ 30 | --pretrained_path ${15} \ 31 | --reward_wrapper gil_reach \ 32 | --apply_wrapper \ 33 | --wandb -------------------------------------------------------------------------------- /scripts/run_pegbox.sh: -------------------------------------------------------------------------------- 1 | sh scripts_export/pegbox_helper.sh push 2 0 0 xy 3 1 0 1 1 push_graphirl sacv2 statefull+image 0 $1 -------------------------------------------------------------------------------- /scripts/run_push.sh: -------------------------------------------------------------------------------- 1 | sh scripts_export/push_helper.sh push 2 0 0 xy 3 1 0 1 1 push_graphirl sacv2 statefull+image 0 $1 -------------------------------------------------------------------------------- /scripts/run_reach.sh: -------------------------------------------------------------------------------- 1 | sh scripts_export/reach_helper.sh reach 2 0 0 xy 3 1 0 1 1 reach_graphirl sacv2 statefull+image 0 $1 -------------------------------------------------------------------------------- /src/algorithms/drq.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from copy import deepcopy 6 | import utils 7 | import algorithms.modules as m 8 | from algorithms.sac import SAC 9 | 10 | 11 | class DrQ(SAC): # [K=1, M=1] 12 | def __init__(self, obs_shape, state_space, action_shape, args): 13 | super().__init__(obs_shape, action_shape, args) 14 | 15 | def update(self, replay_buffer, L, step): 16 | obs, _, action, reward, next_obs, _ = replay_buffer.sample() 17 | 18 | self.update_critic(obs, action, reward, next_obs, L, step) 19 | 20 | if step % self.actor_update_freq == 0: 21 | self.update_actor_and_alpha(obs, L, step) 22 | 23 | if step % self.critic_target_update_freq == 0: 24 | self.soft_update_critic_target() 25 | -------------------------------------------------------------------------------- /src/algorithms/drq_multiview.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from copy import deepcopy 6 | import utils 7 | import algorithms.modules as m 8 | from algorithms.sac import SAC 9 | from algorithms.multiview import MultiView 10 | 11 | 12 | class DrQMultiView(MultiView): # [K=1, M=1] 13 | def __init__(self, obs_shape, action_shape, args): 14 | super().__init__(obs_shape, action_shape, args) 15 | 16 | def update(self, replay_buffer, L, step): 17 | obs, action, reward, next_obs = replay_buffer.sample_drq() 18 | 19 | self.update_critic(obs, action, reward, next_obs, L, step) 20 | 21 | if step % self.actor_update_freq == 0: 22 | self.update_actor_and_alpha(obs, L, step) 23 | 24 | if step % self.critic_target_update_freq == 0: 25 | self.soft_update_critic_target() 26 | -------------------------------------------------------------------------------- /src/algorithms/factory.py: -------------------------------------------------------------------------------- 1 | from algorithms.sac import SAC 2 | from algorithms.sacv2 import SACv2 3 | from algorithms.drq import DrQ 4 | from algorithms.sveav2 import SVEAv2 5 | from algorithms.multiview import MultiView 6 | from algorithms.drq_multiview import DrQMultiView 7 | from algorithms.drqv2 import DrQv2 8 | from algorithms.sacv2_3d import SACv2_3D 9 | 10 | algorithm = { 11 | 'sac': SAC, 12 | 'sacv2': SACv2, 13 | 'sacv2_3d':SACv2_3D, 14 | 'drq': DrQ, 15 | 'sveav2': SVEAv2, 16 | 'multiview': MultiView, 17 | 'drq_multiview': DrQMultiView, 18 | 'drqv2': DrQv2 19 | } 20 | 21 | 22 | def make_agent(obs_shape, state_shape , action_shape, args): 23 | if args.algorithm=='sacv2_3d': 24 | if args.use_latent: 25 | a_obs_shape = (args.bottleneck*32, 32, 32) 26 | elif args.use_impala: 27 | a_obs_shape = (32, 8, 8) 28 | else: 29 | a_obs_shape = (32, 26, 26) 30 | return algorithm[args.algorithm](a_obs_shape, (3, 64, 64), action_shape, args) 31 | else: 32 | return algorithm[args.algorithm](obs_shape, state_shape, action_shape, args) -------------------------------------------------------------------------------- /src/algorithms/multiview.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from copy import deepcopy 6 | import utils 7 | import algorithms.modules as m 8 | 9 | 10 | class MultiView(object): 11 | def __init__(self, obs_shape, action_shape, args): 12 | self.discount = args.discount 13 | self.critic_tau = args.critic_tau 14 | self.encoder_tau = args.encoder_tau 15 | self.actor_update_freq = args.actor_update_freq 16 | self.critic_target_update_freq = args.critic_target_update_freq 17 | self.use_vit = args.use_vit 18 | 19 | shared_cnn_1 = m.SharedCNN(obs_shape, args.num_shared_layers, args.num_filters, args.mean_zero).cuda() # Third Person Image 20 | shared_cnn_2 = m.SharedCNN(obs_shape, args.num_shared_layers, args.num_filters, args.mean_zero).cuda() # First Person Image 21 | 22 | integrator = m.Integrator(shared_cnn_1.out_shape, shared_cnn_2.out_shape, args.num_filters).cuda() # Change channel dimensions of concatenated features 23 | 24 | assert shared_cnn_1.out_shape==shared_cnn_2.out_shape 25 | head_cnn = m.HeadCNN(shared_cnn_1.out_shape, args.num_head_layers, args.num_filters).cuda() # Pass concatenated features into the head 26 | 27 | actor_encoder = m.MultiViewEncoder( 28 | shared_cnn_1, 29 | shared_cnn_2, 30 | integrator, 31 | head_cnn, 32 | m.RLProjection(head_cnn.out_shape, args.projection_dim) 33 | ) 34 | critic_encoder = m.MultiViewEncoder( 35 | shared_cnn_1, 36 | shared_cnn_2, 37 | integrator, 38 | head_cnn, 39 | m.RLProjection(head_cnn.out_shape, args.projection_dim) 40 | ) 41 | 42 | self.actor = m.Actor(actor_encoder, action_shape, args.hidden_dim, args.actor_log_std_min, args.actor_log_std_max, multiview=True).cuda() 43 | self.critic = m.Critic(critic_encoder, action_shape, args.hidden_dim, multiview=True).cuda() 44 | self.critic_target = deepcopy(self.critic) 45 | 46 | self.log_alpha = torch.tensor(np.log(args.init_temperature)).cuda() 47 | self.log_alpha.requires_grad = True 48 | self.target_entropy = -np.prod(action_shape) 49 | 50 | self.actor_optimizer = torch.optim.Adam( 51 | self.actor.parameters(), lr=args.actor_lr, betas=(args.actor_beta, 0.999) 52 | ) 53 | self.critic_optimizer = torch.optim.Adam( 54 | self.critic.parameters(), lr=args.critic_lr, betas=(args.critic_beta, 0.999) 55 | ) 56 | self.log_alpha_optimizer = torch.optim.Adam( 57 | [self.log_alpha], lr=args.alpha_lr, betas=(args.alpha_beta, 0.999) 58 | ) 59 | 60 | self.train() 61 | self.critic_target.train() 62 | 63 | def train(self, training=True): 64 | self.training = training 65 | self.actor.train(training) 66 | self.critic.train(training) 67 | 68 | def eval(self): 69 | self.train(False) 70 | 71 | @property 72 | def alpha(self): 73 | return self.log_alpha.exp() 74 | 75 | def _obs_to_input(self, obs): 76 | if isinstance(obs, utils.LazyFrames): 77 | _obs = np.array(obs) 78 | else: 79 | _obs = obs 80 | _obs = torch.FloatTensor(_obs).cuda() 81 | _obs = _obs.unsqueeze(0) 82 | return _obs 83 | 84 | def select_action(self, obs): 85 | _obs = self._obs_to_input(obs) 86 | with torch.no_grad(): 87 | mu, _, _, _ = self.actor(_obs, compute_pi=False, compute_log_pi=False) 88 | return mu.cpu().data.numpy().flatten() 89 | 90 | def sample_action(self, obs): 91 | _obs = self._obs_to_input(obs) 92 | with torch.no_grad(): 93 | mu, pi, _, _ = self.actor(_obs, compute_log_pi=False) 94 | return pi.cpu().data.numpy().flatten() 95 | 96 | def update_critic(self, obs, action, reward, next_obs, L=None, step=None): 97 | with torch.no_grad(): 98 | _, policy_action, log_pi, _ = self.actor(next_obs) 99 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) 100 | target_V = torch.min(target_Q1, 101 | target_Q2) - self.alpha.detach() * log_pi 102 | target_Q = reward + (self.discount * target_V) 103 | 104 | current_Q1, current_Q2 = self.critic(obs, action) 105 | critic_loss = F.mse_loss(current_Q1, 106 | target_Q) + F.mse_loss(current_Q2, target_Q) 107 | if L is not None: 108 | L.log('train_critic/loss', critic_loss, step) 109 | 110 | self.critic_optimizer.zero_grad() 111 | critic_loss.backward() 112 | self.critic_optimizer.step() 113 | 114 | def update_actor_and_alpha(self, obs, L=None, step=None, update_alpha=True): 115 | _, pi, log_pi, log_std = self.actor(obs, detach=True) 116 | actor_Q1, actor_Q2 = self.critic(obs, pi, detach=True) 117 | 118 | actor_Q = torch.min(actor_Q1, actor_Q2) 119 | actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() 120 | 121 | if L is not None: 122 | L.log('train_actor/loss', actor_loss, step) 123 | entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi) 124 | ) + log_std.sum(dim=-1) 125 | 126 | self.actor_optimizer.zero_grad() 127 | actor_loss.backward() 128 | self.actor_optimizer.step() 129 | 130 | if update_alpha: 131 | self.log_alpha_optimizer.zero_grad() 132 | alpha_loss = (self.alpha * (-log_pi - self.target_entropy).detach()).mean() 133 | 134 | if L is not None: 135 | L.log('train_alpha/loss', alpha_loss, step) 136 | L.log('train_alpha/value', self.alpha, step) 137 | 138 | alpha_loss.backward() 139 | self.log_alpha_optimizer.step() 140 | 141 | def soft_update_critic_target(self): 142 | utils.soft_update_params( 143 | self.critic.Q1, self.critic_target.Q1, self.critic_tau 144 | ) 145 | utils.soft_update_params( 146 | self.critic.Q2, self.critic_target.Q2, self.critic_tau 147 | ) 148 | utils.soft_update_params( 149 | self.critic.encoder, self.critic_target.encoder, 150 | self.encoder_tau 151 | ) 152 | 153 | def update(self, replay_buffer, L, step): 154 | obs, action, reward, next_obs = replay_buffer.sample() 155 | 156 | self.update_critic(obs, action, reward, next_obs, L, step) 157 | 158 | if step % self.actor_update_freq == 0: 159 | self.update_actor_and_alpha(obs, L, step) 160 | 161 | if step % self.critic_target_update_freq == 0: 162 | self.soft_update_critic_target() -------------------------------------------------------------------------------- /src/algorithms/rot_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import math 6 | 7 | 8 | def euler2mat(angle, scaling=False, translation=False): 9 | """Convert euler angles to rotation matrix. 10 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174 11 | Args: 12 | angle: rotation angle along 3 axis (a, b, y) in radians -- size = [B, 3] 13 | Returns: 14 | Rotation matrix corresponding to the euler angles -- size = [B, 3, 3] 15 | """ 16 | B = angle.size(0) 17 | euler_angle = 'bay' 18 | 19 | if euler_angle == 'bay': 20 | a, b, y = angle[:, 0], angle[:, 1], angle[:, 2] 21 | x, y, z = b, a, y 22 | elif euler_angle == 'aby': 23 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2] 24 | else: 25 | raise NotImplementedError 26 | 27 | cosz = torch.cos(z) 28 | sinz = torch.sin(z) 29 | 30 | zeros = z.detach() * 0 31 | ones = zeros.detach() + 1 32 | zmat = torch.stack([cosz, -sinz, zeros, 33 | sinz, cosz, zeros, 34 | zeros, zeros, ones], dim=1).reshape(B, 3, 3) 35 | 36 | cosy = torch.cos(y) 37 | siny = torch.sin(y) 38 | 39 | ymat = torch.stack([cosy, zeros, siny, 40 | zeros, ones, zeros, 41 | -siny, zeros, cosy], dim=1).reshape(B, 3, 3) 42 | 43 | cosx = torch.cos(x) 44 | sinx = torch.sin(x) 45 | 46 | xmat = torch.stack([ones, zeros, zeros, 47 | zeros, cosx, -sinx, 48 | zeros, sinx, cosx], dim=1).reshape(B, 3, 3) 49 | 50 | rotMat = xmat @ ymat @ zmat 51 | # rotMat = zmat 52 | # rotMat = ymat @ zmat 53 | # rotMat = xmat @ ymat 54 | # rotMat = xmat @ zmat 55 | 56 | if scaling: 57 | v_scale = angle[:,3] 58 | v_trans = angle[:,4:] 59 | else: 60 | v_trans = angle[:,3:] 61 | 62 | if scaling: 63 | # one = torch.ones_like(v_scale).detach() 64 | # t_scale = torch.stack([v_scale, one, one, 65 | # one, v_scale, one, 66 | # one, one, v_scale], dim=1).view(B, 3, 3) 67 | rotMat = rotMat * v_scale.unsqueeze(1).unsqueeze(1) 68 | 69 | if translation: 70 | rotMat = torch.cat([rotMat, v_trans.view([B, 3, 1]).cuda()], 2) # F.affine_grid takes 3x4 71 | else: 72 | rotMat = torch.cat([rotMat, torch.zeros([B, 3, 1]).cuda().detach()], 2) # F.affine_grid takes 3x4 73 | 74 | return rotMat 75 | 76 | 77 | def quat2mat(quat, scaling=False, translation=False): 78 | """Convert quaternion coefficients to rotation matrix. 79 | Args: 80 | quat: first three coeff of quaternion of rotation. fourht is then computed to have a norm of 1 -- size = [B, 3] 81 | Returns: 82 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] 83 | """ 84 | norm_quat = torch.cat([quat[:,:1].detach()*0 + 1, quat], dim=1) 85 | norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) 86 | w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] 87 | 88 | B = quat.size(0) 89 | 90 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) 91 | wx, wy, wz = w*x, w*y, w*z 92 | xy, xz, yz = x*y, x*z, y*z 93 | 94 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, 95 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, 96 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3) 97 | rotMat = torch.cat([rotMat, torch.zeros([B, 3, 1]).cuda().detach()], 2) # F.affine_grid takes 3x4 98 | 99 | return rotMat 100 | 101 | -------------------------------------------------------------------------------- /src/algorithms/sac.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from copy import deepcopy 6 | import utils 7 | import algorithms.modules as m 8 | 9 | 10 | class SAC(object): 11 | def __init__(self, obs_shape, action_shape, args): 12 | self.discount = args.discount 13 | self.critic_tau = args.critic_tau 14 | self.encoder_tau = args.encoder_tau 15 | self.actor_update_freq = args.actor_update_freq 16 | self.critic_target_update_freq = args.critic_target_update_freq 17 | self.from_state = args.observation_type=='state' 18 | self.use_vit = args.use_vit 19 | 20 | if self.from_state: 21 | shared = m.Identity(obs_shape) 22 | head = m.Identity(obs_shape) 23 | elif self.use_vit: 24 | shared = m.SharedTransformer(obs_shape, args.num_shared_layers, args.svea_num_heads, args.svea_embed_dim).cuda() 25 | head = m.HeadCNN(shared.out_shape, 0, 0).cuda() 26 | else: 27 | shared = m.SharedCNN(obs_shape, args.num_shared_layers, args.num_filters, args.mean_zero).cuda() 28 | head = m.HeadCNN(shared.out_shape, args.num_head_layers, args.num_filters).cuda() 29 | actor_encoder = m.Encoder( 30 | shared, 31 | head, 32 | m.RLProjection(head.out_shape, args.projection_dim) 33 | ) 34 | critic_encoder = m.Encoder( 35 | shared, 36 | head, 37 | m.RLProjection(head.out_shape, args.projection_dim) 38 | ) 39 | 40 | self.actor = m.Actor(actor_encoder, action_shape, args.hidden_dim, args.actor_log_std_min, args.actor_log_std_max).cuda() 41 | self.critic = m.Critic(critic_encoder, action_shape, args.hidden_dim).cuda() 42 | self.critic_target = deepcopy(self.critic) 43 | 44 | self.log_alpha = torch.tensor(np.log(args.init_temperature)).cuda() 45 | self.log_alpha.requires_grad = True 46 | self.target_entropy = -np.prod(action_shape) 47 | 48 | self.actor_optimizer = torch.optim.Adam( 49 | self.actor.parameters(), lr=args.actor_lr, betas=(args.actor_beta, 0.999) 50 | ) 51 | self.critic_optimizer = torch.optim.Adam( 52 | self.critic.parameters(), lr=args.critic_lr, betas=(args.critic_beta, 0.999) 53 | ) 54 | self.log_alpha_optimizer = torch.optim.Adam( 55 | [self.log_alpha], lr=args.alpha_lr, betas=(args.alpha_beta, 0.999) 56 | ) 57 | 58 | self.train() 59 | self.critic_target.train() 60 | 61 | def train(self, training=True): 62 | self.training = training 63 | self.actor.train(training) 64 | self.critic.train(training) 65 | 66 | def eval(self): 67 | self.train(False) 68 | 69 | @property 70 | def alpha(self): 71 | return self.log_alpha.exp() 72 | 73 | def _obs_to_input(self, obs): 74 | if isinstance(obs, utils.LazyFrames): 75 | _obs = np.array(obs) 76 | else: 77 | _obs = obs 78 | _obs = torch.FloatTensor(_obs).cuda() 79 | _obs = _obs.unsqueeze(0) 80 | return _obs 81 | 82 | def select_action(self, obs, state=None): 83 | _obs = self._obs_to_input(obs) 84 | with torch.no_grad(): 85 | mu, _, _, _ = self.actor(_obs, compute_pi=False, compute_log_pi=False) 86 | return mu.cpu().data.numpy().flatten() 87 | 88 | def sample_action(self, obs, state=None, step=None): 89 | _obs = self._obs_to_input(obs) 90 | with torch.no_grad(): 91 | mu, pi, _, _ = self.actor(_obs, compute_log_pi=False) 92 | return pi.cpu().data.numpy().flatten() 93 | 94 | def update_critic(self, obs, action, reward, next_obs, L=None, step=None): 95 | with torch.no_grad(): 96 | _, policy_action, log_pi, _ = self.actor(next_obs) 97 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) 98 | target_V = torch.min(target_Q1, 99 | target_Q2) - self.alpha.detach() * log_pi 100 | target_Q = reward + (self.discount * target_V) 101 | 102 | current_Q1, current_Q2 = self.critic(obs, action) 103 | critic_loss = F.mse_loss(current_Q1, 104 | target_Q) + F.mse_loss(current_Q2, target_Q) 105 | if L is not None: 106 | L.log('train_critic/loss', critic_loss, step) 107 | 108 | self.critic_optimizer.zero_grad() 109 | critic_loss.backward() 110 | self.critic_optimizer.step() 111 | 112 | def update_actor_and_alpha(self, obs, L=None, step=None, update_alpha=True): 113 | _, pi, log_pi, log_std = self.actor(obs, detach=True) 114 | actor_Q1, actor_Q2 = self.critic(obs, pi, detach=True) 115 | 116 | actor_Q = torch.min(actor_Q1, actor_Q2) 117 | actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean() 118 | 119 | if L is not None: 120 | L.log('train_actor/loss', actor_loss, step) 121 | entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi) 122 | ) + log_std.sum(dim=-1) 123 | 124 | self.actor_optimizer.zero_grad() 125 | actor_loss.backward() 126 | self.actor_optimizer.step() 127 | 128 | if update_alpha: 129 | self.log_alpha_optimizer.zero_grad() 130 | alpha_loss = (self.alpha * (-log_pi - self.target_entropy).detach()).mean() 131 | 132 | if L is not None: 133 | L.log('train_alpha/loss', alpha_loss, step) 134 | L.log('train_alpha/value', self.alpha, step) 135 | 136 | alpha_loss.backward() 137 | self.log_alpha_optimizer.step() 138 | 139 | def soft_update_critic_target(self): 140 | utils.soft_update_params( 141 | self.critic.Q1, self.critic_target.Q1, self.critic_tau 142 | ) 143 | utils.soft_update_params( 144 | self.critic.Q2, self.critic_target.Q2, self.critic_tau 145 | ) 146 | utils.soft_update_params( 147 | self.critic.encoder, self.critic_target.encoder, 148 | self.encoder_tau 149 | ) 150 | 151 | def update(self, replay_buffer, L, step): 152 | obs, _, action, reward, next_obs, _ = replay_buffer.sample() 153 | 154 | self.update_critic(obs, action, reward, next_obs, L, step) 155 | 156 | if step % self.actor_update_freq == 0: 157 | self.update_actor_and_alpha(obs, L, step) 158 | 159 | if step % self.critic_target_update_freq == 0: 160 | self.soft_update_critic_target() 161 | -------------------------------------------------------------------------------- /src/algorithms/svea_multiview.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from copy import deepcopy 6 | import utils 7 | import augmentations 8 | import algorithms.modules as m 9 | from algorithms.sac import SAC 10 | 11 | 12 | class SVEAMultiView(MultiView): 13 | def __init__(self, obs_shape, action_shape, args): 14 | super().__init__(obs_shape, action_shape, args) 15 | self.svea_alpha = args.svea_alpha 16 | self.svea_beta = args.svea_beta 17 | self.svea_augmentation = args.svea_augmentation 18 | raise NotImplementedError('use sveav2 instead') 19 | 20 | def augment(self, obs): 21 | if self.svea_augmentation == 'colorjitter': 22 | return augmentations.random_color_jitter(obs.clone()) 23 | elif self.svea_augmentation == 'conv': 24 | return augmentations.random_conv(obs.clone()) 25 | else: 26 | raise NotImplementedError(f'Unsupported augmentation: {self.svea_augmentation}') 27 | 28 | def update_critic(self, obs, action, reward, next_obs, L=None, step=None): 29 | with torch.no_grad(): 30 | _, policy_action, log_pi, _ = self.actor(next_obs) 31 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) 32 | target_V = torch.min(target_Q1, 33 | target_Q2) - self.alpha.detach() * log_pi 34 | target_Q = reward + (self.discount * target_V) 35 | 36 | if self.svea_alpha == self.svea_beta: 37 | obs = utils.cat(obs, self.augment(obs)) 38 | action = utils.cat(action, action) 39 | target_Q = utils.cat(target_Q, target_Q) 40 | 41 | current_Q1, current_Q2 = self.critic(obs, action) 42 | critic_loss = (self.svea_alpha + self.svea_beta) * \ 43 | (F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)) 44 | else: 45 | current_Q1, current_Q2 = self.critic(obs, action) 46 | critic_loss = self.svea_alpha * \ 47 | (F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)) 48 | 49 | obs_aug = self.augment(obs) 50 | current_Q1_aug, current_Q2_aug = self.critic(obs_aug, action) 51 | critic_loss += self.svea_beta * \ 52 | (F.mse_loss(current_Q1_aug, target_Q) + F.mse_loss(current_Q2_aug, target_Q)) 53 | 54 | if L is not None: 55 | L.log('train_critic/loss', critic_loss, step) 56 | 57 | self.critic_optimizer.zero_grad() 58 | critic_loss.backward() 59 | self.critic_optimizer.step() 60 | 61 | def update(self, replay_buffer, L, step): 62 | obs, action, reward, next_obs = replay_buffer.sample_drq(pad=6 if self.use_vit else 4) 63 | 64 | self.update_critic(obs, action, reward, next_obs, L, step) 65 | 66 | if step % self.actor_update_freq == 0: 67 | self.update_actor_and_alpha(obs, L, step) 68 | 69 | if step % self.critic_target_update_freq == 0: 70 | self.soft_update_critic_target() 71 | -------------------------------------------------------------------------------- /src/algorithms/sveav2.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from copy import deepcopy 6 | import utils 7 | import augmentations 8 | import algorithms.modules as m 9 | from algorithms.sacv2 import SACv2 10 | 11 | 12 | class SVEAv2(SACv2): 13 | def __init__(self, obs_shape, action_shape, args): 14 | super().__init__(obs_shape, action_shape, args) 15 | self.svea_alpha = args.svea_alpha 16 | self.svea_beta = args.svea_beta 17 | self.svea_augmentation = args.svea_augmentation 18 | self.naive = args.naive 19 | 20 | def augment(self, obs): 21 | if self.svea_augmentation == 'colorjitter': 22 | return augmentations.random_color_jitter(obs.clone()) 23 | elif self.svea_augmentation == 'affine+colorjitter': 24 | return augmentations.random_color_jitter(augmentations.random_affine(obs.clone())) 25 | elif self.svea_augmentation == 'noise': 26 | return augmentations.random_noise(obs.clone()) 27 | elif self.svea_augmentation == 'affine+noise': 28 | return augmentations.random_noise(augmentations.random_affine(obs.clone())) 29 | elif self.svea_augmentation == 'conv': 30 | return augmentations.random_conv(obs.clone()) 31 | elif self.svea_augmentation == 'affine+conv': 32 | return augmentations.random_conv(augmentations.random_affine(obs.clone())) 33 | elif self.svea_augmentation == 'overlay': 34 | return augmentations.random_overlay(obs.clone()) 35 | elif self.svea_augmentation == 'affine+overlay': 36 | return augmentations.random_overlay(augmentations.random_affine(obs.clone())) 37 | elif self.svea_augmentation == 'none': 38 | return obs 39 | else: 40 | raise NotImplementedError(f'Unsupported augmentation: {self.svea_augmentation}') 41 | 42 | def update_critic(self, obs, action, reward, next_obs, L=None, step=None): 43 | with torch.no_grad(): 44 | _, policy_action, log_pi, _ = self.actor(next_obs) 45 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action) 46 | target_V = torch.min(target_Q1, 47 | target_Q2) - self.alpha.detach() * log_pi 48 | target_Q = reward + (self.discount * target_V) 49 | 50 | if self.svea_augmentation != 'none' and not self.naive: 51 | action = utils.cat(action, action) 52 | target_Q = utils.cat(target_Q, target_Q) 53 | 54 | current_Q1, current_Q2 = self.critic(obs, action) 55 | critic_loss = (self.svea_alpha + self.svea_beta) * \ 56 | (F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q)) 57 | if L is not None: 58 | L.log('train_critic/loss', critic_loss, step) 59 | 60 | self.critic_optimizer.zero_grad(set_to_none=True) 61 | critic_loss.backward() 62 | self.critic_optimizer.step() 63 | 64 | def update(self, replay_buffer, L, step): 65 | if step % self.update_freq != 0: 66 | return 67 | 68 | obs, action, reward, next_obs = replay_buffer.sample() 69 | obs = self.aug(obs) # random shift 70 | if self.svea_augmentation != 'none': 71 | if self.naive: 72 | obs = self.augment(obs) # naively apply strong augmentation 73 | else: 74 | obs = utils.cat(obs, self.augment(obs)) # strong augmentation 75 | 76 | if self.multiview: 77 | obs = self.encoder(obs[:,:3,:,:], obs[:,3:6,:,:]) 78 | else: 79 | obs = self.encoder(obs) 80 | 81 | if self.svea_augmentation != 'none' and not self.naive: 82 | obs_unaug = obs[:obs.size(0)//2] # unaugmented observations 83 | else: 84 | obs_unaug = obs 85 | 86 | with torch.no_grad(): 87 | next_obs = self.aug(next_obs) 88 | if self.svea_augmentation != 'none' and self.naive: 89 | next_obs = self.augment(next_obs) # naively apply strong augmentation 90 | if self.multiview: 91 | next_obs = self.encoder(next_obs[:,:3,:,:], next_obs[:,3:6,:,:]) 92 | else: 93 | next_obs = self.encoder(next_obs) 94 | 95 | self.update_critic(obs, action, reward, next_obs, L, step) 96 | self.update_actor_and_alpha(obs_unaug.detach(), L, step) 97 | utils.soft_update_params(self.critic, self.critic_target, self.tau) 98 | -------------------------------------------------------------------------------- /src/env/__pycache__/wrappers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/__pycache__/wrappers.cpython-37.pyc -------------------------------------------------------------------------------- /src/env/robot/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /src/env/robot/__pycache__/gym_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/__pycache__/gym_utils.cpython-37.pyc -------------------------------------------------------------------------------- /src/env/robot/__pycache__/push.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/__pycache__/push.cpython-37.pyc -------------------------------------------------------------------------------- /src/env/robot/__pycache__/registration.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/__pycache__/registration.cpython-37.pyc -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/base_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/base_link.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/base_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/base_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/bellows_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/bellows_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/elbow_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/elbow_flex_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/estop_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/estop_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/forearm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/forearm_roll_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/gripper_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/gripper_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/head_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/head_pan_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/head_tilt_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/head_tilt_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/l_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/l_wheel_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/laser_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/laser_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/left_finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/left_finger.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/left_inner_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/left_inner_knuckle.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/left_outer_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/left_outer_knuckle.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/link1.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link1.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/link2.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link2.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/link3.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link3.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/link4.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link4.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/link5.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link5.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/link6.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link6.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/link7.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link7.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/link_base copy.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link_base copy.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/link_base.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link_base.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/r_wheel_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/r_wheel_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/right_finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/right_finger.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/right_inner_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/right_inner_knuckle.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/right_outer_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/right_outer_knuckle.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/shoulder_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/shoulder_lift_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/shoulder_pan_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/shoulder_pan_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/torso_fixed_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/torso_fixed_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/torso_lift_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/torso_lift_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/upperarm_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/upperarm_roll_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/wrist_flex_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/wrist_flex_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/fetch/wrist_roll_link_collision.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/wrist_roll_link_collision.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/golf.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | -------------------------------------------------------------------------------- /src/env/robot/assets/robot/hammer.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /src/env/robot/assets/robot/lift.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/arm/Base_Link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/Base_Link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/arm/Bracelet_Link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/Bracelet_Link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/arm/ForeArm_Link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/ForeArm_Link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/arm/HalfArm1_Link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/HalfArm1_Link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/arm/HalfArm2_Link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/HalfArm2_Link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/arm/Shoulder_Link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/Shoulder_Link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/arm/SphericalWrist1_Link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/SphericalWrist1_Link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/arm/SphericalWrist2_Link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/SphericalWrist2_Link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/inner_finger_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/inner_finger_coarse.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/inner_knuckle_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/inner_knuckle_coarse.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/kinova_robotiq_coupler.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/kinova_robotiq_coupler.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/outer_finger_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/outer_finger_coarse.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/outer_knuckle_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/outer_knuckle_coarse.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/robotiq_85_base_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_base_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/robotiq_85_base_link_coarse.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_base_link_coarse.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/robotiq_85_finger_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_finger_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/robotiq_85_finger_tip_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_finger_tip_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/robotiq_85_inner_knuckle_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_inner_knuckle_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq/robotiq_85_knuckle_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_knuckle_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_base_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_base_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_pad_vis.dae: -------------------------------------------------------------------------------- 1 | 2 | 3 | 2016-07-17T22:25:43.361178 4 | 2016-07-17T22:25:43.361188 5 | Z_UP 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 0.0 0.0 0.0 1.0 14 | 15 | 16 | 0.0 0.0 0.0 1.0 17 | 18 | 19 | 0.7 0.7 0.7 1.0 20 | 21 | 22 | 1 1 1 1.0 23 | 24 | 25 | 0.0 26 | 27 | 28 | 0.0 0.0 0.0 1.0 29 | 30 | 31 | 0.0 32 | 33 | 34 | 0.0 0.0 0.0 1.0 35 | 36 | 37 | 1.0 38 | 39 | 40 | 41 | 42 | 43 | 0 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 4.38531e-14 -1 -4.451336e-05 4.38531e-14 -1 -4.451336e-05 -4.38531e-14 1 4.451336e-05 -4.38531e-14 1 4.451336e-05 -1 -4.385301e-14 -2.011189e-15 -1 -4.385301e-14 -2.011189e-15 -2.009237e-15 -4.451336e-05 1 -2.009237e-15 -4.451336e-05 1 1 4.385301e-14 2.011189e-15 1 4.385301e-14 2.011189e-15 2.009237e-15 4.451336e-05 -1 2.009237e-15 4.451336e-05 -1 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | -10 -23.90175 13.51442 10 -23.9033 48.51442 -10 -23.9033 48.51442 10 -23.90175 13.51442 -10 -18.90175 13.51464 -10 -18.9033 48.51464 10 -18.90175 13.51464 10 -18.9033 48.51464 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 |

0 0 1 0 2 0 3 1 1 1 0 1 4 2 5 2 6 2 5 3 7 3 6 3 2 4 5 4 4 4 2 5 4 5 0 5 5 6 2 6 1 6 5 7 1 7 7 7 7 8 1 8 6 8 1 9 3 9 6 9 0 10 4 10 3 10 4 11 6 11 3 11

79 |
80 |
81 |
82 |
83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 |
105 | -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_pad_vis.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_pad_vis.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_base_link.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_base_link.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_gripper_coupling_vis.stl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_gripper_coupling_vis.stl -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm/link1.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm/link1.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm/link2.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm/link2.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm/link3.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm/link3.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm/link4.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm/link4.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm/link5.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm/link5.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm/link6.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm/link6.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm/link7.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm/link7.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm/link_base.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm/link_base.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm_gripper/base_link.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm_gripper/base_link.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm_gripper/left_finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm_gripper/left_finger.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm_gripper/left_inner_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm_gripper/left_inner_knuckle.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm_gripper/left_outer_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm_gripper/left_outer_knuckle.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm_gripper/right_finger.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm_gripper/right_finger.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm_gripper/right_inner_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm_gripper/right_inner_knuckle.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/mesh/xarm_gripper/right_outer_knuckle.STL: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/xarm_gripper/right_outer_knuckle.STL -------------------------------------------------------------------------------- /src/env/robot/assets/robot/pick_place.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /src/env/robot/assets/robot/push.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | -------------------------------------------------------------------------------- /src/env/robot/assets/robot/reach.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /src/env/robot/assets/robot/shelf_placing_far.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /src/env/robot/assets/robot/shelf_placing_near.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | -------------------------------------------------------------------------------- /src/env/robot/assets/robot/texture/block.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/texture/block.png -------------------------------------------------------------------------------- /src/env/robot/assets/robot/texture/block_hidden.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/texture/block_hidden.png -------------------------------------------------------------------------------- /src/env/robot/assets/robot/texture/carpet-black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/texture/carpet-black.png -------------------------------------------------------------------------------- /src/env/robot/assets/robot/texture/concrete.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/texture/concrete.png -------------------------------------------------------------------------------- /src/env/robot/assets/robot/texture/light_wood.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/texture/light_wood.png -------------------------------------------------------------------------------- /src/env/robot/assets/robot/texture/sponge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/texture/sponge.png -------------------------------------------------------------------------------- /src/env/robot/assets/robot/texture/table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/texture/table1.png -------------------------------------------------------------------------------- /src/env/robot/golf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from gym import utils 4 | from env.robot.base import BaseEnv, get_full_asset_path 5 | 6 | 7 | class GolfXYEnv(BaseEnv, utils.EzPickle): 8 | def __init__(self, xml_path, n_substeps=20, observation_type='image', reward_type='dense', image_size=84): 9 | BaseEnv.__init__(self, 10 | get_full_asset_path(xml_path), 11 | n_substeps=n_substeps, 12 | observation_type=observation_type, 13 | reward_type=reward_type, 14 | image_size=image_size, 15 | reset_free=False, 16 | distance_threshold=0.035, 17 | use_xyz=False, 18 | action_scale=0.1, 19 | has_object=True 20 | ) 21 | utils.EzPickle.__init__(self) 22 | self.default_z_offset = 0.2 23 | 24 | def compute_reward(self, achieved_goal, goal, info): 25 | d = self.goal_distance(achieved_goal, goal, self.use_xyz) 26 | if self.reward_type == 'sparse': 27 | return -(d > self.distance_threshold).astype(np.float32) 28 | else: 29 | penalty = -self._pos_ctrl_magnitude * self.action_penalty 30 | if self.reward_bonus and d <= self.distance_threshold: 31 | return np.around(1-d, 4) + penalty 32 | return np.around(-d, 4) + penalty 33 | 34 | def _step_callback(self): 35 | pass 36 | 37 | def _get_achieved_goal(self): 38 | return np.squeeze(self.sim.data.get_site_xpos('object0').copy()) 39 | 40 | def _limit_gripper(self, gripper_pos, pos_ctrl): 41 | if gripper_pos[0] > 1.4: 42 | pos_ctrl[0] = min(pos_ctrl[0], 0) 43 | if gripper_pos[0] < 1.2: 44 | pos_ctrl[0] = max(pos_ctrl[0], 0) 45 | if gripper_pos[1] > 0.06: 46 | pos_ctrl[1] = min(pos_ctrl[1], 0) 47 | if gripper_pos[1] < -0.2: 48 | pos_ctrl[1] = max(pos_ctrl[1], 0) 49 | return pos_ctrl 50 | 51 | def _sample_object_pos(self): 52 | object_xpos = self.center_of_table.copy() 53 | object_xpos[0] += self.np_random.uniform(-0.04, 0.04, size=1) 54 | object_xpos[1] += self.np_random.uniform(-0.04, 0.04, size=1) 55 | object_xpos[2] += 0.025 56 | object_qpos = self.sim.data.get_joint_qpos('object0:joint') 57 | object_quat = object_qpos[-4:] 58 | object_quat[0] = self.np_random.uniform(-1, 1, size=1) 59 | object_quat[3] = self.np_random.uniform(-1, 1, size=1) 60 | 61 | assert object_qpos.shape == (7,) 62 | object_qpos[:3] = object_xpos[:3] 63 | object_qpos[-4:] = object_quat 64 | self.sim.data.set_joint_qpos('object0:joint', object_qpos) 65 | 66 | def _sample_goal(self): 67 | goal = np.array([1.3, 0.4, 0.7]) 68 | goal[1] += self.np_random.uniform(-0.05, 0.025, size=1) 69 | return BaseEnv._sample_goal(self, goal) 70 | 71 | def _sample_initial_pos(self): 72 | gripper_target = self.center_of_table.copy() 73 | gripper_target[0] += self.np_random.uniform(-0.06, 0.06, size=1) 74 | gripper_target[1] += self.np_random.uniform(-0.2, -0.15, size=1) 75 | gripper_target[2] += self.default_z_offset 76 | BaseEnv._sample_initial_pos(self, gripper_target) 77 | -------------------------------------------------------------------------------- /src/env/robot/gym_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gym import error 4 | try: 5 | import mujoco_py 6 | except ImportError as e: 7 | raise error.DependencyNotInstalled("{}. (HINT: you need to install mujoco_py, and also perform the setup instructions here: https://github.com/openai/mujoco-py/.)".format(e)) 8 | 9 | 10 | def robot_get_obs(sim): 11 | """Returns all joint positions and velocities associated with 12 | a robot. 13 | """ 14 | if sim.data.qpos is not None and sim.model.joint_names: 15 | names = [n for n in sim.model.joint_names if n.startswith('robot')] 16 | return ( 17 | np.array([sim.data.get_joint_qpos(name) for name in names]), 18 | np.array([sim.data.get_joint_qvel(name) for name in names]), 19 | ) 20 | return np.zeros(0), np.zeros(0) 21 | 22 | 23 | def ctrl_set_action(sim, action): 24 | """For torque actuators it copies the action into mujoco ctrl field. 25 | For position actuators it sets the target relative to the current qpos. 26 | """ 27 | if sim.model.nmocap > 0: 28 | _, action = np.split(action, (sim.model.nmocap * 7, )) 29 | if sim.data.ctrl is not None: 30 | for i in range(action.shape[0]): 31 | if sim.model.actuator_biastype[i] == 0: 32 | sim.data.ctrl[i] = action[i] 33 | else: 34 | idx = sim.model.jnt_qposadr[sim.model.actuator_trnid[i, 0]] 35 | sim.data.ctrl[i] = sim.data.qpos[idx] + action[i] 36 | 37 | 38 | def mocap_set_action(sim, action): 39 | """The action controls the robot using mocaps. Specifically, bodies 40 | on the robot (for example the gripper wrist) is controlled with 41 | mocap bodies. In this case the action is the desired difference 42 | in position and orientation (quaternion), in world coordinates, 43 | of the of the target body. The mocap is positioned relative to 44 | the target body according to the delta, and the MuJoCo equality 45 | constraint optimizer tries to center the welded body on the mocap. 46 | """ 47 | if sim.model.nmocap > 0: 48 | action, _ = np.split(action, (sim.model.nmocap * 7, )) 49 | action = action.reshape(sim.model.nmocap, 7) 50 | 51 | pos_delta = action[:, :3] 52 | quat_delta = action[:, 3:] 53 | 54 | reset_mocap2body_xpos(sim) 55 | sim.data.mocap_pos[:] = sim.data.mocap_pos + pos_delta 56 | sim.data.mocap_quat[:] = sim.data.mocap_quat + quat_delta 57 | 58 | 59 | def reset_mocap_welds(sim): 60 | """Resets the mocap welds that we use for actuation. 61 | """ 62 | if sim.model.nmocap > 0 and sim.model.eq_data is not None: 63 | for i in range(sim.model.eq_data.shape[0]): 64 | if sim.model.eq_type[i] == mujoco_py.const.EQ_WELD: 65 | sim.model.eq_data[i, :] = np.array( 66 | [0., 0., 0., 1., 0., 0., 0.]) 67 | sim.forward() 68 | 69 | 70 | def reset_mocap2body_xpos(sim): 71 | """Resets the position and orientation of the mocap bodies to the same 72 | values as the bodies they're welded to. 73 | """ 74 | 75 | if (sim.model.eq_type is None or 76 | sim.model.eq_obj1id is None or 77 | sim.model.eq_obj2id is None): 78 | return 79 | 80 | # For all weld constraints 81 | for eq_type, obj1_id, obj2_id in zip(sim.model.eq_type, 82 | sim.model.eq_obj1id, 83 | sim.model.eq_obj2id): 84 | if eq_type != mujoco_py.const.EQ_WELD: 85 | continue 86 | 87 | body2 = sim.model.body_id2name(obj2_id) 88 | if body2 == 'B0' or body2== 'B9' or body2 == 'B1': 89 | continue 90 | 91 | mocap_id = sim.model.body_mocapid[obj1_id] 92 | 93 | if mocap_id != -1: 94 | # obj1 is the mocap, obj2 is the welded body 95 | body_idx = obj2_id 96 | else: 97 | # obj2 is the mocap, obj1 is the welded body 98 | mocap_id = sim.model.body_mocapid[obj2_id] 99 | body_idx = obj1_id 100 | 101 | assert (mocap_id != -1) 102 | sim.data.mocap_pos[mocap_id][:] = sim.data.body_xpos[body_idx] 103 | sim.data.mocap_quat[mocap_id][:] = sim.data.body_xquat[body_idx] 104 | -------------------------------------------------------------------------------- /src/env/robot/hammer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from gym import utils 4 | from env.robot.base import BaseEnv, get_full_asset_path 5 | 6 | ''' 7 | Not completely updated, consider looking at hammer_all task for latest update 8 | ''' 9 | 10 | class HammerEnv(BaseEnv, utils.EzPickle): 11 | def __init__(self, xml_path, cameras, n_substeps=20, observation_type='image', reward_type='dense', image_size=84, use_xyz=False, render=False): 12 | BaseEnv.__init__(self, 13 | get_full_asset_path(xml_path), 14 | n_substeps=n_substeps, 15 | observation_type=observation_type, 16 | reward_type=reward_type, 17 | image_size=image_size, 18 | reset_free=False, 19 | cameras=cameras, 20 | render=render, 21 | use_xyz=use_xyz, 22 | has_object=True 23 | ) 24 | self.state_dim = (26,) if self.use_xyz else (20,) 25 | utils.EzPickle.__init__(self) 26 | 27 | def compute_reward(self, achieved_goal, goal, info): 28 | eef_pos = self.sim.data.get_site_xpos('tool').copy() 29 | object_pos = self.sim.data.get_site_xpos('nail_target').copy() 30 | gripper_angle = self.sim.data.get_joint_qpos('right_outer_knuckle_joint').copy() 31 | goal_pos = goal.copy() 32 | d_eef_obj = self.goal_distance(eef_pos, object_pos, self.use_xyz) 33 | d_eef_obj_xy = self.goal_distance(eef_pos, object_pos, use_xyz = False) 34 | #d_obj_goal_xy = self.goal_distance(object_pos, goal_pos, use_xyz=False) 35 | d_obj_goal_z = np.abs(object_pos[2] - goal_pos[2]) 36 | eef_z = eef_pos[2] - self.center_of_table.copy()[2] - self.default_z_offset 37 | obj_z = object_pos[2] - self.center_of_table.copy()[2] - self.default_z_offset 38 | 39 | reward = -0.1*np.square(self._pos_ctrl_magnitude) # action penalty 40 | # Get end effector close to the nail 41 | reward += -2 * d_eef_obj 42 | 43 | if d_obj_goal_z<=self.distance_threshold: 44 | reward += -2 * d_obj_goal_z 45 | 46 | return reward 47 | 48 | def _get_state_obs(self): 49 | cot_pos = self.center_of_table.copy() 50 | dt = self.sim.nsubsteps * self.sim.model.opt.timestep 51 | 52 | eef_pos = self.sim.data.get_site_xpos('tool') #- cot_pos 53 | eef_velp = self.sim.data.get_site_xvelp('tool') * dt 54 | goal_pos = self.goal# - cot_pos 55 | gripper_angle = self.sim.data.get_joint_qpos('right_outer_knuckle_joint') 56 | 57 | obj_pos = self.sim.data.get_site_xpos('nail_target') #- cot_pos 58 | obj_rot = self.sim.data.get_joint_qpos('nail_board:joint')[-4:] 59 | obj_velp = self.sim.data.get_site_xvelp('nail_target') * dt 60 | obj_velr = self.sim.data.get_site_xvelr('nail_target') * dt 61 | 62 | if not self.use_xyz: 63 | eef_pos = eef_pos[:2] 64 | eef_velp = eef_velp[:2] 65 | goal_pos = goal_pos[:2] 66 | obj_pos = obj_pos[:2] 67 | obj_velp = obj_velp[:2] 68 | obj_velr = obj_velr[:2] 69 | 70 | values = np.array([ 71 | self.goal_distance(eef_pos, goal_pos, self.use_xyz), 72 | self.goal_distance(obj_pos, goal_pos, self.use_xyz), 73 | self.goal_distance(eef_pos, obj_pos, self.use_xyz), 74 | gripper_angle 75 | ]) 76 | 77 | return np.concatenate([ 78 | eef_pos, eef_velp, goal_pos, obj_pos, obj_rot, obj_velp, obj_velr, values 79 | ], axis=0) 80 | 81 | def _reset_sim(self): 82 | 83 | return BaseEnv._reset_sim(self) 84 | 85 | def _get_achieved_goal(self): 86 | return np.squeeze(self.sim.data.get_site_xpos('nail_target').copy()) 87 | 88 | def _sample_object_pos(self): 89 | return None 90 | 91 | def _sample_goal(self): 92 | # Goal is the peg goal site position 93 | object_qpos = self.sim.data.get_joint_qpos('nail_board:joint') 94 | sampled = object_qpos[:3].copy() 95 | object_quat = object_qpos[-4:] 96 | sampled[0] += self.np_random.uniform(-0.05, 0.05, size=1) 97 | sampled[1] += self.np_random.uniform(-0.1, 0.1, size=1) 98 | object_qpos[:3] = sampled[:3].copy() 99 | object_qpos[-4:] = object_quat 100 | self.sim.data.set_joint_qpos('nail_board:joint', object_qpos) 101 | 102 | peg_site_xpos = self.sim.data.get_site_xpos('nail_goal') 103 | goal = peg_site_xpos.copy() 104 | 105 | return BaseEnv._sample_goal(self, goal) 106 | 107 | def _sample_initial_pos(self): 108 | gripper_target = self.center_of_table.copy() - np.array([0.3, 0, 0]) 109 | gripper_target[0] += self.np_random.uniform(-0.15, -0.05, size=1) 110 | gripper_target[1] += self.np_random.uniform(-0.05, 0.05, size=1) 111 | gripper_target[2] += self.default_z_offset 112 | if self.use_xyz: 113 | gripper_target[2] += self.np_random.uniform(0, 0.1, size=1) 114 | BaseEnv._sample_initial_pos(self, gripper_target) 115 | -------------------------------------------------------------------------------- /src/env/robot/hammer_all.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from gym import utils 4 | from env.robot.base import BaseEnv, get_full_asset_path 5 | 6 | 7 | class HammerAllEnv(BaseEnv, utils.EzPickle): 8 | def __init__(self, xml_path, cameras, n_substeps=20, observation_type='image', reward_type='dense', image_size=84, use_xyz=False, render=False): 9 | BaseEnv.__init__(self, 10 | get_full_asset_path(xml_path), 11 | n_substeps=n_substeps, 12 | observation_type=observation_type, 13 | reward_type=reward_type, 14 | image_size=image_size, 15 | reset_free=False, 16 | cameras=cameras, 17 | render=render, 18 | use_xyz=use_xyz, 19 | has_object=True 20 | ) 21 | self.state_dim = (26,) if self.use_xyz else (20,) 22 | utils.EzPickle.__init__(self) 23 | 24 | def compute_reward(self, achieved_goal, goal, info): 25 | eef_pos = self.sim.data.get_site_xpos('tool').copy() 26 | object_pos = self.sim.data.get_site_xpos('nail_target'+ str(self.nail_id)).copy() 27 | goal_pos = goal.copy() 28 | d_eef_obj = self.goal_distance(eef_pos, object_pos, self.use_xyz) 29 | d_obj_goal_z = self.goal_distance(object_pos, goal_pos, use_xyz=True) 30 | #d_obj_goal_z = np.abs(object_pos[2] - goal_pos[2]) 31 | 32 | reward = -0.1*np.square(self._pos_ctrl_magnitude) # action penalty 33 | # Get end effector close to the nail 34 | reward += -2 * d_eef_obj 35 | 36 | reward += -2 * d_obj_goal_z 37 | 38 | return reward 39 | 40 | def _get_state_obs(self): 41 | cot_pos = self.center_of_table.copy() 42 | dt = self.sim.nsubsteps * self.sim.model.opt.timestep 43 | 44 | eef_pos = self.sim.data.get_site_xpos('tool') #- cot_pos 45 | eef_velp = self.sim.data.get_site_xvelp('tool') * dt 46 | goal_pos = self.goal# - cot_pos 47 | gripper_angle = self.sim.data.get_joint_qpos('right_outer_knuckle_joint') 48 | 49 | obj_pos = self.sim.data.get_site_xpos('nail_target'+ str(self.nail_id)) #- cot_pos 50 | obj_rot = self.sim.data.get_joint_qpos('nail_board:joint')[-4:] 51 | obj_velp = self.sim.data.get_site_xvelp('nail_target'+ str(self.nail_id)) * dt 52 | obj_velr = self.sim.data.get_site_xvelr('nail_target'+ str(self.nail_id)) * dt 53 | 54 | if not self.use_xyz: 55 | eef_pos = eef_pos[:2] 56 | eef_velp = eef_velp[:2] 57 | goal_pos = goal_pos[:2] 58 | obj_pos = obj_pos[:2] 59 | obj_velp = obj_velp[:2] 60 | obj_velr = obj_velr[:2] 61 | 62 | values = np.array([ 63 | self.goal_distance(eef_pos, goal_pos, self.use_xyz), 64 | self.goal_distance(obj_pos, goal_pos, self.use_xyz), 65 | self.goal_distance(eef_pos, obj_pos, self.use_xyz), 66 | gripper_angle 67 | ]) 68 | 69 | return np.concatenate([ 70 | eef_pos, eef_velp, goal_pos, obj_pos, obj_rot, obj_velp, obj_velr, values 71 | ], axis=0) 72 | 73 | def _reset_sim(self): 74 | 75 | return BaseEnv._reset_sim(self) 76 | 77 | def _is_success(self, achieved_goal, desired_goal): 78 | d = self.goal_distance(achieved_goal, desired_goal, self.use_xyz) 79 | return (d < 0.01).astype(np.float32) 80 | 81 | def _get_achieved_goal(self): 82 | return np.squeeze(self.sim.data.get_site_xpos('nail_target'+ str(self.nail_id)).copy()) 83 | 84 | def _sample_object_pos(self): 85 | return None 86 | 87 | 88 | def _sample_goal(self, new=True): 89 | # Goal is the peg goal site position 90 | 91 | # Randomly sample the position of the nail box 92 | object_qpos = self.sim.data.get_joint_qpos('nail_board:joint') 93 | sampled = object_qpos[:3].copy() 94 | object_quat = object_qpos[-4:] 95 | if new: 96 | sampled[0] += self.np_random.uniform(-0.05, 0.05, size=1) 97 | sampled[1] += self.np_random.uniform(-0.1, 0.1, size=1) 98 | 99 | object_qpos[:3] = sampled[:3].copy() 100 | object_qpos[-4:] = object_quat 101 | self.sim.data.set_joint_qpos('nail_board:joint', object_qpos) 102 | 103 | if new: 104 | # Select 1 of the nails randomly 105 | self.nail_id = np.random.randint(4) + 1 106 | 107 | nail_qpos = self.sim.data.get_joint_qpos('nail_dir'+ str(self.nail_id)) 108 | self.sim.data.set_joint_qpos('nail_dir'+ str(self.nail_id), -0.046) 109 | 110 | peg_site_xpos = self.sim.data.get_site_xpos('nail_goal'+ str(self.nail_id)) 111 | goal = peg_site_xpos.copy() 112 | 113 | 114 | return BaseEnv._sample_goal(self, goal) 115 | 116 | def _sample_initial_pos(self): 117 | gripper_target = np.array([1.2561169, 0.3, 0.69603332]) 118 | gripper_target[0] += self.np_random.uniform(-0.05, 0.1, size=1) 119 | gripper_target[1] += self.np_random.uniform(-0.1, 0.1, size=1) 120 | if self.use_xyz: 121 | gripper_target[2] += self.np_random.uniform(-0.05, 0.1, size=1) 122 | BaseEnv._sample_initial_pos(self, gripper_target) 123 | -------------------------------------------------------------------------------- /src/env/robot/peg_in_box.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from gym import utils 4 | from env.robot.base import BaseEnv, get_full_asset_path 5 | 6 | 7 | class PegBoxEnv(BaseEnv, utils.EzPickle): 8 | def __init__(self, xml_path, cameras, n_substeps=20, observation_type='image', reward_type='dense', image_size=84, use_xyz=False, render=False): 9 | self.sample_large = 1 10 | BaseEnv.__init__(self, 11 | get_full_asset_path(xml_path), 12 | n_substeps=n_substeps, 13 | observation_type=observation_type, 14 | reward_type=reward_type, 15 | image_size=image_size, 16 | reset_free=False, 17 | cameras=cameras, 18 | render=render, 19 | use_xyz=use_xyz, 20 | has_object=True 21 | ) 22 | self.state_dim = (26,) if self.use_xyz else (20,) 23 | self.distance_threshold = 0.08 24 | self.distance_threshold_2 = 0.05 25 | self.distance_threshold_3 = 0.15 26 | 27 | utils.EzPickle.__init__(self) 28 | 29 | def _is_success(self, achieved_goal, desired_goal): 30 | 31 | achieved_goal = self.sim.data.get_site_xpos('object0').copy() 32 | 33 | return BaseEnv._is_success(self, achieved_goal, desired_goal) 34 | 35 | def _is_success_2(self, desired_goal): 36 | 37 | achieved_goal = self.sim.data.get_site_xpos('object0').copy() 38 | 39 | return BaseEnv._is_success_2(self, achieved_goal, desired_goal) 40 | 41 | def _is_success_3(self, desired_goal): 42 | 43 | achieved_goal = self.sim.data.get_site_xpos('object0').copy() 44 | 45 | return BaseEnv._is_success_3(self, achieved_goal, desired_goal) 46 | 47 | 48 | def _get_distance(self, desired_goal): 49 | 50 | achieved_goal = self.sim.data.get_site_xpos('object0') 51 | 52 | return BaseEnv.goal_distance(self, achieved_goal, desired_goal, self.use_xyz) 53 | 54 | def step(self, action): 55 | 56 | obs, env_reward, done, info = super(PegBoxEnv, self).step(action) 57 | 58 | info['success_rate_05'] = self._is_success_2(self.goal) 59 | info['success_rate_15'] = self._is_success_3(self.goal) 60 | info['distance'] = self._get_distance(self.goal) 61 | 62 | return obs, env_reward, done, info 63 | 64 | 65 | def compute_reward(self, achieved_goal, goal, info): 66 | object_pos = self.sim.data.get_site_xpos('object0').copy() 67 | goal_pos = goal.copy() 68 | 69 | d_obj_goal_xy = self.goal_distance(object_pos, goal_pos, use_xyz=False) 70 | d_obj_goal_xyz = self.goal_distance(object_pos, goal_pos, use_xyz=True) 71 | 72 | obj_z = object_pos[2] - self.center_of_table.copy()[2] 73 | 74 | reward = -1*np.square(self._pos_ctrl_magnitude) # action penalty 75 | 76 | # Staged rewards 77 | reward += -4*d_obj_goal_xy # move towards box 78 | 79 | if d_obj_goal_xy<=0.05: 80 | reward += 10-20*d_obj_goal_xyz # place object in box 81 | 82 | return reward 83 | 84 | def _get_state_obs(self): 85 | cot_pos = self.center_of_table.copy() 86 | dt = self.sim.nsubsteps * self.sim.model.opt.timestep 87 | 88 | eef_pos = self.sim.data.get_site_xpos('grasp') 89 | eef_velp = self.sim.data.get_site_xvelp('grasp') * dt 90 | goal_pos = self.goal 91 | gripper_angle = self.sim.data.get_joint_qpos('right_outer_knuckle_joint') 92 | 93 | obj_pos = self.sim.data.get_site_xpos('object0') 94 | obj_rot = self.sim.data.get_joint_qpos('object0:joint')[-4:] 95 | obj_velp = self.sim.data.get_site_xvelp('object0') * dt 96 | obj_velr = self.sim.data.get_site_xvelr('object0') * dt 97 | 98 | if not self.use_xyz: 99 | eef_pos = eef_pos[:2] 100 | eef_velp = eef_velp[:2] 101 | goal_pos = goal_pos[:2] 102 | obj_pos = obj_pos[:2] 103 | obj_velp = obj_velp[:2] 104 | obj_velr = obj_velr[:2] 105 | 106 | values = np.array([ 107 | self.goal_distance(eef_pos, goal_pos, self.use_xyz), 108 | self.goal_distance(obj_pos, goal_pos, self.use_xyz), 109 | self.goal_distance(eef_pos, obj_pos, self.use_xyz), 110 | gripper_angle 111 | ]) 112 | 113 | return np.concatenate([ 114 | eef_pos, eef_velp, goal_pos, obj_pos, obj_rot, obj_velp, obj_velr, values 115 | ], axis=0) 116 | 117 | def _reset_sim(self): 118 | return BaseEnv._reset_sim(self) 119 | 120 | def _get_achieved_goal(self): 121 | return np.squeeze(self.sim.data.get_site_xpos('object0').copy()) 122 | 123 | def _sample_object_pos(self): # Object pos similar to gripper 124 | object_qpos = self.sim.data.get_joint_qpos('object0:joint') 125 | object_quat = object_qpos[-4:] 126 | 127 | object_qpos[0:3] = self.gripper_target[0:3] 128 | object_qpos[2] += -0.08 129 | object_qpos[-4:] = object_quat.copy() 130 | 131 | self.sim.data.set_joint_qpos('object0:joint', object_qpos) 132 | 133 | 134 | def _sample_goal(self, new=True): 135 | object_qpos = self.sim.data.get_joint_qpos('box_hole:joint') 136 | object_quat = object_qpos[-4:] 137 | 138 | if new: 139 | goal = np.array([1.605, 0.3, 0.62]) 140 | goal[0] += self.np_random.uniform(-0.05 - 0.05 * self.sample_large, 0.05 + 0.05 * self.sample_large, size=1) 141 | goal[1] += self.np_random.uniform(-0.1 - 0.1 * self.sample_large, 0.1 + 0.1 * self.sample_large, size=1) 142 | else: 143 | goal = object_qpos[:3].copy() 144 | 145 | object_qpos[:3] = goal[:3].copy() 146 | object_qpos[-4:] = object_quat 147 | 148 | self.sim.data.set_joint_qpos('box_hole:joint', object_qpos) 149 | goal[1] += 0.075 150 | goal[2] -= 0.035 151 | self.lift_height = 0.15 152 | 153 | return BaseEnv._sample_goal(self, goal) 154 | 155 | def _sample_initial_pos(self): 156 | gripper_target = np.array([1.2561169, 0.3, 0.69603332]) 157 | gripper_target[0] += self.np_random.uniform(-0.05, 0.1, size=1) 158 | gripper_target[1] += self.np_random.uniform(-0.1, 0.1, size=1) 159 | gripper_target[2] += 0.1 160 | if self.use_xyz: 161 | gripper_target[2] += self.np_random.uniform(-0.05, 0.05, size=1) 162 | self.gripper_target = gripper_target 163 | BaseEnv._sample_initial_pos(self, gripper_target) -------------------------------------------------------------------------------- /src/env/robot/push.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from gym import utils 4 | from env.robot.base import BaseEnv, get_full_asset_path 5 | 6 | 7 | 8 | class PushEnv(BaseEnv, utils.EzPickle): 9 | def __init__(self, xml_path, cameras, n_substeps=20, observation_type='image', reward_type='dense', image_size=84, use_xyz=False, render=False): 10 | self.sample_large = 1 11 | BaseEnv.__init__(self, 12 | get_full_asset_path(xml_path), 13 | n_substeps=n_substeps, 14 | observation_type=observation_type, 15 | reward_type=reward_type, 16 | image_size=image_size, 17 | reset_free=False, 18 | cameras=cameras, 19 | render=render, 20 | use_xyz=use_xyz, 21 | has_object=True 22 | ) 23 | self.state_dim = (26,) if self.use_xyz else (20,) 24 | self.max_z = 0.9 25 | 26 | self.distance_threshold = 0.1 27 | self.distance_threshold_2 = 0.05 28 | 29 | utils.EzPickle.__init__(self) 30 | 31 | def compute_reward(self, achieved_goal, goal, info): 32 | object_goal = self.sim.data.get_site_xpos('object0').copy() 33 | d = self.goal_distance(object_goal, goal, self.use_xyz) 34 | if self.reward_type == 'sparse': 35 | return -(d > self.distance_threshold).astype(np.float32) 36 | else: 37 | return np.around(-3*d - 0.5*np.square(self._pos_ctrl_magnitude), 4) 38 | 39 | def _get_state_obs(self): 40 | cot_pos = self.center_of_table.copy() 41 | dt = self.sim.nsubsteps * self.sim.model.opt.timestep 42 | 43 | eef_pos = self.sim.data.get_site_xpos('grasp') 44 | eef_velp = self.sim.data.get_site_xvelp('grasp') * dt 45 | goal_pos = self.goal 46 | gripper_angle = self.sim.data.get_joint_qpos('right_outer_knuckle_joint') 47 | 48 | obj_pos = self.sim.data.get_site_xpos('object0') 49 | obj_rot = self.sim.data.get_joint_qpos('object0:joint')[-4:] 50 | obj_velp = self.sim.data.get_site_xvelp('object0') * dt 51 | obj_velr = self.sim.data.get_site_xvelr('object0') * dt 52 | 53 | if not self.use_xyz: 54 | eef_pos = eef_pos[:2] 55 | eef_velp = eef_velp[:2] 56 | goal_pos = goal_pos[:2] 57 | obj_pos = obj_pos[:2] 58 | obj_velp = obj_velp[:2] 59 | obj_velr = obj_velr[:2] 60 | 61 | values = np.array([ 62 | self.goal_distance(eef_pos, goal_pos, self.use_xyz), 63 | self.goal_distance(obj_pos, goal_pos, self.use_xyz), 64 | self.goal_distance(eef_pos, obj_pos, self.use_xyz), 65 | gripper_angle 66 | ]) 67 | 68 | return np.concatenate([ 69 | eef_pos, eef_velp, goal_pos, obj_pos, obj_rot, obj_velp, obj_velr, values 70 | ], axis=0) 71 | 72 | def _get_achieved_goal(self): 73 | return np.squeeze(self.sim.data.get_site_xpos('object0').copy()) 74 | 75 | def _sample_object_pos(self): 76 | object_xpos = self.center_of_table.copy() - np.array([0.3, 0, 0]) 77 | 78 | object_xpos[0] += self.np_random.uniform(-0.05, 0.05 + 0.15 * self.sample_large, size=1) 79 | object_xpos[1] += self.np_random.uniform(-0.1 - 0.1 * self.sample_large, 0.1 + 0.1 * self.sample_large, size=1) 80 | object_xpos[2] += 0.08 81 | 82 | object_qpos = self.sim.data.get_joint_qpos('object0:joint') 83 | object_quat = object_qpos[-4:] 84 | 85 | assert object_qpos.shape == (7,) 86 | object_qpos[:3] = object_xpos[:3] 87 | object_qpos[-4:] = object_quat 88 | self.sim.data.set_joint_qpos('object0:joint', object_qpos) 89 | 90 | def _sample_goal(self, new=True): 91 | site_id = self.sim.model.site_name2id('target0') 92 | if new: 93 | goal = np.array([1.605, 0.3, 0.58]) 94 | goal[0] += self.np_random.uniform(-0.05 - 0.05 * self.sample_large, 0.05 + 0.05 * self.sample_large, size=1) 95 | goal[1] += self.np_random.uniform(-0.1 - 0.1 * self.sample_large, 0.1 + 0.1 * self.sample_large, size=1) 96 | else: 97 | goal = self.sim.data.get_site_xpos('target0') 98 | 99 | self.sim.model.site_pos[site_id] = goal 100 | self.sim.forward() 101 | 102 | return BaseEnv._sample_goal(self, goal) 103 | 104 | def _sample_initial_pos(self): 105 | gripper_target = np.array([1.2561169, 0.3, 0.62603332]) 106 | gripper_target[0] += self.np_random.uniform(-0.05, 0.1, size=1) 107 | gripper_target[1] += self.np_random.uniform(-0.1, 0.1, size=1) 108 | if self.use_xyz: 109 | gripper_target[2] += self.np_random.uniform(-0.05, 0.1, size=1) 110 | BaseEnv._sample_initial_pos(self, gripper_target) 111 | 112 | def _is_success(self, achieved_goal, desired_goal): 113 | achieved_goal = self.sim.data.get_site_xpos('object0') 114 | return BaseEnv._is_success(self, achieved_goal, desired_goal) 115 | 116 | def _is_success_2(self, desired_goal): 117 | achieved_goal = self.sim.data.get_site_xpos('object0') 118 | 119 | return BaseEnv._is_success_2(self, achieved_goal, desired_goal) 120 | 121 | def _get_distance(self, desired_goal): 122 | 123 | achieved_goal = self.sim.data.get_site_xpos('object0') 124 | 125 | return BaseEnv.goal_distance(self, achieved_goal, desired_goal, self.use_xyz) 126 | 127 | 128 | def step(self, action): 129 | 130 | obs, env_reward, done, info = super(PushEnv, self).step(action) 131 | info['success_rate_05'] = self._is_success_2(self.goal) 132 | info['distance'] = self._get_distance(self.goal) 133 | 134 | return obs, env_reward, done, info 135 | 136 | class PushNoGoalEnv(PushEnv): 137 | def __init__(self, *args, **kwargs): 138 | super().__init__(*args, **kwargs) 139 | 140 | def compute_reward(self, achieved_goal, goal, info): 141 | object_goal = self.sim.data.get_site_xpos('object0').copy() 142 | nogoal_goal = self.table_xpos.copy() 143 | nogoal_goal[0] += 0.2 144 | d = np.abs(object_goal[0] - nogoal_goal[0]) 145 | return np.around(-d, 4) 146 | 147 | def _sample_goal(self): 148 | goal = np.array([-10., -10., 0.]) 149 | self._pos_ctrl_magnitude = 0 # do not penalize at start of episode 150 | return goal 151 | -------------------------------------------------------------------------------- /src/env/robot/reach.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from gym import utils 4 | from env.robot.base import BaseEnv, get_full_asset_path 5 | 6 | 7 | 8 | class ReachEnv(BaseEnv, utils.EzPickle): 9 | def __init__(self, xml_path, cameras, n_substeps=20, observation_type='image', reward_type='dense', image_size=84, use_xyz=False, render=False): 10 | self.sample_large = 1 11 | BaseEnv.__init__(self, 12 | get_full_asset_path(xml_path), 13 | n_substeps=n_substeps, 14 | observation_type=observation_type, 15 | reward_type=reward_type, 16 | image_size=image_size, 17 | reset_free=False, 18 | cameras=cameras, 19 | render=render, 20 | use_xyz=use_xyz 21 | ) 22 | self.state_dim = (11,) if self.use_xyz else (8,) 23 | utils.EzPickle.__init__(self) 24 | 25 | def compute_reward(self, achieved_goal, goal, info): 26 | d = self.goal_distance(achieved_goal, goal, self.use_xyz) 27 | if self.reward_type == 'sparse': 28 | return -(d > self.distance_threshold).astype(np.float32) 29 | else: 30 | return np.around(-3*d - 0.5*np.square(self._pos_ctrl_magnitude), 4) 31 | 32 | def _get_state_obs(self): 33 | dt = self.sim.nsubsteps * self.sim.model.opt.timestep 34 | 35 | eef_pos = self.sim.data.get_site_xpos('grasp') 36 | eef_velp = self.sim.data.get_site_xvelp('grasp') * dt 37 | goal_pos = self.goal 38 | gripper_angle = self.sim.data.get_joint_qpos('right_outer_knuckle_joint') 39 | 40 | if not self.use_xyz: 41 | eef_pos = eef_pos[:2] 42 | eef_velp = eef_velp[:2] 43 | goal_pos = goal_pos[:2] 44 | 45 | values = np.array([ 46 | self.goal_distance(eef_pos, goal_pos, self.use_xyz), 47 | gripper_angle 48 | ]) 49 | 50 | return np.concatenate([ 51 | eef_pos, eef_velp, goal_pos, values 52 | ], axis=0) 53 | 54 | def _get_achieved_goal(self): 55 | return self.sim.data.get_site_xpos('grasp').copy() 56 | 57 | def _sample_goal(self, new=True): 58 | site_id = self.sim.model.site_name2id('target0') 59 | 60 | if new: 61 | goal = np.array([1.605, 0.3, 0.58]) 62 | goal[0] += self.np_random.uniform(-0.05 - 0.05 * self.sample_large, 0.05 + 0.05 * self.sample_large, size=1) 63 | goal[1] += self.np_random.uniform(-0.1 - 0.1 * self.sample_large, 0.1 + 0.1 * self.sample_large, size=1) 64 | else: 65 | goal = self.sim.data.get_site_xpos('target0') 66 | 67 | 68 | self.sim.model.site_pos[site_id] = goal 69 | self.sim.forward() 70 | 71 | return BaseEnv._sample_goal(self, goal) 72 | 73 | def _sample_initial_pos(self): 74 | gripper_target = np.array([1.2561169, 0.3, 0.62603332]) 75 | gripper_target[0] += self.np_random.uniform(-0.05, 0.1, size=1) 76 | gripper_target[1] += self.np_random.uniform(-0.1, 0.1, size=1) 77 | if self.use_xyz: 78 | gripper_target[2] += self.np_random.uniform(-0.05, 0.1, size=1) 79 | BaseEnv._sample_initial_pos(self, gripper_target) 80 | 81 | 82 | def _is_success(self, achieved_goal, desired_goal): 83 | achieved_goal = self.sim.data.get_site_xpos('grasp').copy() 84 | return BaseEnv._is_success(self, achieved_goal, desired_goal) 85 | 86 | def _is_success_2(self, desired_goal): 87 | achieved_goal = self.sim.data.get_site_xpos('grasp').copy() 88 | 89 | return BaseEnv._is_success_2(self, achieved_goal, desired_goal) 90 | 91 | def _get_distance(self, desired_goal): 92 | 93 | achieved_goal = self.sim.data.get_site_xpos('grasp').copy() 94 | 95 | return BaseEnv.goal_distance(self, achieved_goal, desired_goal, self.use_xyz) 96 | 97 | def step(self, action): 98 | 99 | obs, env_reward, done, info = super(ReachEnv, self).step(action) 100 | info['success_rate_05'] = self._is_success_2(self.goal) 101 | info['distance'] = self._get_distance(self.goal) 102 | 103 | return obs, env_reward, done, info 104 | 105 | class ReachMovingTargetEnv(ReachEnv): 106 | def __init__(self, *args, **kwargs): 107 | super().__init__(*args, **kwargs) 108 | self.set_velocity() 109 | 110 | def set_velocity(self): 111 | self.curr_vel = 0.0025 * np.ones(2) 112 | 113 | def _sample_goal(self): 114 | self.set_velocity() 115 | return ReachEnv._sample_goal(self) 116 | 117 | def _step_callback(self): 118 | self.set_goal() 119 | 120 | def set_goal(self): 121 | curr_goal = self.goal 122 | 123 | if (curr_goal[0] >= 1.4 and self.curr_vel[0] > 0) \ 124 | or curr_goal[0] <= 1.2 and self.curr_vel[0] < 0: 125 | self.curr_vel[0] = -1 * self.curr_vel[0] 126 | if (curr_goal[1] >= 0.2 and self.curr_vel[1] > 0) \ 127 | or curr_goal[1] <= -0.2 and self.curr_vel[1] < 0: 128 | self.curr_vel[1] = -1 * self.curr_vel[1] 129 | self.goal[0] += self.curr_vel[0] 130 | self.goal[1] += self.curr_vel[1] 131 | -------------------------------------------------------------------------------- /src/env/robot/registration.py: -------------------------------------------------------------------------------- 1 | from gym.envs.registration import register 2 | 3 | REGISTERED_ROBOT_ENVS = False 4 | 5 | 6 | def register_robot_envs(n_substeps=20, observation_type='image', reward_type='dense', image_size=84, use_xyz=False): 7 | global REGISTERED_ROBOT_ENVS 8 | if REGISTERED_ROBOT_ENVS: 9 | return 10 | 11 | register( 12 | id='RobotLift-v0', 13 | entry_point='env.robot.lift:LiftEnv', 14 | kwargs=dict( 15 | xml_path='robot/lift.xml', 16 | n_substeps=n_substeps, 17 | observation_type=observation_type, 18 | reward_type=reward_type, 19 | image_size=image_size, 20 | use_xyz=use_xyz 21 | ) 22 | ) 23 | 24 | register( 25 | id='RobotPickplace-v0', 26 | entry_point='env.robot.pick_place:PickPlaceEnv', 27 | kwargs=dict( 28 | xml_path='robot/pick_place.xml', 29 | n_substeps=n_substeps, 30 | observation_type=observation_type, 31 | reward_type=reward_type, 32 | image_size=image_size, 33 | use_xyz=use_xyz 34 | ) 35 | ) 36 | 37 | register( 38 | id='RobotPegbox-v0', 39 | entry_point='env.robot.peg_in_box:PegBoxEnv', 40 | kwargs=dict( 41 | xml_path='robot/peg_in_box.xml', 42 | n_substeps=n_substeps, 43 | observation_type=observation_type, 44 | reward_type=reward_type, 45 | image_size=image_size, 46 | use_xyz=use_xyz 47 | 48 | ) 49 | ) 50 | 51 | register( 52 | id='RobotHammer-v0', 53 | entry_point='env.robot.hammer:HammerEnv', 54 | kwargs=dict( 55 | xml_path='robot/hammer.xml', 56 | n_substeps=n_substeps, 57 | observation_type=observation_type, 58 | reward_type=reward_type, 59 | image_size=image_size, 60 | use_xyz=use_xyz 61 | ) 62 | ) 63 | 64 | register( 65 | id='RobotHammerall-v0', 66 | entry_point='env.robot.hammer_all:HammerAllEnv', 67 | kwargs=dict( 68 | xml_path='robot/hammer_all.xml', 69 | n_substeps=n_substeps, 70 | observation_type=observation_type, 71 | reward_type=reward_type, 72 | image_size=image_size, 73 | use_xyz=use_xyz 74 | ) 75 | ) 76 | 77 | register( 78 | id='RobotReach-v0', 79 | entry_point='env.robot.reach:ReachEnv', 80 | kwargs=dict( 81 | xml_path='robot/reach.xml', 82 | n_substeps=n_substeps, 83 | observation_type=observation_type, 84 | reward_type=reward_type, 85 | image_size=image_size, 86 | use_xyz=use_xyz 87 | ) 88 | ) 89 | 90 | register( 91 | id='RobotReachmovingtarget-v0', 92 | entry_point='env.robot.reach:ReachMovingTargetEnv', 93 | kwargs=dict( 94 | xml_path='robot/reach.xml', 95 | n_substeps=n_substeps, 96 | observation_type=observation_type, 97 | reward_type=reward_type, 98 | image_size=image_size, 99 | use_xyz=use_xyz 100 | ) 101 | ) 102 | 103 | register( 104 | id='RobotPush-v0', 105 | entry_point='env.robot.push:PushEnv', 106 | kwargs=dict( 107 | xml_path='robot/push.xml', 108 | n_substeps=n_substeps, 109 | observation_type=observation_type, 110 | reward_type=reward_type, 111 | image_size=image_size, 112 | use_xyz=use_xyz 113 | ) 114 | ) 115 | 116 | register( 117 | id='RobotPushnogoal-v0', 118 | entry_point='env.robot.push:PushNoGoalEnv', 119 | kwargs=dict( 120 | xml_path='robot/push.xml', 121 | n_substeps=n_substeps, 122 | observation_type=observation_type, 123 | reward_type=reward_type, 124 | image_size=image_size, 125 | use_xyz=use_xyz 126 | ) 127 | ) 128 | 129 | # --- Shelf Placing Task Class --- # 130 | 131 | # classic view 132 | register( 133 | id='RobotShelfplacing-v0', 134 | entry_point='env.robot.shelf_placing:ShelfPlacingEnv', 135 | kwargs=dict( 136 | xml_path='robot/shelf_placing_classic.xml', 137 | n_substeps=n_substeps, 138 | observation_type=observation_type, 139 | reward_type=reward_type, 140 | image_size=image_size, 141 | use_xyz=use_xyz 142 | ) 143 | ) 144 | 145 | # a near view 146 | register( 147 | id='RobotShelfplacingnear-v0', 148 | entry_point='env.robot.shelf_placing:ShelfPlacingEnv', 149 | kwargs=dict( 150 | xml_path='robot/shelf_placing_near.xml', 151 | n_substeps=n_substeps, 152 | observation_type=observation_type, 153 | reward_type=reward_type, 154 | image_size=image_size, 155 | use_xyz=use_xyz 156 | ) 157 | ) 158 | 159 | # a far view 160 | register( 161 | id='RobotShelfplacingfar-v0', 162 | entry_point='env.robot.shelf_placing:ShelfPlacingEnv', 163 | kwargs=dict( 164 | xml_path='robot/shelf_placing_far.xml', 165 | n_substeps=n_substeps, 166 | observation_type=observation_type, 167 | reward_type=reward_type, 168 | image_size=image_size, 169 | use_xyz=use_xyz 170 | ) 171 | ) 172 | 173 | # a task based on ShelfPlacing 174 | register( 175 | id='RobotShelfgoandback-v0', 176 | entry_point='env.robot.shelf_go_and_back:ShelfGoAndBackEnv', 177 | kwargs=dict( 178 | xml_path='robot/shelf_placing_classic.xml', # the same 179 | n_substeps=n_substeps, 180 | observation_type=observation_type, 181 | reward_type=reward_type, 182 | image_size=image_size, 183 | use_xyz=use_xyz 184 | ) 185 | ) 186 | 187 | REGISTERED_ROBOT_ENVS = True 188 | -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import json 3 | import os 4 | import torch 5 | from termcolor import colored 6 | 7 | FORMAT_CONFIG = { 8 | 'rl': { 9 | 'train': [ 10 | ('episode', 'E', 'int'), ('step', 'S', 'int'), 11 | ('duration', 'D', 'time'), ('episode_reward', 'R', 'float'), 12 | ('success_rate', 'S', 'float'), ('actor_loss', 'ALOSS', 'float'), 13 | ('critic_loss', 'CLOSS', 'float'), ('aux_loss', 'AUXLOSS', 'float'), ('predictor_loss', 'PLOSS', 'float') 14 | ], 15 | 'eval': [ 16 | ('step', 'S', 'int'), ('episode_reward', 'ER', 'float'), 17 | ('episode_reward_test_env', 'ERTEST', 'float') 18 | ] 19 | } 20 | } 21 | 22 | 23 | class AverageMeter(object): 24 | def __init__(self): 25 | self._sum = 0 26 | self._count = 0 27 | 28 | def update(self, value, n=1): 29 | self._sum += value 30 | self._count += n 31 | 32 | def value(self): 33 | return self._sum / max(1, self._count) 34 | 35 | 36 | class MetersGroup(object): 37 | def __init__(self, file_name, formating): 38 | self._file_name = file_name 39 | self._formating = formating 40 | self._meters = defaultdict(AverageMeter) 41 | 42 | def log(self, key, value, n=1): 43 | self._meters[key].update(value, n) 44 | 45 | def _prime_meters(self): 46 | data = dict() 47 | for key, meter in self._meters.items(): 48 | if key.startswith('train'): 49 | key = key[len('train') + 1:] 50 | else: 51 | key = key[len('eval') + 1:] 52 | key = key.replace('/', '_') 53 | data[key] = meter.value() 54 | return data 55 | 56 | def _dump_to_file(self, data): 57 | with open(self._file_name, 'a') as f: 58 | f.write(json.dumps(data) + '\n') 59 | 60 | def _format(self, key, value, ty): 61 | template = '%s: ' 62 | if ty == 'int': 63 | template += '%d' 64 | elif ty == 'float': 65 | template += '%.04f' 66 | elif ty == 'time': 67 | template += '%.01f s' 68 | else: 69 | raise 'invalid format type: %s' % ty 70 | return template % (key, value) 71 | 72 | def _dump_to_console(self, data, prefix): 73 | prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green') 74 | pieces = ['{:5}'.format(prefix)] 75 | for key, disp_key, ty in self._formating: 76 | value = data.get(key, 0) 77 | pieces.append(self._format(disp_key, value, ty)) 78 | print('| %s' % (' | '.join(pieces))) 79 | 80 | def dump(self, step, prefix): 81 | if len(self._meters) == 0: 82 | return 83 | data = self._prime_meters() 84 | data['step'] = step 85 | self._dump_to_file(data) 86 | self._dump_to_console(data, prefix) 87 | self._meters.clear() 88 | 89 | 90 | class Logger(object): 91 | def __init__(self, log_dir, config='rl'): 92 | self._log_dir = log_dir 93 | self._train_mg = MetersGroup( 94 | os.path.join(log_dir, 'train.log'), 95 | formating=FORMAT_CONFIG[config]['train'] 96 | ) 97 | self._eval_mg = MetersGroup( 98 | os.path.join(log_dir, 'eval.log'), 99 | formating=FORMAT_CONFIG[config]['eval'] 100 | ) 101 | 102 | def log(self, key, value, step, n=1): 103 | assert key.startswith('train') or key.startswith('eval') 104 | if type(value) == torch.Tensor: 105 | value = value.item() 106 | mg = self._train_mg if key.startswith('train') else self._eval_mg 107 | mg.log(key, value, n) 108 | 109 | def log_param(self, key, param, step): 110 | self.log_histogram(key + '_w', param.weight.data, step) 111 | if hasattr(param.weight, 'grad') and param.weight.grad is not None: 112 | self.log_histogram(key + '_w_g', param.weight.grad.data, step) 113 | if hasattr(param, 'bias'): 114 | self.log_histogram(key + '_b', param.bias.data, step) 115 | if hasattr(param.bias, 'grad') and param.bias.grad is not None: 116 | self.log_histogram(key + '_b_g', param.bias.grad.data, step) 117 | 118 | def dump(self, step): 119 | self._train_mg.dump(step, 'train') 120 | self._eval_mg.dump(step, 'eval') 121 | -------------------------------------------------------------------------------- /src/state.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | class StateRecorder(object): 5 | def __init__(self, dir_name, height=256, width=256, fps=30): 6 | self.dir_name = dir_name 7 | self.save_dir = dir_name if dir_name else None 8 | self.frames = [] 9 | 10 | def init(self, enabled=True): 11 | self.frames = [] 12 | self.enabled = self.save_dir is not None and enabled 13 | 14 | def record(self, state): 15 | if self.enabled: 16 | frame = state 17 | 18 | self.frames.append(frame) 19 | # frame = env.render(mode="rgb_array") 20 | # self.frames.append(frame) 21 | 22 | def save(self, file_name): 23 | if self.enabled: 24 | path = os.path.join(self.save_dir, file_name) 25 | print(f"Shape of frames {len(self.frames)}") 26 | np.save(path, self.frames) 27 | #imageio.mimsave(path, self.frames, fps=self.fps -------------------------------------------------------------------------------- /src/train_policy.py: -------------------------------------------------------------------------------- 1 | ''' 2 | learn a policy w/ learned reward 3 | ''' 4 | 5 | from absl import app 6 | from absl import flags 7 | from absl import logging as logger 8 | from configs.constants import * 9 | from ml_collections.config_flags import DEFINE_config_file 10 | import logging 11 | 12 | 13 | FLAGS = flags.FLAGS 14 | 15 | flags.DEFINE_string("device", "cuda:0", "The compute device.") 16 | 17 | formatter = logging.Formatter( 18 | fmt="[%(levelname)s] [%(asctime)s] [%(module)s.py:%(lineno)s] %(message)s", 19 | datefmt="%b-%d-%y %H:%M:%S" 20 | ) 21 | 22 | logger.get_absl_handler().setFormatter(formatter) 23 | 24 | def main(_): 25 | 26 | config = FLAGS.config 27 | 28 | 29 | if __name__ == "__main__": 30 | app.run(main) -------------------------------------------------------------------------------- /src/video.py: -------------------------------------------------------------------------------- 1 | import imageio 2 | import os 3 | import augmentations 4 | import torch 5 | import numpy as np 6 | class VideoRecorder(object): 7 | def __init__(self, dir_name, height=448, width=448, camera_id=0, fps=25): 8 | self.dir_name = dir_name 9 | self.height = height 10 | self.width = width 11 | self.camera_id = 1 12 | self.fps = fps 13 | self.frames = [] 14 | self.save_both_views = False 15 | 16 | def init(self, enabled=True): 17 | self.frames = [] 18 | #self.frames_2 = [] 19 | 20 | self.enabled = self.dir_name is not None and enabled 21 | 22 | def record(self, env, mode=None): 23 | if self.enabled: 24 | frame = env.unwrapped.render_obs( 25 | mode='rgb_array', 26 | height=self.height, 27 | width=self.width, 28 | camera_id=self.camera_id 29 | ) 30 | frame = frame[0] 31 | if self.save_both_views: frame_2 = frame[0] 32 | 33 | frame_2 = torch.FloatTensor(frame).permute(2,0,1).unsqueeze(0).cuda() 34 | if self.save_both_views: frame_2 = augmentations.random_color_jitter(frame_2).div(255).squeeze(0).cpu().permute(1, 2, 0) 35 | 36 | if mode is not None and 'video' in mode: 37 | _env = env 38 | while 'video' not in _env.__class__.__name__.lower(): 39 | _env = _env.env 40 | frame = _env.apply_to(frame) 41 | if self.save_both_views: frame_2 = _env.apply_to(frame_2) 42 | self.frames.append(frame) 43 | if self.save_both_views: self.frames_2.append(frame_2) 44 | 45 | def save(self, file_name): 46 | if self.enabled: 47 | path = os.path.join(self.dir_name, file_name) 48 | if self.save_both_views: path_2 = os.path.join(self.dir_name, '_x_'+file_name) 49 | imageio.mimsave(path, self.frames, fps=self.fps) 50 | if self.save_both_views: imageio.mimsave(path_2, self.frames_2, fps=self.fps) 51 | --------------------------------------------------------------------------------