4 | #
5 | # License: BSD 3 clause
6 |
7 | import numpy as np
8 | from sklearn.base import BaseEstimator, RegressorMixin
9 | from sklearn.metrics.pairwise import pairwise_kernels
10 |
11 |
12 | class KernelRegression(BaseEstimator, RegressorMixin):
13 | """Nadaraya-Watson kernel regression with automatic bandwidth selection.
14 | This implements Nadaraya-Watson kernel regression with (optional) automatic
15 | bandwith selection of the kernel via leave-one-out cross-validation. Kernel
16 | regression is a simple non-parametric kernelized technique for learning a
17 | non-linear relationship between input variable(s) and a target variable.
18 |
19 | Parameters
20 | ----------
21 | kernel : string or callable, default="rbf"
22 | Kernel map to be approximated. A callable should accept two arguments
23 | and the keyword arguments passed to this object as kernel_params, and
24 | should return a floating point number.
25 | gamma : float, default=None
26 | Gamma parameter for the RBF ("bandwidth"), polynomial,
27 | exponential chi2 and sigmoid kernels. Interpretation of the default
28 | value is left to the kernel; see the documentation for
29 | sklearn.metrics.pairwise. Ignored by other kernels. If a sequence of
30 | values is given, one of these values is selected which minimizes
31 | the mean-squared-error of leave-one-out cross-validation.
32 | See also
33 | --------
34 | sklearn.metrics.pairwise.kernel_metrics : List of built-in kernels.
35 | """
36 |
37 | def __init__(self, kernel="rbf", gamma=None):
38 | self.kernel = kernel
39 | self.gamma = gamma
40 |
41 | def fit(self, X, y):
42 | """Fit the model.
43 |
44 | Parameters
45 | ----------
46 | X : array-like of shape = [n_samples, n_features]
47 | The training input samples.
48 | y : array-like, shape = [n_samples]
49 | The target values
50 | Returns
51 | -------
52 | self : object
53 | Returns self.
54 | """
55 | self.X = X
56 | self.y = y
57 |
58 | if hasattr(self.gamma, "__iter__"):
59 | self.gamma = self._optimize_gamma(self.gamma)
60 |
61 | return self
62 |
63 | def predict(self, X):
64 | """Predict target values for X.
65 |
66 | Parameters
67 | ----------
68 | X : array-like of shape = [n_samples, n_features]
69 | The input samples.
70 | Returns
71 | -------
72 | y : array of shape = [n_samples]
73 | The predicted target value.
74 | """
75 | K = pairwise_kernels(self.X, X, metric=self.kernel, gamma=self.gamma)
76 | return (K * self.y[:, None]).sum(axis=0) / K.sum(axis=0)
77 |
78 | def _optimize_gamma(self, gamma_values):
79 | # Select specific value of gamma from the range of given gamma_values
80 | # by minimizing mean-squared error in leave-one-out cross validation
81 | mse = np.empty_like(gamma_values, dtype=np.float)
82 | for i, gamma in enumerate(gamma_values):
83 | K = pairwise_kernels(self.X, self.X, metric=self.kernel, gamma=gamma)
84 | np.fill_diagonal(K, 0) # leave-one-out
85 | Ky = K * self.y[:, np.newaxis]
86 | y_pred = Ky.sum(axis=0) / K.sum(axis=0)
87 | mse[i] = ((y_pred - self.y) ** 2).mean()
88 | try:
89 | return gamma_values[np.nanargmin(mse)]
90 | except:
91 | return 0
92 |
--------------------------------------------------------------------------------
/scripts/tsne_visualization.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os.path
3 |
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import pandas as pd
7 | import seaborn as sns
8 | import torch
9 | from sklearn.manifold import TSNE
10 |
11 | from uvd.decomp.decomp import (
12 | embedding_decomp,
13 | )
14 | from uvd.models.preprocessors import *
15 | import uvd.utils as U
16 |
17 | from decord import VideoReader
18 |
19 |
20 | def vis_2d_tsne(embeddings: np.ndarray, labels: list):
21 | tsne = TSNE(n_components=2)
22 | tsne_result = tsne.fit_transform(embeddings)
23 | tsne_result_df = pd.DataFrame(
24 | {"tsne_1": tsne_result[:, 0], "tsne_2": tsne_result[:, 1], "label": labels}
25 | )
26 | fig, ax = plt.subplots(1)
27 | sns.scatterplot(x="tsne_1", y="tsne_2", hue="label", data=tsne_result_df, ax=ax, s=120)
28 | lim = (tsne_result.min() - 5, tsne_result.max() + 5)
29 | ax.set_xlim(lim)
30 | ax.set_ylim(lim)
31 | ax.set_aspect("equal")
32 | ax.set_title(f"{preprocessor.__class__.__name__}")
33 | plt.show()
34 |
35 |
36 | def vis_3d_tsne(embeddings: np.ndarray, labels: list):
37 | tsne = TSNE(n_components=3)
38 | tsne_result = tsne.fit_transform(embeddings)
39 | tsne_result_df = pd.DataFrame(
40 | {
41 | "tsne_1": tsne_result[:, 0],
42 | "tsne_2": tsne_result[:, 1],
43 | "tsne_3": tsne_result[:, 2],
44 | "label": labels,
45 | }
46 | )
47 |
48 | fig = plt.figure()
49 | ax = fig.add_subplot(111, projection="3d")
50 |
51 | palette = sns.color_palette("viridis", as_cmap=True)
52 | unique_labels = tsne_result_df["label"].unique()
53 | colors = palette(np.linspace(0, 1, len(unique_labels)))
54 | color_dict = dict(zip(unique_labels, colors))
55 |
56 | for label in unique_labels:
57 | subset = tsne_result_df[tsne_result_df["label"] == label]
58 | ax.scatter(
59 | subset["tsne_1"],
60 | subset["tsne_2"],
61 | subset["tsne_3"],
62 | c=[color_dict[label]],
63 | label=label,
64 | s=120,
65 | )
66 | ax.set_title(f"{preprocessor.__class__.__name__}")
67 | plt.show()
68 |
69 |
70 | if __name__ == "__main__":
71 | parser = argparse.ArgumentParser()
72 | parser.add_argument(
73 | "--video_file",
74 | default=U.f_join(
75 | os.path.dirname(__file__), "examples/microwave-bottom_burner-light_switch-slide_cabinet.mp4"
76 | )
77 | )
78 | parser.add_argument("--preprocessor_name", default="vip")
79 | args = parser.parse_args()
80 |
81 | use_gpu = torch.cuda.is_available()
82 | if not use_gpu:
83 | print("NO GPU FOUND")
84 |
85 | frames = VideoReader(args.video_file, height=224, width=224)[:].asnumpy()
86 | preprocessor = get_preprocessor(
87 | args.preprocessor_name, device="cuda" if use_gpu else None
88 | )
89 | embeddings = preprocessor.process(frames, return_numpy=True)
90 | _, decomp_meta = embedding_decomp(
91 | embeddings=embeddings,
92 | fill_embeddings=False,
93 | return_intermediate_curves=False,
94 | normalize_curve=False,
95 | min_interval=20,
96 | smooth_method="kernel",
97 | gamma=0.1,
98 | )
99 | milestone_indices = decomp_meta.milestone_indices
100 | milestone_rgbs = frames[milestone_indices]
101 |
102 | labels = [
103 | i
104 | for i, count in enumerate(milestone_indices)
105 | for _ in range(count - milestone_indices[i - 1] if i > 0 else count)
106 | ]
107 | labels = [labels[0]] + labels
108 |
109 | vis_2d_tsne(embeddings, labels)
110 | vis_3d_tsne(embeddings, labels)
111 |
--------------------------------------------------------------------------------
/uvd/models/preprocessors/vip_preprocessor.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Optional
4 |
5 | import hydra
6 | import omegaconf
7 | import torch
8 |
9 | _VIP_IMPORT_ERROR = None
10 | try:
11 | import vip
12 | except ImportError as e:
13 | _VIP_IMPORT_ERROR = e
14 |
15 | from torch import nn
16 | from torchvision import transforms as T
17 |
18 | import uvd.utils as U
19 | from uvd.models.preprocessors.base import Preprocessor
20 |
21 |
22 | class VipPreprocessor(Preprocessor):
23 | def __init__(
24 | self,
25 | model_type: str | None = None,
26 | device: torch.device | str | None = None,
27 | remove_bn: bool = False,
28 | bn_to_gn: bool = False,
29 | remove_pool: bool = False,
30 | preprocess_with_fc: bool = False,
31 | save_fc: bool = False,
32 | random_crop: bool = False,
33 | ckpt: str | None = None,
34 | **kwargs,
35 | ):
36 | if _VIP_IMPORT_ERROR is not None:
37 | raise ImportError(_VIP_IMPORT_ERROR)
38 | model_type = model_type or "resnet50"
39 | self.random_crop = random_crop
40 | self.ckpt = ckpt
41 | super().__init__(
42 | model_type=model_type,
43 | device=device,
44 | remove_bn=remove_bn,
45 | bn_to_gn=bn_to_gn,
46 | remove_pool=remove_pool,
47 | preprocess_with_fc=preprocess_with_fc,
48 | save_fc=save_fc,
49 | **kwargs,
50 | )
51 |
52 | def _get_model_and_transform(
53 | self, model_type: str | None = None
54 | ) -> tuple[vip.VIP, Optional[T]]:
55 | if model_type is not None:
56 | assert model_type == "resnet50", f"{model_type} not support"
57 | vip.device = self.device
58 | vip_ = load_vip(modelid="resnet50", ckpt_path=self.ckpt).module
59 | resnet = vip_.convnet.to(self.device)
60 | if self.remove_pool:
61 | # if self.save_fc:
62 | self._pool = U.freeze_module(resnet.avgpool)
63 | self._fc = U.freeze_module(resnet.fc)
64 | model = nn.Sequential(*(list(resnet.children())[:-2]))
65 | else:
66 | model = resnet
67 | # crop_transform = T.RandomCrop(224) if self.random_crop else T.CenterCrop(224)
68 | transform = (
69 | # nn.Sequential(T.Resize(224), vip_.normlayer)
70 | T.Compose([T.Resize(224), vip_.normlayer])
71 | if not self.random_crop
72 | # else nn.Sequential(T.Resize(232), T.RandomCrop(224), vip_.normlayer)
73 | else T.Compose([T.Resize(232), T.RandomCrop(224), vip_.normlayer])
74 | )
75 | return model, transform
76 |
77 |
78 | def load_vip(modelid: str = "resnet50", ckpt_path: str | None = None):
79 | if ckpt_path is None:
80 | return vip.load_vip(modelid)
81 | home = U.f_join("~/.vip")
82 | folderpath = U.f_mkdir(home, modelid)
83 | configpath = U.f_join(home, modelid, "config.yaml")
84 | if not U.f_exists(configpath):
85 | try:
86 | configurl = "https://pytorch.s3.amazonaws.com/models/rl/vip/config.yaml"
87 | vip.load_state_dict_from_url(configurl, folderpath)
88 | except:
89 | configurl = (
90 | "https://drive.google.com/uc?id=1XSQE0gYm-djgueo8vwcNgAiYjwS43EG-"
91 | )
92 | vip.gdown.download(configurl, configpath, quiet=False)
93 |
94 | modelcfg = omegaconf.OmegaConf.load(configpath)
95 | cleancfg = vip.cleanup_config(modelcfg)
96 | rep = hydra.utils.instantiate(cleancfg)
97 | rep = torch.nn.DataParallel(rep)
98 | vip_state_dict = torch.load(ckpt_path, map_location="cpu")["vip"]
99 | rep.load_state_dict(vip_state_dict)
100 | return rep
101 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/third_party/franka/assets/assets.xml:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/adept_envs/adept_envs/utils/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import numpy as np
18 |
19 | try:
20 | import cElementTree as ET
21 | except ImportError:
22 | try:
23 | # Python 2.5 need to import a different module
24 | import xml.etree.cElementTree as ET
25 | except ImportError:
26 | exit_err("Failed to import cElementTree from any known place")
27 |
28 | CONFIG_XML_DATA = """
29 |
30 |
31 |
32 |
33 |
34 | """
35 |
36 |
37 | # Read config from root
38 | def read_config_from_node(root_node, parent_name, child_name, dtype=int):
39 | # find parent
40 | parent_node = root_node.find(parent_name)
41 | if parent_node == None:
42 | quit("Parent %s not found" % parent_name)
43 |
44 | # get child data
45 | child_data = parent_node.get(child_name)
46 | if child_data == None:
47 | quit("Child %s not found" % child_name)
48 |
49 | config_val = np.array(child_data.split(), dtype=dtype)
50 | return config_val
51 |
52 |
53 | # get config frlom file or string
54 | def get_config_root_node(config_file_name=None, config_file_data=None):
55 | try:
56 | # get root
57 | if config_file_data is None:
58 | config_file_content = open(config_file_name, "r")
59 | config = ET.parse(config_file_content)
60 | root_node = config.getroot()
61 | else:
62 | root_node = ET.fromstring(config_file_data)
63 |
64 | # get root data
65 | root_data = root_node.get("name")
66 | root_name = np.array(root_data.split(), dtype=str)
67 | except:
68 | quit("ERROR: Unable to process config file %s" % config_file_name)
69 |
70 | return root_node, root_name
71 |
72 |
73 | # Read config from config_file
74 | def read_config_from_xml(config_file_name, parent_name, child_name, dtype=int):
75 | root_node, root_name = get_config_root_node(config_file_name=config_file_name)
76 | return read_config_from_node(root_node, parent_name, child_name, dtype)
77 |
78 |
79 | # tests
80 | if __name__ == "__main__":
81 | print("Read config and parse -------------------------")
82 | root, root_name = get_config_root_node(config_file_data=CONFIG_XML_DATA)
83 | print("Root:name \t", root_name)
84 | print("limit:low \t", read_config_from_node(root, "limits", "low", float))
85 | print("limit:high \t", read_config_from_node(root, "limits", "high", float))
86 | print("scale:joint \t", read_config_from_node(root, "scale", "joint", float))
87 | print("data:type \t", read_config_from_node(root, "data", "type", str))
88 |
89 | # read straight from xml (dum the XML data as duh.xml for this test)
90 | root, root_name = get_config_root_node(config_file_name="duh.xml")
91 | print("Read from xml --------------------------------")
92 | print("limit:low \t", read_config_from_xml("duh.xml", "limits", "low", float))
93 | print("limit:high \t", read_config_from_xml("duh.xml", "limits", "high", float))
94 | print("scale:joint \t", read_config_from_xml("duh.xml", "scale", "joint", float))
95 | print("data:type \t", read_config_from_xml("duh.xml", "data", "type", str))
96 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/third_party/franka/assets/chain0_overlay.xml:
--------------------------------------------------------------------------------
1 |
2 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
--------------------------------------------------------------------------------
/uvd/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | import errno
2 | import glob
3 | import json
4 | import os
5 | import pickle
6 | import shutil
7 |
8 |
9 | __all__ = [
10 | "f_expand",
11 | "f_exists",
12 | "f_join",
13 | "f_listdir",
14 | "f_mkdir",
15 | "f_remove",
16 | "save_pickle",
17 | "load_pickle",
18 | "load_json",
19 | "dump_json",
20 | "write_text",
21 | "ask_if_overwrite",
22 | ]
23 |
24 |
25 | def f_expand(fpath):
26 | return os.path.expandvars(os.path.expanduser(fpath))
27 |
28 |
29 | def f_exists(*fpaths):
30 | return os.path.exists(f_join(*fpaths))
31 |
32 |
33 | def f_join(*fpaths):
34 | """Join file paths and expand special symbols like `~` for home dir."""
35 | return f_expand(os.path.join(*fpaths))
36 |
37 |
38 | def f_listdir(*fpaths, filter=None, sort=False, full_path=False, nonexist_ok=True):
39 | """
40 | Args:
41 | full_path: True to return full paths to the dir contents
42 | filter: function that takes in file name and returns True to include
43 | nonexist_ok: True to return [] if the dir is non-existent, False to raise
44 | sort: sort the file names by alphabetical
45 | """
46 | dir_path = f_join(*fpaths)
47 | if not os.path.exists(dir_path) and nonexist_ok:
48 | return []
49 | files = os.listdir(dir_path)
50 | if filter is not None:
51 | files = [f for f in files if filter(f)]
52 | if sort:
53 | files.sort()
54 | if full_path:
55 | return [os.path.join(dir_path, f) for f in files]
56 | else:
57 | return files
58 |
59 |
60 | def f_mkdir(*fpaths):
61 | """Recursively creates all the subdirs If exist, do nothing."""
62 | fpath = f_join(*fpaths)
63 | os.makedirs(fpath, exist_ok=True)
64 | return fpath
65 |
66 |
67 | def f_remove(fpath, verbose=False, dry_run=False):
68 | """If exist, remove.
69 |
70 | Supports both dir and file. Supports glob wildcard.
71 | """
72 | assert isinstance(verbose, bool)
73 | fpath = f_expand(fpath)
74 | if dry_run:
75 | print("Dry run, delete:", fpath)
76 | return
77 | for f in glob.glob(fpath):
78 | try:
79 | shutil.rmtree(f)
80 | except OSError as e:
81 | if e.errno == errno.ENOTDIR:
82 | try:
83 | os.remove(f)
84 | except: # final resort safeguard
85 | pass
86 | if verbose:
87 | print(f'Deleted "{fpath}"')
88 |
89 |
90 | def save_pickle(data, *fpaths):
91 | with open(f_join(*fpaths), "wb") as fp:
92 | pickle.dump(data, fp)
93 |
94 |
95 | def load_pickle(*fpaths):
96 | with open(f_join(*fpaths), "rb") as fp:
97 | return pickle.load(fp)
98 |
99 |
100 | def load_json(*file_path, **kwargs):
101 | file_path = f_join(*file_path)
102 |
103 | with open(file_path, "r") as fp:
104 | return json.load(fp, **kwargs)
105 |
106 |
107 | def dump_json(data, *file_path, convert_to_primitive=False, **kwargs):
108 | if convert_to_primitive:
109 | from .array_tensor_utils import any_to_primitive
110 |
111 | data = any_to_primitive(data)
112 | file_path = f_join(*file_path)
113 | with open(file_path, "w") as fp:
114 | json.dump(data, fp, **kwargs)
115 |
116 |
117 | def write_text(s, *fpaths):
118 | with open(f_join(*fpaths), "w") as fp:
119 | fp.write(s)
120 |
121 |
122 | def ask_if_overwrite(*fpaths, default_delete: bool = True):
123 | if f_exists(*fpaths):
124 | conflict_ptah = f_join(*fpaths)
125 | ans = input(
126 | f"WARNING: directory ({conflict_ptah}) already exists! \noverwrite? "
127 | + (f"([Y]/n)\n" if default_delete else f"(y/n)\n")
128 | )
129 | if (ans != "n" and default_delete) or ans == "y":
130 | f_remove(conflict_ptah, verbose=True)
131 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | .idea/
161 |
162 | *.pt
163 | *.pth
164 | *pl
165 | *.patch
166 | *used_configs
167 | .allenact_last_start_time_string
168 | *.lock
169 | *wandb
170 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/adept_models/kitchen/assets/hingecabinet_chain.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/adept_envs/adept_envs/simulation/module.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Module for caching Python modules related to simulation."""
17 |
18 | import sys
19 |
20 | _MUJOCO_PY_MODULE = None
21 |
22 | _DM_MUJOCO_MODULE = None
23 | _DM_VIEWER_MODULE = None
24 | _DM_RENDER_MODULE = None
25 |
26 | _GLFW_MODULE = None
27 |
28 |
29 | def get_mujoco_py():
30 | """Returns the mujoco_py module."""
31 | global _MUJOCO_PY_MODULE
32 | if _MUJOCO_PY_MODULE:
33 | return _MUJOCO_PY_MODULE
34 | try:
35 | import mujoco_py
36 |
37 | # Override the warning function.
38 | from mujoco_py.builder import cymj
39 |
40 | cymj.set_warning_callback(_mj_warning_fn)
41 | except ImportError:
42 | print(
43 | "Failed to import mujoco_py. Ensure that mujoco_py (using MuJoCo "
44 | "v1.50) is installed.",
45 | file=sys.stderr,
46 | )
47 | sys.exit(1)
48 | _MUJOCO_PY_MODULE = mujoco_py
49 | return mujoco_py
50 |
51 |
52 | def get_mujoco_py_mjlib():
53 | """Returns the mujoco_py mjlib module."""
54 |
55 | class MjlibDelegate:
56 | """Wrapper that forwards mjlib calls."""
57 |
58 | def __init__(self, lib):
59 | self._lib = lib
60 |
61 | def __getattr__(self, name: str):
62 | if name.startswith("mj"):
63 | return getattr(self._lib, "_" + name)
64 | raise AttributeError(name)
65 |
66 | return MjlibDelegate(get_mujoco_py().cymj)
67 |
68 |
69 | def get_dm_mujoco():
70 | """Returns the DM Control mujoco module."""
71 | global _DM_MUJOCO_MODULE
72 | if _DM_MUJOCO_MODULE:
73 | return _DM_MUJOCO_MODULE
74 | try:
75 | from dm_control import mujoco
76 | except ImportError:
77 | print(
78 | "Failed to import dm_control.mujoco. Ensure that dm_control (using "
79 | "MuJoCo v2.00) is installed.",
80 | file=sys.stderr,
81 | )
82 | sys.exit(1)
83 | _DM_MUJOCO_MODULE = mujoco
84 | return mujoco
85 |
86 |
87 | def get_dm_viewer():
88 | """Returns the DM Control viewer module."""
89 | global _DM_VIEWER_MODULE
90 | if _DM_VIEWER_MODULE:
91 | return _DM_VIEWER_MODULE
92 | try:
93 | from dm_control import viewer
94 | except ImportError:
95 | print(
96 | "Failed to import dm_control.viewer. Ensure that dm_control (using "
97 | "MuJoCo v2.00) is installed.",
98 | file=sys.stderr,
99 | )
100 | sys.exit(1)
101 | _DM_VIEWER_MODULE = viewer
102 | return viewer
103 |
104 |
105 | def get_dm_render():
106 | """Returns the DM Control render module."""
107 | global _DM_RENDER_MODULE
108 | if _DM_RENDER_MODULE:
109 | return _DM_RENDER_MODULE
110 | try:
111 | try:
112 | from dm_control import _render
113 |
114 | render = _render
115 | except ImportError:
116 | print("Warning: DM Control is out of date.")
117 | from dm_control import render
118 | except ImportError:
119 | print(
120 | "Failed to import dm_control.render. Ensure that dm_control (using "
121 | "MuJoCo v2.00) is installed.",
122 | file=sys.stderr,
123 | )
124 | sys.exit(1)
125 | _DM_RENDER_MODULE = render
126 | return render
127 |
128 |
129 | def _mj_warning_fn(warn_data: bytes):
130 | """Warning function override for mujoco_py."""
131 | print(
132 | "WARNING: Mujoco simulation is unstable (has NaNs): {}".format(
133 | warn_data.decode()
134 | )
135 | )
136 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/adept_models/kitchen/assets/counters_chain.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
--------------------------------------------------------------------------------
/uvd/utils/video_utils.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from typing import Union, List, Optional
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from .array_tensor_utils import any_stack, any_to_torch_tensor, any_to_numpy
8 | from .file_utils import f_mkdir, f_join, f_remove
9 |
10 | __all__ = ["save_video", "ffmpeg_save_video", "compress_video", "VideoTensorWriter"]
11 |
12 |
13 | def save_video(
14 | video: Union[np.ndarray, torch.Tensor],
15 | fname: str,
16 | fps: Optional[int] = None,
17 | compress: bool = False,
18 | ):
19 | import torchvision.io
20 | from einops import rearrange
21 |
22 | fname = f_join(fname)
23 | video = any_to_torch_tensor(video)
24 | assert video.ndim == 4, f"must be 4D tensor, {video.shape}"
25 | assert (
26 | video.size(1) == 3 or video.size(3) == 3
27 | ), "shape should be either T3HW or THW3"
28 |
29 | if video.size(1) == 3:
30 | video = rearrange(video, "T C H W -> T H W C")
31 | output_fname = fname
32 | if compress:
33 | fname = fname.split(".")[0] + "_raw." + fname.split(".")[1]
34 | torchvision.io.write_video(fname, video, fps=fps)
35 | if compress:
36 | compress_video(fname, output_fname, delete_input=True)
37 |
38 |
39 | def ffmpeg_save_video(
40 | video: Union[np.ndarray, torch.Tensor], fname: str, fps: Optional[int] = None
41 | ):
42 | """if ffmpeg: error while loading shared libraries: libopenh264.so.5:
43 |
44 | cannot open shared object file: No such file or directory, do `conda
45 | update ffmpeg`
46 | """
47 | import ffmpeg # pip install ffmpeg-python
48 | from einops import rearrange
49 |
50 | video = any_to_numpy(video)
51 | assert video.ndim == 4, f"must be 4D array, {video.shape}"
52 | assert (
53 | video.shape[1] == 3 or video.shape[3] == 3
54 | ), "shape should be either T3HW or THW3"
55 | if video.shape[1] == 3:
56 | video = rearrange(video, "T C H W -> T H W C")
57 |
58 | out = ffmpeg.input(
59 | "pipe:",
60 | format="rawvideo",
61 | pix_fmt="rgb24",
62 | s="{}x{}".format(video.shape[2], video.shape[1], r=fps or 30),
63 | ).output(
64 | fname,
65 | vcodec="libx264",
66 | crf=28,
67 | preset="fast",
68 | pix_fmt="yuv420p",
69 | loglevel="quiet",
70 | )
71 | process = out.run_async(pipe_stdin=True)
72 | try:
73 | for frame in video:
74 | process.stdin.write(frame.tobytes())
75 | except BrokenPipeError:
76 | pass
77 |
78 | process.stdin.close()
79 | process.wait()
80 |
81 |
82 | def compress_video(in_mp4_path: str, out_mp4_path: str, delete_input: bool = True):
83 | ffmpeg = f"/usr/bin/ffmpeg"
84 | commands_list = [
85 | ffmpeg,
86 | "-v",
87 | "quiet",
88 | "-y",
89 | "-i",
90 | in_mp4_path,
91 | "-vcodec",
92 | "libx264",
93 | "-crf",
94 | "28",
95 | out_mp4_path,
96 | ]
97 | assert subprocess.run(commands_list).returncode == 0, commands_list
98 | if delete_input:
99 | f_remove(in_mp4_path)
100 |
101 |
102 | class VideoTensorWriter:
103 | def __init__(self, folder=".", fps=40):
104 | self._folder = folder
105 | self._fps = fps
106 | self._frames = []
107 |
108 | @property
109 | def frames(self) -> List[np.ndarray]:
110 | return self._frames
111 |
112 | def add_frame(self, frame: Union[np.ndarray, torch.Tensor]):
113 | assert len(frame.shape) == 3
114 | self._frames.append(frame)
115 |
116 | def clear(self):
117 | self._frames = []
118 |
119 | def save(
120 | self,
121 | step: Union[int, str],
122 | save: bool = True,
123 | suffix: Optional[str] = None,
124 | fps: Optional[int] = None,
125 | compress: bool = True,
126 | ) -> str:
127 | """
128 | Requires:
129 | pip install av
130 | """
131 | fps = fps or self._fps
132 | fname = str(step) if suffix is None else f"{step}-{suffix}"
133 | in_fname = f"{fname}_raw.mp4" if compress else f"{fname}.mp4"
134 | in_path = f_join(self._folder, in_fname)
135 | out_path = f_join(self._folder, f"{fname}.mp4")
136 | if save:
137 | f_mkdir(self._folder)
138 | save_video(any_stack(self._frames, dim=0), in_path, fps=fps)
139 | if compress:
140 | compress_video(in_path, out_path, delete_input=True)
141 | self.clear()
142 | # clear in record env wrapper if not `save`
143 | return out_path
144 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/adept_envs/adept_envs/franka/assets/franka_kitchen_jntpos_act_ab.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/third_party/franka/assets/chain1.xml:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------
/uvd/models/policy/lang_cond_mlp_policy.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import gym
4 | import numpy as np
5 | import torch
6 | from omegaconf import DictConfig
7 | from torch import nn
8 |
9 | import uvd.utils as U
10 | from uvd.models.preprocessors import get_preprocessor
11 | from .policy_base import PolicyBase
12 | from .. import MLP
13 | from ..distributions import DistributionBase
14 |
15 | __all__ = ["LanguageConditionedMLPPolicy"]
16 |
17 |
18 | class LanguageConditionedMLPPolicy(PolicyBase):
19 | def __init__(
20 | self,
21 | *,
22 | observation_space: gym.spaces.Dict,
23 | action_space: gym.Space,
24 | preprocessor: DictConfig | None = None,
25 | visual: DictConfig | None = None,
26 | obs_encoder: DictConfig,
27 | act_head: DictConfig | None = None,
28 | use_distribution: bool = False,
29 | bn_to_gn_all: bool = False,
30 | visual_as_attention_mask: bool = False,
31 | condition_embed_diff: bool = False,
32 | **kwargs,
33 | ):
34 | super().__init__(**U.prepare_locals_for_super(locals()))
35 |
36 | if visual is not None:
37 | raise NotImplementedError
38 | # assert preprocessor is not None
39 | # preprocessor = {**preprocessor, "remove_pool": True}
40 | else:
41 | # frozen embedding during training and/or preprocessor only used during rollout
42 | preprocessor = {**preprocessor, "remove_pool": False}
43 | self.preprocessor = get_preprocessor(
44 | device=torch.cuda.current_device(),
45 | **preprocessor,
46 | )
47 |
48 | obs_keys = observation_space.spaces.keys()
49 | rgb_obs_dims = observation_space["rgb"].shape
50 | self.rgb_out_dim = 0
51 | if "rgb" in obs_keys:
52 | # (embed_dim, )
53 | self.rgb_out_dim = rgb_obs_dims[0] * 2
54 |
55 | self.proprio_dim = (
56 | observation_space["proprio"].shape[0] if "proprio" in obs_keys else 0
57 | )
58 |
59 | mlp_input_dim = self.rgb_out_dim
60 | self.mlp: MLP = U.hydra_instantiate(
61 | obs_encoder,
62 | input_dim=mlp_input_dim,
63 | proprio_dim=self.proprio_dim,
64 | output_dim=self.action_dim,
65 | actor_critic=False,
66 | )
67 | self.act_head = U.hydra_instantiate(
68 | act_head,
69 | action_dim=action_space if self.is_multi_discrete else self.action_dim,
70 | )
71 |
72 | if bn_to_gn_all:
73 | U.bn_to_gn(self)
74 |
75 | def forward(
76 | self,
77 | obs: dict[str, torch.Tensor] | torch.Tensor | np.ndarray,
78 | goal: torch.Tensor | np.ndarray | None,
79 | deterministic: bool = False,
80 | return_embeddings: bool = False,
81 | ) -> torch.Tensor | DistributionBase | tuple:
82 | if isinstance(obs, dict):
83 | rgb_embed = obs["rgb"]
84 | proprio = obs.get("proprio", None)
85 | else:
86 | rgb_embed = obs
87 | proprio = None
88 |
89 | if self.preprocessor is not None:
90 | preprocessor_output_dim = self.preprocessor.output_dim
91 | preprocessor_output_dim = (
92 | (preprocessor_output_dim,)
93 | if isinstance(preprocessor_output_dim, int)
94 | else preprocessor_output_dim
95 | )
96 | if rgb_embed.shape[1:] != preprocessor_output_dim:
97 | # B, H, W, 3 or B, 3, H, W after transformed
98 | assert (
99 | rgb_embed.ndim == 4
100 | ), f"{rgb_embed.shape}, {preprocessor_output_dim}"
101 | rgb_embed = self.preprocessor.process(rgb_embed, return_numpy=False)
102 |
103 | # language goal, could be different from each substask or for entire task
104 | if goal is not None and goal.shape[1:] != preprocessor_output_dim:
105 | # goal = self.preprocessor.encode_text(goal)
106 | raise NotImplementedError(goal)
107 | if not torch.is_tensor(goal):
108 | goal = torch.as_tensor(
109 | goal, dtype=rgb_embed.dtype, device=rgb_embed.device
110 | )
111 |
112 | x = torch.cat([rgb_embed, goal], dim=-1) if goal is not None else rgb_embed
113 | assert x.shape[0] == rgb_embed.shape[0] and x.ndim == 2, x.shape
114 | # L, action_dim
115 | x = self.mlp(x=x, proprio=proprio)
116 | x = self.act_head(x)
117 | if deterministic:
118 | x = x.mode()
119 | if return_embeddings:
120 | # return "frozen" embeddings
121 | return x, rgb_embed, goal
122 | return x
123 |
--------------------------------------------------------------------------------
/uvd/models/preprocessors/voltron_preprocessor.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import numpy as np
4 | import torch
5 |
6 | _VOLTRON_IMPORT_ERROR = None
7 | try:
8 | import voltron
9 | from voltron import instantiate_extractor, load
10 | except ImportError as e:
11 | _VOLTRON_IMPORT_ERROR = e
12 |
13 | from torchvision import transforms as T
14 |
15 | import uvd.utils as U
16 | from uvd.models.preprocessors.base import Preprocessor
17 |
18 |
19 | AVAILABLE_VOLTRON_MODEL_TYPES = [
20 | # === Voltron ViT-Small (Sth-Sth) Models ===
21 | "v-cond",
22 | "v-dual",
23 | "v-gen",
24 | # === Voltron ViT-Base Model ===
25 | "v-cond-base",
26 | # === Data-Locked Reproductions ===
27 | # "r-mvp",
28 | # "r-r3m-vit",
29 | # "r-r3m-rn50",
30 | ]
31 |
32 |
33 | class VoltronPreprocessor(Preprocessor):
34 | def __init__(
35 | self,
36 | model_type: str | None = None,
37 | device: torch.device | str | None = None,
38 | remove_bn: bool = False,
39 | bn_to_gn: bool = False,
40 | remove_pool: bool = False,
41 | preprocess_with_fc: bool = False,
42 | save_fc: bool = False,
43 | random_crop: bool = False,
44 | ckpt: str | None = None,
45 | use_language_goal: bool = False,
46 | ):
47 | if _VOLTRON_IMPORT_ERROR is not None:
48 | raise ImportError(_VOLTRON_IMPORT_ERROR)
49 | model_type = model_type or "v-cond"
50 | assert model_type in AVAILABLE_VOLTRON_MODEL_TYPES, (
51 | model_type,
52 | AVAILABLE_VOLTRON_MODEL_TYPES,
53 | )
54 | self.random_crop = random_crop
55 | self.ckpt = ckpt
56 | if save_fc or preprocess_with_fc:
57 | U.rank_zero_print(f"WARNING: LIV no fc to save", color="red")
58 | save_fc = False
59 | preprocess_with_fc = False
60 | bn_to_gn = False
61 | super().__init__(
62 | model_type=model_type,
63 | device=device,
64 | remove_bn=remove_bn,
65 | bn_to_gn=bn_to_gn,
66 | remove_pool=remove_pool,
67 | preprocess_with_fc=preprocess_with_fc,
68 | save_fc=save_fc,
69 | use_language_goal=use_language_goal,
70 | )
71 |
72 | self._cached_language_embedding = {}
73 |
74 | def _get_model_and_transform(self, model_type: str) -> tuple:
75 | vcond, preprocess = load(model_type, freeze=True)
76 | vector_extractor = instantiate_extractor(vcond)()
77 | self.vector_extractor = vector_extractor.to(self.device)
78 | preprocess: T.Compose
79 | normlayer = preprocess.transforms[-1]
80 | assert isinstance(normlayer, T.Normalize)
81 | transform = (
82 | T.Compose([T.Resize(224), normlayer])
83 | if not self.random_crop
84 | else T.Compose([T.Resize(232), T.RandomCrop(224), normlayer])
85 | )
86 | return vcond.to(self.device), transform
87 |
88 | def _encode_image(self, img_tensors: torch.Tensor) -> torch.FloatTensor:
89 | with torch.no_grad():
90 | return self.vector_extractor(self.model(img_tensors, mode="visual"))
91 |
92 | def _encode_text(
93 | self, text: str | np.ndarray | list | torch.Tensor
94 | ) -> torch.Tensor:
95 | raise NotImplementedError
96 |
97 | def encode_text(self, text: str | np.ndarray | list | torch.Tensor) -> torch.Tensor:
98 | return self.cached_language_embed(text)
99 |
100 | def cached_language_embed(self, text: str):
101 | if text in self._cached_language_embedding:
102 | return self._cached_language_embedding[text]
103 | text_embed = self._encode_text(text)
104 | self._cached_language_embedding[text] = text_embed
105 | return text_embed
106 |
107 |
108 | def sim(tensor1, tensor2, metric: str = "l2", device=None):
109 | if type(tensor1) == np.ndarray:
110 | tensor1 = torch.from_numpy(tensor1).to(device)
111 | tensor2 = torch.from_numpy(tensor2).to(device)
112 | if metric == "l2":
113 | d = -torch.linalg.norm(tensor1 - tensor2, dim=-1)
114 | elif metric == "cos":
115 | tensor1 = tensor1 / tensor1.norm(dim=-1, keepdim=True)
116 | tensor2 = tensor2 / tensor2.norm(dim=-1, keepdim=True)
117 | d = torch.nn.CosineSimilarity(-1)(tensor1, tensor2)
118 | else:
119 | raise NotImplementedError
120 | return d
121 |
122 |
123 | PROMPT_DICT = dict(
124 | microwave="open the microwave",
125 | kettle="move the kettle to the top left stove",
126 | light_switch="turn on the light",
127 | hinge_cabinet="open the left hinge cabinet",
128 | slide_cabinet="open the right slide cabinet",
129 | top_burner="turn on the top left burner",
130 | bottom_burner="turn on the bottom left burner",
131 | )
132 | PROMPT_DICT.update({k.replace("_", " "): v for k, v in PROMPT_DICT.items()})
133 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/adept_envs/adept_envs/base_robot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | from collections import deque
18 |
19 | import numpy as np
20 |
21 |
22 | class BaseRobot(object):
23 | """Base class for all robot classes."""
24 |
25 | def __init__(
26 | self,
27 | n_jnt,
28 | n_obj,
29 | pos_bounds=None,
30 | vel_bounds=None,
31 | calibration_path=None,
32 | is_hardware=False,
33 | device_name=None,
34 | overlay=False,
35 | calibration_mode=False,
36 | observation_cache_maxsize=5,
37 | ):
38 | """Create a new robot.
39 |
40 | Args:
41 | n_jnt: The number of dofs in the robot.
42 | n_obj: The number of dofs in the object.
43 | pos_bounds: (n_jnt, 2)-shape matrix denoting the min and max joint
44 | position for each joint.
45 | vel_bounds: (n_jnt, 2)-shape matrix denoting the min and max joint
46 | velocity for each joint.
47 | calibration_path: File path to the calibration configuration file to
48 | use.
49 | is_hardware: Whether to run on hardware or not.
50 | device_name: The device path for the robot hardware. Only required
51 | in legacy mode.
52 | overlay: Whether to show a simulation overlay of the hardware.
53 | calibration_mode: Start with motors disengaged.
54 | """
55 |
56 | assert n_jnt > 0
57 | assert n_obj >= 0
58 |
59 | self._n_jnt = n_jnt
60 | self._n_obj = n_obj
61 | self._n_dofs = n_jnt + n_obj
62 |
63 | self._pos_bounds = None
64 | if pos_bounds is not None:
65 | pos_bounds = np.array(pos_bounds, dtype=np.float32)
66 | assert pos_bounds.shape == (self._n_dofs, 2)
67 | for low, high in pos_bounds:
68 | assert low < high
69 | self._pos_bounds = pos_bounds
70 | self._vel_bounds = None
71 | if vel_bounds is not None:
72 | vel_bounds = np.array(vel_bounds, dtype=np.float32)
73 | assert vel_bounds.shape == (self._n_dofs, 2)
74 | for low, high in vel_bounds:
75 | assert low < high
76 | self._vel_bounds = vel_bounds
77 |
78 | self._is_hardware = is_hardware
79 | self._device_name = device_name
80 | self._calibration_path = calibration_path
81 | self._overlay = overlay
82 | self._calibration_mode = calibration_mode
83 | self._observation_cache_maxsize = observation_cache_maxsize
84 |
85 | # Gets updated
86 | self._observation_cache = deque([], maxlen=self._observation_cache_maxsize)
87 |
88 | @property
89 | def n_jnt(self):
90 | return self._n_jnt
91 |
92 | @property
93 | def n_obj(self):
94 | return self._n_obj
95 |
96 | @property
97 | def n_dofs(self):
98 | return self._n_dofs
99 |
100 | @property
101 | def pos_bounds(self):
102 | return self._pos_bounds
103 |
104 | @property
105 | def vel_bounds(self):
106 | return self._vel_bounds
107 |
108 | @property
109 | def is_hardware(self):
110 | return self._is_hardware
111 |
112 | @property
113 | def device_name(self):
114 | return self._device_name
115 |
116 | @property
117 | def calibration_path(self):
118 | return self._calibration_path
119 |
120 | @property
121 | def overlay(self):
122 | return self._overlay
123 |
124 | @property
125 | def has_obj(self):
126 | return self._n_obj > 0
127 |
128 | @property
129 | def calibration_mode(self):
130 | return self._calibration_mode
131 |
132 | @property
133 | def observation_cache_maxsize(self):
134 | return self._observation_cache_maxsize
135 |
136 | @property
137 | def observation_cache(self):
138 | return self._observation_cache
139 |
140 | def clip_positions(self, positions):
141 | """Clips the given joint positions to the position bounds.
142 |
143 | Args:
144 | positions: The joint positions.
145 |
146 | Returns:
147 | The bounded joint positions.
148 | """
149 | if self.pos_bounds is None:
150 | return positions
151 | assert len(positions) == self.n_jnt or len(positions) == self.n_dofs
152 | pos_bounds = self.pos_bounds[: len(positions)]
153 | return np.clip(positions, pos_bounds[:, 0], pos_bounds[:, 1])
154 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/adept_envs/adept_envs/simulation/sim_robot.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """Module for loading MuJoCo models."""
17 |
18 | import os
19 | from typing import Dict, Optional
20 |
21 | from adept_envs.simulation import module
22 | from adept_envs.simulation.renderer import DMRenderer, MjPyRenderer, RenderMode
23 |
24 |
25 | class MujocoSimRobot:
26 | """Class that encapsulates a MuJoCo simulation.
27 |
28 | This class exposes methods that are agnostic to the simulation backend.
29 | Two backends are supported:
30 | 1. mujoco_py - MuJoCo v1.50
31 | 2. dm_control - MuJoCo v2.00
32 | """
33 |
34 | def __init__(
35 | self,
36 | model_file: str,
37 | use_dm_backend: bool = False,
38 | camera_settings: Optional[Dict] = None,
39 | ):
40 | """Initializes a new simulation.
41 |
42 | Args:
43 | model_file: The MuJoCo XML model file to load.
44 | use_dm_backend: If True, uses DM Control's Physics (MuJoCo v2.0) as
45 | the backend for the simulation. Otherwise, uses mujoco_py (MuJoCo
46 | v1.5) as the backend.
47 | camera_settings: Settings to initialize the renderer's camera. This
48 | can contain the keys `distance`, `azimuth`, and `elevation`.
49 | """
50 | self._use_dm_backend = use_dm_backend
51 |
52 | if not os.path.isfile(model_file):
53 | raise ValueError(
54 | "[MujocoSimRobot] Invalid model file path: {}".format(model_file)
55 | )
56 |
57 | if self._use_dm_backend:
58 | dm_mujoco = module.get_dm_mujoco()
59 | if model_file.endswith(".mjb"):
60 | self.sim = dm_mujoco.Physics.from_binary_path(model_file)
61 | else:
62 | self.sim = dm_mujoco.Physics.from_xml_path(model_file)
63 | self.model = self.sim.model
64 | self._patch_mjlib_accessors(self.model, self.sim.data)
65 | self.renderer = DMRenderer(self.sim, camera_settings=camera_settings)
66 | else: # Use mujoco_py
67 | mujoco_py = module.get_mujoco_py()
68 | self.model = mujoco_py.load_model_from_path(model_file)
69 | self.sim = mujoco_py.MjSim(self.model)
70 | self.renderer = MjPyRenderer(self.sim, camera_settings=camera_settings)
71 |
72 | self.data = self.sim.data
73 |
74 | def close(self):
75 | """Cleans up any resources being used by the simulation."""
76 | self.renderer.close()
77 |
78 | def save_binary(self, path: str):
79 | """Saves the loaded model to a binary .mjb file."""
80 | if os.path.exists(path):
81 | raise ValueError("[MujocoSimRobot] Path already exists: {}".format(path))
82 | if not path.endswith(".mjb"):
83 | path = path + ".mjb"
84 | if self._use_dm_backend:
85 | self.model.save_binary(path)
86 | else:
87 | with open(path, "wb") as f:
88 | f.write(self.model.get_mjb())
89 |
90 | def get_mjlib(self):
91 | """Returns an object that exposes the low-level MuJoCo API."""
92 | if self._use_dm_backend:
93 | return module.get_dm_mujoco().wrapper.mjbindings.mjlib
94 | else:
95 | return module.get_mujoco_py_mjlib()
96 |
97 | def _patch_mjlib_accessors(self, model, data):
98 | """Adds accessors to the DM Control objects to support mujoco_py
99 | API."""
100 | assert self._use_dm_backend
101 | mjlib = self.get_mjlib()
102 |
103 | def name2id(type_name, name):
104 | obj_id = mjlib.mj_name2id(
105 | model.ptr, mjlib.mju_str2Type(type_name.encode()), name.encode()
106 | )
107 | if obj_id < 0:
108 | raise ValueError('No {} with name "{}" exists.'.format(type_name, name))
109 | return obj_id
110 |
111 | if not hasattr(model, "body_name2id"):
112 | model.body_name2id = lambda name: name2id("body", name)
113 |
114 | if not hasattr(model, "geom_name2id"):
115 | model.geom_name2id = lambda name: name2id("geom", name)
116 |
117 | if not hasattr(model, "site_name2id"):
118 | model.site_name2id = lambda name: name2id("site", name)
119 |
120 | if not hasattr(model, "joint_name2id"):
121 | model.joint_name2id = lambda name: name2id("joint", name)
122 |
123 | if not hasattr(model, "actuator_name2id"):
124 | model.actuator_name2id = lambda name: name2id("actuator", name)
125 |
126 | if not hasattr(model, "camera_name2id"):
127 | model.camera_name2id = lambda name: name2id("camera", name)
128 |
129 | if not hasattr(data, "body_xpos"):
130 | data.body_xpos = data.xpos
131 |
132 | if not hasattr(data, "body_xquat"):
133 | data.body_xquat = data.xquat
134 |
--------------------------------------------------------------------------------
/scripts/benchmark_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import time
4 |
5 | import gym
6 | import numpy as np
7 | import torch
8 | import yaml
9 | from omegaconf import DictConfig
10 |
11 | import uvd.utils as U
12 | from uvd.models.preprocessors import get_preprocessor
13 | from uvd.decomp.decomp import embedding_decomp, DEFAULT_DECOMP_KWARGS
14 | from uvd.envs.evaluator.inference_wrapper import InferenceWrapper
15 | from uvd.envs.franka_kitchen.franka_kitchen_base import KitchenBase
16 |
17 | MLP_CFG = """\
18 | policy:
19 | _target_: uvd.models.policy.MLPPolicy
20 | observation_space: ???
21 | action_space: ???
22 | preprocessor: ???
23 | obs_encoder:
24 | __target__: uvd.models.nn.MLP
25 | hidden_dims: [1024, 512, 256]
26 | activation: ReLU
27 | normalization: false
28 | input_normalization: BatchNorm1d
29 | input_normalization_full_obs: false
30 | proprio_output_dim: 512
31 | proprio_add_layernorm: true
32 | proprio_activation: Tanh
33 | proprio_add_noise_eval: false
34 | actor_act: Tanh
35 | act_head:
36 | __target__: uvd.models.distributions.DeterministicHead
37 | """
38 |
39 | GPT_CFG = """\
40 | policy:
41 | _target_: uvd.models.policy.GPTPolicy
42 | observation_space: ???
43 | action_space: ???
44 | preprocessor: ???
45 | use_kv_cache: true
46 | max_seq_length: 10
47 | obs_add: false
48 | proprio_hidden_dim: 512
49 | obs_encoder:
50 | __target__: uvd.models.nn.GPT
51 | use_wte: true
52 | gpt_config:
53 | block_size: 10
54 | vocab_size: null
55 | n_embd: 768
56 | n_layer: 8
57 | n_head: 8
58 | dropout: 0.1
59 | bias: false
60 | use_llama_impl: true
61 | position_embed: rotary
62 | act_head:
63 | __target__: uvd.models.distributions.DeterministicHead
64 | """
65 |
66 | if __name__ == "__main__":
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument("--policy", default="gpt")
69 | parser.add_argument("--preprocessor_name", default="vip")
70 | parser.add_argument("--use_uvd", action="store_true")
71 | parser.add_argument("--n", type=int, default=100)
72 | args = parser.parse_args()
73 |
74 | use_gpu = torch.cuda.is_available()
75 | if not use_gpu:
76 | print("NO GPU FOUND")
77 | preprocessor = get_preprocessor(
78 | args.preprocessor_name, device="cuda" if use_gpu else None
79 | )
80 | policy_name = args.policy.lower()
81 | assert policy_name in ["mlp", "gpt"]
82 | is_causal = policy_name == "gpt"
83 |
84 | env = KitchenBase(frame_height=224, frame_width=224)
85 | env = InferenceWrapper(env, dummy_rtn=is_causal)
86 | env.reset()
87 |
88 | observation_space = gym.spaces.Dict(
89 | rgb=gym.spaces.Box(-np.inf, np.inf, preprocessor.output_dim, np.float32),
90 | proprio=gym.spaces.Box(-1, 1, (9,), np.float32),
91 | milestones=gym.spaces.Box(
92 | -np.inf, np.inf, (6,) + preprocessor.output_dim, np.float32
93 | ),
94 | )
95 | action_space = env.action_space
96 |
97 | cfg = yaml.safe_load(MLP_CFG if policy_name == "mlp" else GPT_CFG)
98 | cfg = DictConfig(cfg)
99 | policy = U.hydra_instantiate(
100 | cfg.policy,
101 | observation_space=observation_space,
102 | action_space=action_space,
103 | preprocessor=preprocessor,
104 | )
105 | policy = policy.to(preprocessor.device).eval()
106 | U.debug_model_info(policy)
107 | if is_causal:
108 | assert policy.causal and policy.use_kv_cache
109 |
110 | preprocessor = policy.preprocessor
111 | # Or load FrankaKitchen dummy datas
112 | dummy_data = np.random.random((300, 224, 224, 3)).astype(np.float32)
113 | emb = preprocessor.process(dummy_data, return_numpy=True)
114 | if args.use_uvd:
115 | _, decomp_meta = embedding_decomp(
116 | embeddings=emb,
117 | fill_embeddings=False,
118 | return_intermediate_curves=False,
119 | **DEFAULT_DECOMP_KWARGS["embed"],
120 | )
121 | milestones = emb[decomp_meta.milestone_indices] # nhw3
122 | else:
123 | milestones = emb[-1][None, ...]
124 | env.milestones = milestones
125 |
126 | MAX_HORIZON = 300
127 | totals = []
128 | for _ in range(args.n):
129 | obs = env.reset()
130 | if is_causal:
131 | policy.reset_cache()
132 |
133 | times = []
134 | for st in range(MAX_HORIZON):
135 | t = time.time()
136 | obs = copy.deepcopy(obs)
137 | batchify_obs = U.batch_observations([obs], device=policy.device)
138 | if is_causal:
139 | # B, T, ...
140 | cur_milestone = env.current_milestone[None, None, ...]
141 | for k in batchify_obs:
142 | batchify_obs[k] = batchify_obs[k][:, None, ...]
143 | else:
144 | # B, ...
145 | cur_milestone = env.current_milestone[None, ...]
146 | with torch.no_grad():
147 | action, obs_embed, goal_embed = policy(
148 | batchify_obs,
149 | goal=torch.as_tensor(cur_milestone, device=policy.device),
150 | deterministic=True,
151 | return_embeddings=True,
152 | input_pos=torch.tensor([st], device=policy.device)
153 | if is_causal
154 | else None,
155 | )
156 | env.current_obs_embedding = obs_embed[0].cpu().numpy()
157 | obs, r, done, info = env.step(action[0].cpu().numpy())
158 | step_t = time.time() - t
159 | times.append(step_t)
160 | times = np.sum(times)
161 | print(times)
162 | totals.append(times)
163 | print(np.mean(totals))
164 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Universal Visual Decomposer:
Long-Horizon Manipulation Made Easy
2 |
3 |
4 |
5 | [[Website]](https://zcczhang.github.io/UVD/)
6 | [[arXiv]](https://arxiv.org/abs/2310.08581)
7 | [[PDF]](https://zcczhang.github.io/UVD/assets/pdf/full_paper.pdf)
8 | [[Installation]](#Installation)
9 | [[Usage]](#Usage)
10 | [[BibTex]](#Citation)
11 | ______________________________________________________________________
12 |
13 |
14 |
15 | https://github.com/zcczhang/UVD/assets/52727818/5555b99a-76eb-4d76-966f-787af763573a
16 |
17 |
18 |
19 |
20 |
21 |
22 | # Installation
23 |
24 | - Follow the [instruction](https://github.com/openai/mujoco-py#install-mujoco) for installing `mujuco-py` and install the following apt packages if using Ubuntu:
25 | ```commandline
26 | sudo apt install -y libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf
27 | ```
28 | - create conda env with Python==3.9
29 | ```commandline
30 | conda create -n uvd python==3.9 -y && conda activate uvd
31 | ```
32 | - Install any/all standalone visual foundation models from their repos separately *before* setup UVD, in case dependency conflicts, e.g.:
33 |
34 | VIP
35 |
36 |
37 |
38 | ```commandline
39 | git clone https://github.com/facebookresearch/vip.git
40 | cd vip && pip install -e .
41 | python -c "from vip import load_vip; vip = load_vip()"
42 | ```
43 |
44 |
45 |
46 |
47 |
48 | R3M
49 |
50 |
51 |
52 | ```commandline
53 | git clone https://github.com/facebookresearch/r3m.git
54 | cd r3m && pip install -e .
55 | python -c "from r3m import load_r3m; r3m = load_r3m('resnet50')"
56 | ```
57 |
58 |
59 |
60 |
61 |
62 | LIV (& CLIP)
63 |
64 |
65 |
66 | ```commandline
67 | git clone https://github.com/penn-pal-lab/LIV.git
68 | cd LIV && pip install -e . && cd liv/models/clip && pip install -e .
69 | python -c "from liv import load_liv; liv = load_liv()"
70 | ```
71 |
72 |
73 |
74 |
75 |
76 |
77 | VC1
78 |
79 |
80 |
81 | ```commandline
82 | git clone https://github.com/facebookresearch/eai-vc.git
83 | cd eai-vc && pip install -e vc_models
84 | ```
85 |
86 |
87 |
88 |
89 |
90 | DINOv2 and ResNet pretrained with ImageNet-1k are directly loaded via torch hub and torchvision.
91 |
92 |
93 | - Under *this* UVD repo directory, install other dependencies
94 | ```commandline
95 | pip install -e .
96 | ```
97 |
98 | # Usage
99 |
100 | We provide a simple API for decompose RGB videos:
101 |
102 | ```python
103 | import torch
104 | import uvd
105 |
106 | # (N sub-goals, *video frame shape)
107 | subgoals = uvd.get_uvd_subgoals(
108 | "/PATH/TO/VIDEO.*", # video filename or (L, *video frame shape) video numpy array
109 | preprocessor_name="vip", # Literal["vip", "r3m", "liv", "clip", "vc1", "dinov2"]
110 | device="cuda" if torch.cuda.is_available() else "cpu", # device for loading frozen preprocessor
111 | return_indices=False, # True if only want the list of subgoal timesteps
112 | )
113 | ```
114 |
115 | or run
116 | ```commandline
117 | python demo.py
118 | ```
119 | to host a Gradio demo locally with different choices of visual representations.
120 |
121 | ## Simulation Data
122 |
123 | We post-processed the data released from original [Relay-Policy-Learning](https://github.com/google-research/relay-policy-learning/tree/master) that keeps the successful trajectories only and adapt the control and observations used in our paper by:
124 | ```commandline
125 | python datasets/data_gen.py raw_data_path=/PATH/TO/RAW_DATA
126 | ```
127 |
128 | Also consider to force set `Builder = LinuxCPUExtensionBuilder` to `Builder = LinuxGPUExtensionBuilder` in `PATH/TO/CONDA/envs/uvd/lib/python3.9/site-packages/mujoco_py/builder.py` to enable (multi-)GPU acceleration.
129 |
130 |
131 | ## Runtime Benchmark
132 |
133 | Since UVD's goal is to be an off-the-shelf method applying to *any* existing policy learning frameworks and models, across BC and RL, we provide minimal scripts for benchmarking the runtime showing negligible runtime under `./scripts` directory:
134 | ```commandline
135 | python scripts/benchmark_decomp.py /PATH/TO/VIDEO
136 | ```
137 | and passing `--preprocessor_name` with other preprocessors (default `vip`) and `--n` for the number of repeated iterations (default `100`).
138 |
139 | For inference or rollouts, we benchmark the runtime by
140 | ```commandline
141 | python scripts/benchmark_inference.py
142 | ```
143 | and passing `--policy` for using MLP or causal GPT policy; `--preprocessor_name` with other preprocessors (default `vip`); `--use_uvd` as boolean arg for whether using UVD or no decomposition (i.e. final goal conditioned); and `--n` for the number of repeated iterations (default `100`). The default episode horizon is set to 300. We found that running in the terminal would be almost 2s slower every episode than directly running with python IDE (e.g. PyCharm, under the script directory and run as script instead of module), but the general trend that including UVD introduces negligible extra runtime still holds true.
144 |
145 | # Citation
146 | If you find this project useful in your research, please consider citing:
147 |
148 | ```bibtex
149 | @inproceedings{zhang2024universal,
150 | title={Universal visual decomposer: Long-horizon manipulation made easy},
151 | author={Zhang, Zichen and Li, Yunshuang and Bastani, Osbert and Gupta, Abhishek and Jayaraman, Dinesh and Ma, Yecheng Jason and Weihs, Luca},
152 | booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)},
153 | pages={6973--6980},
154 | year={2024},
155 | organization={IEEE}
156 | }
157 | ```
158 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/adept_envs/adept_envs/robot_env.py:
--------------------------------------------------------------------------------
1 | """Base class for robotics environments."""
2 |
3 | #!/usr/bin/python
4 | #
5 | # Copyright 2020 Google LLC
6 | #
7 | # Licensed under the Apache License, Version 2.0 (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at
10 | #
11 | # http://www.apache.org/licenses/LICENSE-2.0
12 | #
13 | # Unless required by applicable law or agreed to in writing, software
14 | # distributed under the License is distributed on an "AS IS" BASIS,
15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 | # See the License for the specific language governing permissions and
17 | # limitations under the License.
18 |
19 | import os
20 | from typing import Dict, Optional
21 |
22 | import numpy as np
23 |
24 | from adept_envs import mujoco_env
25 | from adept_envs.base_robot import BaseRobot
26 | from adept_envs.utils.configurable import import_class_from_path
27 | from adept_envs.utils.constants import MODELS_PATH
28 |
29 |
30 | class RobotEnv(mujoco_env.MujocoEnv):
31 | """Base environment for all adept robots."""
32 |
33 | # Mapping of robot name to fully qualified class path.
34 | # e.g. 'robot': 'adept_envs.dclaw.robot.Robot'
35 | # Subclasses should override this to specify the Robot classes they support.
36 | ROBOTS = {}
37 |
38 | # Mapping of device path to the calibration file to use. If the device path
39 | # is not found, the 'default' key is used.
40 | # This can be overriden by subclasses.
41 | CALIBRATION_PATHS = {}
42 |
43 | def __init__(
44 | self,
45 | model_path: str,
46 | robot: BaseRobot,
47 | frame_skip: int,
48 | camera_settings: Optional[Dict] = None,
49 | ):
50 | """Initializes a robotics environment.
51 |
52 | Args:
53 | model_path: The path to the model to run. Relative paths will be
54 | interpreted as relative to the 'adept_models' folder.
55 | robot: The Robot object to use.
56 | frame_skip: The number of simulation steps per environment step. On
57 | hardware this influences the duration of each environment step.
58 | camera_settings: Settings to initialize the simulation camera. This
59 | can contain the keys `distance`, `azimuth`, and `elevation`.
60 | """
61 | self._robot = robot
62 |
63 | # Initial pose for first step.
64 | self.desired_pose = np.zeros(self.n_jnt)
65 |
66 | if not model_path.startswith("/"):
67 | model_path = os.path.abspath(os.path.join(MODELS_PATH, model_path))
68 |
69 | self.remote_viz = None
70 |
71 | try:
72 | from adept_envs.utils.remote_viz import RemoteViz
73 |
74 | self.remote_viz = RemoteViz(model_path)
75 | except ImportError:
76 | pass
77 |
78 | self._initializing = True
79 | super(RobotEnv, self).__init__(
80 | model_path, frame_skip, camera_settings=camera_settings
81 | )
82 | self._initializing = False
83 |
84 | @property
85 | def robot(self):
86 | return self._robot
87 |
88 | @property
89 | def n_jnt(self):
90 | return self._robot.n_jnt
91 |
92 | @property
93 | def n_obj(self):
94 | return self._robot.n_obj
95 |
96 | @property
97 | def skip(self):
98 | """Alias for frame_skip.
99 |
100 | Needed for MJRL.
101 | """
102 | return self.frame_skip
103 |
104 | @property
105 | def initializing(self):
106 | return self._initializing
107 |
108 | def close_env(self):
109 | if self._robot is not None:
110 | self._robot.close()
111 |
112 | def make_robot(
113 | self,
114 | n_jnt,
115 | n_obj=0,
116 | is_hardware=False,
117 | device_name=None,
118 | legacy=False,
119 | **kwargs
120 | ):
121 | """Creates a new robot for the environment.
122 |
123 | Args:
124 | n_jnt: The number of joints in the robot.
125 | n_obj: The number of object joints in the robot environment.
126 | is_hardware: Whether to run on hardware or not.
127 | device_name: The device path for the robot hardware.
128 | legacy: If true, runs using direct dynamixel communication rather
129 | than DDS.
130 | kwargs: See BaseRobot for other parameters.
131 |
132 | Returns:
133 | A Robot object.
134 | """
135 | if not self.ROBOTS:
136 | raise NotImplementedError("Subclasses must override ROBOTS.")
137 |
138 | if is_hardware and not device_name:
139 | raise ValueError("Must provide device name if running on hardware.")
140 |
141 | robot_name = "dds_robot" if not legacy and is_hardware else "robot"
142 | if robot_name not in self.ROBOTS:
143 | raise KeyError(
144 | "Unsupported robot '{}', available: {}".format(
145 | robot_name, list(self.ROBOTS.keys())
146 | )
147 | )
148 |
149 | cls = import_class_from_path(self.ROBOTS[robot_name])
150 |
151 | calibration_path = None
152 | if self.CALIBRATION_PATHS:
153 | if not device_name:
154 | calibration_name = "default"
155 | elif device_name not in self.CALIBRATION_PATHS:
156 | print(
157 | 'Device "{}" not in CALIBRATION_PATHS; using default.'.format(
158 | device_name
159 | )
160 | )
161 | calibration_name = "default"
162 | else:
163 | calibration_name = device_name
164 |
165 | calibration_path = self.CALIBRATION_PATHS[calibration_name]
166 | if not os.path.isfile(calibration_path):
167 | raise OSError(
168 | "Could not find calibration file at: {}".format(calibration_path)
169 | )
170 |
171 | return cls(
172 | n_jnt,
173 | n_obj,
174 | is_hardware=is_hardware,
175 | device_name=device_name,
176 | calibration_path=calibration_path,
177 | **kwargs
178 | )
179 |
--------------------------------------------------------------------------------
/uvd/models/preprocessors/liv_preprocessor.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Optional
4 |
5 | import hydra
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | import omegaconf
9 | import torch
10 |
11 | _LIV_IMPORT_ERROR = None
12 | try:
13 | import liv
14 | except ImportError as e:
15 | _LIV_IMPORT_ERROR = e
16 |
17 | _CLIP_IMPORT_ERROR = None
18 | try:
19 | import clip
20 | except ImportError as e:
21 | _CLIP_IMPORT_ERROR = e
22 |
23 | from torch import nn
24 | from torchvision import transforms as T
25 |
26 | import uvd.utils as U
27 | from uvd.models.preprocessors.base import Preprocessor
28 |
29 |
30 | class LIVPreprocessor(Preprocessor):
31 | def __init__(
32 | self,
33 | model_type: str | None = None,
34 | device: torch.device | str | None = None,
35 | remove_bn: bool = False,
36 | bn_to_gn: bool = False,
37 | remove_pool: bool = False,
38 | preprocess_with_fc: bool = False,
39 | save_fc: bool = False,
40 | random_crop: bool = False,
41 | ckpt: str | None = None,
42 | use_language_goal: bool = False,
43 | ):
44 | if _LIV_IMPORT_ERROR is not None:
45 | raise ImportError(_LIV_IMPORT_ERROR)
46 | model_type = model_type or "resnet50"
47 | self.random_crop = random_crop
48 | self.ckpt = ckpt
49 | if save_fc or preprocess_with_fc:
50 | U.rank_zero_print(f"WARNING: LIV no fc to save", color="red")
51 | save_fc = False
52 | preprocess_with_fc = False
53 | bn_to_gn = False
54 | super().__init__(
55 | model_type=model_type,
56 | device=device,
57 | remove_bn=remove_bn,
58 | bn_to_gn=bn_to_gn,
59 | remove_pool=remove_pool,
60 | preprocess_with_fc=preprocess_with_fc,
61 | save_fc=save_fc,
62 | use_language_goal=use_language_goal,
63 | )
64 |
65 | self._cached_language_embedding = {}
66 |
67 | def _get_model_and_transform(
68 | self, model_type: str | None = None
69 | ) -> tuple[liv.LIV, Optional[T]]:
70 | if model_type is not None:
71 | assert model_type == "resnet50", f"{model_type} not support"
72 | liv.device = self.device
73 | liv_ = load_liv(modelid="resnet50", ckpt_path=self.ckpt).module
74 | clip = liv_.model.to(self.device)
75 | if self.remove_pool:
76 | self._pool = U.freeze_module(clip.visual.attnpool)
77 | self._fc = None
78 | clip.visual.attnpool = nn.Identity()
79 | model = clip
80 | normlayer = liv_.transforms_tensor[-1]
81 | transform = (
82 | T.Compose([T.Resize(224), normlayer])
83 | if not self.random_crop
84 | else T.Compose([T.Resize(232), T.RandomCrop(224), normlayer])
85 | )
86 | return model, transform
87 |
88 | def _encode_image(self, img_tensors: torch.Tensor) -> torch.FloatTensor:
89 | with torch.no_grad():
90 | return self.model.encode_image(img_tensors)
91 |
92 | def _encode_text(
93 | self, text: str | np.ndarray | list | torch.Tensor
94 | ) -> torch.Tensor:
95 | if _CLIP_IMPORT_ERROR is not None:
96 | raise ImportError(_CLIP_IMPORT_ERROR)
97 | if not torch.is_tensor(text):
98 | if isinstance(text, str):
99 | text = [text]
100 | else:
101 | assert isinstance(text, (np.ndarray, list)), type(text)
102 | assert isinstance(text[0], str)
103 | text = clip.tokenize(text).to(self.device)
104 | with torch.no_grad():
105 | return self.model.encode_text(text)
106 |
107 | def encode_text(self, text: str | np.ndarray | list | torch.Tensor) -> torch.Tensor:
108 | return self.cached_language_embed(text)
109 |
110 | def cached_language_embed(self, text: str):
111 | if text in self._cached_language_embedding:
112 | return self._cached_language_embedding[text]
113 | text_embed = self._encode_text(text)
114 | self._cached_language_embedding[text] = text_embed
115 | return text_embed
116 |
117 |
118 | def load_liv(modelid: str = "resnet50", ckpt_path: str | None = None):
119 | if ckpt_path is None:
120 | return liv.load_liv(modelid)
121 | home = U.f_join("~/.liv")
122 | folderpath = U.f_mkdir(home, modelid)
123 | configpath = U.f_join(home, modelid, "config.yaml")
124 | if not U.f_exists(configpath):
125 | try:
126 | liv.hf_hub_download(
127 | repo_id="jasonyma/LIV", filename="config.yaml", local_dir=folderpath
128 | )
129 | except:
130 | configurl = (
131 | "https://drive.google.com/uc?id=1GWA5oSJDuHGB2WEdyZZmkro83FNmtaWl"
132 | )
133 | liv.gdown.download(configurl, configpath, quiet=False)
134 |
135 | modelcfg = omegaconf.OmegaConf.load(configpath)
136 | cleancfg = liv.cleanup_config(modelcfg)
137 | rep = hydra.utils.instantiate(cleancfg)
138 | rep = torch.nn.DataParallel(rep)
139 | vip_state_dict = torch.load(ckpt_path, map_location="cpu")["vip"]
140 | rep.load_state_dict(vip_state_dict)
141 | return rep
142 |
143 |
144 | def sim(tensor1, tensor2, metric: str = "l2", device=None):
145 | if type(tensor1) == np.ndarray:
146 | tensor1 = torch.from_numpy(tensor1).to(device)
147 | tensor2 = torch.from_numpy(tensor2).to(device)
148 | if metric == "l2":
149 | d = -torch.linalg.norm(tensor1 - tensor2, dim=-1)
150 | elif metric == "cos":
151 | tensor1 = tensor1 / tensor1.norm(dim=-1, keepdim=True)
152 | tensor2 = tensor2 / tensor2.norm(dim=-1, keepdim=True)
153 | d = torch.nn.CosineSimilarity(-1)(tensor1, tensor2)
154 | else:
155 | raise NotImplementedError
156 | return d
157 |
158 |
159 | PROMPT_DICT = dict(
160 | microwave="open the microwave",
161 | kettle="move the kettle to the top left stove",
162 | light_switch="turn on the light",
163 | hinge_cabinet="open the left hinge cabinet",
164 | slide_cabinet="open the right slide cabinet",
165 | top_burner="turn on the top left burner",
166 | bottom_burner="turn on the bottom left burner",
167 | )
168 | PROMPT_DICT.update({k.replace("_", " "): v for k, v in PROMPT_DICT.items()})
169 |
--------------------------------------------------------------------------------
/uvd/envs/franka_kitchen/relay-policy-learning/adept_envs/adept_envs/utils/configurable.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | #
3 | # Copyright 2020 Google LLC
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import importlib
18 | import inspect
19 |
20 | from gym.envs.registration import registry as gym_registry
21 |
22 |
23 | def import_class_from_path(class_path):
24 | """Given 'path.to.module:object', imports and returns the object."""
25 | module_path, class_name = class_path.split(":")
26 | module = importlib.import_module(module_path)
27 | return getattr(module, class_name)
28 |
29 |
30 | class ConfigCache(object):
31 | """Configuration class to store constructor arguments.
32 |
33 | This is used to store parameters to pass to Gym environments at init
34 | time.
35 | """
36 |
37 | def __init__(self):
38 | self._configs = {}
39 | self._default_config = {}
40 |
41 | def set_default_config(self, config):
42 | """Sets the default configuration used for all RobotEnv envs."""
43 | self._default_config = dict(config)
44 |
45 | def set_config(self, cls_or_env_id, config):
46 | """Sets the configuration for the given environment within a context.
47 |
48 | Args:
49 | cls_or_env_id (Class | str): A class type or Gym environment ID to
50 | configure.
51 | config (dict): The configuration parameters.
52 | """
53 | config_key = self._get_config_key(cls_or_env_id)
54 | self._configs[config_key] = dict(config)
55 |
56 | def get_config(self, cls_or_env_id):
57 | """Returns the configuration for the given env name.
58 |
59 | Args:
60 | cls_or_env_id (Class | str): A class type or Gym environment ID to
61 | get the configuration of.
62 | """
63 | config_key = self._get_config_key(cls_or_env_id)
64 | config = dict(self._default_config)
65 | config.update(self._configs.get(config_key, {}))
66 | return config
67 |
68 | def clear_config(self, cls_or_env_id):
69 | """Clears the configuration for the given ID."""
70 | config_key = self._get_config_key(cls_or_env_id)
71 | if config_key in self._configs:
72 | del self._configs[config_key]
73 |
74 | def _get_config_key(self, cls_or_env_id):
75 | if inspect.isclass(cls_or_env_id):
76 | return cls_or_env_id
77 | env_id = cls_or_env_id
78 | assert isinstance(env_id, str)
79 | if env_id not in gym_registry.env_specs:
80 | raise ValueError("Unregistered environment name {}.".format(env_id))
81 | entry_point = gym_registry.env_specs[env_id]._entry_point
82 | if callable(entry_point):
83 | return entry_point
84 | else:
85 | return import_class_from_path(entry_point)
86 |
87 |
88 | # Global robot config.
89 | global_config = ConfigCache()
90 |
91 |
92 | def configurable(config_id=None, pickleable=False, config_cache=global_config):
93 | """Class decorator to allow injection of constructor arguments.
94 |
95 | This allows constructor arguments to be passed via ConfigCache.
96 | Example usage:
97 |
98 | @configurable()
99 | class A:
100 | def __init__(b=None, c=2, d='Wow'):
101 | ...
102 |
103 | global_config.set_config(A, {'b': 10, 'c': 20})
104 | a = A() # b=10, c=20, d='Wow'
105 | a = A(b=30) # b=30, c=20, d='Wow'
106 |
107 | Args:
108 | config_id: ID of the config to use. This defaults to the class type.
109 | pickleable: Whether this class is pickleable. If true, causes the pickle
110 | state to include the config and constructor arguments.
111 | config_cache: The ConfigCache to use to read config data from. Uses
112 | the global ConfigCache by default.
113 | """
114 |
115 | def cls_decorator(cls):
116 | assert inspect.isclass(cls)
117 |
118 | # Overwrite the class constructor to pass arguments from the config.
119 | base_init = cls.__init__
120 |
121 | def __init__(self, *args, **kwargs):
122 | config = config_cache.get_config(config_id or type(self))
123 | # Allow kwargs to override the config.
124 | kwargs = {**config, **kwargs}
125 |
126 | # print('Initializing {} with params: {}'.format(type(self).__name__,
127 | # kwargs))
128 |
129 | if pickleable:
130 | self._pkl_env_args = args
131 | self._pkl_env_kwargs = kwargs
132 |
133 | base_init(self, *args, **kwargs)
134 |
135 | cls.__init__ = __init__
136 |
137 | # If the class is pickleable, overwrite the state methods to save
138 | # the constructor arguments and config.
139 | if pickleable:
140 | # Use same pickle keys as gym.utils.ezpickle for backwards compat.
141 | PKL_ARGS_KEY = "_ezpickle_args"
142 | PKL_KWARGS_KEY = "_ezpickle_kwargs"
143 |
144 | def __getstate__(self):
145 | return {
146 | PKL_ARGS_KEY: self._pkl_env_args,
147 | PKL_KWARGS_KEY: self._pkl_env_kwargs,
148 | }
149 |
150 | cls.__getstate__ = __getstate__
151 |
152 | def __setstate__(self, data):
153 | saved_args = data[PKL_ARGS_KEY]
154 | saved_kwargs = data[PKL_KWARGS_KEY]
155 |
156 | # Override the saved state with the current config.
157 | config = config_cache.get_config(config_id or type(self))
158 | # Allow kwargs to override the config.
159 | kwargs = {**saved_kwargs, **config}
160 |
161 | inst = type(self)(*saved_args, **kwargs)
162 | self.__dict__.update(inst.__dict__)
163 |
164 | cls.__setstate__ = __setstate__
165 |
166 | return cls
167 |
168 | return cls_decorator
169 |
--------------------------------------------------------------------------------