15 |
16 |
17 | ## Method
18 |
19 | **GraphIRL** is a self-supervised method for learning a visually invariant reward function directly from a set of diverse third-person video demonstrations via a graph abstraction. Our framework builds an object-centric graph abstraction from video demonstrations and then learns an embedding space that captures task progression by exploiting the temporal cue in the videos. This embedding space is then used to construct a domain invariant and embodiment invariant reward function which can be used to train any standard reinforcement learning algorithm.
20 |
21 |
22 |
23 |
24 |
25 |
26 | ## Citation
27 |
28 | If you use our method or code in your research, please consider citing the paper as follows:
29 |
30 | ```
31 | @article{kumar2022inverse,
32 | title={Graph Inverse Reinforcement Learning from Diverse Videos},
33 | author={Kumar, Sateesh and Zamora, Jonathan and Hansen, Nicklas and Jangir, Rishabh and Wang, Xiaolong},
34 | journal={arXiv preprint arXiv:2207.14299},
35 | year={2022}
36 | }
37 | ```
38 |
39 | ## Instructions
40 |
41 | We assume you have installed [MuJoCo](http://www.mujoco.org) on your machine. You can install dependencies using `conda`:
42 |
43 | ```
44 | conda env create -f environment.yaml
45 | conda activate graphirl
46 | ```
47 |
48 | You can use trained models to extract the reward and train a policy by running:
49 |
50 | ```
51 | bash scripts/run_${Task_Name} /path/to/trained_model
52 | ```
53 |
54 | ## Dataset
55 |
56 | The dataset for GraphIRL can be found in this [Google Drive Folder](https://drive.google.com/drive/folders/1mNJmnyzIoCudRcTdRVrN3WAiuWIM8355?usp=share_link). We have also released the trained reward models for Reach, Push and Peg in Box [here](https://drive.google.com/drive/folders/1O69YtAmq7hqU6kmutqGr-I0vBo7bu8f3?usp=sharing).
57 | We include a [script](scripts/download_data.sh) for downloading all data with the `gdown` package, though feel free to directly download the files from our drive folder if you run into issues with the script.
58 |
59 | ## License & Acknowledgements
60 |
61 | GraphIRL is licensed under the MIT license. [MuJoCo](https://github.com/deepmind/mujoco) is licensed under the Apache 2.0 license. We thank the [XIRL](https://x-irl.github.io/) authors for open-sourcing their codebase to the community, our work is built on top of their engineering efforts.
62 |
--------------------------------------------------------------------------------
/compute_goal_embedding.py:
--------------------------------------------------------------------------------
1 | """Compute and store the mean goal embedding using a trained model."""
2 |
3 | import os
4 | import pickle
5 | import typing
6 |
7 | from absl import app
8 | from absl import flags
9 | import numpy as np
10 | import torch
11 | from torchkit import checkpoint
12 | from graphirl import common
13 | import tqdm
14 |
15 | FLAGS = flags.FLAGS
16 |
17 | flags.DEFINE_string("experiment_path", None, "Path to model checkpoint.")
18 | flags.DEFINE_boolean("restore_checkpoint", True, "Restore model checkpoint.")
19 |
20 | flags.mark_flag_as_required("experiment_path")
21 |
22 | ModelType = torch.nn.Module
23 | DataLoaderType = typing.Dict[str, torch.utils.data.DataLoader]
24 |
25 |
26 | def embed(
27 | model,
28 | downstream_loader,
29 | device,
30 | ):
31 | """Embed the stored trajectories and compute mean goal embedding."""
32 | goal_embs = []
33 | init_embs = []
34 | for class_name, class_loader in downstream_loader.items():
35 | print(f"\tEmbedding class: {class_name}...")
36 | for batch_idx, batch in enumerate(class_loader):
37 |
38 | #if batch_idx % 100 == 0:
39 |
40 | out = model.infer(batch["frames"].to(device))
41 | emb = out.numpy().embs
42 | goal_embs.append(emb[-1, :])
43 | init_embs.append(emb[0, :])
44 |
45 | goal_emb = np.mean(np.stack(goal_embs, axis=0), axis=0, keepdims=True)
46 | dist_to_goal = np.linalg.norm(
47 | np.stack(init_embs, axis=0) - goal_emb, axis=-1).mean()
48 | distance_scale = 1.0 / dist_to_goal
49 | return goal_emb, distance_scale
50 |
51 |
52 | def setup(device):
53 | """Load the latest embedder checkpoint and dataloaders."""
54 | config = common.load_config_from_dir(FLAGS.experiment_path)
55 | model = common.get_model(config)
56 | downstream_loaders = common.get_downstream_dataloaders(config, False)["train"]
57 | checkpoint_dir = os.path.join(FLAGS.experiment_path, "checkpoints")
58 | if FLAGS.restore_checkpoint:
59 | checkpoint_manager = checkpoint.CheckpointManager(
60 | checkpoint.Checkpoint(model=model),
61 | checkpoint_dir,
62 | device,
63 | )
64 | global_step = checkpoint_manager.restore_or_initialize()
65 | print(f"Restored model from checkpoint {global_step}.")
66 | else:
67 | print("Skipping checkpoint restore.")
68 | return model, downstream_loaders
69 |
70 |
71 | def main(_):
72 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73 | model, downstream_loader = setup(device)
74 | model.to(device).eval()
75 | goal_emb, distance_scale = embed(model, downstream_loader, device)
76 | with open(os.path.join(FLAGS.experiment_path, "goal_emb.pkl"), "wb") as fp:
77 | pickle.dump(goal_emb, fp)
78 | with open(os.path.join(FLAGS.experiment_path, "distance_scale.pkl"), "wb") as fp:
79 | pickle.dump(distance_scale, fp)
80 |
81 |
82 | if __name__ == "__main__":
83 | app.run(main)
84 |
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: graphirl
2 | channels:
3 | - defaults
4 | - pytorch
5 | dependencies:
6 | - python=3.8
7 | - pytorch=1.7.1
8 | - torchvision=0.8.2
9 | - cudatoolkit=11.0
10 | - pip
11 | - pip:
12 | - albumentations==0.5.2
13 | - gym==0.17.*
14 | - pymunk==5.6.0
15 | - ml-collections==0.1.0
16 | - scikit-learn==0.24.1
17 | - tensorflow-cpu==2.6.0
18 | - imageio
19 | - imageio-ffmpeg
20 | - tqdm
21 | - absl-py
22 | - numpy
23 | - scipy
24 | - pandas
25 | - gdown==4.4.0
26 | - protobuf~=3.19.0
27 | - wandb
28 | - natsort
29 | - pillow==6.2.0
30 | - termcolor
31 | - opencv-python
32 | - xmltodict
33 | - scikit-image
34 | - seaborn
35 | - mujoco-py
36 | - kornia
37 | - einops
38 |
--------------------------------------------------------------------------------
/graphirl/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 |
--------------------------------------------------------------------------------
/graphirl/common.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Functionality common to pretraining and evaluation."""
17 |
18 | import os
19 | from typing import Dict
20 |
21 | from ml_collections import ConfigDict
22 | import torch
23 | from graphirl import factory
24 | import yaml
25 |
26 | DataLoadersDict = Dict[str, torch.utils.data.DataLoader]
27 |
28 |
29 | def get_pretraining_dataloaders(
30 | config,
31 | debug = False,
32 | ):
33 | """Construct a train/valid pair of pretraining dataloaders.
34 |
35 | Args:
36 | config: ConfigDict object with config parameters.
37 | debug: When set to True, the following happens: 1. Data augmentation is
38 | disabled regardless of config values. 2. Sequential sampling of videos is
39 | turned on. 3. The number of dataloader workers is set to 0.
40 |
41 | Returns:
42 | A dict of train/valid pretraining dataloaders.
43 | """
44 |
45 | def _loader(split):
46 | dataset = factory.dataset_from_config(config, False, split, debug)
47 | batch_sampler = factory.video_sampler_from_config(
48 | config, dataset.dir_tree, downstream=False, sequential=debug)
49 | return torch.utils.data.DataLoader(
50 | dataset,
51 | collate_fn=dataset.collate_fn,
52 | batch_sampler=batch_sampler,
53 | num_workers=4 if torch.cuda.is_available() and not debug else 0,
54 | pin_memory=torch.cuda.is_available() and not debug,
55 | )
56 |
57 | return {
58 | "train": _loader("train"),
59 | "valid": _loader("valid"),
60 | }
61 |
62 |
63 | def get_downstream_dataloaders(
64 | config,
65 | debug,
66 | ):
67 | """Construct a train/valid pair of downstream dataloaders.
68 |
69 | Args:
70 | config: ConfigDict object with config parameters.
71 | debug: When set to True, the following happens: 1. Data augmentation is
72 | disabled regardless of config values. 2. Sequential sampling of videos is
73 | turned on. 3. The number of dataloader workers is set to 0.
74 |
75 | Returns:
76 | A dict of train/valid downstream dataloaders
77 | """
78 |
79 | def _loader(split):
80 | datasets = factory.dataset_from_config(config, True, split, debug)
81 | loaders = {}
82 | for action_class, dataset in datasets.items():
83 | batch_sampler = factory.video_sampler_from_config(
84 | config, dataset.dir_tree, downstream=True, sequential=debug)
85 | loaders[action_class] = torch.utils.data.DataLoader(
86 | dataset,
87 | collate_fn=dataset.collate_fn,
88 | batch_sampler=batch_sampler,
89 | num_workers=4 if torch.cuda.is_available() and not debug else 0,
90 | pin_memory=torch.cuda.is_available() and not debug,
91 | )
92 | return loaders
93 |
94 | return {
95 | "train": _loader("train"),
96 | "valid": _loader("valid"),
97 | }
98 |
99 |
100 | def get_factories(
101 | config,
102 | device,
103 | debug = False,
104 | ):
105 | """Feed config to factories and return objects."""
106 | pretrain_loaders = get_pretraining_dataloaders(config, debug)
107 | downstream_loaders = get_downstream_dataloaders(config, debug)
108 | model = factory.model_from_config(config)
109 | optimizer = factory.optim_from_config(config, model)
110 | trainer = factory.trainer_from_config(config, model, optimizer, device)
111 | eval_manager = factory.evaluator_from_config(config)
112 | return (
113 | model,
114 | optimizer,
115 | pretrain_loaders,
116 | downstream_loaders,
117 | trainer,
118 | eval_manager,
119 | )
120 |
121 |
122 | def get_model(config):
123 | """Construct a model from a config."""
124 | return factory.model_from_config(config)
125 |
126 |
127 | def load_config_from_dir(exp_dir):
128 | """Load experiment config."""
129 | try:
130 | with open(os.path.join(exp_dir, "config.yaml"), "r") as fp:
131 | cfg = yaml.load(fp, Loader=yaml.FullLoader)
132 | return ConfigDict(cfg)
133 | except FileNotFoundError as e:
134 | raise e
135 |
--------------------------------------------------------------------------------
/graphirl/evaluators/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Evaluators."""
17 |
18 | from .cycle_consistency import ThreeWayCycleConsistency
19 | from .cycle_consistency import TwoWayCycleConsistency
20 | from .emb_visualizer import EmbeddingVisualizer
21 | from .kendalls_tau import KendallsTau
22 | from .manager import EvalManager
23 | from .nn_visualizer import NearestNeighbourVisualizer
24 | from .reconstruction_visualizer import ReconstructionVisualizer
25 | from .reward_visualizer import RewardVisualizer
26 |
27 | __all__ = [
28 | "EvalManager",
29 | "KendallsTau",
30 | "TwoWayCycleConsistency",
31 | "ThreeWayCycleConsistency",
32 | "NearestNeighbourVisualizer",
33 | "RewardVisualizer",
34 | "EmbeddingVisualizer",
35 | "ReconstructionVisualizer",
36 | ]
37 |
--------------------------------------------------------------------------------
/graphirl/evaluators/base.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Base evaluator."""
17 |
18 | import abc
19 | from typing import List, Optional, Union
20 |
21 | import dataclasses
22 | import numpy as np
23 | from torchkit import Logger
24 | from xirl.models import SelfSupervisedOutput
25 |
26 |
27 | @dataclasses.dataclass
28 | class EvaluatorOutput:
29 | """The output of an evaluator."""
30 |
31 | # An evaluator does not necessarily generate all fields below. For example,
32 | # some evaluators like Kendalls Tau return a scalar and image metric, while
33 | # TwoWayCycleConsistency only generates a scalar metric.
34 | scalar: Optional[Union[float, List[float]]] = None
35 | image: Optional[Union[np.ndarray, List[np.ndarray]]] = None
36 | video: Optional[Union[np.ndarray, List[np.ndarray]]] = None
37 |
38 | @staticmethod
39 | def _assert_same_attrs(list_out):
40 | """Ensures a list of this class instance have the same attributes."""
41 |
42 | def _not_none(o):
43 | return [getattr(o, a) is not None for a in ["scalar", "image", "video"]]
44 |
45 | expected = _not_none(list_out[0])
46 | for o in list_out[1:]:
47 | actual = _not_none(o)
48 | assert np.array_equal(expected, actual)
49 |
50 | @staticmethod
51 | def merge(list_out):
52 | """Merge a list of this class instance into one."""
53 | # We need to make sure that all elements of the list have the same
54 | # non-empty (i.e. != None) attributes.
55 | EvaluatorOutput._assert_same_attrs(list_out)
56 | # At this point, we're confident that we only need to check the
57 | # attributes of the first member of the list to guarantee the same
58 | # availability for *all* other members of the list.
59 | scalars = None
60 | if list_out[0].scalar is not None:
61 | scalars = [o.scalar for o in list_out]
62 | images = None
63 | if list_out[0].image is not None:
64 | images = [o.image for o in list_out]
65 | videos = None
66 | if list_out[0].video is not None:
67 | videos = [o.video for o in list_out]
68 | return EvaluatorOutput(scalars, images, videos)
69 |
70 | def log(self, logger, global_step, name,
71 | prefix):
72 | """Log the attributes to tensorboard."""
73 | if self.scalar is not None:
74 | if isinstance(self.scalar, list):
75 | self.scalar = np.mean(self.scalar)
76 | logger.log_scalar(self.scalar, global_step, name, prefix)
77 | if self.image is not None:
78 | if isinstance(self.image, list):
79 | for i, image in enumerate(self.image):
80 | logger.log_image(image, global_step, name + f"_{i}", prefix)
81 | else:
82 | logger.log_image(self.image, global_step, name, prefix)
83 | if self.video is not None:
84 | if isinstance(self.video, list):
85 | for i, video in enumerate(self.video):
86 | logger.log_video(video, global_step, name + f"_{i}", prefix)
87 | else:
88 | logger.log_video(self.video, global_step, name, prefix)
89 |
90 |
91 | class Evaluator(abc.ABC):
92 | """Base class for evaluating a self-supervised model on downstream tasks.
93 |
94 | Subclasses must implement the `_evaluate` method.
95 | """
96 |
97 | def __init__(self, inter_class):
98 | self.inter_class = inter_class
99 |
100 | @abc.abstractmethod
101 | def evaluate(self, outs):
102 | """Evaluate the downstream task in embedding space.
103 |
104 | Args:
105 | outs: A list of outputs generated by the model on the downstream dataset.
106 | :meta public:
107 | """
108 | pass
109 |
--------------------------------------------------------------------------------
/graphirl/evaluators/cycle_consistency.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Cycle-consistency evaluator."""
17 |
18 | import itertools
19 | from typing import List
20 |
21 | from .base import Evaluator
22 | from .base import EvaluatorOutput
23 | import numpy as np
24 | from scipy.spatial.distance import cdist
25 | from xirl.models import SelfSupervisedOutput
26 |
27 |
28 | class _CycleConsistency(Evaluator):
29 | """Base class for cycle consistency evaluation."""
30 |
31 | def __init__(self, n_way, stride, distance):
32 | """Constructor.
33 |
34 | Args:
35 | n_way: The number of cycle-consistency ways.
36 | stride: Controls how many frames are skipped in each video sequence. For
37 | example, if the embedding vector of the first video is (100, 128), a
38 | stride of 5 reduces it to (20, 128).
39 | distance: The distance metric to use when calculating nearest-neighbours.
40 |
41 | Raises:
42 | ValueError: If the distance metric is invalid or the
43 | mode is invalid.
44 | """
45 | super().__init__(inter_class=False)
46 |
47 | assert n_way in [2, 3], "n_way must be 2 or 3."
48 | assert isinstance(stride, int), "stride must be an integer."
49 | if distance not in ["sqeuclidean", "cosine"]:
50 | raise ValueError(
51 | "{} is not a supported distance metric.".format(distance))
52 |
53 | self.n_way = n_way
54 | self.stride = stride
55 | self.distance = distance
56 |
57 | def _evaluate_two_way(self, embs):
58 | """Two-way cycle consistency."""
59 | num_embs = len(embs)
60 | total_combinations = num_embs * (num_embs - 1)
61 | ccs = np.zeros((total_combinations))
62 | idx = 0
63 | for i in range(num_embs):
64 | query_emb = embs[i][::self.stride]
65 | ground_truth = np.arange(len(embs[i]))[::self.stride]
66 | for j in range(num_embs):
67 | if i == j:
68 | continue
69 | candidate_emb = embs[j][::self.stride]
70 | dists = cdist(query_emb, candidate_emb, self.distance)
71 | nns = np.argmin(dists[:, np.argmin(dists, axis=1)], axis=0)
72 | ccs[idx] = np.mean(np.abs(nns - ground_truth) <= 1)
73 | idx += 1
74 | ccs = ccs[~np.isnan(ccs)]
75 | return EvaluatorOutput(scalar=np.mean(ccs))
76 |
77 | def _evaluate_three_way(self, embs):
78 | """Three-way cycle consistency."""
79 | num_embs = len(embs)
80 | cycles = np.stack(list(itertools.permutations(np.arange(num_embs), 3)))
81 | total_combinations = len(cycles)
82 | ccs = np.zeros((total_combinations))
83 | for c_idx, cycle in enumerate(cycles):
84 | # Forward consistency check. Each cycle will be a length 3
85 | # permutation, e.g. U - V - W. We compute nearest neighbours across
86 | # consecutive pairs in the cycle and loop back to the first cycle
87 | # index to obtain: U - V - W - U.
88 | query_emb = None
89 | for i in range(len(cycle)):
90 | if query_emb is None:
91 | query_emb = embs[cycle[i]][::self.stride]
92 | candidate_emb = embs[cycle[(i + 1) % len(cycle)]][::self.stride]
93 | dists = cdist(query_emb, candidate_emb, self.distance)
94 | nns_forward = np.argmin(dists, axis=1)
95 | query_emb = candidate_emb[nns_forward]
96 | ground_truth_forward = np.arange(len(embs[cycle[0]]))[::self.stride]
97 | cc_forward = np.abs(nns_forward - ground_truth_forward) <= 1
98 | # Backward consistency check. A backward check is equivalent to
99 | # reversing the middle pair V - W and performing a forward check,
100 | # e.g. U - W - V - U.
101 | cycle[1:] = cycle[1:][::-1]
102 | query_emb = None
103 | for i in range(len(cycle)):
104 | if query_emb is None:
105 | query_emb = embs[cycle[i]][::self.stride]
106 | candidate_emb = embs[cycle[(i + 1) % len(cycle)]][::self.stride]
107 | dists = cdist(query_emb, candidate_emb, self.distance)
108 | nns_backward = np.argmin(dists, axis=1)
109 | query_emb = candidate_emb[nns_backward]
110 | ground_truth_backward = np.arange(len(embs[cycle[0]]))[::self.stride]
111 | cc_backward = np.abs(nns_backward - ground_truth_backward) <= 1
112 | # Require consistency both ways.
113 | cc = np.logical_and(cc_forward, cc_backward)
114 | ccs[c_idx] = np.mean(cc)
115 | ccs = ccs[~np.isnan(ccs)]
116 | return EvaluatorOutput(scalar=np.mean(ccs))
117 |
118 | def evaluate(self, outs):
119 | embs = [o.embs for o in outs]
120 | if self.n_way == 2:
121 | return self._evaluate_two_way(embs)
122 | return self._evaluate_three_way(embs)
123 |
124 |
125 | class TwoWayCycleConsistency(_CycleConsistency):
126 | """2-way cycle consistency evaluator [1].
127 |
128 | References:
129 | [1]: https://arxiv.org/abs/1805.11592
130 | """
131 |
132 | def __init__(self, stride, distance):
133 | super().__init__(2, stride, distance)
134 |
135 |
136 | class ThreeWayCycleConsistency(_CycleConsistency):
137 | """2-way cycle consistency evaluator [1].
138 |
139 | References:
140 | [1]: https://arxiv.org/abs/1805.11592
141 | """
142 |
143 | def __init__(self, stride, distance):
144 | super().__init__(3, stride, distance)
145 |
--------------------------------------------------------------------------------
/graphirl/evaluators/emb_visualizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """2D embedding visualizer."""
17 |
18 | from typing import List
19 |
20 | from .base import Evaluator
21 | from .base import EvaluatorOutput
22 | import matplotlib.pyplot as plt
23 | import numpy as np
24 | from sklearn.decomposition import PCA
25 | from xirl.models import SelfSupervisedOutput
26 |
27 |
28 | class EmbeddingVisualizer(Evaluator):
29 | """Visualize PCA of the embeddings."""
30 |
31 | def __init__(self, num_seqs):
32 | """Constructor.
33 |
34 | Args:
35 | num_seqs: How many embedding sequences to visualize.
36 |
37 | Raises:
38 | ValueError: If the distance metric is invalid.
39 | """
40 | super().__init__(inter_class=True)
41 |
42 | self.num_seqs = num_seqs
43 |
44 | def _gen_emb_plot(self, embs):
45 | """Create a pyplot plot and save to buffer."""
46 | fig = plt.figure()
47 | for emb in embs:
48 | plt.scatter(emb[:, 0], emb[:, 1])
49 | fig.canvas.draw()
50 | img_arr = np.array(fig.canvas.renderer.buffer_rgba())[:, :, :3]
51 | plt.close()
52 | return img_arr
53 |
54 | def evaluate(self, outs):
55 | embs = [o.embs for o in outs]
56 |
57 | # Randomly sample the embedding sequences we'd like to plot.
58 | seq_idxs = np.random.choice(
59 | np.arange(len(embs)), size=self.num_seqs, replace=False)
60 | seq_embs = [embs[idx] for idx in seq_idxs]
61 |
62 | # Subsample embedding sequences to make them the same length.
63 | seq_lens = [s.shape[0] for s in seq_embs]
64 | min_len = np.min(seq_lens)
65 | same_length_embs = []
66 | for emb in seq_embs:
67 | emb_len = len(emb)
68 | stride = emb_len / min_len
69 | idxs = np.arange(0.0, emb_len, stride).round().astype(int)
70 | idxs = np.clip(idxs, a_min=0, a_max=emb_len - 1)
71 | idxs = idxs[:min_len]
72 | same_length_embs.append(emb[idxs])
73 |
74 | # Flatten embeddings to perform PCA.
75 | same_length_embs = np.stack(same_length_embs)
76 | num_seqs, seq_len, emb_dim = same_length_embs.shape
77 | embs_flat = same_length_embs.reshape(-1, emb_dim)
78 | embs_2d = PCA(n_components=2, random_state=0).fit_transform(embs_flat)
79 | embs_2d = embs_2d.reshape(num_seqs, seq_len, 2)
80 |
81 | image = self._gen_emb_plot(embs_2d)
82 | return EvaluatorOutput(image=image)
83 |
--------------------------------------------------------------------------------
/graphirl/evaluators/kendalls_tau.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Kendall rank correlation coefficient evaluator."""
17 |
18 | from typing import List
19 |
20 | from .base import Evaluator
21 | from .base import EvaluatorOutput
22 | import numpy as np
23 | from scipy.spatial.distance import cdist
24 | from scipy.stats import kendalltau
25 | from xirl.models import SelfSupervisedOutput
26 |
27 |
28 | def softmax(dists, temp = 1.0):
29 | dists_ = np.array(dists - np.max(dists))
30 | exp = np.exp(dists_ / temp)
31 | return exp / np.sum(exp)
32 |
33 |
34 | class KendallsTau(Evaluator):
35 | """Kendall rank correlation coefficient [1].
36 |
37 | References:
38 | [1]: https://en.wikipedia.org/wiki/Kendall_rank_correlation_coefficient
39 | """
40 |
41 | def __init__(self, stride, distance):
42 | """Constructor.
43 |
44 | Args:
45 | stride: Controls how many frames are skipped in each video sequence. For
46 | example, if the embedding vector of the first video is (100, 128), a
47 | stride of 5 reduces it to (20, 128).
48 | distance: The distance metric to use when calculating nearest-neighbours.
49 |
50 | Raises:
51 | ValueError: If the distance metric is invalid.
52 | """
53 | super().__init__(inter_class=False)
54 |
55 | assert isinstance(stride, int), "stride must be an integer."
56 | if distance not in ["sqeuclidean", "cosine"]:
57 | raise ValueError(
58 | "{} is not a supported distance metric.".format(distance))
59 |
60 | self.stride = stride
61 | self.distance = distance
62 |
63 | def evaluate(self, outs):
64 | """Get pairwise nearest-neighbours then compute KT."""
65 | embs = [o.embs for o in outs]
66 | num_embs = len(embs)
67 | total_combinations = num_embs * (num_embs - 1)
68 | taus = np.zeros((total_combinations))
69 | idx = 0
70 | img = None
71 | for i in range(num_embs):
72 | query_emb = embs[i][::self.stride]
73 | for j in range(num_embs):
74 | if i == j:
75 | continue
76 | candidate_emb = embs[j][::self.stride]
77 | dists = cdist(query_emb, candidate_emb, self.distance)
78 | if i == 0 and j == 1:
79 | sim_matrix = []
80 | for k in range(len(query_emb)):
81 | sim_matrix.append(softmax(-dists[k]))
82 | img = np.array(sim_matrix, dtype=np.float32)[Ellipsis, None]
83 | nns = np.argmin(dists, axis=1)
84 | taus[idx] = kendalltau(np.arange(len(nns)), nns).correlation
85 | idx += 1
86 | taus = taus[~np.isnan(taus)]
87 | if taus.size == 0:
88 | tau = 0.0
89 | else:
90 | tau = np.mean(taus)
91 | return EvaluatorOutput(scalar=tau, image=img)
92 |
--------------------------------------------------------------------------------
/graphirl/evaluators/manager.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """An evaluator manager."""
17 |
18 | from typing import Dict, List, Mapping, Optional
19 |
20 | from .base import Evaluator
21 | from .base import EvaluatorOutput
22 | import torch
23 | import torch.nn as nn
24 | from xirl.models import SelfSupervisedOutput
25 |
26 |
27 | class EvalManager:
28 | """Manage a bunch of downstream task evaluators and aggregate their results.
29 |
30 | Specifically, the manager embeds the downstream dataset *once*, and shares
31 | the embeddings across all evaluators for more efficient evaluation.
32 | """
33 |
34 | def __init__(self, evaluators):
35 | """Constructor.
36 |
37 | Args:
38 | evaluators: A mapping from evaluator name to Evaluator instance.
39 | """
40 | self._evaluators = evaluators
41 |
42 | @staticmethod
43 | @torch.no_grad()
44 | def embed(
45 | model,
46 | downstream_loader,
47 | device,
48 | eval_iters,
49 | ):
50 | """Run the model on the downstream data and generate embeddings."""
51 | loader_to_output = {}
52 | for action_name, valid_loader in downstream_loader.items():
53 | outs = []
54 | for batch_idx, batch in enumerate(valid_loader):
55 | if eval_iters is not None and batch_idx >= eval_iters:
56 | break
57 | outs.append(model.infer(batch["frames"].to(device)).numpy())
58 | loader_to_output[action_name] = outs
59 | return loader_to_output
60 |
61 | @torch.no_grad()
62 | def evaluate(
63 | self,
64 | model,
65 | downstream_loader,
66 | device,
67 | eval_iters=None,
68 | ):
69 | """Evaluate the model on the validation data.
70 |
71 | Args:
72 | model: The self-supervised model that will embed the frames in the
73 | downstream loader.
74 | downstream_loader: A downstream dataloader. Has a batch size of 1 and
75 | loads all frames of the video.
76 | device: The compute device.
77 | eval_iters: The number of time to call `next()` on the downstream
78 | iterator. Set to None to evaluate on the entire iterator.
79 |
80 | Returns:
81 | A dict mapping from evaluator name to EvaluatorOutput.
82 | """
83 | model.eval()
84 | print("Embedding downstream dataset...")
85 | downstream_outputs = EvalManager.embed(model, downstream_loader, device,
86 | eval_iters)
87 | eval_to_metric = {}
88 | for evaluator_name, evaluator in self._evaluators.items():
89 | print(f"\tRunning {evaluator_name} evaluator...")
90 | if evaluator.inter_class:
91 | # Merge all downstream classes into a single list and do one
92 | # eval computation.
93 | outs = [
94 | o for out in downstream_outputs.values() for o in out # pylint: disable=g-complex-comprehension
95 | ]
96 | metric = evaluator.evaluate(outs)
97 | else:
98 | # Loop and evaluate over downstream classes separately, then
99 | # merge into a single EvaluatorOutput whose fields are lists.
100 | metrics = []
101 | for outs in downstream_outputs.values():
102 | metrics.append(evaluator.evaluate(outs))
103 | metric = EvaluatorOutput.merge(metrics)
104 | eval_to_metric[evaluator_name] = metric
105 | return eval_to_metric
106 |
--------------------------------------------------------------------------------
/graphirl/evaluators/nn_visualizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Nearest-neighbor evaluator."""
17 |
18 | from typing import List
19 |
20 | from .base import Evaluator
21 | from .base import EvaluatorOutput
22 | import numpy as np
23 | from scipy.spatial.distance import cdist
24 | from xirl.models import SelfSupervisedOutput
25 |
26 |
27 | class NearestNeighbourVisualizer(Evaluator):
28 | """Nearest-neighbour frame visualizer."""
29 |
30 | def __init__(
31 | self,
32 | distance,
33 | num_videos,
34 | num_ctx_frames,
35 | ):
36 | """Constructor.
37 |
38 | Args:
39 | distance: The distance metric to use when calculating nearest-neighbours.
40 | num_videos: The number of video sequences to display.
41 | num_ctx_frames: The number of context frames stacked together for each
42 | individual video frame.
43 |
44 | Raises:
45 | ValueError: If the distance metric is invalid.
46 | """
47 | super().__init__(inter_class=True)
48 |
49 | if distance not in ["sqeuclidean", "cosine"]:
50 | raise ValueError(
51 | "{} is not a supported distance metric.".format(distance))
52 |
53 | self.distance = distance
54 | self.num_videos = num_videos
55 | self.num_ctx_frames = num_ctx_frames
56 |
57 | def evaluate(self, outs):
58 | """Sample source and target sequences and plot nn frames."""
59 |
60 | def _reshape(frame):
61 | s, h, w, c = frame.shape
62 | seq_len = s // self.num_ctx_frames
63 | return frame.reshape(seq_len, self.num_ctx_frames, h, w, c)
64 |
65 | embs = [o.embs for o in outs]
66 | frames = [o.frames for o in outs]
67 |
68 | # Randomly sample the video sequences we'd like to plot.
69 | seq_idxs = np.random.choice(
70 | np.arange(len(embs)), size=self.num_videos, replace=False)
71 |
72 | # Perform nearest-neighbor lookup in embedding space and retrieve the
73 | # frames associated with those embeddings.
74 | cand_frames = [_reshape(frames[seq_idxs[0]])[:, -1]]
75 | for cand_idx in seq_idxs[1:]:
76 | dists = cdist(embs[seq_idxs[0]], embs[cand_idx], self.distance)
77 | nn_ids = np.argmin(dists, axis=1)
78 | c_frames = _reshape(frames[cand_idx])
79 | nn_frames = [c_frames[idx, -1] for idx in nn_ids]
80 | cand_frames.append(np.stack(nn_frames))
81 |
82 | video = np.stack(cand_frames)
83 | return EvaluatorOutput(video=video)
84 |
--------------------------------------------------------------------------------
/graphirl/evaluators/reconstruction_visualizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Frame reconstruction visualizer."""
17 |
18 | from typing import List
19 |
20 | from .base import Evaluator
21 | from .base import EvaluatorOutput
22 | import numpy as np
23 | import torch
24 | import torch.nn.functional as F
25 | from torchvision.utils import make_grid
26 | from xirl.models import SelfSupervisedReconOutput
27 |
28 |
29 | class ReconstructionVisualizer(Evaluator):
30 | """Frame reconstruction visualizer."""
31 |
32 | def __init__(self, num_frames, num_ctx_frames):
33 | """Constructor.
34 |
35 | Args:
36 | num_frames: The number of reconstructed frames in a sequence to display.
37 | num_ctx_frames: The number of context frames stacked together for each
38 | individual video frame.
39 | """
40 | super().__init__(inter_class=False)
41 |
42 | self.num_frames = num_frames
43 | self.num_ctx_frames = num_ctx_frames
44 |
45 | def evaluate(self, outs):
46 | """Plot a frame along with its reconstruction."""
47 |
48 | def _remove_ctx_frames(frame):
49 | s, h, w, c = frame.shape
50 | seq_len = s // self.num_ctx_frames
51 | frame = frame.reshape(seq_len, self.num_ctx_frames, h, w, c)
52 | return frame[:, -1]
53 |
54 | frames = [o.frames for o in outs]
55 | recons = [o.reconstruction for o in outs]
56 |
57 | r_idx = np.random.randint(0, len(frames))
58 | frame = _remove_ctx_frames(frames[r_idx])
59 | recon = _remove_ctx_frames(recons[r_idx])
60 |
61 | # Select which frames we want to plot from the sequence.
62 | frame_idxs = np.random.choice(
63 | np.arange(frame.shape[0]), size=self.num_frames, replace=False)
64 | frame = frame[frame_idxs]
65 | recon = recon[frame_idxs]
66 |
67 | # Downsample the frame.
68 | _, _, sh, _ = recon.shape
69 | _, _, h, _ = frame.shape
70 | scale_factor = sh / h
71 | frame_ds = F.interpolate(
72 | torch.from_numpy(frame).permute(0, 3, 1, 2),
73 | mode='bilinear',
74 | scale_factor=scale_factor,
75 | recompute_scale_factor=False,
76 | align_corners=True).permute(0, 2, 3, 1).numpy()
77 |
78 | # Clip reconstruction between 0 and 1.
79 | recon = np.clip(recon, 0.0, 1.0)
80 |
81 | imgs = np.concatenate([frame_ds, recon], axis=0)
82 | img = make_grid(torch.from_numpy(imgs).permute(0, 3, 1, 2), nrow=2)
83 |
84 | return EvaluatorOutput(image=img.permute(1, 2, 0).numpy())
85 |
--------------------------------------------------------------------------------
/graphirl/evaluators/reward_visualizer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Reward visualizer."""
17 |
18 | from typing import List
19 |
20 | from .base import Evaluator
21 | from .base import EvaluatorOutput
22 | import matplotlib.pyplot as plt
23 | import numpy as np
24 | from scipy.spatial.distance import cdist
25 | from xirl.models import SelfSupervisedOutput
26 |
27 |
28 | class RewardVisualizer(Evaluator):
29 | """Distance to goal state visualizer."""
30 |
31 | def __init__(self, distance, num_plots):
32 | """Constructor.
33 |
34 | Args:
35 | distance: The distance metric to use when calculating nearest-neighbours.
36 | num_plots: The number of reward plots to display.
37 |
38 | Raises:
39 | ValueError: If the distance metric is invalid.
40 | """
41 | super().__init__(inter_class=False)
42 |
43 | if distance not in ["sqeuclidean", "cosine"]:
44 | raise ValueError(
45 | "{} is not a supported distance metric.".format(distance))
46 |
47 | self.distance = distance
48 | self.num_plots = num_plots
49 |
50 | def _gen_reward_plot(self, rewards):
51 | """Create a pyplot plot and save to buffer."""
52 | fig, axes = plt.subplots(1, len(rewards), figsize=(6.4 * len(rewards), 4.8))
53 | if len(rewards) == 1:
54 | axes = [axes]
55 | for i, rew in enumerate(rewards):
56 | axes[i].plot(rew)
57 | fig.text(0.5, 0.04, "Timestep", ha="center")
58 | fig.text(0.04, 0.5, "Reward", va="center", rotation="vertical")
59 | fig.canvas.draw()
60 | img_arr = np.array(fig.canvas.renderer.buffer_rgba())[:, :, :3]
61 | plt.close()
62 | return img_arr
63 |
64 | def _compute_goal_emb(self, embs):
65 | """Compute the mean of all last frame embeddings."""
66 | goal_emb = [emb[-1, :] for emb in embs]
67 | goal_emb = np.stack(goal_emb, axis=0)
68 | goal_emb = np.mean(goal_emb, axis=0, keepdims=True)
69 | return goal_emb
70 |
71 | def evaluate(self, outs):
72 | embs = [o.embs for o in outs]
73 | goal_emb = self._compute_goal_emb(embs)
74 |
75 | # Make sure we sample only as many as are available.
76 | num_plots = min(len(embs), self.num_plots)
77 | rand_idxs = np.random.choice(
78 | np.arange(len(embs)), size=num_plots, replace=False)
79 |
80 | # Compute rewards as distances to the goal embedding.
81 | rewards = []
82 | for idx in rand_idxs:
83 | emb = embs[idx]
84 | dists = cdist(emb, goal_emb, self.distance)
85 | rewards.append(-dists)
86 |
87 | image = self._gen_reward_plot(rewards)
88 | return EvaluatorOutput(image=image)
89 |
--------------------------------------------------------------------------------
/graphirl/tensorizers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Tensorizers convert a packet of video data into a packet of video tensors."""
17 |
18 | import abc
19 | from typing import Any, Dict, Union
20 |
21 | import numpy as np
22 | import torch
23 | import torchvision.transforms.functional as TF
24 |
25 | from graphirl.types import SequenceType
26 |
27 | DataArrayPacket = Dict[SequenceType, Union[np.ndarray, str, int]]
28 | DataTensorPacket = Dict[SequenceType, Union[torch.Tensor, str]]
29 |
30 |
31 | class Tensorizer(abc.ABC):
32 | """Base tensorizer class.
33 |
34 | Custom tensorizers must subclass this class.
35 | """
36 |
37 | @abc.abstractmethod
38 | def __call__(self, x):
39 | pass
40 |
41 |
42 | class IdentityTensorizer(Tensorizer):
43 | """Outputs the input as is."""
44 |
45 | def __call__(self, x):
46 | return x
47 |
48 |
49 | class LongTensorizer(Tensorizer):
50 | """Converts the input to a LongTensor."""
51 |
52 | def __call__(self, x):
53 | return torch.from_numpy(np.asarray(x)).long()
54 |
55 | class FloatTensorizer(Tensorizer):
56 |
57 | def __call__(self, x):
58 | return torch.from_numpy(np.asarray(x)).float()
59 |
60 | class FramesTensorizer(Tensorizer):
61 | """Converts a sequence of video frames to a batched FloatTensor."""
62 |
63 | def __call__(self, x):
64 | assert x.ndim == 4, "Input must be a 4D sequence of frames."
65 | frames = []
66 | for frame in x:
67 | frames.append(TF.to_tensor(frame))
68 | return torch.stack(frames, dim=0)
69 |
70 |
71 | class ToTensor:
72 | """Convert video data to video tensors."""
73 |
74 | MAP = {
75 | SequenceType.FRAMES: FramesTensorizer,
76 | SequenceType.FRAME_IDXS: LongTensorizer,
77 | SequenceType.VIDEO_NAME: IdentityTensorizer,
78 | SequenceType.VIDEO_LEN: LongTensorizer,
79 | }
80 |
81 | def __call__(self, data):
82 | """Iterate and transform the data values.
83 |
84 | Args:
85 | data: A dictionary containing key, value pairs where the key is an enum
86 | member of `SequenceType` and the value is either an int, a string or an
87 | ndarray respecting the key type.
88 |
89 | Raises:
90 | ValueError: If the input is not a dictionary or one of its keys is
91 | not a supported sequence type.
92 |
93 | Returns:
94 | The dictionary with the values tensorized.
95 | """
96 | tensors = {}
97 | for key, np_arr in data.items():
98 | tensors[key] = ToTensor.MAP[key]()(np_arr)
99 | return tensors
100 |
101 |
102 |
103 | class ToTensor_bbox:
104 | """Convert video data to bbox tensors."""
105 |
106 | MAP = {
107 | SequenceType.FRAMES: FloatTensorizer,
108 | SequenceType.FRAME_IDXS: LongTensorizer,
109 | SequenceType.VIDEO_NAME: IdentityTensorizer,
110 | SequenceType.VIDEO_LEN: LongTensorizer,
111 | }
112 |
113 | def __call__(self, data):
114 | """Iterate and transform the data values.
115 |
116 | Args:
117 | data: A dictionary containing key, value pairs where the key is an enum
118 | member of `SequenceType` and the value is either an int, a string or an
119 | ndarray respecting the key type.
120 |
121 | Raises:
122 | ValueError: If the input is not a dictionary or one of its keys is
123 | not a supported sequence type.
124 |
125 | Returns:
126 | The dictionary with the values tensorized.
127 | """
128 | tensors = {}
129 | for key, np_arr in data.items():
130 | tensors[key] = ToTensor_bbox.MAP[key]()(np_arr)
131 | return tensors
132 |
--------------------------------------------------------------------------------
/graphirl/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Trainers."""
17 |
18 | from .base import Trainer
19 | from .classification import GoalFrameClassifierTrainer
20 | from .lifs import LIFSTrainer
21 | from .tcc import TCCTrainer
22 | from .tcn import TCNTrainer
23 |
24 | __all__ = [
25 | "Trainer",
26 | "TCCTrainer",
27 | "TCNTrainer",
28 | "LIFSTrainer",
29 | "GoalFrameClassifierTrainer",
30 | ]
31 |
--------------------------------------------------------------------------------
/graphirl/trainers/base.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Base class for defining training algorithms."""
17 |
18 | import abc
19 | from typing import Dict, List, Optional, Union
20 | from ml_collections.config_dict.config_dict import ConfigDict
21 |
22 | import torch
23 | import torch.nn as nn
24 |
25 | from graphirl.models import SelfSupervisedOutput, SelfSupervisedOutputBbox
26 |
27 | BatchType = Dict[str, Union[torch.Tensor, List[str]]]
28 |
29 |
30 | class Trainer(abc.ABC):
31 | """Base trainer abstraction.
32 |
33 | Subclasses should override `compute_loss`.
34 | """
35 |
36 | def __init__(
37 | self,
38 | model,
39 | optimizer,
40 | device,
41 | config,
42 | ):
43 | """Constructor.
44 |
45 | Args:
46 | model: The model to train.
47 | optimizer: The optimizer to use.
48 | device: The computing device.
49 | config: The config dict.
50 | """
51 | self._model = model
52 | self._optimizer = optimizer
53 | self._device = device
54 | self._config = config
55 |
56 | self._model.to(self._device).train()
57 |
58 | @abc.abstractmethod
59 | def compute_loss(
60 | self,
61 | embs,
62 | batch,
63 | ):
64 | """Compute the loss on a single batch.
65 |
66 | Args:
67 | embs: The output of the embedding network.
68 | batch: The output of a VideoDataset dataloader.
69 |
70 | Returns:
71 | A tensor corresponding to the value of the loss function evaluated
72 | on the given batch.
73 | """
74 | pass
75 |
76 | def compute_auxiliary_loss(
77 | self,
78 | out, # pylint: disable=unused-argument
79 | batch, # pylint: disable=unused-argument
80 | ):
81 | """Compute an auxiliary loss on a single batch.
82 |
83 | Args:
84 | out: The output of the self-supervised model.
85 | batch: The output of a VideoDataset dataloader.
86 |
87 | Returns:
88 | A tensor corresponding to the value of the auxiliary loss function
89 | evaluated on the given batch.
90 | """
91 | return 0.0
92 |
93 | def train_one_iter(self, batch):
94 | """Single forward + backward pass of the model.
95 |
96 | Args:
97 | batch: The output of a VideoDataset dataloader.
98 |
99 | Returns:
100 | A dict of loss values.
101 | """
102 | self._model.train()
103 |
104 | self._optimizer.zero_grad()
105 |
106 | # Forward pass to compute embeddings.
107 | frames = batch["frames"].to(self._device)
108 | out = self._model(frames)
109 |
110 | # Compute losses.
111 | loss = self.compute_loss(out.embs, batch)
112 | aux_loss = self.compute_auxiliary_loss(out, batch)
113 | total_loss = loss + aux_loss
114 |
115 | # Backwards pass + optimization step.
116 | total_loss.backward()
117 | self._optimizer.step()
118 |
119 | return {
120 | "train/base_loss": loss,
121 | "train/auxiliary_loss": aux_loss,
122 | "train/total_loss": total_loss,
123 | }
124 |
125 | @torch.no_grad()
126 | def eval_num_iters(
127 | self,
128 | valid_loader,
129 | eval_iters = None,
130 | ):
131 | """Compute the loss with the model in `eval()` mode.
132 |
133 | Args:
134 | valid_loader: The validation data loader.
135 | eval_iters: The number of time to call `next()` on the data iterator. Set
136 | to None to evaluate on the whole validation set.
137 |
138 | Returns:
139 | A dict of validation losses.
140 | """
141 | self._model.eval()
142 |
143 | val_base_loss = 0.0
144 | val_aux_loss = 0.0
145 | it_ = 0
146 | for batch_idx, batch in enumerate(valid_loader):
147 | if eval_iters is not None and batch_idx >= eval_iters:
148 | break
149 |
150 | frames = batch["frames"].to(self._device)
151 | out = self._model(frames)
152 |
153 | val_base_loss += self.compute_loss(out.embs, batch)
154 | val_aux_loss += self.compute_auxiliary_loss(out, batch)
155 | it_ += 1
156 | val_base_loss /= it_
157 | val_aux_loss /= it_
158 |
159 | return {
160 | "valid/base_loss": val_base_loss,
161 | "valid/auxiliary_loss": val_aux_loss,
162 | "valid/total_loss": val_base_loss + val_aux_loss,
163 | }
164 |
--------------------------------------------------------------------------------
/graphirl/trainers/classification.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Goal classifier trainer."""
17 |
18 | from typing import Dict, List, Union
19 |
20 | import torch
21 | import torch.nn.functional as F
22 | from graphirl.trainers.base import Trainer
23 |
24 | BatchType = Dict[str, Union[torch.Tensor, List[str]]]
25 |
26 |
27 | class GoalFrameClassifierTrainer(Trainer):
28 | """A trainer that learns to classifiy whether an image is a goal frame.
29 |
30 | This should be used in conjunction with the LastFrameAndRandomFrames frame
31 | sampler which ensures the batch of frame sequences consists of first
32 | one goal frame, then by N - 1 random other frames.
33 | """
34 |
35 | def compute_loss(
36 | self,
37 | embs,
38 | batch,
39 | ):
40 | del batch
41 |
42 | batch_size, num_cc_frames, _ = embs.shape
43 |
44 | # Create the labels tensor.
45 | row_tensor = torch.FloatTensor([1] + [0] * (num_cc_frames - 1))
46 | label_tensor = row_tensor.unsqueeze(0).repeat(batch_size, 1)
47 | label_tensor = label_tensor.to(self._device)
48 |
49 | return F.binary_cross_entropy_with_logits(
50 | embs.view(batch_size * num_cc_frames),
51 | label_tensor.view(batch_size * num_cc_frames),
52 | )
53 |
--------------------------------------------------------------------------------
/graphirl/trainers/lifs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """LIFS trainer."""
17 |
18 | from typing import Dict, List, Union
19 |
20 | from ml_collections import ConfigDict
21 | import torch
22 | import torch.nn as nn
23 | import torch.nn.functional as F
24 | from graphirl.models import SelfSupervisedReconOutput
25 | from graphirl.trainers.base import Trainer
26 |
27 | BatchType = Dict[str, Union[torch.Tensor, List[str]]]
28 |
29 |
30 | class LIFSTrainer(Trainer):
31 | """A trainer that implements LIFS from [1].
32 |
33 | This should be used in conjunction with the VariableStridedSampler frame
34 | sampler, which assumes rough alignment between pairs of sequences and hence
35 | a time index can be used to correspond frames across sequences.
36 |
37 | Note that the authors of [1] do not implement a negative term in the
38 | contrastive loss. It is just a similarity (l2) loss with an autoencoding
39 | loss to prevent the embeddings from collapsing to trivial constants.
40 |
41 | References:
42 | [1]: https://arxiv.org/abs/1703.02949
43 | """
44 |
45 | def __init__(
46 | self,
47 | model,
48 | optimizer,
49 | device,
50 | config,
51 | ):
52 | super().__init__(model, optimizer, device, config)
53 |
54 | self.temperature = config.LOSS.LIFS.TEMPERATURE
55 |
56 | def compute_auxiliary_loss(
57 | self,
58 | out,
59 | batch,
60 | ):
61 | reconstruction = out.reconstruction
62 | frames = batch["frames"].to(self._device)
63 | b, t, _, _, _ = reconstruction.shape
64 | reconstruction = reconstruction.view((b * t, *reconstruction.shape[2:]))
65 | frames = frames.view((b * t, *frames.shape[2:]))
66 | _, _, sh, _ = reconstruction.shape
67 | _, _, h, _ = frames.shape
68 | scale_factor = sh / h
69 | frames_ds = F.interpolate(
70 | frames,
71 | mode="bilinear",
72 | scale_factor=scale_factor,
73 | recompute_scale_factor=False,
74 | align_corners=True,
75 | )
76 | return F.mse_loss(reconstruction, frames_ds)
77 |
78 | def compute_loss(
79 | self,
80 | embs,
81 | batch,
82 | ):
83 | del batch
84 |
85 | batch_size, num_cc_frames, num_dims = embs.shape
86 |
87 | # Compute pairwise squared L2 distances between embeddings.
88 | embs_flat = embs.view(-1, num_dims)
89 | distances = torch.cdist(embs_flat, embs_flat).pow(2)
90 | distances = distances / self.temperature
91 |
92 | # Each row in a batch corresponds to a frame sequence. Since this
93 | # baseline assumes rough alignment between sequences, we want columns,
94 | # i.e. frames in each row that belong to the same index to be close
95 | # together in embedding space.
96 | labels = torch.arange(num_cc_frames).unsqueeze(0).repeat(batch_size, 1)
97 | labels = labels.to(self._device)
98 | mask = labels.flatten()[:, None] == labels.flatten()[None, :]
99 | return (distances * mask.float()).sum(dim=-1).mean()
100 |
--------------------------------------------------------------------------------
/graphirl/trainers/tcc.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """TCC trainer."""
17 |
18 | from typing import Dict, List, Union
19 |
20 | from ml_collections import ConfigDict
21 | import torch
22 | from graphirl.losses import compute_tcc_loss
23 | from graphirl.trainers.base import Trainer
24 |
25 | BatchType = Dict[str, Union[torch.Tensor, List[str]]]
26 |
27 |
28 | class TCCTrainer(Trainer):
29 | """A trainer for Temporal Cycle Consistency Learning [1].
30 |
31 | References:
32 | [1]: arxiv.org/abs/1904.07846
33 | """
34 |
35 | def __init__(
36 | self,
37 | model,
38 | optimizer,
39 | device,
40 | config,
41 | ):
42 | super().__init__(model, optimizer, device, config)
43 |
44 | self.normalize_embeddings = config.MODEL.NORMALIZE_EMBEDDINGS
45 | self.stochastic_matching = config.LOSS.TCC.STOCHASTIC_MATCHING
46 | self.loss_type = config.LOSS.TCC.LOSS_TYPE
47 | self.similarity_type = config.LOSS.TCC.SIMILARITY_TYPE
48 | self.cycle_length = config.LOSS.TCC.CYCLE_LENGTH
49 | self.temperature = config.LOSS.TCC.SOFTMAX_TEMPERATURE
50 | self.label_smoothing = config.LOSS.TCC.LABEL_SMOOTHING
51 | self.variance_lambda = config.LOSS.TCC.VARIANCE_LAMBDA
52 | self.huber_delta = config.LOSS.TCC.HUBER_DELTA
53 | self.normalize_indices = config.LOSS.TCC.NORMALIZE_INDICES
54 |
55 | def compute_loss(
56 | self,
57 | embs,
58 | batch,
59 | ):
60 | steps = batch["frame_idxs"].to(self._device)
61 | seq_lens = batch["video_len"].to(self._device)
62 |
63 | # Dynamically determine the number of cycles if using stochastic
64 | # matching.
65 | batch_size, num_cc_frames = embs.shape[:2]
66 | num_cycles = int(batch_size * num_cc_frames)
67 |
68 | return compute_tcc_loss(
69 | embs=embs,
70 | idxs=steps,
71 | seq_lens=seq_lens,
72 | stochastic_matching=self.stochastic_matching,
73 | normalize_embeddings=self.normalize_embeddings,
74 | loss_type=self.loss_type,
75 | similarity_type=self.similarity_type,
76 | num_cycles=num_cycles,
77 | cycle_length=self.cycle_length,
78 | temperature=self.temperature,
79 | label_smoothing=self.label_smoothing,
80 | variance_lambda=self.variance_lambda,
81 | huber_delta=self.huber_delta,
82 | normalize_indices=self.normalize_indices,
83 | )
84 |
--------------------------------------------------------------------------------
/graphirl/trainers/tcn.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """TCN trainer."""
17 |
18 | from typing import Dict, List, Union
19 |
20 | import numpy as np
21 | import torch
22 | from graphirl.trainers.base import Trainer
23 |
24 | BatchType = Dict[str, Union[torch.Tensor, List[str]]]
25 |
26 |
27 | class TCNTrainer(Trainer):
28 | """A trainer that implements a single-view Time Contrastive Network [1].
29 |
30 | This should be used in conjunction with the WindowSampler frame sampler.
31 |
32 | References:
33 | [1]: https://arxiv.org/abs/1704.06888
34 | """
35 |
36 | def __init__(
37 | self,
38 | model,
39 | optimizer,
40 | device,
41 | config,
42 | ):
43 | super().__init__(model, optimizer, device, config)
44 |
45 | self.temperature = config.LOSS.TCN.TEMPERATURE
46 | self.num_pairs = config.LOSS.TCN.NUM_PAIRS
47 | self.pos_radius = config.LOSS.TCN.POS_RADIUS
48 | self.neg_radius = config.LOSS.TCN.NEG_RADIUS
49 |
50 | def compute_loss(
51 | self,
52 | embs,
53 | batch,
54 | ):
55 | del batch
56 |
57 | batch_size, num_cc_frames, _ = embs.shape
58 |
59 | # A positive is centered at the current index and can extend pos_radius
60 | # forward or backward.
61 | # A negative can be sampled anywhere outside a negative radius centered
62 | # around the current index.
63 | batch_pos = []
64 | batch_neg = []
65 | idxs = np.arange(num_cc_frames)
66 | for i in range(batch_size):
67 | pos_delta = np.random.choice(
68 | [-self.pos_radius, self.pos_radius],
69 | size=(num_cc_frames, self.num_pairs),
70 | )
71 | pos_idxs = np.clip(idxs[:, None] + pos_delta, 0, num_cc_frames - 1)
72 | batch_pos.append(torch.LongTensor(pos_idxs))
73 |
74 | negatives = []
75 | for idx in idxs:
76 | allowed = (idxs > (idx + self.neg_radius)) | (
77 | idxs < (idx - self.neg_radius))
78 | neg_idxs = np.random.choice(idxs[allowed], size=self.num_pairs)
79 | negatives.append(neg_idxs)
80 | batch_neg.append(torch.LongTensor(np.vstack(negatives)))
81 |
82 | pos_losses = 0.0
83 | neg_losses = 0.0
84 | for i, (positives, negatives) in enumerate(zip(batch_pos, batch_neg)):
85 | row_idx = torch.arange(num_cc_frames).unsqueeze(1)
86 |
87 | # Compute pairwise squared L2 distances between the embeddings in a
88 | # sequence.
89 | emb_seq = embs[i]
90 | distances = torch.cdist(emb_seq, emb_seq).pow(2)
91 | distances = distances / self.temperature
92 |
93 | # For every embedding in the sequence, we need to minimize its
94 | # distance to every positive we sampled for it.
95 | pos_loss = distances[row_idx, positives]
96 | pos_losses += pos_loss.sum()
97 |
98 | # And for negatives, we need to ensure they are at least a distance
99 | # M apart. We use the squared hinge loss to express this constraint.
100 | neg_margin = 1 - distances[row_idx, negatives]
101 | neg_loss = torch.clamp(neg_margin, min=0).pow(2)
102 | neg_losses += neg_loss.sum()
103 |
104 | total_loss = (pos_losses + neg_losses) / (batch_size * num_cc_frames)
105 | return total_loss
106 |
--------------------------------------------------------------------------------
/graphirl/transforms.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Transformations for video data."""
17 |
18 | import enum
19 | from typing import Callable, Dict, Mapping, Sequence, Tuple
20 | import warnings
21 |
22 | import albumentations as alb
23 | import numpy as np
24 | import torch
25 | from graphirl.types import SequenceType
26 |
27 |
28 | @enum.unique
29 | class PretrainedMeans(enum.Enum):
30 | """Pretrained mean normalization values."""
31 |
32 | IMAGENET = (0.485, 0.456, 0.406)
33 |
34 |
35 | @enum.unique
36 | class PretrainedStds(enum.Enum):
37 | """Pretrained std deviation normalization values."""
38 |
39 | IMAGENET = (0.229, 0.224, 0.225)
40 |
41 |
42 | class UnNormalize:
43 | """Unnormalize a batch of images that have been normalized.
44 |
45 | Speficially, re-multiply by the standard deviation and shift by the mean.
46 | """
47 |
48 | def __init__(
49 | self,
50 | mean,
51 | std,
52 | ):
53 | """Constructor.
54 |
55 | Args:
56 | mean: The color channel means.
57 | std: The color channel standard deviation.
58 | """
59 | if np.asarray(mean).shape:
60 | self.mean = torch.tensor(mean)[Ellipsis, :, None, None]
61 | if np.asarray(std).shape:
62 | self.std = torch.tensor(std)[Ellipsis, :, None, None]
63 |
64 | def __call__(self, tensor):
65 | return (tensor * self.std) + self.mean
66 |
67 |
68 | def augment_video(
69 | frames,
70 | pipeline, # pylint: disable=g-bare-generic
71 | ):
72 | """Apply the same augmentation pipeline to all frames in a video.
73 |
74 | Args:
75 | frames: A numpy array of shape (T, H, W, 3), where T is the number of frames
76 | in the video.
77 | pipeline (list): A list containing albumentation augmentations.
78 |
79 | Returns:
80 | The augmented frames of shape (T, H, W, 3).
81 |
82 | Raises:
83 | ValueError: If the input video doesn't have the correct shape.
84 | """
85 | if frames.ndim != 4:
86 | raise ValueError("Input video must be a 4D sequence of frames.")
87 |
88 | transform = alb.ReplayCompose(pipeline, p=1.0)
89 |
90 | # Apply a transformation to the first frame and record the parameters
91 | # that were sampled in a replay, then use the parameters stored in the
92 | # replay to apply an identical transform to the remaining frames in the
93 | # sequence.
94 | with warnings.catch_warnings():
95 | # This supresses albumentations' warning related to ReplayCompose.
96 | warnings.simplefilter("ignore")
97 |
98 | replay, frames_aug = None, []
99 | for frame in frames:
100 | if replay is None:
101 | aug = transform(image=frame)
102 | replay = aug.pop("replay")
103 | else:
104 | aug = transform.replay(replay, image=frame)
105 | frames_aug.append(aug["image"])
106 |
107 | return np.stack(frames_aug, axis=0)
108 |
109 |
110 | class VideoAugmentor:
111 | """Data augmentation for videos.
112 |
113 | Augmentor consistently augments data across the time dimension (i.e. dim 0).
114 | In other words, the same transformation is applied to every single frame in
115 | a video sequence.
116 |
117 | Currently, only image frames, i.e. SequenceType.FRAMES in a video can be
118 | augmented.
119 | """
120 |
121 | MAP = {
122 | SequenceType.FRAMES: augment_video,
123 | }
124 |
125 | def __init__(
126 | self,
127 | params, # pylint: disable=g-bare-generic
128 | ):
129 | """Constructor.
130 |
131 | Args:
132 | params:
133 |
134 | Raises:
135 | ValueError: If params contains an unsupported data augmentation.
136 | """
137 | for key in params.keys():
138 | if key not in SequenceType:
139 | raise ValueError(f"{key} is not a supported SequenceType.")
140 | self._params = params
141 |
142 | def __call__(
143 | self,
144 | data,
145 | ):
146 | """Iterate and transform the data values.
147 |
148 | Currently, data augmentation is only applied to video frames, i.e. the
149 | value of the data dict associated with the SequenceType.IMAGE key.
150 |
151 | Args:
152 | data: A dict mapping from sequence type to sequence value.
153 |
154 | Returns:
155 | A an augmented dict.
156 | """
157 | for key, transforms in self._params.items():
158 | data[key] = VideoAugmentor.MAP[key](data[key], transforms)
159 | return data
160 |
--------------------------------------------------------------------------------
/graphirl/types.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Types shared across modules."""
17 |
18 | import enum
19 |
20 |
21 | @enum.unique
22 | class SequenceType(enum.Enum):
23 | """Sequence data types we know how to preprocess.
24 |
25 | If you need to preprocess additional video data, you must add it here.
26 | """
27 |
28 | FRAMES = "frames"
29 | FRAME_IDXS = "frame_idxs"
30 | VIDEO_NAME = "video_name"
31 | VIDEO_LEN = "video_len"
32 |
33 | def __str__(self): # pylint: disable=invalid-str-returned
34 | return self.value
35 |
36 |
37 | @enum.unique
38 | class ImageTransformationType(enum.Enum):
39 | """Transformations we know how to run on images.
40 |
41 | If you want to add additional augmentations, you must add them here.
42 | """
43 |
44 | RANDOM_RESIZED_CROP = "random_resized_crop"
45 | CENTER_CROP = "center_crop"
46 | GLOBAL_RESIZE = "global_resize"
47 | VERTICAL_FLIP = "vertical_flip"
48 | HORIZONTAL_FLIP = "horizontal_flip"
49 | COLOR_JITTER = "color_jitter"
50 | ROTATE = "rotate"
51 | DROPOUT = "dropout"
52 | NORMALIZE = "normalize"
53 | UNNORMALIZE = "unnormalize"
54 |
55 | def __str__(self): # pylint: disable=invalid-str-returned
56 | return self.value
57 |
--------------------------------------------------------------------------------
/graphirl/video_samplers.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2021 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | """Video samplers for mini-batch creation."""
17 |
18 | import abc
19 | from typing import Dict, Iterator, List, Tuple
20 |
21 | import numpy as np
22 | import torch
23 | from torch.utils.data import Sampler
24 |
25 | ClassIdxVideoIdx = Tuple[int, int]
26 | DirTreeIndices = List[List[ClassIdxVideoIdx]]
27 | VideoBatchIter = Iterator[List[ClassIdxVideoIdx]]
28 |
29 |
30 | class VideoBatchSampler(abc.ABC, Sampler):
31 | """Base class for all video samplers."""
32 |
33 | def __init__(
34 | self,
35 | dir_tree,
36 | batch_size,
37 | sequential = False,
38 | ):
39 | """Constructor.
40 |
41 | Args:
42 | dir_tree: The directory tree of a `datasets.VideoDataset`.
43 | batch_size: The number of videos in a batch.
44 | sequential: Set to `True` to disable any shuffling or randomness.
45 | """
46 | assert isinstance(batch_size, int)
47 |
48 | self._batch_size = batch_size
49 | self._dir_tree = dir_tree
50 | self._sequential = sequential
51 |
52 | @abc.abstractmethod
53 | def _generate_indices(self):
54 | """Generate batch chunks containing (class idx, video_idx) tuples."""
55 | pass
56 |
57 | def __iter__(self):
58 | idxs = self._generate_indices()
59 | if self._sequential:
60 | return iter(idxs)
61 | return iter(idxs[i] for i in torch.randperm(len(idxs)))
62 |
63 | def __len__(self):
64 | num_vids = 0
65 | for vids in self._dir_tree.values():
66 | num_vids += len(vids)
67 | return num_vids // self.batch_size
68 |
69 | @property
70 | def batch_size(self):
71 | return self._batch_size
72 |
73 | @property
74 | def dir_tree(self):
75 | return self._dir_tree
76 |
77 |
78 | class RandomBatchSampler(VideoBatchSampler):
79 | """Randomly samples videos from different classes into the same batch.
80 |
81 | Note the `sequential` arg is disabled here.
82 | """
83 |
84 | def _generate_indices(self):
85 | # Generate a list of video indices for every class.
86 | all_idxs = []
87 | for k, v in enumerate(self._dir_tree.values()):
88 | seq = list(range(len(v)))
89 | all_idxs.extend([(k, s) for s in seq])
90 | # Shuffle the indices.
91 | all_idxs = [all_idxs[i] for i in torch.randperm(len(all_idxs))]
92 | # If we have less total videos than the batch size, we pad with clones
93 | # until we reach a length of batch_size.
94 | if len(all_idxs) < self._batch_size:
95 | while len(all_idxs) < self._batch_size:
96 | all_idxs.append(all_idxs[np.random.randint(0, len(all_idxs))])
97 | # Split the list of indices into chunks of len `batch_size`.
98 | idxs = []
99 | end = self._batch_size * (len(all_idxs) // self._batch_size)
100 | for i in range(0, end, self._batch_size):
101 | batch_idxs = all_idxs[i:i + self._batch_size]
102 | idxs.append(batch_idxs)
103 | return idxs
104 |
105 |
106 | class SameClassBatchSampler(VideoBatchSampler):
107 | """Ensures all videos in a batch belong to the same class."""
108 |
109 | def _generate_indices(self):
110 | idxs = []
111 | for k, v in enumerate(self._dir_tree.values()):
112 | # Generate a list of indices for every video in the class.
113 | len_v = len(v)
114 | seq = list(range(len_v))
115 | if not self._sequential:
116 | seq = [seq[i] for i in torch.randperm(len(seq))]
117 | # Split the list of indices into chunks of len `batch_size`,
118 | # ensuring we drop the last chunk if it is not of adequate length.
119 | batch_idxs = []
120 | end = self._batch_size * (len_v // self._batch_size)
121 | for i in range(0, end, self._batch_size):
122 | xs = seq[i:i + self._batch_size]
123 | # Add the class index to the video index.
124 | xs = [(k, x) for x in xs]
125 | batch_idxs.append(xs)
126 | idxs.extend(batch_idxs)
127 | return idxs
128 |
129 |
130 | class SameClassBatchSamplerDownstream(SameClassBatchSampler):
131 | """A same class batch sampler with a batch size of 1.
132 |
133 | This batch sampler is used for downstream datasets. Since such datasets
134 | typically load a variable number of frames per video, we are forced to use
135 | a batch size of 1.
136 | """
137 |
138 | def __init__(
139 | self,
140 | dir_tree,
141 | sequential = False,
142 | ):
143 | super().__init__(dir_tree, batch_size=1, sequential=sequential)
144 |
--------------------------------------------------------------------------------
/media/preview.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/media/preview.gif
--------------------------------------------------------------------------------
/media/summary.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/media/summary.png
--------------------------------------------------------------------------------
/scripts/pegbox_helper.sh:
--------------------------------------------------------------------------------
1 | export MUJOCO_GL=egl
2 | python3 src/train.py \
3 | --algorithm ${12} \
4 | --domain_name robot \
5 | --task_name $1 \
6 | --episode_length 50 \
7 | --exp_suffix ${11} \
8 | --eval_mode none \
9 | --save_video \
10 | --eval_freq 10k \
11 | --train_steps 1500k \
12 | --buffer_size 500k \
13 | --save_freq 50k \
14 | --frame_stack 1 \
15 | --log_dir /data/sateesh/logs_pegbox_graphirl \
16 | --eval_episodes 50 \
17 | --image_size 84 \
18 | --render_image_size 84 \
19 | --seed $4 \
20 | --cameras $2 \
21 | --camera_dropout $3 \
22 | --action_space $5 \
23 | --attention $7 \
24 | --num_head_layers $6 \
25 | --concat $8 \
26 | --observation_type ${13} \
27 | --context1 $9 \
28 | --context2 ${10} \
29 | --wandb_project pegbox_graphirl \
30 | --save_video \
31 | --pretrained_path {15} \
32 | --reward_wrapper gil_pegbox \
33 | --apply_wrapper \
34 | --wandb
--------------------------------------------------------------------------------
/scripts/push_helper.sh:
--------------------------------------------------------------------------------
1 | export MUJOCO_GL=egl
2 | python3 src/train.py \
3 | --algorithm ${12} \
4 | --domain_name robot \
5 | --task_name $1 \
6 | --episode_length 50 \
7 | --exp_suffix ${11} \
8 | --eval_mode none \
9 | --save_video \
10 | --eval_freq 10k \
11 | --train_steps 800k \
12 | --buffer_size 500k \
13 | --save_freq 50k \
14 | --frame_stack 1 \
15 | --log_dir push_graphirl \
16 | --eval_episodes 50 \
17 | --image_size 84 \
18 | --render_image_size 84 \
19 | --seed $4 \
20 | --cameras $2 \
21 | --camera_dropout $3 \
22 | --action_space $5 \
23 | --attention $7 \
24 | --num_head_layers $6 \
25 | --concat $8 \
26 | --observation_type ${13} \
27 | --context1 $9 \
28 | --context2 ${10} \
29 | --wandb_project push_graphirl \
30 | --save_video \
31 | --pretrained_path ${15} \
32 | --reward_wrapper gil \
33 | --apply_wrapper \
34 | --wandb
35 |
--------------------------------------------------------------------------------
/scripts/reach_helper.sh:
--------------------------------------------------------------------------------
1 | export MUJOCO_GL=egl
2 | python3 src/train.py \
3 | --algorithm ${12} \
4 | --domain_name robot \
5 | --task_name $1 \
6 | --episode_length 50 \
7 | --exp_suffix ${11} \
8 | --eval_mode none \
9 | --save_video \
10 | --eval_freq 10k \
11 | --train_steps 300k \
12 | --save_freq 50k \
13 | --frame_stack 1 \
14 | --log_dir /data/sateesh/reach_graphirl_logs \
15 | --eval_episodes 50 \
16 | --image_size 84 \
17 | --render_image_size 224 \
18 | --seed $4 \
19 | --cameras $2 \
20 | --camera_dropout $3 \
21 | --action_space $5 \
22 | --attention $7 \
23 | --num_head_layers $6 \
24 | --concat $8 \
25 | --observation_type ${13} \
26 | --context1 $9 \
27 | --context2 ${10} \
28 | --wandb_project reach_graphirl \
29 | --save_video \
30 | --pretrained_path ${15} \
31 | --reward_wrapper gil_reach \
32 | --apply_wrapper \
33 | --wandb
--------------------------------------------------------------------------------
/scripts/run_pegbox.sh:
--------------------------------------------------------------------------------
1 | sh scripts_export/pegbox_helper.sh push 2 0 0 xy 3 1 0 1 1 push_graphirl sacv2 statefull+image 0 $1
--------------------------------------------------------------------------------
/scripts/run_push.sh:
--------------------------------------------------------------------------------
1 | sh scripts_export/push_helper.sh push 2 0 0 xy 3 1 0 1 1 push_graphirl sacv2 statefull+image 0 $1
--------------------------------------------------------------------------------
/scripts/run_reach.sh:
--------------------------------------------------------------------------------
1 | sh scripts_export/reach_helper.sh reach 2 0 0 xy 3 1 0 1 1 reach_graphirl sacv2 statefull+image 0 $1
--------------------------------------------------------------------------------
/src/algorithms/drq.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from copy import deepcopy
6 | import utils
7 | import algorithms.modules as m
8 | from algorithms.sac import SAC
9 |
10 |
11 | class DrQ(SAC): # [K=1, M=1]
12 | def __init__(self, obs_shape, state_space, action_shape, args):
13 | super().__init__(obs_shape, action_shape, args)
14 |
15 | def update(self, replay_buffer, L, step):
16 | obs, _, action, reward, next_obs, _ = replay_buffer.sample()
17 |
18 | self.update_critic(obs, action, reward, next_obs, L, step)
19 |
20 | if step % self.actor_update_freq == 0:
21 | self.update_actor_and_alpha(obs, L, step)
22 |
23 | if step % self.critic_target_update_freq == 0:
24 | self.soft_update_critic_target()
25 |
--------------------------------------------------------------------------------
/src/algorithms/drq_multiview.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from copy import deepcopy
6 | import utils
7 | import algorithms.modules as m
8 | from algorithms.sac import SAC
9 | from algorithms.multiview import MultiView
10 |
11 |
12 | class DrQMultiView(MultiView): # [K=1, M=1]
13 | def __init__(self, obs_shape, action_shape, args):
14 | super().__init__(obs_shape, action_shape, args)
15 |
16 | def update(self, replay_buffer, L, step):
17 | obs, action, reward, next_obs = replay_buffer.sample_drq()
18 |
19 | self.update_critic(obs, action, reward, next_obs, L, step)
20 |
21 | if step % self.actor_update_freq == 0:
22 | self.update_actor_and_alpha(obs, L, step)
23 |
24 | if step % self.critic_target_update_freq == 0:
25 | self.soft_update_critic_target()
26 |
--------------------------------------------------------------------------------
/src/algorithms/factory.py:
--------------------------------------------------------------------------------
1 | from algorithms.sac import SAC
2 | from algorithms.sacv2 import SACv2
3 | from algorithms.drq import DrQ
4 | from algorithms.sveav2 import SVEAv2
5 | from algorithms.multiview import MultiView
6 | from algorithms.drq_multiview import DrQMultiView
7 | from algorithms.drqv2 import DrQv2
8 | from algorithms.sacv2_3d import SACv2_3D
9 |
10 | algorithm = {
11 | 'sac': SAC,
12 | 'sacv2': SACv2,
13 | 'sacv2_3d':SACv2_3D,
14 | 'drq': DrQ,
15 | 'sveav2': SVEAv2,
16 | 'multiview': MultiView,
17 | 'drq_multiview': DrQMultiView,
18 | 'drqv2': DrQv2
19 | }
20 |
21 |
22 | def make_agent(obs_shape, state_shape , action_shape, args):
23 | if args.algorithm=='sacv2_3d':
24 | if args.use_latent:
25 | a_obs_shape = (args.bottleneck*32, 32, 32)
26 | elif args.use_impala:
27 | a_obs_shape = (32, 8, 8)
28 | else:
29 | a_obs_shape = (32, 26, 26)
30 | return algorithm[args.algorithm](a_obs_shape, (3, 64, 64), action_shape, args)
31 | else:
32 | return algorithm[args.algorithm](obs_shape, state_shape, action_shape, args)
--------------------------------------------------------------------------------
/src/algorithms/multiview.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from copy import deepcopy
6 | import utils
7 | import algorithms.modules as m
8 |
9 |
10 | class MultiView(object):
11 | def __init__(self, obs_shape, action_shape, args):
12 | self.discount = args.discount
13 | self.critic_tau = args.critic_tau
14 | self.encoder_tau = args.encoder_tau
15 | self.actor_update_freq = args.actor_update_freq
16 | self.critic_target_update_freq = args.critic_target_update_freq
17 | self.use_vit = args.use_vit
18 |
19 | shared_cnn_1 = m.SharedCNN(obs_shape, args.num_shared_layers, args.num_filters, args.mean_zero).cuda() # Third Person Image
20 | shared_cnn_2 = m.SharedCNN(obs_shape, args.num_shared_layers, args.num_filters, args.mean_zero).cuda() # First Person Image
21 |
22 | integrator = m.Integrator(shared_cnn_1.out_shape, shared_cnn_2.out_shape, args.num_filters).cuda() # Change channel dimensions of concatenated features
23 |
24 | assert shared_cnn_1.out_shape==shared_cnn_2.out_shape
25 | head_cnn = m.HeadCNN(shared_cnn_1.out_shape, args.num_head_layers, args.num_filters).cuda() # Pass concatenated features into the head
26 |
27 | actor_encoder = m.MultiViewEncoder(
28 | shared_cnn_1,
29 | shared_cnn_2,
30 | integrator,
31 | head_cnn,
32 | m.RLProjection(head_cnn.out_shape, args.projection_dim)
33 | )
34 | critic_encoder = m.MultiViewEncoder(
35 | shared_cnn_1,
36 | shared_cnn_2,
37 | integrator,
38 | head_cnn,
39 | m.RLProjection(head_cnn.out_shape, args.projection_dim)
40 | )
41 |
42 | self.actor = m.Actor(actor_encoder, action_shape, args.hidden_dim, args.actor_log_std_min, args.actor_log_std_max, multiview=True).cuda()
43 | self.critic = m.Critic(critic_encoder, action_shape, args.hidden_dim, multiview=True).cuda()
44 | self.critic_target = deepcopy(self.critic)
45 |
46 | self.log_alpha = torch.tensor(np.log(args.init_temperature)).cuda()
47 | self.log_alpha.requires_grad = True
48 | self.target_entropy = -np.prod(action_shape)
49 |
50 | self.actor_optimizer = torch.optim.Adam(
51 | self.actor.parameters(), lr=args.actor_lr, betas=(args.actor_beta, 0.999)
52 | )
53 | self.critic_optimizer = torch.optim.Adam(
54 | self.critic.parameters(), lr=args.critic_lr, betas=(args.critic_beta, 0.999)
55 | )
56 | self.log_alpha_optimizer = torch.optim.Adam(
57 | [self.log_alpha], lr=args.alpha_lr, betas=(args.alpha_beta, 0.999)
58 | )
59 |
60 | self.train()
61 | self.critic_target.train()
62 |
63 | def train(self, training=True):
64 | self.training = training
65 | self.actor.train(training)
66 | self.critic.train(training)
67 |
68 | def eval(self):
69 | self.train(False)
70 |
71 | @property
72 | def alpha(self):
73 | return self.log_alpha.exp()
74 |
75 | def _obs_to_input(self, obs):
76 | if isinstance(obs, utils.LazyFrames):
77 | _obs = np.array(obs)
78 | else:
79 | _obs = obs
80 | _obs = torch.FloatTensor(_obs).cuda()
81 | _obs = _obs.unsqueeze(0)
82 | return _obs
83 |
84 | def select_action(self, obs):
85 | _obs = self._obs_to_input(obs)
86 | with torch.no_grad():
87 | mu, _, _, _ = self.actor(_obs, compute_pi=False, compute_log_pi=False)
88 | return mu.cpu().data.numpy().flatten()
89 |
90 | def sample_action(self, obs):
91 | _obs = self._obs_to_input(obs)
92 | with torch.no_grad():
93 | mu, pi, _, _ = self.actor(_obs, compute_log_pi=False)
94 | return pi.cpu().data.numpy().flatten()
95 |
96 | def update_critic(self, obs, action, reward, next_obs, L=None, step=None):
97 | with torch.no_grad():
98 | _, policy_action, log_pi, _ = self.actor(next_obs)
99 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
100 | target_V = torch.min(target_Q1,
101 | target_Q2) - self.alpha.detach() * log_pi
102 | target_Q = reward + (self.discount * target_V)
103 |
104 | current_Q1, current_Q2 = self.critic(obs, action)
105 | critic_loss = F.mse_loss(current_Q1,
106 | target_Q) + F.mse_loss(current_Q2, target_Q)
107 | if L is not None:
108 | L.log('train_critic/loss', critic_loss, step)
109 |
110 | self.critic_optimizer.zero_grad()
111 | critic_loss.backward()
112 | self.critic_optimizer.step()
113 |
114 | def update_actor_and_alpha(self, obs, L=None, step=None, update_alpha=True):
115 | _, pi, log_pi, log_std = self.actor(obs, detach=True)
116 | actor_Q1, actor_Q2 = self.critic(obs, pi, detach=True)
117 |
118 | actor_Q = torch.min(actor_Q1, actor_Q2)
119 | actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
120 |
121 | if L is not None:
122 | L.log('train_actor/loss', actor_loss, step)
123 | entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)
124 | ) + log_std.sum(dim=-1)
125 |
126 | self.actor_optimizer.zero_grad()
127 | actor_loss.backward()
128 | self.actor_optimizer.step()
129 |
130 | if update_alpha:
131 | self.log_alpha_optimizer.zero_grad()
132 | alpha_loss = (self.alpha * (-log_pi - self.target_entropy).detach()).mean()
133 |
134 | if L is not None:
135 | L.log('train_alpha/loss', alpha_loss, step)
136 | L.log('train_alpha/value', self.alpha, step)
137 |
138 | alpha_loss.backward()
139 | self.log_alpha_optimizer.step()
140 |
141 | def soft_update_critic_target(self):
142 | utils.soft_update_params(
143 | self.critic.Q1, self.critic_target.Q1, self.critic_tau
144 | )
145 | utils.soft_update_params(
146 | self.critic.Q2, self.critic_target.Q2, self.critic_tau
147 | )
148 | utils.soft_update_params(
149 | self.critic.encoder, self.critic_target.encoder,
150 | self.encoder_tau
151 | )
152 |
153 | def update(self, replay_buffer, L, step):
154 | obs, action, reward, next_obs = replay_buffer.sample()
155 |
156 | self.update_critic(obs, action, reward, next_obs, L, step)
157 |
158 | if step % self.actor_update_freq == 0:
159 | self.update_actor_and_alpha(obs, L, step)
160 |
161 | if step % self.critic_target_update_freq == 0:
162 | self.soft_update_critic_target()
--------------------------------------------------------------------------------
/src/algorithms/rot_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import math
6 |
7 |
8 | def euler2mat(angle, scaling=False, translation=False):
9 | """Convert euler angles to rotation matrix.
10 | Reference: https://github.com/pulkitag/pycaffe-utils/blob/master/rot_utils.py#L174
11 | Args:
12 | angle: rotation angle along 3 axis (a, b, y) in radians -- size = [B, 3]
13 | Returns:
14 | Rotation matrix corresponding to the euler angles -- size = [B, 3, 3]
15 | """
16 | B = angle.size(0)
17 | euler_angle = 'bay'
18 |
19 | if euler_angle == 'bay':
20 | a, b, y = angle[:, 0], angle[:, 1], angle[:, 2]
21 | x, y, z = b, a, y
22 | elif euler_angle == 'aby':
23 | x, y, z = angle[:, 0], angle[:, 1], angle[:, 2]
24 | else:
25 | raise NotImplementedError
26 |
27 | cosz = torch.cos(z)
28 | sinz = torch.sin(z)
29 |
30 | zeros = z.detach() * 0
31 | ones = zeros.detach() + 1
32 | zmat = torch.stack([cosz, -sinz, zeros,
33 | sinz, cosz, zeros,
34 | zeros, zeros, ones], dim=1).reshape(B, 3, 3)
35 |
36 | cosy = torch.cos(y)
37 | siny = torch.sin(y)
38 |
39 | ymat = torch.stack([cosy, zeros, siny,
40 | zeros, ones, zeros,
41 | -siny, zeros, cosy], dim=1).reshape(B, 3, 3)
42 |
43 | cosx = torch.cos(x)
44 | sinx = torch.sin(x)
45 |
46 | xmat = torch.stack([ones, zeros, zeros,
47 | zeros, cosx, -sinx,
48 | zeros, sinx, cosx], dim=1).reshape(B, 3, 3)
49 |
50 | rotMat = xmat @ ymat @ zmat
51 | # rotMat = zmat
52 | # rotMat = ymat @ zmat
53 | # rotMat = xmat @ ymat
54 | # rotMat = xmat @ zmat
55 |
56 | if scaling:
57 | v_scale = angle[:,3]
58 | v_trans = angle[:,4:]
59 | else:
60 | v_trans = angle[:,3:]
61 |
62 | if scaling:
63 | # one = torch.ones_like(v_scale).detach()
64 | # t_scale = torch.stack([v_scale, one, one,
65 | # one, v_scale, one,
66 | # one, one, v_scale], dim=1).view(B, 3, 3)
67 | rotMat = rotMat * v_scale.unsqueeze(1).unsqueeze(1)
68 |
69 | if translation:
70 | rotMat = torch.cat([rotMat, v_trans.view([B, 3, 1]).cuda()], 2) # F.affine_grid takes 3x4
71 | else:
72 | rotMat = torch.cat([rotMat, torch.zeros([B, 3, 1]).cuda().detach()], 2) # F.affine_grid takes 3x4
73 |
74 | return rotMat
75 |
76 |
77 | def quat2mat(quat, scaling=False, translation=False):
78 | """Convert quaternion coefficients to rotation matrix.
79 | Args:
80 | quat: first three coeff of quaternion of rotation. fourht is then computed to have a norm of 1 -- size = [B, 3]
81 | Returns:
82 | Rotation matrix corresponding to the quaternion -- size = [B, 3, 3]
83 | """
84 | norm_quat = torch.cat([quat[:,:1].detach()*0 + 1, quat], dim=1)
85 | norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True)
86 | w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3]
87 |
88 | B = quat.size(0)
89 |
90 | w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
91 | wx, wy, wz = w*x, w*y, w*z
92 | xy, xz, yz = x*y, x*z, y*z
93 |
94 | rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz,
95 | 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx,
96 | 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).reshape(B, 3, 3)
97 | rotMat = torch.cat([rotMat, torch.zeros([B, 3, 1]).cuda().detach()], 2) # F.affine_grid takes 3x4
98 |
99 | return rotMat
100 |
101 |
--------------------------------------------------------------------------------
/src/algorithms/sac.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from copy import deepcopy
6 | import utils
7 | import algorithms.modules as m
8 |
9 |
10 | class SAC(object):
11 | def __init__(self, obs_shape, action_shape, args):
12 | self.discount = args.discount
13 | self.critic_tau = args.critic_tau
14 | self.encoder_tau = args.encoder_tau
15 | self.actor_update_freq = args.actor_update_freq
16 | self.critic_target_update_freq = args.critic_target_update_freq
17 | self.from_state = args.observation_type=='state'
18 | self.use_vit = args.use_vit
19 |
20 | if self.from_state:
21 | shared = m.Identity(obs_shape)
22 | head = m.Identity(obs_shape)
23 | elif self.use_vit:
24 | shared = m.SharedTransformer(obs_shape, args.num_shared_layers, args.svea_num_heads, args.svea_embed_dim).cuda()
25 | head = m.HeadCNN(shared.out_shape, 0, 0).cuda()
26 | else:
27 | shared = m.SharedCNN(obs_shape, args.num_shared_layers, args.num_filters, args.mean_zero).cuda()
28 | head = m.HeadCNN(shared.out_shape, args.num_head_layers, args.num_filters).cuda()
29 | actor_encoder = m.Encoder(
30 | shared,
31 | head,
32 | m.RLProjection(head.out_shape, args.projection_dim)
33 | )
34 | critic_encoder = m.Encoder(
35 | shared,
36 | head,
37 | m.RLProjection(head.out_shape, args.projection_dim)
38 | )
39 |
40 | self.actor = m.Actor(actor_encoder, action_shape, args.hidden_dim, args.actor_log_std_min, args.actor_log_std_max).cuda()
41 | self.critic = m.Critic(critic_encoder, action_shape, args.hidden_dim).cuda()
42 | self.critic_target = deepcopy(self.critic)
43 |
44 | self.log_alpha = torch.tensor(np.log(args.init_temperature)).cuda()
45 | self.log_alpha.requires_grad = True
46 | self.target_entropy = -np.prod(action_shape)
47 |
48 | self.actor_optimizer = torch.optim.Adam(
49 | self.actor.parameters(), lr=args.actor_lr, betas=(args.actor_beta, 0.999)
50 | )
51 | self.critic_optimizer = torch.optim.Adam(
52 | self.critic.parameters(), lr=args.critic_lr, betas=(args.critic_beta, 0.999)
53 | )
54 | self.log_alpha_optimizer = torch.optim.Adam(
55 | [self.log_alpha], lr=args.alpha_lr, betas=(args.alpha_beta, 0.999)
56 | )
57 |
58 | self.train()
59 | self.critic_target.train()
60 |
61 | def train(self, training=True):
62 | self.training = training
63 | self.actor.train(training)
64 | self.critic.train(training)
65 |
66 | def eval(self):
67 | self.train(False)
68 |
69 | @property
70 | def alpha(self):
71 | return self.log_alpha.exp()
72 |
73 | def _obs_to_input(self, obs):
74 | if isinstance(obs, utils.LazyFrames):
75 | _obs = np.array(obs)
76 | else:
77 | _obs = obs
78 | _obs = torch.FloatTensor(_obs).cuda()
79 | _obs = _obs.unsqueeze(0)
80 | return _obs
81 |
82 | def select_action(self, obs, state=None):
83 | _obs = self._obs_to_input(obs)
84 | with torch.no_grad():
85 | mu, _, _, _ = self.actor(_obs, compute_pi=False, compute_log_pi=False)
86 | return mu.cpu().data.numpy().flatten()
87 |
88 | def sample_action(self, obs, state=None, step=None):
89 | _obs = self._obs_to_input(obs)
90 | with torch.no_grad():
91 | mu, pi, _, _ = self.actor(_obs, compute_log_pi=False)
92 | return pi.cpu().data.numpy().flatten()
93 |
94 | def update_critic(self, obs, action, reward, next_obs, L=None, step=None):
95 | with torch.no_grad():
96 | _, policy_action, log_pi, _ = self.actor(next_obs)
97 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
98 | target_V = torch.min(target_Q1,
99 | target_Q2) - self.alpha.detach() * log_pi
100 | target_Q = reward + (self.discount * target_V)
101 |
102 | current_Q1, current_Q2 = self.critic(obs, action)
103 | critic_loss = F.mse_loss(current_Q1,
104 | target_Q) + F.mse_loss(current_Q2, target_Q)
105 | if L is not None:
106 | L.log('train_critic/loss', critic_loss, step)
107 |
108 | self.critic_optimizer.zero_grad()
109 | critic_loss.backward()
110 | self.critic_optimizer.step()
111 |
112 | def update_actor_and_alpha(self, obs, L=None, step=None, update_alpha=True):
113 | _, pi, log_pi, log_std = self.actor(obs, detach=True)
114 | actor_Q1, actor_Q2 = self.critic(obs, pi, detach=True)
115 |
116 | actor_Q = torch.min(actor_Q1, actor_Q2)
117 | actor_loss = (self.alpha.detach() * log_pi - actor_Q).mean()
118 |
119 | if L is not None:
120 | L.log('train_actor/loss', actor_loss, step)
121 | entropy = 0.5 * log_std.shape[1] * (1.0 + np.log(2 * np.pi)
122 | ) + log_std.sum(dim=-1)
123 |
124 | self.actor_optimizer.zero_grad()
125 | actor_loss.backward()
126 | self.actor_optimizer.step()
127 |
128 | if update_alpha:
129 | self.log_alpha_optimizer.zero_grad()
130 | alpha_loss = (self.alpha * (-log_pi - self.target_entropy).detach()).mean()
131 |
132 | if L is not None:
133 | L.log('train_alpha/loss', alpha_loss, step)
134 | L.log('train_alpha/value', self.alpha, step)
135 |
136 | alpha_loss.backward()
137 | self.log_alpha_optimizer.step()
138 |
139 | def soft_update_critic_target(self):
140 | utils.soft_update_params(
141 | self.critic.Q1, self.critic_target.Q1, self.critic_tau
142 | )
143 | utils.soft_update_params(
144 | self.critic.Q2, self.critic_target.Q2, self.critic_tau
145 | )
146 | utils.soft_update_params(
147 | self.critic.encoder, self.critic_target.encoder,
148 | self.encoder_tau
149 | )
150 |
151 | def update(self, replay_buffer, L, step):
152 | obs, _, action, reward, next_obs, _ = replay_buffer.sample()
153 |
154 | self.update_critic(obs, action, reward, next_obs, L, step)
155 |
156 | if step % self.actor_update_freq == 0:
157 | self.update_actor_and_alpha(obs, L, step)
158 |
159 | if step % self.critic_target_update_freq == 0:
160 | self.soft_update_critic_target()
161 |
--------------------------------------------------------------------------------
/src/algorithms/svea_multiview.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from copy import deepcopy
6 | import utils
7 | import augmentations
8 | import algorithms.modules as m
9 | from algorithms.sac import SAC
10 |
11 |
12 | class SVEAMultiView(MultiView):
13 | def __init__(self, obs_shape, action_shape, args):
14 | super().__init__(obs_shape, action_shape, args)
15 | self.svea_alpha = args.svea_alpha
16 | self.svea_beta = args.svea_beta
17 | self.svea_augmentation = args.svea_augmentation
18 | raise NotImplementedError('use sveav2 instead')
19 |
20 | def augment(self, obs):
21 | if self.svea_augmentation == 'colorjitter':
22 | return augmentations.random_color_jitter(obs.clone())
23 | elif self.svea_augmentation == 'conv':
24 | return augmentations.random_conv(obs.clone())
25 | else:
26 | raise NotImplementedError(f'Unsupported augmentation: {self.svea_augmentation}')
27 |
28 | def update_critic(self, obs, action, reward, next_obs, L=None, step=None):
29 | with torch.no_grad():
30 | _, policy_action, log_pi, _ = self.actor(next_obs)
31 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
32 | target_V = torch.min(target_Q1,
33 | target_Q2) - self.alpha.detach() * log_pi
34 | target_Q = reward + (self.discount * target_V)
35 |
36 | if self.svea_alpha == self.svea_beta:
37 | obs = utils.cat(obs, self.augment(obs))
38 | action = utils.cat(action, action)
39 | target_Q = utils.cat(target_Q, target_Q)
40 |
41 | current_Q1, current_Q2 = self.critic(obs, action)
42 | critic_loss = (self.svea_alpha + self.svea_beta) * \
43 | (F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q))
44 | else:
45 | current_Q1, current_Q2 = self.critic(obs, action)
46 | critic_loss = self.svea_alpha * \
47 | (F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q))
48 |
49 | obs_aug = self.augment(obs)
50 | current_Q1_aug, current_Q2_aug = self.critic(obs_aug, action)
51 | critic_loss += self.svea_beta * \
52 | (F.mse_loss(current_Q1_aug, target_Q) + F.mse_loss(current_Q2_aug, target_Q))
53 |
54 | if L is not None:
55 | L.log('train_critic/loss', critic_loss, step)
56 |
57 | self.critic_optimizer.zero_grad()
58 | critic_loss.backward()
59 | self.critic_optimizer.step()
60 |
61 | def update(self, replay_buffer, L, step):
62 | obs, action, reward, next_obs = replay_buffer.sample_drq(pad=6 if self.use_vit else 4)
63 |
64 | self.update_critic(obs, action, reward, next_obs, L, step)
65 |
66 | if step % self.actor_update_freq == 0:
67 | self.update_actor_and_alpha(obs, L, step)
68 |
69 | if step % self.critic_target_update_freq == 0:
70 | self.soft_update_critic_target()
71 |
--------------------------------------------------------------------------------
/src/algorithms/sveav2.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from copy import deepcopy
6 | import utils
7 | import augmentations
8 | import algorithms.modules as m
9 | from algorithms.sacv2 import SACv2
10 |
11 |
12 | class SVEAv2(SACv2):
13 | def __init__(self, obs_shape, action_shape, args):
14 | super().__init__(obs_shape, action_shape, args)
15 | self.svea_alpha = args.svea_alpha
16 | self.svea_beta = args.svea_beta
17 | self.svea_augmentation = args.svea_augmentation
18 | self.naive = args.naive
19 |
20 | def augment(self, obs):
21 | if self.svea_augmentation == 'colorjitter':
22 | return augmentations.random_color_jitter(obs.clone())
23 | elif self.svea_augmentation == 'affine+colorjitter':
24 | return augmentations.random_color_jitter(augmentations.random_affine(obs.clone()))
25 | elif self.svea_augmentation == 'noise':
26 | return augmentations.random_noise(obs.clone())
27 | elif self.svea_augmentation == 'affine+noise':
28 | return augmentations.random_noise(augmentations.random_affine(obs.clone()))
29 | elif self.svea_augmentation == 'conv':
30 | return augmentations.random_conv(obs.clone())
31 | elif self.svea_augmentation == 'affine+conv':
32 | return augmentations.random_conv(augmentations.random_affine(obs.clone()))
33 | elif self.svea_augmentation == 'overlay':
34 | return augmentations.random_overlay(obs.clone())
35 | elif self.svea_augmentation == 'affine+overlay':
36 | return augmentations.random_overlay(augmentations.random_affine(obs.clone()))
37 | elif self.svea_augmentation == 'none':
38 | return obs
39 | else:
40 | raise NotImplementedError(f'Unsupported augmentation: {self.svea_augmentation}')
41 |
42 | def update_critic(self, obs, action, reward, next_obs, L=None, step=None):
43 | with torch.no_grad():
44 | _, policy_action, log_pi, _ = self.actor(next_obs)
45 | target_Q1, target_Q2 = self.critic_target(next_obs, policy_action)
46 | target_V = torch.min(target_Q1,
47 | target_Q2) - self.alpha.detach() * log_pi
48 | target_Q = reward + (self.discount * target_V)
49 |
50 | if self.svea_augmentation != 'none' and not self.naive:
51 | action = utils.cat(action, action)
52 | target_Q = utils.cat(target_Q, target_Q)
53 |
54 | current_Q1, current_Q2 = self.critic(obs, action)
55 | critic_loss = (self.svea_alpha + self.svea_beta) * \
56 | (F.mse_loss(current_Q1, target_Q) + F.mse_loss(current_Q2, target_Q))
57 | if L is not None:
58 | L.log('train_critic/loss', critic_loss, step)
59 |
60 | self.critic_optimizer.zero_grad(set_to_none=True)
61 | critic_loss.backward()
62 | self.critic_optimizer.step()
63 |
64 | def update(self, replay_buffer, L, step):
65 | if step % self.update_freq != 0:
66 | return
67 |
68 | obs, action, reward, next_obs = replay_buffer.sample()
69 | obs = self.aug(obs) # random shift
70 | if self.svea_augmentation != 'none':
71 | if self.naive:
72 | obs = self.augment(obs) # naively apply strong augmentation
73 | else:
74 | obs = utils.cat(obs, self.augment(obs)) # strong augmentation
75 |
76 | if self.multiview:
77 | obs = self.encoder(obs[:,:3,:,:], obs[:,3:6,:,:])
78 | else:
79 | obs = self.encoder(obs)
80 |
81 | if self.svea_augmentation != 'none' and not self.naive:
82 | obs_unaug = obs[:obs.size(0)//2] # unaugmented observations
83 | else:
84 | obs_unaug = obs
85 |
86 | with torch.no_grad():
87 | next_obs = self.aug(next_obs)
88 | if self.svea_augmentation != 'none' and self.naive:
89 | next_obs = self.augment(next_obs) # naively apply strong augmentation
90 | if self.multiview:
91 | next_obs = self.encoder(next_obs[:,:3,:,:], next_obs[:,3:6,:,:])
92 | else:
93 | next_obs = self.encoder(next_obs)
94 |
95 | self.update_critic(obs, action, reward, next_obs, L, step)
96 | self.update_actor_and_alpha(obs_unaug.detach(), L, step)
97 | utils.soft_update_params(self.critic, self.critic_target, self.tau)
98 |
--------------------------------------------------------------------------------
/src/env/__pycache__/wrappers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/__pycache__/wrappers.cpython-37.pyc
--------------------------------------------------------------------------------
/src/env/robot/__pycache__/base.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/__pycache__/base.cpython-37.pyc
--------------------------------------------------------------------------------
/src/env/robot/__pycache__/gym_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/__pycache__/gym_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/src/env/robot/__pycache__/push.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/__pycache__/push.cpython-37.pyc
--------------------------------------------------------------------------------
/src/env/robot/__pycache__/registration.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/__pycache__/registration.cpython-37.pyc
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/base_link.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/base_link.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/base_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/base_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/bellows_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/bellows_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/elbow_flex_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/elbow_flex_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/estop_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/estop_link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/forearm_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/forearm_roll_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/gripper_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/gripper_link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/head_pan_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/head_pan_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/head_tilt_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/head_tilt_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/l_wheel_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/l_wheel_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/laser_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/laser_link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/left_finger.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/left_finger.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/left_inner_knuckle.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/left_inner_knuckle.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/left_outer_knuckle.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/left_outer_knuckle.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/link1.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link1.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/link2.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link2.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/link3.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link3.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/link4.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link4.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/link5.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link5.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/link6.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link6.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/link7.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link7.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/link_base copy.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link_base copy.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/link_base.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/link_base.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/r_wheel_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/r_wheel_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/right_finger.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/right_finger.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/right_inner_knuckle.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/right_inner_knuckle.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/right_outer_knuckle.STL:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/right_outer_knuckle.STL
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/shoulder_lift_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/shoulder_lift_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/shoulder_pan_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/shoulder_pan_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/torso_fixed_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/torso_fixed_link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/torso_lift_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/torso_lift_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/upperarm_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/upperarm_roll_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/wrist_flex_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/wrist_flex_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/fetch/wrist_roll_link_collision.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/fetch/wrist_roll_link_collision.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/golf.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/hammer.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/lift.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/arm/Base_Link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/Base_Link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/arm/Bracelet_Link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/Bracelet_Link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/arm/ForeArm_Link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/ForeArm_Link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/arm/HalfArm1_Link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/HalfArm1_Link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/arm/HalfArm2_Link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/HalfArm2_Link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/arm/Shoulder_Link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/Shoulder_Link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/arm/SphericalWrist1_Link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/SphericalWrist1_Link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/arm/SphericalWrist2_Link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/arm/SphericalWrist2_Link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/inner_finger_coarse.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/inner_finger_coarse.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/inner_knuckle_coarse.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/inner_knuckle_coarse.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/kinova_robotiq_coupler.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/kinova_robotiq_coupler.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/outer_finger_coarse.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/outer_finger_coarse.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/outer_knuckle_coarse.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/outer_knuckle_coarse.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_base_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_base_link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_base_link_coarse.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_base_link_coarse.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_finger_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_finger_link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_finger_tip_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_finger_tip_link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_inner_knuckle_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_inner_knuckle_link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_knuckle_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq/robotiq_85_knuckle_link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_base_link.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_base_link.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_finger.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_finger_vis.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_inner_knuckle_vis.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_finger.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_finger_vis.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.stl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SateeshKumar21/graph-inverse-rl/7d06634a0946fa87e7b138b45e889d549ae7dd3e/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_outer_knuckle_vis.stl
--------------------------------------------------------------------------------
/src/env/robot/assets/robot/mesh/robotiq_85_gripper/robotiq_arg2f_85_pad_vis.dae:
--------------------------------------------------------------------------------
1 |
2 |
3 | 2016-07-17T22:25:43.361178
4 | 2016-07-17T22:25:43.361188
5 | Z_UP
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 | 0.0 0.0 0.0 1.0
14 |
15 |
16 | 0.0 0.0 0.0 1.0
17 |
18 |
19 | 0.7 0.7 0.7 1.0
20 |
21 |
22 | 1 1 1 1.0
23 |
24 |
25 | 0.0
26 |
27 |
28 | 0.0 0.0 0.0 1.0
29 |
30 |
31 | 0.0
32 |
33 |
34 | 0.0 0.0 0.0 1.0
35 |
36 |
37 | 1.0
38 |
39 |
40 |
41 |
42 |
43 | 0
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 | 4.38531e-14 -1 -4.451336e-05 4.38531e-14 -1 -4.451336e-05 -4.38531e-14 1 4.451336e-05 -4.38531e-14 1 4.451336e-05 -1 -4.385301e-14 -2.011189e-15 -1 -4.385301e-14 -2.011189e-15 -2.009237e-15 -4.451336e-05 1 -2.009237e-15 -4.451336e-05 1 1 4.385301e-14 2.011189e-15 1 4.385301e-14 2.011189e-15 2.009237e-15 4.451336e-05 -1 2.009237e-15 4.451336e-05 -1
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 | -10 -23.90175 13.51442 10 -23.9033 48.51442 -10 -23.9033 48.51442 10 -23.90175 13.51442 -10 -18.90175 13.51464 -10 -18.9033 48.51464 10 -18.90175 13.51464 10 -18.9033 48.51464
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |