├── algos ├── __init__.py ├── goal_algos.py ├── contact_algos.py ├── traj_algos.py └── traj_optimizer.py ├── models ├── __init__.py ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── README.md │ ├── interpolate.py │ ├── simple_tokenizer.py │ ├── clip.py │ └── model.py ├── contact.py ├── perceiver.py ├── temporal.py ├── goal.py ├── attention.py ├── feature_extractors.py └── helpers.py ├── assets └── teaser.png ├── requirements.txt ├── .gitignore ├── scripts ├── download_ckpt_testdata.sh ├── prepare_third_party_modules.sh └── test_demo.sh ├── .gitmodules ├── config └── test_config.yaml ├── LICENSE ├── diffuser_utils └── guidance_params.py ├── README.md └── demos ├── optimize_affordance.py └── infer_affordance.py /algos/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethz-mrl/VidBot/HEAD/assets/teaser.png -------------------------------------------------------------------------------- /models/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | 3 | """ 4 | Modified from https://github.com/openai/CLIP 5 | """ 6 | -------------------------------------------------------------------------------- /models/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethz-mrl/VidBot/HEAD/models/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /models/clip/README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | Modified version of [CLIP](https://github.com/openai/CLIP) with support for dense patch-level feature extraction 3 | (based on [MaskCLIP](https://arxiv.org/abs/2112.01071) parametrization) and interpolation of the positional encoding. 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu117 2 | torch==1.13.1+cu117 3 | torchvision==0.14.1+cu117 4 | numpy==1.26.3 5 | transformers==4.26.1 6 | open3d 7 | opencv-python 8 | omegaconf 9 | easydict 10 | flow_vis 11 | numba 12 | scikit-image 13 | transformations 14 | ftfy 15 | einops 16 | gdown 17 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | wandb/ 2 | out/ 3 | ckpt/ 4 | 5 | # Byte-compiled / optimized / DLL files 6 | __pycache__/ 7 | *.py[cod] 8 | 9 | .vscode/* 10 | 11 | *.egg-info 12 | wandb/* 13 | logs/* 14 | datasets/*.json 15 | datasets/labels/* 16 | *.npz 17 | *.npy 18 | .tmp/ 19 | results/ 20 | # third_party/ 21 | pretrained/ 22 | *.zip 23 | vis/ 24 | inference_results/ 25 | *.blend* 26 | manip_render/ 27 | datasets 28 | pretrained_weights/ -------------------------------------------------------------------------------- /scripts/download_ckpt_testdata.sh: -------------------------------------------------------------------------------- 1 | # mkdir -p datasets 2 | gdown https://drive.google.com/uc?id=1IDCJ-wB05sMVKdLiG0IO_OajsY0ihvpi 3 | unzip vidbot_data_demo.zip -d datasets/ 4 | rm vidbot_data_demo.zip 5 | 6 | gdown https://drive.google.com/uc?id=14ByEGX4zKB7VjIE7fybXiq11kpRvvMqz 7 | unzip epickitchens_traj_demo.zip -d datasets/ 8 | rm epickitchens_traj_demo.zip 9 | 10 | gdown https://drive.google.com/uc?id=1pBZHU65WDwqcZTAm2TpBsLmhn6S9M82l 11 | unzip pretrained.zip 12 | rm pretrained.zip 13 | 14 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "third_party/GroundingDINO"] 2 | path = third_party/GroundingDINO 3 | url = https://github.com/IDEA-Research/GroundingDINO.git 4 | [submodule "third_party/EfficientSAM"] 5 | path = third_party/EfficientSAM 6 | url = https://github.com/yformer/EfficientSAM.git 7 | [submodule "third_party/graspnetAPI"] 8 | path = third_party/graspnetAPI 9 | url = git@github.com:HanzhiC/graspnetAPI.git 10 | [submodule "third_party/graspness_unofficial"] 11 | path = third_party/graspness_unofficial 12 | url = https://github.com/HanzhiC/graspness_unofficial.git 13 | -------------------------------------------------------------------------------- /config/test_config.yaml: -------------------------------------------------------------------------------- 1 | # Dataset and Gripper Mesh 2 | dataset_dir: ./datasets/ 3 | gripper_mesh_file: ./assets/panda_hand_mesh.obj 4 | 5 | # VidBot Affordance Prediction Models 6 | config_traj: ./pretrained/traj/config.yaml 7 | config_goal: ./pretrained/goal/config.yaml 8 | config_contact: ./pretrained/contact/config.yaml 9 | traj_ckpt: ./pretrained/traj/final.ckpt 10 | goal_ckpt: ./pretrained/goal/final.ckpt 11 | contact_ckpt: ./pretrained/contact/final.ckpt 12 | 13 | # Optional: Open-world Detectors and Grasp Detector 14 | config_detector: ./third_party/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py 15 | esam_ckpt: ./third_party/EfficientSAM/weights/efficient_sam_vitt.pt 16 | graspnet_ckpt: ./third_party/graspness_unofficial/weights/minkuresunet_kinect.tar 17 | detector_ckpt: ./third_party/GroundingDINO/weights/groundingdino_swint_ogc.pth 18 | -------------------------------------------------------------------------------- /scripts/prepare_third_party_modules.sh: -------------------------------------------------------------------------------- 1 | # Clone the third-party modules 2 | echo "Cloning third-party modules" 3 | git submodule update --init --recursive 4 | 5 | # EfficientSAM weight is already downloaded via cloning the EfficientSAM repo 6 | echo "EfficientSAM weight is already downloaded via cloning the EfficientSAM repo" 7 | 8 | # Download the weight of the GroundingDINO 9 | echo "Downloading GroundingDINO weight" 10 | mkdir -p third_party/GroundingDINO/weights && cd third_party/GroundingDINO/weights 11 | wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth 12 | cd ../../../ 13 | 14 | # Download the weight of the GraspNet 15 | echo "Downloading GraspNet weight" 16 | mkdir -p third_party/graspness_unofficial/weights && cd third_party/graspness_unofficial/weights 17 | gdown https://drive.google.com/uc?id=10o5fc8LQsbI8H0pIC2RTJMNapW9eczqF 18 | cd ../../../ 19 | 20 | echo "Done with preparing third-party modules" 21 | echo "Follow their instructions for installation!" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Hanzhi Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /diffuser_utils/guidance_params.py: -------------------------------------------------------------------------------- 1 | PARAMS1 = { 2 | "goal_weight": 100.0, 3 | "noncollide_weight": 200.0, 4 | "normal_weight": 200.0, 5 | "contact_weight": 0.0, 6 | "fine_voxel_resolution": 32, 7 | "exclude_object_points": True, 8 | } 9 | 10 | PARAMS2 = { 11 | "goal_weight": 100.0, 12 | "noncollide_weight": 200.0, 13 | "normal_weight": 10.0, 14 | "contact_weight": 0.0, 15 | "fine_voxel_resolution": 128, 16 | "exclude_object_points": False, 17 | } 18 | 19 | PARAMS3 = { 20 | "goal_weight": 100.0, 21 | "noncollide_weight": 500.0, 22 | "normal_weight": 0.0, 23 | "contact_weight": 500.0, 24 | "fine_voxel_resolution": 128, 25 | "exclude_object_points": True, 26 | } 27 | 28 | GUIDANCE_PARAMS_DICT = { 29 | "open": PARAMS1, 30 | "close": PARAMS1, 31 | "pull": PARAMS1, 32 | "push": PARAMS1, 33 | "press": PARAMS1, 34 | "pick": PARAMS2, 35 | "pickup": PARAMS2, 36 | "take": PARAMS2, 37 | "get": PARAMS2, 38 | "put": PARAMS2, 39 | "place": PARAMS2, 40 | "putdown": PARAMS2, 41 | "drop": PARAMS2, 42 | "wipe": PARAMS3, 43 | "move": PARAMS3, 44 | "other": PARAMS2, 45 | 46 | } 47 | 48 | COMMON_ACTIONS = [ 49 | "open", 50 | "close", 51 | "pull", 52 | "push", 53 | "press", 54 | "pick", 55 | "pickup", 56 | "take", 57 | "get", 58 | "put", 59 | "place", 60 | "putdown", 61 | "drop", 62 | "wipe", 63 | "move", 64 | ] 65 | -------------------------------------------------------------------------------- /algos/goal_algos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | 4 | import torch 5 | import torch.nn as nn 6 | import pytorch_lightning as pl 7 | import torch.nn.functional as F 8 | from models.goal import GoalPredictor 9 | import diffuser_utils.dataset_utils as DatasetUtils 10 | from diffuser_utils.guidance_params import COMMON_ACTIONS 11 | import pandas as pd 12 | 13 | 14 | class GoalPredictorModule(pl.LightningModule): 15 | def __init__(self, algo_config): 16 | super(GoalPredictorModule, self).__init__() 17 | self.algo_config = algo_config 18 | self.nets = nn.ModuleDict() 19 | policy_kwargs = algo_config.model 20 | 21 | self.nets["policy"] = GoalPredictor(**policy_kwargs) 22 | 23 | @torch.no_grad() 24 | def encode_action(self, data_batch, clip_model, max_length=20): 25 | action_tokens, action_feature = DatasetUtils.encode_text_clip( 26 | clip_model, 27 | [data_batch["action_text"]], 28 | max_length=max_length, 29 | device="cuda", 30 | ) 31 | 32 | action_tokens.to(self.device) 33 | action_feature.to(self.device) 34 | 35 | action_text = data_batch["action_text"] 36 | verb_text = action_text.split(" ")[0] 37 | if verb_text not in COMMON_ACTIONS: 38 | verb_text = "other" 39 | else: 40 | verb_text = verb_text.replace("-", "") 41 | verb_text = [verb_text] 42 | 43 | verb_tokens, verb_feature = DatasetUtils.encode_text_clip( 44 | clip_model, 45 | verb_text, 46 | max_length=max_length, 47 | device="cuda", 48 | ) 49 | 50 | verb_tokens.to(self.device) 51 | verb_feature.to(self.device) 52 | 53 | data_batch.update({"action_feature": action_feature.float()}) 54 | data_batch.update({"verb_feature": verb_feature.float()}) 55 | 56 | def forward(self, data_batch, training=False): 57 | # self.encode_action(data_batch) 58 | curr_policy = self.nets["policy"] 59 | outputs = curr_policy(data_batch, training) 60 | return outputs 61 | -------------------------------------------------------------------------------- /algos/contact_algos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | 4 | import torch 5 | import torch.nn as nn 6 | import pytorch_lightning as pl 7 | import torch.nn.functional as F 8 | 9 | import diffuser_utils.dataset_utils as DatasetUtils 10 | from models.contact import ContactPredictor 11 | from diffuser_utils.guidance_params import COMMON_ACTIONS 12 | 13 | 14 | class ContactPredictorModule(pl.LightningModule): 15 | def __init__(self, algo_config): 16 | super(ContactPredictorModule, self).__init__() 17 | self.algo_config = algo_config 18 | self.nets = nn.ModuleDict() 19 | 20 | # Initialize the contact former 21 | policy_kwargs = algo_config.model 22 | 23 | self.nets["policy"] = ContactPredictor(**policy_kwargs) 24 | 25 | @torch.no_grad() 26 | def encode_action(self, data_batch, clip_model, max_length=20): 27 | action_tokens, action_feature = DatasetUtils.encode_text_clip( 28 | clip_model, 29 | [data_batch["action_text"]], 30 | max_length=max_length, 31 | device="cuda", 32 | ) 33 | 34 | action_tokens.to(self.device) 35 | action_feature.to(self.device) 36 | 37 | action_text = data_batch["action_text"] 38 | verb_text = action_text.split(" ")[0] 39 | if verb_text not in COMMON_ACTIONS: 40 | verb_text = "other" 41 | else: 42 | verb_text = verb_text.replace("-", "") 43 | verb_text = [verb_text] 44 | 45 | verb_tokens, verb_feature = DatasetUtils.encode_text_clip( 46 | clip_model, 47 | verb_text, 48 | max_length=max_length, 49 | device="cuda", 50 | ) 51 | 52 | verb_tokens.to(self.device) 53 | verb_feature.to(self.device) 54 | 55 | data_batch.update({"action_feature": action_feature.float()}) 56 | data_batch.update({"verb_feature": verb_feature.float()}) 57 | 58 | def forward(self, data_batch, training=False): 59 | # self.encode_action(data_batch) 60 | curr_policy = self.nets["policy"] 61 | outputs = curr_policy(data_batch, training) 62 | return outputs 63 | -------------------------------------------------------------------------------- /scripts/test_demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Exit on error, undefined variables, and pipeline failures 4 | set -euo pipefail 5 | 6 | # Check if the --use_graspnet flag is passed as an argument 7 | USE_GRASPNET=false 8 | if [[ "$#" -gt 0 && "$1" == "--use_graspnet" ]]; then 9 | USE_GRASPNET=true 10 | fi 11 | 12 | # Define the --use_graspnet option based on the variable 13 | graspnet_option="" 14 | if [ "$USE_GRASPNET" = true ]; then 15 | graspnet_option="--use_graspnet" 16 | fi 17 | 18 | # Define array of commands with conditional --use_graspnet 19 | commands=( 20 | "python demos/infer_affordance.py -v -f 0 -i 'pickup sponge' --load_results --no_save $graspnet_option" 21 | "python demos/infer_affordance.py -v -f 0 -i 'pickup brush' --load_results --no_save $graspnet_option" 22 | "python demos/infer_affordance.py -v -f 1 -i 'place sponge' --load_results --no_save $graspnet_option" 23 | "python demos/infer_affordance.py -v -f 2 -i 'take mug' --load_results --no_save $graspnet_option" 24 | "python demos/infer_affordance.py -v -f 3 -i 'take kettle' --load_results --no_save $graspnet_option" 25 | "python demos/infer_affordance.py -v -f 4 -i 'take paper' --load_results --no_save $graspnet_option" 26 | "python demos/infer_affordance.py -v -f 5 -i 'pickup bottle' --load_results --no_save $graspnet_option" 27 | "python demos/infer_affordance.py -v -f 6 -i 'push door' --load_results --no_save $graspnet_option" 28 | "python demos/infer_affordance.py -v -f 7 -i 'open cabinet' --load_results --no_save $graspnet_option" 29 | "python demos/infer_affordance.py -v -f 8 -i 'close cabinet' --load_results --no_save $graspnet_option" 30 | "python demos/infer_affordance.py -v -f 9 -i 'pickup driller' --load_results --no_save $graspnet_option" 31 | "python demos/infer_affordance.py -v -f 10 -i 'place driller' --load_results --no_save $graspnet_option" 32 | ) 33 | 34 | # Get number of commands 35 | num_commands=${#commands[@]} 36 | 37 | # Generate shuffled indices 38 | mapfile -t shuffled_indices < <(shuf -i 0-$((num_commands - 1))) 39 | 40 | # Execute commands in shuffled order 41 | for index in "${shuffled_indices[@]}"; do 42 | echo ">> Running: ${commands[$index]}" 43 | bash -c "${commands[$index]}" 44 | done -------------------------------------------------------------------------------- /models/clip/interpolate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def interpolate_positional_embedding( 6 | positional_embedding: torch.Tensor, x: torch.Tensor, patch_size: int, w: int, h: int 7 | ): 8 | """ 9 | Interpolate the positional encoding for CLIP to the number of patches in the image given width and height. 10 | Modified from DINO ViT `interpolate_pos_encoding` method. 11 | https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L174 12 | """ 13 | assert positional_embedding.ndim == 2, "pos_encoding must be 2D" 14 | 15 | # Number of patches in input 16 | num_patches = x.shape[1] - 1 17 | # Original number of patches for square images 18 | num_og_patches = positional_embedding.shape[0] - 1 19 | 20 | if num_patches == num_og_patches and w == h: 21 | # No interpolation needed 22 | return positional_embedding.to(x.dtype) 23 | 24 | dim = x.shape[-1] 25 | class_pos_embed = positional_embedding[:1] # (1, dim) 26 | patch_pos_embed = positional_embedding[1:] # (num_og_patches, dim) 27 | 28 | # Compute number of tokens 29 | w0 = w // patch_size 30 | h0 = h // patch_size 31 | assert w0 * h0 == num_patches, "Number of patches does not match" 32 | 33 | # Add a small number to avoid floating point error in the interpolation 34 | # see discussion at https://github.com/facebookresearch/dino/issues/8 35 | w0, h0 = w0 + 0.1, h0 + 0.1 36 | 37 | # Interpolate 38 | patch_per_ax = int(np.sqrt(num_og_patches)) 39 | patch_pos_embed_interp = torch.nn.functional.interpolate( 40 | patch_pos_embed.reshape(1, patch_per_ax, patch_per_ax, dim).permute(0, 3, 1, 2), 41 | # (1, dim, patch_per_ax, patch_per_ax) 42 | scale_factor=(w0 / patch_per_ax, h0 / patch_per_ax), 43 | mode="bicubic", 44 | align_corners=False, 45 | recompute_scale_factor=False, 46 | ) # (1, dim, w0, h0) 47 | assert ( 48 | int(w0) == patch_pos_embed_interp.shape[-2] and int(h0) == patch_pos_embed_interp.shape[-1] 49 | ), "Interpolation error." 50 | 51 | patch_pos_embed_interp = patch_pos_embed_interp.permute(0, 2, 3, 1).reshape(-1, dim) # (w0 * h0, dim) 52 | # Concat class token embedding and interpolated patch embeddings 53 | pos_embed_interp = torch.cat([class_pos_embed, patch_pos_embed_interp], dim=0) # (w0 * h0 + 1, dim) 54 | return pos_embed_interp.to(x.dtype) 55 | -------------------------------------------------------------------------------- /models/contact.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import einops 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import tqdm 7 | import numpy as np 8 | 9 | from models.layers_2d import ResNet50Encoder, ResNet50Decoder, PositionalEmbeddingV2 10 | import einops 11 | import torchvision.transforms as transforms 12 | from diffuser_utils.dataset_utils import compute_model_size 13 | from models.perceiver import FeaturePerceiver 14 | 15 | 16 | class ContactPredictor(pl.LightningModule): 17 | def __init__( 18 | self, 19 | in_channels=3, 20 | out_channels=2, 21 | use_skip=True, 22 | encode_action=False, 23 | use_min_loss=False, 24 | **kwargs, 25 | ): 26 | super(ContactPredictor, self).__init__() 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.use_skip = use_skip 30 | self.encode_action = encode_action 31 | self.use_min_loss = use_min_loss 32 | self.transform = transforms.Compose( 33 | [ 34 | transforms.Normalize([0.485, 0.456, 0.406], [ 35 | 0.229, 0.224, 0.225]), 36 | ] 37 | ) 38 | self.visual = ResNet50Encoder(input_channels=in_channels) 39 | self.decoder = ResNet50Decoder( 40 | output_channels=out_channels, use_skip=use_skip) 41 | self.bottleneck_feature_key = self.decoder.bottleneck_feature_key 42 | self.latent_dim = self.decoder.latent_dim 43 | self.visual_proj = nn.Linear(self.latent_dim, 512) 44 | # self.proj = nn.Linear(self.latent_dim, 512) 45 | 46 | if self.encode_action: 47 | self.action_proj = nn.Linear(512, 512) 48 | self.action_fuser = FeaturePerceiver( 49 | transition_dim=512, condition_dim=512, time_emb_dim=0 50 | ) 51 | self.final_proj = nn.Linear(self.action_fuser.last_dim, 512) 52 | else: 53 | print("No action encoding") 54 | 55 | self.fuser = nn.Sequential( 56 | nn.TransformerEncoderLayer( 57 | d_model=512, nhead=4, dim_feedforward=512, batch_first=True 58 | ), 59 | nn.Linear(512, self.latent_dim), 60 | ) 61 | self.positional_embedding = PositionalEmbeddingV2( 62 | d_model=512, max_len=400) 63 | 64 | def forward(self, data_batch, training=False): 65 | outputs = {} 66 | object_color_key = "object_color" 67 | object_depth_key = "object_depth" 68 | if training: 69 | object_color_key += "_aug" 70 | 71 | inputs = self.transform(data_batch[object_color_key]) 72 | if self.in_channels == 4: 73 | object_depth = data_batch[object_depth_key][:, None] 74 | inputs = torch.cat([inputs, object_depth], dim=1) 75 | features = self.visual(inputs) 76 | features = self.forward_latent(data_batch, features) 77 | pred = self.decoder(features) 78 | 79 | # Post-process the prediction 80 | pred_final = [] 81 | pred_vfs = pred[:, : self.out_channels - 1] # [B, 8, H, W] 82 | pred_mask = pred[:, self.out_channels - 1:] # [B, 1, H, W] 83 | for hi in range(0, self.out_channels - 1, 2): 84 | pred_vf = pred_vfs[:, hi: hi + 2] # [B, 2, H, W] 85 | pred_vf = F.normalize(pred_vf, p=2, dim=1) # [-1, 1] 86 | pred_vf = pred_vf.clamp(-1, 1) 87 | pred_final.append(pred_vf) 88 | pred_final.append(pred_mask) 89 | pred_final = torch.cat(pred_final, dim=1) 90 | outputs["pred"] = pred_final # [B, 8+1, H, W] 91 | return outputs 92 | 93 | def forward_latent(self, data_batch, features): 94 | latent = features[self.bottleneck_feature_key] 95 | h, w = latent.shape[-2:] 96 | latent = einops.rearrange(latent, "b c h w -> b (h w) c") 97 | latent = self.visual_proj(latent) 98 | # latent = self.proj(latent) 99 | if self.encode_action: 100 | action_feature = data_batch["action_feature"][:, None] 101 | action_feature = self.action_proj(action_feature) 102 | latent = self.action_fuser(latent, action_feature) 103 | latent = self.final_proj(latent) 104 | 105 | latent = self.positional_embedding(latent) 106 | latent = self.fuser(latent) 107 | latent = einops.rearrange(latent, "b (h w) c -> b c h w", h=h, w=w) 108 | features[self.bottleneck_feature_key] = latent 109 | return features 110 | -------------------------------------------------------------------------------- /models/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from collections.abc import Sequence 5 | from functools import lru_cache 6 | 7 | import ftfy 8 | import regex as re 9 | 10 | 11 | @lru_cache() 12 | def default_bpe(): 13 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 14 | 15 | 16 | @lru_cache() 17 | def bytes_to_unicode(): 18 | """ 19 | Returns list of utf-8 byte and a corresponding list of unicode strings. 20 | The reversible bpe codes work on unicode strings. 21 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 22 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 23 | This is a signficant percentage of your normal, say, 32K bpe vocab. 24 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 25 | And avoids mapping to whitespace/control characters the bpe code barfs on. 26 | """ 27 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 28 | cs = bs[:] 29 | n = 0 30 | for b in range(2**8): 31 | if b not in bs: 32 | bs.append(b) 33 | cs.append(2**8+n) 34 | n += 1 35 | cs = [chr(n) for n in cs] 36 | return dict(zip(bs, cs)) 37 | 38 | 39 | def get_pairs(word): 40 | """Return set of symbol pairs in a word. 41 | Word is represented as tuple of symbols (symbols being variable-length strings). 42 | """ 43 | pairs = set() 44 | prev_char = word[0] 45 | for char in word[1:]: 46 | pairs.add((prev_char, char)) 47 | prev_char = char 48 | return pairs 49 | 50 | 51 | def basic_clean(text): 52 | # note: pretty hacky but it is okay! 53 | # ge: bad.this is used by the cli_multi_label.py script 54 | if not isinstance(text, str): 55 | text = ', '.join(text) 56 | 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r'\s+', ' ', text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe()): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 73 | merges = merges[1:49152-256-2+1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v+'' for v in vocab] 77 | for merge in merges: 78 | vocab.append(''.join(merge)) 79 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 80 | self.encoder = dict(zip(vocab, range(len(vocab)))) 81 | self.decoder = {v: k for k, v in self.encoder.items()} 82 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 83 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 84 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 85 | 86 | def bpe(self, token): 87 | if token in self.cache: 88 | return self.cache[token] 89 | word = tuple(token[:-1]) + ( token[-1] + '',) 90 | pairs = get_pairs(word) 91 | 92 | if not pairs: 93 | return token+'' 94 | 95 | while True: 96 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 97 | if bigram not in self.bpe_ranks: 98 | break 99 | first, second = bigram 100 | new_word = [] 101 | i = 0 102 | while i < len(word): 103 | try: 104 | j = word.index(first, i) 105 | new_word.extend(word[i:j]) 106 | i = j 107 | except: 108 | new_word.extend(word[i:]) 109 | break 110 | 111 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 112 | new_word.append(first+second) 113 | i += 2 114 | else: 115 | new_word.append(word[i]) 116 | i += 1 117 | new_word = tuple(new_word) 118 | word = new_word 119 | if len(word) == 1: 120 | break 121 | else: 122 | pairs = get_pairs(word) 123 | word = ' '.join(word) 124 | self.cache[token] = word 125 | return word 126 | 127 | def encode(self, text): 128 | bpe_tokens = [] 129 | text = whitespace_clean(basic_clean(text)).lower() 130 | for token in re.findall(self.pat, text): 131 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 132 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 133 | return bpe_tokens 134 | 135 | def decode(self, tokens): 136 | text = ''.join([self.decoder[token] for token in tokens]) 137 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 138 | return text 139 | -------------------------------------------------------------------------------- /algos/traj_algos.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import pytorch_lightning as pl 8 | import torch.nn.functional as F 9 | import diffuser_utils.dataset_utils as DatasetUtils 10 | from models.diffuser import DiffuserModel 11 | from models.helpers import EMA 12 | import open3d as o3d 13 | from diffuser_utils.guidance_params import COMMON_ACTIONS 14 | import torchvision 15 | 16 | 17 | class TrajectoryDiffusionModule(pl.LightningModule): 18 | def __init__(self, algo_config): 19 | super(TrajectoryDiffusionModule, self).__init__() 20 | self.algo_config = algo_config 21 | self.nets = nn.ModuleDict() 22 | 23 | # Initialize the diffuser 24 | policy_kwargs = algo_config.model 25 | self.nets["policy"] = DiffuserModel(**policy_kwargs) 26 | 27 | 28 | @torch.no_grad() 29 | def encode_action(self, data_batch, clip_model, max_length=20): 30 | action_tokens, action_feature = DatasetUtils.encode_text_clip( 31 | clip_model, 32 | [data_batch["action_text"]], 33 | max_length=max_length, 34 | device="cuda", 35 | ) 36 | 37 | action_tokens.to(self.device) 38 | action_feature.to(self.device) 39 | 40 | action_text = data_batch["action_text"] 41 | verb_text = action_text.split(" ")[0] 42 | if verb_text not in COMMON_ACTIONS: 43 | verb_text = "other" 44 | else: 45 | verb_text = verb_text.replace("-", "") 46 | verb_text = [verb_text] 47 | 48 | verb_tokens, verb_feature = DatasetUtils.encode_text_clip( 49 | clip_model, 50 | verb_text, 51 | max_length=max_length, 52 | device="cuda", 53 | ) 54 | 55 | verb_tokens.to(self.device) 56 | verb_feature.to(self.device) 57 | 58 | data_batch.update({"action_feature": action_feature.float()}) 59 | data_batch.update({"verb_feature": verb_feature.float()}) 60 | 61 | 62 | def forward( 63 | self, 64 | data_batch, 65 | num_samp=1, 66 | return_diffusion=False, 67 | return_guidance_losses=False, 68 | apply_guidance=False, 69 | class_free_guide_w=0.0, 70 | guide_clean=False, 71 | ): 72 | curr_policy = self.nets["policy"] 73 | return curr_policy( 74 | data_batch, 75 | num_samp, 76 | return_diffusion=return_diffusion, 77 | return_guidance_losses=return_guidance_losses, 78 | apply_guidance=apply_guidance, 79 | class_free_guide_w=class_free_guide_w, 80 | guide_clean=guide_clean, 81 | ) 82 | 83 | def visualize_trajectory_by_rendering( 84 | self, 85 | data_batch, 86 | config_path, 87 | window=False, 88 | return_vis=False, 89 | draw_grippers=False, 90 | **kwargs 91 | ): 92 | batch_size = len(data_batch["color"]) 93 | results = [] 94 | for i in range(batch_size): 95 | vis_o3d = [] 96 | depth = data_batch["depth"][i].cpu().numpy() 97 | color = data_batch["color"][i].cpu().numpy().transpose(1, 2, 0) 98 | intr = data_batch["intrinsics"][i].cpu().numpy() 99 | # gt_traj = data_batch["gt_trajectory"][i].cpu().numpy() 100 | 101 | # backproject 102 | points_scene, scene_ids = DatasetUtils.backproject( 103 | depth, 104 | intr, 105 | depth > 0, 106 | # np.logical_and(hand_mask == 0, depth > 0), 107 | NOCS_convention=False, 108 | ) 109 | 110 | colors_scene = color.copy()[scene_ids[0], scene_ids[1]] 111 | pcd_scene = DatasetUtils.visualize_points( 112 | points_scene, colors_scene) 113 | # gt_traj_vis = DatasetUtils.visualize_3d_trajectory( 114 | # gt_traj, size=0.02, cmap_name="viridis" 115 | # ) 116 | 117 | vis_o3d = [pcd_scene] # + gt_traj_vis 118 | if "pred_trajectories" in data_batch: 119 | print("===> Visualizing pred trajectories") 120 | pred_trajs = data_batch["pred_trajectories"][i].cpu().numpy() 121 | pred_traj_colors = DatasetUtils.random_colors(len(pred_trajs)) 122 | for pi, pred_traj in enumerate(pred_trajs): 123 | _pred_traj_vis = DatasetUtils.visualize_3d_trajectory( 124 | pred_traj, size=0.01, cmap_name="plasma" 125 | ) 126 | if len(pred_trajs) > 1: 127 | _pred_traj_vis = [ 128 | s.paint_uniform_color(pred_traj_colors[pi]) 129 | for s in _pred_traj_vis 130 | ] 131 | 132 | pred_traj_vis = _pred_traj_vis[0] 133 | for ii in range(1, len(_pred_traj_vis)): 134 | pred_traj_vis += _pred_traj_vis[ii] 135 | 136 | vis_o3d += [pred_traj_vis] 137 | 138 | if window: 139 | o3d.visualization.draw(vis_o3d) 140 | 141 | if return_vis: 142 | return vis_o3d 143 | 144 | render_dist = np.median(np.linalg.norm(points_scene, axis=1)) 145 | render_img = DatasetUtils.render_offscreen( 146 | vis_o3d, 147 | config_path, 148 | # dist=render_dist, 149 | resize_factor=0.5, 150 | ) 151 | render_img = torchvision.transforms.ToTensor()(render_img) 152 | results.append(render_img) 153 | results = torch.stack(results, dim=0) # [B, C, H, W] 154 | data_batch.update({"pred_vis": results}) 155 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | VidBot: Learning Generalizable 3D Actions from In-the-Wild 2D Human Videos for Zero-Shot Robotic Manipulation 3 |
4 | CVPR 2025 5 | 6 |
7 | 8 | Paper arXiv 9 | 10 | Project Page 11 | 12 | 13 |
14 |

15 | 16 |
17 | VidBot Teaser 18 |
19 | 20 | This is the official repository of [**VidBot: Learning Generalizable 3D Actions from In-the-Wild 2D Human Videos for Zero-Shot Robotic Manipulation**](https://arxiv.org/abs/2503.07135). For more details, please check our [**project website**](https://hanzhic.github.io/vidbot-project/). 21 | 22 | 23 | 24 | ## Installation 25 | 26 | To install VidBot, follow these steps: 27 | 28 | 1. **Clone the Repository**: 29 | ```bash 30 | git clone https://github.com/HanzhiC/vidbot.git 31 | cd vidbot 32 | ``` 33 | 34 | 2. **Install Dependencies**: 35 | ```bash 36 | # Prepare the environment 37 | conda create -n vidbot python=3.10.9 38 | conda activate vidbot 39 | 40 | # Ensure PyTorch 1.13.1 is installed, pytorch-lightning might change the PyTorch version 41 | pip install pytorch-lightning==1.8.6 42 | pip install -r requirements.txt 43 | 44 | # Install PyTorch Scatter 45 | wget https://data.pyg.org/whl/torch-1.13.0%2Bcu117/torch_scatter-2.1.1%2Bpt113cu117-cp310-cp310-linux_x86_64.whl 46 | pip install torch_scatter-2.1.1+pt113cu117-cp310-cp310-linux_x86_64.whl 47 | rm -rf torch_scatter-2.1.1+pt113cu117-cp310-cp310-linux_x86_64.whl 48 | ``` 49 | 50 | 3. **Download Pretrained Weights and Demo Dataset**: 51 | ```bash 52 | sh scripts/download_ckpt_testdata.sh 53 | ``` 54 | You can now try out VidBot with the demo data we've provided! 55 | 56 | 4. **(Optional) Install Third-Party Modules**: 57 | ```bash 58 | sh scripts/prepare_third_party_modules.sh 59 | ``` 60 | Follow the installation instructions from [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [EfficientSAM](https://github.com/yformer/EfficientSAM), [GraspNet](https://github.com/graspnet/graspness_unofficial), and [GraspNetAPI](https://github.com/graspnet/graspnetAPI) to set up these third-party modules. 61 | 62 | **Note**: The `transformers` library should be version 4.26.1. Installing `GroundingDINO` might change the version. Installing `MinkowskiEngine` for GraspNet can be painful. However, our framework can still function without GraspNet. In such cases, we will employ a simplified method to obtain the grasp poses. 63 | 64 | ## Affordance Inference 65 | 66 | 1. **Quick Start with VidBot**: To quickly explore VidBot, you don't need to install any third-party modules. After downloading the weights and demo dataset, you can use our pre-saved bounding boxes to run the inference scripts with the following command: 67 | ```bash** 68 | bash scripts/test_demo.sh 69 | ``` 70 | 71 | 2. **Testing VidBot with Your Own Data**: To test VidBot using your own data, just put your collected dataset under the `./datasets/` folder. Please ensure your data is organized to match the structure of our demo dataset: 72 | ```text 73 | YOUR_DATASET_NAME/ 74 | ├── camera_intrinsic.json 75 | ├── color 76 | │ ├── 000000.png 77 | │ ├── 000001.png 78 | │ ├── 00000X.png 79 | ├── depth 80 | │ ├── 000000.png 81 | │ ├── 000001.png 82 | │ ├── 00000X.png 83 | ``` 84 | The `camera_intrinsic.json` file should be structured as follows: 85 | ```json 86 | { 87 | "width": width, 88 | "height": height, 89 | "intrinsic_matrix": [ 90 | fx, 91 | 0, 92 | 0, 93 | 0, 94 | fy, 95 | 0, 96 | cx, 97 | cy, 98 | 1 99 | ] 100 | } 101 | ``` 102 | **We recommend using an image resolution of 1280x720.** 103 | 104 | 3. **Run the Inference Script**: To run tests with your own data, execute the following command, ensuring you understand the meaning of each input argument: 105 | ```bash 106 | python demos/infer_affordance.py 107 | --config ./config/test_config.yaml 108 | --dataset YOUR_DATASET_NAME 109 | --frame FRAME_ID 110 | --instruction YOUR_INSTRUCTION 111 | --object OBJECT_CLASS 112 | --visualize 113 | ``` 114 | If you have installed GraspNet and wish to estimate the gripper pose, add the `--use_graspnet` option to the command. 115 | 116 | 117 | ## Citation 118 | 119 | **If you find our work useful, please cite:** 120 | 121 | ```bibtex 122 | @article{chen2025vidbot, 123 | author = {Chen, Hanzhi and Sun, Boyang and Zhang, Anran and Pollefeys, Marc and Leutenegger, Stefan}, 124 | title = {{VidBot}: Learning Generalizable 3D Actions from In-the-Wild 2D Human Videos for Zero-Shot Robotic Manipulation}, 125 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference}, 126 | year = {2025}, 127 | } 128 | ``` 129 | 130 | ## Acknowledgement 131 | Our codebase is built upon [TRACE](https://github.com/nv-tlabs/trace). Partial code is borrowed from [ConvONet](https://github.com/autonomousvision/convolutional_occupancy_networks), [afford-motion](https://github.com/afford-motion/afford-motion) and [rq-vae-transformer 132 | ](https://github.com/kakaobrain/rq-vae-transformer). Thanks for their great contribution! 133 | 134 | ## License 135 | 136 | This project is licensed under the MIT License. See [LICENSE](LICENSE) for more details. 137 | -------------------------------------------------------------------------------- /models/perceiver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from typing import List, Optional, Tuple 5 | from collections import OrderedDict 6 | from einops import rearrange 7 | from models.layers_3d import SinusoidalPosEmb 8 | from models.attention import SelfAttentionBlock, CrossAttentionLayer 9 | 10 | 11 | class FeaturePerceiver(nn.Module): 12 | 13 | def __init__( 14 | self, 15 | transition_dim, 16 | condition_dim, 17 | time_emb_dim, 18 | encoder_q_input_channels=512, 19 | encoder_kv_input_channels=256, 20 | encoder_num_heads=8, 21 | encoder_widening_factor=1, 22 | encoder_dropout=0.1, 23 | encoder_residual_dropout=0.0, 24 | encoder_self_attn_num_layers=2, 25 | decoder_q_input_channels=256, 26 | decoder_kv_input_channels=512, 27 | decoder_num_heads=8, 28 | decoder_widening_factor=1, 29 | decoder_dropout=0.1, 30 | decoder_residual_dropout=0.0, 31 | ) -> None: 32 | super().__init__() 33 | 34 | self.encoder_q_input_channels = encoder_q_input_channels 35 | self.encoder_kv_input_channels = encoder_kv_input_channels 36 | self.encoder_num_heads = encoder_num_heads 37 | self.encoder_widening_factor = encoder_widening_factor 38 | self.encoder_dropout = encoder_dropout 39 | self.encoder_residual_dropout = encoder_residual_dropout 40 | self.encoder_self_attn_num_layers = encoder_self_attn_num_layers 41 | 42 | self.decoder_q_input_channels = decoder_q_input_channels 43 | self.decoder_kv_input_channels = decoder_kv_input_channels 44 | self.decoder_num_heads = decoder_num_heads 45 | self.decoder_widening_factor = decoder_widening_factor 46 | self.decoder_dropout = decoder_dropout 47 | self.decoder_residual_dropout = decoder_residual_dropout 48 | 49 | self.condition_adapter = nn.Linear( 50 | condition_dim, self.encoder_q_input_channels, bias=True 51 | ) 52 | 53 | if time_emb_dim > 0: 54 | self.time_embedding_adapter = nn.Linear( 55 | time_emb_dim, self.encoder_q_input_channels, bias=True 56 | ) 57 | else: 58 | self.time_embedding_adapter = None 59 | 60 | self.encoder_adapter = nn.Linear( 61 | transition_dim, 62 | self.encoder_kv_input_channels, 63 | bias=True, 64 | ) 65 | self.decoder_adapter = nn.Linear( 66 | self.encoder_kv_input_channels, self.decoder_q_input_channels, bias=True 67 | ) 68 | 69 | self.encoder_cross_attn = CrossAttentionLayer( 70 | num_heads=self.encoder_num_heads, 71 | num_q_input_channels=self.encoder_q_input_channels, 72 | num_kv_input_channels=self.encoder_kv_input_channels, 73 | widening_factor=self.encoder_widening_factor, 74 | dropout=self.encoder_dropout, 75 | residual_dropout=self.encoder_residual_dropout, 76 | ) 77 | 78 | self.encoder_self_attn = SelfAttentionBlock( 79 | num_layers=self.encoder_self_attn_num_layers, 80 | num_heads=self.encoder_num_heads, 81 | num_channels=self.encoder_q_input_channels, 82 | widening_factor=self.encoder_widening_factor, 83 | dropout=self.encoder_dropout, 84 | residual_dropout=self.encoder_residual_dropout, 85 | ) 86 | 87 | self.decoder_cross_attn = CrossAttentionLayer( 88 | num_heads=self.decoder_num_heads, 89 | num_q_input_channels=self.decoder_q_input_channels, 90 | num_kv_input_channels=self.decoder_kv_input_channels, 91 | widening_factor=self.decoder_widening_factor, 92 | dropout=self.decoder_dropout, 93 | residual_dropout=self.decoder_residual_dropout, 94 | ) 95 | self.last_dim = self.decoder_q_input_channels 96 | 97 | 98 | def forward( 99 | self, 100 | x, 101 | condition_feat, 102 | time_embedding=None, 103 | ): 104 | """Forward pass of the ContactMLP. 105 | 106 | Args: 107 | x: input contact map, [bs, num_points, transition_dim] 108 | condition_feat: [bs, 1, condition_dim] 109 | time_embedding: [bs, 1, time_embedding_dim] 110 | 111 | Returns: 112 | Output contact map, [bs, num_points, contact_dim] 113 | """ 114 | 115 | # encoder 116 | enc_kv = self.encoder_adapter(x) # [bs, num_points, enc_kv_dim] 117 | cond_feat = self.condition_adapter(condition_feat) # [bs, 1, enc_q_dim] 118 | if time_embedding is not None and self.time_embedding_adapter is not None: 119 | time_embedding = self.time_embedding_adapter( 120 | time_embedding 121 | ) # [bs, 1, enc_q_dim] 122 | 123 | enc_q = torch.cat([cond_feat, time_embedding], dim=1) # [bs, 1 + 1, enc_q_dim] 124 | else: 125 | enc_q = cond_feat 126 | 127 | enc_q = self.encoder_cross_attn(enc_q, enc_kv).last_hidden_state 128 | enc_q = self.encoder_self_attn(enc_q).last_hidden_state 129 | 130 | # decoder 131 | dec_kv = enc_q 132 | dec_q = self.decoder_adapter(enc_kv) # [bs, num_points, dec_q_dim] 133 | dec_q = self.decoder_cross_attn( 134 | dec_q, dec_kv 135 | ).last_hidden_state # [bs, num_points, dec_q_dim] 136 | 137 | return dec_q 138 | 139 | 140 | if __name__ == "__main__": 141 | print("Testing ContactPerceiver") 142 | feat_preceiver = FeaturePreceiver( 143 | transition_dim=64 + 3, 144 | condition_dim=512, 145 | time_emb_dim=64, 146 | ) 147 | x = torch.randn(2, 80, 64 + 3) 148 | cond = torch.randn(2, 1, 512) 149 | time = torch.randn(2, 1, 64) 150 | feat = feat_preceiver(x, cond, time) 151 | print(feat.shape) 152 | -------------------------------------------------------------------------------- /demos/optimize_affordance.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import open3d as o3d 4 | import torch 5 | from scipy.signal import savgol_filter 6 | from algos.traj_optimizer import TrajectoryOptimizer 7 | from diffuser_utils.dataset_utils import visualize_3d_trajectory, load_json, visualize_points, backproject, interpolate_trajectory 8 | 9 | def run(colmap_results, data, device, visualize=True): 10 | traj, vis, vis_scene = [], [], [] 11 | 12 | # Prepare the tensors 13 | frame_ids = data["frame_ids"] 14 | intr = data["intr"] 15 | rgbs = data["rgbs"] 16 | depths = data["depths"] 17 | hand_masks = data["masks"] 18 | obj_masks = data["obj_masks"] 19 | hand_bboxes = data["hand_bboxes"] 20 | obj_bboxes = data["obj_bboxes"] 21 | rgb_tensors, depth_tensors, mask_tensors = [], [], [] 22 | for ii, fi in enumerate(frame_ids): 23 | rgb = rgbs[ii] / 255.0 24 | depth = depths[ii] 25 | mask_hand = hand_masks[ii] 26 | mask_obj = obj_masks[ii] 27 | mask_hand = cv2.dilate(mask_hand, np.ones((5, 5), np.uint8), iterations=3) 28 | mask_obj = cv2.dilate(mask_obj, np.ones((5, 5), np.uint8), iterations=3) 29 | mask_dynamic = np.logical_or(mask_hand > 0, mask_obj > 0) 30 | mask_static = (1 - mask_dynamic).astype(np.float32) 31 | rgb_tensors.append(torch.from_numpy(rgb).permute(2, 0, 1).float()) 32 | depth_tensors.append(torch.from_numpy(depth).unsqueeze(0).float()) 33 | mask_tensors.append(torch.from_numpy(mask_static).unsqueeze(0).float()) 34 | rgb_tensors = torch.stack(rgb_tensors).to(device) 35 | depth_tensors = torch.stack(depth_tensors).to(device) 36 | mask_tensors = torch.stack(mask_tensors).to(device) 37 | height, width = rgb_tensors.shape[-2:] 38 | intr = torch.from_numpy(intr).float().to(device) 39 | 40 | # Initialize the pose and scale optimizer 41 | traj_optimizer = TrajectoryOptimizer( 42 | resolution=(height, width), 43 | lr_scale_global=0.05, 44 | lr_scale=0.1, 45 | lr_pose=0.05, 46 | num_iters_scale=10, 47 | num_iters_pose=50, 48 | device=device, 49 | ) 50 | 51 | # Optimize the global scale 52 | scale_init_tensors, scale_global_final, key_idx = traj_optimizer.optimize_global_scale( 53 | rgb_tensors, 54 | depth_tensors, 55 | mask_tensors, 56 | colmap_results, 57 | ) 58 | 59 | # Optimize the pose and scale 60 | scale_init_tensors = scale_global_final * torch.ones_like(scale_init_tensors) 61 | T_kc_final, scale_final = traj_optimizer.optimize_pose( 62 | intr, 63 | rgb_tensors, 64 | depth_tensors, 65 | mask_tensors, 66 | scale_init_tensors, 67 | scale_global_final, 68 | colmap_results, 69 | key_idx=key_idx, 70 | optimize_pose=True, 71 | verbose=False, 72 | ) 73 | 74 | # Acquire the optimized results 75 | T_kc_final = T_kc_final.detach().cpu().numpy() 76 | scale_final = scale_final.detach().cpu().numpy() 77 | intr_np = intr.clone().cpu().numpy() 78 | 79 | # Pose and scale of the first frame 80 | T_kc0 = T_kc_final[0] 81 | scale_m2c0 = scale_final[0] 82 | T_kc0[:3, 3] = T_kc0[:3, 3] / scale_m2c0 83 | 84 | for ii, fi in enumerate(frame_ids): 85 | # Pose and scale of the current frame 86 | T_kc = T_kc_final[ii] 87 | scale_m2c = scale_final[ii] 88 | T_kc[:3, 3] = T_kc[:3, 3] / scale_m2c 89 | 90 | # Transformation from current frame to the first frame 91 | T_c0c = np.linalg.inv(T_kc0) @ T_kc 92 | 93 | # Get the depth and hand bbox 94 | depth = depths[ii] 95 | hand_bbox = hand_bboxes[ii] 96 | hand_bbox_mask = np.zeros_like(depth) 97 | hand_bbox_mask[hand_bbox[1] : hand_bbox[3], hand_bbox[0] : hand_bbox[2]] = 1 98 | points_hand, scene_ids = backproject(depth, intr_np, hand_bbox_mask > 0, False) 99 | points_hand = points_hand @ T_c0c[:3, :3].T + T_c0c[:3, 3] 100 | wp = np.median(points_hand, axis=0) 101 | traj.append(wp) 102 | 103 | # Acquire the hand points in the scene 104 | if visualize: 105 | hand_seg_mask = hand_masks[ii] 106 | hand_seg_mask = cv2.erode(hand_seg_mask, np.ones((5, 5), np.uint8), iterations=2) 107 | hand_seg_mask = hand_seg_mask * (hand_bbox_mask > 0) 108 | points_hand_scene, scene_ids = backproject(depth, intr_np, hand_seg_mask > 0, False) 109 | point_colors_scene = rgbs[ii][scene_ids[0], scene_ids[1]] 110 | point_colors_scene = point_colors_scene / 255.0 111 | points_hand_scene = points_hand_scene @ T_c0c[:3, :3].T + T_c0c[:3, 3] 112 | pcd_scene = visualize_points(points_hand_scene, point_colors_scene) 113 | vis_scene.append(pcd_scene) 114 | 115 | # Acquire the background points 116 | if ii == 0: 117 | rgb_orig = rgbs[ii] / 255.0 118 | hand_seg_mask = hand_masks[ii] 119 | points_orig, scene_ids = backproject(depth, intr_np, hand_seg_mask == 0, False) 120 | point_colors_orig = rgb_orig[scene_ids[0], scene_ids[1]] 121 | pcd_orig = visualize_points(points_orig, point_colors_orig) 122 | 123 | # Build trajectory and visualize 124 | traj = np.array(traj) 125 | traj = savgol_filter(traj, len(traj) - 1, (len(traj) + 1) // 2, axis=0) 126 | fill_indices = frame_ids - frame_ids[0] 127 | traj_smooth, _ = interpolate_trajectory(fill_indices, traj) 128 | filter_window = min(5, len(traj_smooth) - 1) 129 | traj_smooth = savgol_filter( 130 | traj_smooth, filter_window, (filter_window + 1) // 2, axis=0 131 | ) 132 | if visualize: 133 | _traj_vis = visualize_3d_trajectory(traj_smooth, size=0.05, cmap_name="viridis") 134 | traj_vis = _traj_vis[0] 135 | for v in _traj_vis[1:]: 136 | traj_vis += v 137 | 138 | hand_vis = vis_scene[0] 139 | for v in vis_scene[1:]: 140 | hand_vis += v 141 | vis = [pcd_orig, traj_vis, hand_vis] 142 | return traj_smooth, vis 143 | 144 | if __name__ == "__main__": 145 | # Load the data and prepare the tensors 146 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 147 | colmap_results_fpath = "datasets/epickitchens_traj_demo/colmap.json" 148 | data_fpath = "datasets/epickitchens_traj_demo/observation.npz" 149 | colmap_results = load_json(colmap_results_fpath) 150 | data = np.load(data_fpath) 151 | 152 | # Optimize the trajectory and visualize the results 153 | traj_smooth, vis = run(colmap_results, data, device, visualize=True) 154 | o3d.visualization.draw(vis) -------------------------------------------------------------------------------- /models/temporal.py: -------------------------------------------------------------------------------- 1 | # 2 | # Based on Diffuser: https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py 3 | # 4 | 5 | import torch 6 | import torch.nn as nn 7 | import einops 8 | from einops.layers.torch import Rearrange 9 | 10 | from models.layers_2d import ( 11 | Downsample1d, 12 | Upsample1d, 13 | Conv1dBlock, 14 | ) 15 | 16 | from models.layers_3d import SinusoidalPosEmb 17 | from models.perceiver import FeaturePerceiver 18 | 19 | 20 | class ResidualTemporalMapBlockConcat(nn.Module): 21 | 22 | def __init__( 23 | self, inp_channels, out_channels, time_embed_dim, horizon, kernel_size=5 24 | ): 25 | super().__init__() 26 | 27 | self.time_mlp = nn.Sequential( 28 | nn.Mish(), 29 | nn.Linear(time_embed_dim, out_channels), 30 | Rearrange("batch t -> batch t 1"), 31 | ) 32 | 33 | self.blocks = nn.ModuleList( 34 | [ 35 | Conv1dBlock(inp_channels, out_channels, kernel_size), 36 | Conv1dBlock(out_channels, out_channels, kernel_size), 37 | ] 38 | ) 39 | 40 | self.residual_conv = ( 41 | nn.Conv1d(inp_channels, out_channels, 1) 42 | if inp_channels != out_channels 43 | else nn.Identity() 44 | ) 45 | 46 | def forward(self, x, t): 47 | """ 48 | x : [ batch_size x inp_channels x horizon ] 49 | t : [ batch_size x embed_dim ] 50 | returns: 51 | out : [ batch_size x out_channels x horizon ] 52 | """ 53 | out = self.blocks[0](x) + self.time_mlp(t) 54 | out = self.blocks[1](out) 55 | return out + self.residual_conv(x) 56 | 57 | 58 | class TemporalMapUnet(nn.Module): 59 | 60 | def __init__( 61 | self, 62 | horizon, 63 | transition_dim, 64 | cond_dim, # additional dimension concatenated with the time dimension 65 | output_dim, 66 | dim=32, # time_dimesion 67 | dim_mults=(1, 2, 4, 8), 68 | use_preceiver=False, 69 | ): 70 | super().__init__() 71 | 72 | ResidualTemporalMapBlock = ResidualTemporalMapBlockConcat 73 | 74 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)] 75 | in_out = list(zip(dims[:-1], dims[1:])) 76 | 77 | time_dim = dim 78 | self.time_mlp = nn.Sequential( 79 | SinusoidalPosEmb(time_dim), 80 | nn.Linear(time_dim, time_dim * 4), 81 | nn.Mish(), 82 | nn.Linear(time_dim * 4, time_dim), 83 | ) 84 | 85 | cond_dim = cond_dim + time_dim 86 | 87 | self.downs = nn.ModuleList([]) 88 | self.ups = nn.ModuleList([]) 89 | num_resolutions = len(in_out) 90 | 91 | # Remember the property of the 1D convolution, [B, C_in, L_in] => [B, C_out, L_out] 92 | # L_out is dependent on the kernel size and stride, and L_in 93 | 94 | for ind, (dim_in, dim_out) in enumerate(in_out): 95 | is_last = ind >= (num_resolutions - 1) 96 | 97 | self.downs.append( 98 | nn.ModuleList( 99 | [ 100 | ResidualTemporalMapBlock( 101 | dim_in, dim_out, time_embed_dim=cond_dim, horizon=horizon 102 | ), # Feature dimension changes, no horizon changes 103 | ResidualTemporalMapBlock( 104 | dim_out, dim_out, time_embed_dim=cond_dim, horizon=horizon 105 | ), 106 | ( 107 | Downsample1d(dim_out) if not is_last else nn.Identity() 108 | ), # No feature dimension changes, but horizon changes 109 | ] 110 | ) 111 | ) 112 | 113 | if not is_last: 114 | horizon = horizon // 2 115 | 116 | mid_dim = dims[-1] 117 | self.mid_block1 = ResidualTemporalMapBlock( 118 | mid_dim, mid_dim, time_embed_dim=cond_dim, horizon=horizon 119 | ) 120 | self.mid_block2 = ResidualTemporalMapBlock( 121 | mid_dim, mid_dim, time_embed_dim=cond_dim, horizon=horizon 122 | ) 123 | 124 | final_up_dim = None 125 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 126 | is_last = ind >= (num_resolutions - 1) 127 | 128 | self.ups.append( 129 | nn.ModuleList( 130 | [ 131 | ResidualTemporalMapBlock( 132 | dim_out * 2, 133 | dim_in, 134 | time_embed_dim=cond_dim, 135 | horizon=horizon, 136 | ), # Feature dimension changes, no horizon changes 137 | ResidualTemporalMapBlock( 138 | dim_in, dim_in, time_embed_dim=cond_dim, horizon=horizon 139 | ), 140 | ( 141 | Upsample1d(dim_in) if not is_last else nn.Identity() 142 | ), # No feature dimension change, but horizon changes 143 | ] 144 | ) 145 | ) 146 | final_up_dim = dim_in 147 | 148 | if not is_last: 149 | horizon = horizon * 2 150 | 151 | self.final_conv = nn.Sequential( 152 | Conv1dBlock(final_up_dim, final_up_dim, kernel_size=5), 153 | nn.Conv1d(final_up_dim, output_dim, 1), 154 | ) 155 | self.use_preceiver = use_preceiver 156 | 157 | if self.use_preceiver: 158 | self.preceiver = FeaturePerceiver( 159 | transition_dim=transition_dim, 160 | condition_dim=cond_dim - time_dim, 161 | time_emb_dim=time_dim, 162 | ) 163 | self.proj = nn.Linear(self.preceiver.last_dim, transition_dim) 164 | 165 | def forward(self, x, cond, time): 166 | """ 167 | x : [ batch x horizon x transition ] 168 | cond: [ batch x cond_dim ] 169 | time: [ batch ] 170 | """ 171 | t = self.time_mlp(time) 172 | 173 | if self.use_preceiver: 174 | x = self.preceiver(x, cond[:, None], t[:, None]) 175 | x = self.proj(x) 176 | x = einops.rearrange(x, "b h t -> b t h") 177 | t = torch.cat([t, cond], dim=-1) # [time+object+action+spatial] 178 | 179 | h = [] 180 | for ii, (resnet, resnet2, downsample) in enumerate(self.downs): 181 | x = resnet(x, t) 182 | x = resnet2(x, t) 183 | 184 | h.append(x) 185 | x = downsample( 186 | x 187 | ) # Increase the feature dimension, reduce the horizon (consider the spatial resolution in image) 188 | # print("Downsample step {}, with shape {}".format(ii, x.shape)) # [B, C, H] 189 | # print(f"[ models/temporal ] Downsample step {ii}, with shape {x.shape}") 190 | 191 | x = self.mid_block1(x, t) 192 | x = self.mid_block2(x, t) 193 | for ii, (resnet, resnet2, upsample) in enumerate(self.ups): 194 | x = torch.cat((x, h.pop()), dim=1) 195 | x = resnet(x, t) 196 | x = resnet2(x, t) 197 | x = upsample(x) # Decrease the feature dimension, increase the horizon 198 | # print("Upsample step {}, with shape {}".format(ii, x.shape)) # [B, C, H] 199 | 200 | x = self.final_conv(x) 201 | x = einops.rearrange(x, "b t h -> b h t") 202 | return x 203 | 204 | 205 | if __name__ == "__main__": 206 | model = TemporalMapUnet( 207 | horizon=80, # time horizon 208 | transition_dim=67, # dimension of the input trajectory 209 | cond_dim=32, # dimension of the condition (from image, depth, text, etc.) 210 | output_dim=3, # dimension of the output trajectory 211 | dim=32, # base feature dimension 212 | dim_mults=(2, 4, 8), # number of the layers 213 | use_preceiver=True, 214 | ) 215 | x = torch.randn(2, 80, 67) 216 | cond = torch.randn(2, 32) 217 | time = torch.randn(2) 218 | print("Input shape: ", x.permute(0, 2, 1).shape) # [B, input_dim, H] 219 | out = model(x, cond, time) # [B, dim', H'] 220 | print("Output shape: ", out.permute(0, 2, 1).shape) # [B, outpu_dim, H] 221 | 222 | # Input shape: torch.Size([2, 3, 20]) 223 | # Downsample step 0, with shape torch.Size([2, 64, 10]) 224 | # Downsample step 1, with shape torch.Size([2, 128, 5]) 225 | # Downsample step 2, with shape torch.Size([2, 256, 5]) 226 | # Upsample step 0, with shape torch.Size([2, 128, 10]) 227 | # Upsample step 1, with shape torch.Size([2, 64, 20]) 228 | # Output shape: torch.Size([2, 3, 20]) 229 | -------------------------------------------------------------------------------- /models/goal.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import einops 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import tqdm 7 | import numpy as np 8 | from models.layers_2d import Decoder, Encoder 9 | from models.helpers import FocalLoss 10 | from models.clip import clip, tokenize 11 | 12 | import torchvision.models as models 13 | import torchvision.transforms as transforms 14 | from torchvision.ops import roi_align, roi_pool 15 | from copy import deepcopy 16 | from models.layers_2d import load_clip 17 | from models.perceiver import FeaturePerceiver 18 | 19 | 20 | class GoalPredictor(pl.LightningModule): 21 | def __init__( 22 | self, 23 | in_channels=4, 24 | out_channels=3, 25 | resolution=[256, 448], 26 | channel_multiplier=[1, 2, 4, 8, 16], 27 | bbox_feature_dim=64, 28 | visual_feature_dim=512, 29 | encode_action=False, 30 | encode_bbox=False, 31 | encode_object=False, 32 | num_heads_attention=4, 33 | num_layers_attention=2, 34 | object_encode_mode="roi_pool", 35 | **kwargs, 36 | ): 37 | super().__init__() 38 | 39 | # self.gpt = gpt 40 | self.in_channels = in_channels 41 | self.out_channels = out_channels 42 | self.resolution = resolution 43 | 44 | self.encode_action = encode_action 45 | self.encode_bbox = encode_bbox 46 | self.encode_object = encode_object 47 | 48 | self.visual_feature_dim = visual_feature_dim 49 | self.bbox_feature_dim = bbox_feature_dim 50 | self.object_encode_mode = object_encode_mode 51 | 52 | # ch_mult=[1, 2, 4, 8, 16] 53 | self.channel_multiplier = channel_multiplier 54 | 55 | self.downscale_factor = 2 ** (len(self.channel_multiplier) - 1) 56 | attn_resolutions = ( 57 | resolution[0] // self.downscale_factor, 58 | resolution[1] // self.downscale_factor, 59 | ) 60 | 61 | # Decide the model architecture 62 | self.visual = Encoder( 63 | ch=64, 64 | ch_mult=self.channel_multiplier, 65 | num_res_blocks=2, 66 | attn_resolutions=attn_resolutions, 67 | in_channels=in_channels, 68 | out_ch=out_channels, 69 | resolution=resolution, 70 | double_z=False, 71 | z_channels=self.visual_feature_dim, 72 | ) 73 | 74 | self.decoder = Decoder( 75 | ch=64, 76 | ch_mult=self.channel_multiplier, 77 | num_res_blocks=2, 78 | attn_resolutions=attn_resolutions, 79 | in_channels=in_channels, 80 | out_ch=out_channels, 81 | resolution=resolution, 82 | double_z=False, 83 | z_channels=self.visual_feature_dim, 84 | ) 85 | 86 | self.transform = transforms.Compose( 87 | [ 88 | transforms.Normalize([0.485, 0.456, 0.406], [ 89 | 0.229, 0.224, 0.225]), 90 | ] 91 | ) 92 | 93 | if self.encode_object: 94 | if self.object_encode_mode == "vlm": 95 | obj_dim = self.vlm_dim 96 | else: 97 | obj_dim = self.visual_feature_dim 98 | 99 | self.object_encode_module = nn.Linear( 100 | obj_dim, self.visual_feature_dim) 101 | 102 | if self.encode_action: 103 | 104 | self.action_encode_module = nn.Linear( 105 | self.visual_feature_dim, self.visual_feature_dim 106 | ) 107 | 108 | if self.encode_bbox: 109 | 110 | self.bbox_encode_module = nn.Linear(4, bbox_feature_dim) 111 | 112 | fuser_dim = 0 113 | if self.encode_action: 114 | fuser_dim += self.visual_feature_dim 115 | if self.encode_object: 116 | fuser_dim += self.visual_feature_dim 117 | if self.encode_bbox: 118 | fuser_dim += self.bbox_feature_dim 119 | 120 | if self.encode_action or self.encode_object or self.encode_bbox: 121 | 122 | self.fuser = FeaturePerceiver( 123 | transition_dim=self.visual_feature_dim, 124 | condition_dim=fuser_dim, 125 | time_emb_dim=0, 126 | ) 127 | self.proj = nn.Linear(self.fuser.last_dim, self.visual_feature_dim) 128 | 129 | else: 130 | self.fuser = None 131 | 132 | self.depth_fuser = FeaturePerceiver( 133 | transition_dim=self.visual_feature_dim, 134 | condition_dim=self.visual_feature_dim, 135 | time_emb_dim=0, # No time embedding 136 | ) 137 | self.depth_proj = nn.Sequential( 138 | nn.Linear(self.depth_fuser.last_dim, self.visual_feature_dim), 139 | nn.TransformerEncoder( 140 | nn.TransformerEncoderLayer( 141 | d_model=self.visual_feature_dim, 142 | dim_feedforward=512, 143 | nhead=4, 144 | batch_first=True, 145 | ), 146 | num_layers=3, 147 | ), 148 | nn.Linear(self.visual_feature_dim, 1), 149 | ) 150 | 151 | def forward(self, data_batch, training=False): 152 | target_key = "vfd" 153 | color_key = "color" 154 | object_color_key = "object_color" 155 | 156 | if training: 157 | color_key += "_aug" 158 | object_color_key += "_aug" 159 | 160 | inputs = self.transform(data_batch[color_key]) 161 | depth = data_batch["depth"][:, None] # [B, 1, H, W] 162 | 163 | if self.in_channels == 4: 164 | start_pos_depth = data_batch["start_pos_depth"][:, None] 165 | start_pos_depth_res = depth - start_pos_depth 166 | inputs = torch.cat([inputs, start_pos_depth_res], dim=1) 167 | 168 | # Shape information 169 | batch_size = inputs.shape[0] 170 | h_in, w_in = inputs.shape[-2:] 171 | h_out, w_out = h_in // self.downscale_factor, w_in // self.downscale_factor 172 | 173 | # Box information 174 | bbox = data_batch["bbox"] # [B, 4] 175 | bbox_batch_id = torch.arange( 176 | batch_size, device=inputs.device 177 | ) # Only one box per sample 178 | bbox = torch.cat([bbox_batch_id[:, None], bbox], dim=1) # [B, 5] 179 | 180 | # Extract visual features 181 | context_feature = self.visual(inputs) 182 | context_feature = einops.rearrange( 183 | context_feature.clone(), "b c h w -> b (h w) c", h=h_out, w=w_out 184 | ) 185 | feature = context_feature 186 | 187 | # Acquire bbox features 188 | condition_feature = [] 189 | if self.encode_object: 190 | if self.object_encode_mode == "vlm": 191 | 192 | print("... Try another way to encode object") 193 | elif self.object_encode_mode in ["roi_pool", "roi_align"]: 194 | roi_res = 6 195 | 196 | roi_method = eval(self.object_encode_mode) 197 | context_feature_obj = einops.rearrange( 198 | context_feature, "b (h w) c -> b c h w", h=h_out, w=w_out 199 | ) 200 | spatial_scale = context_feature_obj.shape[-1] / \ 201 | inputs.shape[-1] 202 | assert spatial_scale == context_feature_obj.shape[-2] / \ 203 | inputs.shape[-2] 204 | 205 | context_feature_obj = roi_method( 206 | context_feature_obj, 207 | bbox, 208 | spatial_scale=spatial_scale, 209 | output_size=(roi_res, roi_res), 210 | ) # [B, c, roi_res, roi_res] 211 | object_feature = einops.rearrange( 212 | context_feature_obj, "b c h w -> b (h w) c" 213 | ) 214 | object_feature = object_feature.mean( 215 | dim=1)[:, None] # [B, 1, c] 216 | else: 217 | raise NotImplementedError( 218 | "Object encode mode {} not implemented".format( 219 | self.object_encode_mode 220 | ) 221 | ) 222 | object_feature = self.object_encode_module(object_feature) 223 | 224 | condition_feature.append(object_feature) 225 | 226 | if self.encode_action: 227 | 228 | action_feature = data_batch["action_feature"][:, None] 229 | action_feature = self.action_encode_module( 230 | action_feature) # [B, 1, c] 231 | 232 | condition_feature.append(action_feature) 233 | 234 | if self.encode_bbox: 235 | bbox_norm = bbox[:, 1:].clone() 236 | bbox_norm[:, [0, 2]] = bbox_norm[:, [0, 2]] / inputs.shape[-1] 237 | bbox_norm[:, [1, 3]] = bbox_norm[:, [1, 3]] / inputs.shape[-2] 238 | bbox_feature = self.bbox_encode_module(bbox_norm) # [B, 64] 239 | 240 | bbox_feature = bbox_feature[:, None] # [B, 1, 64] 241 | condition_feature.append(bbox_feature) 242 | 243 | condition_feature = torch.cat( 244 | condition_feature, dim=-1) # [B, 1, 2c+64) 245 | 246 | if self.fuser is not None and len(condition_feature) > 0: 247 | feature = self.fuser(feature, condition_feature) 248 | feature = self.proj(feature) 249 | 250 | depth_feature = feature 251 | verb_feature = data_batch["verb_feature"][:, None] 252 | 253 | depth_feature = self.depth_fuser(depth_feature, verb_feature) 254 | pred_depth = self.depth_proj(depth_feature) 255 | 256 | feature = einops.rearrange( 257 | feature, "b (h w) c -> b c h w", h=h_out, w=w_out) 258 | pred_depth = einops.rearrange( 259 | pred_depth, "b (h w) c -> b c h w", h=h_out, w=w_out 260 | ) 261 | 262 | pred = self.decoder(feature) 263 | pred_depth = F.interpolate( 264 | pred_depth, size=( 265 | self.resolution[0], self.resolution[1]), mode="bilinear" 266 | ) 267 | # Postprocess the output 268 | if pred is not None: 269 | pred_vf, pred_dres, pred_heatmap = pred[:, 270 | :2], pred_depth, pred[:, -1:] 271 | pred_vf = F.normalize(pred_vf, p=2, dim=1) # [-1, 1] 272 | pred_vf = pred_vf.clamp(-1, 1) 273 | if "d_res_scale" in data_batch: 274 | print( 275 | "... Rescale the depth with d_res_scale: ", 276 | data_batch["d_res_scale"], 277 | ) 278 | pred_dres = pred_dres * data_batch["d_res_scale"] 279 | pred_d_final = start_pos_depth - pred_dres # [B, 1, H, W] 280 | pred = torch.cat( 281 | [pred_vf, pred_d_final, pred_heatmap], dim=1 282 | ) # [B, 4, H, W] 283 | 284 | outputs = {"pred": pred} 285 | 286 | return outputs 287 | -------------------------------------------------------------------------------- /models/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from models.clip.model import build_model 14 | from models.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | 19 | BICUBIC = InterpolationMode.BICUBIC 20 | except ImportError: 21 | BICUBIC = Image.BICUBIC 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", # 243MB 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", # 277MB 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", #401MB 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", # 629MB 34 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", # 1.2GB 35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", # 334MB 36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", # 334MB 37 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 38 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 39 | } 40 | 41 | 42 | def _download(url: str, root: str): 43 | os.makedirs(root, exist_ok=True) 44 | filename = os.path.basename(url) 45 | 46 | expected_sha256 = url.split("/")[-2] 47 | download_target = os.path.join(root, filename) 48 | 49 | if os.path.exists(download_target) and not os.path.isfile(download_target): 50 | raise RuntimeError(f"{download_target} exists and is not a regular file") 51 | 52 | if os.path.isfile(download_target): 53 | if ( 54 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 55 | == expected_sha256 56 | ): 57 | return download_target 58 | else: 59 | warnings.warn( 60 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 61 | ) 62 | 63 | print(f"Downloading CLIP model from {url}") 64 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 65 | with tqdm( 66 | total=int(source.info().get("Content-Length")), 67 | ncols=80, 68 | unit="iB", 69 | unit_scale=True, 70 | unit_divisor=1024, 71 | ) as loop: 72 | while True: 73 | buffer = source.read(8192) 74 | if not buffer: 75 | break 76 | 77 | output.write(buffer) 78 | loop.update(len(buffer)) 79 | 80 | if ( 81 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 82 | != expected_sha256 83 | ): 84 | raise RuntimeError( 85 | "Model has been downloaded but the SHA256 checksum does not not match" 86 | ) 87 | 88 | return download_target 89 | 90 | 91 | def _convert_image_to_rgb(image): 92 | return image.convert("RGB") 93 | 94 | 95 | def _transform(n_px): 96 | return Compose( 97 | [ 98 | # Resize(n_px, interpolation=BICUBIC), 99 | # CenterCrop(n_px), 100 | # _convert_image_to_rgb, 101 | # ToTensor(), 102 | Normalize( 103 | (0.48145466, 0.4578275, 0.40821073), 104 | (0.26862954, 0.26130258, 0.27577711), 105 | ), 106 | ] 107 | ) 108 | 109 | 110 | def available_models() -> List[str]: 111 | """Returns the names of available CLIP models""" 112 | return list(_MODELS.keys()) 113 | 114 | 115 | TORCH_HUB_ROOT = os.path.expandvars(os.getenv("$TORCH_HUB_ROOT", "$HOME/.torch_hub")) 116 | 117 | 118 | def load( 119 | name: str, 120 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 121 | jit: bool = False, 122 | download_root: str = None, 123 | ): 124 | """Load a CLIP model 125 | 126 | Parameters 127 | ---------- 128 | name : str 129 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 130 | 131 | device : Union[str, torch.device] 132 | The device to put the loaded model 133 | 134 | jit : bool 135 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 136 | 137 | download_root: str 138 | path to download the model files; by default, it uses "~/.torch_hub/clip" 139 | 140 | Returns 141 | ------- 142 | model : torch.nn.Module 143 | The CLIP model 144 | 145 | preprocess : Callable[[PIL.Image], torch.Tensor] 146 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 147 | """ 148 | if name in _MODELS: 149 | model_path = _download(_MODELS[name], download_root or TORCH_HUB_ROOT) 150 | elif os.path.isfile(name): 151 | model_path = name 152 | else: 153 | raise RuntimeError( 154 | f"Model {name} not found; available models = {available_models()}" 155 | ) 156 | 157 | with open(model_path, "rb") as opened_file: 158 | try: 159 | # loading JIT archive 160 | model = torch.jit.load( 161 | opened_file, map_location=device if jit else "cpu" 162 | ).eval() 163 | state_dict = None 164 | except RuntimeError: 165 | # loading saved state dict 166 | if jit: 167 | warnings.warn( 168 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 169 | ) 170 | jit = False 171 | state_dict = torch.load(opened_file, map_location="cpu") 172 | 173 | if not jit: 174 | model = build_model(state_dict or model.state_dict()).to(device) 175 | if str(device) == "cpu": 176 | model.float() 177 | return model, _transform(model.visual.input_resolution) 178 | 179 | # patch the device names 180 | device_holder = torch.jit.trace( 181 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] 182 | ) 183 | device_node = [ 184 | n 185 | for n in device_holder.graph.findAllNodes("prim::Constant") 186 | if "Device" in repr(n) 187 | ][-1] 188 | 189 | def patch_device(module): 190 | try: 191 | graphs = [module.graph] if hasattr(module, "graph") else [] 192 | except RuntimeError: 193 | graphs = [] 194 | 195 | if hasattr(module, "forward1"): 196 | graphs.append(module.forward1.graph) 197 | 198 | for graph in graphs: 199 | for node in graph.findAllNodes("prim::Constant"): 200 | if "value" in node.attributeNames() and str(node["value"]).startswith( 201 | "cuda" 202 | ): 203 | node.copyAttributes(device_node) 204 | 205 | model.apply(patch_device) 206 | patch_device(model.encode_image) 207 | patch_device(model.encode_text) 208 | 209 | # patch dtype to float32 on CPU 210 | if str(device) == "cpu": 211 | float_holder = torch.jit.trace( 212 | lambda: torch.ones([]).float(), example_inputs=[] 213 | ) 214 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 215 | float_node = float_input.node() 216 | 217 | def patch_float(module): 218 | try: 219 | graphs = [module.graph] if hasattr(module, "graph") else [] 220 | except RuntimeError: 221 | graphs = [] 222 | 223 | if hasattr(module, "forward1"): 224 | graphs.append(module.forward1.graph) 225 | 226 | for graph in graphs: 227 | for node in graph.findAllNodes("aten::to"): 228 | inputs = list(node.inputs()) 229 | for i in [ 230 | 1, 231 | 2, 232 | ]: # dtype can be the second or third argument to aten::to() 233 | if inputs[i].node()["value"] == 5: 234 | inputs[i].node().copyAttributes(float_node) 235 | 236 | model.apply(patch_float) 237 | patch_float(model.encode_image) 238 | patch_float(model.encode_text) 239 | 240 | model.float() 241 | 242 | return model, _transform(model.input_resolution.item()) 243 | 244 | 245 | def tokenize( 246 | texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False 247 | ) -> Union[torch.IntTensor, torch.LongTensor]: 248 | """ 249 | Returns the tokenized representation of given input string(s) 250 | 251 | Parameters 252 | ---------- 253 | texts : Union[str, List[str]] 254 | An input string or a list of input strings to tokenize 255 | 256 | context_length : int 257 | The context length to use; all CLIP models use 77 as the context length 258 | 259 | truncate: bool 260 | Whether to truncate the text in case its encoding is longer than the context length 261 | 262 | Returns 263 | ------- 264 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 265 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 266 | """ 267 | if isinstance(texts, str): 268 | texts = [texts] 269 | 270 | sot_token = _tokenizer.encoder["<|startoftext|>"] 271 | eot_token = _tokenizer.encoder["<|endoftext|>"] 272 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 273 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 274 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 275 | else: 276 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 277 | 278 | for i, tokens in enumerate(all_tokens): 279 | if len(tokens) > context_length: 280 | if truncate: 281 | tokens = tokens[:context_length] 282 | tokens[-1] = eot_token 283 | else: 284 | raise RuntimeError( 285 | f"Input {texts[i]} is too long for context length {context_length}" 286 | ) 287 | result[i, : len(tokens)] = torch.tensor(tokens) 288 | 289 | return result 290 | -------------------------------------------------------------------------------- /demos/infer_affordance.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | import os 4 | import pytorch_lightning as pl 5 | from omegaconf import OmegaConf 6 | import torch 7 | import numpy as np 8 | import cv2 9 | from diffuser_utils.guidance_loss import DiffuserGuidance 10 | from diffuser_utils.guidance_params import GUIDANCE_PARAMS_DICT 11 | import diffuser_utils.dataset_utils as DatasetUtils 12 | from models.clip import clip 13 | import open3d as o3d 14 | from algos.afford_algos import AffordanceInferenceEngine 15 | from easydict import EasyDict as edict 16 | from copy import deepcopy 17 | from transformations import rotation_matrix 18 | 19 | # CLIP-based action text encoder 20 | VLM, VLM_TRANSFORM = clip.load("ViT-B/16", jit=False) 21 | VLM.float() 22 | VLM.eval() 23 | VLM.cuda() 24 | for p in VLM.parameters(): 25 | p.requires_grad = False 26 | 27 | 28 | def main(args): 29 | # Parse the instruction 30 | config = edict(OmegaConf.to_container(OmegaConf.load(args.config))) 31 | gripper_mesh = o3d.io.read_triangle_mesh(config.gripper_mesh_file) 32 | gripper_mesh.compute_vertex_normals() 33 | gripper_mesh.paint_uniform_color([255 / 255.0, 192 / 255.0, 203 / 255.0]) 34 | gripper_mesh.rotate(rotation_matrix( 35 | np.pi/2, [0, 0, 1])[:3, :3], center=[0, 0, 0]) 36 | 37 | # Parse the instruction 38 | frame_id = args.frame 39 | action = args.instruction.split(" ")[0] 40 | object_name = args.object 41 | if args.object == "" and not args.load_results: 42 | print("No object name provided, please provide an object name for the detector to work") 43 | return 44 | 45 | dataset_path = os.path.join(config.dataset_dir, args.dataset) 46 | camera_info = DatasetUtils.load_json( 47 | os.path.join(dataset_path, "camera_intrinsic.json") 48 | ) 49 | intr = np.array(camera_info["intrinsic_matrix"] 50 | ).reshape(3, 3).astype(np.float32).T 51 | 52 | if args.scale > 0: 53 | calib_scale = args.scale 54 | else: 55 | calib_scale = 1.5 * 640 / intr[0, 0] 56 | 57 | # Set the parameters for guidance 58 | if action not in GUIDANCE_PARAMS_DICT: 59 | print("WARNING: Action [{}] not found in guidance params, quality not guaranteed".format( 60 | action)) 61 | guidance_params = GUIDANCE_PARAMS_DICT["other"] 62 | else: 63 | guidance_params = GUIDANCE_PARAMS_DICT[action] 64 | 65 | goal_weight = guidance_params["goal_weight"] 66 | noncollide_weight = guidance_params["noncollide_weight"] 67 | normal_weight = guidance_params["normal_weight"] 68 | contact_weight = guidance_params["contact_weight"] 69 | fine_voxel_resolution = guidance_params["fine_voxel_resolution"] 70 | exclude_object_points = guidance_params["exclude_object_points"] 71 | print("Using guidance params: ", guidance_params) 72 | 73 | # Read the data 74 | depth_file_path = os.path.join( 75 | dataset_path, "depth_m3d", "{:06d}.png".format(frame_id)) 76 | if not os.path.exists(depth_file_path): 77 | depth_file_path = os.path.join( 78 | dataset_path, "depth", "{:06d}.png".format(frame_id)) 79 | 80 | color_file_path = os.path.join( 81 | dataset_path, "color", "{:06d}.png".format(frame_id)) 82 | if not os.path.exists(color_file_path): 83 | color_file_path = color_file_path.replace(".png", ".jpg") 84 | depth = cv2.imread(depth_file_path, -1) 85 | depth = depth / 1000.0 86 | color = cv2.imread(color_file_path, -1)[..., [2, 1, 0]].copy() 87 | depth[depth > 2] = 0 88 | data_batch = DatasetUtils.get_context_data_from_rgbd( 89 | color, 90 | depth, 91 | intr, 92 | voxel_resolution=32, 93 | fine_voxel_resolution=fine_voxel_resolution, 94 | ) 95 | 96 | # Initialize the inference engine 97 | config_traj_path = config.config_traj 98 | config_goal_path = config.config_goal 99 | config_contact_path = config.config_contact 100 | cfg_traj = edict(OmegaConf.to_container(OmegaConf.load(config_traj_path))) 101 | cfg_goal = edict(OmegaConf.to_container(OmegaConf.load(config_goal_path))) 102 | cfg_contact = edict(OmegaConf.to_container( 103 | OmegaConf.load(config_contact_path))) 104 | cfg_traj.TEST.ckpt_path = config.traj_ckpt 105 | cfg_goal.TEST.ckpt_path = config.goal_ckpt 106 | cfg_contact.TEST.ckpt_path = config.contact_ckpt 107 | cfg_traj.TEST.num_samples = 40 108 | cfg_traj.TEST.class_free_guide_weight = -0.7 109 | 110 | use_detector, use_esam = True, True 111 | 112 | if args.skip_coarse_stage: 113 | assert ( 114 | args.load_results 115 | ), "Must load results if skipping coarse affordance prediction" 116 | cfg_goal, cfg_contact = None, None 117 | 118 | if args.load_results: 119 | use_detector, use_esam = False, False 120 | 121 | diffuser_guidance = DiffuserGuidance( 122 | goal_weight=goal_weight, # 100.0 123 | noncollide_weight=noncollide_weight, 124 | contact_weight=contact_weight, 125 | normal_weight=normal_weight, 126 | scale=calib_scale, 127 | exclude_object_points=exclude_object_points, 128 | valid_horizon=-1, 129 | ) 130 | 131 | inference_engine = AffordanceInferenceEngine( 132 | traj_config=cfg_traj, 133 | goal_config=cfg_goal, 134 | contact_config=cfg_contact, 135 | traj_guidance=diffuser_guidance, 136 | use_detector=use_detector, 137 | use_esam=use_esam, 138 | use_graspnet=args.use_graspnet, 139 | detector_config=config.config_detector, 140 | detector_ckpt=config.detector_ckpt, 141 | esam_ckpt=config.esam_ckpt, 142 | graspnet_ckpt=config.graspnet_ckpt, 143 | ) 144 | 145 | for k, v in data_batch.items(): 146 | if isinstance(v, torch.Tensor): 147 | data_batch[k] = v.unsqueeze(0).cuda() 148 | _, null_embeddings = DatasetUtils.encode_text_clip( 149 | VLM, [""], max_length=None, device="cuda" 150 | ) 151 | data_batch["action_feature_null"] = null_embeddings.cuda() 152 | data_batch["action_text"] = args.instruction 153 | data_batch["gripper_points"] = torch.from_numpy( 154 | np.asarray(gripper_mesh.sample_points_uniformly( 155 | number_of_points=2048).points).astype(np.float32) 156 | ).cuda() 157 | 158 | # Run the open-vocabulary object detection 159 | meta_name = "{:06d}_{}.npz".format( 160 | frame_id, args.instruction.replace(" ", "-")) 161 | meta_save_path = os.path.join(dataset_path, "scene_meta", meta_name) 162 | os.makedirs(os.path.dirname(meta_save_path), exist_ok=True) 163 | if args.load_results: 164 | assert os.path.exists(meta_save_path), "Results {} not found".format( 165 | meta_save_path 166 | ) 167 | AffordanceInferenceEngine.load_results(meta_save_path, data_batch) 168 | else: 169 | inference_engine.forward_detect(data_batch, text=object_name) 170 | AffordanceInferenceEngine.export_results(meta_save_path, data_batch) 171 | 172 | num_objects = len(data_batch["bbox_all"][0]) 173 | if num_objects == 0: 174 | print("No objects detected") 175 | return 176 | for n in range(num_objects): 177 | outputs = {} 178 | results_name = "{:06d}_{:06d}_{}.npz".format( 179 | frame_id, n, args.instruction.replace(" ", "-") 180 | ) 181 | results_save_path = os.path.join( 182 | dataset_path, "prediction", results_name) 183 | if args.load_results: 184 | assert os.path.exists(results_save_path), "Results {} not found".format( 185 | results_save_path 186 | ) 187 | AffordanceInferenceEngine.load_results( 188 | results_save_path, data_batch) 189 | else: 190 | # Set the data for the object 191 | label_text = data_batch["label_text_all"][n] 192 | data_batch["bbox"] = data_batch["bbox_all"][:, n] 193 | data_batch["cropped_intr"] = data_batch["cropped_intr_all"][:, n] 194 | data_batch["object_mask"] = data_batch["object_mask_all"][:, n] 195 | data_batch["object_bbox_mask"] = data_batch["object_bbox_mask_all"][:, n] 196 | 197 | data_batch["object_color"] = data_batch["object_color_all"][:, n] 198 | data_batch["object_depth"] = data_batch["object_depth_all"][:, n] 199 | data_batch["object_points"] = data_batch["object_points_all"][:, n] 200 | data_batch["resize_ratio"] = data_batch["resize_ratio_all"][:, n] 201 | data_batch["label_text"] = label_text 202 | 203 | # Inference 204 | inference_engine.encode_action( 205 | data_batch, clip_model=VLM, max_length=20) 206 | if not args.skip_coarse_stage: 207 | inference_engine.forward_contact( 208 | data_batch, 209 | outputs, 210 | solve_vf=False, 211 | update_data_batch=True, 212 | sample_num=100, 213 | ) 214 | inference_engine.forward_goal( 215 | data_batch, 216 | outputs, 217 | solve_vf=False, 218 | update_data_batch=True, 219 | sample_num=100, 220 | ) 221 | 222 | inference_engine.compute_object_contact_normal( 223 | data_batch 224 | ) 225 | inference_engine.compute_object_grasp_pose( 226 | data_batch, collision_thresh=0.25) 227 | 228 | if not args.skip_fine_stage: 229 | inference_engine.forward_traj( 230 | data_batch, 231 | outputs, 232 | radii=0.65, 233 | scale=calib_scale, 234 | use_guidance=True, 235 | update_data_batch=True, 236 | ) 237 | data_batch["pred_trajectories"] = data_batch["pred_trajectories"][:, :, :60] 238 | inference_engine.smooth_traj(data_batch) 239 | 240 | # Save trajectories and guidance losses 241 | if not args.no_save: 242 | os.makedirs(os.path.dirname(results_save_path), exist_ok=True) 243 | AffordanceInferenceEngine.export_results( 244 | results_save_path, data_batch) 245 | 246 | if args.visualize: 247 | assert not args.skip_fine_stage, "Cannot visualize fine stage if it is skipped" 248 | # Update the predicted trajectories 249 | pred_trajs = data_batch["pred_trajectories"] # [B, N, H, 3] 250 | vis_o3d = inference_engine.nets["traj"].visualize_trajectory_by_rendering( 251 | data_batch, "configs/_render.json", window=False, return_vis=True 252 | ) 253 | 254 | if "guide_losses" in data_batch: 255 | pred_trajs_loss = data_batch["guide_losses"]["total_loss"].detach( 256 | ) 257 | traj_loss_colors = DatasetUtils.get_heatmap( 258 | pred_trajs_loss.cpu().numpy(), cmap_name="turbo" 259 | ) 260 | best_traj_idx = np.argmin(pred_trajs_loss.cpu().numpy()) 261 | best_traj = pred_trajs[0, 262 | best_traj_idx].cpu().numpy().squeeze() 263 | for i in range(1, len(vis_o3d)): 264 | vis_o3d[i].paint_uniform_color( 265 | traj_loss_colors[0, i - 1, :]) 266 | else: 267 | traj_loss_colors = None 268 | vis_o3d_trajs = vis_o3d[1] 269 | for _vis_traj in vis_o3d[2:]: 270 | vis_o3d_trajs += _vis_traj 271 | 272 | if "grasp_pose" in data_batch: 273 | gripper_colors = DatasetUtils.get_heatmap( 274 | np.arange(len(best_traj))[None], cmap_name="plasma" 275 | )[0] 276 | grasp_pose = data_batch["grasp_pose"].cpu().numpy().squeeze() 277 | gripper_mesh_init = deepcopy(gripper_mesh) 278 | gripper_mesh_init.transform(grasp_pose) 279 | gripper_mesh_init.paint_uniform_color(gripper_colors[10]) 280 | vis_o3d_best = gripper_mesh_init 281 | 282 | for hi in range(15, len(best_traj)): 283 | if hi % 20 == 0: 284 | gripper_mesh_hi = deepcopy(gripper_mesh) 285 | gripper_mesh_hi.transform(grasp_pose) 286 | gripper_mesh_hi.translate(best_traj[hi] - best_traj[0]) 287 | gripper_mesh_hi.paint_uniform_color(gripper_colors[hi]) 288 | vis_o3d_best += gripper_mesh_hi 289 | vis_o3d_final = [vis_o3d[0], vis_o3d_trajs, vis_o3d_best] 290 | 291 | o3d.visualization.draw( 292 | vis_o3d_final, title="Affordance Visualization for Instruction: {} (Object {})".format( 293 | args.instruction, n) 294 | ) 295 | 296 | 297 | if __name__ == "__main__": 298 | parser = argparse.ArgumentParser() 299 | parser.add_argument( 300 | "--config", 301 | type=str, 302 | default="./config/test_config.yaml", 303 | help="Specify the checkpoint path, config file path, and the dataset directory" 304 | ) 305 | parser.add_argument("-d", "--dataset", type=str, 306 | default="vidbot_data_demo", help="Dataset name") 307 | parser.add_argument("-f", "--frame", type=int, 308 | default=0, help="Frame index") 309 | parser.add_argument("-i", "--instruction", type=str, 310 | default="open drawer", 311 | help="Instruction fed to the affordance model, should be in the format of Verb + Object, no space between the verb") 312 | parser.add_argument("-o", "--object", type=str, default="", 313 | help="Object class name fed to the detector") 314 | parser.add_argument("-v", "--visualize", 315 | action="store_true", help="Visualize the results") 316 | parser.add_argument("-s", "--scale", type=float, 317 | default=-1, help="Scale of the trajectory, set this to -1 if you want to use the default scale") 318 | parser.add_argument("--no_save", action="store_true", 319 | help="Do not save the results") 320 | parser.add_argument("--load_results", action="store_true", 321 | help="Load the results from the file, set this to True if you don't want to install the detector") 322 | parser.add_argument("--skip_coarse_stage", 323 | action="store_true", help="Skip the coarse stage, make sure to set --load_results to True") 324 | parser.add_argument("--skip_fine_stage", 325 | action="store_true", help="Skip the fine stage, make sure to set --load_results to True") 326 | parser.add_argument("--use_graspnet", action="store_true", 327 | help="Use the graspnet, use this if GraspNet is successfully installed, othewise we use a heuristic approach to acuqire the grasp pose") 328 | args = parser.parse_args() 329 | pl.seed_everything(42) 330 | 331 | main(args) 332 | -------------------------------------------------------------------------------- /algos/traj_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from scipy.signal import savgol_filter 4 | from models.layers_2d import BackprojectDepth, Project3D 5 | from pytorch3d.transforms import rotation_6d_to_matrix, matrix_to_rotation_6d 6 | import numpy as np 7 | import cv2 8 | 9 | class TrajectoryOptimizer: 10 | def __init__( 11 | self, 12 | resolution=(256, 456), 13 | lr_scale_global=0.05, 14 | lr_scale=0.1, 15 | lr_pose=0.05, 16 | num_iters_scale=10, 17 | num_iters_pose=100, 18 | warp_mode="points", 19 | device="cuda", 20 | ): 21 | self.height, self.width = resolution[0], resolution[1] 22 | self.device = device 23 | self.lr_scale_global = lr_scale_global 24 | self.lr_scale = lr_scale 25 | self.lr_pose = lr_pose 26 | self.num_iters_scale = num_iters_scale 27 | self.num_iters_pose = num_iters_pose 28 | self.backproject_depth = BackprojectDepth(self.height, self.width).to(device) 29 | self.project_3d = Project3D().to(device) 30 | self.warp_mode = warp_mode 31 | 32 | def compute_warped_results( 33 | self, 34 | intrinsics, 35 | rgb_tensors, 36 | depth_tensors, 37 | mask_tensors, 38 | scale_tensors, 39 | rgb_key, 40 | depth_key, 41 | mask_key, 42 | scake_key, 43 | T_kc_tensors, 44 | mode, 45 | return_color=False, 46 | verbose=False, 47 | ): 48 | N, _, height, width = depth_tensors.shape 49 | depth_key = depth_key * scake_key[:, None, None, None] # To Colmap space 50 | depth_key = depth_key.repeat(N, 1, 1, 1) 51 | mask_key = mask_key.repeat(N, 1, 1, 1) 52 | rgb_key = rgb_key.repeat(N, 1, 1, 1) 53 | 54 | # Prepare the depth 55 | depth_tensors_tmp = ( 56 | depth_tensors * scale_tensors[:, None, None, None] 57 | ) # To Colmap space 58 | 59 | # Compute the warping flow from i to k, 60 | points = self.backproject_depth(depth_tensors_tmp, K=intrinsics) # [N, 3, H*W] 61 | points = points.permute(0, 2, 1) # [N, H*W, 3] 62 | pix_coords = self.project_3d( 63 | points, K=intrinsics, T=T_kc_tensors 64 | ) # [N, 2, H*W] 65 | 66 | # Acquire the backward pixel flow 67 | pix_coords[:, 0] = (pix_coords[:, 0] / (width - 1)) * 2 - 1 68 | pix_coords[:, 1] = (pix_coords[:, 1] / (height - 1)) * 2 - 1 69 | pix_coords = pix_coords.view(-1, 2, height, width) # [N, 2, H, W] 70 | pix_coords = pix_coords.permute(0, 2, 3, 1) # [N, H, W, 2] 71 | 72 | warped_depths_key = F.grid_sample( 73 | depth_key, 74 | pix_coords, 75 | mode="bilinear", 76 | padding_mode="border", 77 | align_corners=True, 78 | ) # In frame c 79 | 80 | warped_masks_key = F.grid_sample( 81 | mask_key, 82 | pix_coords, 83 | mode="nearest", 84 | padding_mode="border", 85 | align_corners=True, 86 | ) # In frame c 87 | 88 | if mode == "points": 89 | # Warp the points 90 | points_c = points # In frame c 91 | points_c = points_c @ T_kc_tensors[:, :3, :3].transpose( 92 | -1, -2 93 | ) + T_kc_tensors[:, :3, 3].unsqueeze(-2) # In frame k 94 | points_k = self.backproject_depth(depth_key, K=intrinsics) # In frame k 95 | points_k = points_k.view(-1, 3, height, width) # [N, 3, H, W] 96 | points_k_to_c = F.grid_sample( 97 | points_k, 98 | pix_coords, 99 | mode="nearest", 100 | padding_mode="border", 101 | align_corners=True, 102 | ) 103 | points_k_to_c = points_k_to_c.view(-1, 3, height, width) # [N, 3, H, W] 104 | points_k_to_c = points_k_to_c.permute(0, 2, 3, 1) # [N, H, W, 3] 105 | points_k_to_c = points_k_to_c.view(N, -1, 3) # [N, H*W, 3] 106 | points_k_to_c = points_k_to_c.clone().detach() 107 | 108 | # Warp the depth 109 | elif mode == "depth": 110 | points_k_to_c = self.backproject_depth( 111 | warped_depths_key, K=intrinsics 112 | ) # In frame c 113 | points_k_to_c = points_k_to_c.permute(0, 2, 1) # In frame c 114 | points_c = points # In frame c 115 | points_c = points_c @ T_kc_tensors[:, :3, :3].transpose( 116 | -1, -2 117 | ) + T_kc_tensors[:, :3, 3].unsqueeze(-2) # In frame k 118 | points_k_to_c = points_k_to_c @ T_kc_tensors[:, :3, :3].transpose( 119 | -1, -2 120 | ) + T_kc_tensors[:, :3, 3].unsqueeze(-2) # In frame k 121 | 122 | else: 123 | raise ValueError("Invalid mode: {}".format(mode)) 124 | 125 | if return_color: 126 | warped_rgbs_key = F.grid_sample( 127 | rgb_key, 128 | pix_coords, 129 | mode="bilinear", 130 | padding_mode="border", 131 | align_corners=True, 132 | ) 133 | if verbose: 134 | for bi in range(len(rgb_key)): 135 | warped_rgb_key_for_i = warped_rgbs_key[bi].detach().cpu() 136 | rgb_i = rgb_tensors[bi].detach().cpu() 137 | rgb_vis_i = torch.cat([rgb_i, warped_rgb_key_for_i], dim=-1) 138 | rgb_vis_i = rgb_vis_i.permute(1, 2, 0).cpu().numpy() 139 | rgb_vis_i = (rgb_vis_i * 255).astype(np.uint8) 140 | rgb_vis_i = rgb_vis_i[..., ::-1] 141 | cv2.imshow("rgb_{}".format(bi), rgb_vis_i) 142 | cv2.waitKey(0) 143 | 144 | return ( 145 | warped_rgbs_key, 146 | points_c, 147 | points_k_to_c, 148 | mask_tensors, 149 | warped_masks_key, 150 | ) 151 | 152 | return ( 153 | None, 154 | points_c, 155 | points_k_to_c, 156 | mask_tensors, 157 | warped_masks_key, 158 | ) 159 | 160 | def optimize_pose( 161 | self, 162 | intr, 163 | rgb_tensors, 164 | depth_tensors, 165 | mask_tensors, 166 | scale_init_tensors, 167 | scale_global, 168 | colmap_results, 169 | key_idx=0, 170 | depth_filter_thresh=1.25, 171 | optimize_pose=True, 172 | verbose=False, 173 | ): 174 | # Prepare the tensors 175 | T_wc_tensors = [] 176 | frame_ids = list(colmap_results.keys()) 177 | 178 | for ii, fi in enumerate(frame_ids): 179 | T_world_ci = np.array(colmap_results[str(fi)]["T_wc"]).reshape(4, 4) 180 | T_wc_tensors.append(torch.from_numpy(T_world_ci).float()) 181 | T_wc_tensors = torch.stack(T_wc_tensors).to(self.device) # [N, 4, 4] 182 | key_rgb = rgb_tensors[key_idx][None] # [1, 3, H, W] 183 | key_depth = depth_tensors[key_idx][None] # [1, 1, H, W] 184 | key_mask = mask_tensors[key_idx][None] # [1, 1, H, W] 185 | key_scale = scale_global.clone()[None] # [1] 186 | T_wk = T_wc_tensors[key_idx][None] # [1, 4, 4] 187 | T_kc_tensors = torch.matmul(torch.inverse(T_wk), T_wc_tensors) # [N, 4, 4] 188 | 189 | # Prepare the optimization 190 | scale_tensors_global = torch.ones_like(scale_init_tensors) * key_scale 191 | delta_scale = scale_init_tensors / scale_tensors_global 192 | delta_translation = torch.zeros_like(T_kc_tensors[:, :3, 3]).float() 193 | delta_r6d = matrix_to_rotation_6d( 194 | torch.eye(3, device=self.device).float() 195 | )[None].repeat(len(T_kc_tensors), 1) 196 | 197 | delta_scale.requires_grad = True 198 | delta_translation.requires_grad = optimize_pose 199 | delta_r6d.requires_grad = optimize_pose 200 | 201 | optimizer = torch.optim.Adam([delta_scale], self.lr_scale, betas=(0.9, 0.9)) 202 | if optimize_pose: 203 | optimizer.add_param_group( 204 | {"params": delta_translation, "lr": self.lr_pose, "betas": (0.9, 0.9)} 205 | ) 206 | optimizer.add_param_group( 207 | {"params": delta_r6d, "lr": self.lr_pose, "betas": (0.9, 0.9)} 208 | ) 209 | 210 | for it in range(self.num_iters_pose): 211 | optimizer.zero_grad() 212 | height, width = rgb_tensors.shape[-2:] 213 | scale_curr = scale_tensors_global * delta_scale 214 | T_kc_tensors_curr = T_kc_tensors.clone() 215 | delta_rot = rotation_6d_to_matrix(delta_r6d).transpose(-1, -2) 216 | delta_T = ( 217 | torch.eye(4, device=self.device) 218 | .float()[None] 219 | .repeat(len(T_kc_tensors), 1, 1) 220 | ) 221 | delta_T[..., :3, :3] = delta_rot 222 | delta_T[..., :3, 3] = delta_translation 223 | T_kc_tensors_curr = torch.matmul( 224 | delta_T, 225 | T_kc_tensors_curr, 226 | ) 227 | 228 | _, points_c, points_k_to_c, masks_c, warped_masks_key = ( 229 | self.compute_warped_results( 230 | intr, 231 | rgb_tensors, 232 | depth_tensors, 233 | mask_tensors, 234 | scale_curr, 235 | key_rgb, 236 | key_depth, 237 | key_mask, 238 | key_scale, 239 | T_kc_tensors_curr, 240 | mode=self.warp_mode, 241 | verbose=verbose, 242 | return_color=verbose, 243 | ) 244 | ) 245 | points_c = points_c.view(-1, height, width, 3).permute(0, 3, 1, 2) 246 | points_k_to_c = points_k_to_c.view(-1, height, width, 3).permute( 247 | 0, 3, 1, 2 248 | ) 249 | points_loss = F.mse_loss( 250 | points_c, points_k_to_c, reduction="none" 251 | ) # [N, 3, H, W] 252 | 253 | masks_static = masks_c * warped_masks_key 254 | 255 | points_loss = torch.cat( 256 | [points_loss[:key_idx], points_loss[key_idx + 1 :]], dim=0 257 | ) 258 | masks_static = torch.cat( 259 | [masks_static[:key_idx], masks_static[key_idx + 1 :]], dim=0 260 | ) 261 | depth_filter = torch.cat( 262 | [depth_tensors[:key_idx], depth_tensors[key_idx + 1 :]], dim=0 263 | ).repeat(1, 3, 1, 1) 264 | depth_filter = depth_filter < depth_filter_thresh 265 | points_loss = points_loss * masks_static * depth_filter # [N, 3, H, W] 266 | 267 | loss_geo = points_loss.mean() 268 | loss_scale_reg = F.l1_loss(delta_scale, torch.ones_like(delta_scale)) * 10 269 | loss_translation_reg = F.l1_loss( 270 | delta_translation, torch.zeros_like(delta_translation) 271 | ) 272 | loss_rot_reg = F.l1_loss( 273 | delta_r6d, 274 | matrix_to_rotation_6d(torch.eye(3, device=self.device).float())[None].repeat(len(T_kc_tensors), 1) 275 | ) 276 | loss_reg = loss_scale_reg + loss_translation_reg + loss_rot_reg 277 | loss = loss_geo + loss_reg 278 | 279 | # Compute the loss and backprop 280 | loss.backward() 281 | optimizer.step() 282 | 283 | T_kc_final = T_kc_tensors.clone() 284 | delta_rot = rotation_6d_to_matrix(delta_r6d).transpose(-1, -2) 285 | delta_T = ( 286 | torch.eye(4, device=self.device) 287 | .float()[None] 288 | .repeat(len(T_kc_tensors), 1, 1) 289 | ) 290 | delta_T[..., :3, :3] = delta_rot 291 | delta_T[..., :3, 3] = delta_translation 292 | T_kc_final = torch.matmul( 293 | delta_T, 294 | T_kc_tensors_curr, 295 | ) 296 | scale_final = scale_tensors_global * delta_scale 297 | return T_kc_final, scale_final 298 | 299 | def optimize_global_scale( 300 | self, 301 | rgb_tensors, 302 | depth_tensors, 303 | mask_tensors, 304 | colmap_results, 305 | ): 306 | # Prepare the results from colmap 307 | scale_init_tensors = [] 308 | metric_d_tensors, colmap_d_tensors, valid_d_tensors = [], [], [] 309 | frame_ids = list(colmap_results.keys()) 310 | frame_id_start = int(frame_ids[0]) 311 | for ii, fi in enumerate(frame_ids): 312 | depth = depth_tensors[ii, 0].cpu().numpy() 313 | mask = mask_tensors[ii, 0].cpu().numpy() 314 | uv = np.array(colmap_results[str(fi)]["uv"]).reshape(-1, 2) 315 | colmap_d = np.array(colmap_results[str(fi)]["d"]) 316 | uv_mask = np.logical_and( 317 | np.logical_and(uv[:, 0] >= 0, uv[:, 0] < depth.shape[1]), 318 | np.logical_and(uv[:, 1] >= 0, uv[:, 1] < depth.shape[0]), 319 | ) 320 | uv = uv[uv_mask] 321 | colmap_d = colmap_d[uv_mask] 322 | metric_d = depth[uv[:, 1], uv[:, 0]] # [S] 323 | valid_d = mask[uv[:, 1], uv[:, 0]] 324 | scale_init = np.median(colmap_d) / np.median(metric_d) 325 | scale_init_tensors.append(scale_init) 326 | metric_d_tensors.append(torch.from_numpy(metric_d).float()) 327 | colmap_d_tensors.append(torch.from_numpy(colmap_d).float()) 328 | valid_d_tensors.append(torch.from_numpy(valid_d).float()) 329 | 330 | # Prepare the tensors 331 | metric_d_tensors = torch.cat(metric_d_tensors).to(self.device) # [S] 332 | colmap_d_tensors = torch.cat(colmap_d_tensors).to(self.device) # [S] 333 | valid_d_tensors = torch.cat(valid_d_tensors).to(self.device) # [S] 334 | scale_init_tensors = ( 335 | torch.tensor(scale_init_tensors).float().to(self.device) 336 | ) # [N] 337 | 338 | # Start the optimization 339 | scale_global = torch.median(scale_init_tensors) 340 | delta_scale_global = torch.ones_like(scale_global) 341 | delta_scale_global.requires_grad = True 342 | optimizer_scale = torch.optim.Adam([delta_scale_global], self.lr_scale_global) 343 | 344 | # Do the optimization 345 | for it in range(self.num_iters_scale): 346 | optimizer_scale.zero_grad() 347 | scale_global_curr = scale_global * delta_scale_global 348 | loss_d = F.mse_loss( 349 | metric_d_tensors * scale_global_curr, 350 | colmap_d_tensors, 351 | reduction="none", 352 | ) # [S] 353 | 354 | loss_d = loss_d * valid_d_tensors 355 | loss_d = loss_d.sum() / valid_d_tensors.sum() 356 | loss_d.backward() 357 | optimizer_scale.step() 358 | scale_global_final = (scale_global * delta_scale_global).detach().clone() 359 | 360 | # Compute the amount of valid landmarks 361 | scale_global_np = scale_global_final.detach().cpu().numpy() 362 | key_idx, key_valid_diff_d = 0, -np.inf 363 | for ii, fi in enumerate(frame_ids): 364 | depth = depth_tensors[ii, 0].cpu().numpy() 365 | mask = mask_tensors[ii, 0].cpu().numpy() 366 | uv = np.array(colmap_results[str(fi)]["uv"]).reshape(-1, 2) 367 | colmap_d = np.array(colmap_results[str(fi)]["d"]) 368 | uv_mask = np.logical_and( 369 | np.logical_and(uv[:, 0] >= 0, uv[:, 0] < depth.shape[1]), 370 | np.logical_and(uv[:, 1] >= 0, uv[:, 1] < depth.shape[0]), 371 | ) 372 | uv = uv[uv_mask] 373 | colmap_d = colmap_d[uv_mask] 374 | metric_d = depth[uv[:, 1], uv[:, 0]] # [S] 375 | valid_d = mask[uv[:, 1], uv[:, 0]] # [S] 376 | diff_d = np.abs(colmap_d / scale_global_np - metric_d) # [S] 377 | diff_d[valid_d == 0] = np.inf 378 | 379 | # Compute the amount of valid landmarks 380 | valid_diff_d = (diff_d < 0.07).sum() 381 | if valid_diff_d > key_valid_diff_d: 382 | key_idx = ii 383 | key_valid_diff_d = valid_diff_d 384 | return scale_init_tensors, scale_global_final, key_idx 385 | 386 | -------------------------------------------------------------------------------- /models/attention.py: -------------------------------------------------------------------------------- 1 | # Code from https://github.com/krasserm/perceiver-io/blob/main/perceiver/model/core/modules.py 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from typing import List, Optional, Tuple 6 | from collections import OrderedDict 7 | from einops import rearrange 8 | 9 | KVCache = Tuple[torch.Tensor, torch.Tensor] 10 | 11 | class RotaryPositionEmbedding: 12 | # Specified in https://arxiv.org/abs/2104.09864 13 | # Modified from https://github.com/lucidrains/rotary-embedding-torch 14 | def __init__(self, frq_pos_enc: torch.Tensor, right_align: bool = False): 15 | # frq_pos_enc shape is (b, n, c). 16 | # frq_pos_enc is broadcast to (b, h, n, c). 17 | self.frq_pos_enc = rearrange(frq_pos_enc, "b n c -> b 1 n c") 18 | self.rotate_dim = frq_pos_enc.shape[-1] 19 | self.right_align = right_align 20 | 21 | def rotate(self, t): 22 | seq_len = t.shape[-2] 23 | if self.right_align: 24 | # q and k are right-aligned in Perceiver AR 25 | pos_enc = self.frq_pos_enc[..., -seq_len:, :] 26 | else: 27 | # q and k are left-aligned 28 | pos_enc = self.frq_pos_enc[..., :seq_len, :] 29 | 30 | t_rot, t_pass = t[..., : self.rotate_dim], t[..., self.rotate_dim :] 31 | t_rot = (t_rot * pos_enc.cos()) + (self._rotate_half(t_rot) * pos_enc.sin()) 32 | 33 | return torch.cat((t_rot, t_pass), dim=-1) 34 | 35 | @staticmethod 36 | def _rotate_half(x): 37 | # Rearranges channel dimension [x1, x2, x3, x4, ...] -> [-x2, x1, -x4, x3, ...] 38 | x = rearrange(x, "... (c r) -> ... c r", r=2) 39 | x1, x2 = x.unbind(dim=-1) 40 | x = torch.stack((-x2, x1), dim=-1) 41 | return rearrange(x, "... c r -> ... (c r)") 42 | 43 | 44 | class ModuleOutput(OrderedDict): 45 | def __getattr__(self, name): 46 | if name in self: 47 | return self[name] 48 | else: 49 | raise AttributeError("No such attribute: " + name) 50 | 51 | def __setattr__(self, name, value): 52 | self[name] = value 53 | 54 | def __delattr__(self, name): 55 | if name in self: 56 | del self[name] 57 | else: 58 | raise AttributeError("No such attribute: " + name) 59 | 60 | 61 | class Residual(nn.Module): 62 | def __init__(self, module: nn.Module, dropout: float = 0.0): 63 | super().__init__() 64 | self.module = module 65 | self.dropout = nn.Dropout(dropout) 66 | 67 | def forward(self, *args, **kwargs): 68 | output = self.module(*args, **kwargs) 69 | output.last_hidden_state = self.dropout(output.last_hidden_state) + args[0] 70 | return output 71 | 72 | 73 | class MultiHeadAttention(nn.Module): 74 | def __init__( 75 | self, 76 | num_heads: int, 77 | num_q_input_channels: int, 78 | num_kv_input_channels: int, 79 | num_qk_channels: Optional[int] = None, 80 | num_v_channels: Optional[int] = None, 81 | num_output_channels: Optional[int] = None, 82 | max_heads_parallel: Optional[int] = None, 83 | causal_attention: bool = False, 84 | dropout: float = 0.0, 85 | qkv_bias: bool = True, 86 | out_bias: bool = True, 87 | ): 88 | """Multi-head attention as specified in https://arxiv.org/abs/2107.14795 Appendix E plus support for rotary 89 | position embeddings (https://arxiv.org/abs/2104.09864) and causal attention. Causal attention requires 90 | queries and keys to be right-aligned, if they have different length. 91 | 92 | :param num_heads: Number of attention heads. 93 | :param num_q_input_channels: Number of query input channels. 94 | :param num_kv_input_channels: Number of key/value input channels. 95 | :param num_qk_channels: Number of query and key channels. Default is number `num_q_input_channels` 96 | :param num_v_channels: Number of value channels. Default is `num_qk_channels`. 97 | :param num_output_channels: Number of output channels. Default is `num_q_input_channels` 98 | :param max_heads_parallel: Maximum number of heads to be processed in parallel. Default is `num_heads`. 99 | :param causal_attention: Whether to apply a causal attention mask. Default is `False`. 100 | :param dropout: Dropout probability for attention matrix values. Default is `0.0` 101 | :param qkv_bias: Whether to use a bias term for query, key and value projections. Default is `True`. 102 | :param qkv_bias: Whether to use a bias term for output projection. Default is `True`. 103 | """ 104 | super().__init__() 105 | 106 | if num_qk_channels is None: 107 | num_qk_channels = num_q_input_channels 108 | 109 | if num_v_channels is None: 110 | num_v_channels = num_qk_channels 111 | 112 | if num_output_channels is None: 113 | num_output_channels = num_q_input_channels 114 | 115 | if num_qk_channels % num_heads != 0: 116 | raise ValueError("num_qk_channels must be divisible by num_heads") 117 | 118 | if num_v_channels % num_heads != 0: 119 | raise ValueError("num_v_channels must be divisible by num_heads") 120 | 121 | num_qk_channels_per_head = num_qk_channels // num_heads 122 | 123 | self.dp_scale = num_qk_channels_per_head**-0.5 124 | self.num_heads = num_heads 125 | self.num_qk_channels = num_qk_channels 126 | self.num_v_channels = num_v_channels 127 | self.causal_attention = causal_attention 128 | 129 | if max_heads_parallel is None: 130 | self.max_heads_parallel = num_heads 131 | else: 132 | self.max_heads_parallel = max_heads_parallel 133 | 134 | self.q_proj = nn.Linear(num_q_input_channels, num_qk_channels, bias=qkv_bias) 135 | self.k_proj = nn.Linear(num_kv_input_channels, num_qk_channels, bias=qkv_bias) 136 | self.v_proj = nn.Linear(num_kv_input_channels, num_v_channels, bias=qkv_bias) 137 | self.o_proj = nn.Linear(num_v_channels, num_output_channels, bias=out_bias) 138 | self.dropout = nn.Dropout(dropout) 139 | 140 | def forward( 141 | self, 142 | x_q: torch.Tensor, 143 | x_kv: torch.Tensor, 144 | pad_mask: Optional[torch.Tensor] = None, 145 | rot_pos_emb_q: Optional[RotaryPositionEmbedding] = None, 146 | rot_pos_emb_k: Optional[RotaryPositionEmbedding] = None, 147 | kv_cache: Optional[KVCache] = None, 148 | ): 149 | """... 150 | 151 | :param x_q: Query input of shape (B, N, D) where B is the batch size, N the query sequence length and D the 152 | number of query input channels (= `num_q_input_channels`) 153 | :param x_kv: Key/value input of shape (B, L, C) where B is the batch size, L the key/value sequence length and C 154 | are the number of key/value input channels (= `num_kv_input_channels`) 155 | :param pad_mask: Boolean key padding mask. `True` values indicate padding tokens. 156 | :param rot_pos_emb_q: Applies a rotary position embedding to query i.e. if defined, rotates the query. 157 | :param rot_pos_emb_k: Applies a rotary position embedding to key i.e. if defined, rotates the key. 158 | :param kv_cache: cache with past keys and values. 159 | :return: attention result of shape (B, N, F) where B is the batch size, N the query sequence length and F the 160 | number of output channels (= `num_output_channels`) 161 | """ 162 | 163 | q = self.q_proj(x_q) 164 | k = self.k_proj(x_kv) 165 | v = self.v_proj(x_kv) 166 | 167 | if kv_cache is not None: 168 | k_cache, v_cache = kv_cache 169 | k = torch.cat([k_cache, k], dim=1) 170 | v = torch.cat([v_cache, v], dim=1) 171 | kv_cache = (k, v) 172 | 173 | q, k, v = (rearrange(x, "b n (h c) -> b h n c", h=self.num_heads) for x in [q, k, v]) 174 | q = q * self.dp_scale 175 | 176 | if rot_pos_emb_q is not None: 177 | q = rot_pos_emb_q.rotate(q) 178 | 179 | if rot_pos_emb_k is not None: 180 | k = rot_pos_emb_k.rotate(k) 181 | 182 | if pad_mask is not None: 183 | pad_mask = rearrange(pad_mask, "b j -> b 1 1 j") 184 | 185 | if self.causal_attention: 186 | i = q.shape[2] 187 | j = k.shape[2] 188 | 189 | # If q and k have different length, causal masking only works if they are right-aligned. 190 | causal_mask = torch.ones((i, j), device=x_q.device, dtype=torch.bool).triu(j - i + 1) 191 | 192 | o_chunks = [] 193 | 194 | # Only process a given maximum number of heads in 195 | # parallel, using several iterations, if necessary. 196 | for q_chunk, k_chunk, v_chunk in zip( 197 | q.split(self.max_heads_parallel, dim=1), 198 | k.split(self.max_heads_parallel, dim=1), 199 | v.split(self.max_heads_parallel, dim=1), 200 | ): 201 | attn = torch.einsum("b h i c, b h j c -> b h i j", q_chunk, k_chunk) 202 | attn_max_neg = -torch.finfo(attn.dtype).max 203 | 204 | if pad_mask is not None: 205 | attn.masked_fill_(pad_mask, attn_max_neg) 206 | 207 | if self.causal_attention: 208 | attn.masked_fill_(causal_mask, attn_max_neg) 209 | 210 | attn = attn.softmax(dim=-1) 211 | attn = self.dropout(attn) 212 | 213 | o_chunk = torch.einsum("b h i j, b h j c -> b h i c", attn, v_chunk) 214 | o_chunks.append(o_chunk) 215 | 216 | o = torch.cat(o_chunks, dim=1) 217 | o = rearrange(o, "b h n c -> b n (h c)", h=self.num_heads) 218 | o = self.o_proj(o) 219 | 220 | return ModuleOutput(last_hidden_state=o, kv_cache=kv_cache) 221 | 222 | 223 | class CrossAttention(nn.Module): 224 | def __init__( 225 | self, 226 | num_heads: int, 227 | num_q_input_channels: int, 228 | num_kv_input_channels: int, 229 | num_qk_channels: Optional[int] = None, 230 | num_v_channels: Optional[int] = None, 231 | max_heads_parallel: Optional[int] = None, 232 | causal_attention: bool = False, 233 | dropout: float = 0.0, 234 | qkv_bias: bool = True, 235 | out_bias: bool = True, 236 | ): 237 | """Pre-layer-norm cross-attention (see `MultiHeadAttention` for attention details).""" 238 | super().__init__() 239 | self.q_norm = nn.LayerNorm(num_q_input_channels) 240 | self.kv_norm = nn.LayerNorm(num_kv_input_channels) 241 | self.attention = MultiHeadAttention( 242 | num_heads=num_heads, 243 | num_q_input_channels=num_q_input_channels, 244 | num_kv_input_channels=num_kv_input_channels, 245 | num_qk_channels=num_qk_channels, 246 | num_v_channels=num_v_channels, 247 | max_heads_parallel=max_heads_parallel, 248 | causal_attention=causal_attention, 249 | dropout=dropout, 250 | qkv_bias=qkv_bias, 251 | out_bias=out_bias, 252 | ) 253 | 254 | def forward( 255 | self, 256 | x_q: torch.Tensor, 257 | x_kv: Optional[torch.Tensor] = None, 258 | x_kv_prefix: Optional[torch.Tensor] = None, 259 | pad_mask: Optional[torch.Tensor] = None, 260 | rot_pos_emb_q: Optional[RotaryPositionEmbedding] = None, 261 | rot_pos_emb_k: Optional[RotaryPositionEmbedding] = None, 262 | kv_cache: Optional[KVCache] = None, 263 | ): 264 | """Pre-layer-norm cross-attention of query input `x_q` to key/value input (`x_kv` or `x_kv_prefix`). 265 | 266 | If `x_kv_prefix` is defined, the entire key/value input is a concatenation of `x_kv_prefix` and `x_q` along 267 | the sequence dimension. In this case, the query attends to itself at the end of the key/value sequence (use 268 | case: Perceiver AR). If `x_kv_prefix` is not defined, `x_kv` is the entire key/value input. 269 | """ 270 | x_q = self.q_norm(x_q) 271 | 272 | if x_kv is None: 273 | x_kv_prefix = self.kv_norm(x_kv_prefix) 274 | x_kv = torch.cat([x_kv_prefix, x_q], dim=1) 275 | else: 276 | x_kv = self.kv_norm(x_kv) 277 | 278 | return self.attention( 279 | x_q, x_kv, pad_mask=pad_mask, rot_pos_emb_q=rot_pos_emb_q, rot_pos_emb_k=rot_pos_emb_k, kv_cache=kv_cache 280 | ) 281 | 282 | 283 | class SelfAttention(nn.Module): 284 | def __init__( 285 | self, 286 | num_heads: int, 287 | num_channels: int, 288 | num_qk_channels: Optional[int] = None, 289 | num_v_channels: Optional[int] = None, 290 | max_heads_parallel: Optional[int] = None, 291 | causal_attention: bool = False, 292 | dropout: float = 0.0, 293 | qkv_bias: bool = True, 294 | out_bias: bool = True, 295 | ): 296 | """Pre-layer norm self-attention (see `MultiHeadAttention` and for attention details).""" 297 | super().__init__() 298 | self.norm = nn.LayerNorm(num_channels) 299 | self.attention = MultiHeadAttention( 300 | num_heads=num_heads, 301 | num_q_input_channels=num_channels, 302 | num_kv_input_channels=num_channels, 303 | num_qk_channels=num_qk_channels, 304 | num_v_channels=num_v_channels, 305 | max_heads_parallel=max_heads_parallel, 306 | causal_attention=causal_attention, 307 | dropout=dropout, 308 | qkv_bias=qkv_bias, 309 | out_bias=out_bias, 310 | ) 311 | 312 | def forward( 313 | self, 314 | x: torch.Tensor, 315 | pad_mask: Optional[torch.Tensor] = None, 316 | rot_pos_emb: Optional[RotaryPositionEmbedding] = None, 317 | kv_cache: Optional[KVCache] = None, 318 | ): 319 | """Pre-layer-norm self-attention of input `x`.""" 320 | x = self.norm(x) 321 | return self.attention( 322 | x, 323 | x, 324 | pad_mask=pad_mask, 325 | rot_pos_emb_q=rot_pos_emb, 326 | rot_pos_emb_k=rot_pos_emb, 327 | kv_cache=kv_cache, 328 | ) 329 | 330 | 331 | class AbstractAttentionLayer(nn.Sequential): 332 | def empty_kv_cache(self, x) -> KVCache: 333 | k_cache = torch.empty(x.shape[0], 0, self.num_qk_channels, dtype=x.dtype, device=x.device) 334 | v_cache = torch.empty(x.shape[0], 0, self.num_v_channels, dtype=x.dtype, device=x.device) 335 | return k_cache, v_cache 336 | 337 | def forward(self, *args, kv_cache: Optional[KVCache] = None, **kwargs): 338 | attn_output = self[0](*args, kv_cache=kv_cache, **kwargs) 339 | mlp_output = self[1](attn_output.last_hidden_state) 340 | return ModuleOutput(last_hidden_state=mlp_output.last_hidden_state, kv_cache=attn_output.kv_cache) 341 | 342 | 343 | class CrossAttentionLayer(AbstractAttentionLayer): 344 | def __init__( 345 | self, 346 | num_heads: int, 347 | num_q_input_channels: int, 348 | num_kv_input_channels: int, 349 | num_qk_channels: Optional[int] = None, 350 | num_v_channels: Optional[int] = None, 351 | max_heads_parallel: Optional[int] = None, 352 | causal_attention: bool = False, 353 | widening_factor: int = 1, 354 | dropout: float = 0.0, 355 | residual_dropout: float = 0.0, 356 | attention_residual: bool = True, 357 | qkv_bias: bool = True, 358 | out_bias: bool = True, 359 | mlp_bias: bool = True, 360 | ): 361 | cross_attn = CrossAttention( 362 | num_heads=num_heads, 363 | num_q_input_channels=num_q_input_channels, 364 | num_kv_input_channels=num_kv_input_channels, 365 | num_qk_channels=num_qk_channels, 366 | num_v_channels=num_v_channels, 367 | max_heads_parallel=max_heads_parallel, 368 | causal_attention=causal_attention, 369 | dropout=dropout, 370 | qkv_bias=qkv_bias, 371 | out_bias=out_bias, 372 | ) 373 | 374 | self.num_qk_channels = cross_attn.attention.num_qk_channels 375 | self.num_v_channels = cross_attn.attention.num_v_channels 376 | 377 | super().__init__( 378 | Residual(cross_attn, residual_dropout) if attention_residual else cross_attn, 379 | Residual(MLP(num_q_input_channels, widening_factor, bias=mlp_bias), residual_dropout), 380 | ) 381 | 382 | 383 | class SelfAttentionLayer(AbstractAttentionLayer): 384 | def __init__( 385 | self, 386 | num_heads: int, 387 | num_channels: int, 388 | num_qk_channels: Optional[int] = None, 389 | num_v_channels: Optional[int] = None, 390 | max_heads_parallel: Optional[int] = None, 391 | causal_attention: bool = False, 392 | widening_factor: int = 1, 393 | dropout: float = 0.0, 394 | residual_dropout: float = 0.0, 395 | qkv_bias: bool = True, 396 | out_bias: bool = True, 397 | mlp_bias: bool = True, 398 | ): 399 | self_attn = SelfAttention( 400 | num_heads=num_heads, 401 | num_channels=num_channels, 402 | num_qk_channels=num_qk_channels, 403 | num_v_channels=num_v_channels, 404 | max_heads_parallel=max_heads_parallel, 405 | causal_attention=causal_attention, 406 | dropout=dropout, 407 | qkv_bias=qkv_bias, 408 | out_bias=out_bias, 409 | ) 410 | 411 | self.num_qk_channels = self_attn.attention.num_qk_channels 412 | self.num_v_channels = self_attn.attention.num_v_channels 413 | 414 | super().__init__( 415 | Residual(self_attn, residual_dropout), 416 | Residual(MLP(num_channels, widening_factor, bias=mlp_bias), residual_dropout), 417 | ) 418 | 419 | 420 | class SelfAttentionBlock(nn.Sequential): 421 | def __init__( 422 | self, 423 | num_layers: int, 424 | num_heads: int, 425 | num_channels: int, 426 | num_qk_channels: Optional[int] = None, 427 | num_v_channels: Optional[int] = None, 428 | num_rotary_layers: int = 1, 429 | max_heads_parallel: Optional[int] = None, 430 | causal_attention: bool = False, 431 | widening_factor: int = 1, 432 | dropout: float = 0.0, 433 | residual_dropout: float = 0.0, 434 | qkv_bias: bool = True, 435 | out_bias: bool = True, 436 | mlp_bias: bool = True, 437 | ): 438 | layers = [ 439 | SelfAttentionLayer( 440 | num_heads=num_heads, 441 | num_channels=num_channels, 442 | num_qk_channels=num_qk_channels, 443 | num_v_channels=num_v_channels, 444 | max_heads_parallel=max_heads_parallel, 445 | causal_attention=causal_attention, 446 | widening_factor=widening_factor, 447 | dropout=dropout, 448 | residual_dropout=residual_dropout, 449 | qkv_bias=qkv_bias, 450 | out_bias=out_bias, 451 | mlp_bias=mlp_bias, 452 | ) 453 | for _ in range(num_layers) 454 | ] 455 | 456 | 457 | self.num_rotary_layers = num_rotary_layers 458 | super().__init__(*layers) 459 | 460 | def forward( 461 | self, 462 | x: torch.Tensor, 463 | pad_mask: Optional[torch.Tensor] = None, 464 | rot_pos_emb: Optional[RotaryPositionEmbedding] = None, 465 | kv_cache: Optional[List[KVCache]] = None, 466 | ): 467 | if kv_cache is None: 468 | kv_cache_updated = None 469 | else: 470 | if len(kv_cache) == 0: 471 | # initialize kv_cache for each self-attention layer 472 | kv_cache = [layer.empty_kv_cache(x) for layer in self] 473 | kv_cache_updated = [] 474 | 475 | for i, layer in enumerate(self): 476 | rot_pos_emb_use = i < self.num_rotary_layers or self.num_rotary_layers == -1 477 | rot_pos_emb_i = rot_pos_emb if rot_pos_emb_use else None 478 | 479 | kv_cache_i = None if kv_cache is None else kv_cache[i] 480 | output = layer(x, pad_mask=pad_mask, rot_pos_emb=rot_pos_emb_i, kv_cache=kv_cache_i) 481 | 482 | x = output.last_hidden_state 483 | 484 | if kv_cache_updated is not None: 485 | kv_cache_updated.append(output.kv_cache) 486 | 487 | return ModuleOutput(last_hidden_state=x, kv_cache=kv_cache_updated) 488 | 489 | class MLP(nn.Sequential): 490 | def __init__(self, num_channels: int, widening_factor: int, bias: bool = True): 491 | super().__init__( 492 | nn.LayerNorm(num_channels), 493 | nn.Linear(num_channels, widening_factor * num_channels, bias=bias), 494 | nn.GELU(), 495 | nn.Linear(widening_factor * num_channels, num_channels, bias=bias), 496 | ) 497 | 498 | def forward(self, x): 499 | return ModuleOutput(last_hidden_state=super().forward(x)) -------------------------------------------------------------------------------- /models/feature_extractors.py: -------------------------------------------------------------------------------- 1 | import einops 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from models.layers_2d import ( 8 | BackprojectDepth, 9 | load_clip, 10 | Project3D, 11 | ) 12 | from torchvision.ops import FeaturePyramidNetwork 13 | import torchvision 14 | from models.helpers import TSDFVolume, get_view_frustum 15 | from models.layers_3d import VoxelGridEncoder 16 | 17 | from typing import Union, List, Tuple 18 | 19 | from models.perceiver import FeaturePerceiver 20 | 21 | 22 | class MultiScaleImageFeatureExtractor(nn.Module): 23 | _RESNET_MEAN = [0.485, 0.456, 0.406] 24 | _RESNET_STD = [0.229, 0.224, 0.225] 25 | 26 | def __init__( 27 | self, 28 | modelname: str = "dino_vits16", 29 | freeze: bool = False, 30 | scale_factors: list = [1, 1 / 2, 1 / 3], 31 | embedding_dim: int = None, 32 | ): 33 | super().__init__() 34 | self.freeze = freeze 35 | self.scale_factors = scale_factors 36 | self.embedding_dim = embedding_dim 37 | 38 | if "res" in modelname: 39 | self._net = getattr(torchvision.models, modelname)(pretrained=True) 40 | self._output_dim = self._net.fc.weight.shape[1] 41 | self._net.fc = nn.Identity() 42 | elif "dinov2" in modelname: 43 | self._net = torch.hub.load("facebookresearch/dinov2", modelname) 44 | self._output_dim = self._net.norm.weight.shape[0] 45 | elif "dino" in modelname: 46 | self._net = torch.hub.load("facebookresearch/dino:main", modelname) 47 | self._output_dim = self._net.norm.weight.shape[0] 48 | else: 49 | raise ValueError(f"Unknown model name {modelname}") 50 | 51 | for name, value in ( 52 | ("_resnet_mean", self._RESNET_MEAN), 53 | ("_resnet_std", self._RESNET_STD), 54 | ): 55 | self.register_buffer( 56 | name, torch.FloatTensor(value).view(1, 3, 1, 1), persistent=False 57 | ) 58 | 59 | if self.freeze: 60 | for param in self.parameters(): 61 | param.requires_grad = False 62 | 63 | if self.embedding_dim is not None: 64 | self._last_layer = nn.Linear(self._output_dim, self.embedding_dim) 65 | self._output_dim = self.embedding_dim 66 | else: 67 | self._last_layer = nn.Identity() 68 | 69 | def get_output_dim(self): 70 | return self._output_dim 71 | 72 | def forward(self, image_rgb: torch.Tensor): 73 | img_normed = self._resnet_normalize_image(image_rgb) 74 | features = self._compute_multiscale_features(img_normed) 75 | return features 76 | 77 | def _resnet_normalize_image(self, img: torch.Tensor): 78 | return (img - self._resnet_mean) / self._resnet_std 79 | 80 | def _compute_multiscale_features(self, img_normed: torch.Tensor): 81 | multiscale_features = None 82 | 83 | if len(self.scale_factors) <= 0: 84 | raise ValueError( 85 | f"Wrong format of self.scale_factors: {self.scale_factors}" 86 | ) 87 | 88 | for scale_factor in self.scale_factors: 89 | if scale_factor == 1: 90 | inp = img_normed 91 | else: 92 | inp = self._resize_image(img_normed, scale_factor) 93 | 94 | if multiscale_features is None: 95 | multiscale_features = self._net(inp) 96 | else: 97 | multiscale_features += self._net(inp) 98 | 99 | averaged_features = multiscale_features / len(self.scale_factors) 100 | averaged_features = self._last_layer(averaged_features) 101 | return averaged_features 102 | 103 | @staticmethod 104 | def _resize_image(image: torch.Tensor, scale_factor: float): 105 | return nn.functional.interpolate( 106 | image, scale_factor=scale_factor, mode="bilinear", align_corners=False 107 | ) 108 | 109 | 110 | class TSDFMapFeatureExtractor(nn.Module): 111 | def __init__( 112 | self, 113 | input_image_shape, 114 | voxel_resolution=32, 115 | voxel_feature_dim=64, 116 | vlm_feature_attn_dim=256, 117 | # use_feature_decoder=True, 118 | ): 119 | super().__init__() 120 | self.input_image_shape = input_image_shape 121 | self.voxel_resolution = voxel_resolution 122 | self.embedding_dim = voxel_feature_dim 123 | self.backproject = BackprojectDepth(input_image_shape[0], input_image_shape[1]) 124 | self.project_3d = Project3D() 125 | 126 | # Load pretrained backbone 127 | # self.vlm, self.vlm_transform = clip.load("ViT-B/16", jit=False) 128 | self.vlm, self.vlm_transform = load_clip() 129 | 130 | self.vlm.float() 131 | for p in self.vlm.parameters(): 132 | p.requires_grad = False 133 | 134 | # Load 3D Unet 135 | self.tsdf_net = VoxelGridEncoder( 136 | self.voxel_resolution, c_dim=self.embedding_dim 137 | ) 138 | 139 | self.feature_pyramid = FeaturePyramidNetwork( 140 | [64, 256, 512, 1024, 2048], voxel_feature_dim 141 | ) 142 | 143 | self.feature_map_pyramid_keys = ["res1", "res2", "res3"] 144 | 145 | # Cross Attention Layer 146 | # self.action_proj = nn.Linear(vlm_feature_attn_dim, self.embedding_dim, bias=True) 147 | self.vlm_preceiver_pyramid = nn.ModuleList() 148 | self.vlm_proj_pyramid = nn.ModuleList() 149 | vlm_preceiver = FeaturePerceiver( 150 | transition_dim=self.embedding_dim, 151 | condition_dim=vlm_feature_attn_dim, 152 | time_emb_dim=0, 153 | ) 154 | vlm_proj = nn.Linear(vlm_preceiver.last_dim, self.embedding_dim, bias=True) 155 | for _ in range(len(self.feature_map_pyramid_keys)): 156 | 157 | self.vlm_preceiver_pyramid.append(vlm_preceiver) 158 | self.vlm_proj_pyramid.append(vlm_proj) 159 | # Feature projection layer 160 | proj_dim_in = voxel_feature_dim * (1 + len(self.feature_map_pyramid_keys)) 161 | self.proj = nn.Linear(proj_dim_in, self.embedding_dim, bias=True) 162 | 163 | def compute_tsdf_volume(self, color, depth, intrinsics, verbose=False): 164 | cam_pose = np.eye(4) 165 | tsdf_grid_batch = [] 166 | tsdf_bounds_batch = [] 167 | tsdf_color_batch = [] 168 | mesh_batch = [] 169 | for i in range(len(depth)): 170 | d_np = depth[i].cpu().numpy()[0] 171 | c_np = ( 172 | color[i].cpu().numpy().transpose(1, 2, 0) 173 | ) # [H, W, 3], requested by TSDFVolume 174 | K_np = intrinsics[i].cpu().numpy() 175 | view_frust_pts = get_view_frustum(d_np, K_np, cam_pose) 176 | vol_bnds = np.zeros((3, 2)) 177 | vol_bnds[:, 0] = np.minimum( 178 | vol_bnds[:, 0], np.amin(view_frust_pts, axis=1) 179 | ).min() 180 | vol_bnds[:, 1] = np.maximum( 181 | vol_bnds[:, 1], np.amax(view_frust_pts, axis=1) 182 | ).max() 183 | tsdf = TSDFVolume(vol_bnds, voxel_dim=self.voxel_resolution) 184 | tsdf.integrate(c_np * 255, d_np, K_np, cam_pose) 185 | tsdf_grid = torch.from_numpy(tsdf.get_tsdf_volume()) 186 | tsdf_grid_batch.append(tsdf_grid) 187 | tsdf_bounds_batch.append(torch.from_numpy(vol_bnds[0])) 188 | if verbose: 189 | mesh = tsdf.get_mesh() 190 | color_grid = torch.from_numpy(tsdf.get_color_volume()) / 255.0 191 | mesh_batch.append(mesh) 192 | tsdf_color_batch.append(color_grid) 193 | tsdf_bounds_batch = ( 194 | torch.stack(tsdf_bounds_batch, dim=0).to(depth.device).float() 195 | ) 196 | tsdf_grid_batch = torch.stack(tsdf_grid_batch, dim=0).to(depth.device).float() 197 | if verbose: 198 | tsdf_color_batch = ( 199 | torch.stack(tsdf_color_batch, dim=0).to(depth.device).float() 200 | ) 201 | return tsdf_grid_batch, tsdf_color_batch, tsdf_bounds_batch, mesh_batch 202 | return tsdf_grid_batch 203 | 204 | def compute_context_features( 205 | self, color, depth, intrinsics, tsdf=None, action_features=None 206 | ): 207 | if tsdf is None: 208 | tsdf = self.compute_tsdf_volume(color, depth, intrinsics) 209 | 210 | h_in, w_in = color.shape[-2:] 211 | 212 | color = self.vlm_transform(color) 213 | color_features = self.vlm(color) # [B, N, C] 214 | color_features = self.feature_pyramid(color_features) # [B, N, C] 215 | 216 | # Action grounding 217 | if action_features is not None: 218 | for i, k in enumerate(self.feature_map_pyramid_keys): 219 | color_feature_i = color_features[k] # 220 | h, w = color_feature_i.shape[-2:] 221 | color_feature_i = einops.rearrange(color_feature_i, "B C H W-> B (H W) C") 222 | color_feature_i = self.vlm_preceiver_pyramid[i](color_feature_i, action_features[:, None]) 223 | color_feature_i = self.vlm_proj_pyramid[i](color_feature_i) 224 | color_feature_i = einops.rearrange(color_feature_i, "B (H W) C -> B C H W", H=h, W=w) 225 | color_features[k] = color_feature_i 226 | 227 | color_features_pyramid = [] 228 | for i, k in enumerate(self.feature_map_pyramid_keys): 229 | color_feature_i = color_features[k] 230 | color_feature_i = F.interpolate( 231 | color_feature_i, size=(h_in, w_in), mode="bilinear" 232 | ) 233 | color_features_pyramid.append(color_feature_i) 234 | points_map_pyramid = [tsdf] * len(color_features_pyramid) # [B, D, H, W] 235 | points_pe_pyramid = [self.tsdf_net(tsdf)] * len(color_features_pyramid) # [B, P, D, H, W] 236 | 237 | # color_features = einops.rearrange( 238 | # color_features, "B (H W) C -> B C H W", H=h_out, W=w_out 239 | # ) 240 | # import pdb; pdb.set_trace() 241 | # if self.use_feature_decoder: 242 | # color_features = self.feature_decoder(color_features) 243 | 244 | # color_features = F.interpolate( 245 | # color_features, size=tuple(self.input_image_shape), mode="bilinear" 246 | # ) 247 | 248 | 249 | # # Compute the point features pyramid 250 | # for i, k in enumerate(self.feature_map_pyramid_keys): 251 | # color_feature_i = color_features[k] 252 | # color_feature_i = F.interpolate( 253 | # color_feature_i, size=tuple(self.input_image_shape), mode="bilinear" 254 | # ) 255 | 256 | # color_features_pyramid.append(color_feature_i) 257 | return color_features_pyramid, points_map_pyramid, points_pe_pyramid 258 | 259 | @staticmethod 260 | def interpolate_voxel_grid_features(voxel_grid, query_points, voxel_bounds): 261 | """ 262 | Parameters 263 | ---------- 264 | voxel_grid : torch.Tensor 265 | with shape [B, C, D, H, W] 266 | query_points : torch.Tensor 267 | _with shape [B, N, 3] 268 | voxel_bounds: torch.Tensor 269 | _with shape [B, 2] 270 | """ 271 | voxel_bounds = voxel_bounds.unsqueeze(-1).repeat(1, 1, 3) # [B, 2, 3] 272 | query_points = (query_points - voxel_bounds[:, 0:1]) / ( 273 | voxel_bounds[:, 1:2] - voxel_bounds[:, 0:1] 274 | ) 275 | query_grids = ( 276 | query_points * 2 - 1 277 | ) # Normalize the query points from [0, 1] to [-1, 1] 278 | query_grids = query_grids[ 279 | ..., [2, 1, 0] 280 | ] # Convert to the voxel grid coordinate system 281 | query_grids = query_grids[:, :, None, None] # [B, N, 1, 1, 3] 282 | query_features = F.grid_sample( 283 | voxel_grid, query_grids, mode="bilinear", align_corners=True 284 | ) # [B, C, N, 1, 1] 285 | query_features = query_features.squeeze(-1).squeeze(-1) # [B, C, N] 286 | return query_features 287 | 288 | def interpolate_image_grid_features(self, image_grid, query_points, intrinsics): 289 | """ 290 | Parameters 291 | ---------- 292 | image_grid : torch.Tensor 293 | with shape [B, C, H, W] 294 | query_points : torch.Tensor 295 | _with shape [B, N, 3] 296 | """ 297 | batch_size, _, height, width = image_grid.shape 298 | query_grids = self.project_3d(query_points, intrinsics) # [B, 2, N] 299 | query_grids[:, 0] = (query_grids[:, 0] / (width - 1)) * 2 - 1 300 | query_grids[:, 1] = (query_grids[:, 1] / (height - 1)) * 2 - 1 301 | query_grids = query_grids.permute(0, 2, 1)[:, :, None] # [B, N, 1, 2] 302 | query_featurs = F.grid_sample( 303 | image_grid, query_grids, mode="bilinear", align_corners=True 304 | ) # [B, C, N, 1] 305 | query_featurs = query_featurs.squeeze(-1) 306 | return query_featurs 307 | 308 | def forward( 309 | self, 310 | color_features_pyramid, 311 | points_map_pyramid, 312 | points_pe_pyramid, 313 | query_points, 314 | intrinsics, 315 | voxel_bounds, 316 | **kwargs, 317 | ): 318 | """_summary_ 319 | 320 | Parameters 321 | ---------- 322 | color_features_pyramid : list of torch.Tensor 323 | with shape [[B, C, H, W]] 324 | points_map_pyramid : list of torch.Tensor for TSDF volume 325 | [[B, D, H, W]] 326 | points_pe_pyramid : list of torch.Tensor for TSDF volume feature 327 | [[B, P, D, H, W]] 328 | query_points : query points 329 | [B, N, 3] 330 | intrinsics : torch.Tensor or np.ndarray 331 | [3, 3] 332 | voxel_bounds : _type_ 333 | [B, 2] 334 | 335 | Returns 336 | ------- 337 | torch.Tensor 338 | shape of [B, N, C*4] 339 | """ 340 | assert len(color_features_pyramid) == len(points_map_pyramid) 341 | assert len(color_features_pyramid) == len(points_pe_pyramid) 342 | batch_size, num_query_points, _ = query_points.shape 343 | features = [] 344 | 345 | for i in range(len(color_features_pyramid)): 346 | # Re-arrange to feature maps 347 | color_feature_i = color_features_pyramid[i] # [B, C, H, W] 348 | points_pe_i = points_pe_pyramid[i] # [B, P, D, H, W] 349 | points_map_i = points_map_pyramid[i][:, None] # [B, 1, D, H, W] 350 | 351 | if i == 0: 352 | # Interpolate the voxel grid features 353 | feat_occ = self.interpolate_voxel_grid_features( 354 | points_map_i, query_points, voxel_bounds 355 | ) 356 | 357 | # Interpolate the voxel grid features 358 | feat_3d = self.interpolate_voxel_grid_features( 359 | points_pe_i, query_points, voxel_bounds 360 | ) 361 | features.append(feat_3d) # [B, C, N] 362 | 363 | # Interpolate the 2D feature maps 364 | feat_2d = self.interpolate_image_grid_features( 365 | color_feature_i, query_points, intrinsics 366 | ) 367 | features.append(feat_2d) # [B, C, N] 368 | features = torch.cat(features, dim=1).permute(0, 2, 1) # [B, N, C*3] 369 | features = self.proj(features) # [B, N, C] 370 | return features 371 | 372 | class TSDFMapGeometryExtractor(nn.Module): 373 | def __init__( 374 | self, 375 | input_image_shape, 376 | voxel_resolution=64, 377 | voxel_feature_dim=64, 378 | ): 379 | super().__init__() 380 | self.input_image_shape = input_image_shape 381 | self.voxel_resolution = voxel_resolution 382 | self.embedding_dim = voxel_feature_dim 383 | self.backproject = BackprojectDepth(input_image_shape[0], input_image_shape[1]) 384 | self.project_3d = Project3D() 385 | 386 | # Load 3D Unet 387 | self.tsdf_net = VoxelGridEncoder( 388 | self.voxel_resolution, c_dim=self.embedding_dim 389 | ) 390 | 391 | def compute_tsdf_volume(self, color, depth, intrinsics, verbose=False): 392 | cam_pose = np.eye(4) 393 | tsdf_grid_batch = [] 394 | tsdf_bounds_batch = [] 395 | tsdf_color_batch = [] 396 | mesh_batch = [] 397 | for i in range(len(depth)): 398 | d_np = depth[i].cpu().numpy()[0] 399 | c_np = ( 400 | color[i].cpu().numpy().transpose(1, 2, 0) 401 | ) # [H, W, 3], requested by TSDFVolume 402 | K_np = intrinsics[i].cpu().numpy() 403 | view_frust_pts = get_view_frustum(d_np, K_np, cam_pose) 404 | vol_bnds = np.zeros((3, 2)) 405 | vol_bnds[:, 0] = np.minimum( 406 | vol_bnds[:, 0], np.amin(view_frust_pts, axis=1) 407 | ).min() 408 | vol_bnds[:, 1] = np.maximum( 409 | vol_bnds[:, 1], np.amax(view_frust_pts, axis=1) 410 | ).max() 411 | tsdf = TSDFVolume(vol_bnds, voxel_dim=self.voxel_resolution) 412 | tsdf.integrate(c_np * 255, d_np, K_np, cam_pose) 413 | tsdf_grid = torch.from_numpy(tsdf.get_tsdf_volume()) 414 | tsdf_grid_batch.append(tsdf_grid) 415 | tsdf_bounds_batch.append(torch.from_numpy(vol_bnds[0])) 416 | if verbose: 417 | mesh = tsdf.get_mesh() 418 | color_grid = torch.from_numpy(tsdf.get_color_volume()) / 255.0 419 | mesh_batch.append(mesh) 420 | tsdf_color_batch.append(color_grid) 421 | tsdf_bounds_batch = ( 422 | torch.stack(tsdf_bounds_batch, dim=0).to(depth.device).float() 423 | ) 424 | tsdf_grid_batch = torch.stack(tsdf_grid_batch, dim=0).to(depth.device).float() 425 | if verbose: 426 | tsdf_color_batch = ( 427 | torch.stack(tsdf_color_batch, dim=0).to(depth.device).float() 428 | ) 429 | return tsdf_grid_batch, tsdf_color_batch, tsdf_bounds_batch, mesh_batch 430 | return tsdf_grid_batch 431 | 432 | def compute_context_features( 433 | self, color, depth, intrinsics, tsdf=None, action_featurs=None 434 | ): 435 | if tsdf is None: 436 | tsdf = self.compute_tsdf_volume(color, depth, intrinsics) 437 | 438 | 439 | color_features_pyramid = [None] # [B, C, H, W] 440 | points_map_pyramid = [tsdf] # [B, D, H, W] 441 | points_pe_pyramid = [self.tsdf_net(tsdf)] # [B, P, D, H, W] 442 | 443 | return color_features_pyramid, points_map_pyramid, points_pe_pyramid 444 | 445 | @staticmethod 446 | def interpolate_voxel_grid_features(voxel_grid, query_points, voxel_bounds): 447 | """ 448 | Parameters 449 | ---------- 450 | voxel_grid : torch.Tensor 451 | with shape [B, C, D, H, W] 452 | query_points : torch.Tensor 453 | _with shape [B, N, 3] 454 | voxel_bounds: torch.Tensor 455 | _with shape [B, 2] 456 | """ 457 | voxel_bounds = voxel_bounds.unsqueeze(-1).repeat(1, 1, 3) # [B, 2, 3] 458 | query_points = (query_points - voxel_bounds[:, 0:1]) / ( 459 | voxel_bounds[:, 1:2] - voxel_bounds[:, 0:1] 460 | ) 461 | query_grids = ( 462 | query_points * 2 - 1 463 | ) # Normalize the query points from [0, 1] to [-1, 1] 464 | query_grids = query_grids[ 465 | ..., [2, 1, 0] 466 | ] # Convert to the voxel grid coordinate system 467 | query_grids = query_grids[:, :, None, None] # [B, N, 1, 1, 3] 468 | query_features = F.grid_sample( 469 | voxel_grid, query_grids, mode="bilinear", align_corners=True 470 | ) # [B, C, N, 1, 1] 471 | query_features = query_features.squeeze(-1).squeeze(-1) # [B, C, N] 472 | return query_features 473 | 474 | def interpolate_image_grid_features(self, image_grid, query_points, intrinsics): 475 | """ 476 | Parameters 477 | ---------- 478 | image_grid : torch.Tensor 479 | with shape [B, C, H, W] 480 | query_points : torch.Tensor 481 | _with shape [B, N, 3] 482 | """ 483 | batch_size, _, height, width = image_grid.shape 484 | query_grids = self.project_3d(query_points, intrinsics) # [B, 2, N] 485 | query_grids[:, 0] = (query_grids[:, 0] / (width - 1)) * 2 - 1 486 | query_grids[:, 1] = (query_grids[:, 1] / (height - 1)) * 2 - 1 487 | query_grids = query_grids.permute(0, 2, 1)[:, :, None] # [B, N, 1, 2] 488 | query_featurs = F.grid_sample( 489 | image_grid, query_grids, mode="bilinear", align_corners=True 490 | ) # [B, C, N, 1] 491 | query_featurs = query_featurs.squeeze(-1) 492 | return query_featurs 493 | 494 | def forward( 495 | self, 496 | color_features_pyramid, 497 | points_map_pyramid, 498 | points_pe_pyramid, 499 | query_points, 500 | intrinsics, 501 | voxel_bounds, 502 | **kwargs, 503 | ): 504 | """_summary_ 505 | 506 | Parameters 507 | ---------- 508 | color_features_pyramid : list of torch.Tensor 509 | with shape [[B, C, H, W]] 510 | points_map_pyramid : list of torch.Tensor for TSDF volume 511 | [[B, D, H, W]] 512 | points_pe_pyramid : list of torch.Tensor for TSDF volume feature 513 | [[B, P, D, H, W]] 514 | query_points : query points 515 | [B, N, 3] 516 | intrinsics : torch.Tensor or np.ndarray 517 | [3, 3] 518 | voxel_bounds : _type_ 519 | [B, 2] 520 | 521 | Returns 522 | ------- 523 | torch.Tensor 524 | shape of [B, N, C*4] 525 | """ 526 | assert len(color_features_pyramid) == len(points_map_pyramid) 527 | assert len(color_features_pyramid) == len(points_pe_pyramid) 528 | batch_size, num_query_points, _ = query_points.shape 529 | features = [] 530 | 531 | for i in range(len(color_features_pyramid)): 532 | # Re-arrange to feature maps 533 | color_feature_i = color_features_pyramid[i] # [B, C, H, W] 534 | points_pe_i = points_pe_pyramid[i] # [B, P, D, H, W] 535 | points_map_i = points_map_pyramid[i][:, None] # [B, 1, D, H, W] 536 | 537 | if i == 0: 538 | # Interpolate the voxel grid features 539 | feat_occ = self.interpolate_voxel_grid_features( 540 | points_map_i, query_points, voxel_bounds 541 | ) 542 | 543 | # Interpolate the voxel grid features 544 | feat_3d = self.interpolate_voxel_grid_features( 545 | points_pe_i, query_points, voxel_bounds 546 | ) 547 | features.append(feat_3d) # [B, C, N] 548 | 549 | # # Interpolate the 2D feature maps 550 | # feat_2d = self.interpolate_image_grid_features( 551 | # color_feature_i, query_points, intrinsics 552 | # ) 553 | # features.append(feat_2d) # [B, C, N] 554 | features = torch.cat(features, dim=1).permute(0, 2, 1) # [B, N, C*3] 555 | return features 556 | -------------------------------------------------------------------------------- /models/helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from numba import njit, prange 7 | from skimage import measure 8 | import open3d as o3d 9 | import time 10 | import torch 11 | from math import ceil 12 | from models.clip import clip, tokenize 13 | 14 | 15 | def compute_null_text_embeddings(vlm, batch_size=1, device="cuda"): 16 | action_tokens_null = tokenize("") 17 | action_tokens_null = action_tokens_null.repeat(batch_size, 1) 18 | action_tokens_null = action_tokens_null.to(device) 19 | action_feature_null = vlm.encode_text(action_tokens_null).float() 20 | return action_feature_null 21 | 22 | 23 | def fourier_positional_encoding(input, L): # [B,...,C] 24 | shape = input.shape 25 | freq = 2 ** torch.arange(L, dtype=torch.float32, device=input.device) * np.pi # [L] 26 | spectrum = input[..., None] * freq # [B,...,C,L] 27 | sin, cos = spectrum.sin(), spectrum.cos() # [B,...,C,L] 28 | input_enc = torch.stack([sin, cos], dim=-2) # [B,...,C,2,L] 29 | input_enc = input_enc.view(*shape[:-1], -1) # [B,...,2CL] 30 | return input_enc 31 | 32 | 33 | def exists(x): 34 | return x is not None 35 | 36 | 37 | def default(val, d): 38 | if exists(val): 39 | return val 40 | return d() if callable(d) else d 41 | 42 | 43 | def round_up_multiple(num, mult): 44 | return ceil(num / mult) * mult 45 | 46 | 47 | def extract(a, t, x_shape): 48 | b, *_ = t.shape 49 | out = a.gather(-1, t) 50 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 51 | 52 | 53 | def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32): 54 | """ 55 | cosine schedule 56 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ 57 | """ 58 | steps = timesteps + 1 59 | x = np.linspace(0, steps, steps) 60 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 61 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 62 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 63 | betas_clipped = np.clip(betas, a_min=0, a_max=0.999) 64 | return torch.tensor(betas_clipped, dtype=dtype) 65 | 66 | 67 | # -----------------------------------------------------------------------------# 68 | # ---------------------------------- losses -----------------------------------# 69 | # -----------------------------------------------------------------------------# 70 | 71 | 72 | class FocalLoss(nn.Module): 73 | def __init__(self, gamma: float = 0, size_average: bool = True): 74 | super(FocalLoss, self).__init__() 75 | self.gamma = gamma 76 | self.size_average = size_average 77 | 78 | def forward(self, input, target): 79 | logpt = F.log_softmax(input, dim=-1) 80 | logpt = logpt.gather(1, target.view(-1, 1)).view(-1) 81 | pt = logpt.exp() 82 | 83 | loss = -1 * (1 - pt) ** self.gamma * logpt 84 | if self.size_average: 85 | return loss.mean() 86 | else: 87 | return loss.sum() 88 | 89 | 90 | class WeightedLoss(nn.Module): 91 | 92 | def __init__(self, weights): 93 | super().__init__() 94 | self.register_buffer("weights", weights) 95 | 96 | def forward(self, pred, targ): 97 | """ 98 | pred, targ : tensor 99 | [ batch_size x horizon x transition_dim ] 100 | """ 101 | loss = self._loss(pred, targ) 102 | weighted_loss = (loss * self.weights).mean() 103 | return weighted_loss 104 | 105 | 106 | class WeightedL1(WeightedLoss): 107 | 108 | def _loss(self, pred, targ): 109 | return torch.abs(pred - targ) 110 | 111 | 112 | class WeightedL2(WeightedLoss): 113 | 114 | def _loss(self, pred, targ): 115 | return F.mse_loss(pred, targ, reduction="none") 116 | 117 | 118 | Losses = { 119 | "l1": WeightedL1, 120 | "l2": WeightedL2, 121 | } 122 | 123 | 124 | class EMA: 125 | """ 126 | empirical moving average 127 | """ 128 | 129 | def __init__(self, beta): 130 | super().__init__() 131 | self.beta = beta 132 | 133 | def update_model_average(self, ma_model, current_model): 134 | with torch.no_grad(): 135 | ema_state_dict = ma_model.state_dict() 136 | for key, value in current_model.state_dict().items(): 137 | ema_value = ema_state_dict[key] 138 | ema_value.copy_(self.beta * ema_value + (1.0 - self.beta) * value) 139 | 140 | 141 | # -----------------------------------------------------------------------------# 142 | # ---------------------------------- TSDF -----------------------------------# 143 | # -----------------------------------------------------------------------------# 144 | 145 | 146 | def get_view_frustum(depth_im, cam_intr, cam_pose): 147 | """Get corners of 3D camera view frustum of depth image""" 148 | im_h = depth_im.shape[0] 149 | im_w = depth_im.shape[1] 150 | max_depth = np.max(depth_im) 151 | view_frust_pts = np.array( 152 | [ 153 | (np.array([0, 0, 0, im_w, im_w]) - cam_intr[0, 2]) 154 | * np.array([0, max_depth, max_depth, max_depth, max_depth]) 155 | / cam_intr[0, 0], 156 | (np.array([0, 0, im_h, 0, im_h]) - cam_intr[1, 2]) 157 | * np.array([0, max_depth, max_depth, max_depth, max_depth]) 158 | / cam_intr[1, 1], 159 | np.array([0, max_depth, max_depth, max_depth, max_depth]), 160 | ] 161 | ) 162 | view_frust_pts = view_frust_pts.T @ cam_pose[:3, :3].T + cam_pose[:3, 3] 163 | return view_frust_pts.T 164 | 165 | 166 | # try: 167 | # import pycuda.driver as cuda 168 | # import pycuda.autoinit 169 | # from pycuda.compiler import SourceModule 170 | 171 | # FUSION_GPU_MODE = 1 172 | # except Exception as err: 173 | # print("Warning: {}".format(err)) 174 | # print("Failed to import PyCUDA. Running fusion in CPU mode.") 175 | # FUSION_GPU_MODE = 0 176 | 177 | 178 | class TSDFVolume: 179 | """ 180 | Volumetric with TSDF representation 181 | """ 182 | 183 | def __init__( 184 | self, 185 | vol_bounds: np.ndarray, 186 | voxel_dim: float, 187 | use_gpu: bool = False, 188 | verbose: bool = False, 189 | num_margin: float = 5.0, 190 | enable_color=True, 191 | unknown_free=True, 192 | ): 193 | """ 194 | Constructor 195 | 196 | :param vol_bounds: An ndarray is shape (3,2), define the min & max bounds of voxels. 197 | :param voxel_size: Voxel size in meters. 198 | :param use_gpu: Use GPU for voxel update. 199 | :param verbose: Print verbose message or not. 200 | """ 201 | 202 | vol_bounds = np.asarray(vol_bounds) 203 | assert vol_bounds.shape == (3, 2), "vol_bounds should be of shape (3,2)" 204 | 205 | self._verbose = verbose 206 | self._use_gpu = use_gpu 207 | self._vol_bounds = vol_bounds 208 | 209 | self._vol_dim = [voxel_dim, voxel_dim, voxel_dim] 210 | self._vox_size = float( 211 | (self._vol_bounds[0, 1] - self._vol_bounds[0, 0]) / voxel_dim 212 | ) 213 | self._trunc_margin = num_margin * self._vox_size # truncation on SDF 214 | 215 | # Check GPU 216 | if self._use_gpu: 217 | if torch.cuda.is_available(): 218 | if self._verbose: 219 | print("# Using GPU mode") 220 | self._device = torch.device("cuda:0") 221 | else: 222 | if self._verbose: 223 | print("# Not available CUDA device, using CPU mode") 224 | self._device = torch.device("cpu") 225 | else: 226 | if self._verbose: 227 | print("# Using CPU mode") 228 | self._device = torch.device("cpu") 229 | 230 | # Coordinate origin of the volume, set as the min value of volume bounds 231 | self._vol_origin = torch.tensor( 232 | self._vol_bounds[:, 0].copy(order="C"), device=self._device 233 | ).float() 234 | 235 | # Grid coordinates of voxels 236 | xx, yy, zz = torch.meshgrid( 237 | torch.arange(self._vol_dim[0]), 238 | torch.arange(self._vol_dim[1]), 239 | torch.arange(self._vol_dim[2]), 240 | indexing="ij", 241 | ) 242 | self._vox_coords = ( 243 | torch.cat([xx.reshape(1, -1), yy.reshape(1, -1), zz.reshape(1, -1)], dim=0) 244 | .int() 245 | .T 246 | ) 247 | if self._use_gpu: 248 | self._vox_coords = self._vox_coords.cuda() 249 | 250 | # World coordinates of voxel centers 251 | self._world_coords = self.vox2world( 252 | self._vol_origin, self._vox_coords, self._vox_size 253 | ) 254 | self.enable_color = enable_color 255 | 256 | # TSDF & weights 257 | self._tsdf_vol = torch.ones( 258 | size=self._vol_dim, device=self._device, dtype=torch.float32 259 | ) 260 | self._weight_vol = torch.zeros( 261 | size=self._vol_dim, device=self._device, dtype=torch.float32 262 | ) 263 | if self.enable_color: 264 | self._color_vol = torch.zeros( 265 | size=[*self._vol_dim, 3], device=self._device, dtype=torch.float32 266 | ) 267 | 268 | # Mesh paramters 269 | self._mesh = o3d.geometry.TriangleMesh() 270 | self.unknown_free = unknown_free 271 | @staticmethod 272 | def vox2world(vol_origin: torch.Tensor, vox_coords: torch.Tensor, vox_size): 273 | """ 274 | Converts voxel grid coordinates to world coordinates 275 | 276 | :param vol_origin: Origin of the volume in world coordinates, (3,1). 277 | :parma vol_coords: List of all grid coordinates in the volume, (N,3). 278 | :param vol_size: Size of volume. 279 | :retrun: Grid points under world coordinates. Tensor with shape (N, 3) 280 | """ 281 | 282 | cam_pts = torch.empty_like(vox_coords, dtype=torch.float32) 283 | cam_pts = vol_origin + (vox_size * vox_coords) 284 | 285 | return cam_pts 286 | 287 | @staticmethod 288 | def cam2pix(cam_pts: torch.Tensor, intrinsics: torch.Tensor): 289 | """ 290 | Convert points in camera coordinate to pixel coordinates 291 | 292 | :param cam_pts: Points in camera coordinates, (N,3). 293 | :param intrinsics: Vamera intrinsics, (3,3). 294 | :return: Pixel coordinate (u,v) cooresponding to input points. Tensor with shape (N, 2). 295 | """ 296 | 297 | cam_pts_z = cam_pts[:, 2].repeat(3, 1).T 298 | pix = torch.round((cam_pts @ intrinsics.T) / cam_pts_z) 299 | 300 | return pix 301 | 302 | @staticmethod 303 | def ridgid_transform(points: torch.Tensor, transform: torch.Tensor): 304 | """ 305 | Apply rigid transform (4,4) on points 306 | 307 | :param points: Points, shape (N,3). 308 | :param transform: Tranform matrix, shape (4,4). 309 | :return: Points after transform. 310 | """ 311 | 312 | points_h = torch.cat( 313 | [points, torch.ones((points.shape[0], 1), device=points.device)], 1 314 | ) 315 | points_h = (transform @ points_h.T).T 316 | 317 | return points_h[:, :3] 318 | 319 | def get_tsdf_volume(self): 320 | return self._tsdf_vol.cpu().numpy() 321 | 322 | def get_color_volume(self): 323 | return self._color_vol.permute(3, 0, 1, 2).cpu().numpy() 324 | 325 | def get_mesh(self): 326 | """ 327 | Get mesh. 328 | """ 329 | return self._mesh 330 | 331 | def integrate(self, color_img, depth_img, intrinsic, cam_pose, weight: float = 1.0): 332 | """ 333 | Integrate an depth image to the TSDF volume 334 | 335 | :param depth_img: depth image with depth value in meter. 336 | :param intrinsics: camera intrinsics of shape (3,3). 337 | :param cam_pose: camera pose, transform matrix of shape (4,4) 338 | :param weight: weight assign for current frame, higher value indicate higher confidence 339 | """ 340 | 341 | time_begin = time.time() 342 | img_h, img_w = depth_img.shape 343 | depth_img = torch.from_numpy(depth_img).float().to(self._device) 344 | color_img = torch.from_numpy(color_img).float().to(self._device) # [H, W, 3] 345 | cam_pose = torch.from_numpy(cam_pose).float().to(self._device) 346 | intrinsic = torch.from_numpy(intrinsic).float().to(self._device) 347 | 348 | # TODO: 349 | # Better way to select valid voxels. 350 | # - Current: 351 | # -> Back project all voxels to frame pixels according to current camera pose. 352 | # -> Select valid pixels within frame size. 353 | # - Possible: 354 | # -> Project pixel to voxel coordinates 355 | # -> hash voxel coordinates 356 | # -> dynamically allocate voxel chunks 357 | 358 | # Get the world coordinates of all voxels 359 | # world_points = geometry.vox2world(self._vol_origin, self._vox_coords, self._vox_size) 360 | 361 | # Get voxel centers under camera coordinates 362 | world_points = self.ridgid_transform( 363 | self._world_coords, cam_pose.inverse() 364 | ) # [N^3, 3] 365 | 366 | # Get the pixel coordinates (u,v) of all voxels under current camere pose 367 | # Multiple voxels can be projected to a same (u,v) 368 | voxel_uv = self.cam2pix(world_points, intrinsic) # [N^3, 3] 369 | voxel_u, voxel_v = voxel_uv[:, 0], voxel_uv[:, 1] # [N^3], [N^3] 370 | voxel_z = world_points[:, 2] 371 | 372 | # Filter out voxels points that visible in current frame 373 | pixel_mask = torch.logical_and( 374 | voxel_u >= 0, 375 | torch.logical_and( 376 | voxel_u < img_w, 377 | torch.logical_and( 378 | voxel_v >= 0, torch.logical_and(voxel_v < img_h, voxel_z > 0) 379 | ), 380 | ), 381 | ) 382 | 383 | # Get depth value 384 | depth_value = torch.zeros(voxel_u.shape, device=self._device) 385 | depth_value[pixel_mask] = depth_img[ 386 | voxel_v[pixel_mask].long(), voxel_u[pixel_mask].long() 387 | ] 388 | 389 | # Compute and Integrate TSDF 390 | sdf_value = depth_value - voxel_z # Compute SDF 391 | if self.unknown_free: 392 | voxel_mask = torch.logical_and( 393 | depth_value > 0, sdf_value >= -self._trunc_margin 394 | ) # Truncate SDF 395 | else: 396 | voxel_mask = depth_value > 0 # Truncate SDF 397 | 398 | tsdf_value = torch.minimum( 399 | torch.ones_like(sdf_value, device=self._device), 400 | sdf_value / self._trunc_margin, 401 | ) 402 | tsdf_value = tsdf_value[voxel_mask] 403 | # Get coordinates of valid voxels with valid TSDF value 404 | valid_vox_x = self._vox_coords[voxel_mask, 0].long() 405 | valid_vox_y = self._vox_coords[voxel_mask, 1].long() 406 | valid_vox_z = self._vox_coords[voxel_mask, 2].long() 407 | 408 | # Update TSDF of cooresponding voxels 409 | weight_old = self._weight_vol[valid_vox_x, valid_vox_y, valid_vox_z] 410 | tsdf_old = self._tsdf_vol[valid_vox_x, valid_vox_y, valid_vox_z] 411 | 412 | if self.enable_color: 413 | color_value = torch.zeros([voxel_u.shape[0], 3], device=self._device) 414 | color_value[pixel_mask] = color_img[ 415 | voxel_v[pixel_mask].long(), voxel_u[pixel_mask].long(), : 416 | ] 417 | color_value = color_value[voxel_mask] 418 | color_old = self._color_vol[valid_vox_x, valid_vox_y, valid_vox_z] 419 | 420 | else: 421 | color_value = None 422 | color_old = None 423 | tsdf_new, color_new, weight_new = self.update_tsdf( 424 | tsdf_old, tsdf_value, color_old, color_value, weight_old, weight 425 | ) 426 | 427 | self._tsdf_vol[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_new 428 | self._weight_vol[valid_vox_x, valid_vox_y, valid_vox_z] = weight_new 429 | 430 | if self.enable_color: 431 | self._color_vol[valid_vox_x, valid_vox_y, valid_vox_z] = color_new 432 | 433 | if self._verbose: 434 | print("# Update {} voxels.".format(len(tsdf_new))) 435 | print( 436 | "# Integration Timing: {:.5f} (second).".format( 437 | time.time() - time_begin 438 | ) 439 | ) 440 | 441 | def get_mesh(self): 442 | """ 443 | Extract mesh from current TSDF volume. 444 | """ 445 | 446 | time_begin = time.time() 447 | 448 | if self._use_gpu: 449 | tsdf_vol = self._tsdf_vol.cpu().numpy() 450 | vol_origin = self._vol_origin.cpu().numpy() 451 | if self.enable_color: 452 | color_vol = self._color_vol.cpu().numpy() / 255 453 | 454 | else: 455 | tsdf_vol = self._tsdf_vol.numpy() 456 | vol_origin = self._vol_origin.numpy() 457 | if self.enable_color: 458 | color_vol = self._color_vol.numpy() / 255 459 | 460 | _vertices, triangles, _, _ = measure.marching_cubes(-tsdf_vol, 0) 461 | vertices_sample = (_vertices / self._vol_dim[0] - 0.5) * 2 462 | 463 | # interpolation to get colors 464 | vertices_pt = torch.from_numpy(vertices_sample).float()[ 465 | None, None, None, :, [2, 1, 0] 466 | ] # [1, 1, 1, N, 3] 467 | 468 | # mesh_vertices = _vertices / self._vol_dim[0] + self._vol_origin.cpu().numpy() 469 | mesh_vertices = vertices_sample 470 | self._mesh.vertices = o3d.utility.Vector3dVector(mesh_vertices.astype(float)) 471 | self._mesh.triangles = o3d.utility.Vector3iVector(triangles.astype(np.int32)) 472 | if self.enable_color: 473 | color_vol_pt = ( 474 | torch.from_numpy(color_vol).float().permute(3, 0, 1, 2)[None] 475 | ) # [1, 3, H, W, D] 476 | vert_colors = torch.nn.functional.grid_sample( 477 | color_vol_pt, vertices_pt, align_corners=True 478 | ) # [1, 3, 1, 1, N] 479 | vert_colors = vert_colors.squeeze().cpu().numpy().T 480 | self._mesh.vertex_colors = o3d.utility.Vector3dVector( 481 | vert_colors.astype(float) 482 | ) 483 | 484 | self._mesh.compute_vertex_normals() 485 | if self._verbose: 486 | print("# Extracting Mesh: {} Vertices".format(mesh_vertices.shape[0])) 487 | print("# Meshing Timing: {:.5f} (second).".format(time.time() - time_begin)) 488 | return self._mesh 489 | 490 | def update_tsdf( 491 | self, tsdf_old, tsdf_new, color_old, color_new, weight_old, obs_weight 492 | ): 493 | """ 494 | Update the TSDF value of given voxel 495 | V = (wv + WV) / w + W 496 | 497 | :param tsdf_old: Old TSDF values. 498 | :param tsdf_new: New TSDF values. 499 | :param weight_old: Voxels weights. 500 | :param obs_weight: Weight of current update. 501 | :return: Updated TSDF values & Updated weights. 502 | """ 503 | 504 | tsdf_vol_int = torch.empty_like( 505 | tsdf_old, dtype=torch.float32, device=self._device 506 | ) 507 | weight_new = torch.empty_like( 508 | weight_old, dtype=torch.float32, device=self._device 509 | ) 510 | 511 | weight_new = weight_old + obs_weight 512 | tsdf_vol_int = (weight_old * tsdf_old + obs_weight * tsdf_new) / weight_new 513 | if color_old is not None: 514 | color_vol_int = ( 515 | weight_old[:, None] * color_old + obs_weight * color_new 516 | ) / weight_new[:, None] 517 | return tsdf_vol_int, color_vol_int, weight_new 518 | else: 519 | color_vol_int = None 520 | 521 | return tsdf_vol_int, color_vol_int, weight_new 522 | 523 | 524 | class TSDFVolume2(TSDFVolume): 525 | """ 526 | Volumetric with TSDF representation 527 | """ 528 | 529 | def __init__( 530 | self, 531 | vol_bounds: np.ndarray, 532 | voxel_size: float, 533 | use_gpu: bool = False, 534 | verbose: bool = False, 535 | num_margin: float = 5.0, 536 | enable_color=True, 537 | ): 538 | """ 539 | Constructor 540 | 541 | :param vol_bounds: An ndarray is shape (3,2), define the min & max bounds of voxels. 542 | :param voxel_size: Voxel size in meters. 543 | :param use_gpu: Use GPU for voxel update. 544 | :param verbose: Print verbose message or not. 545 | """ 546 | 547 | vol_bounds = np.asarray(vol_bounds) 548 | assert vol_bounds.shape == (3, 2), "vol_bounds should be of shape (3,2)" 549 | 550 | self._verbose = verbose 551 | self._use_gpu = use_gpu 552 | self._vol_bounds = vol_bounds 553 | if self._use_gpu: 554 | if torch.cuda.is_available(): 555 | if self._verbose: 556 | print("# Using GPU mode") 557 | self._device = torch.device("cuda:0") 558 | else: 559 | if self._verbose: 560 | print("# Not available CUDA device, using CPU mode") 561 | self._device = torch.device("cpu") 562 | else: 563 | if self._verbose: 564 | print("# Using CPU mode") 565 | self._device = torch.device("cpu") 566 | self._voxel_size = float(voxel_size) 567 | self._trunc_margin = 5 * self._voxel_size # truncation on SDF 568 | self._vox_size = float(voxel_size) 569 | self._trunc_margin = num_margin * self._vox_size # truncation on SDF 570 | # Adjust volume bounds and ensure C-order contiguous 571 | self._vol_dim = ( 572 | np.ceil( 573 | (self._vol_bounds[:, 1] - self._vol_bounds[:, 0]) / self._voxel_size 574 | ) 575 | .copy(order="C") 576 | .astype(int) 577 | ) 578 | 579 | self._vol_bounds[:, 1] = ( 580 | self._vol_bounds[:, 0] + self._vol_dim * self._voxel_size 581 | ) 582 | self._vol_dim = self._vol_dim.tolist() 583 | # self._vol_origin = self._vol_bounds[:, 0].copy(order="C").astype(np.float32) 584 | # self._vol_dim = torch.tensor(self._vol_dim, device=self._device).int() 585 | self._vol_origin = torch.tensor( 586 | self._vol_bounds[:, 0].copy(order="C"), device=self._device 587 | ).float() 588 | # Grid coordinates of voxels 589 | xx, yy, zz = torch.meshgrid( 590 | torch.arange(self._vol_dim[0]), 591 | torch.arange(self._vol_dim[1]), 592 | torch.arange(self._vol_dim[2]), 593 | indexing="ij", 594 | ) 595 | self._vox_coords = ( 596 | torch.cat([xx.reshape(1, -1), yy.reshape(1, -1), zz.reshape(1, -1)], dim=0) 597 | .int() 598 | .T 599 | ) 600 | if self._use_gpu: 601 | self._vox_coords = self._vox_coords.cuda() 602 | 603 | # World coordinates of voxel centers 604 | self._world_coords = self.vox2world( 605 | self._vol_origin, self._vox_coords, self._vox_size 606 | ) 607 | self.enable_color = enable_color 608 | 609 | # TSDF & weights 610 | self._tsdf_vol = torch.ones( 611 | size=self._vol_dim, device=self._device, dtype=torch.float32 612 | ) 613 | self._weight_vol = torch.zeros( 614 | size=self._vol_dim, device=self._device, dtype=torch.float32 615 | ) 616 | if self.enable_color: 617 | self._color_vol = torch.zeros( 618 | size=[*self._vol_dim, 3], device=self._device, dtype=torch.float32 619 | ) 620 | 621 | # Mesh paramters 622 | self._mesh = o3d.geometry.TriangleMesh() 623 | -------------------------------------------------------------------------------- /models/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from models.clip.interpolate import interpolate_positional_embedding 10 | 11 | 12 | class Bottleneck(nn.Module): 13 | expansion = 4 14 | 15 | def __init__(self, inplanes, planes, stride=1): 16 | super().__init__() 17 | 18 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 19 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 20 | self.bn1 = nn.BatchNorm2d(planes) 21 | self.relu1 = nn.ReLU(inplace=True) 22 | 23 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 24 | self.bn2 = nn.BatchNorm2d(planes) 25 | self.relu2 = nn.ReLU(inplace=True) 26 | 27 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 28 | 29 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 30 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 31 | self.relu3 = nn.ReLU(inplace=True) 32 | 33 | self.downsample = None 34 | self.stride = stride 35 | 36 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 37 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 38 | self.downsample = nn.Sequential( 39 | OrderedDict( 40 | [ 41 | ("-1", nn.AvgPool2d(stride)), 42 | ( 43 | "0", 44 | nn.Conv2d( 45 | inplanes, 46 | planes * self.expansion, 47 | 1, 48 | stride=1, 49 | bias=False, 50 | ), 51 | ), 52 | ("1", nn.BatchNorm2d(planes * self.expansion)), 53 | ] 54 | ) 55 | ) 56 | 57 | def forward(self, x: torch.Tensor): 58 | identity = x 59 | 60 | out = self.relu1(self.bn1(self.conv1(x))) 61 | out = self.relu2(self.bn2(self.conv2(out))) 62 | out = self.avgpool(out) 63 | out = self.bn3(self.conv3(out)) 64 | 65 | if self.downsample is not None: 66 | identity = self.downsample(x) 67 | 68 | out += identity 69 | out = self.relu3(out) 70 | return out 71 | 72 | 73 | class AttentionPool2d(nn.Module): 74 | def __init__( 75 | self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None 76 | ): 77 | super().__init__() 78 | self.positional_embedding = nn.Parameter( 79 | torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 80 | ) 81 | self.k_proj = nn.Linear(embed_dim, embed_dim) 82 | self.q_proj = nn.Linear(embed_dim, embed_dim) 83 | self.v_proj = nn.Linear(embed_dim, embed_dim) 84 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 85 | self.num_heads = num_heads 86 | self.spacial_dim = spacial_dim 87 | 88 | def forward(self, x): 89 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 90 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 91 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 92 | x, _ = F.multi_head_attention_forward( 93 | query=x[:1], 94 | key=x, 95 | value=x, 96 | embed_dim_to_check=x.shape[-1], 97 | num_heads=self.num_heads, 98 | q_proj_weight=self.q_proj.weight, 99 | k_proj_weight=self.k_proj.weight, 100 | v_proj_weight=self.v_proj.weight, 101 | in_proj_weight=None, 102 | in_proj_bias=torch.cat( 103 | [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] 104 | ), 105 | bias_k=None, 106 | bias_v=None, 107 | add_zero_attn=False, 108 | dropout_p=0, 109 | out_proj_weight=self.c_proj.weight, 110 | out_proj_bias=self.c_proj.bias, 111 | use_separate_proj_weight=True, 112 | training=self.training, 113 | need_weights=False, 114 | ) 115 | return x.squeeze(0) 116 | 117 | def forward_v(self, x: torch.Tensor): 118 | """ 119 | Forward function for computing the value features for dense prediction (i.e., features for every image patch). 120 | """ 121 | _, _, w, h = x.shape 122 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 123 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 124 | 125 | # Interpolate positional embedding to match the size of the input 126 | interpolated_pe = interpolate_positional_embedding( 127 | self.positional_embedding, x.permute(1, 0, 2), patch_size=1, w=w, h=h 128 | ) 129 | x = x + interpolated_pe[:, None, :] # (HW+1)NC 130 | 131 | v_in = F.linear(x, self.v_proj.weight, self.v_proj.bias) 132 | v_out = F.linear(v_in, self.c_proj.weight, self.c_proj.bias) 133 | v_out = v_out.permute(1, 0, 2) # (HW+1)NC -> N(HW+1)C 134 | return v_out 135 | 136 | 137 | class ModifiedResNet(nn.Module): 138 | """ 139 | A ResNet class that is similar to torchvision's but contains the following changes: 140 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 141 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 142 | - The final pooling layer is a QKV attention instead of an average pool 143 | """ 144 | 145 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 146 | super().__init__() 147 | self.output_dim = output_dim 148 | self.input_resolution = input_resolution 149 | 150 | # the 3-layer stem 151 | self.conv1 = nn.Conv2d( 152 | 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False 153 | ) 154 | self.bn1 = nn.BatchNorm2d(width // 2) 155 | self.relu1 = nn.ReLU(inplace=True) 156 | self.conv2 = nn.Conv2d( 157 | width // 2, width // 2, kernel_size=3, padding=1, bias=False 158 | ) 159 | self.bn2 = nn.BatchNorm2d(width // 2) 160 | self.relu2 = nn.ReLU(inplace=True) 161 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 162 | self.bn3 = nn.BatchNorm2d(width) 163 | self.relu3 = nn.ReLU(inplace=True) 164 | self.avgpool = nn.AvgPool2d(2) 165 | 166 | # residual layers 167 | self._inplanes = width # this is a *mutable* variable used during construction 168 | self.layer1 = self._make_layer(width, layers[0]) 169 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 170 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 171 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 172 | 173 | embed_dim = width * 32 # the ResNet feature dimension 174 | self.attnpool = AttentionPool2d( 175 | input_resolution // 32, embed_dim, heads, output_dim 176 | ) 177 | 178 | def _make_layer(self, planes, blocks, stride=1): 179 | layers = [Bottleneck(self._inplanes, planes, stride)] 180 | 181 | self._inplanes = planes * Bottleneck.expansion 182 | for _ in range(1, blocks): 183 | layers.append(Bottleneck(self._inplanes, planes)) 184 | 185 | return nn.Sequential(*layers) 186 | 187 | def forward(self, x, patch_output: bool = False): 188 | def stem(x): 189 | x = self.relu1(self.bn1(self.conv1(x))) 190 | x = self.relu2(self.bn2(self.conv2(x))) 191 | x = self.relu3(self.bn3(self.conv3(x))) 192 | x = self.avgpool(x) 193 | return x 194 | 195 | x = x.type(self.conv1.weight.dtype) 196 | x = stem(x) 197 | x = self.layer1(x) 198 | x = self.layer2(x) 199 | x = self.layer3(x) 200 | x = self.layer4(x) 201 | 202 | if patch_output: 203 | x = self.attnpool.forward_v(x) 204 | x = x[:, 1:, :] # remove the cls token 205 | else: 206 | x = self.attnpool(x) 207 | 208 | return x 209 | 210 | 211 | class LayerNorm(nn.LayerNorm): 212 | """Subclass torch's LayerNorm to handle fp16.""" 213 | 214 | def forward(self, x: torch.Tensor): 215 | orig_type = x.dtype 216 | ret = super().forward(x.type(torch.float32)) 217 | return ret.type(orig_type) 218 | 219 | 220 | class QuickGELU(nn.Module): 221 | def forward(self, x: torch.Tensor): 222 | return x * torch.sigmoid(1.702 * x) 223 | 224 | 225 | class ResidualAttentionBlock(nn.Module): 226 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 227 | super().__init__() 228 | 229 | self.attn = nn.MultiheadAttention(d_model, n_head) 230 | self.ln_1 = LayerNorm(d_model) 231 | self.mlp = nn.Sequential( 232 | OrderedDict( 233 | [ 234 | ("c_fc", nn.Linear(d_model, d_model * 4)), 235 | ("gelu", QuickGELU()), 236 | ("c_proj", nn.Linear(d_model * 4, d_model)), 237 | ] 238 | ) 239 | ) 240 | self.ln_2 = LayerNorm(d_model) 241 | self.attn_mask = attn_mask 242 | 243 | def attention(self, x: torch.Tensor): 244 | self.attn_mask = ( 245 | self.attn_mask.to(dtype=x.dtype, device=x.device) 246 | if self.attn_mask is not None 247 | else None 248 | ) 249 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 250 | 251 | def forward_v(self, x: torch.Tensor): 252 | """ 253 | Forward function for computing the value features for dense prediction (i.e., features for every image patch). 254 | """ 255 | # Get the weights and biases for the value projection, multihead attention uses 3 * embed_dim for the input projection 256 | v_in_proj_weight = self.attn.in_proj_weight[-self.attn.embed_dim :] 257 | v_in_proj_bias = self.attn.in_proj_bias[-self.attn.embed_dim :] 258 | 259 | v_in = F.linear(self.ln_1(x), v_in_proj_weight, v_in_proj_bias) 260 | v_out = F.linear(v_in, self.attn.out_proj.weight, self.attn.out_proj.bias) 261 | 262 | # Using the value features works the best. Adding this to 'x' or feeding 'v' to the LayerNorm then MLP degrades the performance 263 | return v_out 264 | 265 | def forward(self, x: torch.Tensor): 266 | x = x + self.attention(self.ln_1(x)) 267 | x = x + self.mlp(self.ln_2(x)) 268 | return x 269 | 270 | 271 | class Transformer(nn.Module): 272 | def __init__( 273 | self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None 274 | ): 275 | super().__init__() 276 | self.width = width 277 | self.layers = layers 278 | self.resblocks = nn.Sequential( 279 | *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)] 280 | ) 281 | 282 | def forward(self, x: torch.Tensor): 283 | return self.resblocks(x) 284 | 285 | 286 | class VisionTransformer(nn.Module): 287 | def __init__( 288 | self, 289 | input_resolution: int, 290 | patch_size: int, 291 | width: int, 292 | layers: int, 293 | heads: int, 294 | output_dim: int, 295 | output_indices: list = [6, 7, 8, 9, 10], 296 | ): 297 | super().__init__() 298 | self.input_resolution = input_resolution 299 | self.output_dim = output_dim 300 | self.conv1 = nn.Conv2d( 301 | in_channels=3, 302 | out_channels=width, 303 | kernel_size=patch_size, 304 | stride=patch_size, 305 | bias=False, 306 | ) 307 | 308 | scale = width**-0.5 309 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 310 | self.positional_embedding = nn.Parameter( 311 | scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width) 312 | ) 313 | self.ln_pre = LayerNorm(width) 314 | 315 | self.transformer = Transformer(width, layers, heads) 316 | 317 | self.ln_post = LayerNorm(width) 318 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 319 | 320 | self.patch_size = patch_size 321 | self.output_indices = output_indices 322 | 323 | def forward(self, x: torch.Tensor, patch_output: bool = False, return_dict=False): 324 | output_dict = {} 325 | 326 | _, _, w, h = x.shape 327 | 328 | x = self.conv1(x) # shape = [*, width, grid, grid] 329 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 330 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 331 | x = torch.cat( 332 | [ 333 | self.class_embedding.to(x.dtype) 334 | + torch.zeros( 335 | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device 336 | ), 337 | x, 338 | ], 339 | dim=1, 340 | ) # shape = [*, grid ** 2 + 1, width] 341 | x = x + interpolate_positional_embedding( 342 | self.positional_embedding, x, patch_size=self.patch_size, w=w, h=h 343 | ) 344 | x = self.ln_pre(x) 345 | 346 | x = x.permute(1, 0, 2) # NLD -> LND 347 | 348 | if patch_output: 349 | *layers, last_resblock = self.transformer.resblocks 350 | penultimate = nn.Sequential(*layers) 351 | for ii, layer in enumerate(layers): 352 | _x = layer(x) # Results after layer normalization 353 | # TODO: add f"res{ii}" output dict here!!! 354 | if return_dict and ii in self.output_indices: 355 | res_i = ii - 5 356 | output_dict[f"res{res_i}"] = _x.permute(1, 0, 2) 357 | 358 | x = penultimate(x) 359 | x = last_resblock.forward_v(x) 360 | x = x.permute(1, 0, 2) # LND -> NLD 361 | 362 | # Extract the patch tokens, not the class token 363 | x = x[:, 1:, :] 364 | x = self.ln_post(x) 365 | if self.proj is not None: 366 | # This is equivalent to conv1d 367 | x = x @ self.proj 368 | 369 | if return_dict: 370 | output_dict["final"] = x 371 | for k, v in output_dict.items(): 372 | print(k, v.shape) 373 | return output_dict 374 | else: 375 | return x 376 | 377 | x = self.transformer(x) 378 | x = x.permute(1, 0, 2) # LND -> NLD 379 | 380 | x = self.ln_post(x[:, 0, :]) 381 | 382 | if self.proj is not None: 383 | x = x @ self.proj 384 | 385 | return x 386 | 387 | 388 | class CLIP(nn.Module): 389 | def __init__( 390 | self, 391 | embed_dim: int, 392 | # vision 393 | image_resolution: int, 394 | vision_layers: Union[Tuple[int, int, int, int], int], 395 | vision_width: int, 396 | vision_patch_size: int, 397 | # text 398 | context_length: int, 399 | vocab_size: int, 400 | transformer_width: int, 401 | transformer_heads: int, 402 | transformer_layers: int, 403 | ): 404 | super().__init__() 405 | 406 | self.context_length = context_length 407 | 408 | if isinstance(vision_layers, (tuple, list)): 409 | vision_heads = vision_width * 32 // 64 410 | self.visual = ModifiedResNet( 411 | layers=vision_layers, 412 | output_dim=embed_dim, 413 | heads=vision_heads, 414 | input_resolution=image_resolution, 415 | width=vision_width, 416 | ) 417 | else: 418 | vision_heads = vision_width // 64 419 | self.visual = VisionTransformer( 420 | input_resolution=image_resolution, 421 | patch_size=vision_patch_size, 422 | width=vision_width, 423 | layers=vision_layers, 424 | heads=vision_heads, 425 | output_dim=embed_dim, 426 | ) 427 | 428 | self.transformer = Transformer( 429 | width=transformer_width, 430 | layers=transformer_layers, 431 | heads=transformer_heads, 432 | attn_mask=self.build_attention_mask(), 433 | ) 434 | 435 | self.vocab_size = vocab_size 436 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 437 | self.positional_embedding = nn.Parameter( 438 | torch.empty(self.context_length, transformer_width) 439 | ) 440 | self.ln_final = LayerNorm(transformer_width) 441 | 442 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 443 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 444 | 445 | self.initialize_parameters() 446 | 447 | def initialize_parameters(self): 448 | nn.init.normal_(self.token_embedding.weight, std=0.02) 449 | nn.init.normal_(self.positional_embedding, std=0.01) 450 | 451 | if isinstance(self.visual, ModifiedResNet): 452 | if self.visual.attnpool is not None: 453 | std = self.visual.attnpool.c_proj.in_features**-0.5 454 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 455 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 456 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 457 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 458 | 459 | for resnet_block in [ 460 | self.visual.layer1, 461 | self.visual.layer2, 462 | self.visual.layer3, 463 | self.visual.layer4, 464 | ]: 465 | for name, param in resnet_block.named_parameters(): 466 | if name.endswith("bn3.weight"): 467 | nn.init.zeros_(param) 468 | 469 | proj_std = (self.transformer.width**-0.5) * ( 470 | (2 * self.transformer.layers) ** -0.5 471 | ) 472 | attn_std = self.transformer.width**-0.5 473 | fc_std = (2 * self.transformer.width) ** -0.5 474 | for block in self.transformer.resblocks: 475 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 476 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 477 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 478 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 479 | 480 | if self.text_projection is not None: 481 | nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5) 482 | 483 | def build_attention_mask(self): 484 | # lazily create causal attention mask, with full attention between the vision tokens 485 | # pytorch uses additive attention mask; fill with -inf 486 | mask = torch.empty(self.context_length, self.context_length) 487 | mask.fill_(float("-inf")) 488 | mask.triu_(1) # zero out the lower diagonal 489 | return mask 490 | 491 | @property 492 | def dtype(self): 493 | return self.visual.conv1.weight.dtype 494 | 495 | def encode_image(self, image): 496 | return self.visual(image.type(self.dtype)) 497 | 498 | def get_patch_encodings(self, image) -> torch.Tensor: 499 | """Get the encodings for each patch in the image""" 500 | return self.visual(image.type(self.dtype), patch_output=True) 501 | 502 | def get_image_encoder_projection(self) -> nn.Parameter: 503 | """Get vision transformer projection matrix.""" 504 | assert isinstance(self.visual, VisionTransformer) 505 | return self.visual.proj 506 | 507 | def encode_text(self, text): 508 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 509 | 510 | x = x + self.positional_embedding.type(self.dtype) 511 | x = x.permute(1, 0, 2) # NLD -> LND 512 | x = self.transformer(x) 513 | x = x.permute(1, 0, 2) # LND -> NLD 514 | x = self.ln_final(x).type(self.dtype) 515 | 516 | # x.shape = [batch_size, n_ctx, transformer.width] 517 | # take features from the eot embedding (eot_token is the highest number in each sequence) 518 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 519 | 520 | return x 521 | 522 | def forward(self, image, text): 523 | image_features = self.encode_image(image) 524 | text_features = self.encode_text(text) 525 | 526 | # normalized features 527 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 528 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 529 | 530 | # cosine similarity as logits 531 | logit_scale = self.logit_scale.exp() 532 | logits_per_image = logit_scale * image_features @ text_features.t() 533 | logits_per_text = logits_per_image.t() 534 | 535 | # shape = [global_batch_size, global_batch_size] 536 | return logits_per_image, logits_per_text 537 | 538 | 539 | def convert_weights(model: nn.Module): 540 | """Convert applicable model parameters to fp16""" 541 | 542 | def _convert_weights_to_fp16(l): 543 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 544 | l.weight.data = l.weight.data.half() 545 | if l.bias is not None: 546 | l.bias.data = l.bias.data.half() 547 | 548 | if isinstance(l, nn.MultiheadAttention): 549 | for attr in [ 550 | *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], 551 | "in_proj_bias", 552 | "bias_k", 553 | "bias_v", 554 | ]: 555 | tensor = getattr(l, attr) 556 | if tensor is not None: 557 | tensor.data = tensor.data.half() 558 | 559 | for name in ["text_projection", "proj"]: 560 | if hasattr(l, name): 561 | attr = getattr(l, name) 562 | if attr is not None: 563 | attr.data = attr.data.half() 564 | 565 | model.apply(_convert_weights_to_fp16) 566 | 567 | 568 | def build_model(state_dict: dict): 569 | vit = "visual.proj" in state_dict 570 | 571 | if vit: 572 | vision_width = state_dict["visual.conv1.weight"].shape[0] 573 | vision_layers = len( 574 | [ 575 | k 576 | for k in state_dict.keys() 577 | if k.startswith("visual.") and k.endswith(".attn.in_proj_weight") 578 | ] 579 | ) 580 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 581 | grid_size = round( 582 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5 583 | ) 584 | image_resolution = vision_patch_size * grid_size 585 | else: 586 | counts: list = [ 587 | len( 588 | set( 589 | k.split(".")[2] 590 | for k in state_dict 591 | if k.startswith(f"visual.layer{b}") 592 | ) 593 | ) 594 | for b in [1, 2, 3, 4] 595 | ] 596 | vision_layers = tuple(counts) 597 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 598 | output_width = round( 599 | (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5 600 | ) 601 | vision_patch_size = None 602 | assert ( 603 | output_width**2 + 1 604 | == state_dict["visual.attnpool.positional_embedding"].shape[0] 605 | ) 606 | image_resolution = output_width * 32 607 | 608 | embed_dim = state_dict["text_projection"].shape[1] 609 | context_length = state_dict["positional_embedding"].shape[0] 610 | vocab_size = state_dict["token_embedding.weight"].shape[0] 611 | transformer_width = state_dict["ln_final.weight"].shape[0] 612 | transformer_heads = transformer_width // 64 613 | transformer_layers = len( 614 | set( 615 | k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks") 616 | ) 617 | ) 618 | 619 | model = CLIP( 620 | embed_dim, 621 | image_resolution, 622 | vision_layers, 623 | vision_width, 624 | vision_patch_size, 625 | context_length, 626 | vocab_size, 627 | transformer_width, 628 | transformer_heads, 629 | transformer_layers, 630 | ) 631 | 632 | for key in ["input_resolution", "context_length", "vocab_size"]: 633 | if key in state_dict: 634 | del state_dict[key] 635 | 636 | convert_weights(model) 637 | model.load_state_dict(state_dict) 638 | return model.eval() 639 | --------------------------------------------------------------------------------