├── .gitignore ├── README.md ├── configs ├── eval_base.yaml ├── eval_novel.yaml ├── train_invariant_region.yaml ├── train_region_match.yaml └── train_region_match_fine.yaml ├── data.py ├── eval_base.py ├── eval_novel.py ├── geometry_lib.py ├── heuristics.py ├── network.py ├── requirements.txt ├── train.py └── utils ├── __init__.py ├── ckpt.py ├── clip.py ├── color_remap.py ├── dist.py ├── env.py ├── icp.py ├── layers.py ├── match.py ├── math3d.py ├── metric.py ├── object.py ├── optim.py ├── rollout.py ├── str.py ├── structure.py ├── transfer.py └── vis.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | output 163 | outputs 164 | .vscode 165 | cub-1.10.0 166 | *.avi 167 | *.tar 168 | *.html 169 | .virtual_documents 170 | local_cache 171 | data 172 | wandb 173 | *.pkl 174 | *.prof 175 | *.novel 176 | *.origin_base 177 | weights 178 | datasets 179 | *.tar -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # One-Shot Imitation Learning with Invariance Matching for Robotic Manipulation 2 | 3 | [Paper](https://arxiv.org/abs/2405.13178) [Website](https://mlzxy.github.io/imop) 4 | 5 | ## Install Dependencies 6 | 7 | Use python 3.9 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | # note that it installs torch==1.13.0 12 | # which is necessary for some Open3D torch api 13 | ``` 14 | 15 | Then install [Open3D](https://www.open3d.org/docs/release/compilation.html) and [Pytorch3D](https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md) from source. The tested commits are `7c0acac0a50293c52d2adb70967e729f98fa5018` and `2f11ddc5ee7d6bd56f2fb6744a16776fab6536f7` for Open3D and Pytorch3D, respectively. 16 | 17 | Then install torch_geometric==2.4.0, torch_cluster==1.6.1, and torch_scatter==2.1.1. You may need to download the precompiled wheel https://data.pyg.org/whl/. 18 | 19 | > Please also install RLBench from https://github.com/mlzxy/RLBench. This branch does not read test images during evaluation (as they are not used anyway), provide color information (e.g., when stacking blocks, it makes sure the one-shot demonstration and testing scene has the same target block color), and fixes a small bug of the success criterion of the place cup task. 20 | 21 | 22 | ## Evaluation and Training 23 | 24 | Download datasets and weights from [https://rutgers.box.com/s/icwvszhcb5jvp8zr33htpqboebk2aupq](https://rutgers.box.com/s/icwvszhcb5jvp8zr33htpqboebk2aupq), extract them as `weights` and `datasets` folders. 25 | 26 | 27 | ```bash 28 | # Run evaluation on novel tasks 29 | python3 ./eval_novel.py config=configs/eval_novel.yaml 30 | 31 | # Run evaluation on base tasks 32 | python3 ./eval_base.py config=configs/eval_base.yaml 33 | 34 | # Run training 35 | python3 ./train.py config=configs/train_invariant_region.yaml 36 | python3 ./train.py config=configs/train_{region_match, region_match_fine}.yaml 37 | ``` 38 | 39 | Note this released version simplifies the original implementation by removing the state routing (just using the next key pose), and directly using the RLBench instance mask name to determine the groundtruth of invariant regions. Some clustering heuristics are also used during evaluation to improve the accuracy for certain tasks (details in the [eval_base/novel.py](eval_novel.py)). 40 | 41 | The network implementation is in [network.py](network.py) and [geometry_lib.py](geometry_lib.py). The [geometry_lib.py](geometry_lib.py) is a single file implementation for: 42 | 43 | 1. A set of primitives of batching point cloud of different sizes (`to_dense_batch / to_flat_batch / ...`) 44 | 2. Point transformer (`PointTransformerNetwork`) 45 | 3. KNN graph transformer (`KnnTransformerNetwork / make_knn_transformer_layers / ...`) 46 | 4. DualSoftmaxMatching (`DualSoftmaxReposition`) for the correspondence-based regression. 47 | 5. Utilities such as knn search (`knn/knn_gather/resample`), farthest point sampling (`fps_by_sizes`) and etc. 48 | 49 | 50 | ## Some comments 51 | 52 | I believe region matching is a good idea for solving one-shot manipulation learning. However, many things could be improved. For example: 53 | 54 | 1. It would be better to match image regions, instead of point cloud regions. 55 | 2. Region matching inherently contains multi-modality, which needs to be considered from the beginning. I suggest to checkout the failure cases in the supplementary 56 | https://mlzxy.github.io/imop/supplementary.pdf. 57 | 3. My initial version is based on object-centric representation (selecting masks as invariant regions). Later I remove the object-centric assumption for a supposely more general setting, but this could be wrong because the object-centric version works better and way more robust. 58 | 59 | 60 | 61 | 62 | # Citation 63 | 64 | In case this work is useful 65 | 66 | ```bibtex 67 | @inproceedings{zhang2024oneshot, 68 | title={One-Shot Imitation Learning with Invariance Matching for Robotic Manipulation}, 69 | author={Xinyu Zhang and Abdeslam Boularias}, 70 | booktitle = {Proceedings of Robotics: Science and Systems}, 71 | address = {Delft, Netherlands}, 72 | year = {2024} 73 | } 74 | ``` -------------------------------------------------------------------------------- /configs/eval_base.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | name: eval_base 4 | chdir: false 5 | run: 6 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M} 7 | 8 | output_dir: ${hydra:run.dir} 9 | # ---------------------------------------------------------# 10 | 11 | 12 | clear_output: True 13 | 14 | eval: 15 | episode_num: 1 16 | episode_length: 25 17 | start_episode: 0 18 | headless: true 19 | device: 0 20 | 21 | model_paths: 22 | region_match: "./weights/region_match_44999.pth" 23 | invariant_region: "./weights/invariant_region_44999.pth" 24 | region_match_fine: "./weights/region_match_fine_40000.pth" 25 | 26 | agent: 27 | min_episodes_per_desc: -1 28 | support_episode: -1 # just for debug 29 | debug: False 30 | 31 | 32 | testset_path: ./datasets/base_tests 33 | demoset_path: ./datasets/base_demonstrations -------------------------------------------------------------------------------- /configs/eval_novel.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | name: eval_novel 4 | chdir: false 5 | run: 6 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M} 7 | 8 | output_dir: ${hydra:run.dir} 9 | # ---------------------------------------------------------# 10 | 11 | clear_output: True 12 | 13 | eval: 14 | episode_num: 25 15 | episode_length: 50 16 | start_episode: 0 17 | headless: true 18 | device: 0 19 | 20 | model_paths: 21 | region_match: "./weights/region_match_44999.pth" 22 | invariant_region: "./weights/invariant_region_44999.pth" 23 | region_match_fine: "./weights/region_match_fine_40000.pth" 24 | 25 | agent: 26 | min_episodes_per_desc: -1 27 | support_episode: -1 # debug 28 | debug: False 29 | cache_to: "./datasets/episodes.pkl.novel" 30 | 31 | 32 | testset_path: ./datasets/novel_tests 33 | demoset_path: ./datasets/novel_demonstrations -------------------------------------------------------------------------------- /configs/train_invariant_region.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | name: train_invariant_region 4 | chdir: false 5 | run: 6 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M} 7 | 8 | output_dir: ${hydra:run.dir} 9 | # ---------------------------------------------------------# 10 | 11 | notes: "" 12 | 13 | model: 14 | type: "invariant_region" 15 | 16 | 17 | train: 18 | bs: 5 19 | epochs: 18 20 | num_gpus: 1 21 | num_workers: 8 22 | 23 | num_transitions_per_epoch: 12500 24 | 25 | log_freq: 20 26 | save_freq: 5000 27 | 28 | lr: 1e-3 # per sample, will multiply with bs and world_size, 5e-4 29 | warmup_steps: 2000 30 | 31 | grad_clip_after: 1000 32 | grad_clip_value: 33 | overall: 20.0 34 | 35 | checkpoint: "" 36 | 37 | wandb: False 38 | wandb_alert: False 39 | tensorboard: False 40 | 41 | 42 | data: 43 | grid_size: 0.005 44 | db_path: ./datasets/base_training_set_raw 45 | db_cache: ./datasets/base_training_set_cache # can set to empty 46 | pairs_cache: ./datasets/pairs.pkl 47 | aug: True 48 | correspondence: False 49 | align_twice: False 50 | max_pts: 5000 51 | color_only_instructions: False 52 | include_T: False 53 | noisy_mask: 0.00 54 | -------------------------------------------------------------------------------- /configs/train_region_match.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | name: train_region_match 4 | chdir: false 5 | run: 6 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M} 7 | 8 | output_dir: ${hydra:run.dir} 9 | # ---------------------------------------------------------# 10 | 11 | notes: "" 12 | 13 | model: 14 | type: "region_match" 15 | 16 | 17 | train: 18 | bs: 4 19 | epochs: 15 20 | num_gpus: 1 21 | num_workers: 8 22 | 23 | num_transitions_per_epoch: 12500 24 | 25 | log_freq: 20 26 | save_freq: 5000 27 | 28 | lr: 1e-3 29 | warmup_steps: 2000 30 | 31 | grad_clip_after: 1000 32 | grad_clip_value: 33 | overall: 10.0 34 | 35 | checkpoint: "" 36 | 37 | wandb: False 38 | wandb_alert: False 39 | tensorboard: False 40 | 41 | 42 | data: 43 | grid_size: 0.005 44 | db_path: ./datasets/base_training_set_raw 45 | db_cache: ./datasets/base_training_set_cache # can set to empty 46 | pairs_cache: ./datasets/pairs.pkl 47 | aug: False 48 | correspondence: True 49 | align_twice: True 50 | max_pts: 5000 51 | color_only_instructions: True 52 | include_T: False 53 | noisy_mask: 0.02 -------------------------------------------------------------------------------- /configs/train_region_match_fine.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | job: 3 | name: train_region_match_fine 4 | chdir: false 5 | run: 6 | dir: outputs/${hydra.job.name}/${now:%Y-%m-%d_%H-%M} 7 | 8 | output_dir: ${hydra:run.dir} 9 | # ---------------------------------------------------------# 10 | 11 | notes: "" 12 | 13 | model: 14 | type: "region_match_fine" 15 | 16 | train: 17 | bs: 4 18 | epochs: 15 19 | num_gpus: 1 20 | num_workers: 8 21 | 22 | num_transitions_per_epoch: 12500 23 | 24 | log_freq: 100 25 | save_freq: 5000 26 | 27 | lr: 1e-3 # per sample, will multiply with bs and world_size, 5e-4 28 | warmup_steps: 2000 29 | 30 | grad_clip_after: 500 31 | grad_clip_value: 32 | overall: 8.0 33 | 34 | checkpoint: "" 35 | 36 | wandb: False 37 | wandb_alert: False 38 | tensorboard: False 39 | 40 | 41 | data: 42 | grid_size: 0.005 43 | db_path: ./datasets/base_training_set_raw 44 | db_cache: ./datasets/base_training_set_cache # can set to empty 45 | pairs_cache: ./datasets/pairs.pkl 46 | aug: True 47 | correspondence: True 48 | align_twice: True 49 | max_pts: 5000 50 | color_only_instructions: False 51 | include_T: False 52 | noisy_mask: 0.00 53 | -------------------------------------------------------------------------------- /eval_base.py: -------------------------------------------------------------------------------- 1 | import textdistance 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | from typing import List 5 | import random 6 | import os.path as osp 7 | import torch 8 | torch.set_grad_enabled(False) 9 | import numpy as np 10 | from tqdm import tqdm 11 | from utils.env import rlbench_obs_config, EndEffectorPoseViaPlanning, CustomMultiTaskRLBenchEnv 12 | from rlbench.backend.utils import task_file_to_task_class 13 | from rlbench.action_modes.gripper_action_modes import Discrete 14 | from rlbench.action_modes.action_mode import MoveArmThenGripper 15 | from collections import defaultdict 16 | from termcolor import colored 17 | from utils import configurable, DictConfig, config_to_dict 18 | from utils.structure import BASE_RLBENCH_TASKS, ActResult 19 | from utils.vis import * 20 | from dataclasses import dataclass 21 | 22 | from hdbscan import HDBSCAN 23 | 24 | from utils.ckpt import remove_dict_prefix 25 | from utils.object import Section 26 | import data as dlib 27 | import utils.icp as icplib 28 | cat = dlib.cat 29 | 30 | __dirname = osp.dirname(__file__) 31 | 32 | import heuristics as heu 33 | from network import InvariantRegionNetwork, RegionMatchingNetwork, RegionMatchingNetwork_fine 34 | 35 | 36 | def to_np_obs(obs): 37 | def _get_type(x): 38 | if not hasattr(x, 'dtype'): return np.float32 39 | if x.dtype == np.float64: 40 | return np.float32 41 | return x.dtype 42 | return {k: np.array(v, dtype=_get_type(v)) if not isinstance(v, dict) else v for k, v in obs.items()} 43 | 44 | 45 | def parse_number(s): 46 | digits = [char for char in s if char.isdigit() and len(char) == 1] 47 | if len(digits) > 0: return digits[0] 48 | else: return "" 49 | 50 | 51 | 52 | def load_models(model_paths, dev: torch.device): 53 | m1 = RegionMatchingNetwork() 54 | m1.load_state_dict(remove_dict_prefix( 55 | torch.load(model_paths['region_match'], map_location=dev)['model'], prefix="module.")) 56 | m1 = m1.to(dev).eval() 57 | 58 | m2 = InvariantRegionNetwork() 59 | m2.load_state_dict(remove_dict_prefix( 60 | torch.load(model_paths['invariant_region'], map_location='cpu')['model'], prefix="module.")) 61 | m2 = m2.to(dev).eval() 62 | 63 | m3 = RegionMatchingNetwork_fine() 64 | m3.load_state_dict(remove_dict_prefix( 65 | torch.load(model_paths['region_match_fine'], map_location=dev)['model'], prefix="module.")) 66 | m3 = m3.to(dev).eval() 67 | 68 | return {'region_match': m1, 'invariant_region': m2, 'region_match_fine': m3} 69 | 70 | 71 | def get_datasets(data_path): 72 | db = dlib.RLBenchDataset(tasks= BASE_RLBENCH_TASKS, path=data_path, 73 | grid_size=0.005, min_max_pts_per_obj=5000, 74 | max_episode_num=100) 75 | collate_fn = dlib.RLBenchCollator(use_segmap=False, training=False) 76 | return db, collate_fn 77 | 78 | 79 | @dataclass 80 | class KeyFrame: 81 | type: str = "" 82 | task: str = "" 83 | pcd = None 84 | rgb = None 85 | 86 | cluster_ids = None 87 | cluster_id_set = None 88 | assigned_cluster_id = -1 89 | cluster_map = None 90 | 91 | key_prob_map = None 92 | 93 | item = None 94 | next_item = None 95 | 96 | key_region_not_found = False 97 | 98 | 99 | def get_key_mask(kf: KeyFrame): 100 | if kf.key_region_not_found: 101 | return np.ones_like(kf.key_prob_map).astype(bool) 102 | 103 | if kf.task in tasks_need_clustering: 104 | assert kf.assigned_cluster_id != -1 105 | return kf.cluster_ids == kf.assigned_cluster_id 106 | else: 107 | return kf.key_prob_map > 0.5 108 | 109 | 110 | def smoothen_key_prob(kf: KeyFrame, neighbors=5): 111 | _, idxs = icplib.knn(kf.pcd, kf.pcd, k=neighbors) 112 | prob_map_neighbor = kf.key_prob_map[idxs.flatten()].reshape(-1, neighbors) 113 | kf.key_prob_map = prob_map_neighbor.max(axis=1) 114 | 115 | 116 | def build_cluster_map(prev_kf: KeyFrame, curr_kf: KeyFrame, threshold=0.035): 117 | if curr_kf.task not in tasks_need_clustering: return None 118 | 119 | def compute_centers(clst, clst_map, pcd): 120 | centers = [] 121 | for a in clst: 122 | centers.append(pcd[clst_map == a].mean(axis=0)) 123 | return np.stack(centers) 124 | 125 | prev_centers = compute_centers(prev_kf.cluster_id_set, prev_kf.cluster_ids, prev_kf.pcd) 126 | curr_centers = compute_centers(curr_kf.cluster_id_set, curr_kf.cluster_ids, curr_kf.pcd) 127 | 128 | dists = np.linalg.norm(prev_centers[:, None, :] - curr_centers[None, :, :], axis=-1, ord=1) 129 | closest_ids = dists.argmin(axis=1).flatten() 130 | mapping = {} 131 | for a, b in enumerate(closest_ids): 132 | if dists[a, b] < threshold: 133 | mapping[a] = b 134 | return mapping 135 | 136 | 137 | def find_largest_cluster(key_frame: KeyFrame): 138 | largest_cid = -1 139 | largest_size = -1 140 | for cid in key_frame.cluster_id_set: 141 | size = (key_frame.cluster_ids == cid).sum() 142 | if size > largest_size: 143 | largest_size = size 144 | largest_cid = cid 145 | assert largest_cid != -1 146 | return largest_cid 147 | 148 | 149 | def find_most_salient_cluster(key_frame: KeyFrame, min_cluster_size=10, score_margin=0.1): 150 | prob_map, cluster_map, cluster_id_set = key_frame.key_prob_map, key_frame.cluster_ids, key_frame.cluster_id_set 151 | max_score = -1 152 | max_clst_id = -1 153 | prob_mask = prob_map > 0.5 154 | mean_scores = [(prob_map[prob_mask & (cluster_map == clst_id)]).mean() for clst_id in cluster_id_set] 155 | max_mean_score = max([v for v in mean_scores if not np.isnan(v)]) 156 | for cid in cluster_id_set: 157 | mask = cluster_map == cid 158 | score = prob_map[mask & prob_mask].sum() 159 | if score > max_score and mask.sum() >= min_cluster_size and score + score_margin > max_mean_score: 160 | max_score = score 161 | max_clst_id = cid 162 | return max_clst_id 163 | 164 | 165 | tasks_need_clustering = {'place_cups', 'stack_cups', 'stack_blocks', 'sweep_to_dustpan_of_size', 'slide_block_to_color_target', 'turn_tap'} 166 | 167 | 168 | class EvaluationModelWrapper: 169 | def __init__(self, model_dict, db: dlib.RLBenchDataset, collate_fn, logger=print, 170 | support_episode=-1, min_episodes_per_desc=-1, debug=False): 171 | if support_episode >= 0: logger("Warning: support_episode is SET, this shall only be set in debug") 172 | self.invariant_region, self.region_match, self.region_match_fine = model_dict['invariant_region'], model_dict['region_match'], model_dict['region_match_fine'] 173 | self.db = db 174 | self.device = next(self.region_match_fine.parameters()).device 175 | self.collate = collate_fn 176 | self.debug = debug 177 | self.logger = logger 178 | self.support_episode = support_episode 179 | 180 | self.demo_db = {} 181 | for t in self.db.tasks: 182 | self.demo_db[t] = {} 183 | for e in self.db.get_episodes(t): 184 | desc, vn = self.db.get_desc_and_vn(t, e) 185 | if desc not in self.demo_db[t]: self.demo_db[t][desc] = [] 186 | self.demo_db[t][desc].append({'episode': e }) 187 | if min_episodes_per_desc > 0: 188 | for desc in list(self.demo_db[t].keys()): 189 | if len(self.demo_db[t][desc]) >= min_episodes_per_desc: 190 | self.demo_db[t][desc] = np.random.choice(self.demo_db[t][desc], min_episodes_per_desc).tolist() 191 | self.counter = defaultdict(lambda: 0) 192 | self._init() 193 | 194 | 195 | def _init(self): 196 | self.references: List[KeyFrame] = [] 197 | self.pose_history = [] 198 | self.cursor = 0 199 | 200 | self.last_action = None 201 | 202 | self.current_task = "" 203 | self.current_episode_description = "" 204 | 205 | self.prev_tgt_frame: KeyFrame = None 206 | self.color_instruction = None 207 | self.cluster_id_mapping = {} 208 | 209 | 210 | def reset(self, task, desc, color_instruction=None): 211 | self._init() 212 | self.color_instruction = color_instruction 213 | assert task is not None 214 | self.current_task = task 215 | self.current_episode_description = desc 216 | 217 | if self.support_episode < 0: 218 | if desc in self.demo_db[task]: 219 | candidates = self.demo_db[task][desc] 220 | else: 221 | demo_desc = sorted([(textdistance.levenshtein.distance(desc, demo_desc), demo_desc) for demo_desc in self.demo_db[task]])[0][1] 222 | self.logger(f'\t{desc} not found in existing demonstrations, use demonstrations of `{demo_desc}`') 223 | candidates = self.demo_db[task][demo_desc] 224 | 225 | ref_e = random.choice(candidates)['episode'] 226 | else: 227 | ref_e = self.support_episode 228 | self.logger(f'\tselect support episode = {ref_e}') 229 | 230 | for ind, kf in enumerate(self.db.get_kfs(task, ref_e, exclude_last=False)): 231 | src_t = self.db.get(task, ref_e, kf, training=False) 232 | src_t1 = self.db.get(task, ref_e, src_t['kf_t+1'], training=False) 233 | sample = {'src': {'t': src_t, 't+1': src_t1}, 234 | 'tgt': { 't': src_t, 't+1': src_t1}, 235 | 'match': None, 'index': None} 236 | batch = self.collate([sample, ]) 237 | batch = dlib.to_device(batch, self.device) 238 | iv_region = self.invariant_region(batch, debug=False) 239 | prob = iv_region['output']['prob_map'].flatten().cpu().numpy() 240 | 241 | data = KeyFrame(type="src", task=task) 242 | data.item = src_t 243 | data.next_item = src_t1 244 | data.pcd, data.rgb = src_t['pcd'], src_t['rgb'] 245 | data.key_prob_map = prob 246 | src_key_mask = prob > 0.5 247 | 248 | if src_key_mask.sum() < 15: 249 | if ind > 0: 250 | prev = self.references[ind - 1] 251 | prev_src_key_mask = get_key_mask(prev) 252 | if prev_src_key_mask.sum() < 15: # if previous frame has no key region, then use all, shall be very rare corner case 253 | data.key_prob_map[:] = 1.0 254 | else: 255 | key_pcd = prev.pcd[prev_src_key_mask] # propagate the invariant region a little bit through k-nearst neighbor 256 | _, idxs = icplib.knn(key_pcd, data.pcd, k=3) 257 | data.key_prob_map[:] = 0 258 | data.key_prob_map[idxs.flatten()] = 1.0 259 | smoothen_key_prob(data) 260 | else: 261 | data.key_prob_map[:] = 1.0 # all activated 262 | 263 | if task in tasks_need_clustering: 264 | data.cluster_ids = HDBSCAN(min_cluster_size=15).fit_predict(src_t['pcd']) 265 | data.cluster_id_set = set(np.unique(data.cluster_ids).tolist()) - {-1} 266 | data.assigned_cluster_id = find_most_salient_cluster(data) 267 | else: 268 | if np.all(data.key_prob_map < 0.1): 269 | data.key_region_not_found = True 270 | 271 | self.references.append(data) 272 | 273 | if task in tasks_need_clustering: 274 | for i in range(len(self.references) - 1): 275 | d1, d2 = self.references[i], self.references[i+1] 276 | d2.cluster_map = build_cluster_map(d1, d2) 277 | 278 | self.counter[task] += 1 279 | 280 | 281 | def act(self, obs): 282 | task = self.current_task 283 | self.pose_history.append(obs['gripper_pose']) 284 | 285 | if self.cursor == len(self.references) - 1: 286 | return self.last_action 287 | elif self.cursor >= len(self.references): 288 | return None 289 | 290 | src_frame = self.references[self.cursor] 291 | tgt = self.db.prepare_obs({**obs, 'task': self.current_task, 292 | 'desc': self.current_episode_description}, pose0=self.pose_history[0]) 293 | 294 | tgt_frame = KeyFrame(type='tgt', task=task) 295 | tgt_frame.pcd, tgt_frame.rgb = tgt['pcd'], tgt['rgb'] 296 | tgt_frame.item = tgt 297 | if task in tasks_need_clustering: 298 | tgt_frame.cluster_ids = HDBSCAN(min_cluster_size=15).fit_predict(tgt_frame.pcd) 299 | tgt_frame.cluster_id_set = set(np.unique(tgt_frame.cluster_ids).tolist()) - {-1} 300 | 301 | sample = {'src': {'t': src_frame.item, 't+1': src_frame.next_item}, 302 | 'tgt': { 't': tgt_frame.item, 't+1': tgt_frame.item}, 303 | 'match': None, 'index': None} 304 | src_frame.item['key_mask'] = get_key_mask(src_frame) 305 | 306 | key_region_propagated = False 307 | if self.prev_tgt_frame is not None and task in tasks_need_clustering: 308 | # if clustering is used, then we can propagate the key cluster (region) from previous frame instead of re-estimate 309 | tgt_frame.cluster_map = build_cluster_map(self.prev_tgt_frame, tgt_frame) 310 | 311 | new_cluster_id_mapping = {} 312 | for a, b in self.cluster_id_mapping.items(): 313 | if a in src_frame.cluster_map and b in tgt_frame.cluster_map: 314 | new_cluster_id_mapping[src_frame.cluster_map[a]] = tgt_frame.cluster_map[b] 315 | self.cluster_id_mapping = new_cluster_id_mapping 316 | 317 | if src_frame.assigned_cluster_id in self.cluster_id_mapping: 318 | tgt_frame.assigned_cluster_id = self.cluster_id_mapping[src_frame.assigned_cluster_id] 319 | tgt_frame.key_prob_map = (tgt_frame.cluster_ids == tgt_frame.assigned_cluster_id).astype(np.float32) 320 | key_region_propagated = True 321 | 322 | if not key_region_propagated: 323 | # this piece code is to assist the region matching through a position mask... with icp between invariant region and clusters on target frame 324 | # I kind of forget the exact purpose 325 | if task in {'place_shape_in_shape_sorter', 'put_groceries_in_cupboard'}: 326 | _, pos_id = heu.parse_instructions(task, src_frame.item['desc'], color_only=False)[0] 327 | cluster_ids = HDBSCAN(min_cluster_size=15).fit_predict(tgt_frame.pcd) 328 | cluster_id_set = list(set(np.unique(cluster_ids).tolist()) - {-1}) 329 | src_key_mask = src_frame.key_prob_map > 0.5 330 | src_key_pcd = src_frame.pcd[src_key_mask] 331 | src_key_rgb = src_frame.rgb[src_key_mask] 332 | 333 | src_position_mask = np.full([len(src_frame.pcd)], fill_value=-1, dtype=np.float32) 334 | src_position_mask[src_key_mask] = pos_id 335 | src_frame.item['position_mask'] = src_position_mask 336 | 337 | rgb_distance = [] 338 | for cid in cluster_id_set: 339 | cluster_mask = cluster_ids == cid 340 | tgt_clst_pcd = tgt_frame.pcd[cluster_mask] 341 | tgt_clst_rgb = tgt_frame.rgb[cluster_mask] 342 | X = icplib.icp(src_key_pcd, tgt_clst_pcd).transformation 343 | src_key_pcd_warped = icplib.h_transform(X, src_key_pcd) 344 | 345 | dists, idxs = icplib.knn(src_key_pcd_warped, tgt_clst_pcd) 346 | rgb_dists = np.linalg.norm(tgt_clst_rgb[idxs] - src_key_rgb, axis=-1, ord=1) 347 | 348 | inv_dists, inv_idxs = icplib.knn(tgt_clst_pcd, src_key_pcd_warped) 349 | inv_rgb_dists = np.linalg.norm(tgt_clst_rgb - src_key_rgb[inv_idxs], axis=-1, ord=1) 350 | 351 | rgb_distance.append(rgb_dists.mean() + inv_rgb_dists.mean()) 352 | 353 | tgt_cluster_id = cluster_id_set[np.argmin(rgb_distance)] 354 | 355 | tgt_position_mask = np.full([len(tgt_frame.pcd)], fill_value=-1, dtype=np.float32) 356 | tgt_position_mask[cluster_ids == tgt_cluster_id] = pos_id 357 | tgt_frame.item['position_mask'] = tgt_position_mask 358 | else: 359 | src_frame.item['position_mask'] = heu.get_color_position_mask(task, src_frame.item['desc'], src_frame.item['id2names'], 360 | src_frame.item['rgb'], src_frame.item['mask'], **self.color_instruction) 361 | tgt_frame.item['position_mask'] = heu.get_color_position_mask(task, tgt_frame.item['desc'], tgt_frame.item['id2names'], 362 | tgt_frame.item['rgb'], tgt_frame.item['mask'], **self.color_instruction) 363 | batch = self.collate([sample, ]) 364 | batch = dlib.to_device(batch, self.device) 365 | matched_result = self.region_match(batch) 366 | prob_map = matched_result['output']['conf_matrix'].reshape(-1, len(batch['tgt']['t']['pcd'])).sum(dim=0) 367 | tgt_frame.key_prob_map = prob_map.flatten().cpu().numpy() 368 | smoothen_key_prob(tgt_frame) 369 | 370 | if task in tasks_need_clustering: 371 | try: 372 | tgt_frame.assigned_cluster_id = find_most_salient_cluster(tgt_frame) 373 | except ValueError: 374 | # this branch never reaches 375 | tgt_frame.assigned_cluster_id = find_largest_cluster(tgt_frame) 376 | self.cluster_id_mapping[src_frame.assigned_cluster_id] = tgt_frame.assigned_cluster_id 377 | else: 378 | if np.all(tgt_frame.key_prob_map < 0.1): 379 | tgt_frame.key_region_not_found = True 380 | 381 | 382 | sample = {'src': {'t': src_frame.item, 't+1': src_frame.item}, 383 | 'tgt': { 't': tgt_frame.item, 't+1': tgt_frame.item}, 384 | 'match': None, 'index': None} 385 | tgt_frame.key_region_not_found = src_frame.key_region_not_found = tgt_frame.key_region_not_found | src_frame.key_region_not_found 386 | 387 | src_frame.item['key_mask'] = get_key_mask(src_frame) 388 | tgt_frame.item['key_mask'] = get_key_mask(tgt_frame) 389 | batch = self.collate([sample, ]) 390 | batch = dlib.to_device(batch, self.device) 391 | try: 392 | matched_result_fine = self.region_match_fine(batch) 393 | except Exception as e: 394 | self.logger(f'Exception: {e}') 395 | return None 396 | 397 | estimated_frame = matched_result_fine['output']['predict_frame'][0].cpu().numpy() 398 | 399 | # transfer actions from reference to current frame 400 | frame0 = icplib.pose7_to_frame(self.pose_history[0]) 401 | X_02t = icplib.pose7_to_X(obs['gripper_pose']) @ icplib.inv(icplib.pose7_to_X(self.pose_history[0])) 402 | frame_t = icplib.h_transform(X_02t, frame0) 403 | X_t2tp1 = icplib.Rt_2_X(*icplib.arun(frame_t, estimated_frame)) 404 | X_02tp1 = X_t2tp1 @ X_02t 405 | 406 | next_pose_X = X_02tp1 @ icplib.pose7_to_X(self.pose_history[0]) 407 | next_pose = icplib.X_to_pose7(next_pose_X) 408 | 409 | self.cursor += 1 410 | 411 | self.last_action = ActResult(np.array(list(next_pose) + [src_frame.item['open_t+1'], float(src_frame.item['ignore_col_t+1'])])) 412 | self.prev_tgt_frame = tgt_frame 413 | return self.last_action 414 | 415 | 416 | 417 | def evaluate(agent: EvaluationModelWrapper, episode_length=25, 418 | tasks=BASE_RLBENCH_TASKS, num_episodes=5, headless=True, 419 | testset_path="/home/xinyu/Workspace/RLBench/test", 420 | logger=print, 421 | start_episode=0): 422 | if isinstance(tasks, str): tasks = [tasks] 423 | obs_config = rlbench_obs_config(["front", "left_shoulder", "right_shoulder", "wrist"], [128, 128], method_name="") 424 | 425 | gripper_mode = Discrete() 426 | arm_action_mode = EndEffectorPoseViaPlanning() 427 | action_mode = MoveArmThenGripper(arm_action_mode, gripper_mode) 428 | 429 | task_classes = [task_file_to_task_class(task) for task in tasks] 430 | 431 | try: 432 | eval_env = CustomMultiTaskRLBenchEnv( 433 | task_classes=task_classes, 434 | observation_config=obs_config, 435 | action_mode=action_mode, 436 | dataset_root=testset_path, 437 | episode_length=episode_length, 438 | headless=headless, 439 | swap_task_every=num_episodes, 440 | include_lang_goal_in_obs=True 441 | ) 442 | eval_env.eval = True 443 | eval_env.launch() 444 | scores = defaultdict(list) 445 | 446 | for task_name in tasks: 447 | for ep in range(start_episode, start_episode + num_episodes): 448 | logger(f"{task_name} - {ep}") 449 | episode_rollout = [] 450 | obs = to_np_obs(eval_env.reset_to_demo(ep)) 451 | lang_goal = eval_env._lang_goal 452 | # assuming the color information is available, like when stacking blue blocks, we use a blue block demonstration 453 | color_information = eval_env.get_color_information() 454 | 455 | agent.reset(task=task_name, desc=lang_goal, color_instruction=color_information) 456 | 457 | for step in range(episode_length): 458 | action = agent.act({**obs, 'task': task_name}) 459 | if action is None: 460 | episode_rollout.append(0.0) 461 | break 462 | transition = eval_env.step(action) 463 | obs = dict(transition.observation) 464 | if step == episode_length - 1: 465 | transition.terminal = True 466 | episode_rollout.append(transition.reward) 467 | if transition.terminal: break 468 | 469 | reward = episode_rollout[-1] 470 | scores[task_name].append(reward) 471 | txt = colored(f"\tEvaluating {task_name} | Episode {ep} | Score: {reward} | Episode Length: {len(episode_rollout)} | Lang Goal: {lang_goal}", 'red') 472 | logger(txt) 473 | 474 | for k, values in scores.items(): 475 | logger(f'{k}, {np.mean(values):.02f}') 476 | mean_score = np.mean(list(scores.values())) 477 | logger(f'Average Score: {mean_score}') 478 | return scores 479 | finally: 480 | eval_env.shutdown() 481 | 482 | 483 | @configurable() 484 | def main(cfg: DictConfig): 485 | dev = torch.device(cfg.eval.device) 486 | logfile = open(osp.join(cfg.output_dir, 'log.eval.txt'), "w") 487 | 488 | if cfg.clear_output: 489 | import shutil 490 | if osp.exists('./outputs/eval_vis'): 491 | shutil.rmtree('./outputs/eval_vis/') 492 | 493 | def log(msg, printer=print): 494 | print(msg, file=logfile, flush=True) 495 | printer(msg) 496 | 497 | tasks = BASE_RLBENCH_TASKS 498 | db, collate_fn = get_datasets(cfg.demoset_path) 499 | model_dict = load_models(cfg.eval.model_paths, dev) 500 | 501 | agent = EvaluationModelWrapper(model_dict, db, collate_fn, logger=log, 502 | **config_to_dict(cfg.eval.agent)) 503 | 504 | evaluate(agent, 505 | tasks=tasks, 506 | num_episodes=cfg.eval.episode_num, 507 | headless=cfg.eval.headless, 508 | logger=log, 509 | start_episode=cfg.eval.start_episode, 510 | episode_length=cfg.eval.episode_length, 511 | testset_path=cfg.testset_path) 512 | 513 | if __name__ == "__main__": 514 | main() 515 | -------------------------------------------------------------------------------- /eval_novel.py: -------------------------------------------------------------------------------- 1 | import textdistance 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | from typing import List 5 | import random 6 | import os.path as osp 7 | import torch 8 | torch.set_grad_enabled(False) 9 | import numpy as np 10 | from tqdm import tqdm 11 | from utils.env import rlbench_obs_config, EndEffectorPoseViaPlanning, CustomMultiTaskRLBenchEnv 12 | 13 | from rlbench.backend.utils import task_file_to_task_class 14 | from rlbench.action_modes.gripper_action_modes import Discrete 15 | from rlbench.action_modes.action_mode import MoveArmThenGripper 16 | from collections import defaultdict 17 | from termcolor import colored 18 | from utils import configurable, DictConfig, config_to_dict 19 | from utils.structure import BASE_RLBENCH_TASKS, NOVEL_RLBENCH_TASKS, load_pkl, dump_pkl, ActResult 20 | from utils.vis import * 21 | from dataclasses import dataclass 22 | 23 | from hdbscan import HDBSCAN 24 | 25 | from utils.ckpt import remove_dict_prefix 26 | 27 | import data as dlib 28 | import utils.icp as icplib 29 | cat = dlib.cat 30 | 31 | __dirname = osp.dirname(__file__) 32 | 33 | import heuristics as heu 34 | from network import InvariantRegionNetwork, RegionMatchingNetwork, RegionMatchingNetwork_fine 35 | 36 | 37 | def to_np_obs(obs): 38 | def _get_type(x): 39 | if not hasattr(x, 'dtype'): return np.float32 40 | if x.dtype == np.float64: 41 | return np.float32 42 | return x.dtype 43 | return {k: np.array(v, dtype=_get_type(v)) if not isinstance(v, dict) else v for k, v in obs.items()} 44 | 45 | 46 | def load_models(model_paths, dev: torch.device): 47 | m1 = RegionMatchingNetwork() 48 | m1.load_state_dict(remove_dict_prefix( 49 | torch.load(model_paths['region_match'], map_location=dev)['model'], prefix="module.")) 50 | m1 = m1.to(dev).eval() 51 | 52 | m2 = InvariantRegionNetwork() 53 | m2.load_state_dict(remove_dict_prefix( 54 | torch.load(model_paths['invariant_region'], map_location='cpu')['model'], prefix="module.")) 55 | m2 = m2.to(dev).eval() 56 | 57 | m3 = RegionMatchingNetwork_fine() 58 | m3.load_state_dict(remove_dict_prefix( 59 | torch.load(model_paths['region_match_fine'], map_location=dev)['model'], prefix="module.")) 60 | m3 = m3.to(dev).eval() 61 | 62 | return {'region_match': m1, 'invariant_region': m2, 'region_match_fine': m3} 63 | 64 | 65 | def get_datasets(demoset_path): 66 | db = dlib.RLBenchDataset(tasks=NOVEL_RLBENCH_TASKS , path=demoset_path, 67 | grid_size=0.005, min_max_pts_per_obj=5000, 68 | max_episode_num=5) 69 | collate_fn = dlib.RLBenchCollator(use_segmap=False, training=False) 70 | return db, collate_fn 71 | 72 | 73 | # ======================================================== # 74 | 75 | @dataclass 76 | class KeyFrame: 77 | type: str = "" 78 | task: str = "" 79 | pcd = None 80 | rgb = None 81 | 82 | cluster_ids = None 83 | cluster_id_set = None 84 | assigned_cluster_id = -1 85 | cluster_map = None 86 | 87 | key_prob_map = None 88 | 89 | item = None 90 | next_item = None 91 | 92 | key_region_not_found = False 93 | 94 | 95 | 96 | def get_key_mask(kf: KeyFrame): 97 | if kf.key_region_not_found: 98 | return np.ones_like(kf.key_prob_map).astype(bool) 99 | 100 | if kf.task in tasks_need_clustering: 101 | assert kf.assigned_cluster_id != -1 102 | return kf.cluster_ids == kf.assigned_cluster_id 103 | else: 104 | return kf.key_prob_map > 0.1 105 | 106 | 107 | def smoothen_key_prob(kf: KeyFrame, neighbors=5): 108 | _, idxs = icplib.knn(kf.pcd, kf.pcd, k=neighbors) 109 | prob_map_neighbor = kf.key_prob_map[idxs.flatten()].reshape(-1, neighbors) 110 | kf.key_prob_map = prob_map_neighbor.max(axis=1) 111 | 112 | 113 | 114 | def build_cluster_map(prev_kf: KeyFrame, curr_kf: KeyFrame, threshold=0.035): 115 | if curr_kf.task not in tasks_need_clustering: return None 116 | 117 | def compute_centers(clst, clst_map, pcd): 118 | centers = [] 119 | for a in clst: 120 | centers.append(pcd[clst_map == a].mean(axis=0)) 121 | return np.stack(centers) 122 | 123 | prev_centers = compute_centers(prev_kf.cluster_id_set, prev_kf.cluster_ids, prev_kf.pcd) 124 | curr_centers = compute_centers(curr_kf.cluster_id_set, curr_kf.cluster_ids, curr_kf.pcd) 125 | 126 | dists = np.linalg.norm(prev_centers[:, None, :] - curr_centers[None, :, :], axis=-1, ord=1) 127 | closest_ids = dists.argmin(axis=1).flatten() 128 | mapping = {} 129 | for a, b in enumerate(closest_ids): 130 | if dists[a, b] < threshold: 131 | mapping[a] = b 132 | return mapping 133 | 134 | 135 | def find_most_salient_cluster(key_frame: KeyFrame, min_cluster_size=10, score_margin=0.1): 136 | prob_map, cluster_map, cluster_id_set = key_frame.key_prob_map, key_frame.cluster_ids, key_frame.cluster_id_set 137 | max_score = -1 138 | max_clst_id = -1 139 | prob_mask = prob_map > 0.1 140 | mean_scores = [(prob_map[prob_mask & (cluster_map == clst_id)]).mean() for clst_id in cluster_id_set] 141 | max_mean_score = max([v for v in mean_scores if not np.isnan(v)]) 142 | for cid in cluster_id_set: 143 | mask = cluster_map == cid 144 | score = prob_map[mask & prob_mask].sum() 145 | if score > max_score and mask.sum() >= min_cluster_size and score + score_margin > max_mean_score: 146 | max_score = score 147 | max_clst_id = cid 148 | return max_clst_id 149 | 150 | 151 | tasks_need_clustering = {'block_pyramid', "place_hanger_on_rack", 'lamp_on', 'phone_on_base' } 152 | 153 | 154 | class EvaluationModelWrapper: 155 | def __init__(self, model_dict, db: dlib.RLBenchDataset, collate_fn, logger=print, is_novel=False, 156 | support_episode=-1, min_episodes_per_desc=-1, cache_to="", debug=False): 157 | if support_episode >= 0: logger("Warning: support_episode is SET, this shall only be set in debug") 158 | self.invariant_region, self.region_match, self.region_match_fine = model_dict['invariant_region'], model_dict['region_match'], model_dict['region_match_fine'] 159 | self.db = db 160 | self.device = next(self.region_match_fine.parameters()).device 161 | self.collate = collate_fn 162 | self.debug = debug 163 | self.logger = logger 164 | self.is_novel = is_novel 165 | self.support_episode = support_episode 166 | if is_novel: min_episodes_per_desc = 1 # single demo 167 | 168 | if osp.exists(cache_to): 169 | self.demo_db = load_pkl(cache_to) 170 | else: 171 | self.demo_db = {} 172 | for t in tqdm(self.db.tasks): 173 | self.demo_db[t] = {} 174 | for e in self.db.get_episodes(t): 175 | kfs = self.db.get_kfs(t, e) 176 | desc, vn = self.db.get_desc_and_vn(t, e) 177 | if t == 'stack_blocks': 178 | if '4' not in desc: continue 179 | if t == 'place_cups': 180 | if '3' not in desc: continue 181 | if desc not in self.demo_db[t]: self.demo_db[t][desc] = [] 182 | self.demo_db[t][desc].append({'episode': e, 'ratio': random.random()}) 183 | for desc in list(self.demo_db[t].keys()): 184 | num_full_episodes = len([v['ratio'] for v in self.demo_db[t][desc] if v['ratio'] >= 1.0]) 185 | if num_full_episodes >= min_episodes_per_desc: 186 | self.demo_db[t][desc] = [a for a in self.demo_db[t][desc] if a['ratio'] >= 1.0] 187 | else: 188 | self.demo_db[t][desc] = sorted(self.demo_db[t][desc], reverse=True, 189 | key=lambda x: x['ratio'])[:min_episodes_per_desc] 190 | if cache_to: dump_pkl(cache_to, self.demo_db) 191 | 192 | self.counter = defaultdict(lambda: 0) 193 | self._init() 194 | 195 | 196 | def _init(self): 197 | self.references: List[KeyFrame] = [] # store src information! 198 | self.pose_history = [] 199 | self.cursor = 0 200 | 201 | self.last_action = None 202 | 203 | self.current_task = "" 204 | self.current_episode_description = "" 205 | 206 | self.prev_tgt_frame: KeyFrame = None 207 | self.color_instruction = None 208 | self.cluster_id_mapping = {} 209 | 210 | 211 | def reset(self, task, desc, color_instruction=None): 212 | if task != self.current_task: 213 | self._init() 214 | self._init() 215 | self.color_instruction = color_instruction 216 | assert task is not None 217 | self.current_task = task 218 | self.current_episode_description = desc 219 | 220 | if self.support_episode < 0: 221 | if desc in self.demo_db[task]: 222 | candidates = self.demo_db[task][desc] 223 | else: 224 | demo_desc = sorted([(textdistance.levenshtein.distance(desc, demo_desc), demo_desc) for demo_desc in self.demo_db[task]])[0][1] 225 | self.logger(f'\t{desc} not found in existing demonstrations, use demonstrations of `{demo_desc}`') 226 | candidates = self.demo_db[task][demo_desc] 227 | 228 | ref_e = random.choice(candidates)['episode'] 229 | else: 230 | ref_e = self.support_episode 231 | self.logger(f'\tselect support episode = {ref_e}') 232 | 233 | for ind, kf in enumerate(self.db.get_kfs(task, ref_e, exclude_last=False)): 234 | src_t = self.db.get(task, ref_e, kf, training=False) 235 | src_t1 = self.db.get(task, ref_e, src_t['kf_t+1'], training=False) 236 | sample = {'src': {'t': src_t, 't+1': src_t1}, 237 | 'tgt': { 't': src_t, 't+1': src_t1}, 238 | 'match': None, 'index': None} 239 | batch = self.collate([sample, ]) 240 | batch = dlib.to_device(batch, self.device) 241 | iv_region = self.invariant_region(batch, debug=False) 242 | prob = iv_region['output']['prob_map'].flatten().cpu().numpy() 243 | 244 | data = KeyFrame(type="src", task=task) 245 | data.item = src_t 246 | data.next_item = src_t1 247 | data.pcd, data.rgb = src_t['pcd'], src_t['rgb'] 248 | data.key_prob_map = prob 249 | src_key_mask = prob > 0.1 250 | 251 | if src_key_mask.sum() < 15: 252 | if ind > 0: 253 | prev = self.references[ind - 1] 254 | prev_src_key_mask = get_key_mask(prev) 255 | if prev_src_key_mask.sum() < 15: 256 | data.key_prob_map[:] = 1.0 257 | else: 258 | key_pcd = prev.pcd[prev_src_key_mask] 259 | _, idxs = icplib.knn(key_pcd, data.pcd, k=3) 260 | data.key_prob_map[:] = 0 261 | data.key_prob_map[idxs.flatten()] = 1.0 262 | smoothen_key_prob(data) 263 | else: 264 | data.key_prob_map[:] = 1.0 # all activated 265 | 266 | if task in tasks_need_clustering: 267 | data.cluster_ids = HDBSCAN(min_cluster_size=15).fit_predict(src_t['pcd']) 268 | data.cluster_id_set = set(np.unique(data.cluster_ids).tolist()) - {-1} 269 | data.assigned_cluster_id = find_most_salient_cluster(data) 270 | data.key_prob_map = (data.cluster_ids == data.assigned_cluster_id).astype(np.float32) 271 | 272 | self.references.append(data) 273 | 274 | 275 | if task in tasks_need_clustering: 276 | for i in range(len(self.references) - 1): 277 | d1, d2 = self.references[i], self.references[i+1] 278 | d2.cluster_map = build_cluster_map(d1, d2) 279 | 280 | self.counter[task] += 1 281 | 282 | 283 | def act(self, obs): 284 | task = self.current_task 285 | self.pose_history.append(obs['gripper_pose']) 286 | 287 | if self.cursor == len(self.references) - 1: 288 | if task == 'slide_cabinet_open_and_place_cups': 289 | self.last_action.action[-2] = 1.0 290 | return self.last_action 291 | elif self.cursor >= len(self.references): 292 | return None 293 | 294 | src_frame = self.references[self.cursor] 295 | tgt = self.db.prepare_obs({**obs, 'task': self.current_task, 296 | 'desc': self.current_episode_description}, pose0=self.pose_history[0]) 297 | 298 | tgt_frame = KeyFrame(type='tgt', task=task) 299 | tgt_frame.pcd, tgt_frame.rgb = tgt['pcd'], tgt['rgb'] 300 | tgt_frame.item = tgt 301 | if task in tasks_need_clustering: 302 | tgt_frame.cluster_ids = HDBSCAN(min_cluster_size=15).fit_predict(tgt_frame.pcd) 303 | tgt_frame.cluster_id_set = set(np.unique(tgt_frame.cluster_ids).tolist()) - {-1} 304 | 305 | sample = {'src': {'t': src_frame.item, 't+1': src_frame.next_item}, 306 | 'tgt': { 't': tgt_frame.item, 't+1': tgt_frame.item}, 307 | 'match': None, 'index': None} 308 | src_frame.item['key_mask'] = get_key_mask(src_frame) 309 | 310 | key_region_propagated = False 311 | if self.prev_tgt_frame is not None and task in tasks_need_clustering: 312 | # sometimes find the most salient cluster performs better 313 | tgt_frame.cluster_map = build_cluster_map(self.prev_tgt_frame, tgt_frame) 314 | 315 | new_cluster_id_mapping = {} 316 | for a, b in self.cluster_id_mapping.items(): 317 | if a in src_frame.cluster_map and b in tgt_frame.cluster_map: 318 | new_cluster_id_mapping[src_frame.cluster_map[a]] = tgt_frame.cluster_map[b] 319 | self.cluster_id_mapping = new_cluster_id_mapping 320 | 321 | if src_frame.assigned_cluster_id in self.cluster_id_mapping: 322 | tgt_frame.assigned_cluster_id = self.cluster_id_mapping[src_frame.assigned_cluster_id] 323 | tgt_frame.key_prob_map = (tgt_frame.cluster_ids == tgt_frame.assigned_cluster_id).astype(np.float32) 324 | key_region_propagated = True 325 | 326 | if not key_region_propagated: 327 | src_frame.item['position_mask'] = heu.get_color_position_mask(task, src_frame.item['desc'], src_frame.item['id2names'], 328 | src_frame.item['rgb'], src_frame.item['mask'], **self.color_instruction) 329 | tgt_frame.item['position_mask'] = heu.get_color_position_mask(task, tgt_frame.item['desc'], tgt_frame.item['id2names'], 330 | tgt_frame.item['rgb'], tgt_frame.item['mask'], **self.color_instruction) 331 | batch = self.collate([sample, ]) 332 | batch = dlib.to_device(batch, self.device) 333 | matched_result = self.region_match(batch) 334 | prob_map = matched_result['output']['conf_matrix'].reshape(-1, len(batch['tgt']['t']['pcd'])).sum(dim=0) 335 | tgt_frame.key_prob_map = prob_map.flatten().cpu().numpy() 336 | 337 | smoothen_key_prob(tgt_frame) 338 | 339 | if task in tasks_need_clustering: 340 | tgt_frame.assigned_cluster_id = find_most_salient_cluster(tgt_frame) 341 | self.cluster_id_mapping[src_frame.assigned_cluster_id] = tgt_frame.assigned_cluster_id 342 | else: 343 | if np.all(tgt_frame.key_prob_map < 0.1): 344 | tgt_frame.key_region_not_found = True 345 | 346 | sample = {'src': {'t': src_frame.item, 't+1': src_frame.item}, 347 | 'tgt': { 't': tgt_frame.item, 't+1': tgt_frame.item}, 348 | 'match': None, 'index': None} 349 | tgt_frame.key_region_not_found = src_frame.key_region_not_found = tgt_frame.key_region_not_found | src_frame.key_region_not_found 350 | 351 | src_frame.item['key_mask'] = get_key_mask(src_frame) 352 | tgt_frame.item['key_mask'] = get_key_mask(tgt_frame) 353 | batch = self.collate([sample, ]) 354 | batch = dlib.to_device(batch, self.device) 355 | try: 356 | matched_result_fine = self.region_match_fine(batch) 357 | except Exception as e: 358 | self.logger(f'Exception: {e}') 359 | return None 360 | 361 | 362 | estimated_frame = matched_result_fine['output']['predict_frame'][0].cpu().numpy() 363 | t = estimated_frame[0] 364 | if t[2] < 0.75: 365 | estimated_frame[:, 2] = (0.75 - estimated_frame[:, 2]) + 0.75 366 | 367 | frame0 = icplib.pose7_to_frame(self.pose_history[0]) 368 | X_02t = icplib.pose7_to_X(obs['gripper_pose']) @ icplib.inv(icplib.pose7_to_X(self.pose_history[0])) 369 | frame_t = icplib.h_transform(X_02t, frame0) 370 | X_t2tp1 = icplib.Rt_2_X(*icplib.arun(frame_t, estimated_frame)) 371 | X_02tp1 = X_t2tp1 @ X_02t 372 | 373 | next_pose_X = X_02tp1 @ icplib.pose7_to_X(self.pose_history[0]) 374 | next_pose = icplib.X_to_pose7(next_pose_X) 375 | 376 | self.cursor += 1 377 | if task in ['stack_blocks', 'stack_cups']: # 378 | src_frame.item['ignore_col_t+1'] = False 379 | 380 | self.last_action = ActResult(np.array(list(next_pose) + [src_frame.item['open_t+1'], float(src_frame.item['ignore_col_t+1'])])) 381 | self.prev_tgt_frame = tgt_frame 382 | return self.last_action 383 | 384 | 385 | 386 | def evaluate(agent: EvaluationModelWrapper, episode_length=25, 387 | tasks=BASE_RLBENCH_TASKS, num_episodes=5, headless=True, 388 | testset_path="", 389 | logger=print, 390 | start_episode=0): 391 | if isinstance(tasks, str): tasks = [tasks] 392 | obs_config = rlbench_obs_config(["front", "left_shoulder", "right_shoulder", "wrist"], [128, 128], method_name="") 393 | PCD, RGB, MASK = 0, 1, 2 394 | 395 | gripper_mode = Discrete() 396 | arm_action_mode = EndEffectorPoseViaPlanning() 397 | action_mode = MoveArmThenGripper(arm_action_mode, gripper_mode) 398 | 399 | task_classes = [task_file_to_task_class(task) for task in tasks] 400 | 401 | try: 402 | eval_env = CustomMultiTaskRLBenchEnv( 403 | task_classes=task_classes, 404 | observation_config=obs_config, 405 | action_mode=action_mode, 406 | dataset_root=testset_path, 407 | episode_length=episode_length, 408 | headless=headless, 409 | swap_task_every=num_episodes, 410 | include_lang_goal_in_obs=True 411 | ) 412 | eval_env.eval = True 413 | eval_env.launch() 414 | scores = defaultdict(list) 415 | 416 | for task_name in tasks: 417 | for ep in range(start_episode, start_episode + num_episodes): 418 | logger(f"{task_name} - {ep}") 419 | episode_rollout = [] 420 | # transitions 421 | obs = to_np_obs(eval_env.reset_to_demo(ep)) 422 | lang_goal = eval_env._lang_goal 423 | color_information = eval_env.get_color_information() 424 | 425 | agent.reset(task=task_name, desc=lang_goal, color_instruction=color_information) 426 | 427 | for step in range(episode_length): 428 | action = agent.act({**obs, 'task': task_name}) 429 | if action is None: 430 | episode_rollout.append(0.0) 431 | break 432 | transition = eval_env.step(action) 433 | obs = dict(transition.observation) 434 | if step == episode_length - 1: 435 | transition.terminal = True 436 | episode_rollout.append(transition.reward) 437 | if transition.terminal: break 438 | 439 | reward = episode_rollout[-1] 440 | scores[task_name].append(reward) 441 | txt = colored(f"\tEvaluating {task_name} | Episode {ep} | Score: {reward} | Episode Length: {len(episode_rollout)} | Lang Goal: {lang_goal}", 'red') 442 | logger(txt) 443 | 444 | for k, values in scores.items(): 445 | logger(f'{k}, {np.mean(values):.02f}') 446 | mean_score = np.mean(list(scores.values())) 447 | logger(f'Average Score: {mean_score}') 448 | return scores 449 | finally: 450 | # pass 451 | eval_env.shutdown() 452 | 453 | 454 | @configurable() 455 | def main(cfg: DictConfig): 456 | dev = torch.device(cfg.eval.device) 457 | logfile = open(osp.join(cfg.output_dir, 'log.eval.txt'), "w") 458 | 459 | if cfg.clear_output: 460 | import shutil 461 | if osp.exists('./outputs/eval_vis'): 462 | shutil.rmtree('./outputs/eval_vis/') 463 | 464 | def log(msg, printer=print): 465 | print(msg, file=logfile, flush=True) 466 | printer(msg) 467 | 468 | tasks = NOVEL_RLBENCH_TASKS 469 | db, collate_fn = get_datasets(cfg.demoset_path) 470 | model_dict = load_models(cfg.eval.model_paths, dev) 471 | 472 | agent = EvaluationModelWrapper(model_dict, db, collate_fn, logger=log, is_novel=True, 473 | **config_to_dict(cfg.eval.agent)) 474 | evaluate(agent, 475 | tasks=tasks, 476 | num_episodes=cfg.eval.episode_num, 477 | headless=cfg.eval.headless, 478 | testset_path=cfg.testset_path, 479 | logger=log, 480 | start_episode=cfg.eval.start_episode, 481 | episode_length=cfg.eval.episode_length) 482 | 483 | if __name__ == "__main__": 484 | main() 485 | -------------------------------------------------------------------------------- /heuristics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | """ 3 | put_item_in_drawer, 4 | reach_and_drag, 5 | turn_tap, 6 | slide_block_to_color_target, 7 | open_drawer, 8 | put_groceries_in_cupboard, 9 | place_shape_in_shape_sorter, 10 | put_money_in_safe, 11 | push_buttons, 12 | close_jar, 13 | stack_blocks, 14 | place_cups, 15 | place_wine_at_rack_location, 16 | light_bulb_in, 17 | sweep_to_dustpan_of_size, 18 | insert_onto_square_peg, 19 | meat_off_grill, 20 | stack_cups 21 | """ 22 | 23 | SPATIAL_DIRECTIVES = ['top', 'middle', 'bottom', 'left', 'right'] 24 | 25 | def parse_spatial_directive(desc): 26 | for n in SPATIAL_DIRECTIVES: 27 | if n in desc: 28 | return n 29 | 30 | def parse_number(desc): 31 | for i in range(1, 5): 32 | if str(i) in desc: return i 33 | 34 | 35 | def object_shall_be_movable(task, desc, oname): 36 | exact = False 37 | if task == 'put_item_in_drawer': 38 | words = ['item', parse_spatial_directive(desc)] 39 | elif task == 'reach_and_drag': 40 | words = ['cube', 'stick'] 41 | elif task in ['turn_tap', 'open_drawer']: 42 | words = [parse_spatial_directive(desc)] 43 | elif task == 'slide_block_to_color_target': 44 | words = ['block'] 45 | elif task == 'put_groceries_in_cupboard': 46 | oname = oname.split("_")[0] 47 | return oname in desc and oname != 'cupboard' 48 | elif task == 'place_shape_in_shape_sorter': 49 | oname = oname.split("_")[0] 50 | return oname in desc and oname != 'shape' 51 | elif task == 'put_money_in_safe': 52 | words = ['dollar_stack'] 53 | elif task == 'push_buttons': 54 | return False 55 | elif task == 'close_jar': 56 | words = ['jar_lid0'] 57 | elif task == 'stack_blocks': 58 | return [f'target{i}' for i in range(parse_number(desc))] 59 | elif task == 'place_cups': 60 | words = [f'mug_visual{i}' for i in range(parse_number(desc))] 61 | elif task == 'place_wine_at_rack_location': 62 | words = ['wine_bottle'] 63 | elif task == 'light_bulb_in': 64 | words = ['bulb1', 'bulb0'] 65 | exact = True 66 | elif task == 'sweep_to_dustpan_of_size': 67 | words = ['broom_visual', 'dirt0'] 68 | elif task == 'insert_onto_square_peg': 69 | words = ['square_ring'] 70 | elif task == 'meat_off_grill': 71 | words = ['chicken', 'steak'] 72 | elif task == 'stack_cups': 73 | words = ['cup'] 74 | else: 75 | raise KeyError('unrecognized task: ' + task) 76 | if exact: 77 | return any([w == oname for w in words]) 78 | else: 79 | return any([w in oname for w in words]) 80 | 81 | 82 | def number_of_movable_objects_at_once(task, desc): 83 | if task in ['reach_and_drag', 'sweep_to_dustpan_of_size']: 84 | return 2 85 | else: 86 | return 1 87 | 88 | 89 | colors = dict([ 90 | ('red', (1.0, 0.0, 0.0)), 91 | ('maroon', (0.5, 0.0, 0.0)), 92 | ('lime', (0.0, 1.0, 0.0)), 93 | ('green', (0.0, 0.5, 0.0)), 94 | ('blue', (0.0, 0.0, 1.0)), 95 | ('navy', (0.0, 0.0, 0.5)), 96 | ('yellow', (1.0, 1.0, 0.0)), 97 | ('cyan', (0.0, 1.0, 1.0)), 98 | ('magenta', (1.0, 0.0, 1.0)), 99 | ('silver', (0.75, 0.75, 0.75)), 100 | ('gray', (0.5, 0.5, 0.5)), 101 | ('orange', (1.0, 0.5, 0.0)), 102 | ('olive', (0.5, 0.5, 0.0)), 103 | ('purple', (0.5, 0.0, 0.5)), 104 | ('teal', (0, 0.5, 0.5)), 105 | ('azure', (0.0, 0.5, 1.0)), 106 | ('violet', (0.5, 0.0, 1.0)), 107 | ('rose', (1.0, 0.0, 0.5)), 108 | ('black', (0.0, 0.0, 0.0)), 109 | ('white', (1.0, 1.0, 1.0)), 110 | ('pink', (0.95, 0.075, 0.54)) 111 | ]) 112 | 113 | COLOR_NAMES = list(colors.keys()) 114 | 115 | def find_color_directive(desc, all_colors=None): 116 | cs = [] 117 | if all_colors is None: all_colors = colors.keys() 118 | for c in all_colors: 119 | c = f' {c} ' 120 | if c in desc: 121 | cs.append((desc.index(c), c.strip())) 122 | cs = sorted(cs) 123 | return [c for _, c in cs] 124 | 125 | 126 | GROCERY_NAMES = [ 127 | 'crackers', 128 | 'chocolate jello', 129 | 'strawberry jello', 130 | 'soup', 131 | 'tuna', 132 | 'spam', 133 | 'coffee', 134 | 'mustard', 135 | 'sugar', 136 | ] 137 | 138 | SHAPE_NAMES = ['cube', 'cylinder', 'triangular prism', 'star', 'moon'] 139 | 140 | COLOR_RVT_TASKS = [ 141 | "reach_and_drag", 142 | "push_buttons", 143 | "close_jar", 144 | "stack_blocks", 145 | "light_bulb_in", 146 | "insert_onto_square_peg", 147 | "stack_cups", 148 | "block_pyramid", 149 | "slide_block_to_color_target" 150 | ] 151 | 152 | def all_instructions(color_only=False): 153 | if color_only: 154 | return COLOR_NAMES 155 | else: 156 | ALL_INSTRUCTIONS = GROCERY_NAMES + SHAPE_NAMES + COLOR_NAMES + SPATIAL_DIRECTIVES 157 | return ALL_INSTRUCTIONS 158 | 159 | def find_tag_indexes(desc, tags): 160 | indexes = [] 161 | for t in tags: 162 | indexes.append(desc.index(t)) 163 | return indexes 164 | 165 | def parse_instructions(task, desc, color_only=False): 166 | tags = [] 167 | if color_only: 168 | if task in COLOR_RVT_TASKS: 169 | tags = find_color_directive(desc, all_instructions(color_only)) 170 | else: 171 | if task == 'place_shape_in_shape_sorter': 172 | for g in SHAPE_NAMES: 173 | if g in desc: 174 | tags.append(g) 175 | elif task == 'put_groceries_in_cupboard': 176 | for g in GROCERY_NAMES: 177 | if g in desc: 178 | tags.append(g) 179 | elif task in ['open_drawer', 'put_item_in_drawer', 'turn_tap']: 180 | tags = [parse_spatial_directive(desc), ] 181 | else: 182 | tags = find_color_directive(desc) 183 | indexes = find_tag_indexes(desc, tags) 184 | return list(zip(tags, indexes)) 185 | 186 | 187 | def list_index(lst, v): 188 | if v in lst: 189 | return lst.index(v) 190 | else: 191 | return -1 192 | 193 | 194 | def assign_instruction_class_to_object(object_names, task, desc, targets=None, color_only=False): 195 | instructions = parse_instructions(task, desc, color_only=color_only) 196 | if len(instructions) == 0: return [-1] * len(object_names) 197 | ALL_INSTRUCTIONS = all_instructions(color_only) 198 | if color_only and task not in COLOR_RVT_TASKS: 199 | return [-1] * len(object_names) 200 | 201 | if task == 'put_groceries_in_cupboard': 202 | grocery_names = [a.replace(' ', '_') + '_visual' for a in GROCERY_NAMES] 203 | indexes = [list_index(grocery_names, obj) for obj in object_names] 204 | class_indexes = [ALL_INSTRUCTIONS.index(GROCERY_NAMES[i]) if i != -1 else i for i in indexes] 205 | 206 | elif task == 'place_shape_in_shape_sorter': 207 | shape_names = [(a.replace(' ', '_') + '_visual') if a in ('star', 'moon') else a for a in SHAPE_NAMES] 208 | indexes = [list_index(shape_names, obj) for obj in object_names] 209 | class_indexes = [ALL_INSTRUCTIONS.index(SHAPE_NAMES[i]) if i != -1 else i for i in indexes] 210 | 211 | elif task == 'slide_block_to_color_target': 212 | class_indexes = [] 213 | for o in object_names: 214 | ind = list_index(['target1', 'target2', 'target3', 'target4'], o) 215 | if ind != -1: 216 | ind = ALL_INSTRUCTIONS.index(['green', 'blue', 'pink', 'yellow'][ind]) 217 | class_indexes.append(ind) 218 | 219 | elif task in ['open_drawer', 'put_item_in_drawer',]: 220 | class_indexes = [] 221 | for o in object_names: 222 | ind = list_index(['drawer_top', 'drawer_middle', 'drawer_bottom'], o) 223 | if ind != -1: 224 | ind = ALL_INSTRUCTIONS.index(['top', 'middle', 'bottom'][ind]) 225 | class_indexes.append(ind) 226 | 227 | elif task == 'turn_tap': 228 | class_indexes = [] 229 | for o in object_names: 230 | ind = list_index(['tap_right_visual', 'tap_left_visual'], o) 231 | if ind != -1: 232 | ind = ALL_INSTRUCTIONS.index(['right', 'left'][ind]) 233 | class_indexes.append(ind) 234 | 235 | elif task == 'stack_cups': 236 | assert len(instructions) == 1 237 | target_color = instructions[0][0] 238 | class_indexes = [ALL_INSTRUCTIONS.index(target_color) if o == 'cup2_visual' else -1 for o in object_names] 239 | 240 | elif task == 'push_buttons': 241 | assert len(instructions) >= 1 242 | target_colors = [c for c, _ in instructions] 243 | target_names = [f'push_buttons_target{i}' for i in range(len(instructions))] 244 | class_indexes = [] 245 | for o in object_names: 246 | ind = list_index(target_names, o) 247 | if ind != -1: 248 | ind = ALL_INSTRUCTIONS.index(target_colors[ind]) 249 | class_indexes.append(ind) 250 | 251 | elif task == 'stack_blocks': 252 | class_indexes = [] 253 | assert len(instructions) == 1 254 | target_color = instructions[0][0] 255 | target_color_clsind = ALL_INSTRUCTIONS.index(target_color) 256 | for name in object_names: 257 | if 'target' in name and 'target_plane' not in name: 258 | class_indexes.append(target_color_clsind) 259 | else: 260 | class_indexes.append(-1) 261 | 262 | elif task == 'reach_and_drag': 263 | assert len(instructions) == 1 264 | target_color = instructions[0][0] 265 | target_color_clsind = ALL_INSTRUCTIONS.index(target_color) 266 | class_indexes = [target_color_clsind if 'target' in o else -1 for o in object_names] 267 | 268 | elif task in [ 'light_bulb_in', 'close_jar', 'insert_onto_square_peg']: 269 | assert len(instructions) == 1 270 | assert targets is not None 271 | target_color = instructions[0][0] 272 | target_color_clsind = ALL_INSTRUCTIONS.index(target_color) 273 | class_indexes = [target_color_clsind if o in targets else -1 for o in object_names] 274 | 275 | else: 276 | raise KeyError(task) 277 | 278 | return class_indexes 279 | 280 | 281 | 282 | def extend_key_objects(item): 283 | task, desc = item['task'], item['desc'] 284 | if item['key_id'] == -1: return [-1] 285 | result = [item['key_id'], ] 286 | def find_ids(names): return [item['name2ids'][n] for n in names] 287 | 288 | if task in [ "push_buttons", "meat_off_grill",]: 289 | pass 290 | elif task == "sweep_to_dustpan_of_size": 291 | pass 292 | # if "dustpan" in item['key_name'] and 'broom' not in item['key_name']: 293 | # result.append(item['name2ids']['dirt0']) 294 | elif task == "put_money_in_safe": 295 | if 'safe' in item['key_name']: 296 | result = find_ids(['safe_body']) 297 | # center_body = item['pcd'][item['mask'] == item['name2ids']['safe_body']].max(axis=0) 298 | # center_dollar = item['pcd'][item['mask'] == item['name2ids']['dollar_stack']].max(axis=0) 299 | # if center_body[-1] > center_dollar[-1]: 300 | # result += find_ids(['dollar_stack']) 301 | elif task in ["put_item_in_drawer", "open_drawer"]: 302 | entire_drawer = [i for i, k in item['id2names'].items() if 'drawer' in k] 303 | if task == 'open_drawer': 304 | result = entire_drawer 305 | else: 306 | if item['key_name'] == 'drawer_frame': 307 | result = entire_drawer 308 | elif 'drawer' in item['key_name']: 309 | if item['kf_t'] < 200: 310 | result = entire_drawer 311 | 312 | elif task == "reach_and_drag": 313 | pass 314 | # if item['key_name'] == 'target0': 315 | # result = find_ids(['target0', 'cube']) 316 | elif task == "slide_block_to_color_target": 317 | if 'target' in item['key_name']: 318 | targets = ['target1', 'target2', 'target3', 'target4'] 319 | colors = ['green', 'blue', 'pink', 'yellow'] 320 | for i, (c, t) in enumerate(zip(colors, targets)): 321 | if c in desc: 322 | result = [item['name2ids'][t],] 323 | break 324 | elif task == "turn_tap": 325 | if 'right' in desc: 326 | result = find_ids(['tap_right_visual']) 327 | else: 328 | result = find_ids(['tap_left_visual']) 329 | elif task == "put_groceries_in_cupboard": 330 | for g in GROCERY_NAMES: 331 | if g in desc: 332 | g = g.replace(' ', '_') + '_visual' 333 | break 334 | if 'cupboard' not in item['key_name']: 335 | result = find_ids([g]) 336 | else: 337 | pass 338 | # center_cupboard =item['pcd'][item['mask'] == item['name2ids']['cupboard']].mean(axis=0) 339 | # center_obj = item['pcd'][item['mask'] == item['name2ids'][g]].mean(axis=0) 340 | # if abs(center_obj[-1] - center_cupboard[-1]) < 0.05: 341 | # result.append(item['name2ids'][g]) 342 | elif task == "place_shape_in_shape_sorter": 343 | for g in SHAPE_NAMES: 344 | if g in desc: 345 | g = g.replace(' ', '_') 346 | g = (g + '_visual') if g in ('star', 'moon') else g 347 | break 348 | if 'shape_sorter' in item['key_name']: 349 | result = find_ids(['shape_sorter_visual', 'shape_sorter']) 350 | # center_sorter = item['pcd'][item['mask'] == item['name2ids']['shape_sorter_visual']].mean(axis=0) 351 | # center_obj = item['pcd'][item['mask'] == item['name2ids'][g]].mean(axis=0) 352 | # if center_obj[-1] > center_sorter[-1]: 353 | # result.append(item['name2ids'][g]) 354 | else: 355 | result = find_ids([g]) 356 | elif task == "close_jar": 357 | pass 358 | # if 'lid' not in item['key_name']: 359 | # center_lid = item['pcd'][item['mask'] == item['name2ids']['jar_lid0']].mean(axis=0)[:2] 360 | # center_jar = item['pcd'][item['mask'] == item['key_id']].mean(axis=0)[:2] 361 | # if np.linalg.norm(center_jar - center_lid) <= 0.03: 362 | # result = [item['key_id'], item['name2ids']['jar_lid0']] 363 | elif task == "stack_blocks": 364 | # ['stack_blocks_target'] 365 | center_plane = item['pcd'][item['mask'] == item['name2ids']['stack_blocks_target_plane']].mean(axis=0)[:2] 366 | center_key = item['pcd'][item['mask'] == item['key_id']].mean(axis=0)[:2] 367 | if np.linalg.norm(center_key - center_plane) <= 0.03: # stacking mode 368 | names = ['stack_blocks_target_plane'] 369 | for i in range(4): 370 | _n = f'stack_blocks_target{i}' 371 | center_block = item['pcd'][item['mask'] == item['name2ids'][_n]].mean(axis=0)[:2] 372 | if np.linalg.norm(center_block - center_plane) <= 0.03 and _n != item['grasp_name']: 373 | names.append(_n) 374 | result = find_ids(names) 375 | elif task == "place_wine_at_rack_location": 376 | if "rack" in item['key_name']: 377 | result = find_ids(['rack_top_visual', 'rack_bottom_visual']) 378 | # if item['kf_t'] > 145: 379 | # result.append(item['name2ids']['wine_bottle_visual']) 380 | elif task == "light_bulb_in": 381 | if 'lamp' in item['key_name']: 382 | center_screw = item['pcd'][item['mask'] == item['name2ids']['lamp_screw']].mean(axis=0)[:2] 383 | center_bulb = item['pcd'][item['mask'] == item['grasp_id']].mean(axis=0)[:2] if item['grasp_id'] != -1 else 0 384 | if np.linalg.norm(center_bulb - center_screw) <= 0.03: 385 | result = find_ids(['lamp_base', 'lamp_screw']) + [item['grasp_id'],] 386 | else: 387 | result = find_ids(['lamp_base', 'lamp_screw']) 388 | elif 'holder' in item['key_name']: 389 | bulb_name = item['key_name'].replace('_holder', '') 390 | center_holder = item['pcd'][item['mask'] == item['key_id']].mean(axis=0) 391 | center_bulb = item['pcd'][item['mask'] == item['name2ids'][bulb_name]].mean(axis=0) 392 | if np.linalg.norm(center_bulb - center_holder) <= 0.1: 393 | result = [item['key_id'], item['name2ids'][bulb_name]] 394 | else: # bulb0/1 395 | holder_name = item['key_name'].replace('bulb', 'bulb_holder') 396 | center_bulb = item['pcd'][item['mask'] == item['key_id']].mean(axis=0) 397 | center_holder = item['pcd'][item['mask'] == item['name2ids'][holder_name]].mean(axis=0) 398 | if np.linalg.norm(center_bulb - center_holder) <= 0.1: 399 | result = [item['key_id'], item['name2ids'][holder_name]] 400 | else: 401 | center_screw = item['pcd'][item['mask'] == item['name2ids']['lamp_screw']].mean(axis=0)[:2] 402 | if np.linalg.norm(center_bulb[:2] - center_screw) <= 0.03: 403 | result = [item['key_id']] + find_ids(['lamp_base', 'lamp_screw']) 404 | # for k in list(result): 405 | # name = item['id2names'][k] 406 | # if 'bulb' in name and 'holder' not in name: 407 | # if ('light_' + name) in item['name2ids']: 408 | # result.append(item['name2ids']['light_' + name]) 409 | elif task == 'place_cups': 410 | if item['grasp_id'] != -1: 411 | if 'mug' in item['key_name'] and item['key_name'] != item['grasp_name']: 412 | result = [item['grasp_id']] 413 | 414 | if 'poke' in item['key_name'] or item['key_name'] == 'mug_visual3': 415 | poke_name = 'place_cups_holder_spoke' + item['grasp_name'][-1] 416 | cup_tree = ['place_cups_holder_base'] + [f'place_cups_holder_spoke{j}' for j in range(3)] 417 | # center_poke = item['pcd'][item['mask'] == item['name2ids'][poke_name]].mean(axis=0) 418 | # center_grasp = item['pcd'][item['mask'] == item['grasp_id']].mean(axis=0) 419 | # if np.linalg.norm(center_poke - center_grasp) < 0.12: 420 | # result = find_ids([item['grasp_name']] + cup_tree) 421 | # else: 422 | result = find_ids(cup_tree) 423 | elif task == "insert_onto_square_peg": 424 | pass 425 | # if "pillar" in item['key_name']: 426 | # center_pillar = item['pcd'][item['mask'] == item['key_id']].mean(axis=0)[:2] 427 | # center_ring = item['pcd'][item['mask'] == item['name2ids']['square_ring']].mean(axis=0)[:2] 428 | # if np.linalg.norm(center_ring - center_pillar) <= 0.03: 429 | # result = [item['key_id'], item['name2ids']['square_ring']] 430 | elif task == "stack_cups": 431 | if item['grasp_id'] != -1: 432 | if item['key_name'] in ['cup1_visual', 'cup2_visual']: 433 | center_cup1 = item['pcd'][item['mask'] == item['name2ids']['cup1_visual']].mean(axis=0) 434 | center_cup2 = item['pcd'][item['mask'] == item['name2ids']['cup2_visual']].mean(axis=0) 435 | if np.linalg.norm(center_cup1 - center_cup2) <= 0.05: 436 | result = find_ids(['cup1_visual', 'cup2_visual']) 437 | # center_grasp = item['pcd'][item['mask'] == item['grasp_id']].mean(axis=0) 438 | # if center_grasp[-1] > center_cup2[-1] and np.linalg.norm(center_cup2[:2] - center_grasp[:2]) <= 0.03: 439 | # if item['grasp_id'] not in result: 440 | # result.append(item['grasp_id']) 441 | else: 442 | raise KeyError() 443 | 444 | return result 445 | 446 | 447 | def get_color_position_mask(task, desc, id2names, rgb, mask, target=None, distractors=None): 448 | position_mask = np.full([len(mask)], fill_value=-1, dtype=np.float32) 449 | tags = parse_instructions(task, desc, color_only=True) # [('red', 8), ('white', 20)] 450 | name2ids = {v:k for k,v in id2names.items()} 451 | if len(tags) > 0: 452 | if task == 'push_buttons': 453 | targets = sorted([k for k in id2names.values() if 'push_buttons_target' in k]) 454 | for i, tname in enumerate(targets): 455 | if i <= len(tags) - 1: 456 | position_mask[mask == name2ids[tname]] = tags[i][1] 457 | elif task == 'stack_blocks': 458 | for mask_id, name in id2names.items(): 459 | if 'target' in name and 'target_plane' not in name: 460 | position_mask[mask == mask_id] = tags[0][1] 461 | elif task == 'reach_and_drag': 462 | for mask_id, name in id2names.items(): 463 | if 'target' in name: 464 | position_mask[mask == mask_id] = tags[0][1] 465 | elif task == 'stack_cups': 466 | for mask_id, name in id2names.items(): 467 | if 'cup2' in name: 468 | position_mask[mask == mask_id] = tags[0][1] 469 | elif task == 'block_pyramid': 470 | for mask_id, name in id2names.items(): 471 | if 'block_pyramid_block' in name: 472 | position_mask[mask == mask_id] = tags[0][1] 473 | elif task in ['close_jar', 'insert_onto_square_peg', 'light_bulb_in']: 474 | position_mask[mask == name2ids[target]] = tags[0][1] 475 | elif task == 'slide_block_to_color_target': 476 | targets = ['target1', 'target2', 'target3', 'target4'] 477 | colors = ['green', 'blue', 'pink', 'yellow'] 478 | mask_name = dict(zip(colors, targets))[tags[0][0]] 479 | position_mask[mask == name2ids[mask_name]] = tags[0][1] 480 | else: 481 | raise KeyError() 482 | 483 | return position_mask -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from collections import defaultdict 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from copy import deepcopy as dc 7 | import geometry_lib as glib 8 | from utils.object import Section 9 | from torch_geometric.nn.pool import global_mean_pool 10 | 11 | 12 | class InvariantRegionNetwork(nn.Module): 13 | def __init__(self, reason_depth=4): 14 | super().__init__() 15 | self.backbone = glib.PointTransformerNetwork(grid_sizes=(0.015, 0.03), 16 | depths=(2, 4, 2), 17 | dec_depths=None, 18 | hidden_dims=(64, 128, 256), 19 | n_heads=(4, 8, 8), ks=(16, 24, 32), in_dim=14, skip_dec=True) 20 | self.reason_stages = nn.ModuleList() 21 | layers = ['cross(a, b, b, a)', 'self(a, b)', 'cross(a, b, b, a)'] 22 | self.key_linears = nn.ModuleList() 23 | for i in range(reason_depth): 24 | _layers = dc(layers) 25 | if i == reason_depth - 1: 26 | _layers[-1] = 'cross(a, b)' 27 | 28 | block = glib.KnnTransformerNetwork(_layers, glib.make_knn_transformer_layers(_layers, 256, 8)) 29 | self.reason_stages.append(block) 30 | self.key_linears.append(nn.Linear(256, 1)) 31 | 32 | self.temperature = 0.25 33 | self.k = 16 34 | 35 | 36 | def forward(self, batch, debug=False): 37 | bsize = len(batch['src']['t']['X_to_robot_frame']) 38 | dev = batch['src']['t']['pcd'].device 39 | output, loss_dict, metric_dict = defaultdict(list), {}, {} 40 | t1, t2 = batch['src']['t'], batch['src']['t+1'] 41 | 42 | for item in [t1, t2]: 43 | item['offset'] = glib.batch2offset(item['batch_index']) 44 | item['robot_pcd'] = glib.batch_X_transform_flat(item['pcd'], item['batch_index'], item['X_to_robot_frame']) 45 | 46 | item['feat'] = torch.cat([item[k] for k in ['pcd', 'rgb', 'normal', 'robot_pcd']], dim=1) 47 | item['open'], item['ignore_col'] = item['open'].reshape(-1, 1).float(), item['ignore_col'].reshape(-1, 1).float() 48 | item['feat'] = torch.cat([item['feat'], 49 | glib.expand(item['open'], item['batch_index']), 50 | glib.expand(item['ignore_col'], item['batch_index'])], dim=1) 51 | 52 | for item in [t1, t2]: 53 | item['coarse_pcd'], item['coarse_feat'], item['coarse_offset'] = self.backbone([item['pcd'], item['feat'], item['offset']]) 54 | 55 | if self.training or debug: 56 | t1['coarse_key_mask'] = self.get_coarse_mask(t1['pcd'], t1['offset'], t1['coarse_pcd'], t1['coarse_offset'], 57 | key_mask=t1['key_mask'])['key_mask'] 58 | 59 | knn_indexes = { 60 | 'a2a': glib.knn(t1['coarse_pcd'], t1['coarse_pcd'], self.k, query_offset=t1['coarse_offset'])[0], 61 | 'b2b': glib.knn(t2['coarse_pcd'], t2['coarse_pcd'], self.k, query_offset=t2['coarse_offset'])[0], 62 | 'a2b': glib.knn(t1['coarse_pcd'], t2['coarse_pcd'], self.k, query_offset=t1['coarse_offset'], base_offset=t2['coarse_offset'])[0], 63 | 'b2a': glib.knn(t2['coarse_pcd'], t1['coarse_pcd'], self.k, query_offset=t2['coarse_offset'], base_offset=t1['coarse_offset'])[0] 64 | } 65 | 66 | for i, stage in enumerate(self.reason_stages): 67 | tmp = stage(feat={'a': t1['coarse_feat'], 'b': t2['coarse_feat']}, 68 | coord={'a': t1['coarse_pcd'], 'b': t2['coarse_pcd']}, 69 | knn_indexes=knn_indexes) 70 | t1['coarse_feat'], t2['coarse_feat'] = tmp['a'], tmp['b'] 71 | logits = self.key_linears[i](t1['coarse_feat']) / self.temperature 72 | output['coarse_prob_map'].append(logits.sigmoid()) 73 | 74 | if self.training or debug: 75 | _loss_dict, _metric_dict = self.get_loss(logits.squeeze(-1), t1['coarse_key_mask'], t1['coarse_offset']) 76 | loss_dict.update({f'{k}{i}': v for k, v in _loss_dict.items()}) 77 | metric_dict.update({f'{k}{i}': v for k, v in _metric_dict.items()}) 78 | 79 | if not self.training: 80 | prob_map = self.to_fine_map(t1['pcd'], t1['offset'], t1['coarse_pcd'], t1['coarse_offset'], 81 | coarse_prob_map=output['coarse_prob_map'][-1])['coarse_prob_map'] 82 | output['prob_map'] = prob_map 83 | 84 | return {'output': dict(output), 85 | 'loss_dict': loss_dict, 'metric_dict': metric_dict} 86 | 87 | def compute_iou(self, inputs, gt_mask, batch_mask): 88 | pred_mask = inputs.sigmoid() > 0.5 89 | pred_mask *= batch_mask 90 | iou = (2 * (pred_mask * gt_mask).sum(1)) / (pred_mask.sum(1) + gt_mask.sum(1) + 1e-3) 91 | return iou.mean() 92 | 93 | 94 | def compute_focal_loss(self, inputs, targets, gamma=2.0, input_sigmoid=False): 95 | if input_sigmoid: 96 | p = inputs 97 | else: 98 | p = F.sigmoid(inputs) 99 | p = torch.clamp(p, 1e-7, 1-1e-7) 100 | p_t = p * targets + (1 - p) * (1 - targets) 101 | 102 | if input_sigmoid: 103 | ce_loss = F.binary_cross_entropy(inputs, targets, reduction="none") 104 | else: 105 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 106 | loss = ce_loss * ((1 - p_t) ** gamma) 107 | return loss 108 | 109 | def compute_dice_loss( 110 | self, 111 | inputs: torch.Tensor, 112 | targets: torch.Tensor, 113 | batch_mask, 114 | input_sigmoid=False 115 | ): 116 | if not input_sigmoid: inputs = inputs.sigmoid() 117 | inputs = inputs * batch_mask 118 | numerator = 2 * (inputs * targets).sum(-1) 119 | denominator = inputs.sum(-1) + targets.sum(-1) 120 | loss = 1 - (numerator + 1) / (denominator + 1) 121 | return loss.mean() 122 | 123 | def get_coarse_mask(self, fine_pcd, fine_offset, coarse_pcd, coarse_offset, **mask_dict): 124 | indexes, _ = glib.knn(coarse_pcd, fine_pcd, k=1, query_offset=coarse_offset, base_offset=fine_offset) 125 | indexes = indexes.flatten() 126 | result = {} 127 | for k, mask in mask_dict.items(): 128 | result[k] = mask[indexes] 129 | return result 130 | 131 | 132 | def to_fine_map(self, fine_pcd, fine_offset, coarse_pcd, coarse_offset, **map_dict): 133 | indexes, _ = glib.knn(fine_pcd, coarse_pcd, k=1, query_offset=fine_offset, base_offset=coarse_offset) 134 | indexes = indexes.flatten() 135 | result = {} 136 | for k, v in map_dict.items(): 137 | result[k] = v[indexes] 138 | return result 139 | 140 | def get_loss(self, logits, label_mask, offset, want_iou=True, input_sigmoid=False): 141 | batch_logits, batch_mask = glib.to_dense_batch(logits, offset, input_offset=True) 142 | batch_label_mask, _ = glib.to_dense_batch(label_mask, offset, input_offset=True) 143 | 144 | dice_loss = self.compute_dice_loss(batch_logits, batch_label_mask, batch_mask, input_sigmoid=input_sigmoid) 145 | focal_loss = self.compute_focal_loss(batch_logits, batch_label_mask.float(), input_sigmoid=input_sigmoid) # (B, L) 146 | focal_loss = focal_loss * batch_mask 147 | focal_loss = focal_loss.sum(dim=1) / batch_mask.sum(dim=1) 148 | 149 | metric_dict = {} 150 | if want_iou: 151 | metric_dict['iou'] = self.compute_iou(batch_logits, batch_label_mask, batch_mask) 152 | 153 | return {'focal_loss': focal_loss.mean(), 'dice_loss': dice_loss}, metric_dict 154 | 155 | 156 | 157 | class RegionMatchingNetwork(nn.Module): 158 | 159 | def __init__(self, k=16, 160 | in_dim=14, 161 | hidden_dim=128, n_heads=4, matching_temperature=1.0, max_condition_num=-1, 162 | focal_gamma=2.0, #stage1_query_layers=DEFAULT_STAGE1_QUERY_LAYERS, 163 | 164 | stage1_grid_sizes=(0.015, 0.03), 165 | stage1_depths=(2, 3, 2), 166 | stage1_dec_depths=(1, 1), 167 | stage1_hidden_dims=(256, 384), 168 | 169 | stage2_layers=None, **kwargs): 170 | super().__init__() 171 | stage2_layers = stage2_layers or [ 172 | "positioning(src,tgt)", 173 | "cross(src,tgt,tgt,src)", 174 | "self(src,tgt)", 175 | "cross(src,tgt,tgt,src)", 176 | "positioning(src,tgt)", 177 | "cross(src,tgt,tgt,src)", 178 | "self(src,tgt)", 179 | "cross(src,tgt,tgt,src)", 180 | "positioning(src,tgt):no_emb", 181 | ] 182 | match_block = glib.DualSoftmaxReposition(hidden_dim, matching_temperature, max_condition_num=max_condition_num, 183 | focal_gamma=focal_gamma, one_way=True) 184 | self.in_dim = in_dim 185 | self.k = k 186 | self.hidden_dim = hidden_dim 187 | self.n_heads = n_heads 188 | self.reposition = glib.DualSoftmaxReposition(hidden_dim, matching_temperature, max_condition_num=max_condition_num, 189 | focal_gamma=focal_gamma, use_projection=False, one_way=True) 190 | 191 | # self.stage1 = glib.KnnTransformerNetwork(stage1_layers, glib.make_knn_transformer_layers(stage1_layers, hidden_dim, n_heads)) 192 | self.stage1 = glib.PointTransformerNetwork(grid_sizes=stage1_grid_sizes, 193 | depths=stage1_depths, 194 | dec_depths=stage1_dec_depths, 195 | hidden_dims=(hidden_dim,) + stage1_hidden_dims, 196 | n_heads=(4, 8, 8), ks=(16, 24, 32), in_dim=in_dim) 197 | 198 | def make_position_layer(typ): 199 | if 'no_emb' in typ: 200 | return dc(match_block) 201 | else: 202 | return (dc(match_block), nn.Linear(3, hidden_dim)) 203 | 204 | def clean_names(layers): 205 | return [l.split(":")[0] for l in layers] 206 | 207 | self.stage2 = glib.KnnTransformerNetwork(clean_names(stage2_layers), [ 208 | make_position_layer(l) if 'positioning' in l else glib.make_knn_transformer_one_layer(l, hidden_dim, n_heads) 209 | for l in stage2_layers if "embedding" not in l 210 | ]) 211 | 212 | self.sequence_embed = nn.Linear(hidden_dim, hidden_dim) 213 | 214 | 215 | def forward(self, batch): 216 | meta_data = batch['meta'] 217 | bsize = len(meta_data['ko_correspondence']) 218 | dev = batch['src']['t']['pcd'].device 219 | output, loss_dict, metric_dict = {}, {}, {} 220 | 221 | src, tgt = frame_data = batch['src']['t'], batch['tgt']['t'] 222 | 223 | with Section("Initial Data Preparation"): 224 | for item in [src, tgt]: 225 | item['offset'] = glib.batch2offset(item['batch_index']) 226 | item['robot_pcd'] = glib.batch_X_transform_flat(item['pcd'], item['batch_index'], item['X_to_robot_frame']) 227 | 228 | item['feat'] = torch.cat([item[k] for k in ['pcd', 'rgb', 'normal', 'robot_pcd']], dim=1) 229 | item['open'], item['ignore_col'] = item['open'].reshape(-1, 1).float(), item['ignore_col'].reshape(-1, 1).float() 230 | item['feat'] = torch.cat([item['feat'], 231 | glib.expand(item['open'], item['batch_index']), 232 | glib.expand(item['ignore_col'], item['batch_index'])], dim=1) 233 | 234 | # item['feat'] = self.input_embed(item['feat']) 235 | item['knn'] = glib.knn(item['pcd'], item['pcd'], self.k, query_offset=item['offset'])[0] 236 | 237 | with Section("Stage 1. base"): 238 | for item in [src, tgt]: 239 | # item['feat'] = self.stage1(feat={'p': item['feat']}, coord={'p': item['pcd']}, knn_indexes={'p2p': item['knn']})['p'] 240 | item['feat'] = self.stage1([item['pcd'], item['feat'], item['offset']])[1] 241 | 242 | 243 | with Section("Creating key feature"): 244 | for name, mask in [('ko', src['key_mask']),]: # ('ctx', ~src['key_mask']) 245 | src[name + '_batch_index'] = src['batch_index'][mask] 246 | src[name + '_offset'] = glib.batch2offset(src[name + '_batch_index']) 247 | src[name + '_pcd'] = src['pcd'][mask] 248 | src[name + '_feat'] = src['feat'][mask] 249 | 250 | src['ko2ko'], _ = glib.knn(src['ko_pcd'], src['ko_pcd'], self.k, query_offset=src['ko_offset']) 251 | 252 | with Section("add instruction (optional)"): 253 | for k, item in [('src', src), ('tgt', tgt)]: 254 | position_mask = item['noisy_position_mask' if self.training and ('noisy_position_mask' in item) else 'position_mask'] 255 | seq_triangular = glib.distance_embed(position_mask[:, None], scale=1., num_pos_feats=self.hidden_dim) 256 | seq_embed = self.sequence_embed(seq_triangular[:, 0, :]) 257 | seq_embed = (position_mask != -1)[:, None] * seq_embed 258 | if k == 'src': 259 | item['ko_feat'] = item['ko_feat'] + seq_embed[item['key_mask']] 260 | else: 261 | item['feat'] = item['feat'] + seq_embed 262 | 263 | src2tgt_kindexes, _ = glib.knn(src['ko_pcd'], tgt['pcd'], self.k, query_offset=src['ko_offset'], base_offset=tgt['offset']) 264 | tgt2src_kindexes, _ = glib.knn(tgt['pcd'], src['ko_pcd'], self.k, query_offset=tgt['offset'], base_offset=src['ko_offset']) 265 | 266 | with Section("Stage 2. registration"): 267 | coord, feat, knn_indexes, position_outputs = self.stage2(feat={'src': src['ko_feat'], 'tgt':tgt['feat']}, 268 | coord={'src': src['ko_pcd'], 'tgt': tgt['pcd']}, 269 | batch_index={'src': src['ko_batch_index'], 'tgt': tgt['batch_index']}, 270 | knn_indexes={'src2src': src['ko2ko'], 'tgt2tgt': tgt['knn'], 271 | 'src2tgt': src2tgt_kindexes, 'tgt2src': tgt2src_kindexes}) 272 | 273 | src_tp1_position = src['robot_position_t+1'] 274 | conf_matrix = position_outputs[-1]['conf_matrix'].detach() 275 | R0, t0, cond = self.reposition.arun(conf_matrix, src['ko_pcd'], src['ko_batch_index'], tgt['pcd'], tgt['batch_index']) 276 | tgt_tp1_position_hat = glib.batch_Rt_transform(src_tp1_position, R0, t0) 277 | output['Rt'] = [R0, t0] 278 | output['predict_frame'] = tgt_tp1_position_hat 279 | 280 | if self.training: 281 | conf_matrix = position_outputs[-1]['conf_matrix'] 282 | correspondence = [] 283 | for m1, m2 in zip(meta_data['correspondence'], meta_data['ko_correspondence']): 284 | correspondence.append(torch.cat([m2[:, 0].reshape(-1, 1), m1[:, 1].reshape(-1, 1)], dim=1)) 285 | 286 | gt_matrix = self.reposition.to_gt_correspondence_matrix(conf_matrix, correspondence) 287 | for i, out in enumerate(position_outputs): 288 | corr_loss = self.reposition.compute_matching_loss(out['conf_matrix'], gt_matrix=gt_matrix) 289 | loss_dict[f'position_corr_loss_{i}'] = corr_loss 290 | 291 | if tgt.get('robot_position_t+1', None) is not None: 292 | reg_action_l1dist = torch.abs(tgt_tp1_position_hat - tgt['robot_position_t+1']).sum(dim=-1) 293 | metric_dict['action(reg)_l1_t'] = reg_action_l1dist[:, 0].mean() 294 | metric_dict['action(reg)_l1_xyz'] = reg_action_l1dist[:, 1:].mean() 295 | 296 | output['conf_matrix'] = conf_matrix 297 | return {'output': output, 298 | 'loss_dict': loss_dict, 'metric_dict': metric_dict} 299 | 300 | 301 | 302 | class RegionMatchingNetwork_fine(nn.Module): 303 | def __init__(self, k=16, 304 | in_dim=14, 305 | hidden_dim=128, n_heads=4, matching_temperature=1.0, max_condition_num=-1, 306 | focal_gamma=2.0, 307 | 308 | stage1_grid_sizes=(0.015, 0.03), 309 | stage1_depths=(2, 3, 2), 310 | stage1_dec_depths=(1, 1), 311 | stage1_hidden_dims=(256, 384), 312 | 313 | stage2_layers=None): 314 | super().__init__() 315 | stage2_layers = [ 316 | "positioning(src,tgt)", 317 | "cross(src,tgt,tgt,src)", 318 | "self(src,tgt)", 319 | "cross(src,tgt,tgt,src)", 320 | "positioning(src,tgt)", 321 | "cross(src,tgt,tgt,src)", 322 | "self(src,tgt)", 323 | "cross(src,tgt,tgt,src)", 324 | "positioning(src,tgt):no_emb", 325 | ] 326 | 327 | match_block = glib.DualSoftmaxReposition(hidden_dim, matching_temperature, max_condition_num=max_condition_num, 328 | focal_gamma=focal_gamma) 329 | self.in_dim = in_dim 330 | self.k = k 331 | self.hidden_dim = hidden_dim 332 | self.n_heads = n_heads 333 | self.reposition = glib.DualSoftmaxReposition(hidden_dim, matching_temperature, max_condition_num=max_condition_num, 334 | focal_gamma=focal_gamma, use_projection=False) 335 | 336 | self.stage1 = glib.PointTransformerNetwork(grid_sizes=stage1_grid_sizes, 337 | depths=stage1_depths, 338 | dec_depths=stage1_dec_depths, 339 | hidden_dims=(hidden_dim,) + stage1_hidden_dims, 340 | n_heads=(4, 8, 8), ks=(16, 24, 32), in_dim=in_dim) 341 | 342 | def make_position_layer(typ): 343 | if 'no_emb' in typ: 344 | return dc(match_block) 345 | else: 346 | return (dc(match_block), nn.Linear(3, hidden_dim)) 347 | 348 | def clean_names(layers): 349 | return [l.split(":")[0] for l in layers] 350 | 351 | self.stage2 = glib.KnnTransformerNetwork(clean_names(stage2_layers), [ 352 | make_position_layer(l) if 'positioning' in l else glib.make_knn_transformer_one_layer(l, hidden_dim, n_heads) 353 | for l in stage2_layers if "embedding" not in l 354 | ]) 355 | 356 | def forward(self, batch): 357 | meta_data = batch['meta'] 358 | bsize = len(meta_data['ko_correspondence']) 359 | dev = batch['src']['t']['pcd'].device 360 | output, loss_dict, metric_dict = {}, {}, {} 361 | src, tgt = frame_data = batch['src']['t'], batch['tgt']['t'] 362 | 363 | for item in [src, tgt]: 364 | item['offset'] = glib.batch2offset(item['batch_index']) 365 | item['robot_pcd'] = glib.batch_X_transform_flat(item['pcd'], item['batch_index'], item['X_to_robot_frame']) 366 | 367 | item['feat'] = torch.cat([item[k] for k in ['pcd', 'rgb', 'normal', 'robot_pcd']], dim=1) 368 | item['open'], item['ignore_col'] = item['open'].reshape(-1, 1).float(), item['ignore_col'].reshape(-1, 1).float() 369 | item['feat'] = torch.cat([item['feat'], 370 | glib.expand(item['open'], item['batch_index']), 371 | glib.expand(item['ignore_col'], item['batch_index'])], dim=1) 372 | 373 | for item in [src, tgt]: 374 | item['feat'] = self.stage1([item['pcd'], item['feat'], item['offset']])[1] 375 | 376 | for item in [src, tgt]: 377 | item['key_batch_index'] = item['batch_index'][item['key_mask']] 378 | item['key_offset'] = glib.batch2offset(item['key_batch_index']) 379 | item['key_pcd'] = item['pcd'][item['key_mask']] 380 | item['key_pcd(origin)'] = item['key_pcd'].clone() 381 | item['key_feat'] = item['feat'][item['key_mask']] 382 | 383 | item['key_center'] = global_mean_pool(item['key_pcd'], item['key_batch_index']) 384 | item['key_pcd'] -= glib.expand(item['key_center'], item['key_batch_index']) 385 | 386 | item['robot_position(origin)'] = item['robot_position'].clone() 387 | item['robot_position'] -= item['key_center'][:, None, :] 388 | if 'robot_position_t+1' in item: 389 | item['robot_position_t+1(origin)'] = item['robot_position_t+1'].clone() 390 | item['robot_position_t+1'] -= item['key_center'][:, None, :] 391 | 392 | 393 | knn_indexes = dict(src2tgt=glib.knn(src['key_pcd'], tgt['key_pcd'], self.k, query_offset=src['key_offset'], base_offset=tgt['key_offset'])[0], 394 | tgt2src=glib.knn(tgt['key_pcd'], src['key_pcd'], self.k, query_offset=tgt['key_offset'], base_offset=src['key_offset'])[0], 395 | src2src=glib.knn(src['key_pcd'], src['key_pcd'], self.k, query_offset=src['key_offset'])[0], 396 | tgt2tgt=glib.knn(tgt['key_pcd'], tgt['key_pcd'], self.k, query_offset=tgt['key_offset'])[0]) 397 | 398 | 399 | coord, feat, knn_indexes, position_outputs = self.stage2(feat={'src': src['key_feat'], 'tgt':tgt['key_feat']}, 400 | coord={'src': src['key_pcd'], 'tgt': tgt['key_pcd']}, 401 | batch_index={'src': src['key_batch_index'], 'tgt': tgt['key_batch_index']}, 402 | knn_indexes=knn_indexes) 403 | 404 | src_tp1_position = src['robot_position_t+1'] 405 | conf_matrix = position_outputs[-1]['conf_matrix'].detach() 406 | R0, t0, cond = self.reposition.arun(conf_matrix, src['key_pcd'], src['key_batch_index'], tgt['key_pcd'], tgt['key_batch_index']) 407 | tgt_tp1_position_hat = glib.batch_Rt_transform(src_tp1_position, R0, t0) 408 | output['transformation'] = self.reposition.arun(conf_matrix, src['key_pcd(origin)'], src['key_batch_index'], 409 | tgt['key_pcd(origin)'], tgt['key_batch_index']) 410 | tgt_tp1_position_hat_origin = glib.batch_Rt_transform(src['robot_position_t+1(origin)'], *output['transformation'][:2]) 411 | output['predict_frame'] = tgt_tp1_position_hat_origin 412 | output['conf_matrix'] = position_outputs[-1]['conf_matrix'] 413 | if self.training: 414 | conf_matrix = position_outputs[-1]['conf_matrix'] 415 | gt_matrix = self.reposition.to_gt_correspondence_matrix(conf_matrix, meta_data['ko_correspondence']) 416 | for i, out in enumerate(position_outputs): 417 | corr_loss = self.reposition.compute_matching_loss(out['conf_matrix'], gt_matrix=gt_matrix) 418 | loss_dict[f'position_corr_loss_{i}'] = corr_loss 419 | 420 | if tgt.get('robot_position_t+1', None) is not None: 421 | reg_action_l1dist = torch.abs(tgt_tp1_position_hat - tgt['robot_position_t+1']).sum(dim=-1) 422 | metric_dict['action(reg)_l1_t'] = reg_action_l1dist[:, 0].mean() 423 | metric_dict['action(reg)_l1_xyz'] = reg_action_l1dist[:, 1:].mean() 424 | 425 | return {'output': output, 426 | 'loss_dict': loss_dict, 'metric_dict': metric_dict} -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | einops 3 | clip @ git+https://github.com/openai/CLIP.git 4 | scipy 5 | cloudpickle 6 | runstats 7 | ipyvolume 8 | scikit-learn 9 | fire 10 | faiss-cpu 11 | bidict 12 | textdistance 13 | hdbscan 14 | termcolor 15 | torch==1.13.0 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import multiprocessing 3 | import traceback 4 | from termcolor import colored 5 | from tqdm import tqdm 6 | import os.path as osp 7 | from copy import copy 8 | from omegaconf import OmegaConf 9 | from utils import configurable, DictConfig, config_to_dict 10 | from utils.structure import load_pkl 11 | import torch.multiprocessing as mp 12 | from utils.dist import find_free_port 13 | from utils.ckpt import remove_dict_prefix, compute_grad_norm 14 | from time import time 15 | from torch.optim.lr_scheduler import CosineAnnealingLR 16 | from utils.optim import GradualWarmupScheduler 17 | from torch.nn.parallel import DistributedDataParallel 18 | from runstats import Statistics 19 | import torch.distributed as dist 20 | from torch.utils.data import DataLoader 21 | import os 22 | from torch.utils.tensorboard import SummaryWriter 23 | from utils.object import flat2d, to_item, color_terms, detach 24 | 25 | from data import RLBenchDataset, RLBenchTransitionPairDataset, RLBenchCollator, to_device 26 | from network import InvariantRegionNetwork, RegionMatchingNetwork, RegionMatchingNetwork_fine 27 | 28 | 29 | def main_single(rank: int, cfg: DictConfig, port: int, log_dir:str): 30 | world_size = cfg.train.num_gpus 31 | if world_size == 0: world_size = 1 32 | ddp, on_master = world_size > 1, rank == 0 33 | print(f'Rank - {rank}, master = {on_master}') 34 | if ddp: 35 | os.environ["MASTER_ADDR"] = "localhost" 36 | os.environ["MASTER_PORT"] = str(port) 37 | dist.init_process_group("nccl", rank=rank, world_size=world_size) 38 | device = rank 39 | if cfg.train.num_gpus == 0: device = 'cpu' 40 | else: 41 | torch.cuda.set_device(device) 42 | torch.cuda.empty_cache() 43 | 44 | if on_master: 45 | logfile = open(osp.join(log_dir, 'log.txt'), "w") 46 | 47 | def log(msg, printer=print): 48 | if on_master: 49 | print(msg, file=logfile, flush=False) 50 | printer(msg) 51 | 52 | log(f"「 {cfg.notes} 」") 53 | 54 | if cfg.train.tensorboard and rank == 0: 55 | writer = SummaryWriter(log_dir=osp.join(log_dir, 'tensorboard'), max_queue=10000) 56 | if cfg.train.wandb and rank == 0: 57 | import wandb 58 | wandb.init(project="imop", config=config_to_dict(cfg), 59 | notes=cfg.notes, 60 | name=f'{osp.basename(osp.dirname(log_dir))}_{osp.basename(log_dir)}') 61 | 62 | def log_metrics(stats): 63 | if cfg.train.tensorboard and rank == 0: 64 | for k, v in stats.items(): writer.add_scalar(k, v, i) 65 | if cfg.train.wandb and rank == 0: 66 | wandb.log(stats) 67 | 68 | lr = cfg.train.lr * (world_size * cfg.train.bs) 69 | cos_dec_max_step = cfg.train.epochs * cfg.train.num_transitions_per_epoch // cfg.train.bs 70 | log(f'cosine learning rate - max steps {cos_dec_max_step}') 71 | 72 | collate_fn = RLBenchCollator(use_segmap=False) 73 | model_kwargs = config_to_dict(cfg.model) 74 | model_type = model_kwargs.pop('type') 75 | 76 | if model_type == 'invariant_region': 77 | model = InvariantRegionNetwork(**model_kwargs) 78 | elif model_type == 'region_match': 79 | model = RegionMatchingNetwork(**model_kwargs) 80 | elif model_type == 'region_match_fine': 81 | model = RegionMatchingNetwork_fine(**model_kwargs) 82 | else: 83 | raise KeyError(model_type) 84 | 85 | if cfg.train.checkpoint: 86 | log(f"loading checkpoint from {cfg.train.checkpoint}") 87 | load_result = model.load_state_dict(remove_dict_prefix(torch.load(cfg.train.checkpoint), prefix="module."), strict=False) 88 | log(f"load result: {load_result}") 89 | 90 | model = model.train().to(device) 91 | if ddp: 92 | model = DistributedDataParallel(model, device_ids=[device]) 93 | 94 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr) 95 | 96 | after_scheduler = CosineAnnealingLR(optimizer, T_max=cos_dec_max_step, eta_min=lr / 100) 97 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=cfg.train.warmup_steps, after_scheduler=after_scheduler) 98 | 99 | start_step = 0 100 | total_sample_num = cfg.train.num_transitions_per_epoch * cfg.train.epochs 101 | total_sample_num -= (start_step * world_size * cfg.train.bs) 102 | 103 | db = RLBenchDataset(grid_size=cfg.data.grid_size, cache_to=cfg.data.db_cache, path=cfg.data.db_path, cache_mode='read', 104 | color_only_instructions=cfg.data.color_only_instructions, 105 | min_max_pts_per_obj=getattr(cfg.data, 'max_pts', 5000)) # just use the default parameters 106 | 107 | assert osp.exists(cfg.data.pairs_cache) 108 | pair_db = RLBenchTransitionPairDataset(db, cache_to=cfg.data.pairs_cache, size=total_sample_num, use_aug=cfg.data.aug, 109 | correspondence=cfg.data.correspondence, 110 | align_twice=cfg.data.align_twice, include_T=cfg.data.include_T, noisy_mask=cfg.data.noisy_mask) 111 | dataloader = DataLoader(pair_db, 112 | batch_size=cfg.train.bs, 113 | shuffle=True, 114 | pin_memory=True, 115 | num_workers=cfg.train.num_workers, 116 | multiprocessing_context=multiprocessing.get_context("spawn"), 117 | collate_fn=collate_fn, 118 | drop_last=False) 119 | 120 | start = time() 121 | run_stats = {} 122 | total_steps = len(dataloader) 123 | 124 | for i, batch in enumerate(tqdm(dataloader, disable=rank != 0)): 125 | batch = {k: to_device(v, device) for k,v in batch.items()} 126 | result = model(batch) 127 | loss_dict = result['loss_dict'] 128 | 129 | optimizer.zero_grad(set_to_none=True) 130 | if 'total' not in loss_dict: 131 | if hasattr(cfg.train, 'loss_weight'): 132 | loss_dict['total'] = 0 133 | for k, v in loss_dict.items(): 134 | weight = 1.0 135 | for loss_name, loss_weight in config_to_dict(cfg.train.loss_weight).items(): 136 | if loss_name in k: 137 | weight = loss_weight 138 | break 139 | loss_dict['total'] += (v * weight) 140 | else: 141 | loss_dict['total'] = sum(loss_dict.values()) 142 | loss_dict['total'].backward() 143 | 144 | overall_grad_norm = compute_grad_norm(model) 145 | if cfg.train.grad_clip_after >= 0 and i >= cfg.train.grad_clip_after: 146 | if i == cfg.train.grad_clip_after: log("Start gradient clipping") 147 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.train.grad_clip_value.overall) 148 | 149 | optimizer.step() 150 | scheduler.step() 151 | 152 | 153 | if rank == 0: 154 | loss_dict = {**loss_dict, **result.get('metric_dict', {})} 155 | 156 | for k, v in loss_dict.items(): 157 | if k not in run_stats: 158 | run_stats[k] = Statistics() 159 | 160 | stat_dict = copy(loss_dict) 161 | for k in run_stats: 162 | if k in loss_dict: 163 | run_stats[k].push(detach(loss_dict[k])) 164 | stat_dict[k] = run_stats[k].mean() 165 | 166 | loss_dict['lr'] = stat_dict['lr'] = scheduler.get_last_lr()[0] 167 | loss_dict['grad_norm'] = stat_dict['grad_norm'] = overall_grad_norm.detach() 168 | 169 | log_metrics(loss_dict) 170 | 171 | if i % cfg.train.log_freq == 0: 172 | msg = f"[step:{str(i + start_step).zfill(8)} time:{time()-start:.01f}s] " + " ".join([f"{k}:{to_item(v):.04f}" for k, v in sorted(stat_dict.items())]) 173 | log(msg, printer=tqdm.write) 174 | if i != 0 and (i % cfg.train.save_freq == 0 or i == total_steps - 1): 175 | log(f"checkpoint to {log_dir} at step {i + start_step} and reset running metrics", printer=tqdm.write) 176 | torch.save({'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 177 | 'scheduler':scheduler, 'step': i}, f'{log_dir}/{str(i + start_step).zfill(8)}.pth') 178 | run_stats = {} 179 | 180 | 181 | @configurable() 182 | def main(cfg: DictConfig): 183 | if cfg.train.num_gpus <= 1: 184 | main_single(0, cfg, -1, cfg.output_dir) 185 | else: 186 | port = find_free_port() 187 | mp.spawn(main_single, args=(cfg, port, cfg.output_dir), nprocs=cfg.train.num_gpus, join=True) 188 | 189 | 190 | if __name__ == "__main__": 191 | main() -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | import hydra 2 | import os.path as osp 3 | import sys 4 | from omegaconf import DictConfig, OmegaConf 5 | 6 | def configurable(config_path="config/default.yaml"): 7 | def wrapper(main_func): 8 | config_path_arg = None 9 | i = 1 10 | for a in sys.argv[1:]: 11 | if a.startswith("config="): 12 | config_path_arg = a.split("=")[-1] 13 | sys.argv.pop(i) 14 | i += 1 15 | break 16 | if config_path_arg is None: 17 | config_path_arg = config_path 18 | assert config_path_arg, "config file must be given by `config=path/to/file`" 19 | main_wrapper = hydra.main(config_path=osp.abspath(osp.dirname(config_path_arg)), 20 | config_name=osp.splitext(osp.basename(config_path_arg))[0], 21 | version_base=None) 22 | return main_wrapper(main_func) 23 | return wrapper 24 | 25 | 26 | 27 | def load_hydra_config(config_path, overrides=[]): 28 | 29 | from hydra import compose, initialize 30 | from omegaconf import OmegaConf 31 | 32 | with initialize(version_base=None, config_path=osp.dirname(config_path), job_name="load_config"): 33 | cfg = compose(config_name=osp.splitext(osp.basename(config_path))[0], overrides=overrides) 34 | return cfg 35 | 36 | 37 | def config_to_dict(cfg): 38 | return OmegaConf.to_container(cfg) 39 | 40 | -------------------------------------------------------------------------------- /utils/ckpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def remove_dict_prefix(state_dict, prefix="module."): 5 | result = {} 6 | for k, v in state_dict.items(): 7 | if k.startswith(prefix): 8 | k = k[len(prefix):] 9 | result[k] = v 10 | return result 11 | 12 | 13 | def get_model(mod): 14 | if hasattr(mod, 'module'): 15 | return mod.module 16 | else: 17 | return mod 18 | 19 | 20 | def freeze_model(mod): 21 | mod.eval() 22 | for p in mod.parameters(): 23 | p.requires_grad = False 24 | 25 | 26 | def compute_grad_norm(model): 27 | if isinstance(model, nn.Module): 28 | grads = [ 29 | param.grad.detach().flatten() 30 | for param in model.parameters() 31 | if param.grad is not None 32 | ] 33 | else: 34 | grads = [param.grad.detach().flatten() for param in model if param.grad is not None] 35 | norm = torch.cat(grads).norm() 36 | return norm -------------------------------------------------------------------------------- /utils/clip.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torch 3 | 4 | 5 | # extract CLIP language features for goal string 6 | def clip_encode_text(clip_model, text): 7 | x = clip_model.token_embedding(text).type( 8 | clip_model.dtype 9 | ) # [batch_size, n_ctx, d_model] 10 | 11 | x = x + clip_model.positional_embedding.type(clip_model.dtype) 12 | x = x.permute(1, 0, 2) # NLD -> LND 13 | x = clip_model.transformer(x) 14 | x = x.permute(1, 0, 2) # LND -> NLD 15 | x = clip_model.ln_final(x).type(clip_model.dtype) 16 | 17 | emb = x.clone() 18 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ clip_model.text_projection 19 | 20 | return x, emb -------------------------------------------------------------------------------- /utils/color_remap.py: -------------------------------------------------------------------------------- 1 | import re 2 | import numpy as np 3 | from scipy import stats 4 | import itertools 5 | 6 | def filter_products(pds): 7 | out = set() 8 | for a in pds: 9 | if len(a) == len(set(a)): 10 | out.add(tuple(sorted(a))) 11 | return list(out) 12 | 13 | 14 | def generate_object_pairs(mask_ids, k): 15 | return filter_products(itertools.product([int(k) for k in mask_ids], repeat=k)) 16 | 17 | 18 | colors = dict([ 19 | ('red', (1.0, 0.0, 0.0)), 20 | ('maroon', (0.5, 0.0, 0.0)), 21 | ('lime', (0.0, 1.0, 0.0)), 22 | ('green', (0.0, 0.5, 0.0)), 23 | ('blue', (0.0, 0.0, 1.0)), 24 | ('navy', (0.0, 0.0, 0.5)), 25 | ('yellow', (1.0, 1.0, 0.0)), 26 | ('cyan', (0.0, 1.0, 1.0)), 27 | ('magenta', (1.0, 0.0, 1.0)), 28 | ('silver', (0.75, 0.75, 0.75)), 29 | ('gray', (0.5, 0.5, 0.5)), 30 | ('lightgray', (0.35, 0.35, 0.35)), 31 | ('orange', (1.0, 0.5, 0.0)), 32 | ('olive', (0.5, 0.5, 0.0)), 33 | ('purple', (0.5, 0.0, 0.5)), 34 | # ('pink', (0.95, 0.075, 0.54)), 35 | ('teal', (0, 0.5, 0.5)), 36 | ('azure', (0.0, 0.5, 1.0)), 37 | ('violet', (0.5, 0.0, 1.0)), 38 | ('rose', (1.0, 0.0, 0.5)), 39 | ('black', (0.0, 0.0, 0.0)), 40 | ('white', (1.0, 1.0, 1.0)), 41 | ]) 42 | 43 | # colors = {'blue': [0.211, 0.348, 0.699], 44 | # 'teal': [0.211, 0.606, 0.459], 45 | # 'lime': [0.21, 0.842, 0.203], 46 | # 'yellow': [0.704, 0.842, 0.202], 47 | # 'navy': [0.204, 0.34, 0.458], 48 | # 'orange': [0.701, 0.606, 0.204], 49 | # 'purple': [0.465, 0.362, 0.457], 50 | # 'white': [0.704, 0.844, 0.699], 51 | # 'green': [0.21, 0.608, 0.203], 52 | # 'azure': [0.21, 0.602, 0.693], 53 | # 'olive': [0.465, 0.602, 0.2], 54 | # 'rose': [0.702, 0.362, 0.461], 55 | # 'red': [0.709, 0.341, 0.199], 56 | # 'cyan': [0.214, 0.843, 0.694], 57 | # 'gray': [0.467, 0.606, 0.46], 58 | # 'silver': [0.601, 0.732, 0.595], 59 | # 'maroon': [0.471, 0.354, 0.206], 60 | # 'black': [0.208, 0.346, 0.201], 61 | # 'magenta': [0.701, 0.356, 0.692], 62 | # 'violet': [0.469, 0.347, 0.7]} 63 | 64 | 65 | # for k in list(colors.keys()): 66 | # v = colors[k] 67 | # colors[k] = (v[0] * 0.8, v[1] * 0.8, v[2] * 0.8) 68 | 69 | color_names = [] 70 | color_values = [] 71 | 72 | for c, cv in colors.items(): 73 | color_names.append(c) 74 | color_values.append(cv) 75 | 76 | color_names = np.array(color_names) 77 | color_values = np.array(color_values) 78 | 79 | 80 | def find_color_directive(desc): 81 | cs = [] 82 | for c in colors: 83 | c = f' {c} ' 84 | if c in desc: 85 | cs.append((desc.index(c), c.strip())) 86 | 87 | cs = sorted(cs) 88 | return [c for _, c in cs] 89 | 90 | 91 | def remap_colors(task, rgb, new_mask, id2names, desc): 92 | other_color = 'blue' 93 | name2ids = {v: k for k,v in id2names.items()} 94 | 95 | if task == 'slide_block_to_color_target': 96 | return rgb 97 | elif task == 'push_buttons': 98 | changes = [(e, ['red', 'green', 'blue'][i]) for i, e in enumerate(find_color_directive(desc))] 99 | targets = sorted([k for k in id2names.values() if 'push_buttons_target' in k]) 100 | for t, (_, cto) in zip(targets, changes): 101 | rgb[new_mask == name2ids[t]] = (np.array(colors[cto]) * 255).astype(int) 102 | return rgb 103 | elif task == 'stack_blocks': 104 | for mask_id, name in id2names.items(): 105 | if 'target' in name and 'target_plane' not in name: 106 | rgb[new_mask == mask_id] = np.array([255, 0, 0]) 107 | elif 'distractor' in name: 108 | rgb[new_mask == mask_id] = (np.array(colors[other_color]) * 255).astype(int) 109 | return rgb 110 | elif task == 'reach_and_drag': 111 | for mask_id, name in id2names.items(): 112 | if 'target' in name: 113 | rgb[new_mask == mask_id] = np.array([255, 0, 0]) 114 | elif 'distractor' in name: 115 | rgb[new_mask == mask_id] = (np.array(colors[other_color]) * 255).astype(int) 116 | return rgb 117 | else: 118 | targets = [] 119 | changes = {} 120 | for k in id2names.values(): 121 | if 'target' in k or \ 122 | 'distractor' in k or \ 123 | re.match(r'jar\d', k) or \ 124 | re.match(r'cup\d_visual', k) or \ 125 | 'bulb_holder' in k or \ 126 | 'pillar' in k: 127 | existing_colors = find_color_directive(desc) 128 | assert len(existing_colors) == 1, desc 129 | changes[existing_colors[0]] = 'red' 130 | targets.append(k) 131 | 132 | if len(targets) > 0: 133 | def norm(c): 134 | return c / np.linalg.norm(c) 135 | 136 | def find_by_color(cfrom): 137 | # could not handle black color! 138 | if cfrom != 'black': 139 | color_sims = [(np.dot( norm((rgb[new_mask == name2ids[t]] / 255.).mean(axis=0)), norm(np.array(colors[cfrom]))), t) for t in targets] 140 | color_sims = sorted(color_sims, reverse=True) 141 | remaining_colors = [color_sims[0][1]] + [c for v, c in color_sims[1:] if abs((color_sims[0][0] - v) / color_sims[0][0]) < 0.03] 142 | color_diffs = [(np.sum(np.abs((rgb[new_mask == name2ids[t]] / 255.).mean(axis=0) - np.array(colors[cfrom]))), t) for t in remaining_colors] 143 | t = min(color_diffs)[1] 144 | else: 145 | color_diffs = [(np.sum(np.abs((rgb[new_mask == name2ids[t]] / 255.).mean(axis=0) - np.array(colors[cfrom]))), t) for t in targets] 146 | t = min(color_diffs)[1] 147 | return t 148 | # singular changes 149 | for cfrom, cto in changes.items(): 150 | t = find_by_color(cfrom) 151 | if cfrom == 'gray': 152 | t2 = find_by_color('lightgray') 153 | if t2 != t: t = t2 154 | rgb[new_mask == name2ids[t]] = (np.array(colors[cto]) * 255).astype(int) 155 | targets.remove(t) 156 | 157 | for t in targets: 158 | rgb[new_mask == name2ids[t]] = (np.array(colors[other_color]) * 255).astype(int) 159 | 160 | return rgb 161 | 162 | 163 | -------------------------------------------------------------------------------- /utils/dist.py: -------------------------------------------------------------------------------- 1 | import socket 2 | from contextlib import closing 3 | 4 | def find_free_port(): 5 | with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: 6 | s.bind(('', 0)) 7 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 8 | return s.getsockname()[1] -------------------------------------------------------------------------------- /utils/env.py: -------------------------------------------------------------------------------- 1 | from clip import tokenize 2 | from abc import ABC, abstractmethod 3 | from typing import Any, List, Type 4 | import numpy as np 5 | from utils.structure import ObservationElement, Transition, ROBOT_STATE_KEYS, ActResult, \ 6 | Summary, VideoSummary, TextSummary, ImageSummary, Env 7 | from utils.str import insert_uline_before_cap 8 | try: 9 | from rlbench import ObservationConfig, Environment, CameraConfig 10 | except (ModuleNotFoundError, ImportError) as e: 11 | print("You need to install RLBench: 'https://github.com/stepjam/RLBench'") 12 | raise e 13 | from rlbench.action_modes.action_mode import ActionMode 14 | from rlbench.backend.observation import Observation 15 | from rlbench.backend.task import Task 16 | from pyrep.objects import VisionSensor, Dummy 17 | from pyrep.const import RenderMode 18 | from rlbench.action_modes.arm_action_modes import ( 19 | EndEffectorPoseViaPlanning as _EndEffectorPoseViaPlanning, 20 | Scene, 21 | ) 22 | from pyrep.errors import IKError, ConfigurationPathError 23 | from rlbench.backend.exceptions import InvalidActionError 24 | 25 | 26 | class EndEffectorPoseViaPlanning(_EndEffectorPoseViaPlanning): 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | 30 | def action(self, scene: Scene, action: np.ndarray, ignore_collisions: bool = True): 31 | action[:3] = np.clip( 32 | action[:3], 33 | np.array( 34 | [scene._workspace_minx, scene._workspace_miny, scene._workspace_minz] 35 | ) 36 | + 1e-7, 37 | np.array( 38 | [scene._workspace_maxx, scene._workspace_maxy, scene._workspace_maxz] 39 | ) 40 | - 1e-7, 41 | ) 42 | super().action(scene, action, ignore_collisions) 43 | 44 | 45 | 46 | def rlbench_obs_config(camera_names: List[str], 47 | camera_resolution: List[int], 48 | method_name: str): 49 | unused_cams = CameraConfig() 50 | unused_cams.set_all(False) 51 | used_cams = CameraConfig( 52 | rgb=True, 53 | point_cloud=True, 54 | mask=True, 55 | depth=False, 56 | image_size=camera_resolution, 57 | render_mode=RenderMode.OPENGL) 58 | 59 | cam_obs = [] 60 | kwargs = {} 61 | for n in camera_names: 62 | kwargs[n] = used_cams 63 | cam_obs.append('%s_rgb' % n) 64 | cam_obs.append('%s_pointcloud' % n) 65 | 66 | obs_config = ObservationConfig( 67 | front_camera=kwargs.get('front', unused_cams), 68 | left_shoulder_camera=kwargs.get('left_shoulder', unused_cams), 69 | right_shoulder_camera=kwargs.get('right_shoulder', unused_cams), 70 | wrist_camera=kwargs.get('wrist', unused_cams), 71 | overhead_camera=kwargs.get('overhead', unused_cams), 72 | joint_forces=False, 73 | joint_positions=True, 74 | joint_velocities=True, 75 | task_low_dim_state=False, 76 | gripper_touch_forces=False, 77 | gripper_pose=True, 78 | gripper_open=True, 79 | gripper_matrix=True, 80 | gripper_joint_positions=True, 81 | ) 82 | 83 | obs_config.left_shoulder_camera.masks_as_one_channel = False 84 | obs_config.right_shoulder_camera.masks_as_one_channel = False 85 | obs_config.overhead_camera.masks_as_one_channel = False 86 | obs_config.wrist_camera.masks_as_one_channel = False 87 | obs_config.front_camera.masks_as_one_channel = False 88 | return obs_config 89 | 90 | 91 | 92 | def _get_cam_observation_elements(camera: CameraConfig, prefix: str, channels_last): 93 | elements = [] 94 | img_s = list(camera.image_size) 95 | shape = img_s + [3] if channels_last else [3] + img_s 96 | if camera.rgb: 97 | elements.append( 98 | ObservationElement('%s_rgb' % prefix, shape, np.uint8)) 99 | if camera.point_cloud: 100 | elements.append( 101 | ObservationElement('%s_point_cloud' % prefix, shape, np.float32)) 102 | elements.append( 103 | ObservationElement('%s_camera_extrinsics' % prefix, (4, 4), 104 | np.float32)) 105 | elements.append( 106 | ObservationElement('%s_camera_intrinsics' % prefix, (3, 3), 107 | np.float32)) 108 | if camera.depth: 109 | shape = img_s + [1] if schannels_last else [1] + img_s 110 | elements.append( 111 | ObservationElement('%s_depth' % prefix, shape, np.float32)) 112 | if camera.mask: 113 | raise NotImplementedError() 114 | 115 | return elements 116 | 117 | 118 | def _observation_elements(observation_config, channels_last) -> List[ObservationElement]: 119 | elements = [] 120 | robot_state_len = 0 121 | if observation_config.joint_velocities: 122 | robot_state_len += 7 123 | if observation_config.joint_positions: 124 | robot_state_len += 7 125 | if observation_config.joint_forces: 126 | robot_state_len += 7 127 | if observation_config.gripper_open: 128 | robot_state_len += 1 129 | if observation_config.gripper_pose: 130 | robot_state_len += 7 131 | if observation_config.gripper_joint_positions: 132 | robot_state_len += 2 133 | if observation_config.gripper_touch_forces: 134 | robot_state_len += 2 135 | if observation_config.task_low_dim_state: 136 | raise NotImplementedError() 137 | if robot_state_len > 0: 138 | elements.append(ObservationElement( 139 | 'low_dim_state', (robot_state_len,), np.float32)) 140 | elements.extend(_get_cam_observation_elements( 141 | observation_config.left_shoulder_camera, 'left_shoulder', channels_last)) 142 | elements.extend(_get_cam_observation_elements( 143 | observation_config.right_shoulder_camera, 'right_shoulder', channels_last)) 144 | elements.extend(_get_cam_observation_elements( 145 | observation_config.front_camera, 'front', channels_last)) 146 | elements.extend(_get_cam_observation_elements( 147 | observation_config.wrist_camera, 'wrist', channels_last)) 148 | return elements 149 | 150 | 151 | def _extract_obs(obs: Observation, channels_last: bool, observation_config: ObservationConfig): 152 | misc = obs.misc 153 | obs_dict = vars(obs) 154 | obs_dict = {k: v for k, v in obs_dict.items() if v is not None} 155 | robot_state = obs.get_low_dim_data() 156 | # **Remove** all of the individual state elements 157 | obs_dict = {k: v for k, v in obs_dict.items() 158 | if k not in ROBOT_STATE_KEYS} 159 | if not channels_last: 160 | # Swap channels from last dim to 1st dim 161 | obs_dict = {k: np.transpose( 162 | v, [2, 0, 1]) if v.ndim == 3 else np.expand_dims(v, 0) 163 | for k, v in obs_dict.items()} 164 | else: 165 | # Add extra dim to depth data 166 | obs_dict = {k: v if v.ndim == 3 else np.expand_dims(v, -1) 167 | for k, v in obs_dict.items()} 168 | obs_dict['low_dim_state'] = np.array(robot_state, dtype=np.float32) 169 | for (k, v) in [(k, v) for k, v in obs_dict.items() if 'point_cloud' in k]: 170 | obs_dict[k] = v.astype(np.float32) 171 | 172 | for config, name in [ 173 | (observation_config.left_shoulder_camera, 'left_shoulder'), 174 | (observation_config.right_shoulder_camera, 'right_shoulder'), 175 | (observation_config.front_camera, 'front'), 176 | (observation_config.wrist_camera, 'wrist'), 177 | (observation_config.overhead_camera, 'overhead')]: 178 | if config.point_cloud: 179 | obs_dict['%s_camera_extrinsics' % name] = obs.misc['%s_camera_extrinsics' % name] 180 | obs_dict['%s_camera_intrinsics' % name] = obs.misc['%s_camera_intrinsics' % name] 181 | 182 | if 'object_ids' in misc: 183 | obs_dict['object_ids'] = misc['object_ids'] 184 | return obs_dict 185 | 186 | 187 | class MultiTaskRLBenchEnv(Env): 188 | 189 | def __init__(self, 190 | task_classes: List[Type[Task]], 191 | observation_config: ObservationConfig, 192 | action_mode: ActionMode, 193 | dataset_root: str = '', 194 | channels_last=False, 195 | headless=True, 196 | swap_task_every: int = 1, include_lang_goal_in_obs=False): 197 | 198 | self._eval_env = False 199 | self._include_lang_goal_in_obs = include_lang_goal_in_obs 200 | self._task_classes = task_classes 201 | self._observation_config = observation_config 202 | self._channels_last = channels_last 203 | self._rlbench_env = Environment( 204 | action_mode=action_mode, obs_config=observation_config, 205 | dataset_root=dataset_root, headless=headless) 206 | self._task = None 207 | self._lang_goal = 'unknown goal' 208 | self._swap_task_every = swap_task_every 209 | self._rlbench_env 210 | self._episodes_this_task = 0 211 | self._active_task_id = -1 212 | 213 | self._task_name_to_idx = {insert_uline_before_cap(tc.__name__): i 214 | for i, tc in enumerate(self._task_classes)} 215 | 216 | 217 | @property 218 | def eval(self): 219 | return self._eval_env 220 | 221 | @eval.setter 222 | def eval(self, is_eval): 223 | self._eval_env = is_eval 224 | 225 | @property 226 | def active_task_id(self) -> int: 227 | return self._active_task_id 228 | 229 | def _set_new_task(self, shuffle=False): 230 | if shuffle: 231 | self._active_task_id = np.random.randint(0, len(self._task_classes)) 232 | else: 233 | self._active_task_id = (self._active_task_id + 1) % len(self._task_classes) 234 | task = self._task_classes[self._active_task_id] 235 | self._task = self._rlbench_env.get_task(task) 236 | 237 | def set_task(self, task_name: str): 238 | self._active_task_id = self._task_name_to_idx[task_name] 239 | task = self._task_classes[self._active_task_id] 240 | self._task = self._rlbench_env.get_task(task) 241 | 242 | descriptions, _ = self._task.reset() 243 | self._lang_goal = descriptions[0] # first description variant 244 | 245 | def extract_obs(self, obs: Observation): 246 | extracted_obs = _extract_obs(obs, self._channels_last, self._observation_config) 247 | if self._include_lang_goal_in_obs: 248 | extracted_obs['lang_goal_tokens'] = tokenize([self._lang_goal])[0].numpy() 249 | return extracted_obs 250 | 251 | def launch(self): 252 | self._rlbench_env.launch() 253 | self._set_new_task() 254 | 255 | def shutdown(self): 256 | self._rlbench_env.shutdown() 257 | 258 | def reset(self) -> dict: 259 | self._episodes_this_task += 1 260 | if self._episodes_this_task == self._swap_task_every: 261 | self._set_new_task() 262 | self._episodes_this_task = 0 263 | 264 | descriptions, obs = self._task.reset() 265 | self._lang_goal = descriptions[0] # first description variant 266 | return self.extract_obs(obs) 267 | 268 | def step(self, action: np.ndarray) -> Transition: 269 | obs, reward, terminal = self._task.step(action) 270 | obs = self.extract_obs(obs) 271 | return Transition(obs, reward, terminal) 272 | 273 | @property 274 | def observation_elements(self) -> List[ObservationElement]: 275 | """ return the specification of observable data """ 276 | return _observation_elements(self._observation_config, self._channels_last) 277 | 278 | @property 279 | def action_shape(self): 280 | return (self._rlbench_env.action_size, ) 281 | 282 | @property 283 | def env(self) -> Environment: 284 | return self._rlbench_env 285 | 286 | @property 287 | def num_tasks(self) -> int: 288 | return len(self._task_classes) 289 | 290 | 291 | 292 | class CustomMultiTaskRLBenchEnv(MultiTaskRLBenchEnv): 293 | def __init__(self, 294 | task_classes: List[Type[Task]], 295 | observation_config: ObservationConfig, 296 | action_mode: ActionMode, 297 | episode_length: int, 298 | dataset_root: str = '', 299 | channels_last: bool = False, 300 | reward_scale=100.0, 301 | headless: bool = True, 302 | swap_task_every: int = 1, 303 | time_in_state: bool = False, 304 | include_lang_goal_in_obs: bool = False, 305 | record: bool = False): 306 | super().__init__( 307 | task_classes, observation_config, action_mode, dataset_root, 308 | channels_last, headless=headless, swap_task_every=swap_task_every, 309 | include_lang_goal_in_obs=include_lang_goal_in_obs) 310 | 311 | self._reward_scale = reward_scale 312 | self._episode_index = 0 313 | self._record_current_episode = False 314 | self._record_cam = None 315 | self._previous_obs, self._previous_obs_dict = None, None 316 | self._episode_length = episode_length 317 | self._time_in_state = time_in_state 318 | # self._record_every_n = record_every_n 319 | self._record = record 320 | self._i = 0 321 | self._error_type_counts = { 322 | 'IKError': 0, 323 | 'ConfigurationPathError': 0, 324 | 'InvalidActionError': 0, 325 | } 326 | self._last_exception = None 327 | 328 | @property 329 | def observation_elements(self) -> List[ObservationElement]: 330 | obs_elems = super().observation_elements 331 | for oe in obs_elems: 332 | if oe.name == 'low_dim_state': 333 | oe.shape = (oe.shape[0] - 7 * 3 + int(self._time_in_state),) # remove pose and joint velocities as they will not be included 334 | self.low_dim_state_len = oe.shape[0] 335 | return obs_elems 336 | 337 | def extract_obs(self, obs: Observation, t=None, prev_action=None): 338 | obs.joint_velocities = None 339 | grip_mat = obs.gripper_matrix 340 | grip_pose = obs.gripper_pose 341 | joint_pos = obs.joint_positions 342 | obs.gripper_pose = None 343 | obs.gripper_matrix = None 344 | obs.wrist_camera_matrix = None 345 | obs.joint_positions = None 346 | if obs.gripper_joint_positions is not None: 347 | obs.gripper_joint_positions = np.clip( 348 | obs.gripper_joint_positions, 0., 0.04) 349 | 350 | obs_dict = super().extract_obs(obs) 351 | 352 | if self._time_in_state: 353 | time = (1. - ((self._i if t is None else t) / float( 354 | self._episode_length - 1))) * 2. - 1. 355 | obs_dict['low_dim_state'] = np.concatenate( 356 | [obs_dict['low_dim_state'], [time]]).astype(np.float32) 357 | 358 | obs.gripper_matrix = grip_mat 359 | obs.joint_positions = joint_pos 360 | obs.gripper_pose = grip_pose 361 | obs_dict['gripper_pose'] = grip_pose 362 | obs_dict['gripper_open'] = obs.gripper_open 363 | return obs_dict 364 | 365 | def reset(self) -> dict: 366 | self._i = 0 367 | self._previous_obs_dict = super().reset() 368 | self._episode_index += 1 369 | return self._previous_obs_dict 370 | 371 | def launch(self): 372 | super().launch() 373 | 374 | def step(self, act_result: ActResult) -> Transition: 375 | action = act_result.action # from model 376 | success = False 377 | obs = self._previous_obs_dict # in case action fails. 378 | 379 | try: 380 | obs, reward, terminal = self._task.step(action) 381 | if reward >= 1: 382 | success = True 383 | reward *= self._reward_scale 384 | else: 385 | reward = 0.0 386 | obs = self.extract_obs(obs) 387 | self._previous_obs_dict = obs 388 | except (IKError, ConfigurationPathError, InvalidActionError) as e: 389 | print(e) 390 | terminal = True 391 | reward = 0.0 392 | 393 | if isinstance(e, IKError): 394 | self._error_type_counts['IKError'] += 1 395 | elif isinstance(e, ConfigurationPathError): 396 | self._error_type_counts['ConfigurationPathError'] += 1 397 | elif isinstance(e, InvalidActionError): 398 | self._error_type_counts['InvalidActionError'] += 1 399 | 400 | self._last_exception = e 401 | 402 | summaries = [] 403 | self._i += 1 404 | return Transition(obs, reward, terminal, summaries=summaries) 405 | 406 | 407 | def reset_to_demo(self, i, variation_number=-1, start_new=False): 408 | if self._episodes_this_task == self._swap_task_every or start_new: 409 | self._set_new_task() 410 | self._episodes_this_task = 0 411 | self._episodes_this_task += 1 412 | 413 | self._i = 0 414 | self._task.set_variation(-1) 415 | d = self._task.get_demos( 416 | 1, live_demos=False, random_selection=False, from_episode_number=i, 417 | 418 | image_paths=True # skip rgb/pcd loading, as they are not used 419 | )[0] # from dataset_root 420 | 421 | self._task.set_variation(d.variation_number) 422 | desc, obs = self._task.reset_to_demo(d) 423 | self._lang_goal = desc[0] 424 | 425 | self._previous_obs_dict = self.extract_obs(obs) 426 | self._episode_index += 1 427 | return self._previous_obs_dict 428 | 429 | 430 | def get_color_information(self): 431 | """ directly use the color information in rlbench to avoid the work of building an external color detector or collecting 432 | demonstrations for all colors 433 | """ 434 | task = self._task._scene.task 435 | return { 436 | 'target': getattr(task, '_color_target', None), 437 | 'distractors': getattr(task, '_color_distractors', []) 438 | } 439 | 440 | 441 | 442 | def get_active_task_of_env(env): 443 | return {v:k for k,v in env._task_name_to_idx.items()}[env._active_task_id] -------------------------------------------------------------------------------- /utils/icp.py: -------------------------------------------------------------------------------- 1 | import open3d as o3d 2 | from scipy.spatial.transform import Rotation 3 | import numpy as np 4 | from numpy.linalg import inv 5 | 6 | 7 | 8 | def knn(query, points, k=1): 9 | import faiss 10 | if len(query) == 0 or len(points) == 0: 11 | raise ValueError("Found array with 0 sample(s) (shape=(0, 3)) while a minimum of 1 is required by NearestNeighbors.") 12 | 13 | index = faiss.IndexFlatL2(points.shape[1]) 14 | index.add(points.astype(np.float32)) 15 | distances, indices = index.search(query.astype(np.float32), k=k) 16 | distances = np.sqrt(distances) 17 | if k == 1: 18 | return distances.flatten(), indices.flatten() 19 | else: 20 | return distances, indices 21 | 22 | 23 | def icp(source_pts, target_pts, source_rgb=None, target_rgb=None, max_corr_dist=0.2, max_iteration=30, rotation_hint=False, init_X=None, plane=False): 24 | if init_X is None: 25 | source_center = source_pts.mean(axis=0, keepdims=True) 26 | target_center = target_pts.mean(axis=0, keepdims=True) 27 | source = o3d.geometry.PointCloud() 28 | # ICP works better with normalized points 29 | source.points = o3d.utility.Vector3dVector(source_pts - source_center) 30 | target = o3d.geometry.PointCloud() 31 | target.points = o3d.utility.Vector3dVector(target_pts - target_center) 32 | has_color = source_rgb is not None 33 | if has_color: 34 | source.colors = o3d.utility.Vector3dVector(source_rgb) 35 | target.colors = o3d.utility.Vector3dVector(target_rgb) 36 | 37 | source.estimate_normals() 38 | target.estimate_normals() 39 | 40 | H_init = np.asarray([[1, 0, 0, 0.], 41 | [0, 1, 0., 0.], 42 | [0, 0, 1, 0], 43 | [0.0, 0.0, 0.0, 1.0]]) 44 | trans_init = H_init.copy() 45 | 46 | def register(init): 47 | try: 48 | reg_p2l = getattr(o3d.pipelines.registration, "registration_colored_icp" if has_color else "registration_icp")( 49 | source, target, max_corr_dist, init, 50 | o3d.pipelines.registration.TransformationEstimationForColoredICP() 51 | if has_color else 52 | getattr(o3d.pipelines.registration, 'TransformationEstimationPointTo' + ('Plane' if plane else 'Point'))(), 53 | o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration)) # TransformationEstimationPointToPlane 54 | except RuntimeError: 55 | reg_p2l = None 56 | return reg_p2l 57 | 58 | if not rotation_hint: 59 | reg_p2l = register(trans_init) 60 | else: 61 | # clever initialization to improve ICP accuracy 62 | inits = [] 63 | for rad in [0, np.pi/2, np.pi]: 64 | X = H_init.copy() 65 | X[:3, :3] = Rotation.from_euler('z', rad).as_matrix() 66 | inits.append(X) 67 | results = [register(init) for init in inits] 68 | max_fitness = max([a.fitness for a in results if a is not None] + [0]) 69 | reg_p2l = None 70 | min_mse = 100 71 | for r in results: 72 | if r is not None and r.fitness == max_fitness: 73 | v = r.inlier_rmse 74 | if v < min_mse: 75 | reg_p2l = r 76 | min_mse = v 77 | 78 | if reg_p2l is not None: 79 | # transform X back to non-normalized space 80 | X = reg_p2l.transformation.copy() 81 | X[:3, 3] += (target_center.flatten() - (X[:3, :3] @ source_center.flatten())) 82 | reg_p2l.transformation = X 83 | 84 | return reg_p2l 85 | else: 86 | source = o3d.geometry.PointCloud() 87 | source.points = o3d.utility.Vector3dVector(source_pts) 88 | target = o3d.geometry.PointCloud() 89 | target.points = o3d.utility.Vector3dVector(target_pts) 90 | has_color = source_rgb is not None 91 | if has_color: 92 | source.colors = o3d.utility.Vector3dVector(source_rgb) 93 | target.colors = o3d.utility.Vector3dVector(target_rgb) 94 | source.estimate_normals() 95 | target.estimate_normals() 96 | 97 | try: 98 | reg_p2l = getattr(o3d.pipelines.registration, "registration_colored_icp" if has_color else "registration_icp")( 99 | source, target, max_corr_dist, init_X, 100 | o3d.pipelines.registration.TransformationEstimationForColoredICP() 101 | if has_color else 102 | getattr(o3d.pipelines.registration, 'TransformationEstimationPointTo' + ('Plane' if plane else 'Point'))(), 103 | o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=max_iteration)) # TransformationEstimationPointToPlane 104 | except RuntimeError: 105 | reg_p2l = None 106 | return reg_p2l 107 | 108 | 109 | def estimate_pca_box(pcd): 110 | return to_o3d_pcd(pcd).get_minimal_oriented_bounding_box() 111 | 112 | 113 | def to_o3d_pcd(pcd): 114 | source = o3d.geometry.PointCloud() 115 | source.points = o3d.utility.Vector3dVector(pcd) 116 | return source 117 | 118 | 119 | def to_np_pcd(o3d_pcd): 120 | return np.asarray(o3d_pcd.points) 121 | 122 | 123 | def box_volume(pcd): 124 | try: 125 | return estimate_pca_box(pcd).volume() 126 | except Exception as e: 127 | msg = str(e) 128 | print(f'Exception while estimating point cloud volume: {msg}, return zero volume') 129 | return 0.0 130 | 131 | 132 | def voxel_grid(pcd, voxel_size=0.02): 133 | source = to_o3d_pcd(pcd) 134 | voxel_volume = o3d.geometry.VoxelGrid.create_from_point_cloud(source, voxel_size) 135 | return np.asarray([voxel_volume.origin + pt.grid_index*voxel_volume.voxel_size for pt in voxel_volume.get_voxels()]) 136 | 137 | 138 | def normalize_point_cloud_to_origin(pcd): 139 | center = pcd.mean(axis=0, keepdims=True) 140 | pcd = pcd - center 141 | return pcd, center.flatten() 142 | 143 | 144 | def to_unit_length(v): 145 | return v / np.linalg.norm(v) 146 | 147 | 148 | def R_2_X(R): 149 | X = np.eye(4) 150 | X[:3, :3] = R 151 | return X 152 | 153 | 154 | def t_2_X(t): 155 | X = np.eye(4) 156 | if isinstance(t, np.ndarray): t = t.flatten() 157 | X[:3, -1] = t 158 | return X 159 | 160 | 161 | def Rt_2_X(R, t): 162 | X = np.eye(4) 163 | X[:3, :3] = R 164 | if isinstance(t, np.ndarray): t = t.flatten() 165 | X[:3, -1] = t 166 | return X 167 | 168 | def X_2_Rt(X): 169 | return X[:3, :3], X[:3, -1] 170 | 171 | def to_homo_axis(pts): 172 | return np.concatenate([pts, np.ones((len(pts), 1))], axis=1) 173 | 174 | 175 | def h_transform(T, pts): 176 | return (T @ to_homo_axis(pts).T).T[:, :3] 177 | 178 | def r_transform(R, pts): 179 | return (R @ pts.T).T 180 | 181 | def axis_angle_rotate(axis, radian): 182 | assert axis.shape == (3,) 183 | rot = Rotation.from_rotvec(to_unit_length(axis) * radian) 184 | return rot.as_matrix() 185 | 186 | def rotate_X(X, Pc, axis, angle): 187 | R, t = X_2_Rt(X) 188 | Rot_inv = inv(axis_angle_rotate(axis, angle)) 189 | Rnew = Rot_inv @ R 190 | tnew = Rot_inv @ (t - Pc) + Pc 191 | return Rt_2_X(Rnew, tnew) 192 | 193 | 194 | def get_matching_ratio(P_from, P_to, threshold=0.01): 195 | dist, _ = knn(P_from, P_to) 196 | matched_ratio = (dist <= threshold).sum() / len(P_from) 197 | return matched_ratio 198 | 199 | 200 | def resolve_rotation_ambiguity(X, ref_points, points, ref_context_points, context_points, ambiguity_threshold=0.95): 201 | P = points 202 | P_n, P_c = normalize_point_cloud_to_origin(P) 203 | bbox = estimate_pca_box(P_n) 204 | 205 | P_ref = h_transform(X, ref_points) 206 | P_ref_n, P_ref_c = normalize_point_cloud_to_origin(P_ref) 207 | 208 | X_alternatives = [X, ] 209 | X_AXIS, Y_AXIS, Z_AXIS = 0, 1, 2 210 | 211 | for radian in [np.pi, np.pi/2, -np.pi/2]: 212 | for axis in [X_AXIS, Y_AXIS, Z_AXIS]: 213 | ratio = get_matching_ratio(P_ref_n, r_transform(axis_angle_rotate(bbox.R[axis], radian), P_n)) 214 | # print('ambiguity ratio -> ', ratio) 215 | if ratio >= ambiguity_threshold: 216 | _X = rotate_X(X, P_c, bbox.R[axis], radian) 217 | if check_X_validity(_X): 218 | X_alternatives.append(_X) 219 | 220 | if len(X_alternatives) == 1: 221 | return X_alternatives[0] 222 | 223 | distances = [] 224 | Xs = [] 225 | for _X in X_alternatives: 226 | distance = knn(h_transform(_X, ref_context_points), context_points)[0].mean() 227 | distances.append(distance) 228 | Xs.append(_X) 229 | 230 | ind = np.argmin(distances) 231 | # if distances[0] / (distances[ind] + 1e-6) > 2: 232 | if ind != 0: 233 | print('ambiguity resolve to another X!') 234 | else: 235 | ind = 0 236 | chosen_X = Xs[ind] 237 | return chosen_X 238 | 239 | 240 | 241 | def check_X_validity(X): 242 | plane = np.zeros((10, 10, 3)) 243 | plane[:, :, 0], plane[:, :, 1] = np.meshgrid(np.linspace(0, 1, 10), np.linspace(0, 1, 10)) 244 | plane[:, :, 2] = np.random.uniform(0.74, 0.76, size=(10, 10)) 245 | plane = plane.reshape(-1, 3) 246 | X_plane = h_transform(X, plane) 247 | m = X_plane[:, 2].mean() 248 | return 0.73 <= m <= 0.77 249 | 250 | 251 | def pose7_to_X(pose): 252 | R = Rotation.from_quat(pose[3:]).as_matrix() 253 | t = pose[:3] 254 | X = np.zeros((4, 4)) 255 | X[:3, :3] = R 256 | X[-1, -1] = 1 257 | X[:3, -1] = t 258 | return X 259 | 260 | 261 | def X_to_pose7(X): 262 | t = X[:3, -1] 263 | q = Rotation.from_matrix(X[:3, :3]).as_quat() 264 | return np.concatenate([t, q]) 265 | 266 | 267 | def pose7_to_frame(pose, scale=0.1): 268 | pose = pose.copy() 269 | R = Rotation.from_quat(pose[3:]).as_matrix() * scale 270 | t = pose[:3] 271 | return np.array([t, R[0] + t, R[1] + t, R[2] + t]) 272 | 273 | 274 | def X_to_frame(X): 275 | return pose7_to_frame(X_to_pose7(X)) 276 | 277 | 278 | def frame_to_X(frame): 279 | frame = np.copy(frame) 280 | t, x, y, z = frame 281 | X = np.eye(4) 282 | X[:3, -1] = t 283 | x -= t 284 | y -= t 285 | z -= t 286 | x = x / np.linalg.norm(x) 287 | y = y / np.linalg.norm(y) 288 | z = z / np.linalg.norm(z) 289 | X[0, :3] = x 290 | X[1, :3] = y 291 | X[2, :3] = z 292 | return X 293 | 294 | 295 | def h_transform_X(T, X): 296 | return frame_to_X(h_transform(T, np.array(X_to_frame(X)))) 297 | 298 | def h_transform_pose(T, pose): 299 | return X_to_pose7(frame_to_X(h_transform(T, np.array(pose7_to_frame(pose))))) 300 | 301 | 302 | def rotate_from_origin(pts, matrix): 303 | center = pts.mean(axis=0, keepdims=True) 304 | pts -= center 305 | pts = r_transform(matrix, pts) 306 | pts += center 307 | return pts 308 | 309 | 310 | def fps_sample_to(pts, N): 311 | if N >= len(pts): 312 | return pts 313 | else: 314 | return to_np_pcd(to_o3d_pcd(pts).farthest_point_down_sample(N)) 315 | 316 | 317 | 318 | def arun(A, B): 319 | """Solve 3D registration using Arun's method: B = RA + t 320 | """ 321 | A, B = A.T, B.T 322 | 323 | N = A.shape[1] 324 | assert B.shape[1] == N 325 | 326 | # calculate centroids 327 | A_centroid = np.reshape(1/N * (np.sum(A, axis=1)), (3,1)) 328 | B_centroid = np.reshape(1/N * (np.sum(B, axis=1)), (3,1)) 329 | 330 | # calculate the vectors from centroids 331 | A_prime = A - A_centroid 332 | B_prime = B - B_centroid 333 | 334 | # rotation estimation 335 | H = np.zeros([3, 3]) 336 | for i in range(N): 337 | ai = A_prime[:, i] 338 | bi = B_prime[:, i] 339 | H = H + np.outer(ai, bi) 340 | U, S, V_transpose = np.linalg.svd(H) 341 | V = np.transpose(V_transpose) 342 | U_transpose = np.transpose(U) 343 | R = V @ np.diag([1, 1, np.linalg.det(V) * np.linalg.det(U_transpose)]) @ U_transpose 344 | 345 | # translation estimation 346 | t = B_centroid - R @ A_centroid 347 | 348 | return R, t.flatten() -------------------------------------------------------------------------------- /utils/layers.py: -------------------------------------------------------------------------------- 1 | from functools import wraps 2 | import torch 3 | import torch.nn.functional as F 4 | from torch import nn, einsum 5 | from einops import rearrange, repeat 6 | 7 | LRELU_SLOPE = 0.02 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def default(val, d): 13 | return val if exists(val) else d 14 | 15 | def cache_fn(f): 16 | cache = None 17 | 18 | @wraps(f) 19 | def cached_fn(*args, _cache=True, **kwargs): 20 | if not _cache: 21 | return f(*args, **kwargs) 22 | nonlocal cache 23 | if cache is not None: 24 | return cache 25 | cache = f(*args, **kwargs) 26 | return cache 27 | 28 | return cached_fn 29 | 30 | 31 | class PreNorm(nn.Module): 32 | def __init__(self, dim, fn, context_dim=None): 33 | super().__init__() 34 | self.fn = fn 35 | self.norm = nn.LayerNorm(dim) 36 | self.norm_context = nn.LayerNorm(context_dim) if exists(context_dim) else None 37 | 38 | def forward(self, x, **kwargs): 39 | x = self.norm(x) 40 | 41 | if exists(self.norm_context): 42 | context = kwargs["context"] 43 | normed_context = self.norm_context(context) 44 | kwargs.update(context=normed_context) 45 | 46 | return self.fn(x, **kwargs) 47 | 48 | 49 | class GEGLU(nn.Module): 50 | def forward(self, x): 51 | x, gates = x.chunk(2, dim=-1) 52 | return x * F.gelu(gates) 53 | 54 | 55 | class FeedForward(nn.Module): 56 | def __init__(self, dim, mult=4): 57 | super().__init__() 58 | self.net = nn.Sequential( 59 | nn.Linear(dim, dim * mult * 2), 60 | GEGLU(), 61 | nn.Linear(dim * mult, dim) 62 | ) 63 | 64 | def forward(self, x): 65 | return self.net(x) 66 | 67 | 68 | class Attention(nn.Module): # is all you need. Living up to its name. 69 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0): 70 | 71 | super().__init__() 72 | inner_dim = dim_head * heads 73 | context_dim = default(context_dim, query_dim) 74 | self.scale = dim_head**-0.5 75 | self.heads = heads 76 | 77 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 78 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=False) 79 | self.to_out = nn.Linear(inner_dim, query_dim) 80 | 81 | self.dropout_p = dropout 82 | # dropout left in use_fast for backward compatibility 83 | self.dropout = nn.Dropout(self.dropout_p) 84 | 85 | def forward(self, x, context=None, mask=None): 86 | h = self.heads 87 | 88 | q = self.to_q(x) 89 | context = default(context, x) 90 | k, v = self.to_kv(context).chunk(2, dim=-1) 91 | 92 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 93 | sim = einsum("b i d, b j d -> b i j", q, k) * self.scale 94 | if exists(mask): 95 | mask = rearrange(mask, "b ... -> b (...)") 96 | max_neg_value = -torch.finfo(sim.dtype).max 97 | mask = repeat(mask, "b j -> (b h) () j", h=h) 98 | sim.masked_fill_(~mask, max_neg_value) 99 | # attention 100 | attn = sim.softmax(dim=-1) 101 | # dropout 102 | attn = self.dropout(attn) 103 | out = einsum("b i j, b j d -> b i d", attn, v) 104 | 105 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 106 | out = self.to_out(out) 107 | return out 108 | 109 | 110 | def act_layer(act): 111 | if act == "relu": 112 | return nn.ReLU() 113 | elif act == "lrelu": 114 | return nn.LeakyReLU(LRELU_SLOPE) 115 | elif act == "elu": 116 | return nn.ELU() 117 | elif act == "tanh": 118 | return nn.Tanh() 119 | elif act == "prelu": 120 | return nn.PReLU() 121 | else: 122 | raise ValueError("%s not recognized." % act) 123 | 124 | 125 | def norm_layer2d(norm, channels): 126 | if norm == "batch": 127 | return nn.BatchNorm2d(channels) 128 | elif norm == "instance": 129 | return nn.InstanceNorm2d(channels, affine=True) 130 | elif norm == "layer": 131 | return nn.GroupNorm(1, channels, affine=True) 132 | elif norm == "group": 133 | return nn.GroupNorm(4, channels, affine=True) 134 | else: 135 | raise ValueError("%s not recognized." % norm) 136 | 137 | 138 | def norm_layer1d(norm, num_channels): 139 | if norm == "batch": 140 | return nn.BatchNorm1d(num_channels) 141 | elif norm == "instance": 142 | return nn.InstanceNorm1d(num_channels, affine=True) 143 | elif norm == "layer": 144 | return nn.LayerNorm(num_channels) 145 | elif norm == "group": 146 | return nn.GroupNorm(4, num_channels, affine=True) 147 | else: 148 | raise ValueError("%s not recognized." % norm) 149 | 150 | 151 | class Conv2DBlock(nn.Module): 152 | def __init__( 153 | self, 154 | in_channels, 155 | out_channels, 156 | kernel_sizes=3, 157 | strides=1, 158 | norm=None, 159 | activation=None, 160 | padding_mode="replicate", 161 | padding=None, 162 | ): 163 | super().__init__() 164 | padding = kernel_sizes // 2 if padding is None else padding 165 | self.conv2d = nn.Conv2d( 166 | in_channels, 167 | out_channels, 168 | kernel_sizes, 169 | strides, 170 | padding=padding, 171 | padding_mode=padding_mode, 172 | ) 173 | 174 | if activation is None: 175 | nn.init.xavier_uniform_( 176 | self.conv2d.weight, gain=nn.init.calculate_gain("linear") 177 | ) 178 | nn.init.zeros_(self.conv2d.bias) 179 | elif activation == "tanh": 180 | nn.init.xavier_uniform_( 181 | self.conv2d.weight, gain=nn.init.calculate_gain("tanh") 182 | ) 183 | nn.init.zeros_(self.conv2d.bias) 184 | elif activation == "lrelu": 185 | nn.init.kaiming_uniform_( 186 | self.conv2d.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" 187 | ) 188 | nn.init.zeros_(self.conv2d.bias) 189 | elif activation == "relu": 190 | nn.init.kaiming_uniform_(self.conv2d.weight, nonlinearity="relu") 191 | nn.init.zeros_(self.conv2d.bias) 192 | else: 193 | raise ValueError() 194 | 195 | self.activation = None 196 | if norm is not None: 197 | self.norm = norm_layer2d(norm, out_channels) 198 | else: 199 | self.norm = None 200 | if activation is not None: 201 | self.activation = act_layer(activation) 202 | self.out_channels = out_channels 203 | 204 | def forward(self, x): 205 | x = self.conv2d(x) 206 | x = self.norm(x) if self.norm is not None else x 207 | x = self.activation(x) if self.activation is not None else x 208 | return x 209 | 210 | 211 | class Conv2DUpsampleBlock(nn.Module): 212 | def __init__( 213 | self, 214 | in_channels, 215 | out_channels, 216 | strides, 217 | kernel_sizes=3, 218 | norm=None, 219 | activation=None, 220 | ): 221 | super().__init__() 222 | layer = [ 223 | Conv2DBlock(in_channels, out_channels, kernel_sizes, 1, norm, activation) 224 | ] 225 | if strides > 1: 226 | layer.append( 227 | nn.Upsample(scale_factor=strides, mode="bilinear", align_corners=False) 228 | ) 229 | convt_block = Conv2DBlock( 230 | out_channels, out_channels, kernel_sizes, 1, norm, activation 231 | ) 232 | layer.append(convt_block) 233 | self.conv_up = nn.Sequential(*layer) 234 | 235 | def forward(self, x): 236 | return self.conv_up(x) 237 | 238 | 239 | class DenseBlock(nn.Module): 240 | def __init__(self, in_features, out_features, norm=None, activation=None): 241 | super(DenseBlock, self).__init__() 242 | self.linear = nn.Linear(in_features, out_features) 243 | 244 | if activation is None: 245 | nn.init.xavier_uniform_( 246 | self.linear.weight, gain=nn.init.calculate_gain("linear") 247 | ) 248 | nn.init.zeros_(self.linear.bias) 249 | elif activation == "tanh": 250 | nn.init.xavier_uniform_( 251 | self.linear.weight, gain=nn.init.calculate_gain("tanh") 252 | ) 253 | nn.init.zeros_(self.linear.bias) 254 | elif activation == "lrelu": 255 | nn.init.kaiming_uniform_( 256 | self.linear.weight, a=LRELU_SLOPE, nonlinearity="leaky_relu" 257 | ) 258 | nn.init.zeros_(self.linear.bias) 259 | elif activation == "relu": 260 | nn.init.kaiming_uniform_(self.linear.weight, nonlinearity="relu") 261 | nn.init.zeros_(self.linear.bias) 262 | else: 263 | raise ValueError() 264 | 265 | self.activation = None 266 | self.norm = None 267 | if norm is not None: 268 | self.norm = norm_layer1d(norm, out_features) 269 | if activation is not None: 270 | self.activation = act_layer(activation) 271 | 272 | def forward(self, x): 273 | x = self.linear(x) 274 | x = self.norm(x) if self.norm is not None else x 275 | x = self.activation(x) if self.activation is not None else x 276 | return x 277 | -------------------------------------------------------------------------------- /utils/match.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def partition_arg_topK(matrix, K, axis=0): 4 | """ find index of K smallest entries along a axis 5 | perform topK based on np.argpartition 6 | :param matrix: to be sorted 7 | :param K: select and sort the top K items 8 | :param axis: 0 or 1. dimension to be sorted. 9 | :return: 10 | """ 11 | a_part = np.argpartition(matrix, K, axis=axis) 12 | if axis == 0: 13 | row_index = np.arange(matrix.shape[1 - axis]) 14 | a_sec_argsort_K = np.argsort(matrix[a_part[0:K, :], row_index], axis=axis) 15 | return a_part[0:K, :][a_sec_argsort_K, row_index] 16 | else: 17 | column_index = np.arange(matrix.shape[1 - axis])[:, None] 18 | a_sec_argsort_K = np.argsort(matrix[column_index, a_part[:, 0:K]], axis=axis) 19 | return a_part[:, 0:K][column_index, a_sec_argsort_K] 20 | 21 | 22 | def knn_point_np(k, reference_pts, query_pts): 23 | ''' 24 | :param k: number of k in k-nn search 25 | :param reference_pts: (N, 3) float32 array, input points 26 | :param query_pts: (M, 3) float32 array, query points 27 | :return: 28 | val: (batch_size, npoint, k) float32 array, L2 distances 29 | idx: (batch_size, npoint, k) int32 array, indices to input points 30 | ''' 31 | 32 | N, _ = reference_pts.shape 33 | M, _ = query_pts.shape 34 | origin_N = N 35 | if N <= k: 36 | reference_pts = np.concatenate([reference_pts, np.full([k + 1 - N, 3], fill_value=float('inf'))]) 37 | N = k+1 38 | 39 | reference_pts = reference_pts.reshape(1, N, -1).repeat(M, axis=0) 40 | query_pts = query_pts.reshape(M, 1, -1).repeat(N, axis=1) 41 | dist = np.sum((reference_pts - query_pts) ** 2, -1) 42 | idx = partition_arg_topK(dist, K=k, axis=1) 43 | val = np.take_along_axis ( dist , idx, axis=1) 44 | if origin_N < N: 45 | idx[idx >= origin_N] = -1 46 | return np.sqrt(val), idx 47 | 48 | 49 | def mutual_neighbor_correspondence(src_pcd_deformed, tgt_pcd, search_radius=0.3, knn=1): 50 | src_idx = np.arange(src_pcd_deformed.shape[0]) 51 | 52 | s2t_dists, ref_tgt_idx = knn_point_np (knn, tgt_pcd, src_pcd_deformed) 53 | s2t_dists, ref_tgt_idx = s2t_dists[:,0], ref_tgt_idx [:, 0] 54 | valid_distance = s2t_dists < search_radius 55 | 56 | _, ref_src_idx = knn_point_np(knn, src_pcd_deformed, tgt_pcd) 57 | _, ref_src_idx = _, ref_src_idx[:, 0] 58 | 59 | cycle_src_idx = ref_src_idx[ref_tgt_idx] 60 | 61 | is_mutual_nn = cycle_src_idx == src_idx 62 | 63 | mutual_nn = np.logical_and( is_mutual_nn, valid_distance) 64 | correspondences = np.stack([src_idx[mutual_nn], ref_tgt_idx[mutual_nn]] , axis=0) 65 | 66 | return correspondences 67 | 68 | 69 | if __name__ == "__main__": 70 | 71 | query = np.random.rand(10, 3) 72 | reference = np.random.rand(20, 3) 73 | 74 | a = knn_point_np(3, reference, query) 75 | b = knn_point_np(100, reference, query) 76 | 77 | print(1) 78 | -------------------------------------------------------------------------------- /utils/math3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.neighbors import NearestNeighbors 4 | from scipy.spatial.transform import Rotation 5 | 6 | 7 | def point_to_voxel_index( 8 | point: np.ndarray, 9 | voxel_size: np.ndarray, 10 | coord_bounds: np.ndarray): 11 | bb_mins = np.array(coord_bounds[0:3]) 12 | bb_maxs = np.array(coord_bounds[3:]) 13 | dims_m_one = np.array([voxel_size] * 3) - 1 14 | bb_ranges = bb_maxs - bb_mins 15 | res = bb_ranges / (np.array([voxel_size] * 3) + 1e-12) 16 | voxel_indicy = np.minimum( 17 | np.floor((point - bb_mins) / (res + 1e-12)).astype( 18 | np.int32), dims_m_one) 19 | return voxel_indicy 20 | 21 | def stack_on_channel(x): 22 | # expect (B, T, C, ...) -> (B, T*C, ...) 23 | return torch.cat(torch.split(x, 1, dim=1), dim=2).squeeze(1) 24 | 25 | def normalize_quaternion(quat): 26 | return np.array(quat) / np.linalg.norm(quat, axis=-1, keepdims=True) 27 | 28 | def quaternion_to_discrete_euler(quaternion, resolution): 29 | euler = Rotation.from_quat(quaternion).as_euler('xyz', degrees=True) + 180 30 | assert np.min(euler) >= 0 and np.max(euler) <= 360 31 | disc = np.around((euler / resolution)).astype(int) 32 | disc[disc == int(360 / resolution)] = 0 33 | return disc 34 | 35 | def discrete_euler_to_quaternion(discrete_euler, resolution): 36 | euluer = (discrete_euler * resolution) - 180 37 | return Rotation.from_euler('xyz', euluer, degrees=True).as_quat() 38 | 39 | 40 | def sensitive_gimble_fix(euler): 41 | """ 42 | :param euler: euler angles in degree as np.ndarray in shape either [3] or 43 | [b, 3] 44 | """ 45 | # selecting sensitive angle, y-axis 46 | select1 = (89 < euler[..., 1]) & (euler[..., 1] < 91) 47 | euler[select1, 1] = 90 48 | # selecting sensitive angle 49 | select2 = (-91 < euler[..., 1]) & (euler[..., 1] < -89) 50 | euler[select2, 1] = -90 51 | 52 | # recalulating the euler angles, see assert 53 | r = Rotation.from_euler("xyz", euler, degrees=True) 54 | euler = r.as_euler("xyz", degrees=True) 55 | 56 | select = select1 | select2 57 | assert (euler[select][..., 2] == 0).all(), euler # z-axis for the fixed ones 58 | return euler 59 | 60 | 61 | ############################## 62 | # ICP # 63 | ############################## 64 | 65 | 66 | def best_fit_transform(A, B): 67 | ''' 68 | Calculates the least-squares best-fit transform that maps corresponding points A to B in m spatial dimensions 69 | Input: 70 | A: Nxm numpy array of corresponding points 71 | B: Nxm numpy array of corresponding points 72 | Returns: 73 | T: (m+1)x(m+1) homogeneous transformation matrix that maps A on to B 74 | R: mxm rotation matrix 75 | t: mx1 translation vector 76 | ''' 77 | 78 | assert A.shape == B.shape 79 | 80 | # get number of dimensions 81 | m = A.shape[1] 82 | 83 | # translate points to their centroids 84 | centroid_A = np.mean(A, axis=0) 85 | centroid_B = np.mean(B, axis=0) 86 | AA = A - centroid_A 87 | BB = B - centroid_B 88 | 89 | # rotation matrix 90 | H = np.dot(AA.T, BB) 91 | U, S, Vt = np.linalg.svd(H) 92 | R = np.dot(Vt.T, U.T) 93 | 94 | # special reflection case 95 | if np.linalg.det(R) < 0: 96 | Vt[m-1,:] *= -1 97 | R = np.dot(Vt.T, U.T) 98 | 99 | # translation 100 | t = centroid_B.T - np.dot(R,centroid_A.T) 101 | 102 | # homogeneous transformation 103 | T = np.identity(m+1) 104 | T[:m, :m] = R 105 | T[:m, m] = t 106 | 107 | return T, R, t 108 | 109 | 110 | def nearest_neighbor(src, dst): 111 | ''' 112 | Find the nearest (Euclidean) neighbor in dst for each point in src 113 | Input: 114 | src: Nxm array of points 115 | dst: Nxm array of points 116 | Output: 117 | distances: Euclidean distances of the nearest neighbor 118 | indices: dst indices of the nearest neighbor 119 | ''' 120 | 121 | assert src.shape == dst.shape 122 | 123 | neigh = NearestNeighbors(n_neighbors=1) 124 | neigh.fit(dst) 125 | distances, indices = neigh.kneighbors(src, return_distance=True) 126 | return distances.ravel(), indices.ravel() 127 | 128 | 129 | def icp(A, B, init_pose=None, max_iterations=100, tolerance=0.001): 130 | ''' 131 | The Iterative Closest Point method: finds best-fit transform that maps points A on to points B 132 | Input: 133 | A: Nxm numpy array of source mD points 134 | B: Nxm numpy array of destination mD point 135 | init_pose: (m+1)x(m+1) homogeneous transformation 136 | max_iterations: exit algorithm after max_iterations 137 | tolerance: convergence criteria 138 | Output: 139 | T: final homogeneous transformation that maps A on to B 140 | distances: Euclidean distances (errors) of the nearest neighbor 141 | i: number of iterations to converge 142 | ''' 143 | 144 | assert A.shape == B.shape 145 | 146 | # get number of dimensions 147 | m = A.shape[1] 148 | 149 | # make points homogeneous, copy them to maintain the originals 150 | src = np.ones((m+1,A.shape[0])) 151 | dst = np.ones((m+1,B.shape[0])) 152 | src[:m,:] = np.copy(A.T) 153 | dst[:m,:] = np.copy(B.T) 154 | 155 | # apply the initial pose estimation 156 | if init_pose is not None: 157 | src = np.dot(init_pose, src) 158 | 159 | prev_error = 0 160 | 161 | for i in range(max_iterations): 162 | # find the nearest neighbors between the current source and destination points 163 | distances, indices = nearest_neighbor(src[:m,:].T, dst[:m,:].T) 164 | 165 | # compute the transformation between the current source and nearest destination points 166 | T,_,_ = best_fit_transform(src[:m,:].T, dst[:m,indices].T) 167 | 168 | # update the current source 169 | src = np.dot(T, src) 170 | 171 | # check error 172 | mean_error = np.mean(distances) 173 | if np.abs(prev_error - mean_error) < tolerance: 174 | break 175 | prev_error = mean_error 176 | 177 | # calculate final transformation 178 | T,_,_ = best_fit_transform(A, src[:m,:].T) 179 | 180 | return T, distances, i 181 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Lock 2 | import numpy as np 3 | from typing import List 4 | from utils.structure import Summary, FullTransition, ScalarSummary 5 | 6 | class MetricAccumulator(object): 7 | 8 | def __init__(self): 9 | self._previous = [] 10 | self._current = 0 11 | 12 | def update(self, value): 13 | self._current += value 14 | 15 | def next(self): 16 | self._previous.append(self._current) 17 | self._current = 0 18 | 19 | def reset(self): 20 | self._previous.clear() 21 | 22 | def min(self): 23 | return np.min(self._previous) 24 | 25 | def max(self): 26 | return np.max(self._previous) 27 | 28 | def mean(self): 29 | return np.mean(self._previous) 30 | 31 | def median(self): 32 | return np.median(self._previous) 33 | 34 | def std(self): 35 | return np.std(self._previous) 36 | 37 | def __len__(self): 38 | return len(self._previous) 39 | 40 | def __getitem__(self, i): 41 | return self._previous[i] 42 | 43 | 44 | 45 | class SimpleAccumulator: 46 | 47 | def __init__(self, prefix, mean_only: bool = True): 48 | self._prefix = prefix 49 | self._mean_only = mean_only 50 | self._lock = Lock() 51 | self._episode_returns = MetricAccumulator() 52 | self._episode_lengths = MetricAccumulator() 53 | self._summaries = [] 54 | self._transitions = 0 55 | 56 | def _reset_data(self): 57 | with self._lock: 58 | self._episode_returns.reset() 59 | self._episode_lengths.reset() 60 | self._summaries.clear() 61 | 62 | def step(self, transition: FullTransition, eval: bool): 63 | with self._lock: 64 | self._transitions += 1 65 | self._episode_returns.update(transition.reward) 66 | self._episode_lengths.update(1) 67 | if transition.terminal: 68 | self._episode_returns.next() 69 | self._episode_lengths.next() 70 | self._summaries.extend(list(transition.summaries)) 71 | 72 | def _get(self) -> List[Summary]: 73 | sums = [] 74 | 75 | if self._mean_only: 76 | stat_keys = ["mean"] 77 | else: 78 | stat_keys = ["min", "max", "mean", "median", "std"] 79 | names = ["return", "length"] 80 | metrics = [self._episode_returns, self._episode_lengths] 81 | for name, metric in zip(names, metrics): 82 | for stat_key in stat_keys: 83 | if self._mean_only: 84 | assert stat_key == "mean" 85 | sum_name = '%s/%s' % (self._prefix, name) 86 | else: 87 | sum_name = '%s/%s/%s' % (self._prefix, name, stat_key) 88 | sums.append( 89 | ScalarSummary(sum_name, getattr(metric, stat_key)())) 90 | sums.append(ScalarSummary( 91 | '%s/total_transitions' % self._prefix, self._transitions)) 92 | sums.extend(self._summaries) 93 | return sums 94 | 95 | def pop(self) -> List[Summary]: 96 | data = [] 97 | if len(self._episode_returns) > 1: 98 | data = self._get() 99 | self._reset_data() 100 | return data 101 | 102 | def peak(self) -> List[Summary]: 103 | return self._get() 104 | 105 | def reset(self): 106 | self._transitions = 0 107 | self._reset_data() 108 | 109 | 110 | class StatAccumulator: 111 | 112 | def __init__(self, mean_only: bool = True): 113 | self._train_acc = SimpleAccumulator( 114 | 'train_envs', mean_only=mean_only) 115 | self._eval_acc = SimpleAccumulator( 116 | 'eval_envs', mean_only=mean_only) 117 | 118 | def step(self, transition: FullTransition, eval: bool): 119 | if eval: 120 | self._eval_acc.step(transition, eval) 121 | else: 122 | self._train_acc.step(transition, eval) 123 | 124 | def pop(self) -> List[Summary]: 125 | return self._train_acc.pop() + self._eval_acc.pop() 126 | 127 | def peak(self) -> List[Summary]: 128 | return self._train_acc.peak() + self._eval_acc.peak() 129 | 130 | def reset(self) -> None: 131 | self._train_acc.reset() 132 | self._eval_acc.reset() -------------------------------------------------------------------------------- /utils/object.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | 4 | def flat2d(lst): 5 | return sum(lst, []) 6 | 7 | 8 | class SkipWithBlock(Exception): 9 | pass 10 | 11 | 12 | class Section: 13 | def __init__(self, *args, skip=False): 14 | self.skip = skip 15 | 16 | def __enter__(self): 17 | if self.skip: 18 | sys.settrace(lambda *args, **keys: None) 19 | frame = sys._getframe(1) 20 | frame.f_trace = self.trace 21 | 22 | def trace(self, frame, event, arg): 23 | raise SkipWithBlock() 24 | 25 | def __exit__(self, exc_type, exc_value, traceback): 26 | if exc_type is None: 27 | return # No exception 28 | if issubclass(exc_type, SkipWithBlock): 29 | return True # Suppress special SkipWithBlock exception 30 | 31 | def step(self): 32 | pass 33 | 34 | 35 | def Todo(*args): 36 | pass 37 | 38 | 39 | def split_array_into_chunks(arr, n): 40 | if n <= 0: 41 | return "Number of chunks should be greater than 0" 42 | 43 | chunk_size = len(arr) // n 44 | chunks = [arr[i:i + chunk_size] for i in range(0, len(arr), chunk_size)] 45 | 46 | # Adjusting last chunk in case of uneven division 47 | if len(chunks) > n: 48 | chunks[-2] += chunks[-1] 49 | chunks = chunks[:-1] 50 | 51 | return chunks 52 | 53 | 54 | def to_item(v): 55 | if hasattr(v, 'item'): 56 | return v.item() 57 | else: 58 | return v 59 | 60 | def detach(v): 61 | if hasattr(v, 'detach'): 62 | return v.detach() 63 | else: 64 | return v 65 | 66 | 67 | def color_terms(s, words): 68 | for w in words: 69 | s = s.replace(w, f"\033[41m\033[97m{w}\033[0m") 70 | return s 71 | 72 | 73 | 74 | def simple_mean(v): 75 | if len(v) == 0: return 0 76 | else: return sum(v) / len(v) -------------------------------------------------------------------------------- /utils/optim.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py 2 | 3 | """Lamb optimizer.""" 4 | 5 | import collections 6 | import math 7 | 8 | import torch 9 | from torch.optim import Optimizer 10 | 11 | from torch.optim.lr_scheduler import _LRScheduler 12 | from torch.optim.lr_scheduler import ReduceLROnPlateau 13 | 14 | 15 | class Lamb(Optimizer): 16 | r"""Implements Lamb algorithm. 17 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 18 | Arguments: 19 | params (iterable): iterable of parameters to optimize or dicts defining 20 | parameter groups 21 | lr (float, optional): learning rate (default: 1e-3) 22 | betas (Tuple[float, float], optional): coefficients used for computing 23 | running averages of gradient and its square (default: (0.9, 0.999)) 24 | eps (float, optional): term added to the denominator to improve 25 | numerical stability (default: 1e-8) 26 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 27 | adam (bool, optional): always use trust ratio = 1, which turns this into 28 | Adam. Useful for comparison purposes. 29 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 30 | https://arxiv.org/abs/1904.00962 31 | """ 32 | 33 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, 34 | weight_decay=0, adam=False): 35 | if not 0.0 <= lr: 36 | raise ValueError("Invalid learning rate: {}".format(lr)) 37 | if not 0.0 <= eps: 38 | raise ValueError("Invalid epsilon value: {}".format(eps)) 39 | if not 0.0 <= betas[0] < 1.0: 40 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 41 | if not 0.0 <= betas[1] < 1.0: 42 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 43 | defaults = dict(lr=lr, betas=betas, eps=eps, 44 | weight_decay=weight_decay) 45 | self.adam = adam 46 | super(Lamb, self).__init__(params, defaults) 47 | 48 | def step(self, closure=None): 49 | """Performs a single optimization step. 50 | Arguments: 51 | closure (callable, optional): A closure that reevaluates the model 52 | and returns the loss. 53 | """ 54 | loss = None 55 | if closure is not None: 56 | loss = closure() 57 | 58 | for group in self.param_groups: 59 | for p in group['params']: 60 | if p.grad is None: 61 | continue 62 | grad = p.grad.data 63 | if grad.is_sparse: 64 | raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.') 65 | 66 | state = self.state[p] 67 | 68 | # State initialization 69 | if len(state) == 0: 70 | state['step'] = 0 71 | # Exponential moving average of gradient values 72 | state['exp_avg'] = torch.zeros_like(p.data) 73 | # Exponential moving average of squared gradient values 74 | state['exp_avg_sq'] = torch.zeros_like(p.data) 75 | 76 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 77 | beta1, beta2 = group['betas'] 78 | 79 | state['step'] += 1 80 | 81 | # Decay the first and second moment running average coefficient 82 | # m_t 83 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 84 | # v_t 85 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 86 | 87 | # Paper v3 does not use debiasing. 88 | # bias_correction1 = 1 - beta1 ** state['step'] 89 | # bias_correction2 = 1 - beta2 ** state['step'] 90 | # Apply bias to lr to avoid broadcast. 91 | step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 92 | 93 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 94 | 95 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) 96 | if group['weight_decay'] != 0: 97 | adam_step.add_(p.data, alpha=group['weight_decay']) 98 | 99 | adam_norm = adam_step.pow(2).sum().sqrt() 100 | if weight_norm == 0 or adam_norm == 0: 101 | trust_ratio = 1 102 | else: 103 | trust_ratio = weight_norm / adam_norm 104 | state['weight_norm'] = weight_norm 105 | state['adam_norm'] = adam_norm 106 | state['trust_ratio'] = trust_ratio 107 | if self.adam: 108 | trust_ratio = 1 109 | 110 | p.data.add_(adam_step, alpha=-step_size * trust_ratio) 111 | 112 | return loss 113 | 114 | 115 | # source: https://github.com/ildoonet/pytorch-gradual-warmup-lr/blob/master/warmup_scheduler/scheduler.py 116 | # updated such that it is suitable for cases when epoch number start from 0 117 | # lr constantly increases from "epoch 0" to "epoch (total_epoch - 1)" such that 118 | # lr at epoch is same as the base_lr for the after_scheduler 119 | # Only tested for case when multiplier is 1.0 and after schduler is a 120 | # MultiStepLR 121 | 122 | class GradualWarmupScheduler(_LRScheduler): 123 | """Gradually warm-up(increasing) learning rate in optimizer. 124 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 125 | Args: 126 | optimizer (Optimizer): Wrapped optimizer. 127 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 128 | total_epoch: target learning rate is reached at total_epoch, gradually 129 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 130 | """ 131 | 132 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 133 | self.multiplier = multiplier 134 | if self.multiplier < 1.0: 135 | raise ValueError("multiplier should be greater thant or equal to 1.") 136 | self.total_epoch = total_epoch 137 | self.after_scheduler = after_scheduler 138 | self.finished = False 139 | super(GradualWarmupScheduler, self).__init__(optimizer) 140 | 141 | def get_lr(self): 142 | if (self.last_epoch + 1) > self.total_epoch: 143 | if self.after_scheduler: 144 | if not self.finished: 145 | self.after_scheduler.base_lrs = [ 146 | base_lr * self.multiplier for base_lr in self.base_lrs 147 | ] 148 | self.finished = True 149 | return self.after_scheduler.get_last_lr() 150 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 151 | 152 | if self.multiplier == 1.0: 153 | return [ 154 | base_lr * ((float(self.last_epoch) + 1) / self.total_epoch) 155 | for base_lr in self.base_lrs 156 | ] 157 | else: 158 | return [ 159 | base_lr 160 | * ( 161 | (self.multiplier - 1.0) * (self.last_epoch + 1) / self.total_epoch 162 | + 1.0 163 | ) 164 | for base_lr in self.base_lrs 165 | ] 166 | 167 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 168 | if epoch is None: 169 | epoch = self.last_epoch + 1 170 | self.last_epoch = ( 171 | epoch if epoch != 0 else 1 172 | ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 173 | if self.last_epoch <= self.total_epoch: 174 | warmup_lr = [ 175 | base_lr 176 | * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) 177 | for base_lr in self.base_lrs 178 | ] 179 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 180 | param_group["lr"] = lr 181 | else: 182 | if epoch is None: 183 | self.after_scheduler.step(metrics, None) 184 | else: 185 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 186 | 187 | def step(self, epoch=None, metrics=None): 188 | if type(self.after_scheduler) != ReduceLROnPlateau: 189 | if self.finished and self.after_scheduler: 190 | if epoch is None: 191 | self.after_scheduler.step(None) 192 | else: 193 | self.after_scheduler.step(epoch) 194 | self._last_lr = self.after_scheduler.get_last_lr() 195 | else: 196 | return super(GradualWarmupScheduler, self).step(epoch) 197 | else: 198 | self.step_ReduceLROnPlateau(metrics, epoch) 199 | 200 | def state_dict(self): 201 | state_dict = { 202 | key: value 203 | for key, value in self.__dict__.items() 204 | if key not in ["optimizer", "after_scheduler"] 205 | } 206 | 207 | if not (self.after_scheduler is None): 208 | state_dict["after_scheduler_state_dict"] = self.after_scheduler.state_dict() 209 | 210 | return state_dict 211 | 212 | def load_state_dict(self, state_dict): 213 | if self.after_scheduler is None: 214 | assert not ("after_scheduler_state_dict" in state_dict) 215 | else: 216 | self.after_scheduler.load_state_dict( 217 | state_dict["after_scheduler_state_dict"] 218 | ) 219 | del state_dict["after_scheduler_state_dict"] 220 | 221 | self.__dict__.update(state_dict) 222 | -------------------------------------------------------------------------------- /utils/rollout.py: -------------------------------------------------------------------------------- 1 | from multiprocessing import Value 2 | import numpy as np 3 | import torch 4 | from utils.structure import Agent, Env, FullTransition, ActResult 5 | 6 | 7 | class RolloutGenerator(object): 8 | 9 | def __init__(self, device = 'cuda:0'): 10 | self._env_device = device 11 | 12 | def _get_type(self, x): 13 | if not hasattr(x, 'dtype'): return np.float32 14 | if x.dtype == np.float64: 15 | return np.float32 16 | return x.dtype 17 | 18 | def generator(self, step_signal: Value, 19 | env: Env, agent: Agent, 20 | episode_length: int, 21 | eval: bool, eval_seed: int = 0, 22 | record_enabled: bool = False): 23 | 24 | if eval: 25 | obs = env.reset_to_demo(eval_seed) 26 | else: 27 | obs = env.reset() 28 | agent.reset(task={v:k for k,v in env._task_name_to_idx.items()}[env._active_task_id], desc=env._lang_goal) 29 | obs_history = {k: np.array(v, dtype=self._get_type(v)) if not isinstance(v, dict) else v for k, v in obs.items()} 30 | for step in range(episode_length): 31 | # add batch dimension 32 | prepped_data = {k: torch.tensor(v, device=self._env_device)[None, ...] if not isinstance(v, dict) else v 33 | for k, v in obs_history.items()} 34 | act_result = agent.act(step_signal.value, prepped_data) 35 | 36 | if act_result is None: 37 | return 38 | 39 | # Convert to np if not already 40 | agent_obs_elems = {k: np.array(v) for k, v in act_result.observation_elements.items()} 41 | extra_replay_elements = {k: np.array(v) for k, v in act_result.replay_elements.items()} 42 | 43 | transition = env.step(act_result) 44 | obs_tp1 = dict(transition.observation) 45 | timeout = False 46 | if step == episode_length - 1: 47 | # If last transition, and not terminal, then we timed out 48 | timeout = not transition.terminal 49 | if timeout: 50 | transition.terminal = True 51 | if "needs_reset" in transition.info: 52 | transition.info["needs_reset"] = True 53 | 54 | obs_and_replay_elems = {} 55 | obs_and_replay_elems.update(obs) 56 | obs_and_replay_elems.update(agent_obs_elems) 57 | obs_and_replay_elems.update(extra_replay_elements) 58 | 59 | for k in obs_history.keys(): 60 | obs_history[k] = transition.observation[k] 61 | 62 | transition.info["active_task_id"] = env.active_task_id 63 | 64 | replay_transition = FullTransition( 65 | obs_and_replay_elems, act_result.action, transition.reward, 66 | transition.terminal, timeout, summaries=transition.summaries, 67 | info=transition.info) 68 | 69 | if transition.terminal or timeout: 70 | # If the agent gives us observations then we need to call act 71 | # one last time (i.e. acting in the terminal state). 72 | if len(act_result.observation_elements) > 0: 73 | prepped_data = {k: torch.tensor([v], device=self._env_device) if not isinstance(v, dict) else v 74 | for k, v in obs_history.items()} 75 | act_result = agent.act(step_signal.value, prepped_data) 76 | agent_obs_elems_tp1 = {k: np.array(v) for k, v in act_result.observation_elements.items()} 77 | obs_tp1.update(agent_obs_elems_tp1) 78 | replay_transition.final_observation = obs_tp1 79 | 80 | if record_enabled and transition.terminal or timeout or step == episode_length - 1: 81 | env.env._action_mode.arm_action_mode.record_end(env.env._scene, steps=60, step_scene=True) 82 | obs = dict(transition.observation) 83 | 84 | yield replay_transition 85 | 86 | if transition.info.get("needs_reset", transition.terminal): 87 | return -------------------------------------------------------------------------------- /utils/str.py: -------------------------------------------------------------------------------- 1 | from functools import reduce 2 | 3 | 4 | def insert_uline_before_cap(str): 5 | return reduce(lambda x, y: x + ('_' if y.isupper() else '') + y, str).lower() -------------------------------------------------------------------------------- /utils/structure.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | import lzma as lzmalib 3 | import os 4 | import os.path as osp 5 | import numpy as np 6 | from dataclasses import dataclass, field 7 | from enum import Enum 8 | import pickle 9 | import json 10 | from abc import ABC, abstractmethod 11 | import base64 12 | from hashlib import blake2b 13 | 14 | 15 | LOW_DIM_PICKLE = 'low_dim_obs.pkl' 16 | KEYPOINT_JSON = "keypoints.json" 17 | VARIATION_NUMBER_PICKLE = 'variation_number.pkl' 18 | LANG_GOAL_EMB = "lang_emb.pkl" 19 | DESC_PICKLE = "variation_descriptions.pkl" 20 | 21 | 22 | BASE_RLBENCH_TASKS = [ 23 | "put_item_in_drawer", 24 | "reach_and_drag", 25 | "turn_tap", 26 | "slide_block_to_color_target", 27 | "open_drawer", 28 | "put_groceries_in_cupboard", 29 | "place_shape_in_shape_sorter", 30 | "put_money_in_safe", 31 | "push_buttons", 32 | "close_jar", 33 | "stack_blocks", 34 | "place_cups", 35 | "place_wine_at_rack_location", 36 | "light_bulb_in", 37 | "sweep_to_dustpan_of_size", 38 | "insert_onto_square_peg", 39 | "meat_off_grill", 40 | "stack_cups", 41 | ] 42 | 43 | NOVEL_RLBENCH_TASKS = [ 44 | "basketball_in_hoop", 45 | "put_rubbish_in_bin", 46 | "scoop_with_spatula", 47 | "place_hanger_on_rack", 48 | "hit_ball_with_queue", 49 | "block_pyramid", 50 | "take_lid_off_saucepan", 51 | "lamp_on", 52 | "phone_on_base", 53 | "open_box", 54 | "close_laptop_lid", 55 | "beat_the_buzz", 56 | "remove_cups", 57 | "play_jenga", 58 | "put_knife_on_chopping_board", 59 | "straighten_rope", 60 | "change_clock", 61 | "open_wine_bottle", 62 | "open_door", 63 | "take_money_out_safe", 64 | "close_microwave", 65 | "slide_cabinet_open_and_place_cups" 66 | ] 67 | 68 | 69 | 70 | def load_pkl(fp, lzma=False): 71 | if lzma: 72 | with lzmalib.open(fp, 'rb') as f: 73 | return pickle.load(f) 74 | else: 75 | with open(fp, 'rb') as f: 76 | return pickle.load(f) 77 | 78 | def as_list(x): 79 | if isinstance(x, list): 80 | return x 81 | elif isinstance(x, tuple): 82 | return list(x) 83 | elif hasattr(x, 'tolist'): 84 | return x.tolist() 85 | else: 86 | return [x] 87 | 88 | 89 | def load_json(fp): 90 | with open(fp, 'r') as f: 91 | return json.load(f) 92 | 93 | 94 | def ensure_dir(d): 95 | if d: os.makedirs(d, exist_ok=True) 96 | 97 | def dump_pkl(fp, obj, lzma=False): 98 | ensure_dir(osp.dirname(fp)) 99 | if lzma: 100 | with lzmalib.open(fp, 'wb') as f: 101 | return pickle.dump(obj, f) 102 | else: 103 | with open(fp, 'wb') as f: 104 | return pickle.dump(obj, f) 105 | 106 | 107 | def dump_json(fp, obj, **kwargs): 108 | ensure_dir(osp.dirname(fp)) 109 | with open(fp, 'w') as f: 110 | return json.dump(obj, f, **kwargs) 111 | 112 | @dataclass 113 | class ActResult: 114 | action: Any 115 | observation_elements: dict = field(default_factory=dict) 116 | replay_elements: dict = field(default_factory=dict) 117 | info: dict = field(default_factory=dict) 118 | 119 | @dataclass 120 | class ObservationElement: 121 | name: str 122 | shape: tuple 123 | type: Any 124 | 125 | 126 | @dataclass 127 | class Summary: 128 | name: str 129 | value: Any 130 | 131 | @dataclass 132 | class TextSummary(Summary): pass 133 | 134 | ScalarSummary = TextSummary 135 | 136 | @dataclass 137 | class ImageSummary(Summary): pass 138 | 139 | @dataclass 140 | class VideoSummary: 141 | name: str 142 | value: Any 143 | fps: int = 30 144 | 145 | 146 | @dataclass 147 | class Transition: 148 | observation: dict 149 | reward: float 150 | terminal: bool 151 | info: dict = field(default_factory=dict) 152 | summaries: List[Summary] = field(default_factory=list) 153 | 154 | @dataclass 155 | class FullTransition: 156 | observation: dict 157 | action: np.ndarray 158 | reward: float 159 | terminal: bool 160 | timeout: bool 161 | summaries: List[Summary] = field(default_factory=list) 162 | info: dict = field(default_factory=dict) 163 | 164 | 165 | ROBOT_STATE_KEYS = ['joint_velocities', 'joint_positions', 'joint_forces', 166 | 'gripper_open', 'gripper_pose', 167 | 'gripper_joint_positions', 'gripper_touch_forces', 168 | 'task_low_dim_state', 'misc'] 169 | 170 | 171 | 172 | 173 | class Env(ABC): 174 | 175 | def __init__(self): 176 | self._active_task_id = 0 177 | self._eval_env = False 178 | 179 | @property 180 | def eval(self): 181 | return self._eval_env 182 | 183 | @eval.setter 184 | def eval(self, eval): 185 | self._eval_env = eval 186 | 187 | @property 188 | def active_task_id(self) -> int: 189 | return self._active_task_id 190 | 191 | @abstractmethod 192 | def launch(self) -> None: 193 | pass 194 | 195 | def shutdown(self) -> None: 196 | pass 197 | 198 | @abstractmethod 199 | def reset(self) -> dict: 200 | pass 201 | 202 | @abstractmethod 203 | def step(self, action: np.ndarray) -> Transition: 204 | pass 205 | 206 | @property 207 | @abstractmethod 208 | def observation_elements(self) -> List[ObservationElement]: 209 | pass 210 | 211 | @property 212 | @abstractmethod 213 | def action_shape(self) -> tuple: 214 | pass 215 | 216 | @property 217 | @abstractmethod 218 | def env(self) -> Any: 219 | pass 220 | 221 | @property 222 | @abstractmethod 223 | def num_tasks(self) -> int: 224 | pass 225 | 226 | 227 | 228 | 229 | class Agent(ABC): 230 | 231 | @abstractmethod 232 | def build(self, training: bool, device=None) -> None: 233 | pass 234 | 235 | @abstractmethod 236 | def update(self, step: int, replay_sample: dict, **kwargs) -> dict: 237 | pass 238 | 239 | @abstractmethod 240 | def act(self, step: int, observation: dict, **kwargs) -> ActResult: 241 | # returns dict of values that get put in the replay. 242 | # One of these must be 'action'. 243 | pass 244 | 245 | def reset(self) -> None: 246 | pass 247 | 248 | def reset_to_demo(self, i: int, variation_number: int=-1) -> None: 249 | pass 250 | 251 | 252 | 253 | @dataclass 254 | class DataElement: 255 | name: str 256 | shape: tuple 257 | type: Any 258 | is_observation: bool = False 259 | 260 | 261 | 262 | def hash_object(*objs): 263 | h = blake2b(digest_size=20) 264 | h.update(pickle.dumps(objs)) 265 | return h.hexdigest() 266 | 267 | -------------------------------------------------------------------------------- /utils/transfer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from tqdm import tqdm 4 | import numpy as np 5 | from scipy.spatial.transform import Rotation 6 | from typing import List, Optional 7 | 8 | arm_mask_codes = [31, 34, 35, 39, 40, 41, 42, 43, 44, 45, 46] 9 | table_mask_codes = [48, 52] 10 | bg_mask_codes = [10, 55] 11 | scene_bounds = [-0.3, -0.5, 0.6, 0.7, 0.5, 1.6] 12 | 13 | 14 | def clean_mask(mask): 15 | for a in arm_mask_codes: 16 | mask[mask == a] = 0 17 | for a in table_mask_codes: 18 | mask[mask == a] = 1 19 | return mask 20 | 21 | def keep_valid_pcd(pc, other, scene_bounds): 22 | x_min, y_min, z_min, x_max, y_max, z_max = scene_bounds 23 | inv_pnt = ( # invalid points 24 | (pc[:, 0] < x_min) 25 | | (pc[:, 0] > x_max) 26 | | (pc[:, 1] < y_min) 27 | | (pc[:, 1] > y_max) 28 | | (pc[:, 2] < z_min) 29 | | (pc[:, 2] > z_max) 30 | | np.isnan(pc[:, 0]) 31 | | np.isnan(pc[:, 1]) 32 | | np.isnan(pc[:, 2]) 33 | ) 34 | return pc[~inv_pnt], tuple([x[~inv_pnt] for x in other]) 35 | 36 | 37 | def normalize_within_bounds(pc, scene_bounds): 38 | x_min, y_min, z_min, x_max, y_max, z_max = scene_bounds 39 | pc = pc.copy() 40 | pc[:, 0] = (pc[:, 0] - x_min) / (x_max - x_min) 41 | pc[:, 1] = (pc[:, 1] - y_min) / (y_max - y_min) 42 | pc[:, 2] = (pc[:, 2] - z_min) / (z_max - z_min) 43 | return pc 44 | 45 | 46 | def gripper_pose_2_frame(gripper_pose, scale=0.25): 47 | t = normalize_within_bounds(gripper_pose[None, :3], scene_bounds).flatten() 48 | dcm = Rotation.from_quat(gripper_pose[3:]).as_matrix() 49 | dcm *= scale 50 | return t, dcm[0, :] + t, dcm[1, :] + t, dcm[2, :] + t 51 | 52 | 53 | def assemble_point_cloud(m): 54 | pcd = [] 55 | rgb = [] 56 | mask = [] 57 | for c in CAMERAS: 58 | pcd.append(m[f"{c}_point_cloud"].reshape(-1, 3)) 59 | rgb.append(m[f"{c}_rgb"].reshape(-1, 3)) 60 | mask.append(m[f"{c}_mask"].reshape(-1, 3)) 61 | pcd = np.concatenate(pcd) 62 | rgb = np.concatenate(rgb) 63 | mask = np.concatenate(mask) 64 | pcd, (rgb, mask) = keep_valid_pcd(pcd, (rgb, mask), scene_bounds) 65 | pcd = normalize_within_bounds(pcd, scene_bounds) 66 | return pcd, rgb, clean_mask(mask[:, 0]) 67 | 68 | 69 | def transfer_gripper_pose( 70 | pcd1, 71 | pcd2, 72 | mask1, 73 | mask2, 74 | object_id_1, 75 | object_id_2, 76 | gripper_pose_1, 77 | origin_frame_scale=0.25, 78 | ): 79 | obj_pts_1 = pcd1[mask1 == object_id_1] 80 | obj_pts_2 = pcd2[mask2 == object_id_2] 81 | 82 | min_len = min(len(obj_pts_1), len(obj_pts_2)) 83 | 84 | if len(obj_pts_1) > min_len: 85 | obj_pts_1 = obj_pts_1[np.random.permutation(len(obj_pts_1))[:min_len]] 86 | 87 | if len(obj_pts_2) > min_len: 88 | obj_pts_2 = obj_pts_2[np.random.permutation(len(obj_pts_2))[:min_len]] 89 | 90 | T, _, _ = icp(obj_pts_1, obj_pts_2) 91 | 92 | # translation 93 | pose1_t = normalize_within_bounds(gripper_pose_1[None, :3], scene_bounds).flatten() 94 | new_t = (T @ np.concatenate([pose1_t, np.ones(1)]))[:3] 95 | 96 | new_R = T[:3, :3] @ Rotation.from_quat(gripper_pose_1[3:]).as_matrix() 97 | 98 | result = { 99 | "gripper_pose": np.concatenate([new_t, Rotation.from_matrix(new_R).as_quat()]), 100 | "gripper_frame": [ 101 | new_t, 102 | new_R[0] * origin_frame_scale + new_t, 103 | new_R[1] * origin_frame_scale + new_t, 104 | new_R[2] * origin_frame_scale + new_t, 105 | ], 106 | } 107 | return result 108 | 109 | 110 | def get_applicable_frame_idxes(kps, win_size=10): 111 | frames = [] 112 | prev_kp = 0 113 | 114 | for k in kps: 115 | if k - win_size > prev_kp: 116 | frames += list(range(prev_kp, k - win_size)) 117 | else: 118 | frames.append(prev_kp) 119 | prev_kp = k 120 | 121 | frames.append(prev_kp) 122 | return frames 123 | 124 | 125 | if __name__ == "__main__": 126 | from PIL import Image 127 | from utils.structure import BASE_RLBENCH_TASKS 128 | root = "/scratch/xz653/code/RVT/data/train" 129 | 130 | print(get_applicable_frame_idxes([43, 80, 85, 100])) 131 | 132 | # if False: 133 | # for task in tqdm(BASE_RLBENCH_TASKS): 134 | # print(f'task - {task}') 135 | # episode_path = osp.join(root, task, 'all_variations/episodes/episode0') 136 | # mask = np.array(Image.open(osp.join(episode_path, f"front_mask", '1.png'))).reshape(-1, 3)[:, 0] 137 | # print(f'\t before: {set(np.unique(mask).tolist())}') 138 | # mask = clean_mask(mask) 139 | # mask_inds = np.unique(mask).tolist() 140 | # for m in mask_inds: 141 | # if len(str(m)) < 3 and m not in [0, 1]: 142 | # print(m) 143 | # print(f'\t after: {set(mask_inds)}') 144 | 145 | -------------------------------------------------------------------------------- /utils/vis.py: -------------------------------------------------------------------------------- 1 | import random 2 | import ipyvolume as ipv 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | import matplotlib.colors as mcolors 6 | 7 | 8 | def get_color_map(arr, shuffle=True): 9 | colors = dict(mcolors.CSS4_COLORS) 10 | colors.pop("white") 11 | colors = sorted(list(colors.items())) 12 | colors = [c for _, c in colors] 13 | if shuffle: 14 | random.shuffle(colors) 15 | nums = np.unique(arr) 16 | return dict(zip(nums, colors)) 17 | 18 | 19 | def show_pcd(pcd, rgb=None, frame=None, mask=None, autoscale=True, autoscale_rgb="auto", frame_color=None, save=None, title="", save_kwargs={}, return_view_scale=False, vis_kwargs={}, with_axis=False): 20 | x, y, z = pcd[:, 0], pcd[:, 1], pcd[:, 2] 21 | 22 | if isinstance(autoscale, list): 23 | x_min, y_min, z_min, x_max, y_max, z_max = autoscale 24 | else: 25 | x_min, x_max = x.min(), x.max() 26 | y_min, y_max = y.min(), y.max() 27 | z_min, z_max = z.min(), z.max() 28 | 29 | if autoscale_rgb == "auto" and rgb is not None: 30 | autoscale_rgb = rgb.max() >= 10 or rgb.min() <= -0.5 # then it is probably 255-based color 31 | 32 | 33 | kwargs = {} 34 | if rgb is not None: 35 | kwargs["color"] = rgb 36 | if autoscale_rgb: 37 | if rgb.max() >= 10: 38 | kwargs["color"] = kwargs["color"].astype(float) / 255.0 39 | if rgb.min() <= -0.5: 40 | kwargs["color"] = (kwargs["color"] + 1) / 2 41 | 42 | if autoscale: 43 | # import warnings 44 | 45 | # with warnings.catch_warnings(): 46 | # warnings.filterwarnings('error') 47 | # try: 48 | x, y, z = (x - x_min) / (x_max - x_min), (y - y_min) / (y_max - y_min), (z - z_min) / (z_max - z_min) 49 | # except Warning as e: 50 | # print(1) 51 | 52 | if mask is not None: 53 | fig = ipv.figure() 54 | cmap = get_color_map(mask) 55 | for mask_ind, color in cmap.items(): 56 | inds = mask == mask_ind 57 | if rgb is not None: 58 | if isinstance(rgb, dict): 59 | color = rgb[mask_ind] 60 | else: 61 | color = rgb[inds] 62 | ipv.scatter( 63 | x[inds], 64 | y[inds], 65 | z[inds], 66 | size=1, 67 | color=color, 68 | description=str(mask_ind), 69 | **vis_kwargs 70 | ) 71 | else: 72 | fig = ipv.quickscatter(x, y, z, size=1, description=title.split('/')[-1] or "Pointcloud", **kwargs, **vis_kwargs) 73 | 74 | min_v = np.array([x_min, y_min, z_min]) 75 | max_v = np.array([x_max, y_max, z_max]) 76 | if frame is not None: 77 | if isinstance(frame[0][0], np.ndarray): 78 | # list of frames 79 | if autoscale: 80 | frame = [[(a - min_v) / (max_v - min_v) for a in f] for f in frame] 81 | if isinstance(frame_color, list): 82 | it = zip(frame_color, frame) 83 | else: 84 | it = zip([frame_color] * len(frame), frame) 85 | for fc, f in it: 86 | draw_frame(*f, color=fc) 87 | else: 88 | if autoscale: 89 | frame = [(a - min_v) / (max_v - min_v) for a in frame] 90 | draw_frame(*frame, color=frame_color) 91 | 92 | 93 | if not with_axis: 94 | ipv.style.use('minimal') 95 | 96 | if save: 97 | ipv.save(save, title=title or save, **save_kwargs) 98 | 99 | 100 | if return_view_scale: 101 | return ipv.gcc(), [min_v, max_v] 102 | else: 103 | return ipv.gcc() 104 | 105 | 106 | # def scale_pts_for_view(pts, view_scale): 107 | # min_v, max_v = [a.reshape(1, -1) for a in view_scale] 108 | # return (pts - min_v) / (max_v - min_v) 109 | 110 | 111 | def draw_frame(o, x, y, z, color=None): 112 | # draw_frame([0.1, 0.1, 0.1], [0.5, 0.1, 0.1], [0.1, 0.5, 0.1], [0.1, 0.1, 0.5]) 113 | default_colors = [ 114 | np.array([[255, 255, 0], [255, 0, 0], [0, 255, 0], [0, 0, 255]]).astype( 115 | float) / 255, 116 | "red", "green", "blue" 117 | ] 118 | if color is None: 119 | colors = default_colors 120 | elif isinstance(color, (float, int)): 121 | colors = [c * color if isinstance(c, np.ndarray) else c for c in default_colors] 122 | else: 123 | colors = [color, ] * 4 124 | 125 | ipv.scatter( 126 | np.array([o[0], x[0], y[0], z[0]]), 127 | np.array([o[1], x[1], y[1], z[1]]), 128 | np.array([o[2], x[2], y[2], z[2]]), 129 | color=colors[0], 130 | marker="sphere", 131 | size=3, 132 | ) 133 | ipv.plot( 134 | np.array([o[0], x[0]]), 135 | np.array([o[1], x[1]]), 136 | np.array([o[2], x[2]]), 137 | color=colors[1], 138 | ) 139 | ipv.plot( 140 | np.array([o[0], y[0]]), 141 | np.array([o[1], y[1]]), 142 | np.array([o[2], y[2]]), 143 | color=colors[2], 144 | ) 145 | ipv.plot( 146 | np.array([o[0], z[0]]), 147 | np.array([o[1], z[1]]), 148 | np.array([o[2], z[2]]), 149 | color=colors[3], 150 | ) 151 | 152 | 153 | def draw_knn_point(source, endpoints, color='red', size=2): 154 | """source: (3), endpoints: (k,3) """ 155 | 156 | ipv.scatter( 157 | np.concatenate([source[:1], endpoints[:, 0]]), 158 | np.concatenate([source[1:2], endpoints[:, 1]]), 159 | np.concatenate([source[2:], endpoints[:, 2]]), 160 | color=color, 161 | marker="sphere", 162 | size=3, 163 | ) 164 | 165 | for ep in endpoints: 166 | ipv.plot( 167 | np.array([source[0], ep[0]]), 168 | np.array([source[1], ep[1]]), 169 | np.array([source[2], ep[2]]), 170 | color=color, 171 | ) 172 | 173 | 174 | 175 | def draw_ball(center, radius, N=100): 176 | x0, y0, z0 = center 177 | fig = ipv.gcf() 178 | (xmin, xmax), (ymin, ymax), (zmin, zmax) = fig.xlim, fig.ylim, fig.zlim 179 | xstep, ystep, zstep = (xmax - xmin) / N, (ymax - ymin) / N, (zmax - zmin) / N 180 | x, y, z = np.ogrid[xmin:xmax:xstep, ymin:ymax:ystep, zmin:zmax:zstep] 181 | r = np.sqrt((x - x0)**2 + (y - y0)**2 + (z - z0)**2) 182 | r[r > radius] = 0 183 | ipv.volshow(r.T, extent=[(xmin, xmax), (ymin, ymax), (zmin, zmax)]) 184 | 185 | 186 | 187 | def to_named_masks(mask, id2names, key = None): 188 | vis_mask = np.empty((len(mask), ), dtype=object) 189 | for mask_id, name in id2names.items(): 190 | tag = f'{name} ({mask_id})' 191 | if mask_id == key: tag += '【*】' 192 | vis_mask[mask == mask_id] = tag 193 | return vis_mask --------------------------------------------------------------------------------