├── .gitignore ├── .gitmodules ├── LICENSE ├── configs └── train.yaml ├── dataloaders ├── __init__.py ├── open_classification.py ├── real_dataset.py ├── record3d.py └── scannet_200_classes.py ├── demo ├── 1 - parse rgbd.ipynb ├── 2 - label with pretrained models.ipynb ├── 3 - training a CLIP field.ipynb └── 4 - test model.ipynb ├── docs ├── css │ └── simple-grid.css ├── index.html ├── kitchen.html ├── mfiles │ ├── 1.mp4 │ ├── 2.mp4 │ ├── 3.mp4 │ ├── Exp Table.001.png │ ├── arch bet.001.png │ ├── arch.avif │ ├── arch.jpeg │ ├── arch.png │ ├── behavior_transformers.mp4 │ ├── data_processing.avif │ ├── data_processing.jpeg │ ├── data_processing.png │ ├── env │ │ ├── carla │ │ │ ├── 1_obs.mp4 │ │ │ ├── 1_over.mp4 │ │ │ ├── 2_obs.mp4 │ │ │ └── 2_over.mp4 │ │ ├── kitchen │ │ │ ├── 1.mp4 │ │ │ ├── 2.mp4 │ │ │ └── all │ │ │ │ ├── kblh_small.mp4 │ │ │ │ ├── kbls_small.mp4 │ │ │ │ ├── kbth_small.mp4 │ │ │ │ ├── kbtl_small.mp4 │ │ │ │ ├── mklh_small.mp4 │ │ │ │ ├── mkls_small.mp4 │ │ │ │ ├── mkth_small.mp4 │ │ │ │ └── mktl_small.mp4 │ │ └── pushblock │ │ │ ├── 1.mp4 │ │ │ ├── 2.mp4 │ │ │ └── all │ │ │ ├── g1gg.mp4 │ │ │ ├── g1gg2.mp4 │ │ │ ├── g1gg3.mp4 │ │ │ ├── g1gr.mp4 │ │ │ ├── g1gr2.mp4 │ │ │ ├── g1gr3.mp4 │ │ │ ├── r1rg.mp4 │ │ │ ├── r1rg2.mp4 │ │ │ ├── r1rg3.mp4 │ │ │ ├── r1rr.mp4 │ │ │ ├── r1rr2.mp4 │ │ │ └── r1rr3.mp4 │ ├── exp.png │ ├── kitchen_1cm.pcd │ ├── model_nyu_kitchen_fill_out_water_bottle.pcd │ ├── model_nyu_kitchen_make_some_coffee.pcd │ ├── model_nyu_kitchen_warm_up_my_lunch.pcd │ ├── multimodal_colorbar_flipped-1.png │ ├── multimodal_colorbar_flipped.pdf │ ├── nyu_kitchen_throw_my_trash.pcd │ ├── nyu_robot_run_clipped_small.mp4 │ ├── pit_robot_run_clipped_small.mp4 │ ├── query_navigation.avif │ ├── query_navigation.jpg │ └── training.mp4 └── more │ ├── bibtex.txt │ ├── blockpush │ └── index.html │ └── kitchen │ └── index.html ├── grid_hash_model.py ├── gridencoder ├── __init__.py ├── backend.py ├── grid.py ├── setup.py └── src │ ├── bindings.cpp │ ├── gridencoder.cu │ └── gridencoder.h ├── misc.py ├── readme.md ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | notebooks 4 | outputs 5 | wandb 6 | training.log 7 | core 8 | 9 | # Data files 10 | *.r3d 11 | *.zip 12 | *.pt 13 | *.ply 14 | *.pcd 15 | *.jpg 16 | *.pt 17 | *.pth 18 | 19 | # Slurm files 20 | *.sbatch 21 | wrapper.sh 22 | multirun 23 | 24 | # C/C++ bindings files 25 | *.so 26 | *.o 27 | *.egg-info 28 | build 29 | dist 30 | 31 | .cache 32 | .ipynb_checkpoints 33 | .pytest_cache -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "Detic"] 2 | path = Detic 3 | url = git@github.com:notmahi/Detic.git 4 | [submodule "LSeg"] 5 | path = LSeg 6 | url = git@github.com:notmahi/LSeg.git 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Nur Muhammad "Mahi" Shafiullah 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/train.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - _self_ 3 | - override hydra/launcher: submitit_slurm 4 | 5 | project: clip_field 6 | deterministic_id: false 7 | device: cuda 8 | use_cache: false 9 | batch_size: 12544 10 | 11 | # Dataset details 12 | dataset_path: nyu.r3d 13 | cache_result: true 14 | cache_path: detic_labeled_dataset.pt 15 | saved_dataset_path: detic_labeled_dataset.pt 16 | 17 | # Data loading and labelling specs 18 | sample_freq: 5 19 | detic_threshold: 0.4 20 | subsample_prob: 0.2 21 | use_lseg: false 22 | use_extra_classes: false 23 | use_gt_classes_in_detic: true 24 | 25 | # Neural field specs 26 | model_type: hash 27 | num_grid_levels: 18 28 | level_dim: 8 # So total dimension 144 29 | per_level_scale: 2 30 | mlp_depth: 1 31 | mlp_width: 600 32 | log2_hashmap_size: 20 33 | 34 | # Training specs 35 | seed: 42 36 | epochs: 100 37 | exp_decay_coeff: 0.5 38 | image_to_label_loss_ratio: 1.0 39 | label_to_image_loss_ratio: 1.0 40 | instance_loss_scale: 100.0 41 | epoch_size: 3e6 42 | dataparallel: false 43 | num_workers: 10 44 | 45 | # Debug purposes, visualize Detic results 46 | visualize_detic_results: false 47 | detic_visualization_path: "detic_debug" 48 | 49 | # Cache only runs are for building per-dataset caches, which can be used for multi-run later. 50 | cache_only_run: false 51 | 52 | # Learning rate data 53 | lr: 1e-4 54 | weight_decay: 0.003 55 | betas: 56 | - 0.9 57 | - 0.999 58 | 59 | save_directory: "clip_implicit_model" 60 | 61 | web_models: 62 | clip: "ViT-B/32" 63 | sentence: "all-mpnet-base-v2" 64 | 65 | hydra: 66 | callbacks: 67 | log_job_return: 68 | _target_: hydra.experimental.callbacks.LogJobReturnCallback 69 | 70 | launcher: 71 | timeout_min: 180 72 | cpus_per_task: 10 73 | gpus_per_node: 1 74 | tasks_per_node: 1 75 | mem_gb: 128 76 | nodes: 1 77 | name: ${hydra.job.name} 78 | _target_: hydra_plugins.hydra_submitit_launcher.submitit_launcher.SlurmLauncher 79 | signal_delay_s: 120 80 | max_num_timeout: 1 81 | additional_parameters: {} 82 | array_parallelism: 256 83 | setup: null 84 | partition: learnfair 85 | 86 | # Add any custom labels you want here 87 | custom_labels: 88 | - kitchen counter 89 | - kitchen cabinet 90 | - stove 91 | - cabinet 92 | - bathroom counter 93 | - refrigerator 94 | - microwave 95 | - oven 96 | - fireplace 97 | - door 98 | - sink 99 | - furniture 100 | - dish rack 101 | - dining table 102 | - shelf 103 | - bar 104 | - dishwasher 105 | - toaster oven 106 | - toaster 107 | - mini fridge 108 | - soap dish 109 | - coffee maker 110 | - table 111 | - bowl 112 | - rack 113 | - bulletin board 114 | - water cooler 115 | - coffee kettle 116 | - lamp 117 | - plate 118 | - window 119 | - dustpan 120 | - trash bin 121 | - ceiling 122 | - doorframe 123 | - trash can 124 | - basket 125 | - wall 126 | - bottle 127 | - broom 128 | - bin 129 | - paper 130 | - storage container 131 | - box 132 | - tray 133 | - whiteboard 134 | - decoration 135 | - board 136 | - cup 137 | - windowsill 138 | - potted plant 139 | - light 140 | - machine 141 | - fire extinguisher 142 | - bag 143 | - paper towel roll 144 | - chair 145 | - book 146 | - fire alarm 147 | - blinds 148 | - crate 149 | - tissue box 150 | - towel 151 | - paper bag 152 | - column 153 | - fan 154 | - object 155 | - range hood 156 | - plant 157 | - structure 158 | - poster 159 | - mat 160 | - water bottle 161 | - power outlet 162 | - storage bin 163 | - radiator 164 | - picture 165 | - water pitcher 166 | - pillar 167 | - light switch 168 | - bucket 169 | - storage organizer 170 | - vent 171 | - counter 172 | - ceiling light 173 | - case of water bottles 174 | - pipe 175 | - scale 176 | - recycling bin 177 | - clock 178 | - sign 179 | - folded chair 180 | - power strip 181 | 182 | # Or just comment it out for SCANNET 200 labels. 183 | # custom_labels: null -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from dataloaders.record3d import R3DSemanticDataset 2 | from dataloaders.open_classification import ClassificationExtractor 3 | from dataloaders.real_dataset import DeticDenseLabelledDataset 4 | -------------------------------------------------------------------------------- /dataloaders/open_classification.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import clip 4 | from sentence_transformers import SentenceTransformer 5 | from typing import List 6 | 7 | 8 | class ClassificationExtractor: 9 | PROMPT = "A " 10 | EMPTY_CLASS = "Other" 11 | LOGIT_TEMP = 100.0 12 | 13 | def __init__( 14 | self, 15 | clip_model_name: str, 16 | sentence_model_name: str, 17 | class_names: List[str], 18 | device: str = "cuda", 19 | image_weight: float = 1.0, 20 | label_weight: float = 5.0, 21 | ): 22 | clip_model, _ = clip.load(clip_model_name, device=device) 23 | sentence_model = SentenceTransformer(sentence_model_name, device=device) 24 | 25 | # Adding this class in the beginning since the labels are 1-indexed. 26 | text_strings = [] 27 | for name in class_names: 28 | text_strings.append(self.PROMPT + name.replace("-", " ").replace("_", " ")) 29 | with torch.no_grad(): 30 | all_embedded_text = sentence_model.encode(text_strings) 31 | all_embedded_text = torch.from_numpy(all_embedded_text).float().to(device) 32 | 33 | with torch.no_grad(): 34 | text = clip.tokenize(text_strings).to(device) 35 | clip_encoded_text = clip_model.encode_text(text).float().to(device) 36 | 37 | del clip_model 38 | del sentence_model 39 | 40 | self.class_names = text_strings 41 | self.total_label_classes = len(text_strings) 42 | self._sentence_embed_size = all_embedded_text.size(-1) 43 | self._clip_embed_size = clip_encoded_text.size(-1) 44 | 45 | self._sentence_features = F.normalize(all_embedded_text, p=2, dim=-1) 46 | self._clip_text_features = F.normalize(clip_encoded_text, p=2, dim=-1) 47 | 48 | self._image_weight = image_weight 49 | self._label_weight = label_weight 50 | 51 | def calculate_classifications( 52 | self, model_text_features: torch.Tensor, model_image_features: torch.Tensor 53 | ): 54 | # Figure out the classification given the learned embedding of the objects. 55 | assert model_text_features.size(-1) == self._sentence_embed_size 56 | assert model_image_features.size(-1) == self._clip_embed_size 57 | 58 | # Now do the softmax over the classes. 59 | model_text_features = F.normalize(model_text_features, p=2, dim=-1) 60 | model_image_features = F.normalize(model_image_features, p=2, dim=-1) 61 | 62 | with torch.no_grad(): 63 | text_logits = model_text_features @ self._sentence_features.T 64 | image_logits = model_image_features @ self._clip_text_features.T 65 | 66 | assert text_logits.size(-1) == self.total_label_classes 67 | assert image_logits.size(-1) == self.total_label_classes 68 | 69 | # Figure out weighted sum of probabilities. 70 | return ( 71 | self._label_weight * F.softmax(self.LOGIT_TEMP * text_logits, dim=-1) 72 | + self._image_weight * F.softmax(self.LOGIT_TEMP * image_logits, dim=-1) 73 | ) / (self._label_weight + self._image_weight) 74 | -------------------------------------------------------------------------------- /dataloaders/real_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional, Union 3 | import clip 4 | import einops 5 | import os 6 | import torch 7 | import tqdm 8 | import cv2 9 | 10 | import numpy as np 11 | from pathlib import Path 12 | from torch.utils.data import Dataset, DataLoader, Subset 13 | from dataloaders.record3d import R3DSemanticDataset 14 | from dataloaders.scannet_200_classes import SCANNET_COLOR_MAP_200, CLASS_LABELS_200 15 | 16 | 17 | # Setup detectron2 logger 18 | from detectron2.utils.logger import setup_logger 19 | from sentence_transformers import SentenceTransformer 20 | from torch.utils.data import Dataset 21 | 22 | setup_logger() 23 | d2_logger = logging.getLogger("detectron2") 24 | d2_logger.setLevel(level=logging.WARNING) 25 | 26 | # import some common libraries 27 | import sys 28 | 29 | # import some common detectron2 utilities 30 | from detectron2.config import get_cfg 31 | from detectron2.data import MetadataCatalog 32 | from detectron2.engine import DefaultPredictor 33 | 34 | 35 | DETIC_PATH = os.environ.get("DETIC_PATH", Path(__file__).parent / "../Detic") 36 | LSEG_PATH = os.environ.get("LSEG_PATH", Path(__file__).parent / "../LSeg/") 37 | 38 | sys.path.insert(0, f"{LSEG_PATH}/") 39 | from encoding.models.sseg import BaseNet 40 | from additional_utils.models import LSeg_MultiEvalModule 41 | from modules.lseg_module import LSegModule 42 | import torchvision.transforms as transforms 43 | 44 | # Detic libraries 45 | sys.path.insert(0, f"{DETIC_PATH}/third_party/CenterNet2/") 46 | sys.path.insert(0, f"{DETIC_PATH}/") 47 | from centernet.config import add_centernet_config 48 | from detic.config import add_detic_config 49 | from detic.modeling.utils import reset_cls_test 50 | from detic.modeling.text.text_encoder import build_text_encoder 51 | 52 | cfg = get_cfg() 53 | add_centernet_config(cfg) 54 | add_detic_config(cfg) 55 | cfg.merge_from_file( 56 | f"{DETIC_PATH}/configs/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.yaml" 57 | ) 58 | cfg.MODEL.WEIGHTS = "https://dl.fbaipublicfiles.com/detic/Detic_LCOCOI21k_CLIP_SwinB_896b32_4x_ft4x_max-size.pth" 59 | cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5 # set threshold for this model 60 | cfg.MODEL.ROI_BOX_HEAD.ZEROSHOT_WEIGHT_PATH = "rand" 61 | cfg.MODEL.ROI_HEADS.ONE_CLASS_PER_PROPOSAL = ( 62 | False # For better visualization purpose. Set to False for all classes. 63 | ) 64 | cfg.MODEL.ROI_BOX_HEAD.CAT_FREQ_PATH = ( 65 | f"{DETIC_PATH}/datasets/metadata/lvis_v1_train_cat_info.json" 66 | ) 67 | # cfg.MODEL.DEVICE='cpu' # uncomment this to use cpu-only mode. 68 | 69 | 70 | def get_clip_embeddings(vocabulary, prompt="a "): 71 | text_encoder = build_text_encoder(pretrain=True) 72 | text_encoder.eval() 73 | texts = [prompt + x.replace("-", " ") for x in vocabulary] 74 | emb = text_encoder(texts).detach().permute(1, 0).contiguous().cpu() 75 | return emb 76 | 77 | 78 | # New visualizer class to disable jitter. 79 | from detectron2.utils.visualizer import Visualizer 80 | from detectron2.utils.visualizer import ColorMode 81 | import matplotlib.colors as mplc 82 | 83 | 84 | class LowJitterVisualizer(Visualizer): 85 | def _jitter(self, color): 86 | """ 87 | Randomly modifies given color to produce a slightly different color than the color given. 88 | 89 | Args: 90 | color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color 91 | picked. The values in the list are in the [0.0, 1.0] range. 92 | 93 | Returns: 94 | jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the 95 | color after being jittered. The values in the list are in the [0.0, 1.0] range. 96 | """ 97 | color = mplc.to_rgb(color) 98 | vec = np.random.rand(3) 99 | # better to do it in another color space 100 | vec = vec / np.linalg.norm(vec) 101 | vec *= 0.01 # 1% noise in the color 102 | res = np.clip(vec + color, 0, 1) 103 | return tuple(res) 104 | 105 | 106 | SCANNET_NAME_TO_COLOR = { 107 | x: np.array(c) for x, c in zip(CLASS_LABELS_200, SCANNET_COLOR_MAP_200.values()) 108 | } 109 | 110 | SCANNET_ID_TO_COLOR = { 111 | i: np.array(c) for i, c in enumerate(SCANNET_COLOR_MAP_200.values()) 112 | } 113 | 114 | 115 | class DeticDenseLabelledDataset(Dataset): 116 | LSEG_LABEL_WEIGHT = 0.1 117 | LSEG_IMAGE_DISTANCE = 10.0 118 | 119 | def __init__( 120 | self, 121 | view_dataset: Union[R3DSemanticDataset, Subset[R3DSemanticDataset]], 122 | clip_model_name: str = "ViT-B/32", 123 | sentence_encoding_model_name="all-mpnet-base-v2", 124 | device: str = "cuda", 125 | batch_size: int = 1, 126 | detic_threshold: float = 0.3, 127 | num_images_to_label: int = -1, 128 | subsample_prob: float = 0.2, 129 | use_lseg: bool = False, 130 | use_extra_classes: bool = False, 131 | use_gt_classes: bool = True, 132 | exclude_gt_images: bool = False, 133 | gt_inst_images: Optional[List[int]] = None, 134 | gt_sem_images: Optional[List[int]] = None, 135 | visualize_results: bool = False, 136 | visualization_path: Optional[str] = None, 137 | use_scannet_colors: bool = True, 138 | ): 139 | dataset = view_dataset 140 | view_data = ( 141 | view_dataset.dataset if isinstance(view_dataset, Subset) else view_dataset 142 | ) 143 | self._image_width, self._image_height = view_data.image_size 144 | clip_model, _ = clip.load(clip_model_name, device=device) 145 | sentence_model = SentenceTransformer(sentence_encoding_model_name) 146 | 147 | self._batch_size = batch_size 148 | self._device = device 149 | self._detic_threshold = detic_threshold 150 | self._subsample_prob = subsample_prob 151 | 152 | self._label_xyz = [] 153 | self._label_rgb = [] 154 | self._label_weight = [] 155 | self._label_idx = [] 156 | self._text_ids = [] 157 | self._text_id_to_feature = {} 158 | self._image_features = [] 159 | self._distance = [] 160 | 161 | self._exclude_gt_image = exclude_gt_images 162 | images_to_label = self.get_best_sem_segmented_images( 163 | dataset, num_images_to_label, gt_inst_images, gt_sem_images 164 | ) 165 | self._use_lseg = use_lseg 166 | self._use_extra_classes = use_extra_classes 167 | self._use_gt_classes = use_gt_classes 168 | self._use_scannet_colors = use_scannet_colors 169 | 170 | self._visualize = visualize_results 171 | if self._visualize: 172 | assert visualization_path is not None 173 | self._visualization_path = Path(visualization_path) 174 | os.makedirs(self._visualization_path, exist_ok=True) 175 | # First, setup detic with the combined classes. 176 | self._setup_detic_all_classes(view_data) 177 | self._setup_detic_dense_labels( 178 | dataset, images_to_label, clip_model, sentence_model 179 | ) 180 | 181 | del clip_model 182 | del sentence_model 183 | 184 | def get_best_sem_segmented_images( 185 | self, 186 | dataset, 187 | num_images_to_label: int, 188 | gt_inst_images: Optional[List[int]] = None, 189 | gt_sem_images: Optional[List[int]] = None, 190 | ): 191 | # Using depth as a proxy for object diversity in a scene. 192 | if self._exclude_gt_image: 193 | assert gt_inst_images is not None 194 | assert gt_sem_images is not None 195 | num_objects_and_images = [] 196 | for idx in range(len(dataset)): 197 | if self._exclude_gt_image: 198 | if idx in gt_inst_images or idx in gt_sem_images: 199 | continue 200 | num_objects_and_images.append( 201 | (dataset[idx]["depth"].max() - dataset[idx]["depth"].min(), idx) 202 | ) 203 | 204 | sorted_num_object_and_img = sorted( 205 | num_objects_and_images, key=lambda x: x[0], reverse=True 206 | ) 207 | return [x[1] for x in sorted_num_object_and_img[:num_images_to_label]] 208 | 209 | @torch.no_grad() 210 | def _setup_detic_dense_labels( 211 | self, dataset, images_to_label, clip_model, sentence_model 212 | ): 213 | # Now just iterate over the images and do Detic preprocessing. 214 | dataloader = DataLoader(dataset, batch_size=1, shuffle=False, pin_memory=False) 215 | label_idx = 0 216 | for idx, data_dict in tqdm.tqdm( 217 | enumerate(dataloader), total=len(dataset), desc="Calculating Detic features" 218 | ): 219 | if idx not in images_to_label: 220 | continue 221 | rgb = einops.rearrange(data_dict["rgb"][..., :3], "b h w c -> b c h w") 222 | xyz = data_dict["xyz_position"] 223 | for image, coordinates in zip(rgb, xyz): 224 | # Now calculate the Detic classification for this. 225 | with torch.no_grad(): 226 | result = self._predictor.model( 227 | [ 228 | { 229 | "image": image * 255, 230 | "height": self._image_height, 231 | "width": self._image_width, 232 | } 233 | ] 234 | )[0] 235 | # Now extract the results from the image and store them 236 | instance = result["instances"] 237 | reshaped_rgb = einops.rearrange(image, "c h w -> h w c") 238 | ( 239 | reshaped_coordinates, 240 | valid_mask, 241 | ) = self._reshape_coordinates_and_get_valid(coordinates, data_dict) 242 | if self._visualize: 243 | v = LowJitterVisualizer( 244 | reshaped_rgb, 245 | self.metadata, 246 | instance_mode=ColorMode.SEGMENTATION, 247 | ) 248 | out = v.draw_instance_predictions(instance.to("cpu")) 249 | cv2.imwrite( 250 | str(self._visualization_path / f"{idx}.jpg"), 251 | out.get_image()[:, :, ::-1], 252 | [int(cv2.IMWRITE_JPEG_QUALITY), 80], 253 | ) 254 | for pred_class, pred_mask, pred_score, feature in zip( 255 | instance.pred_classes.cpu(), 256 | instance.pred_masks.cpu(), 257 | instance.scores.cpu(), 258 | instance.features.cpu(), 259 | ): 260 | real_mask = pred_mask[valid_mask] 261 | real_mask_rect = valid_mask & pred_mask 262 | # Go over each instance and add it to the DB. 263 | total_points = len(reshaped_coordinates[real_mask]) 264 | resampled_indices = torch.rand(total_points) < self._subsample_prob 265 | self._label_xyz.append( 266 | reshaped_coordinates[real_mask][resampled_indices] 267 | ) 268 | self._label_rgb.append( 269 | reshaped_rgb[real_mask_rect][resampled_indices] 270 | ) 271 | self._text_ids.append( 272 | torch.ones(total_points)[resampled_indices] 273 | * self._new_class_to_old_class_mapping[pred_class.item()] 274 | ) 275 | self._label_weight.append( 276 | torch.ones(total_points)[resampled_indices] * pred_score 277 | ) 278 | self._image_features.append( 279 | einops.repeat(feature, "d -> b d", b=total_points)[ 280 | resampled_indices 281 | ] 282 | ) 283 | self._label_idx.append( 284 | torch.ones(total_points)[resampled_indices] * label_idx 285 | ) 286 | self._distance.append(torch.zeros(total_points)[resampled_indices]) 287 | label_idx += 1 288 | 289 | # First delete leftover Detic predictors 290 | del self._predictor 291 | 292 | if self._use_lseg: 293 | # Now, get to LSeg 294 | self._setup_lseg() 295 | for idx, data_dict in tqdm.tqdm( 296 | enumerate(dataloader), 297 | total=len(dataset), 298 | desc="Calculating LSeg features", 299 | ): 300 | if idx not in images_to_label: 301 | continue 302 | rgb = einops.rearrange(data_dict["rgb"][..., :3], "b h w c -> b c h w") 303 | xyz = data_dict["xyz_position"] 304 | for image, coordinates in zip(rgb, xyz): 305 | # Now figure out the LSeg lables. 306 | with torch.no_grad(): 307 | unsqueezed_image = image.unsqueeze(0).float().cuda() 308 | resized_image = self.resize(image).unsqueeze(0).cuda() 309 | tfm_image = self.transform(unsqueezed_image) 310 | outputs = self.evaluator.parallel_forward( 311 | tfm_image, self._all_lseg_classes 312 | ) 313 | image_feature = clip_model.encode_image(resized_image).squeeze( 314 | 0 315 | ) 316 | image_feature = image_feature.cpu() 317 | predicts = [torch.max(output, 1)[1].cpu() for output in outputs] 318 | predict = predicts[0] 319 | 320 | ( 321 | reshaped_coordinates, 322 | valid_mask, 323 | ) = self._reshape_coordinates_and_get_valid(coordinates, data_dict) 324 | reshaped_rgb = einops.rearrange(image, "c h w -> h w c") 325 | 326 | for label in range(len(self._all_classes)): 327 | pred_mask = predict.squeeze(0) == label 328 | real_mask = pred_mask[valid_mask] 329 | real_mask_rect = valid_mask & pred_mask 330 | total_points = len(reshaped_coordinates[real_mask]) 331 | resampled_indices = ( 332 | torch.rand(total_points) < self._subsample_prob 333 | ) 334 | if total_points: 335 | self._label_xyz.append( 336 | reshaped_coordinates[real_mask][resampled_indices] 337 | ) 338 | self._label_rgb.append( 339 | reshaped_rgb[real_mask_rect][resampled_indices] 340 | ) 341 | # Ideally, this should give all classes their true class label. 342 | self._text_ids.append( 343 | torch.ones(total_points)[resampled_indices] 344 | * self._new_class_to_old_class_mapping[label] 345 | ) 346 | # Uniform label confidence of LSEG_LABEL_WEIGHT 347 | self._label_weight.append( 348 | torch.ones(total_points)[resampled_indices] 349 | * self.LSEG_LABEL_WEIGHT 350 | ) 351 | self._image_features.append( 352 | einops.repeat( 353 | image_feature, "d -> b d", b=total_points 354 | )[resampled_indices] 355 | ) 356 | self._label_idx.append( 357 | torch.ones(total_points)[resampled_indices] * label_idx 358 | ) 359 | self._distance.append( 360 | torch.ones(total_points)[resampled_indices] 361 | * self.LSEG_IMAGE_DISTANCE 362 | ) 363 | # Since they all get the same image, here label idx is increased once 364 | # at the very end. 365 | label_idx += 1 366 | 367 | # Now, delete the module and the evaluator 368 | del self.evaluator 369 | del self.module 370 | del self.transform 371 | 372 | # Now, get all the sentence encoding for all the labels. 373 | text_strings = [ 374 | DeticDenseLabelledDataset.process_text(x) for x in self._all_classes 375 | ] 376 | text_strings += self._all_classes 377 | with torch.no_grad(): 378 | all_embedded_text = sentence_model.encode(text_strings) 379 | all_embedded_text = torch.from_numpy(all_embedded_text).float() 380 | 381 | for i, feature in enumerate(all_embedded_text): 382 | self._text_id_to_feature[i] = feature 383 | 384 | # Now, we map from label to text using this model. 385 | self._label_xyz = torch.cat(self._label_xyz).float() 386 | self._label_rgb = torch.cat(self._label_rgb).float() 387 | self._label_weight = torch.cat(self._label_weight).float() 388 | self._image_features = torch.cat(self._image_features).float() 389 | self._text_ids = torch.cat(self._text_ids).long() 390 | self._label_idx = torch.cat(self._label_idx).long() 391 | self._distance = torch.cat(self._distance).float() 392 | self._instance = ( 393 | torch.ones_like(self._text_ids) * -1 394 | ).long() # We don't have instance ID from this dataset. 395 | 396 | def _resample(self): 397 | resampled_indices = torch.rand(len(self._label_xyz)) < self._subsample_prob 398 | logging.info( 399 | f"Resampling dataset down from {len(self._label_xyz)} points to {resampled_indices.long().sum().item()} points." 400 | ) 401 | self._label_xyz = self._label_xyz[resampled_indices] 402 | self._label_rgb = self._label_rgb[resampled_indices] 403 | self._label_weight = self._label_weight[resampled_indices] 404 | self._image_features = self._image_features[resampled_indices] 405 | self._text_ids = self._text_ids[resampled_indices] 406 | self._label_idx = self._label_idx[resampled_indices] 407 | self._distance = self._distance[resampled_indices] 408 | self._instance = self._instance[resampled_indices] 409 | 410 | def _reshape_coordinates_and_get_valid(self, coordinates, data_dict): 411 | if "conf" in data_dict: 412 | # Real world data, find valid mask 413 | valid_mask = ( 414 | torch.as_tensor( 415 | (~np.isnan(data_dict["depth"]) & (data_dict["conf"] == 2)) 416 | & (data_dict["depth"] < 3.0) 417 | ) 418 | .squeeze(0) 419 | .bool() 420 | ) 421 | reshaped_coordinates = torch.as_tensor(coordinates) 422 | return reshaped_coordinates, valid_mask 423 | else: 424 | reshaped_coordinates = einops.rearrange(coordinates, "c h w -> (h w) c") 425 | valid_mask = torch.ones_like(coordinates).mean(dim=0).bool() 426 | return reshaped_coordinates, valid_mask 427 | 428 | def __getitem__(self, idx): 429 | # Create a dictionary with all relevant results. 430 | return { 431 | "xyz": self._label_xyz[idx].float(), 432 | "rgb": self._label_rgb[idx].float(), 433 | "label": self._text_ids[idx].long(), 434 | "instance": self._instance[idx].long(), 435 | "img_idx": self._label_idx[idx].long(), 436 | "distance": self._distance[idx].float(), 437 | "clip_vector": self._text_id_to_feature.get( 438 | self._text_ids[idx].item() 439 | ).float(), 440 | "clip_image_vector": self._image_features[idx].float(), 441 | "semantic_weight": self._label_weight[idx].float(), 442 | } 443 | 444 | def __len__(self): 445 | return len(self._label_xyz) 446 | 447 | @staticmethod 448 | def process_text(x: str) -> str: 449 | return x.replace("-", " ").replace("_", " ").lstrip().rstrip().lower() 450 | 451 | def _setup_detic_all_classes(self, view_data: R3DSemanticDataset): 452 | # Unifying all the class labels. 453 | predictor = DefaultPredictor(cfg) 454 | prebuilt_class_names = [ 455 | DeticDenseLabelledDataset.process_text(x) 456 | for x in view_data._id_to_name.values() 457 | ] 458 | prebuilt_class_set = ( 459 | set(prebuilt_class_names) if self._use_gt_classes else set() 460 | ) 461 | filtered_new_classes = ( 462 | [x for x in CLASS_LABELS_200 if x not in prebuilt_class_set] 463 | if self._use_extra_classes 464 | else [] 465 | ) 466 | 467 | self._all_classes = prebuilt_class_names + filtered_new_classes 468 | 469 | if self._use_gt_classes: 470 | self._new_class_to_old_class_mapping = { 471 | x: x for x in range(len(self._all_classes)) 472 | } 473 | else: 474 | # We are not using all classes, so we should map which new/extra class maps 475 | # to which old class. 476 | for class_idx, class_name in enumerate(self._all_classes): 477 | if class_name in prebuilt_class_set: 478 | old_idx = prebuilt_class_names.index(class_name) 479 | else: 480 | old_idx = len(prebuilt_class_names) + filtered_new_classes.index( 481 | class_name 482 | ) 483 | self._new_class_to_old_class_mapping[class_idx] = old_idx 484 | 485 | self._all_classes = [ 486 | DeticDenseLabelledDataset.process_text(x) for x in self._all_classes 487 | ] 488 | new_metadata = MetadataCatalog.get("__unused") 489 | new_metadata.thing_classes = self._all_classes 490 | if self._use_scannet_colors: 491 | new_metadata.thing_colors = SCANNET_ID_TO_COLOR 492 | self.metadata = new_metadata 493 | classifier = get_clip_embeddings(new_metadata.thing_classes) 494 | num_classes = len(new_metadata.thing_classes) 495 | reset_cls_test(predictor.model, classifier, num_classes) 496 | # Reset visualization threshold 497 | output_score_threshold = self._detic_threshold 498 | for cascade_stages in range(len(predictor.model.roi_heads.box_predictor)): 499 | predictor.model.roi_heads.box_predictor[ 500 | cascade_stages 501 | ].test_score_thresh = output_score_threshold 502 | self._predictor = predictor 503 | 504 | def find_in_class(self, classname): 505 | try: 506 | return self._all_classes.index(classname) 507 | except ValueError: 508 | ret_value = len(self._all_classes) + self._unfound_offset 509 | self._unfound_offset += 1 510 | return ret_value 511 | 512 | def _setup_lseg(self): 513 | self._lseg_classes = self._all_classes 514 | self._num_true_lseg_classes = len(self._lseg_classes) 515 | self._all_lseg_classes = self._all_classes # + ["Other"] 516 | 517 | # We will try to classify all the classes, but will use LSeg labels for classes that 518 | # are not identified by Detic. 519 | LSEG_MODEL_PATH = f"{LSEG_PATH}/checkpoints/demo_e200.ckpt" 520 | try: 521 | self.module = LSegModule.load_from_checkpoint( 522 | checkpoint_path=LSEG_MODEL_PATH, 523 | data_path="", 524 | dataset="ade20k", 525 | backbone="clip_vitl16_384", 526 | aux=False, 527 | num_features=256, 528 | aux_weight=0, 529 | se_loss=False, 530 | se_weight=0, 531 | base_lr=0, 532 | batch_size=1, 533 | max_epochs=0, 534 | ignore_index=255, 535 | dropout=0.0, 536 | scale_inv=False, 537 | augment=False, 538 | no_batchnorm=False, 539 | widehead=True, 540 | widehead_hr=False, 541 | map_locatin=self._device, 542 | arch_option=0, 543 | block_depth=0, 544 | activation="lrelu", 545 | ) 546 | except FileNotFoundError: 547 | LSEG_URL = "https://github.com/isl-org/lang-seg" 548 | raise FileNotFoundError( 549 | "LSeg model not found. Please download it from {} and place it in {}".format( 550 | LSEG_URL, LSEG_MODEL_PATH 551 | ) 552 | ) 553 | if isinstance(self.module.net, BaseNet): 554 | model = self.module.net 555 | else: 556 | model = self.module 557 | 558 | model = model.eval() 559 | model = model.to(self._device) 560 | self.scales = [0.5, 0.75, 1.0, 1.25, 1.5, 1.75] 561 | 562 | model.mean = [0.5, 0.5, 0.5] 563 | model.std = [0.5, 0.5, 0.5] 564 | 565 | self.transform = transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 566 | self.resize = transforms.Resize((224, 224)) 567 | 568 | self.evaluator = LSeg_MultiEvalModule(model, scales=self.scales, flip=True).to( 569 | self._device 570 | ) 571 | self.evaluator = self.evaluator.eval() 572 | -------------------------------------------------------------------------------- /dataloaders/record3d.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import List, Optional 4 | from zipfile import ZipFile 5 | 6 | import liblzfse 7 | import numpy as np 8 | import open3d as o3d 9 | import tqdm 10 | from PIL import Image 11 | from quaternion import as_rotation_matrix, quaternion 12 | from torch.utils.data import Dataset 13 | 14 | from dataloaders.scannet_200_classes import CLASS_LABELS_200 15 | 16 | 17 | class R3DSemanticDataset(Dataset): 18 | def __init__( 19 | self, 20 | path: str, 21 | custom_classes: Optional[List[str]] = CLASS_LABELS_200, 22 | ): 23 | if path.endswith((".zip", ".r3d")): 24 | self._path = ZipFile(path) 25 | else: 26 | self._path = Path(path) 27 | 28 | if custom_classes: 29 | self._classes = custom_classes 30 | else: 31 | self._classes = CLASS_LABELS_200 32 | 33 | self._reshaped_depth = [] 34 | self._reshaped_conf = [] 35 | self._depth_images = [] 36 | self._rgb_images = [] 37 | self._confidences = [] 38 | 39 | self._metadata = self._read_metadata() 40 | self.global_xyzs = [] 41 | self.global_pcds = [] 42 | self._load_data() 43 | self._reshape_all_depth_and_conf() 44 | self.calculate_all_global_xyzs() 45 | 46 | def _read_metadata(self): 47 | with self._path.open("metadata", "r") as f: 48 | metadata_dict = json.load(f) 49 | 50 | # Now figure out the details from the metadata dict. 51 | self.rgb_width = metadata_dict["w"] 52 | self.rgb_height = metadata_dict["h"] 53 | self.fps = metadata_dict["fps"] 54 | self.camera_matrix = np.array(metadata_dict["K"]).reshape(3, 3).T 55 | 56 | self.image_size = (self.rgb_width, self.rgb_height) 57 | self.poses = np.array(metadata_dict["poses"]) 58 | self.init_pose = np.array(metadata_dict["initPose"]) 59 | self.total_images = len(self.poses) 60 | 61 | self._id_to_name = {i: x for (i, x) in enumerate(self._classes)} 62 | 63 | def load_image(self, filepath): 64 | with self._path.open(filepath, "r") as image_file: 65 | return np.asarray(Image.open(image_file)) 66 | 67 | def load_depth(self, filepath): 68 | with self._path.open(filepath, "r") as depth_fh: 69 | raw_bytes = depth_fh.read() 70 | decompressed_bytes = liblzfse.decompress(raw_bytes) 71 | depth_img: np.ndarray = np.frombuffer(decompressed_bytes, dtype=np.float32) 72 | 73 | if depth_img.shape[0] == 960 * 720: 74 | depth_img = depth_img.reshape((960, 720)) # For a FaceID camera 3D Video 75 | else: 76 | depth_img = depth_img.reshape((256, 192)) # For a LiDAR 3D Video 77 | return depth_img 78 | 79 | def load_conf(self, filepath): 80 | with self._path.open(filepath, "r") as depth_fh: 81 | raw_bytes = depth_fh.read() 82 | decompressed_bytes = liblzfse.decompress(raw_bytes) 83 | depth_img = np.frombuffer(decompressed_bytes, dtype=np.uint8) 84 | if depth_img.shape[0] == 960 * 720: 85 | depth_img = depth_img.reshape((960, 720)) # For a FaceID camera 3D Video 86 | else: 87 | depth_img = depth_img.reshape((256, 192)) # For a LiDAR 3D Video 88 | return depth_img 89 | 90 | def _load_data(self): 91 | assert self.fps # Make sure metadata is read correctly first. 92 | for i in tqdm.trange(self.total_images, desc="Loading data"): 93 | # Read up the RGB and depth images first. 94 | rgb_filepath = f"rgbd/{i}.jpg" 95 | depth_filepath = f"rgbd/{i}.depth" 96 | conf_filepath = f"rgbd/{i}.conf" 97 | 98 | depth_img = self.load_depth(depth_filepath) 99 | confidence = self.load_conf(conf_filepath) 100 | rgb_img = self.load_image(rgb_filepath) 101 | 102 | # Now, convert depth image to real world XYZ pointcloud. 103 | self._depth_images.append(depth_img) 104 | self._rgb_images.append(rgb_img) 105 | self._confidences.append(confidence) 106 | 107 | def _reshape_all_depth_and_conf(self): 108 | for index in tqdm.trange(len(self.poses), desc="Upscaling depth and conf"): 109 | depth_image = self._depth_images[index] 110 | # Upscale depth image. 111 | pil_img = Image.fromarray(depth_image) 112 | reshaped_img = pil_img.resize((self.rgb_width, self.rgb_height)) 113 | reshaped_img = np.asarray(reshaped_img) 114 | self._reshaped_depth.append(reshaped_img) 115 | 116 | # Upscale confidence as well 117 | confidence = self._confidences[index] 118 | conf_img = Image.fromarray(confidence) 119 | reshaped_conf = conf_img.resize((self.rgb_width, self.rgb_height)) 120 | reshaped_conf = np.asarray(reshaped_conf) 121 | self._reshaped_conf.append(reshaped_conf) 122 | 123 | def get_global_xyz(self, index, depth_scale=1000.0, only_confident=True): 124 | reshaped_img = np.copy(self._reshaped_depth[index]) 125 | # If only confident, replace not confident points with nans 126 | if only_confident: 127 | reshaped_img[self._reshaped_conf[index] != 2] = np.nan 128 | 129 | depth_o3d = o3d.geometry.Image( 130 | np.ascontiguousarray(depth_scale * reshaped_img).astype(np.float32) 131 | ) 132 | rgb_o3d = o3d.geometry.Image( 133 | np.ascontiguousarray(self._rgb_images[index]).astype(np.uint8) 134 | ) 135 | 136 | rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth( 137 | rgb_o3d, depth_o3d, convert_rgb_to_intensity=False 138 | ) 139 | 140 | camera_intrinsics = o3d.camera.PinholeCameraIntrinsic( 141 | width=int(self.rgb_width), 142 | height=int(self.rgb_height), 143 | fx=self.camera_matrix[0, 0], 144 | fy=self.camera_matrix[1, 1], 145 | cx=self.camera_matrix[0, 2], 146 | cy=self.camera_matrix[1, 2], 147 | ) 148 | pcd = o3d.geometry.PointCloud.create_from_rgbd_image( 149 | rgbd_image, camera_intrinsics 150 | ) 151 | # Flip the pcd 152 | pcd.transform([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 153 | 154 | extrinsic_matrix = np.eye(4) 155 | qx, qy, qz, qw, px, py, pz = self.poses[index] 156 | extrinsic_matrix[:3, :3] = as_rotation_matrix(quaternion(qw, qx, qy, qz)) 157 | extrinsic_matrix[:3, -1] = [px, py, pz] 158 | pcd.transform(extrinsic_matrix) 159 | 160 | # Now transform everything by init pose. 161 | init_matrix = np.eye(4) 162 | qx, qy, qz, qw, px, py, pz = self.init_pose 163 | init_matrix[:3, :3] = as_rotation_matrix(quaternion(qw, qx, qy, qz)) 164 | init_matrix[:3, -1] = [px, py, pz] 165 | pcd.transform(init_matrix) 166 | 167 | return pcd 168 | 169 | def calculate_all_global_xyzs(self, only_confident=True): 170 | if len(self.global_xyzs): 171 | return self.global_xyzs, self.global_pcds 172 | for i in tqdm.trange(len(self.poses), desc="Calculating global XYZs"): 173 | global_xyz_pcd = self.get_global_xyz(i, only_confident=only_confident) 174 | global_xyz = np.asarray(global_xyz_pcd.points) 175 | self.global_xyzs.append(global_xyz) 176 | self.global_pcds.append(global_xyz_pcd) 177 | return self.global_xyzs, self.global_pcds 178 | 179 | def __len__(self): 180 | return len(self.poses) 181 | 182 | def __getitem__(self, idx): 183 | result = { 184 | "xyz_position": self.global_xyzs[idx], 185 | "rgb": self._rgb_images[idx], 186 | "depth": self._reshaped_depth[idx], 187 | "conf": self._reshaped_conf[idx], 188 | } 189 | return result 190 | -------------------------------------------------------------------------------- /dataloaders/scannet_200_classes.py: -------------------------------------------------------------------------------- 1 | # Credits to ScanNet 200 dataset 2 | # https://github.com/ScanNet/ScanNet/tree/master/BenchmarkScripts/ScanNet200 3 | 4 | CLASS_LABELS_200 = ( 5 | "wall", 6 | "chair", 7 | "floor", 8 | "table", 9 | "door", 10 | "couch", 11 | "cabinet", 12 | "shelf", 13 | "desk", 14 | "office chair", 15 | "bed", 16 | "pillow", 17 | "sink", 18 | "picture", 19 | "window", 20 | "toilet", 21 | "bookshelf", 22 | "monitor", 23 | "curtain", 24 | "book", 25 | "armchair", 26 | "coffee table", 27 | "box", 28 | "refrigerator", 29 | "lamp", 30 | "kitchen cabinet", 31 | "towel", 32 | "clothes", 33 | "tv", 34 | "nightstand", 35 | "counter", 36 | "dresser", 37 | "stool", 38 | "cushion", 39 | "plant", 40 | "ceiling", 41 | "bathtub", 42 | "end table", 43 | "dining table", 44 | "keyboard", 45 | "bag", 46 | "backpack", 47 | "toilet paper", 48 | "printer", 49 | "tv stand", 50 | "whiteboard", 51 | "blanket", 52 | "shower curtain", 53 | "trash can", 54 | "closet", 55 | "stairs", 56 | "microwave", 57 | "stove", 58 | "shoe", 59 | "computer tower", 60 | "bottle", 61 | "bin", 62 | "ottoman", 63 | "bench", 64 | "board", 65 | "washing machine", 66 | "mirror", 67 | "copier", 68 | "basket", 69 | "sofa chair", 70 | "file cabinet", 71 | "fan", 72 | "laptop", 73 | "shower", 74 | "paper", 75 | "person", 76 | "paper towel dispenser", 77 | "oven", 78 | "blinds", 79 | "rack", 80 | "plate", 81 | "blackboard", 82 | "piano", 83 | "suitcase", 84 | "rail", 85 | "radiator", 86 | "recycling bin", 87 | "container", 88 | "wardrobe", 89 | "soap dispenser", 90 | "telephone", 91 | "bucket", 92 | "clock", 93 | "stand", 94 | "light", 95 | "laundry basket", 96 | "pipe", 97 | "clothes dryer", 98 | "guitar", 99 | "toilet paper holder", 100 | "seat", 101 | "speaker", 102 | "column", 103 | "bicycle", 104 | "ladder", 105 | "bathroom stall", 106 | "shower wall", 107 | "cup", 108 | "jacket", 109 | "storage bin", 110 | "coffee maker", 111 | "dishwasher", 112 | "paper towel roll", 113 | "machine", 114 | "mat", 115 | "windowsill", 116 | "bar", 117 | "toaster", 118 | "bulletin board", 119 | "ironing board", 120 | "fireplace", 121 | "soap dish", 122 | "kitchen counter", 123 | "doorframe", 124 | "toilet paper dispenser", 125 | "mini fridge", 126 | "fire extinguisher", 127 | "ball", 128 | "hat", 129 | "shower curtain rod", 130 | "water cooler", 131 | "paper cutter", 132 | "tray", 133 | "shower door", 134 | "pillar", 135 | "ledge", 136 | "toaster oven", 137 | "mouse", 138 | "toilet seat cover dispenser", 139 | "furniture", 140 | "cart", 141 | "storage container", 142 | "scale", 143 | "tissue box", 144 | "light switch", 145 | "crate", 146 | "power outlet", 147 | "decoration", 148 | "sign", 149 | "projector", 150 | "closet door", 151 | "vacuum cleaner", 152 | "candle", 153 | "plunger", 154 | "stuffed animal", 155 | "headphones", 156 | "dish rack", 157 | "broom", 158 | "guitar case", 159 | "range hood", 160 | "dustpan", 161 | "hair dryer", 162 | "water bottle", 163 | "handicap bar", 164 | "purse", 165 | "vent", 166 | "shower floor", 167 | "water pitcher", 168 | "mailbox", 169 | "bowl", 170 | "paper bag", 171 | "alarm clock", 172 | "music stand", 173 | "projector screen", 174 | "divider", 175 | "laundry detergent", 176 | "bathroom counter", 177 | "object", 178 | "bathroom vanity", 179 | "closet wall", 180 | "laundry hamper", 181 | "bathroom stall door", 182 | "ceiling light", 183 | "trash bin", 184 | "dumbbell", 185 | "stair rail", 186 | "tube", 187 | "bathroom cabinet", 188 | "cd case", 189 | "closet rod", 190 | "coffee kettle", 191 | "structure", 192 | "shower head", 193 | "keyboard piano", 194 | "case of water bottles", 195 | "coat rack", 196 | "storage organizer", 197 | "folded chair", 198 | "fire alarm", 199 | "power strip", 200 | "calendar", 201 | "poster", 202 | "potted plant", 203 | "luggage", 204 | "mattress", 205 | ) 206 | 207 | SCANNET_COLOR_MAP_200 = { 208 | 0: (0.0, 0.0, 0.0), 209 | 1: (174.0, 199.0, 232.0), 210 | 2: (188.0, 189.0, 34.0), 211 | 3: (152.0, 223.0, 138.0), 212 | 4: (255.0, 152.0, 150.0), 213 | 5: (214.0, 39.0, 40.0), 214 | 6: (91.0, 135.0, 229.0), 215 | 7: (31.0, 119.0, 180.0), 216 | 8: (229.0, 91.0, 104.0), 217 | 9: (247.0, 182.0, 210.0), 218 | 10: (91.0, 229.0, 110.0), 219 | 11: (255.0, 187.0, 120.0), 220 | 13: (141.0, 91.0, 229.0), 221 | 14: (112.0, 128.0, 144.0), 222 | 15: (196.0, 156.0, 148.0), 223 | 16: (197.0, 176.0, 213.0), 224 | 17: (44.0, 160.0, 44.0), 225 | 18: (148.0, 103.0, 189.0), 226 | 19: (229.0, 91.0, 223.0), 227 | 21: (219.0, 219.0, 141.0), 228 | 22: (192.0, 229.0, 91.0), 229 | 23: (88.0, 218.0, 137.0), 230 | 24: (58.0, 98.0, 137.0), 231 | 26: (177.0, 82.0, 239.0), 232 | 27: (255.0, 127.0, 14.0), 233 | 28: (237.0, 204.0, 37.0), 234 | 29: (41.0, 206.0, 32.0), 235 | 31: (62.0, 143.0, 148.0), 236 | 32: (34.0, 14.0, 130.0), 237 | 33: (143.0, 45.0, 115.0), 238 | 34: (137.0, 63.0, 14.0), 239 | 35: (23.0, 190.0, 207.0), 240 | 36: (16.0, 212.0, 139.0), 241 | 38: (90.0, 119.0, 201.0), 242 | 39: (125.0, 30.0, 141.0), 243 | 40: (150.0, 53.0, 56.0), 244 | 41: (186.0, 197.0, 62.0), 245 | 42: (227.0, 119.0, 194.0), 246 | 44: (38.0, 100.0, 128.0), 247 | 45: (120.0, 31.0, 243.0), 248 | 46: (154.0, 59.0, 103.0), 249 | 47: (169.0, 137.0, 78.0), 250 | 48: (143.0, 245.0, 111.0), 251 | 49: (37.0, 230.0, 205.0), 252 | 50: (14.0, 16.0, 155.0), 253 | 51: (196.0, 51.0, 182.0), 254 | 52: (237.0, 80.0, 38.0), 255 | 54: (138.0, 175.0, 62.0), 256 | 55: (158.0, 218.0, 229.0), 257 | 56: (38.0, 96.0, 167.0), 258 | 57: (190.0, 77.0, 246.0), 259 | 58: (208.0, 49.0, 84.0), 260 | 59: (208.0, 193.0, 72.0), 261 | 62: (55.0, 220.0, 57.0), 262 | 63: (10.0, 125.0, 140.0), 263 | 64: (76.0, 38.0, 202.0), 264 | 65: (191.0, 28.0, 135.0), 265 | 66: (211.0, 120.0, 42.0), 266 | 67: (118.0, 174.0, 76.0), 267 | 68: (17.0, 242.0, 171.0), 268 | 69: (20.0, 65.0, 247.0), 269 | 70: (208.0, 61.0, 222.0), 270 | 71: (162.0, 62.0, 60.0), 271 | 72: (210.0, 235.0, 62.0), 272 | 73: (45.0, 152.0, 72.0), 273 | 74: (35.0, 107.0, 149.0), 274 | 75: (160.0, 89.0, 237.0), 275 | 76: (227.0, 56.0, 125.0), 276 | 77: (169.0, 143.0, 81.0), 277 | 78: (42.0, 143.0, 20.0), 278 | 79: (25.0, 160.0, 151.0), 279 | 80: (82.0, 75.0, 227.0), 280 | 82: (253.0, 59.0, 222.0), 281 | 84: (240.0, 130.0, 89.0), 282 | 86: (123.0, 172.0, 47.0), 283 | 87: (71.0, 194.0, 133.0), 284 | 88: (24.0, 94.0, 205.0), 285 | 89: (134.0, 16.0, 179.0), 286 | 90: (159.0, 32.0, 52.0), 287 | 93: (213.0, 208.0, 88.0), 288 | 95: (64.0, 158.0, 70.0), 289 | 96: (18.0, 163.0, 194.0), 290 | 97: (65.0, 29.0, 153.0), 291 | 98: (177.0, 10.0, 109.0), 292 | 99: (152.0, 83.0, 7.0), 293 | 100: (83.0, 175.0, 30.0), 294 | 101: (18.0, 199.0, 153.0), 295 | 102: (61.0, 81.0, 208.0), 296 | 103: (213.0, 85.0, 216.0), 297 | 104: (170.0, 53.0, 42.0), 298 | 105: (161.0, 192.0, 38.0), 299 | 106: (23.0, 241.0, 91.0), 300 | 107: (12.0, 103.0, 170.0), 301 | 110: (151.0, 41.0, 245.0), 302 | 112: (133.0, 51.0, 80.0), 303 | 115: (184.0, 162.0, 91.0), 304 | 116: (50.0, 138.0, 38.0), 305 | 118: (31.0, 237.0, 236.0), 306 | 120: (39.0, 19.0, 208.0), 307 | 121: (223.0, 27.0, 180.0), 308 | 122: (254.0, 141.0, 85.0), 309 | 125: (97.0, 144.0, 39.0), 310 | 128: (106.0, 231.0, 176.0), 311 | 130: (12.0, 61.0, 162.0), 312 | 131: (124.0, 66.0, 140.0), 313 | 132: (137.0, 66.0, 73.0), 314 | 134: (250.0, 253.0, 26.0), 315 | 136: (55.0, 191.0, 73.0), 316 | 138: (60.0, 126.0, 146.0), 317 | 139: (153.0, 108.0, 234.0), 318 | 140: (184.0, 58.0, 125.0), 319 | 141: (135.0, 84.0, 14.0), 320 | 145: (139.0, 248.0, 91.0), 321 | 148: (53.0, 200.0, 172.0), 322 | 154: (63.0, 69.0, 134.0), 323 | 155: (190.0, 75.0, 186.0), 324 | 156: (127.0, 63.0, 52.0), 325 | 157: (141.0, 182.0, 25.0), 326 | 159: (56.0, 144.0, 89.0), 327 | 161: (64.0, 160.0, 250.0), 328 | 163: (182.0, 86.0, 245.0), 329 | 165: (139.0, 18.0, 53.0), 330 | 166: (134.0, 120.0, 54.0), 331 | 168: (49.0, 165.0, 42.0), 332 | 169: (51.0, 128.0, 133.0), 333 | 170: (44.0, 21.0, 163.0), 334 | 177: (232.0, 93.0, 193.0), 335 | 180: (176.0, 102.0, 54.0), 336 | 185: (116.0, 217.0, 17.0), 337 | 188: (54.0, 209.0, 150.0), 338 | 191: (60.0, 99.0, 204.0), 339 | 193: (129.0, 43.0, 144.0), 340 | 195: (252.0, 100.0, 106.0), 341 | 202: (187.0, 196.0, 73.0), 342 | 208: (13.0, 158.0, 40.0), 343 | 213: (52.0, 122.0, 152.0), 344 | 214: (128.0, 76.0, 202.0), 345 | 221: (187.0, 50.0, 115.0), 346 | 229: (180.0, 141.0, 71.0), 347 | 230: (77.0, 208.0, 35.0), 348 | 232: (72.0, 183.0, 168.0), 349 | 233: (97.0, 99.0, 203.0), 350 | 242: (172.0, 22.0, 158.0), 351 | 250: (155.0, 64.0, 40.0), 352 | 261: (118.0, 159.0, 30.0), 353 | 264: (69.0, 252.0, 148.0), 354 | 276: (45.0, 103.0, 173.0), 355 | 283: (111.0, 38.0, 149.0), 356 | 286: (184.0, 9.0, 49.0), 357 | 300: (188.0, 174.0, 67.0), 358 | 304: (53.0, 206.0, 53.0), 359 | 312: (97.0, 235.0, 252.0), 360 | 323: (66.0, 32.0, 182.0), 361 | 325: (236.0, 114.0, 195.0), 362 | 331: (241.0, 154.0, 83.0), 363 | 342: (133.0, 240.0, 52.0), 364 | 356: (16.0, 205.0, 144.0), 365 | 370: (75.0, 101.0, 198.0), 366 | 392: (237.0, 95.0, 251.0), 367 | 395: (191.0, 52.0, 49.0), 368 | 399: (227.0, 254.0, 54.0), 369 | 408: (49.0, 206.0, 87.0), 370 | 417: (48.0, 113.0, 150.0), 371 | 488: (125.0, 73.0, 182.0), 372 | 540: (229.0, 32.0, 114.0), 373 | 562: (158.0, 119.0, 28.0), 374 | 570: (60.0, 205.0, 27.0), 375 | 572: (18.0, 215.0, 201.0), 376 | 581: (79.0, 76.0, 153.0), 377 | 609: (134.0, 13.0, 116.0), 378 | 748: (192.0, 97.0, 63.0), 379 | 776: (108.0, 163.0, 18.0), 380 | 1156: (95.0, 220.0, 156.0), 381 | 1163: (98.0, 141.0, 208.0), 382 | 1164: (144.0, 19.0, 193.0), 383 | 1165: (166.0, 36.0, 57.0), 384 | 1166: (212.0, 202.0, 34.0), 385 | 1167: (23.0, 206.0, 34.0), 386 | 1168: (91.0, 211.0, 236.0), 387 | 1169: (79.0, 55.0, 137.0), 388 | 1170: (182.0, 19.0, 117.0), 389 | 1171: (134.0, 76.0, 14.0), 390 | 1172: (87.0, 185.0, 28.0), 391 | 1173: (82.0, 224.0, 187.0), 392 | 1174: (92.0, 110.0, 214.0), 393 | 1175: (168.0, 80.0, 171.0), 394 | 1176: (197.0, 63.0, 51.0), 395 | 1178: (175.0, 199.0, 77.0), 396 | 1179: (62.0, 180.0, 98.0), 397 | 1180: (8.0, 91.0, 150.0), 398 | 1181: (77.0, 15.0, 130.0), 399 | 1182: (154.0, 65.0, 96.0), 400 | 1183: (197.0, 152.0, 11.0), 401 | 1184: (59.0, 155.0, 45.0), 402 | 1185: (12.0, 147.0, 145.0), 403 | 1186: (54.0, 35.0, 219.0), 404 | 1187: (210.0, 73.0, 181.0), 405 | 1188: (221.0, 124.0, 77.0), 406 | 1189: (149.0, 214.0, 66.0), 407 | 1190: (72.0, 185.0, 134.0), 408 | 1191: (42.0, 94.0, 198.0), 409 | } 410 | 411 | ### For instance segmentation the non-object categories ### 412 | VALID_PANOPTIC_IDS = (1, 3) 413 | 414 | CLASS_LABELS_PANOPTIC = ("wall", "floor") 415 | -------------------------------------------------------------------------------- /demo/3 - training a CLIP field.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a6aa8354-6fec-4b7b-be0e-432b571fa354", 6 | "metadata": {}, 7 | "source": [ 8 | "# 3. Training a CLIP-Field\n", 9 | "\n", 10 | "In this tutorial, we are going to create a CLIP-Field from our saved data. CLIP-Field is an implicit neural field that maps from 3D XYZ coordinates to higher dimensional representations such as CLIP visual features and Sentence-BERT semantic embeddings." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "678eca98", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "import logging\n", 21 | "import os\n", 22 | "import pprint\n", 23 | "import random\n", 24 | "from typing import Dict, Union\n", 25 | "\n", 26 | "import hydra\n", 27 | "import numpy as np\n", 28 | "import torch\n", 29 | "import torch.nn.functional as F\n", 30 | "import torchmetrics\n", 31 | "import tqdm\n", 32 | "from omegaconf import OmegaConf\n", 33 | "from torch.utils.data import DataLoader, Subset\n", 34 | "\n", 35 | "import wandb\n", 36 | "import sys\n", 37 | "sys.path.append('..')" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "id": "522b1c6c", 44 | "metadata": {}, 45 | "outputs": [ 46 | { 47 | "name": "stdout", 48 | "output_type": "stream", 49 | "text": [ 50 | "Jupyter environment detected. Enabling Open3D WebVisualizer.\n", 51 | "[Open3D INFO] WebRTC GUI backend enabled.\n", 52 | "[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "from dataloaders import (\n", 58 | " R3DSemanticDataset,\n", 59 | " DeticDenseLabelledDataset,\n", 60 | " ClassificationExtractor,\n", 61 | ")\n", 62 | "from misc import ImplicitDataparallel\n", 63 | "from grid_hash_model import GridCLIPModel" 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "id": "bcd413f6", 69 | "metadata": {}, 70 | "source": [ 71 | "## Load the data and create a model\n", 72 | "\n", 73 | "Now, we will set up the constants and create the models." 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 3, 79 | "id": "2432180d", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# Set up the constants\n", 84 | "\n", 85 | "SAVE_DIRECTORY = \"../clip_implicit_model\"\n", 86 | "DEVICE = \"cuda\"\n", 87 | "IMAGE_TO_LABEL_CLIP_LOSS_SCALE = 1.0\n", 88 | "LABEL_TO_IMAGE_LOSS_SCALE = 1.0\n", 89 | "EXP_DECAY_COEFF = 0.5\n", 90 | "SAVE_EVERY = 5\n", 91 | "METRICS = {\n", 92 | " \"accuracy\": torchmetrics.Accuracy,\n", 93 | "}\n", 94 | "\n", 95 | "BATCH_SIZE = 11000\n", 96 | "NUM_WORKERS = 10\n", 97 | "\n", 98 | "CLIP_MODEL_NAME = \"ViT-B/32\"\n", 99 | "SBERT_MODEL_NAME = \"all-mpnet-base-v2\"" 100 | ] 101 | }, 102 | { 103 | "cell_type": "code", 104 | "execution_count": 4, 105 | "id": "f6f5fb4e", 106 | "metadata": {}, 107 | "outputs": [], 108 | "source": [ 109 | "# Load the data and create the dataloader created in the previous tutorial notebook\n", 110 | "\n", 111 | "training_data = torch.load(\"../detic_labeled_dataset.pt\")\n", 112 | "max_coords, _ = training_data._label_xyz.max(dim=0)\n", 113 | "min_coords, _ = training_data._label_xyz.min(dim=0)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 5, 119 | "id": "d8e96bcf", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# Set up the model\n", 124 | "\n", 125 | "label_model = GridCLIPModel(\n", 126 | " image_rep_size=training_data[0][\"clip_image_vector\"].shape[-1],\n", 127 | " text_rep_size=training_data[0][\"clip_vector\"].shape[-1],\n", 128 | " mlp_depth=1,\n", 129 | " mlp_width=600,\n", 130 | " log2_hashmap_size=20,\n", 131 | " num_levels=18,\n", 132 | " level_dim=8,\n", 133 | " per_level_scale=2,\n", 134 | " max_coords=max_coords,\n", 135 | " min_coords=min_coords,\n", 136 | ").to(DEVICE)" 137 | ] 138 | }, 139 | { 140 | "cell_type": "markdown", 141 | "id": "3a1f6f4d", 142 | "metadata": {}, 143 | "source": [ 144 | "## Training and evaulation code\n", 145 | "\n", 146 | "Now, we will set up the training and the evaluation code. We will train the model to predict the CLIP/SBert features from the 3D coordinates with a contrastive loss. For evaluation, we will measure the zero-shot label accuracy of the model." 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 6, 152 | "id": "ea5bd590", 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "@torch.no_grad()\n", 157 | "def zero_shot_eval(\n", 158 | " classifier: ClassificationExtractor, \n", 159 | " predicted_label_latents: torch.Tensor, \n", 160 | " predicted_image_latents: torch.Tensor, \n", 161 | " language_label_index: torch.Tensor, \n", 162 | " metric_calculators: Dict[str, Dict[str, torchmetrics.Metric]]\n", 163 | "):\n", 164 | " \"\"\"Evaluate the model on the zero-shot classification task.\"\"\"\n", 165 | " class_probs = classifier.calculate_classifications(\n", 166 | " model_text_features=predicted_label_latents,\n", 167 | " model_image_features=predicted_image_latents,\n", 168 | " )\n", 169 | " # Now figure out semantic accuracy and loss.\n", 170 | " # Semseg mask is necessary for the boundary case where all the points in the batch are \"unlabeled\"\n", 171 | " semseg_mask = torch.logical_and(\n", 172 | " language_label_index != -1,\n", 173 | " language_label_index < classifier.total_label_classes,\n", 174 | " ).squeeze(-1)\n", 175 | " if not torch.any(semseg_mask):\n", 176 | " classification_loss = torch.zeros_like(semseg_mask).mean(dim=-1)\n", 177 | " else:\n", 178 | " # Figure out the right classes.\n", 179 | " masked_class_prob = class_probs[semseg_mask]\n", 180 | " masked_labels = language_label_index[semseg_mask].squeeze(-1).long()\n", 181 | " classification_loss = F.cross_entropy(\n", 182 | " torch.log(masked_class_prob),\n", 183 | " masked_labels,\n", 184 | " )\n", 185 | " if metric_calculators.get(\"semantic\"):\n", 186 | " for _, calculators in metric_calculators[\"semantic\"].items():\n", 187 | " _ = calculators(masked_class_prob, masked_labels)\n", 188 | " return classification_loss" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": 7, 194 | "id": "2aeb16cd", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "def train(\n", 199 | " clip_train_loader: DataLoader,\n", 200 | " labelling_model: Union[GridCLIPModel, ImplicitDataparallel],\n", 201 | " optim: torch.optim.Optimizer,\n", 202 | " epoch: int,\n", 203 | " classifier: ClassificationExtractor,\n", 204 | " device: Union[str, torch.device] = DEVICE,\n", 205 | " exp_decay_coeff: float = EXP_DECAY_COEFF,\n", 206 | " image_to_label_loss_ratio: float = IMAGE_TO_LABEL_CLIP_LOSS_SCALE,\n", 207 | " label_to_image_loss_ratio: float = LABEL_TO_IMAGE_LOSS_SCALE,\n", 208 | " disable_tqdm: bool = False,\n", 209 | " metric_calculators: Dict[str, Dict[str, torchmetrics.Metric]] = {},\n", 210 | "):\n", 211 | " \"\"\"\n", 212 | " Train the model for one epoch.\n", 213 | " \"\"\"\n", 214 | " total_loss = 0\n", 215 | " label_loss = 0\n", 216 | " image_loss = 0\n", 217 | " classification_loss = 0\n", 218 | " total_samples = 0\n", 219 | " total_classification_loss = 0\n", 220 | " labelling_model.train()\n", 221 | " total = len(clip_train_loader)\n", 222 | " for clip_data_dict in tqdm.tqdm(\n", 223 | " clip_train_loader,\n", 224 | " total=total,\n", 225 | " disable=disable_tqdm,\n", 226 | " desc=f\"Training epoch {epoch}\",\n", 227 | " ):\n", 228 | " xyzs = clip_data_dict[\"xyz\"].to(device)\n", 229 | " clip_labels = clip_data_dict[\"clip_vector\"].to(device)\n", 230 | " clip_image_labels = clip_data_dict[\"clip_image_vector\"].to(device)\n", 231 | " image_weights = torch.exp(-exp_decay_coeff * clip_data_dict[\"distance\"]).to(\n", 232 | " device\n", 233 | " )\n", 234 | " label_weights = clip_data_dict[\"semantic_weight\"].to(device)\n", 235 | " image_label_index: torch.Tensor = (\n", 236 | " clip_data_dict[\"img_idx\"].to(device).reshape(-1, 1)\n", 237 | " )\n", 238 | " language_label_index: torch.Tensor = (\n", 239 | " clip_data_dict[\"label\"].to(device).reshape(-1, 1)\n", 240 | " )\n", 241 | "\n", 242 | " (predicted_label_latents, predicted_image_latents) = labelling_model(xyzs)\n", 243 | " # Calculate the loss from the image to label side.\n", 244 | " batch_size = len(image_label_index)\n", 245 | " image_label_mask: torch.Tensor = (\n", 246 | " image_label_index != image_label_index.t()\n", 247 | " ).float() + torch.eye(batch_size, device=device)\n", 248 | " language_label_mask: torch.Tensor = (\n", 249 | " language_label_index != language_label_index.t()\n", 250 | " ).float() + torch.eye(batch_size, device=device)\n", 251 | "\n", 252 | " # For logging purposes, keep track of negative samples per point.\n", 253 | " image_label_mask.requires_grad = False\n", 254 | " language_label_mask.requires_grad = False\n", 255 | " contrastive_loss_labels = labelling_model.compute_loss(\n", 256 | " predicted_label_latents,\n", 257 | " clip_labels,\n", 258 | " label_mask=language_label_mask,\n", 259 | " weights=label_weights,\n", 260 | " )\n", 261 | " contrastive_loss_images = labelling_model.compute_loss(\n", 262 | " predicted_image_latents,\n", 263 | " clip_image_labels,\n", 264 | " label_mask=image_label_mask,\n", 265 | " weights=image_weights,\n", 266 | " )\n", 267 | " del (\n", 268 | " image_label_mask,\n", 269 | " image_label_index,\n", 270 | " language_label_mask,\n", 271 | " )\n", 272 | "\n", 273 | " # Mostly for evaluation purposes, calculate the classification loss.\n", 274 | " classification_loss = zero_shot_eval(\n", 275 | " classifier, predicted_label_latents, predicted_image_latents, language_label_index, metric_calculators\n", 276 | " )\n", 277 | "\n", 278 | " contrastive_loss = (\n", 279 | " image_to_label_loss_ratio * contrastive_loss_images\n", 280 | " + label_to_image_loss_ratio * contrastive_loss_labels\n", 281 | " )\n", 282 | "\n", 283 | " optim.zero_grad(set_to_none=True)\n", 284 | " contrastive_loss.backward()\n", 285 | " optim.step()\n", 286 | " # Clip the temperature term for stability\n", 287 | " labelling_model.temperature.data = torch.clamp(\n", 288 | " labelling_model.temperature.data, max=np.log(100.0)\n", 289 | " )\n", 290 | " label_loss += contrastive_loss_labels.detach().cpu().item()\n", 291 | " image_loss += contrastive_loss_images.detach().cpu().item()\n", 292 | " total_classification_loss += classification_loss.detach().cpu().item()\n", 293 | " total_loss += contrastive_loss.detach().cpu().item()\n", 294 | " total_samples += 1\n", 295 | "\n", 296 | " to_log = {\n", 297 | " \"train_avg/contrastive_loss_labels\": label_loss / total_samples,\n", 298 | " \"train_avg/contrastive_loss_images\": image_loss / total_samples,\n", 299 | " \"train_avg/semseg_loss\": total_classification_loss / total_samples,\n", 300 | " \"train_avg/loss_sum\": total_loss / total_samples,\n", 301 | " \"train_avg/labelling_temp\": torch.exp(labelling_model.temperature.data.detach())\n", 302 | " .cpu()\n", 303 | " .item(),\n", 304 | " }\n", 305 | " for metric_dict in metric_calculators.values():\n", 306 | " for metric_name, metric in metric_dict.items():\n", 307 | " try:\n", 308 | " to_log[f\"train_avg/{metric_name}\"] = (\n", 309 | " metric.compute().detach().cpu().item()\n", 310 | " )\n", 311 | " except RuntimeError as e:\n", 312 | " to_log[f\"train_avg/{metric_name}\"] = 0.0\n", 313 | " metric.reset()\n", 314 | " wandb.log(to_log)\n", 315 | " logging.debug(pprint.pformat(to_log, indent=4, width=1))\n", 316 | " return total_loss" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 8, 322 | "id": "a84d1638", 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [ 326 | "def save(\n", 327 | " labelling_model: Union[ImplicitDataparallel, GridCLIPModel],\n", 328 | " optim: torch.optim.Optimizer,\n", 329 | " epoch: int,\n", 330 | " save_directory: str = SAVE_DIRECTORY,\n", 331 | " saving_dataparallel: bool = False,\n", 332 | "):\n", 333 | " if saving_dataparallel:\n", 334 | " to_save = labelling_model.module\n", 335 | " else:\n", 336 | " to_save = labelling_model\n", 337 | " state_dict = {\n", 338 | " \"model\": to_save.state_dict(),\n", 339 | " \"optim\": optim.state_dict(),\n", 340 | " \"epoch\": epoch,\n", 341 | " }\n", 342 | " torch.save(\n", 343 | " state_dict,\n", 344 | " f\"{save_directory}/implicit_scene_label_model_latest.pt\",\n", 345 | " )\n", 346 | " return 0" 347 | ] 348 | }, 349 | { 350 | "cell_type": "markdown", 351 | "id": "290e034a", 352 | "metadata": {}, 353 | "source": [ 354 | "## Set up the auxilary classes\n", 355 | "\n", 356 | "Like zero-shot classifier, dataloader, evaluators, optimizer, etc." 357 | ] 358 | }, 359 | { 360 | "cell_type": "code", 361 | "execution_count": 9, 362 | "id": "5492e990", 363 | "metadata": {}, 364 | "outputs": [ 365 | { 366 | "name": "stderr", 367 | "output_type": "stream", 368 | "text": [ 369 | "INFO - 2022-10-11 10:25:47,753 - SentenceTransformer - Load pretrained SentenceTransformer: all-mpnet-base-v2\n" 370 | ] 371 | }, 372 | { 373 | "data": { 374 | "application/vnd.jupyter.widget-view+json": { 375 | "model_id": "ba7d7a73c9cf46548f6b3c8bafebaa15", 376 | "version_major": 2, 377 | "version_minor": 0 378 | }, 379 | "text/plain": [ 380 | "Batches: 0%| | 0/3 [00:00" 529 | ] 530 | }, 531 | "metadata": {}, 532 | "output_type": "display_data" 533 | }, 534 | { 535 | "data": { 536 | "text/html": [ 537 | "Tracking run with wandb version 0.13.3" 538 | ], 539 | "text/plain": [ 540 | "" 541 | ] 542 | }, 543 | "metadata": {}, 544 | "output_type": "display_data" 545 | }, 546 | { 547 | "data": { 548 | "text/html": [ 549 | "Run data is saved locally in /home/mahi/code/clip-fields/demo/wandb/run-20221011_102550-j12j175e" 550 | ], 551 | "text/plain": [ 552 | "" 553 | ] 554 | }, 555 | "metadata": {}, 556 | "output_type": "display_data" 557 | }, 558 | { 559 | "data": { 560 | "text/html": [ 561 | "Syncing run gallant-river-1 to Weights & Biases (docs)
" 562 | ], 563 | "text/plain": [ 564 | "" 565 | ] 566 | }, 567 | "metadata": {}, 568 | "output_type": "display_data" 569 | } 570 | ], 571 | "source": [ 572 | "wandb.init(\n", 573 | " project=\"clipfields\",\n", 574 | ")\n", 575 | "# Set the extra parameters.\n", 576 | "wandb.config.web_labelled_points = len(training_data)" 577 | ] 578 | }, 579 | { 580 | "cell_type": "code", 581 | "execution_count": 14, 582 | "id": "9fb3671d", 583 | "metadata": {}, 584 | "outputs": [ 585 | { 586 | "name": "stderr", 587 | "output_type": "stream", 588 | "text": [ 589 | "Training epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00, 7.25it/s]\n", 590 | "Training epoch 1: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00, 7.34it/s]\n", 591 | "Training epoch 2: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00, 7.33it/s]\n", 592 | "Training epoch 3: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00, 7.35it/s]\n", 593 | "Training epoch 4: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00, 7.34it/s]\n", 594 | "Training epoch 5: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:34<00:00, 7.33it/s]\n" 595 | ] 596 | } 597 | ], 598 | "source": [ 599 | "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\" # Just to reduce excessive logging from sbert\n", 600 | "\n", 601 | "epoch = 0\n", 602 | "NUM_EPOCHS = 5\n", 603 | "\n", 604 | "while epoch <= NUM_EPOCHS:\n", 605 | " train(\n", 606 | " clip_train_loader,\n", 607 | " label_model,\n", 608 | " optim,\n", 609 | " epoch,\n", 610 | " train_classifier,\n", 611 | " metric_calculators=train_metric_calculators,\n", 612 | " )\n", 613 | " epoch += 1\n", 614 | " if epoch % SAVE_EVERY == 0:\n", 615 | " save(label_model, optim, epoch)" 616 | ] 617 | }, 618 | { 619 | "cell_type": "markdown", 620 | "id": "f54ff760-c0eb-4ace-b1e7-fc92a03c62be", 621 | "metadata": {}, 622 | "source": [ 623 | "This is saved already in `../clip_implicit_model`, so we don't have to save our trained model again. You can see the run data on [Weights and biases](https://wandb.ai/mahi/clipfields/runs/j12j175e). On our next tutorial episode, we will evaluate our model." 624 | ] 625 | } 626 | ], 627 | "metadata": { 628 | "kernelspec": { 629 | "display_name": "Python 3 (ipykernel)", 630 | "language": "python", 631 | "name": "python3" 632 | }, 633 | "language_info": { 634 | "codemirror_mode": { 635 | "name": "ipython", 636 | "version": 3 637 | }, 638 | "file_extension": ".py", 639 | "mimetype": "text/x-python", 640 | "name": "python", 641 | "nbconvert_exporter": "python", 642 | "pygments_lexer": "ipython3", 643 | "version": "3.8.13" 644 | }, 645 | "vscode": { 646 | "interpreter": { 647 | "hash": "db1f79a380f3c855f0bd6a6f1ad2b80d271c32dcf53749ce06204dd6e0a63f81" 648 | } 649 | } 650 | }, 651 | "nbformat": 4, 652 | "nbformat_minor": 5 653 | } 654 | -------------------------------------------------------------------------------- /docs/css/simple-grid.css: -------------------------------------------------------------------------------- 1 | @import url(https://fonts.googleapis.com/css?family=Lato:400,300,300italic,400italic,700,700italic); 2 | 3 | /* UNIVERSAL */ 4 | 5 | html, 6 | body { 7 | height: 100%; 8 | width: 100%; 9 | margin: 0; 10 | padding: 0; 11 | left: 0; 12 | top: 0; 13 | font-size: 100%; 14 | } 15 | 16 | /* ROOT FONT STYLES */ 17 | img { 18 | max-width: 100%; 19 | } 20 | 21 | .img { 22 | max-width: 100%; 23 | } 24 | 25 | 26 | * { 27 | font-family: 'Lato', Helvetica, sans-serif; 28 | color: #333447; 29 | line-height: 1.5; 30 | } 31 | 32 | .auto-resizable-iframe { 33 | max-width: 100%; 34 | margin: 0px auto; 35 | } 36 | 37 | .auto-resizable-iframe>div { 38 | position: relative; 39 | padding-bottom: 60%; 40 | height: 0px; 41 | } 42 | 43 | .auto-resizable-iframe iframe { 44 | position: absolute; 45 | top: 0px; 46 | left: 0px; 47 | width: 100%; 48 | height: 100%; 49 | } 50 | 51 | /* TYPOGRAPHY */ 52 | .line { 53 | height: 1px; 54 | width: 100%; 55 | background: #D6D6D8; 56 | margin: 20px 0; 57 | } 58 | 59 | h1 { 60 | font-size: 3.0rem; 61 | } 62 | 63 | h2 { 64 | font-size: 2rem; 65 | } 66 | 67 | h3 { 68 | font-size: 1.375rem; 69 | } 70 | 71 | h4 { 72 | font-size: 1.125rem; 73 | } 74 | 75 | h5 { 76 | font-size: 1rem; 77 | } 78 | 79 | h6 { 80 | font-size: 0.875rem; 81 | } 82 | 83 | p { 84 | font-size: 1.125rem; 85 | font-weight: 200; 86 | line-height: 1.8; 87 | } 88 | 89 | .font-light { 90 | font-weight: 300; 91 | } 92 | 93 | .font-regular { 94 | font-weight: 400; 95 | } 96 | 97 | .font-heavy { 98 | font-weight: 700; 99 | } 100 | 101 | /* POSITIONING */ 102 | 103 | .left { 104 | text-align: left; 105 | } 106 | 107 | .right { 108 | text-align: right; 109 | } 110 | 111 | .center { 112 | text-align: center; 113 | margin-left: auto; 114 | margin-right: auto; 115 | } 116 | 117 | .justify { 118 | text-align: justify; 119 | } 120 | 121 | /* ==== GRID SYSTEM ==== */ 122 | 123 | .container { 124 | width: 90%; 125 | margin-left: auto; 126 | margin-right: auto; 127 | } 128 | 129 | .row { 130 | position: relative; 131 | width: 100%; 132 | } 133 | 134 | 135 | 136 | .row [class^="col"] { 137 | float: left; 138 | margin: 0.5rem 2%; 139 | min-height: 0.125rem; 140 | } 141 | 142 | .col-1, 143 | .col-2, 144 | .col-3, 145 | .col-4, 146 | .col-5, 147 | .col-6, 148 | .col-7, 149 | .col-8, 150 | .col-9, 151 | .col-10, 152 | .col-11, 153 | .col-12 { 154 | width: 96%; 155 | } 156 | 157 | .col-1-sm { 158 | width: 4.33%; 159 | } 160 | 161 | .col-2-sm { 162 | width: 12.66%; 163 | } 164 | 165 | .col-3-sm { 166 | width: 21%; 167 | } 168 | 169 | .col-4-sm { 170 | width: 29.33%; 171 | } 172 | 173 | .col-5-sm { 174 | width: 37.66%; 175 | } 176 | 177 | .col-6-sm { 178 | width: 46%; 179 | } 180 | 181 | .col-7-sm { 182 | width: 54.33%; 183 | } 184 | 185 | .col-8-sm { 186 | width: 62.66%; 187 | } 188 | 189 | .col-9-sm { 190 | width: 71%; 191 | } 192 | 193 | .col-10-sm { 194 | width: 79.33%; 195 | } 196 | 197 | .col-11-sm { 198 | width: 87.66%; 199 | } 200 | 201 | .col-12-sm { 202 | width: 96%; 203 | } 204 | 205 | .row::after { 206 | content: ""; 207 | display: table; 208 | clear: both; 209 | } 210 | 211 | .hidden-sm { 212 | display: none; 213 | } 214 | 215 | @media only screen and (min-width: 33.75em) { 216 | 217 | /* 540px */ 218 | .container { 219 | width: 80%; 220 | } 221 | } 222 | 223 | @media only screen and (min-width: 45em) { 224 | 225 | /* 720px */ 226 | .col-1 { 227 | width: 4.33%; 228 | } 229 | 230 | .col-2 { 231 | width: 12.66%; 232 | } 233 | 234 | .col-3 { 235 | width: 21%; 236 | } 237 | 238 | .col-4 { 239 | width: 29.33%; 240 | } 241 | 242 | .col-5 { 243 | width: 37.66%; 244 | } 245 | 246 | .col-6 { 247 | width: 46%; 248 | } 249 | 250 | .col-7 { 251 | width: 54.33%; 252 | } 253 | 254 | .col-8 { 255 | width: 62.66%; 256 | } 257 | 258 | .col-9 { 259 | width: 71%; 260 | } 261 | 262 | .col-10 { 263 | width: 79.33%; 264 | } 265 | 266 | .col-11 { 267 | width: 87.66%; 268 | } 269 | 270 | .col-12 { 271 | width: 96%; 272 | } 273 | 274 | .col-13 { 275 | width: 100%; 276 | padding-top: 50px; 277 | } 278 | 279 | .col-14 { 280 | width: 65%; 281 | padding-top: 10px; 282 | padding-left: 100px; 283 | } 284 | 285 | .col-15 { 286 | width: 60%; 287 | } 288 | 289 | .hidden-sm { 290 | display: block; 291 | } 292 | } 293 | 294 | .code { 295 | /* background-color: lightgray; */ 296 | font-family: 'Courier New', Courier, monospace; 297 | font-size: small; 298 | overflow-wrap: normal; 299 | } 300 | 301 | @media only screen and (min-width: 60em) { 302 | 303 | /* 960px */ 304 | .container { 305 | width: 75%; 306 | max-width: 60rem; 307 | } 308 | } 309 | 310 | .seek_button { 311 | background-color: #FFF; 312 | border-top: #FFF; 313 | border-width: 1px; 314 | border-bottom-style: dashed; 315 | text-align: center; 316 | text-decoration: none; 317 | font-size: 16px; 318 | margin: 4px 0px; 319 | cursor: pointer; 320 | } 321 | 322 | /* Fallback stuff */ 323 | progress[value] { 324 | appearance: none; 325 | /* Needed for Safari */ 326 | border: none; 327 | /* Needed for Firefox */ 328 | color: #57068c; 329 | width: 100%; 330 | /* Fallback to a solid color */ 331 | } 332 | 333 | /* WebKit styles */ 334 | progress[value]::-webkit-progress-value { 335 | background-image: linear-gradient(to right, 336 | #57068c, #57068c); 337 | transition: width; 338 | transition-timing-function: ease; 339 | transition-duration: 500ms; 340 | width: 100%; 341 | } 342 | 343 | /* Firefox styles */ 344 | progress[value]::-moz-progress-bar { 345 | background-image: -moz-linear-gradient(right, 346 | #57068c, #57068c); 347 | width: 100%; 348 | } 349 | 350 | .seek_button { 351 | color: #00A2FF; 352 | } -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 13 | 14 | 15 | 16 | 17 | 18 | 20 | 21 | 22 | 23 | 24 | CLIP-Fields: Weakly Supervised Semantic Fields for Robotic Memory 25 | 79 | 86 | 87 | 88 | 89 |
90 |
91 |
92 |
93 |

CLIP-Fields: Weakly Supervised Semantic Fields for Robotic Memory

94 |
95 | 96 | 101 |
102 | 103 |

Code

104 |
105 |
106 |
107 | 108 |

Data

109 |
110 |
111 | 116 | 121 |
122 |
123 |
124 |

Mahi Shafiullah1

125 |
126 |
127 |

Chris Paxton2

128 |
129 |
130 |

Lerrel Pinto1

131 |
132 |
133 |

134 |
135 |
136 |

Soumith Chintala2

137 |
138 |
139 |

Arthur Szlam2

140 |
141 | 142 |
143 |
144 |
145 |

1: New York University

146 |
147 |
148 |

2: Meta AI

149 |
150 |
151 | 152 |
153 |
154 |

155 | Tl;dr CLIP-Field is a novel weakly supervised approach for learning a semantic robot 156 | memory that can respond to natural language queries solely from raw RGB-D and odometry data with no extra 157 | human labelling. 158 | It combines the image and language understanding capabilites of novel vision-language models (VLMs) like 159 | CLIP, large language models like sentence BERT, and open-label object detection models like Detic, and 160 | with spatial understanding capabilites of neural radiance field (NeRF) style architectures to 161 | build a spatial database that holds semantic information in it. 162 |

163 |
164 |
165 | 166 | 167 |
168 |
169 |

Abstract

170 |

171 | We propose CLIP-Fields, an implicit scene model that can be trained with no direct human supervision. 172 | This model learns a mapping from spatial locations to semantic embedding vectors. 173 | The mapping can then be used for a variety of tasks, such as segmentation, instance identification, semantic 174 | search over space, and view localization. 175 | Most importantly, the mapping can be trained with supervision coming only from web-image and web-text 176 | trained models such as CLIP, Detic, and Sentence-BERT. 177 | When compared to baselines like Mask-RCNN, our method outperforms on few-shot instance identification or 178 | semantic segmentation on the HM3D dataset with only a fraction of the examples. 179 | Finally, we show that using CLIP-Fields as a scene memory, robots can perform semantic navigation in 180 | real-world environments. 181 |

182 |
183 |
184 |
185 | 186 | 187 |
188 |
189 |
190 |

Real World Robot Experiments

191 |

192 | In these experiments, the robot is navigating the real world environements to "go and look at" the objects 193 | that are described by the query, which we expect to make accomplishing many downstream tasks possible, 194 | simply from natural language queries. 195 |

196 |
197 |

Robot queries in a real lab kitchen setup.

198 | 201 | Progress 202 |
203 | 205 | 208 | 211 | 214 | 216 |
217 |
218 |
219 |

Robot queries in a real lounge/library setup.

220 | 223 | 224 | Progress 225 |
226 | 228 | 231 | 233 | 236 | 239 |
240 |
241 |
242 |
243 |
244 | 245 |
246 |
247 |
248 |

Interactive Demonstration

249 |

250 | In this interactive demo, we show a heatmap of association between environment points and natural language 251 | queries made by a trained CLIP-field. Note that this model was trained without any human labels, and none of 252 | these phrases ever appeared in the training set. 253 |

254 | 255 |
256 |
257 |
258 | 259 | 260 |
261 |
262 |
263 |

Method

264 |

CLIP-Fields is based off of a series of simple ideas: 265 |

    266 |
  • Webscale models like CLIP and Detic provide lots of semantic information about objects that can be used 267 | for robot tasks, but don't encode spatial qualities of this information. 268 |
  • 269 |
  • NeRF-like approaches, on the other hand, have recently shown that they can capture very detailed scene 270 | information. 271 |
  • 272 |
  • We can combine these two, using a novel contrastive loss in order to capture scene-specific embeddings. 273 | We supervise multiple "heads," including object detection and CLIP, based on these webscale vision models, 274 | which allows us to do open-vocabulary queries at test time. 275 |
  • 276 |
277 |

278 |
279 |
280 |
281 | 282 |
283 |
284 |
285 | 286 | 287 | 288 | 289 |

290 | We collect our real world data using an iPhone 13 Pro, whose LiDAR scanner gives us RGB-D 291 | and odometry 292 | information, which we use to establish a pixel-to-real world coordinate correspondense. 293 |

294 |

295 | We use pre-trained models such as Detic and LSeg to 297 | extract the open-label semantic annotations from the RGB images, and use the labels to get Sentence-BERT encoding, and proposed bounding 299 | boxes to get CLIP visual encoding. 300 | Note that we need to use no human labelling at all for training our models, and all 301 | of our supervision comes from pre-trained web-scale language models or VLMs. 302 |

303 |
304 |
305 |
306 | 307 |
308 |
309 |
310 | 311 | 312 | 313 | 314 |

315 | Our model is an implicit function that maps each 3-dimensional spatial location to a higher dimensional 316 | representation vector. Each of the vectors contain both the language-based and vision-based semantic 317 | embeddings of the content of location (x, y, z). 318 |

319 |

320 | The trunk of our model is an instant neural 321 | graphics primitive based hash-grid architecture as our scene representation, and use MLPs to map them 322 | to higher dimensions that match the output dimension for embedding models such as Sentence-BERT or CLIP. 323 |

324 |
325 |
326 |
327 | 328 |
329 |
330 |
331 | 334 | replay 335 |

336 |

337 | We train with a contrastive loss that pushes the model's learned embeddings to be close to similar 338 | embeddings in the labeled datasets and far away from dissimilar embeddings. The contrastive training also 339 | helps us denoise the (sometimes) noisy labels given by the training models. 340 |

341 |
342 |
343 |
344 | 345 | 346 |
347 |
348 |
349 |

Experiments

350 | 351 | 352 | 353 | 354 |

355 | On our robot, we load the CLIP-Field to help with the localization and navigation of the robot. 356 | When the robot gets a new text query, we first convert it to a representation vector by encoding it with 357 | sentence-BERT and CLIP-text encoder. 358 | Then, we compare the representations with the representations of the XYZ coordinates in the scene and find 359 | the location in space maximizing their similarity. 360 | We use the robot’s navigation stack to navigate to that region, and point the robot camera to an XYZ 361 | coordinate where the dot product was highest. 362 | We consider the navigation task successful if the robot can navigate to and point the camera at an object 363 | that satisfies the query. 364 |

365 |
366 |
367 |
368 | 369 | 370 |
371 |
372 |
373 |

Future Work

374 |

We showed that CLIP-Fields can learn 3D semantic scene representations from little or no labeled data, 375 | relying on weakly-supervised web-image trained models, and that we can use these model in order to perform a 376 | simple “look-at” task. CLIP-Fields allow us to answer queries of varying levels of complexity. We expect 377 | this kind of 3D representation to be generally useful for robotics. For example, it ma be enriched with 378 | affordances for planning; the geometric database can be readily combined with end-to-end differentiable 379 | planners. In future work, we also hope to explore models that share parameters across scenes, and can handle 380 | dynamic scenes and objects. 381 |

382 | 383 |
384 |
385 |
386 | 387 |
388 |
389 |
390 |
391 | † This work was done while the first author was interning at Facebook AI Research. 392 |
393 |
394 |
395 |
396 | 397 |
398 |
399 |
400 | 401 | 419 | 420 | -------------------------------------------------------------------------------- /docs/kitchen.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | three.js webgl - PCD 6 | 7 | 8 | 157 | 158 | 159 | 160 |
161 | 162 | 163 |
164 |
165 | 166 | 167 | 168 | 169 |
170 | 171 | 172 | 173 | 180 | 181 | 494 | 495 | 496 | -------------------------------------------------------------------------------- /docs/mfiles/1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/1.mp4 -------------------------------------------------------------------------------- /docs/mfiles/2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/2.mp4 -------------------------------------------------------------------------------- /docs/mfiles/3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/3.mp4 -------------------------------------------------------------------------------- /docs/mfiles/Exp Table.001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/Exp Table.001.png -------------------------------------------------------------------------------- /docs/mfiles/arch bet.001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/arch bet.001.png -------------------------------------------------------------------------------- /docs/mfiles/arch.avif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/arch.avif -------------------------------------------------------------------------------- /docs/mfiles/arch.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/arch.jpeg -------------------------------------------------------------------------------- /docs/mfiles/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/arch.png -------------------------------------------------------------------------------- /docs/mfiles/behavior_transformers.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/behavior_transformers.mp4 -------------------------------------------------------------------------------- /docs/mfiles/data_processing.avif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/data_processing.avif -------------------------------------------------------------------------------- /docs/mfiles/data_processing.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/data_processing.jpeg -------------------------------------------------------------------------------- /docs/mfiles/data_processing.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/data_processing.png -------------------------------------------------------------------------------- /docs/mfiles/env/carla/1_obs.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/carla/1_obs.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/carla/1_over.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/carla/1_over.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/carla/2_obs.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/carla/2_obs.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/carla/2_over.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/carla/2_over.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/kitchen/1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/kitchen/1.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/kitchen/2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/kitchen/2.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/kitchen/all/kblh_small.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/kitchen/all/kblh_small.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/kitchen/all/kbls_small.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/kitchen/all/kbls_small.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/kitchen/all/kbth_small.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/kitchen/all/kbth_small.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/kitchen/all/kbtl_small.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/kitchen/all/kbtl_small.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/kitchen/all/mklh_small.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/kitchen/all/mklh_small.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/kitchen/all/mkls_small.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/kitchen/all/mkls_small.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/kitchen/all/mkth_small.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/kitchen/all/mkth_small.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/kitchen/all/mktl_small.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/kitchen/all/mktl_small.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/1.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/1.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/2.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/g1gg.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/g1gg.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/g1gg2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/g1gg2.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/g1gg3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/g1gg3.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/g1gr.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/g1gr.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/g1gr2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/g1gr2.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/g1gr3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/g1gr3.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/r1rg.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/r1rg.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/r1rg2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/r1rg2.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/r1rg3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/r1rg3.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/r1rr.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/r1rr.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/r1rr2.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/r1rr2.mp4 -------------------------------------------------------------------------------- /docs/mfiles/env/pushblock/all/r1rr3.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/env/pushblock/all/r1rr3.mp4 -------------------------------------------------------------------------------- /docs/mfiles/exp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/exp.png -------------------------------------------------------------------------------- /docs/mfiles/kitchen_1cm.pcd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/kitchen_1cm.pcd -------------------------------------------------------------------------------- /docs/mfiles/model_nyu_kitchen_fill_out_water_bottle.pcd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/model_nyu_kitchen_fill_out_water_bottle.pcd -------------------------------------------------------------------------------- /docs/mfiles/model_nyu_kitchen_make_some_coffee.pcd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/model_nyu_kitchen_make_some_coffee.pcd -------------------------------------------------------------------------------- /docs/mfiles/model_nyu_kitchen_warm_up_my_lunch.pcd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/model_nyu_kitchen_warm_up_my_lunch.pcd -------------------------------------------------------------------------------- /docs/mfiles/multimodal_colorbar_flipped-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/multimodal_colorbar_flipped-1.png -------------------------------------------------------------------------------- /docs/mfiles/multimodal_colorbar_flipped.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/multimodal_colorbar_flipped.pdf -------------------------------------------------------------------------------- /docs/mfiles/nyu_kitchen_throw_my_trash.pcd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/nyu_kitchen_throw_my_trash.pcd -------------------------------------------------------------------------------- /docs/mfiles/nyu_robot_run_clipped_small.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/nyu_robot_run_clipped_small.mp4 -------------------------------------------------------------------------------- /docs/mfiles/pit_robot_run_clipped_small.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/pit_robot_run_clipped_small.mp4 -------------------------------------------------------------------------------- /docs/mfiles/query_navigation.avif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/query_navigation.avif -------------------------------------------------------------------------------- /docs/mfiles/query_navigation.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/query_navigation.jpg -------------------------------------------------------------------------------- /docs/mfiles/training.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/notmahi/clip-fields/7edfa8562b1337d7b73bf700968e8b74323b8b9e/docs/mfiles/training.mp4 -------------------------------------------------------------------------------- /docs/more/bibtex.txt: -------------------------------------------------------------------------------- 1 | @article{shafiullah2022clipfields, 2 | title = {CLIP-Fields: Weakly Supervised Semantic Fields for Robotic Memory}, 3 | author = {Nur Muhammad Mahi Shafiullah and Chris Paxton and Lerrel Pinto and Soumith Chintala and Arthur Szlam}, 4 | year = {2022}, 5 | journal = {arXiv preprint arXiv: Arxiv-2210.05663} 6 | } 7 | -------------------------------------------------------------------------------- /docs/more/blockpush/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 19 | 20 | 21 | 22 | 23 | Behavior Transformers: Cloning k modes with one stone 24 | 36 | 37 | 38 | 39 |
40 |
41 |
42 |
43 |

Behavior Transformers:
Cloning 44 | k 45 | modes with one stone

46 |
47 | 48 | 53 |
54 |
55 | 56 |
57 |
58 |
59 |

Unconditional Rollouts of BeT on Block Pushing environment

60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 | 71 |
72 |
73 | 76 |
77 |
78 | 81 |
82 |
83 | 84 |
85 |
86 | 89 |
90 |
91 | 94 |
95 |
96 | 99 |
100 |
101 | 102 |
103 |
104 | 107 |
108 |
109 | 112 |
113 |
114 | 117 |
118 |
119 | 120 |
121 |
122 | 125 |
126 |
127 | 130 |
131 |
132 | 135 |
136 |
137 | 138 |
139 |
140 |
141 | 142 |
143 |
144 |
145 | 146 | 165 | 166 | -------------------------------------------------------------------------------- /docs/more/kitchen/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 19 | 20 | 21 | 22 | 23 | Behavior Transformers: Cloning k modes with one stone 24 | 36 | 37 | 38 | 39 |
40 |
41 |
42 |
43 |

Behavior Transformers:
Cloning 44 | k 45 | modes with one stone

46 |
47 | 48 | 53 |
54 |
55 | 56 |
57 |
58 |
59 |

Unconditional Rollouts of BeT on Franka Kitchen

60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 | 71 |
72 |
73 | 76 |
77 |
78 | 79 |
80 |
81 | 84 |
85 |
86 | 89 |
90 |
91 | 92 |
93 |
94 | 97 |
98 |
99 | 102 |
103 |
104 | 105 |
106 |
107 | 110 |
111 |
112 | 115 |
116 |
117 | 118 |
119 |
120 |
121 | 122 |
123 |
124 |
125 | 126 | 145 | 146 | -------------------------------------------------------------------------------- /grid_hash_model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from gridencoder import GridEncoder 8 | from misc import MLP 9 | 10 | 11 | class GridCLIPModel(nn.Module): 12 | def __init__( 13 | self, 14 | max_coords: Optional[torch.Tensor] = None, 15 | min_coords: Optional[torch.Tensor] = None, 16 | mlp_depth: int = 2, 17 | mlp_width: int = 256, 18 | batchnorm: bool = False, 19 | num_levels: int = 16, 20 | level_dim: int = 8, 21 | log2_hashmap_size: int = 24, 22 | per_level_scale: float = 2.0, 23 | device: str = "cuda", 24 | image_rep_size: int = 512, 25 | text_rep_size: int = 512, 26 | bounds: float = 10.0, 27 | ): 28 | super().__init__() 29 | 30 | self._grid_model = GridEncoder( 31 | input_dim=3, 32 | num_levels=num_levels, 33 | level_dim=level_dim, 34 | base_resolution=16, 35 | log2_hashmap_size=log2_hashmap_size, 36 | per_level_scale=per_level_scale, 37 | desired_resolution=None, 38 | gridtype="hash", 39 | align_corners=False, 40 | ) 41 | # Now convert the output with an MLP 42 | self._post_grid = MLP( 43 | input_dim=num_levels * level_dim, 44 | hidden_dim=mlp_width, 45 | hidden_depth=mlp_depth, 46 | output_dim=image_rep_size + text_rep_size, 47 | batchnorm=batchnorm, 48 | ) 49 | # Mini MLP for extra storage for image loss 50 | self._image_head = nn.Identity() 51 | # Magic value adviced by @imisra 52 | self.temperature = nn.Parameter(torch.log(torch.tensor(1.0 / 0.07))) 53 | 54 | self._image_rep_size = image_rep_size 55 | self._text_rep_size = text_rep_size 56 | 57 | if not (max_coords is not None and min_coords is not None): 58 | self._max_bounds, self._min_bounds = ( 59 | torch.ones(3) * bounds, 60 | torch.ones(3) * -bounds, 61 | ) 62 | else: 63 | assert len(max_coords) == len(min_coords) 64 | self._max_bounds, self._min_bounds = max_coords, min_coords 65 | 66 | self._grid_model = self._grid_model.to(device) 67 | self._post_grid = self._post_grid.to(device) 68 | self._image_head = self._image_head.to(device) 69 | self.temperature.data = self.temperature.data.to(device) 70 | self._max_bounds = self._max_bounds.to(device) 71 | self._min_bounds = self._min_bounds.to(device) 72 | 73 | def forward(self, x: torch.Tensor, bounds: Optional[float] = None): 74 | if bounds is None: 75 | max_bounds, min_bounds = self._max_bounds.to(x.device), self._min_bounds.to( 76 | x.device 77 | ) 78 | else: 79 | max_bounds, min_bounds = ( 80 | torch.ones(3, device=x.device) * bounds, 81 | torch.ones(3, device=x.device) * -bounds, 82 | ) 83 | bounded_x = (x - min_bounds) / (max_bounds - min_bounds) 84 | grid_hash = self._grid_model(bounded_x, bound=1.0) 85 | result = self._post_grid(grid_hash) 86 | # label_latent, image_latent = torch.chunk(result, chunks=2, dim=-1) 87 | label_latent, image_latent = ( 88 | result[..., : self._text_rep_size], 89 | result[ 90 | ..., self._text_rep_size : self._text_rep_size + self._image_rep_size 91 | ], 92 | ) 93 | image_latent = self._image_head(image_latent) 94 | return label_latent, image_latent 95 | 96 | def to(self, device): 97 | self._grid_model = self._grid_model.to(device) 98 | self._post_grid = self._post_grid.to(device) 99 | self._image_head = self._image_head.to(device) 100 | self._max_bounds = self._max_bounds.to(device) 101 | self._min_bounds = self._min_bounds.to(device) 102 | self.temperature.data = self.temperature.data.to(device) 103 | return self 104 | 105 | def compute_loss( 106 | self, predicted_latents, actual_latents, label_mask=None, weights=None 107 | ): 108 | normalized_predicted_latents = F.normalize(predicted_latents, p=2, dim=-1) 109 | normalized_actual_latents = F.normalize(actual_latents, p=2, dim=-1) 110 | temp = torch.exp(self.temperature) 111 | sim = ( 112 | torch.einsum( 113 | "i d, j d -> i j", 114 | normalized_predicted_latents, 115 | normalized_actual_latents, 116 | ) 117 | * temp 118 | ) 119 | # Zero out the cells where the labels are same. 120 | if label_mask is not None: 121 | sim = sim * label_mask 122 | del label_mask 123 | labels = torch.arange(len(predicted_latents), device=predicted_latents.device) 124 | if weights is None: 125 | loss = (F.cross_entropy(sim, labels) + F.cross_entropy(sim.t(), labels)) / 2 126 | else: 127 | loss = ( 128 | F.cross_entropy(sim, labels, reduction="none") 129 | + F.cross_entropy(sim.t(), labels, reduction="none") 130 | ) / 2 131 | loss = (loss * weights).mean() 132 | return loss 133 | -------------------------------------------------------------------------------- /gridencoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .grid import GridEncoder 2 | -------------------------------------------------------------------------------- /gridencoder/backend.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.cpp_extension import load 3 | 4 | _src_path = os.path.dirname(os.path.abspath(__file__)) 5 | 6 | nvcc_flags = [ 7 | "-O3", 8 | "-std=c++14", 9 | "-U__CUDA_NO_HALF_OPERATORS__", 10 | "-U__CUDA_NO_HALF_CONVERSIONS__", 11 | "-U__CUDA_NO_HALF2_OPERATORS__", 12 | ] 13 | 14 | if os.name == "posix": 15 | c_flags = ["-O3", "-std=c++14"] 16 | elif os.name == "nt": 17 | c_flags = ["/O2", "/std:c++17"] 18 | 19 | # find cl.exe 20 | def find_cl_path(): 21 | import glob 22 | 23 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 24 | paths = sorted( 25 | glob.glob( 26 | r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" 27 | % edition 28 | ), 29 | reverse=True, 30 | ) 31 | if paths: 32 | return paths[0] 33 | 34 | # If cl.exe is not on path, try to find it. 35 | if os.system("where cl.exe >nul 2>nul") != 0: 36 | cl_path = find_cl_path() 37 | if cl_path is None: 38 | raise RuntimeError( 39 | "Could not locate a supported Microsoft Visual C++ installation" 40 | ) 41 | os.environ["PATH"] += ";" + cl_path 42 | 43 | _backend = load( 44 | name="_grid_encoder", 45 | extra_cflags=c_flags, 46 | extra_cuda_cflags=nvcc_flags, 47 | sources=[ 48 | os.path.join(_src_path, "src", f) 49 | for f in [ 50 | "gridencoder.cu", 51 | "bindings.cpp", 52 | ] 53 | ], 54 | ) 55 | 56 | __all__ = ["_backend"] 57 | -------------------------------------------------------------------------------- /gridencoder/grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.autograd import Function 6 | from torch.autograd.function import once_differentiable 7 | from torch.cuda.amp import custom_bwd, custom_fwd 8 | 9 | try: 10 | import _gridencoder as _backend 11 | except ImportError: 12 | from .backend import _backend 13 | 14 | _gridtype_to_id = { 15 | "hash": 0, 16 | "tiled": 1, 17 | } 18 | 19 | 20 | class _grid_encode(Function): 21 | @staticmethod 22 | @custom_fwd 23 | def forward( 24 | ctx, 25 | inputs, 26 | embeddings, 27 | offsets, 28 | per_level_scale, 29 | base_resolution, 30 | calc_grad_inputs=False, 31 | gridtype=0, 32 | align_corners=False, 33 | ): 34 | # inputs: [B, D], float in [0, 1] 35 | # embeddings: [sO, C], float 36 | # offsets: [L + 1], int 37 | # RETURN: [B, F], float 38 | 39 | inputs = inputs.contiguous() 40 | 41 | B, D = inputs.shape # batch size, coord dim 42 | L = offsets.shape[0] - 1 # level 43 | C = embeddings.shape[1] # embedding dim for each level 44 | S = np.log2( 45 | per_level_scale 46 | ) # resolution multiplier at each level, apply log2 for later CUDA exp2f 47 | H = base_resolution # base resolution 48 | 49 | # manually handle autocast (only use half precision embeddings, inputs must be float for enough precision) 50 | # if C % 2 != 0, force float, since half for atomicAdd is very slow. 51 | if torch.is_autocast_enabled() and C % 2 == 0: 52 | embeddings = embeddings.to(torch.half) 53 | 54 | # L first, optimize cache for cuda kernel, but needs an extra permute later 55 | outputs = torch.empty(L, B, C, device=inputs.device, dtype=embeddings.dtype) 56 | 57 | if calc_grad_inputs: 58 | dy_dx = torch.empty( 59 | B, L * D * C, device=inputs.device, dtype=embeddings.dtype 60 | ) 61 | else: 62 | dy_dx = torch.empty( 63 | 1, device=inputs.device, dtype=embeddings.dtype 64 | ) # placeholder... TODO: a better way? 65 | 66 | _backend.grid_encode_forward( 67 | inputs, 68 | embeddings, 69 | offsets, 70 | outputs, 71 | B, 72 | D, 73 | C, 74 | L, 75 | S, 76 | H, 77 | calc_grad_inputs, 78 | dy_dx, 79 | gridtype, 80 | align_corners, 81 | ) 82 | 83 | # permute back to [B, L * C] 84 | outputs = outputs.permute(1, 0, 2).reshape(B, L * C) 85 | 86 | ctx.save_for_backward(inputs, embeddings, offsets, dy_dx) 87 | ctx.dims = [B, D, C, L, S, H, gridtype] 88 | ctx.calc_grad_inputs = calc_grad_inputs 89 | ctx.align_corners = align_corners 90 | 91 | return outputs 92 | 93 | @staticmethod 94 | # @once_differentiable 95 | @custom_bwd 96 | def backward(ctx, grad): 97 | 98 | inputs, embeddings, offsets, dy_dx = ctx.saved_tensors 99 | B, D, C, L, S, H, gridtype = ctx.dims 100 | calc_grad_inputs = ctx.calc_grad_inputs 101 | align_corners = ctx.align_corners 102 | 103 | # grad: [B, L * C] --> [L, B, C] 104 | grad = grad.view(B, L, C).permute(1, 0, 2).contiguous() 105 | 106 | grad_embeddings = torch.zeros_like(embeddings) 107 | 108 | if calc_grad_inputs: 109 | grad_inputs = torch.zeros_like(inputs, dtype=embeddings.dtype) 110 | else: 111 | grad_inputs = torch.zeros(1, device=inputs.device, dtype=embeddings.dtype) 112 | 113 | _backend.grid_encode_backward( 114 | grad, 115 | inputs, 116 | embeddings, 117 | offsets, 118 | grad_embeddings, 119 | B, 120 | D, 121 | C, 122 | L, 123 | S, 124 | H, 125 | calc_grad_inputs, 126 | dy_dx, 127 | grad_inputs, 128 | gridtype, 129 | align_corners, 130 | ) 131 | 132 | if calc_grad_inputs: 133 | grad_inputs = grad_inputs.to(inputs.dtype) 134 | return grad_inputs, grad_embeddings, None, None, None, None, None, None 135 | else: 136 | return None, grad_embeddings, None, None, None, None, None, None 137 | 138 | 139 | grid_encode = _grid_encode.apply 140 | 141 | 142 | class GridEncoder(nn.Module): 143 | def __init__( 144 | self, 145 | input_dim=3, 146 | num_levels=16, 147 | level_dim=2, 148 | per_level_scale=2, 149 | base_resolution=16, 150 | log2_hashmap_size=19, 151 | desired_resolution=None, 152 | gridtype="hash", 153 | align_corners=False, 154 | ): 155 | super().__init__() 156 | 157 | # the finest resolution desired at the last level, if provided, overridee per_level_scale 158 | if desired_resolution is not None: 159 | per_level_scale = np.exp2( 160 | np.log2(desired_resolution / base_resolution) / (num_levels - 1) 161 | ) 162 | 163 | self.input_dim = input_dim # coord dims, 2 or 3 164 | self.num_levels = num_levels # num levels, each level multiply resolution by 2 165 | self.level_dim = level_dim # encode channels per level 166 | self.per_level_scale = ( 167 | per_level_scale # multiply resolution by this scale at each level. 168 | ) 169 | self.log2_hashmap_size = log2_hashmap_size 170 | self.base_resolution = base_resolution 171 | self.output_dim = num_levels * level_dim 172 | self.gridtype = gridtype 173 | self.gridtype_id = _gridtype_to_id[gridtype] # "tiled" or "hash" 174 | self.align_corners = align_corners 175 | 176 | # allocate parameters 177 | offsets = [] 178 | offset = 0 179 | self.max_params = 2 ** log2_hashmap_size 180 | for i in range(num_levels): 181 | resolution = int(np.ceil(base_resolution * per_level_scale ** i)) 182 | params_in_level = min( 183 | self.max_params, 184 | (resolution if align_corners else resolution + 1) ** input_dim, 185 | ) # limit max number 186 | params_in_level = int(np.ceil(params_in_level / 8) * 8) # make divisible 187 | offsets.append(offset) 188 | offset += params_in_level 189 | offsets.append(offset) 190 | offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) 191 | self.register_buffer("offsets", offsets) 192 | 193 | self.n_params = offsets[-1] * level_dim 194 | 195 | # parameters 196 | self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) 197 | 198 | self.reset_parameters() 199 | 200 | def reset_parameters(self): 201 | std = 1e-4 202 | self.embeddings.data.uniform_(-std, std) 203 | 204 | def __repr__(self): 205 | return f"GridEncoder: input_dim={self.input_dim} num_levels={self.num_levels} level_dim={self.level_dim} resolution={self.base_resolution} -> {int(round(self.base_resolution * self.per_level_scale ** (self.num_levels - 1)))} per_level_scale={self.per_level_scale:.4f} params={tuple(self.embeddings.shape)} gridtype={self.gridtype} align_corners={self.align_corners}" 206 | 207 | def forward(self, inputs, bound=1): 208 | # inputs: [..., input_dim], normalized real world positions in [-bound, bound] 209 | # return: [..., num_levels * level_dim] 210 | 211 | inputs = (inputs + bound) / (2 * bound) # map to [0, 1] 212 | 213 | # print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item()) 214 | 215 | prefix_shape = list(inputs.shape[:-1]) 216 | inputs = inputs.view(-1, self.input_dim) 217 | 218 | outputs = grid_encode( 219 | inputs, 220 | self.embeddings, 221 | self.offsets, 222 | self.per_level_scale, 223 | self.base_resolution, 224 | inputs.requires_grad, 225 | self.gridtype_id, 226 | self.align_corners, 227 | ) 228 | outputs = outputs.view(prefix_shape + [self.output_dim]) 229 | 230 | # print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) 231 | 232 | return outputs 233 | -------------------------------------------------------------------------------- /gridencoder/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 4 | 5 | _src_path = os.path.dirname(os.path.abspath(__file__)) 6 | 7 | nvcc_flags = [ 8 | "-O3", 9 | "-std=c++17", 10 | "-U__CUDA_NO_HALF_OPERATORS__", 11 | "-U__CUDA_NO_HALF_CONVERSIONS__", 12 | "-U__CUDA_NO_HALF2_OPERATORS__", 13 | ] 14 | 15 | if os.name == "posix": 16 | c_flags = ["-O3", "-std=c++17"] 17 | elif os.name == "nt": 18 | c_flags = ["/O2", "/std:c++17"] 19 | 20 | # find cl.exe 21 | def find_cl_path(): 22 | import glob 23 | 24 | for edition in ["Enterprise", "Professional", "BuildTools", "Community"]: 25 | paths = sorted( 26 | glob.glob( 27 | r"C:\\Program Files (x86)\\Microsoft Visual Studio\\*\\%s\\VC\\Tools\\MSVC\\*\\bin\\Hostx64\\x64" 28 | % edition 29 | ), 30 | reverse=True, 31 | ) 32 | if paths: 33 | return paths[0] 34 | 35 | # If cl.exe is not on path, try to find it. 36 | if os.system("where cl.exe >nul 2>nul") != 0: 37 | cl_path = find_cl_path() 38 | if cl_path is None: 39 | raise RuntimeError( 40 | "Could not locate a supported Microsoft Visual C++ installation" 41 | ) 42 | os.environ["PATH"] += ";" + cl_path 43 | 44 | setup( 45 | name="gridencoder", # package name, import this to use python API 46 | ext_modules=[ 47 | CUDAExtension( 48 | name="_gridencoder", # extension name, import this to use CUDA API 49 | sources=[ 50 | os.path.join(_src_path, "src", f) 51 | for f in [ 52 | "gridencoder.cu", 53 | "bindings.cpp", 54 | ] 55 | ], 56 | extra_compile_args={ 57 | "cxx": c_flags, 58 | "nvcc": nvcc_flags, 59 | }, 60 | ), 61 | ], 62 | cmdclass={ 63 | "build_ext": BuildExtension, 64 | }, 65 | ) 66 | -------------------------------------------------------------------------------- /gridencoder/src/bindings.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "gridencoder.h" 4 | 5 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 6 | m.def("grid_encode_forward", &grid_encode_forward, "grid_encode_forward (CUDA)"); 7 | m.def("grid_encode_backward", &grid_encode_backward, "grid_encode_backward (CUDA)"); 8 | } -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | #include 12 | #include 13 | 14 | 15 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 16 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be a contiguous tensor") 17 | #define CHECK_IS_INT(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Int, #x " must be an int tensor") 18 | #define CHECK_IS_FLOATING(x) TORCH_CHECK(x.scalar_type() == at::ScalarType::Float || x.scalar_type() == at::ScalarType::Half || x.scalar_type() == at::ScalarType::Double, #x " must be a floating tensor") 19 | 20 | 21 | // just for compatability of half precision in AT_DISPATCH_FLOATING_TYPES_AND_HALF... 22 | static inline __device__ at::Half atomicAdd(at::Half *address, at::Half val) { 23 | // requires CUDA >= 10 and ARCH >= 70 24 | // this is very slow compared to float or __half2, and never used. 25 | //return atomicAdd(reinterpret_cast<__half*>(address), val); 26 | } 27 | 28 | 29 | template 30 | static inline __host__ __device__ T div_round_up(T val, T divisor) { 31 | return (val + divisor - 1) / divisor; 32 | } 33 | 34 | 35 | template 36 | __device__ uint32_t fast_hash(const uint32_t pos_grid[D]) { 37 | static_assert(D <= 7, "fast_hash can only hash up to 7 dimensions."); 38 | 39 | // While 1 is technically not a good prime for hashing (or a prime at all), it helps memory coherence 40 | // and is sufficient for our use case of obtaining a uniformly colliding index from high-dimensional 41 | // coordinates. 42 | constexpr uint32_t primes[7] = { 1, 2654435761, 805459861, 3674653429, 2097192037, 1434869437, 2165219737 }; 43 | 44 | uint32_t result = 0; 45 | #pragma unroll 46 | for (uint32_t i = 0; i < D; ++i) { 47 | result ^= pos_grid[i] * primes[i]; 48 | } 49 | 50 | return result; 51 | } 52 | 53 | 54 | template 55 | __device__ uint32_t get_grid_index(const uint32_t gridtype, const bool align_corners, const uint32_t ch, const uint32_t hashmap_size, const uint32_t resolution, const uint32_t pos_grid[D]) { 56 | uint32_t stride = 1; 57 | uint32_t index = 0; 58 | 59 | #pragma unroll 60 | for (uint32_t d = 0; d < D && stride <= hashmap_size; d++) { 61 | index += pos_grid[d] * stride; 62 | stride *= align_corners ? resolution: (resolution + 1); 63 | } 64 | 65 | // NOTE: for NeRF, the hash is in fact not necessary. Check https://github.com/NVlabs/instant-ngp/issues/97. 66 | // gridtype: 0 == hash, 1 == tiled 67 | if (gridtype == 0 && stride > hashmap_size) { 68 | index = fast_hash(pos_grid); 69 | } 70 | 71 | return (index % hashmap_size) * C + ch; 72 | } 73 | 74 | 75 | template 76 | __global__ void kernel_grid( 77 | const float * __restrict__ inputs, 78 | const scalar_t * __restrict__ grid, 79 | const int * __restrict__ offsets, 80 | scalar_t * __restrict__ outputs, 81 | const uint32_t B, const uint32_t L, const float S, const uint32_t H, 82 | const bool calc_grad_inputs, 83 | scalar_t * __restrict__ dy_dx, 84 | const uint32_t gridtype, 85 | const bool align_corners 86 | ) { 87 | const uint32_t b = blockIdx.x * blockDim.x + threadIdx.x; 88 | 89 | if (b >= B) return; 90 | 91 | const uint32_t level = blockIdx.y; 92 | 93 | // locate 94 | grid += (uint32_t)offsets[level] * C; 95 | inputs += b * D; 96 | outputs += level * B * C + b * C; 97 | 98 | // check input range (should be in [0, 1]) 99 | bool flag_oob = false; 100 | #pragma unroll 101 | for (uint32_t d = 0; d < D; d++) { 102 | if (inputs[d] < 0 || inputs[d] > 1) { 103 | flag_oob = true; 104 | } 105 | } 106 | // if input out of bound, just set output to 0 107 | if (flag_oob) { 108 | #pragma unroll 109 | for (uint32_t ch = 0; ch < C; ch++) { 110 | outputs[ch] = 0; 111 | } 112 | if (calc_grad_inputs) { 113 | dy_dx += b * D * L * C + level * D * C; // B L D C 114 | #pragma unroll 115 | for (uint32_t d = 0; d < D; d++) { 116 | #pragma unroll 117 | for (uint32_t ch = 0; ch < C; ch++) { 118 | dy_dx[d * C + ch] = 0; 119 | } 120 | } 121 | } 122 | return; 123 | } 124 | 125 | const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; 126 | const float scale = exp2f(level * S) * H - 1.0f; 127 | const uint32_t resolution = (uint32_t)ceil(scale) + 1; 128 | 129 | // calculate coordinate 130 | float pos[D]; 131 | uint32_t pos_grid[D]; 132 | 133 | #pragma unroll 134 | for (uint32_t d = 0; d < D; d++) { 135 | pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); 136 | pos_grid[d] = floorf(pos[d]); 137 | pos[d] -= (float)pos_grid[d]; 138 | } 139 | 140 | //printf("[b=%d, l=%d] pos=(%f, %f)+(%d, %d)\n", b, level, pos[0], pos[1], pos_grid[0], pos_grid[1]); 141 | 142 | // interpolate 143 | scalar_t results[C] = {0}; // temp results in register 144 | 145 | #pragma unroll 146 | for (uint32_t idx = 0; idx < (1 << D); idx++) { 147 | float w = 1; 148 | uint32_t pos_grid_local[D]; 149 | 150 | #pragma unroll 151 | for (uint32_t d = 0; d < D; d++) { 152 | if ((idx & (1 << d)) == 0) { 153 | w *= 1 - pos[d]; 154 | pos_grid_local[d] = pos_grid[d]; 155 | } else { 156 | w *= pos[d]; 157 | pos_grid_local[d] = pos_grid[d] + 1; 158 | } 159 | } 160 | 161 | uint32_t index = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); 162 | 163 | // writing to register (fast) 164 | #pragma unroll 165 | for (uint32_t ch = 0; ch < C; ch++) { 166 | results[ch] += w * grid[index + ch]; 167 | } 168 | 169 | //printf("[b=%d, l=%d] int %d, idx %d, w %f, val %f\n", b, level, idx, index, w, grid[index]); 170 | } 171 | 172 | // writing to global memory (slow) 173 | #pragma unroll 174 | for (uint32_t ch = 0; ch < C; ch++) { 175 | outputs[ch] = results[ch]; 176 | } 177 | 178 | // prepare dy_dx for calc_grad_inputs 179 | // differentiable (soft) indexing: https://discuss.pytorch.org/t/differentiable-indexing/17647/9 180 | if (calc_grad_inputs) { 181 | 182 | dy_dx += b * D * L * C + level * D * C; // B L D C 183 | 184 | #pragma unroll 185 | for (uint32_t gd = 0; gd < D; gd++) { 186 | 187 | scalar_t results_grad[C] = {0}; 188 | 189 | #pragma unroll 190 | for (uint32_t idx = 0; idx < (1 << (D - 1)); idx++) { 191 | float w = scale; 192 | uint32_t pos_grid_local[D]; 193 | 194 | #pragma unroll 195 | for (uint32_t nd = 0; nd < D - 1; nd++) { 196 | const uint32_t d = (nd >= gd) ? (nd + 1) : nd; 197 | 198 | if ((idx & (1 << nd)) == 0) { 199 | w *= 1 - pos[d]; 200 | pos_grid_local[d] = pos_grid[d]; 201 | } else { 202 | w *= pos[d]; 203 | pos_grid_local[d] = pos_grid[d] + 1; 204 | } 205 | } 206 | 207 | pos_grid_local[gd] = pos_grid[gd]; 208 | uint32_t index_left = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); 209 | pos_grid_local[gd] = pos_grid[gd] + 1; 210 | uint32_t index_right = get_grid_index(gridtype, align_corners, 0, hashmap_size, resolution, pos_grid_local); 211 | 212 | #pragma unroll 213 | for (uint32_t ch = 0; ch < C; ch++) { 214 | results_grad[ch] += w * (grid[index_right + ch] - grid[index_left + ch]); 215 | } 216 | } 217 | 218 | #pragma unroll 219 | for (uint32_t ch = 0; ch < C; ch++) { 220 | dy_dx[gd * C + ch] = results_grad[ch]; 221 | } 222 | } 223 | } 224 | } 225 | 226 | 227 | template 228 | __global__ void kernel_grid_backward( 229 | const scalar_t * __restrict__ grad, 230 | const float * __restrict__ inputs, 231 | const scalar_t * __restrict__ grid, 232 | const int * __restrict__ offsets, 233 | scalar_t * __restrict__ grad_grid, 234 | const uint32_t B, const uint32_t L, const float S, const uint32_t H, 235 | const uint32_t gridtype, 236 | const bool align_corners 237 | ) { 238 | const uint32_t b = (blockIdx.x * blockDim.x + threadIdx.x) * N_C / C; 239 | if (b >= B) return; 240 | 241 | const uint32_t level = blockIdx.y; 242 | const uint32_t ch = (blockIdx.x * blockDim.x + threadIdx.x) * N_C - b * C; 243 | 244 | // locate 245 | grad_grid += offsets[level] * C; 246 | inputs += b * D; 247 | grad += level * B * C + b * C + ch; // L, B, C 248 | 249 | const uint32_t hashmap_size = offsets[level + 1] - offsets[level]; 250 | const float scale = exp2f(level * S) * H - 1.0f; 251 | const uint32_t resolution = (uint32_t)ceil(scale) + 1; 252 | 253 | // check input range (should be in [0, 1]) 254 | #pragma unroll 255 | for (uint32_t d = 0; d < D; d++) { 256 | if (inputs[d] < 0 || inputs[d] > 1) { 257 | return; // grad is init as 0, so we simply return. 258 | } 259 | } 260 | 261 | // calculate coordinate 262 | float pos[D]; 263 | uint32_t pos_grid[D]; 264 | 265 | #pragma unroll 266 | for (uint32_t d = 0; d < D; d++) { 267 | pos[d] = inputs[d] * scale + (align_corners ? 0.0f : 0.5f); 268 | pos_grid[d] = floorf(pos[d]); 269 | pos[d] -= (float)pos_grid[d]; 270 | } 271 | 272 | scalar_t grad_cur[N_C] = {0}; // fetch to register 273 | #pragma unroll 274 | for (uint32_t c = 0; c < N_C; c++) { 275 | grad_cur[c] = grad[c]; 276 | } 277 | 278 | // interpolate 279 | #pragma unroll 280 | for (uint32_t idx = 0; idx < (1 << D); idx++) { 281 | float w = 1; 282 | uint32_t pos_grid_local[D]; 283 | 284 | #pragma unroll 285 | for (uint32_t d = 0; d < D; d++) { 286 | if ((idx & (1 << d)) == 0) { 287 | w *= 1 - pos[d]; 288 | pos_grid_local[d] = pos_grid[d]; 289 | } else { 290 | w *= pos[d]; 291 | pos_grid_local[d] = pos_grid[d] + 1; 292 | } 293 | } 294 | 295 | uint32_t index = get_grid_index(gridtype, align_corners, ch, hashmap_size, resolution, pos_grid_local); 296 | 297 | // atomicAdd for __half is slow (especially for large values), so we use __half2 if N_C % 2 == 0 298 | // TODO: use float which is better than __half, if N_C % 2 != 0 299 | if (std::is_same::value && N_C % 2 == 0) { 300 | #pragma unroll 301 | for (uint32_t c = 0; c < N_C; c += 2) { 302 | // process two __half at once (by interpreting as a __half2) 303 | __half2 v = {(__half)(w * grad_cur[c]), (__half)(w * grad_cur[c + 1])}; 304 | atomicAdd((__half2*)&grad_grid[index + c], v); 305 | } 306 | // float, or __half when N_C % 2 != 0 (which means C == 1) 307 | } else { 308 | #pragma unroll 309 | for (uint32_t c = 0; c < N_C; c++) { 310 | atomicAdd(&grad_grid[index + c], w * grad_cur[c]); 311 | } 312 | } 313 | } 314 | } 315 | 316 | 317 | template 318 | __global__ void kernel_input_backward( 319 | const scalar_t * __restrict__ grad, 320 | const scalar_t * __restrict__ dy_dx, 321 | scalar_t * __restrict__ grad_inputs, 322 | uint32_t B, uint32_t L 323 | ) { 324 | const uint32_t t = threadIdx.x + blockIdx.x * blockDim.x; 325 | if (t >= B * D) return; 326 | 327 | const uint32_t b = t / D; 328 | const uint32_t d = t - b * D; 329 | 330 | dy_dx += b * L * D * C; 331 | 332 | scalar_t result = 0; 333 | 334 | # pragma unroll 335 | for (int l = 0; l < L; l++) { 336 | # pragma unroll 337 | for (int ch = 0; ch < C; ch++) { 338 | result += grad[l * B * C + b * C + ch] * dy_dx[l * D * C + d * C + ch]; 339 | } 340 | } 341 | 342 | grad_inputs[t] = result; 343 | } 344 | 345 | 346 | template 347 | void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) { 348 | static constexpr uint32_t N_THREAD = 512; 349 | const dim3 blocks_hashgrid = { div_round_up(B, N_THREAD), L, 1 }; 350 | switch (C) { 351 | case 1: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; 352 | case 2: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; 353 | case 4: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; 354 | case 8: kernel_grid<<>>(inputs, embeddings, offsets, outputs, B, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; 355 | default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 356 | } 357 | } 358 | 359 | // inputs: [B, D], float, in [0, 1] 360 | // embeddings: [sO, C], float 361 | // offsets: [L + 1], uint32_t 362 | // outputs: [L, B, C], float (L first, so only one level of hashmap needs to fit into cache at a time.) 363 | // H: base resolution 364 | // dy_dx: [B, L * D * C] 365 | template 366 | void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) { 367 | switch (D) { 368 | case 2: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; 369 | case 3: kernel_grid_wrapper(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break; 370 | default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 371 | } 372 | 373 | } 374 | 375 | template 376 | void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { 377 | static constexpr uint32_t N_THREAD = 256; 378 | const uint32_t N_C = std::min(2u, C); // n_features_per_thread 379 | const dim3 blocks_hashgrid = { div_round_up(B * C / N_C, N_THREAD), L, 1 }; 380 | switch (C) { 381 | case 1: 382 | kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); 383 | if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); 384 | break; 385 | case 2: 386 | kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); 387 | if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); 388 | break; 389 | case 4: 390 | kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); 391 | if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); 392 | break; 393 | case 8: 394 | kernel_grid_backward<<>>(grad, inputs, embeddings, offsets, grad_embeddings, B, L, S, H, gridtype, align_corners); 395 | if (calc_grad_inputs) kernel_input_backward<<>>(grad, dy_dx, grad_inputs, B, L); 396 | break; 397 | default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 398 | } 399 | } 400 | 401 | 402 | // grad: [L, B, C], float 403 | // inputs: [B, D], float, in [0, 1] 404 | // embeddings: [sO, C], float 405 | // offsets: [L + 1], uint32_t 406 | // grad_embeddings: [sO, C] 407 | // H: base resolution 408 | template 409 | void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) { 410 | switch (D) { 411 | case 2: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break; 412 | case 3: kernel_grid_backward_wrapper(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break; 413 | default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."}; 414 | } 415 | } 416 | 417 | 418 | 419 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners) { 420 | CHECK_CUDA(inputs); 421 | CHECK_CUDA(embeddings); 422 | CHECK_CUDA(offsets); 423 | CHECK_CUDA(outputs); 424 | CHECK_CUDA(dy_dx); 425 | 426 | CHECK_CONTIGUOUS(inputs); 427 | CHECK_CONTIGUOUS(embeddings); 428 | CHECK_CONTIGUOUS(offsets); 429 | CHECK_CONTIGUOUS(outputs); 430 | CHECK_CONTIGUOUS(dy_dx); 431 | 432 | CHECK_IS_FLOATING(inputs); 433 | CHECK_IS_FLOATING(embeddings); 434 | CHECK_IS_INT(offsets); 435 | CHECK_IS_FLOATING(outputs); 436 | CHECK_IS_FLOATING(dy_dx); 437 | 438 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 439 | embeddings.scalar_type(), "grid_encode_forward", ([&] { 440 | grid_encode_forward_cuda(inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), outputs.data_ptr(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr(), gridtype, align_corners); 441 | })); 442 | } 443 | 444 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners) { 445 | CHECK_CUDA(grad); 446 | CHECK_CUDA(inputs); 447 | CHECK_CUDA(embeddings); 448 | CHECK_CUDA(offsets); 449 | CHECK_CUDA(grad_embeddings); 450 | CHECK_CUDA(dy_dx); 451 | CHECK_CUDA(grad_inputs); 452 | 453 | CHECK_CONTIGUOUS(grad); 454 | CHECK_CONTIGUOUS(inputs); 455 | CHECK_CONTIGUOUS(embeddings); 456 | CHECK_CONTIGUOUS(offsets); 457 | CHECK_CONTIGUOUS(grad_embeddings); 458 | CHECK_CONTIGUOUS(dy_dx); 459 | CHECK_CONTIGUOUS(grad_inputs); 460 | 461 | CHECK_IS_FLOATING(grad); 462 | CHECK_IS_FLOATING(inputs); 463 | CHECK_IS_FLOATING(embeddings); 464 | CHECK_IS_INT(offsets); 465 | CHECK_IS_FLOATING(grad_embeddings); 466 | CHECK_IS_FLOATING(dy_dx); 467 | CHECK_IS_FLOATING(grad_inputs); 468 | 469 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 470 | grad.scalar_type(), "grid_encode_backward", ([&] { 471 | grid_encode_backward_cuda(grad.data_ptr(), inputs.data_ptr(), embeddings.data_ptr(), offsets.data_ptr(), grad_embeddings.data_ptr(), B, D, C, L, S, H, calc_grad_inputs, dy_dx.data_ptr(), grad_inputs.data_ptr(), gridtype, align_corners); 472 | })); 473 | 474 | } 475 | -------------------------------------------------------------------------------- /gridencoder/src/gridencoder.h: -------------------------------------------------------------------------------- 1 | #ifndef _HASH_ENCODE_H 2 | #define _HASH_ENCODE_H 3 | 4 | #include 5 | #include 6 | 7 | // inputs: [B, D], float, in [0, 1] 8 | // embeddings: [sO, C], float 9 | // offsets: [L + 1], uint32_t 10 | // outputs: [B, L * C], float 11 | // H: base resolution 12 | void grid_encode_forward(const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, at::Tensor dy_dx, const uint32_t gridtype, const bool align_corners); 13 | void grid_encode_backward(const at::Tensor grad, const at::Tensor inputs, const at::Tensor embeddings, const at::Tensor offsets, at::Tensor grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, const at::Tensor dy_dx, at::Tensor grad_inputs, const uint32_t gridtype, const bool align_corners); 14 | 15 | #endif -------------------------------------------------------------------------------- /misc.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def mlp( 5 | input_dim, 6 | hidden_dim, 7 | output_dim, 8 | hidden_depth, 9 | output_mod=None, 10 | batchnorm=False, 11 | activation=nn.ReLU, 12 | ): 13 | if hidden_depth == 0: 14 | mods = [nn.Linear(input_dim, output_dim)] 15 | else: 16 | mods = ( 17 | [nn.Linear(input_dim, hidden_dim), activation(inplace=True)] 18 | if not batchnorm 19 | else [ 20 | nn.Linear(input_dim, hidden_dim), 21 | nn.BatchNorm1d(hidden_dim), 22 | activation(inplace=True), 23 | ] 24 | ) 25 | for _ in range(hidden_depth - 1): 26 | mods += ( 27 | [nn.Linear(hidden_dim, hidden_dim), activation(inplace=True)] 28 | if not batchnorm 29 | else [ 30 | nn.Linear(hidden_dim, hidden_dim), 31 | nn.BatchNorm1d(hidden_dim), 32 | activation(inplace=True), 33 | ] 34 | ) 35 | mods.append(nn.Linear(hidden_dim, output_dim)) 36 | if output_mod is not None: 37 | mods.append(output_mod) 38 | trunk = nn.Sequential(*mods) 39 | return trunk 40 | 41 | 42 | def weight_init(m): 43 | """Custom weight init for Conv2D and Linear layers.""" 44 | if isinstance(m, nn.Linear): 45 | nn.init.orthogonal_(m.weight.data) 46 | if hasattr(m.bias, "data"): 47 | m.bias.data.fill_(0.0) 48 | 49 | 50 | class MLP(nn.Module): 51 | def __init__( 52 | self, 53 | input_dim, 54 | hidden_dim, 55 | output_dim, 56 | hidden_depth, 57 | output_mod=None, 58 | batchnorm=False, 59 | activation=nn.ReLU, 60 | ): 61 | super().__init__() 62 | self.trunk = mlp( 63 | input_dim, 64 | hidden_dim, 65 | output_dim, 66 | hidden_depth, 67 | output_mod, 68 | batchnorm=batchnorm, 69 | activation=activation, 70 | ) 71 | self.apply(weight_init) 72 | 73 | def forward(self, x): 74 | return self.trunk(x) 75 | 76 | 77 | class ImplicitDataparallel(nn.DataParallel): 78 | def compute_loss(self, *args, **kwargs): 79 | return self.module.compute_loss(*args, **kwargs) 80 | 81 | @property 82 | def temperature(self): 83 | return self.module.temperature 84 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # CLIP-Fields: Weakly Supervised Semantic Fields for Robotic Memory 2 | ### Teaching robots in the real world to respond to natural language queries with zero human labels — using pretrained large language models (LLMs), visual language models (VLMs), and neural fields. 3 | 4 | [[Paper]](https://arxiv.org/abs/2210.05663) [[Website]](https://mahis.life/clip-fields/) [[Code]](https://github.com/notmahi/clip-fields) [[Data]](https://osf.io/famgv) [[Video]](https://youtu.be/bKu7GvRiSQU) 5 | 6 | Authors: [Mahi Shafiullah](https://mahis.life), [Chris Paxton](https://cpaxton.github.io/), [Lerrel Pinto](https://lerrelpinto.com), [Soumith Chintala](https://soumith.ch), Arthur Szlam. 7 | 8 | https://user-images.githubusercontent.com/3000253/195213301-43eae6e8-4516-4b8d-98e7-633c607c6616.mp4 9 | 10 | **Tl;dr** CLIP-Field is a novel weakly supervised approach for learning a semantic robot memory that can respond to natural language queries solely from raw RGB-D and odometry data with no extra human labelling. It combines the image and language understanding capabilites of novel vision-language models (VLMs) like CLIP, large language models like sentence BERT, and open-label object detection models like Detic, and with spatial understanding capabilites of neural radiance field (NeRF) style architectures to build a spatial database that holds semantic information in it. 11 | 12 | ## Installation 13 | To properly install this repo and all the dependencies, follow these instructions. 14 | 15 | ``` 16 | # Clone this repo. 17 | git clone --recursive https://github.com/notmahi/clip-fields 18 | cd clip-fields 19 | 20 | # Create conda environment and install the dependencies. 21 | conda create -n cf python=3.8 22 | conda activate cf 23 | conda install -y pytorch torchvision torchaudio cudatoolkit=11.8 -c pytorch-lts -c nvidia 24 | pip install -r requirements.txt 25 | 26 | # Install the hashgrid encoder with the relevant cuda module. 27 | cd gridencoder 28 | # For this part, it may be necessary to find out what your nvcc path is and use that, 29 | # For me $which nvcc gives public/apps/cuda/11.8/bin/nvcc, so I used the following part 30 | # export CUDA_HOME=/public/apps/cuda/11.8 31 | python setup.py install 32 | cd .. 33 | ``` 34 | ## Interactive Tutorial and Evaluation 35 | We have an interactive tutorial and evaluation notebook that you can use to explore the model and evaluate it on your own data. You can find them in the [`demo/`](https://github.com/notmahi/clip-fields/tree/main/demo) directory, that you can run after installing the dependencies. 36 | 37 | ## Training a CLIP-Field directly 38 | Once you have the dependencies installed, you can run the training script `train.py` with any [.r3d](https://record3d.app/) files that you have! If you just want to try out a sample, download the [sample data](https://osf.io/famgv) `nyu.r3d` and run the following command. 39 | 40 | ``` 41 | python train.py dataset_path=nyu.r3d 42 | ``` 43 | 44 | If you want to use LSeg as an additional source of open-label annotations, you should download the [LSeg demo model](https://github.com/isl-org/lang-seg#-try-demo-now) and place it in the `path_to_LSeg/checkpoints/demo_e200.ckpt`. Then, you can run the following command. 45 | 46 | ``` 47 | python train.py dataset_path=nyu.r3d use_lseg=true 48 | ``` 49 | 50 | You can check out the `config/train.yaml` for a list of possible configuration options. In particular, if you want to train with any particular set of labels, you can specify them in the `custom_labels` field in `config/train.yaml`. 51 | 52 | 53 | ## Acknowledgements 54 | We would like to thank the following projects for making their code and models available, which we relied upon heavily in this work. 55 | * [CLIP](https://github.com/openai/CLIP) with [MIT License](https://github.com/openai/CLIP/blob/main/LICENSE) 56 | * [Detic](https://github.com/facebookresearch/Detic/) with [Apache License 2.0](https://github.com/facebookresearch/Detic/blob/main/LICENSE) 57 | * [Torch NGP](https://github.com/ashawkey/torch-ngp) with [MIT License](https://github.com/ashawkey/torch-ngp/blob/main/LICENSE) 58 | * [LSeg](https://github.com/isl-org/lang-seg) with [MIT License](https://github.com/isl-org/lang-seg/blob/main/LICENSE) 59 | * [Sentence BERT](https://www.sbert.net/) with [Apache License 2.0](https://github.com/UKPLab/sentence-transformers/blob/master/LICENSE) 60 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | timm==0.5.4 2 | hydra-core==1.2.0 3 | git+https://github.com/openai/CLIP.git 4 | ftfy==6.1.1 5 | tqdm==4.64.0 6 | regex==2019.11.1 7 | point-transformer-pytorch==0.1.5 8 | sentence_transformers==2.2.2 9 | numpy-quaternion==2022.4.2 10 | git+https://github.com/facebookresearch/detectron2.git 11 | pandas==1.4.3 12 | pyntcloud==0.3.1 13 | hydra-submitit-launcher==1.2.0 14 | torch-encoding @ git+https://github.com/zhanghang1989/PyTorch-Encoding@c959dab8312b637fcc7edce83607acb4b0f82645 15 | torchmetrics==0.6.0 16 | sentence-transformers==2.2.2 17 | opencv-python==4.5.5.64 18 | imageio==2.19.3 19 | altair==4.2.0 20 | streamlit==1.12.2 21 | protobuf==3.20.1 22 | matplotlib==3.5.2 23 | test-tube==0.7.5 24 | wandb==0.13.3 25 | omegaconf==2.2.2 26 | numpy==1.21.4 27 | einops==0.4.1 28 | pytorch-lightning==1.3.5 29 | pyquaternion==0.9.9 30 | pyliblzfse==0.4.1 31 | open3d==0.15.2 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import pprint 4 | import random 5 | from typing import Dict, Union 6 | 7 | import hydra 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import torchmetrics 12 | import tqdm 13 | from omegaconf import OmegaConf 14 | from torch.utils.data import DataLoader, Subset 15 | 16 | import wandb 17 | from dataloaders import ( 18 | R3DSemanticDataset, 19 | DeticDenseLabelledDataset, 20 | ClassificationExtractor, 21 | ) 22 | from misc import ImplicitDataparallel 23 | from grid_hash_model import GridCLIPModel 24 | 25 | 26 | SAVE_DIRECTORY = "clip_implicit_model" 27 | DEVICE = "cuda" 28 | IMAGE_TO_LABEL_CLIP_LOSS_SCALE = 1.0 29 | LABEL_TO_IMAGE_LOSS_SCALE = 1.0 30 | EXP_DECAY_COEFF = 0.5 31 | SAVE_EVERY = 5 32 | METRICS = { 33 | "accuracy": torchmetrics.Accuracy, 34 | } 35 | 36 | logger = logging.getLogger(__name__) 37 | logger.setLevel(logging.DEBUG) 38 | 39 | 40 | def seed_everything(seed: int): 41 | random.seed(seed) 42 | os.environ["PYTHONHASHSEED"] = str(seed) 43 | np.random.seed(seed) 44 | torch.manual_seed(seed) 45 | torch.cuda.manual_seed(seed) 46 | 47 | 48 | def train( 49 | clip_train_loader: DataLoader, 50 | labelling_model: Union[GridCLIPModel, ImplicitDataparallel], 51 | optim: torch.optim.Optimizer, 52 | epoch: int, 53 | classifier: ClassificationExtractor, 54 | device: Union[str, torch.device] = DEVICE, 55 | exp_decay_coeff: float = EXP_DECAY_COEFF, 56 | image_to_label_loss_ratio: float = IMAGE_TO_LABEL_CLIP_LOSS_SCALE, 57 | label_to_image_loss_ratio: float = LABEL_TO_IMAGE_LOSS_SCALE, 58 | disable_tqdm: bool = False, 59 | metric_calculators: Dict[str, Dict[str, torchmetrics.Metric]] = {}, 60 | ): 61 | total_loss = 0 62 | label_loss = 0 63 | image_loss = 0 64 | classification_loss = 0 65 | total_samples = 0 66 | total_classification_loss = 0 67 | labelling_model.train() 68 | total = len(clip_train_loader) 69 | for clip_data_dict in tqdm.tqdm( 70 | clip_train_loader, 71 | total=total, 72 | disable=disable_tqdm, 73 | desc=f"Training epoch {epoch}", 74 | ): 75 | xyzs = clip_data_dict["xyz"].to(device) 76 | clip_labels = clip_data_dict["clip_vector"].to(device) 77 | clip_image_labels = clip_data_dict["clip_image_vector"].to(device) 78 | image_weights = torch.exp(-exp_decay_coeff * clip_data_dict["distance"]).to( 79 | device 80 | ) 81 | label_weights = clip_data_dict["semantic_weight"].to(device) 82 | image_label_index: torch.Tensor = ( 83 | clip_data_dict["img_idx"].to(device).reshape(-1, 1) 84 | ) 85 | language_label_index: torch.Tensor = ( 86 | clip_data_dict["label"].to(device).reshape(-1, 1) 87 | ) 88 | 89 | (predicted_label_latents, predicted_image_latents) = labelling_model(xyzs) 90 | # Calculate the loss from the image to label side. 91 | batch_size = len(image_label_index) 92 | image_label_mask: torch.Tensor = ( 93 | image_label_index != image_label_index.t() 94 | ).float() + torch.eye(batch_size, device=device) 95 | language_label_mask: torch.Tensor = ( 96 | language_label_index != language_label_index.t() 97 | ).float() + torch.eye(batch_size, device=device) 98 | 99 | # For logging purposes, keep track of negative samples per point. 100 | image_label_mask.requires_grad = False 101 | language_label_mask.requires_grad = False 102 | contrastive_loss_labels = labelling_model.compute_loss( 103 | predicted_label_latents, 104 | clip_labels, 105 | label_mask=language_label_mask, 106 | weights=label_weights, 107 | ) 108 | contrastive_loss_images = labelling_model.compute_loss( 109 | predicted_image_latents, 110 | clip_image_labels, 111 | label_mask=image_label_mask, 112 | weights=image_weights, 113 | ) 114 | del ( 115 | image_label_mask, 116 | image_label_index, 117 | language_label_mask, 118 | ) 119 | 120 | # Now figure out semantic segmentation. 121 | with torch.no_grad(): 122 | class_probs = classifier.calculate_classifications( 123 | model_text_features=predicted_label_latents, 124 | model_image_features=predicted_image_latents, 125 | ) 126 | # Now figure out semantic accuracy and loss. 127 | semseg_mask = torch.logical_and( 128 | language_label_index != -1, 129 | language_label_index < classifier.total_label_classes, 130 | ).squeeze(-1) 131 | if not torch.any(semseg_mask): 132 | classification_loss = torch.zeros_like(contrastive_loss_images) 133 | else: 134 | # Figure out the right classes. 135 | masked_class_prob = class_probs[semseg_mask] 136 | masked_labels = language_label_index[semseg_mask].squeeze(-1).long() 137 | classification_loss = F.cross_entropy( 138 | torch.log(masked_class_prob), 139 | masked_labels, 140 | ) 141 | if metric_calculators.get("semantic"): 142 | for _, calculators in metric_calculators["semantic"].items(): 143 | _ = calculators(masked_class_prob, masked_labels) 144 | 145 | contrastive_loss = ( 146 | image_to_label_loss_ratio * contrastive_loss_images 147 | + label_to_image_loss_ratio * contrastive_loss_labels 148 | ) 149 | 150 | optim.zero_grad(set_to_none=True) 151 | contrastive_loss.backward() 152 | optim.step() 153 | # Clip the temperature term for stability 154 | labelling_model.temperature.data = torch.clamp( 155 | labelling_model.temperature.data, max=np.log(100.0) 156 | ) 157 | label_loss += contrastive_loss_labels.detach().cpu().item() 158 | image_loss += contrastive_loss_images.detach().cpu().item() 159 | total_classification_loss += classification_loss.detach().cpu().item() 160 | total_loss += contrastive_loss.detach().cpu().item() 161 | total_samples += 1 162 | 163 | to_log = { 164 | "train_avg/contrastive_loss_labels": label_loss / total_samples, 165 | "train_avg/contrastive_loss_images": image_loss / total_samples, 166 | "train_avg/semseg_loss": total_classification_loss / total_samples, 167 | "train_avg/loss_sum": total_loss / total_samples, 168 | "train_avg/labelling_temp": torch.exp(labelling_model.temperature.data.detach()) 169 | .cpu() 170 | .item(), 171 | } 172 | for metric_dict in metric_calculators.values(): 173 | for metric_name, metric in metric_dict.items(): 174 | try: 175 | to_log[f"train_avg/{metric_name}"] = ( 176 | metric.compute().detach().cpu().item() 177 | ) 178 | except RuntimeError as e: 179 | to_log[f"train_avg/{metric_name}"] = 0.0 180 | metric.reset() 181 | wandb.log(to_log) 182 | logger.debug(pprint.pformat(to_log, indent=4, width=1)) 183 | return total_loss 184 | 185 | 186 | def save( 187 | labelling_model: Union[ImplicitDataparallel, GridCLIPModel], 188 | optim: torch.optim.Optimizer, 189 | epoch: int, 190 | save_directory: str = SAVE_DIRECTORY, 191 | saving_dataparallel: bool = False, 192 | ): 193 | if saving_dataparallel: 194 | to_save = labelling_model.module 195 | else: 196 | to_save = labelling_model 197 | state_dict = { 198 | "model": to_save.state_dict(), 199 | "optim": optim.state_dict(), 200 | "epoch": epoch, 201 | } 202 | torch.save( 203 | state_dict, 204 | f"{save_directory}/implicit_scene_label_model_latest.pt", 205 | ) 206 | return 0 207 | 208 | 209 | def get_real_dataset(cfg): 210 | if cfg.use_cache: 211 | location_train_dataset = torch.load(cfg.saved_dataset_path) 212 | else: 213 | view_dataset = R3DSemanticDataset(cfg.dataset_path, cfg.custom_labels) 214 | if cfg.sample_freq != 1: 215 | view_dataset = Subset( 216 | view_dataset, 217 | torch.arange(0, len(view_dataset), cfg.sample_freq), 218 | ) 219 | location_train_dataset = DeticDenseLabelledDataset( 220 | view_dataset, 221 | clip_model_name=cfg.web_models.clip, 222 | sentence_encoding_model_name=cfg.web_models.sentence, 223 | device=cfg.device, 224 | detic_threshold=cfg.detic_threshold, 225 | subsample_prob=cfg.subsample_prob, 226 | use_lseg=cfg.use_lseg, 227 | use_extra_classes=cfg.use_extra_classes, 228 | use_gt_classes=cfg.use_gt_classes_in_detic, 229 | visualize_results=cfg.visualize_detic_results, 230 | visualization_path=cfg.detic_visualization_path, 231 | ) 232 | if cfg.cache_result: 233 | torch.save(location_train_dataset, cfg.cache_path) 234 | return location_train_dataset 235 | 236 | 237 | @hydra.main(version_base="1.2", config_path="configs", config_name="train.yaml") 238 | def main(cfg): 239 | seed_everything(cfg.seed) 240 | # Set up single thread tokenizer. 241 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 242 | real_dataset: DeticDenseLabelledDataset = get_real_dataset(cfg) 243 | # Setup our model with min and max coordinates. 244 | max_coords, _ = real_dataset._label_xyz.max(dim=0) 245 | min_coords, _ = real_dataset._label_xyz.min(dim=0) 246 | logger.debug(f"Environment bounds: max {max_coords} min {min_coords}") 247 | train_classifier = ClassificationExtractor( 248 | clip_model_name=cfg.web_models.clip, 249 | sentence_model_name=cfg.web_models.sentence, 250 | class_names=real_dataset._all_classes, 251 | device=cfg.device, 252 | ) 253 | 254 | # Set up our metrics on this dataset. 255 | train_metric_calculators = {} 256 | train_class_count = {"semantic": train_classifier.total_label_classes} 257 | average_style = ["micro", "macro", "weighted"] 258 | for classes, counts in train_class_count.items(): 259 | train_metric_calculators[classes] = {} 260 | for metric_name, metric_cls in METRICS.items(): 261 | for avg in average_style: 262 | if "accuracy" in metric_name: 263 | new_metric = metric_cls( 264 | num_classes=counts, average=avg, multiclass=True 265 | ).to(cfg.device) 266 | train_metric_calculators[classes][ 267 | f"{classes}_{metric_name}_{avg}" 268 | ] = new_metric 269 | 270 | if torch.cuda.device_count() > 1 and cfg.dataparallel: 271 | batch_multiplier = torch.cuda.device_count() 272 | else: 273 | batch_multiplier = 1 274 | 275 | clip_train_loader = DataLoader( 276 | real_dataset, 277 | batch_size=batch_multiplier * cfg.batch_size, 278 | shuffle=True, 279 | pin_memory=True, 280 | num_workers=cfg.num_workers, 281 | ) 282 | logger.debug(f"Total train dataset sizes: {len(real_dataset)}") 283 | 284 | labelling_model = GridCLIPModel( 285 | image_rep_size=real_dataset[0]["clip_image_vector"].shape[-1], 286 | text_rep_size=real_dataset[0]["clip_vector"].shape[-1], 287 | mlp_depth=cfg.mlp_depth, 288 | mlp_width=cfg.mlp_width, 289 | log2_hashmap_size=cfg.log2_hashmap_size, 290 | num_levels=cfg.num_grid_levels, 291 | level_dim=cfg.level_dim, 292 | per_level_scale=cfg.per_level_scale, 293 | max_coords=max_coords, 294 | min_coords=min_coords, 295 | ).to(cfg.device) 296 | optim = torch.optim.Adam( 297 | labelling_model.parameters(), 298 | lr=cfg.lr, 299 | betas=tuple(cfg.betas), 300 | weight_decay=cfg.weight_decay, 301 | ) 302 | 303 | save_directory = cfg.save_directory 304 | state_dict = "{}/implicit_scene_label_model_latest.pt".format(save_directory) 305 | 306 | if os.path.exists("{}/".format(save_directory)) and os.path.exists(state_dict): 307 | logger.info(f"Resuming job from: {state_dict}") 308 | loaded_dict = torch.load(state_dict) 309 | labelling_model.load_state_dict(loaded_dict["model"]) 310 | optim.load_state_dict(loaded_dict["optim"]) 311 | epoch = loaded_dict["epoch"] 312 | resume = "allow" 313 | del loaded_dict 314 | else: 315 | logger.info("Could not find old runs, starting fresh...") 316 | os.makedirs("{}/".format(save_directory), exist_ok=True) 317 | resume = False 318 | epoch = 0 319 | 320 | dataparallel = False 321 | if torch.cuda.device_count() > 1 and cfg.dataparallel: 322 | labelling_model = ImplicitDataparallel(labelling_model) 323 | dataparallel = True 324 | 325 | wandb.init( 326 | project=cfg.project, 327 | tags=[f"model/{cfg.model_type}"], 328 | config=OmegaConf.to_container(cfg, resolve=True), 329 | resume=resume, 330 | ) 331 | # Set the extra parameters. 332 | wandb.config.web_labelled_points = len(real_dataset) 333 | 334 | # Disable tqdm if we are running inside slurm 335 | disable_tqdm = os.environ.get("SLURM_JOB_ID") is not None 336 | while epoch <= cfg.epochs: 337 | train( 338 | clip_train_loader, 339 | labelling_model, 340 | optim, 341 | epoch, 342 | train_classifier, 343 | cfg.device, 344 | exp_decay_coeff=cfg.exp_decay_coeff, 345 | image_to_label_loss_ratio=cfg.image_to_label_loss_ratio, 346 | label_to_image_loss_ratio=cfg.label_to_image_loss_ratio, 347 | disable_tqdm=disable_tqdm, 348 | metric_calculators=train_metric_calculators, 349 | ) 350 | epoch += 1 351 | if epoch % SAVE_EVERY == 0: 352 | save( 353 | labelling_model, 354 | optim, 355 | epoch, 356 | save_directory=save_directory, 357 | saving_dataparallel=dataparallel, 358 | ) 359 | 360 | 361 | if __name__ == "__main__": 362 | main() 363 | --------------------------------------------------------------------------------