├── algos
├── __init__.py
├── goal_algos.py
├── contact_algos.py
├── traj_algos.py
└── traj_optimizer.py
├── models
├── __init__.py
├── clip
│ ├── __init__.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── README.md
│ ├── interpolate.py
│ ├── simple_tokenizer.py
│ ├── clip.py
│ └── model.py
├── contact.py
├── perceiver.py
├── temporal.py
├── goal.py
├── attention.py
├── feature_extractors.py
└── helpers.py
├── assets
└── teaser.png
├── requirements.txt
├── .gitignore
├── scripts
├── download_ckpt_testdata.sh
├── prepare_third_party_modules.sh
└── test_demo.sh
├── .gitmodules
├── config
└── test_config.yaml
├── LICENSE
├── diffuser_utils
└── guidance_params.py
├── README.md
└── demos
├── optimize_affordance.py
└── infer_affordance.py
/algos/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethz-mrl/VidBot/HEAD/assets/teaser.png
--------------------------------------------------------------------------------
/models/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
3 | """
4 | Modified from https://github.com/openai/CLIP
5 | """
6 |
--------------------------------------------------------------------------------
/models/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ethz-mrl/VidBot/HEAD/models/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/models/clip/README.md:
--------------------------------------------------------------------------------
1 | # CLIP
2 | Modified version of [CLIP](https://github.com/openai/CLIP) with support for dense patch-level feature extraction
3 | (based on [MaskCLIP](https://arxiv.org/abs/2112.01071) parametrization) and interpolation of the positional encoding.
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu117
2 | torch==1.13.1+cu117
3 | torchvision==0.14.1+cu117
4 | numpy==1.26.3
5 | transformers==4.26.1
6 | open3d
7 | opencv-python
8 | omegaconf
9 | easydict
10 | flow_vis
11 | numba
12 | scikit-image
13 | transformations
14 | ftfy
15 | einops
16 | gdown
17 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | wandb/
2 | out/
3 | ckpt/
4 |
5 | # Byte-compiled / optimized / DLL files
6 | __pycache__/
7 | *.py[cod]
8 |
9 | .vscode/*
10 |
11 | *.egg-info
12 | wandb/*
13 | logs/*
14 | datasets/*.json
15 | datasets/labels/*
16 | *.npz
17 | *.npy
18 | .tmp/
19 | results/
20 | # third_party/
21 | pretrained/
22 | *.zip
23 | vis/
24 | inference_results/
25 | *.blend*
26 | manip_render/
27 | datasets
28 | pretrained_weights/
--------------------------------------------------------------------------------
/scripts/download_ckpt_testdata.sh:
--------------------------------------------------------------------------------
1 | # mkdir -p datasets
2 | gdown https://drive.google.com/uc?id=1IDCJ-wB05sMVKdLiG0IO_OajsY0ihvpi
3 | unzip vidbot_data_demo.zip -d datasets/
4 | rm vidbot_data_demo.zip
5 |
6 | gdown https://drive.google.com/uc?id=14ByEGX4zKB7VjIE7fybXiq11kpRvvMqz
7 | unzip epickitchens_traj_demo.zip -d datasets/
8 | rm epickitchens_traj_demo.zip
9 |
10 | gdown https://drive.google.com/uc?id=1pBZHU65WDwqcZTAm2TpBsLmhn6S9M82l
11 | unzip pretrained.zip
12 | rm pretrained.zip
13 |
14 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "third_party/GroundingDINO"]
2 | path = third_party/GroundingDINO
3 | url = https://github.com/IDEA-Research/GroundingDINO.git
4 | [submodule "third_party/EfficientSAM"]
5 | path = third_party/EfficientSAM
6 | url = https://github.com/yformer/EfficientSAM.git
7 | [submodule "third_party/graspnetAPI"]
8 | path = third_party/graspnetAPI
9 | url = git@github.com:HanzhiC/graspnetAPI.git
10 | [submodule "third_party/graspness_unofficial"]
11 | path = third_party/graspness_unofficial
12 | url = https://github.com/HanzhiC/graspness_unofficial.git
13 |
--------------------------------------------------------------------------------
/config/test_config.yaml:
--------------------------------------------------------------------------------
1 | # Dataset and Gripper Mesh
2 | dataset_dir: ./datasets/
3 | gripper_mesh_file: ./assets/panda_hand_mesh.obj
4 |
5 | # VidBot Affordance Prediction Models
6 | config_traj: ./pretrained/traj/config.yaml
7 | config_goal: ./pretrained/goal/config.yaml
8 | config_contact: ./pretrained/contact/config.yaml
9 | traj_ckpt: ./pretrained/traj/final.ckpt
10 | goal_ckpt: ./pretrained/goal/final.ckpt
11 | contact_ckpt: ./pretrained/contact/final.ckpt
12 |
13 | # Optional: Open-world Detectors and Grasp Detector
14 | config_detector: ./third_party/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py
15 | esam_ckpt: ./third_party/EfficientSAM/weights/efficient_sam_vitt.pt
16 | graspnet_ckpt: ./third_party/graspness_unofficial/weights/minkuresunet_kinect.tar
17 | detector_ckpt: ./third_party/GroundingDINO/weights/groundingdino_swint_ogc.pth
18 |
--------------------------------------------------------------------------------
/scripts/prepare_third_party_modules.sh:
--------------------------------------------------------------------------------
1 | # Clone the third-party modules
2 | echo "Cloning third-party modules"
3 | git submodule update --init --recursive
4 |
5 | # EfficientSAM weight is already downloaded via cloning the EfficientSAM repo
6 | echo "EfficientSAM weight is already downloaded via cloning the EfficientSAM repo"
7 |
8 | # Download the weight of the GroundingDINO
9 | echo "Downloading GroundingDINO weight"
10 | mkdir -p third_party/GroundingDINO/weights && cd third_party/GroundingDINO/weights
11 | wget -q https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth
12 | cd ../../../
13 |
14 | # Download the weight of the GraspNet
15 | echo "Downloading GraspNet weight"
16 | mkdir -p third_party/graspness_unofficial/weights && cd third_party/graspness_unofficial/weights
17 | gdown https://drive.google.com/uc?id=10o5fc8LQsbI8H0pIC2RTJMNapW9eczqF
18 | cd ../../../
19 |
20 | echo "Done with preparing third-party modules"
21 | echo "Follow their instructions for installation!"
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Hanzhi Chen
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/diffuser_utils/guidance_params.py:
--------------------------------------------------------------------------------
1 | PARAMS1 = {
2 | "goal_weight": 100.0,
3 | "noncollide_weight": 200.0,
4 | "normal_weight": 200.0,
5 | "contact_weight": 0.0,
6 | "fine_voxel_resolution": 32,
7 | "exclude_object_points": True,
8 | }
9 |
10 | PARAMS2 = {
11 | "goal_weight": 100.0,
12 | "noncollide_weight": 200.0,
13 | "normal_weight": 10.0,
14 | "contact_weight": 0.0,
15 | "fine_voxel_resolution": 128,
16 | "exclude_object_points": False,
17 | }
18 |
19 | PARAMS3 = {
20 | "goal_weight": 100.0,
21 | "noncollide_weight": 500.0,
22 | "normal_weight": 0.0,
23 | "contact_weight": 500.0,
24 | "fine_voxel_resolution": 128,
25 | "exclude_object_points": True,
26 | }
27 |
28 | GUIDANCE_PARAMS_DICT = {
29 | "open": PARAMS1,
30 | "close": PARAMS1,
31 | "pull": PARAMS1,
32 | "push": PARAMS1,
33 | "press": PARAMS1,
34 | "pick": PARAMS2,
35 | "pickup": PARAMS2,
36 | "take": PARAMS2,
37 | "get": PARAMS2,
38 | "put": PARAMS2,
39 | "place": PARAMS2,
40 | "putdown": PARAMS2,
41 | "drop": PARAMS2,
42 | "wipe": PARAMS3,
43 | "move": PARAMS3,
44 | "other": PARAMS2,
45 |
46 | }
47 |
48 | COMMON_ACTIONS = [
49 | "open",
50 | "close",
51 | "pull",
52 | "push",
53 | "press",
54 | "pick",
55 | "pickup",
56 | "take",
57 | "get",
58 | "put",
59 | "place",
60 | "putdown",
61 | "drop",
62 | "wipe",
63 | "move",
64 | ]
65 |
--------------------------------------------------------------------------------
/algos/goal_algos.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 |
4 | import torch
5 | import torch.nn as nn
6 | import pytorch_lightning as pl
7 | import torch.nn.functional as F
8 | from models.goal import GoalPredictor
9 | import diffuser_utils.dataset_utils as DatasetUtils
10 | from diffuser_utils.guidance_params import COMMON_ACTIONS
11 | import pandas as pd
12 |
13 |
14 | class GoalPredictorModule(pl.LightningModule):
15 | def __init__(self, algo_config):
16 | super(GoalPredictorModule, self).__init__()
17 | self.algo_config = algo_config
18 | self.nets = nn.ModuleDict()
19 | policy_kwargs = algo_config.model
20 |
21 | self.nets["policy"] = GoalPredictor(**policy_kwargs)
22 |
23 | @torch.no_grad()
24 | def encode_action(self, data_batch, clip_model, max_length=20):
25 | action_tokens, action_feature = DatasetUtils.encode_text_clip(
26 | clip_model,
27 | [data_batch["action_text"]],
28 | max_length=max_length,
29 | device="cuda",
30 | )
31 |
32 | action_tokens.to(self.device)
33 | action_feature.to(self.device)
34 |
35 | action_text = data_batch["action_text"]
36 | verb_text = action_text.split(" ")[0]
37 | if verb_text not in COMMON_ACTIONS:
38 | verb_text = "other"
39 | else:
40 | verb_text = verb_text.replace("-", "")
41 | verb_text = [verb_text]
42 |
43 | verb_tokens, verb_feature = DatasetUtils.encode_text_clip(
44 | clip_model,
45 | verb_text,
46 | max_length=max_length,
47 | device="cuda",
48 | )
49 |
50 | verb_tokens.to(self.device)
51 | verb_feature.to(self.device)
52 |
53 | data_batch.update({"action_feature": action_feature.float()})
54 | data_batch.update({"verb_feature": verb_feature.float()})
55 |
56 | def forward(self, data_batch, training=False):
57 | # self.encode_action(data_batch)
58 | curr_policy = self.nets["policy"]
59 | outputs = curr_policy(data_batch, training)
60 | return outputs
61 |
--------------------------------------------------------------------------------
/algos/contact_algos.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 |
4 | import torch
5 | import torch.nn as nn
6 | import pytorch_lightning as pl
7 | import torch.nn.functional as F
8 |
9 | import diffuser_utils.dataset_utils as DatasetUtils
10 | from models.contact import ContactPredictor
11 | from diffuser_utils.guidance_params import COMMON_ACTIONS
12 |
13 |
14 | class ContactPredictorModule(pl.LightningModule):
15 | def __init__(self, algo_config):
16 | super(ContactPredictorModule, self).__init__()
17 | self.algo_config = algo_config
18 | self.nets = nn.ModuleDict()
19 |
20 | # Initialize the contact former
21 | policy_kwargs = algo_config.model
22 |
23 | self.nets["policy"] = ContactPredictor(**policy_kwargs)
24 |
25 | @torch.no_grad()
26 | def encode_action(self, data_batch, clip_model, max_length=20):
27 | action_tokens, action_feature = DatasetUtils.encode_text_clip(
28 | clip_model,
29 | [data_batch["action_text"]],
30 | max_length=max_length,
31 | device="cuda",
32 | )
33 |
34 | action_tokens.to(self.device)
35 | action_feature.to(self.device)
36 |
37 | action_text = data_batch["action_text"]
38 | verb_text = action_text.split(" ")[0]
39 | if verb_text not in COMMON_ACTIONS:
40 | verb_text = "other"
41 | else:
42 | verb_text = verb_text.replace("-", "")
43 | verb_text = [verb_text]
44 |
45 | verb_tokens, verb_feature = DatasetUtils.encode_text_clip(
46 | clip_model,
47 | verb_text,
48 | max_length=max_length,
49 | device="cuda",
50 | )
51 |
52 | verb_tokens.to(self.device)
53 | verb_feature.to(self.device)
54 |
55 | data_batch.update({"action_feature": action_feature.float()})
56 | data_batch.update({"verb_feature": verb_feature.float()})
57 |
58 | def forward(self, data_batch, training=False):
59 | # self.encode_action(data_batch)
60 | curr_policy = self.nets["policy"]
61 | outputs = curr_policy(data_batch, training)
62 | return outputs
63 |
--------------------------------------------------------------------------------
/scripts/test_demo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Exit on error, undefined variables, and pipeline failures
4 | set -euo pipefail
5 |
6 | # Check if the --use_graspnet flag is passed as an argument
7 | USE_GRASPNET=false
8 | if [[ "$#" -gt 0 && "$1" == "--use_graspnet" ]]; then
9 | USE_GRASPNET=true
10 | fi
11 |
12 | # Define the --use_graspnet option based on the variable
13 | graspnet_option=""
14 | if [ "$USE_GRASPNET" = true ]; then
15 | graspnet_option="--use_graspnet"
16 | fi
17 |
18 | # Define array of commands with conditional --use_graspnet
19 | commands=(
20 | "python demos/infer_affordance.py -v -f 0 -i 'pickup sponge' --load_results --no_save $graspnet_option"
21 | "python demos/infer_affordance.py -v -f 0 -i 'pickup brush' --load_results --no_save $graspnet_option"
22 | "python demos/infer_affordance.py -v -f 1 -i 'place sponge' --load_results --no_save $graspnet_option"
23 | "python demos/infer_affordance.py -v -f 2 -i 'take mug' --load_results --no_save $graspnet_option"
24 | "python demos/infer_affordance.py -v -f 3 -i 'take kettle' --load_results --no_save $graspnet_option"
25 | "python demos/infer_affordance.py -v -f 4 -i 'take paper' --load_results --no_save $graspnet_option"
26 | "python demos/infer_affordance.py -v -f 5 -i 'pickup bottle' --load_results --no_save $graspnet_option"
27 | "python demos/infer_affordance.py -v -f 6 -i 'push door' --load_results --no_save $graspnet_option"
28 | "python demos/infer_affordance.py -v -f 7 -i 'open cabinet' --load_results --no_save $graspnet_option"
29 | "python demos/infer_affordance.py -v -f 8 -i 'close cabinet' --load_results --no_save $graspnet_option"
30 | "python demos/infer_affordance.py -v -f 9 -i 'pickup driller' --load_results --no_save $graspnet_option"
31 | "python demos/infer_affordance.py -v -f 10 -i 'place driller' --load_results --no_save $graspnet_option"
32 | )
33 |
34 | # Get number of commands
35 | num_commands=${#commands[@]}
36 |
37 | # Generate shuffled indices
38 | mapfile -t shuffled_indices < <(shuf -i 0-$((num_commands - 1)))
39 |
40 | # Execute commands in shuffled order
41 | for index in "${shuffled_indices[@]}"; do
42 | echo ">> Running: ${commands[$index]}"
43 | bash -c "${commands[$index]}"
44 | done
--------------------------------------------------------------------------------
/models/clip/interpolate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 |
5 | def interpolate_positional_embedding(
6 | positional_embedding: torch.Tensor, x: torch.Tensor, patch_size: int, w: int, h: int
7 | ):
8 | """
9 | Interpolate the positional encoding for CLIP to the number of patches in the image given width and height.
10 | Modified from DINO ViT `interpolate_pos_encoding` method.
11 | https://github.com/facebookresearch/dino/blob/7c446df5b9f45747937fb0d72314eb9f7b66930a/vision_transformer.py#L174
12 | """
13 | assert positional_embedding.ndim == 2, "pos_encoding must be 2D"
14 |
15 | # Number of patches in input
16 | num_patches = x.shape[1] - 1
17 | # Original number of patches for square images
18 | num_og_patches = positional_embedding.shape[0] - 1
19 |
20 | if num_patches == num_og_patches and w == h:
21 | # No interpolation needed
22 | return positional_embedding.to(x.dtype)
23 |
24 | dim = x.shape[-1]
25 | class_pos_embed = positional_embedding[:1] # (1, dim)
26 | patch_pos_embed = positional_embedding[1:] # (num_og_patches, dim)
27 |
28 | # Compute number of tokens
29 | w0 = w // patch_size
30 | h0 = h // patch_size
31 | assert w0 * h0 == num_patches, "Number of patches does not match"
32 |
33 | # Add a small number to avoid floating point error in the interpolation
34 | # see discussion at https://github.com/facebookresearch/dino/issues/8
35 | w0, h0 = w0 + 0.1, h0 + 0.1
36 |
37 | # Interpolate
38 | patch_per_ax = int(np.sqrt(num_og_patches))
39 | patch_pos_embed_interp = torch.nn.functional.interpolate(
40 | patch_pos_embed.reshape(1, patch_per_ax, patch_per_ax, dim).permute(0, 3, 1, 2),
41 | # (1, dim, patch_per_ax, patch_per_ax)
42 | scale_factor=(w0 / patch_per_ax, h0 / patch_per_ax),
43 | mode="bicubic",
44 | align_corners=False,
45 | recompute_scale_factor=False,
46 | ) # (1, dim, w0, h0)
47 | assert (
48 | int(w0) == patch_pos_embed_interp.shape[-2] and int(h0) == patch_pos_embed_interp.shape[-1]
49 | ), "Interpolation error."
50 |
51 | patch_pos_embed_interp = patch_pos_embed_interp.permute(0, 2, 3, 1).reshape(-1, dim) # (w0 * h0, dim)
52 | # Concat class token embedding and interpolated patch embeddings
53 | pos_embed_interp = torch.cat([class_pos_embed, patch_pos_embed_interp], dim=0) # (w0 * h0 + 1, dim)
54 | return pos_embed_interp.to(x.dtype)
55 |
--------------------------------------------------------------------------------
/models/contact.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import einops
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import tqdm
7 | import numpy as np
8 |
9 | from models.layers_2d import ResNet50Encoder, ResNet50Decoder, PositionalEmbeddingV2
10 | import einops
11 | import torchvision.transforms as transforms
12 | from diffuser_utils.dataset_utils import compute_model_size
13 | from models.perceiver import FeaturePerceiver
14 |
15 |
16 | class ContactPredictor(pl.LightningModule):
17 | def __init__(
18 | self,
19 | in_channels=3,
20 | out_channels=2,
21 | use_skip=True,
22 | encode_action=False,
23 | use_min_loss=False,
24 | **kwargs,
25 | ):
26 | super(ContactPredictor, self).__init__()
27 | self.in_channels = in_channels
28 | self.out_channels = out_channels
29 | self.use_skip = use_skip
30 | self.encode_action = encode_action
31 | self.use_min_loss = use_min_loss
32 | self.transform = transforms.Compose(
33 | [
34 | transforms.Normalize([0.485, 0.456, 0.406], [
35 | 0.229, 0.224, 0.225]),
36 | ]
37 | )
38 | self.visual = ResNet50Encoder(input_channels=in_channels)
39 | self.decoder = ResNet50Decoder(
40 | output_channels=out_channels, use_skip=use_skip)
41 | self.bottleneck_feature_key = self.decoder.bottleneck_feature_key
42 | self.latent_dim = self.decoder.latent_dim
43 | self.visual_proj = nn.Linear(self.latent_dim, 512)
44 | # self.proj = nn.Linear(self.latent_dim, 512)
45 |
46 | if self.encode_action:
47 | self.action_proj = nn.Linear(512, 512)
48 | self.action_fuser = FeaturePerceiver(
49 | transition_dim=512, condition_dim=512, time_emb_dim=0
50 | )
51 | self.final_proj = nn.Linear(self.action_fuser.last_dim, 512)
52 | else:
53 | print("No action encoding")
54 |
55 | self.fuser = nn.Sequential(
56 | nn.TransformerEncoderLayer(
57 | d_model=512, nhead=4, dim_feedforward=512, batch_first=True
58 | ),
59 | nn.Linear(512, self.latent_dim),
60 | )
61 | self.positional_embedding = PositionalEmbeddingV2(
62 | d_model=512, max_len=400)
63 |
64 | def forward(self, data_batch, training=False):
65 | outputs = {}
66 | object_color_key = "object_color"
67 | object_depth_key = "object_depth"
68 | if training:
69 | object_color_key += "_aug"
70 |
71 | inputs = self.transform(data_batch[object_color_key])
72 | if self.in_channels == 4:
73 | object_depth = data_batch[object_depth_key][:, None]
74 | inputs = torch.cat([inputs, object_depth], dim=1)
75 | features = self.visual(inputs)
76 | features = self.forward_latent(data_batch, features)
77 | pred = self.decoder(features)
78 |
79 | # Post-process the prediction
80 | pred_final = []
81 | pred_vfs = pred[:, : self.out_channels - 1] # [B, 8, H, W]
82 | pred_mask = pred[:, self.out_channels - 1:] # [B, 1, H, W]
83 | for hi in range(0, self.out_channels - 1, 2):
84 | pred_vf = pred_vfs[:, hi: hi + 2] # [B, 2, H, W]
85 | pred_vf = F.normalize(pred_vf, p=2, dim=1) # [-1, 1]
86 | pred_vf = pred_vf.clamp(-1, 1)
87 | pred_final.append(pred_vf)
88 | pred_final.append(pred_mask)
89 | pred_final = torch.cat(pred_final, dim=1)
90 | outputs["pred"] = pred_final # [B, 8+1, H, W]
91 | return outputs
92 |
93 | def forward_latent(self, data_batch, features):
94 | latent = features[self.bottleneck_feature_key]
95 | h, w = latent.shape[-2:]
96 | latent = einops.rearrange(latent, "b c h w -> b (h w) c")
97 | latent = self.visual_proj(latent)
98 | # latent = self.proj(latent)
99 | if self.encode_action:
100 | action_feature = data_batch["action_feature"][:, None]
101 | action_feature = self.action_proj(action_feature)
102 | latent = self.action_fuser(latent, action_feature)
103 | latent = self.final_proj(latent)
104 |
105 | latent = self.positional_embedding(latent)
106 | latent = self.fuser(latent)
107 | latent = einops.rearrange(latent, "b (h w) c -> b c h w", h=h, w=w)
108 | features[self.bottleneck_feature_key] = latent
109 | return features
110 |
--------------------------------------------------------------------------------
/models/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from collections.abc import Sequence
5 | from functools import lru_cache
6 |
7 | import ftfy
8 | import regex as re
9 |
10 |
11 | @lru_cache()
12 | def default_bpe():
13 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
14 |
15 |
16 | @lru_cache()
17 | def bytes_to_unicode():
18 | """
19 | Returns list of utf-8 byte and a corresponding list of unicode strings.
20 | The reversible bpe codes work on unicode strings.
21 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
22 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
23 | This is a signficant percentage of your normal, say, 32K bpe vocab.
24 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
25 | And avoids mapping to whitespace/control characters the bpe code barfs on.
26 | """
27 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
28 | cs = bs[:]
29 | n = 0
30 | for b in range(2**8):
31 | if b not in bs:
32 | bs.append(b)
33 | cs.append(2**8+n)
34 | n += 1
35 | cs = [chr(n) for n in cs]
36 | return dict(zip(bs, cs))
37 |
38 |
39 | def get_pairs(word):
40 | """Return set of symbol pairs in a word.
41 | Word is represented as tuple of symbols (symbols being variable-length strings).
42 | """
43 | pairs = set()
44 | prev_char = word[0]
45 | for char in word[1:]:
46 | pairs.add((prev_char, char))
47 | prev_char = char
48 | return pairs
49 |
50 |
51 | def basic_clean(text):
52 | # note: pretty hacky but it is okay!
53 | # ge: bad.this is used by the cli_multi_label.py script
54 | if not isinstance(text, str):
55 | text = ', '.join(text)
56 |
57 | text = ftfy.fix_text(text)
58 | text = html.unescape(html.unescape(text))
59 | return text.strip()
60 |
61 |
62 | def whitespace_clean(text):
63 | text = re.sub(r'\s+', ' ', text)
64 | text = text.strip()
65 | return text
66 |
67 |
68 | class SimpleTokenizer(object):
69 | def __init__(self, bpe_path: str = default_bpe()):
70 | self.byte_encoder = bytes_to_unicode()
71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
72 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
73 | merges = merges[1:49152-256-2+1]
74 | merges = [tuple(merge.split()) for merge in merges]
75 | vocab = list(bytes_to_unicode().values())
76 | vocab = vocab + [v+'' for v in vocab]
77 | for merge in merges:
78 | vocab.append(''.join(merge))
79 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
80 | self.encoder = dict(zip(vocab, range(len(vocab))))
81 | self.decoder = {v: k for k, v in self.encoder.items()}
82 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
83 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
84 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
85 |
86 | def bpe(self, token):
87 | if token in self.cache:
88 | return self.cache[token]
89 | word = tuple(token[:-1]) + ( token[-1] + '',)
90 | pairs = get_pairs(word)
91 |
92 | if not pairs:
93 | return token+''
94 |
95 | while True:
96 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
97 | if bigram not in self.bpe_ranks:
98 | break
99 | first, second = bigram
100 | new_word = []
101 | i = 0
102 | while i < len(word):
103 | try:
104 | j = word.index(first, i)
105 | new_word.extend(word[i:j])
106 | i = j
107 | except:
108 | new_word.extend(word[i:])
109 | break
110 |
111 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
112 | new_word.append(first+second)
113 | i += 2
114 | else:
115 | new_word.append(word[i])
116 | i += 1
117 | new_word = tuple(new_word)
118 | word = new_word
119 | if len(word) == 1:
120 | break
121 | else:
122 | pairs = get_pairs(word)
123 | word = ' '.join(word)
124 | self.cache[token] = word
125 | return word
126 |
127 | def encode(self, text):
128 | bpe_tokens = []
129 | text = whitespace_clean(basic_clean(text)).lower()
130 | for token in re.findall(self.pat, text):
131 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
132 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
133 | return bpe_tokens
134 |
135 | def decode(self, tokens):
136 | text = ''.join([self.decoder[token] for token in tokens])
137 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
138 | return text
139 |
--------------------------------------------------------------------------------
/algos/traj_algos.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import pytorch_lightning as pl
8 | import torch.nn.functional as F
9 | import diffuser_utils.dataset_utils as DatasetUtils
10 | from models.diffuser import DiffuserModel
11 | from models.helpers import EMA
12 | import open3d as o3d
13 | from diffuser_utils.guidance_params import COMMON_ACTIONS
14 | import torchvision
15 |
16 |
17 | class TrajectoryDiffusionModule(pl.LightningModule):
18 | def __init__(self, algo_config):
19 | super(TrajectoryDiffusionModule, self).__init__()
20 | self.algo_config = algo_config
21 | self.nets = nn.ModuleDict()
22 |
23 | # Initialize the diffuser
24 | policy_kwargs = algo_config.model
25 | self.nets["policy"] = DiffuserModel(**policy_kwargs)
26 |
27 |
28 | @torch.no_grad()
29 | def encode_action(self, data_batch, clip_model, max_length=20):
30 | action_tokens, action_feature = DatasetUtils.encode_text_clip(
31 | clip_model,
32 | [data_batch["action_text"]],
33 | max_length=max_length,
34 | device="cuda",
35 | )
36 |
37 | action_tokens.to(self.device)
38 | action_feature.to(self.device)
39 |
40 | action_text = data_batch["action_text"]
41 | verb_text = action_text.split(" ")[0]
42 | if verb_text not in COMMON_ACTIONS:
43 | verb_text = "other"
44 | else:
45 | verb_text = verb_text.replace("-", "")
46 | verb_text = [verb_text]
47 |
48 | verb_tokens, verb_feature = DatasetUtils.encode_text_clip(
49 | clip_model,
50 | verb_text,
51 | max_length=max_length,
52 | device="cuda",
53 | )
54 |
55 | verb_tokens.to(self.device)
56 | verb_feature.to(self.device)
57 |
58 | data_batch.update({"action_feature": action_feature.float()})
59 | data_batch.update({"verb_feature": verb_feature.float()})
60 |
61 |
62 | def forward(
63 | self,
64 | data_batch,
65 | num_samp=1,
66 | return_diffusion=False,
67 | return_guidance_losses=False,
68 | apply_guidance=False,
69 | class_free_guide_w=0.0,
70 | guide_clean=False,
71 | ):
72 | curr_policy = self.nets["policy"]
73 | return curr_policy(
74 | data_batch,
75 | num_samp,
76 | return_diffusion=return_diffusion,
77 | return_guidance_losses=return_guidance_losses,
78 | apply_guidance=apply_guidance,
79 | class_free_guide_w=class_free_guide_w,
80 | guide_clean=guide_clean,
81 | )
82 |
83 | def visualize_trajectory_by_rendering(
84 | self,
85 | data_batch,
86 | config_path,
87 | window=False,
88 | return_vis=False,
89 | draw_grippers=False,
90 | **kwargs
91 | ):
92 | batch_size = len(data_batch["color"])
93 | results = []
94 | for i in range(batch_size):
95 | vis_o3d = []
96 | depth = data_batch["depth"][i].cpu().numpy()
97 | color = data_batch["color"][i].cpu().numpy().transpose(1, 2, 0)
98 | intr = data_batch["intrinsics"][i].cpu().numpy()
99 | # gt_traj = data_batch["gt_trajectory"][i].cpu().numpy()
100 |
101 | # backproject
102 | points_scene, scene_ids = DatasetUtils.backproject(
103 | depth,
104 | intr,
105 | depth > 0,
106 | # np.logical_and(hand_mask == 0, depth > 0),
107 | NOCS_convention=False,
108 | )
109 |
110 | colors_scene = color.copy()[scene_ids[0], scene_ids[1]]
111 | pcd_scene = DatasetUtils.visualize_points(
112 | points_scene, colors_scene)
113 | # gt_traj_vis = DatasetUtils.visualize_3d_trajectory(
114 | # gt_traj, size=0.02, cmap_name="viridis"
115 | # )
116 |
117 | vis_o3d = [pcd_scene] # + gt_traj_vis
118 | if "pred_trajectories" in data_batch:
119 | print("===> Visualizing pred trajectories")
120 | pred_trajs = data_batch["pred_trajectories"][i].cpu().numpy()
121 | pred_traj_colors = DatasetUtils.random_colors(len(pred_trajs))
122 | for pi, pred_traj in enumerate(pred_trajs):
123 | _pred_traj_vis = DatasetUtils.visualize_3d_trajectory(
124 | pred_traj, size=0.01, cmap_name="plasma"
125 | )
126 | if len(pred_trajs) > 1:
127 | _pred_traj_vis = [
128 | s.paint_uniform_color(pred_traj_colors[pi])
129 | for s in _pred_traj_vis
130 | ]
131 |
132 | pred_traj_vis = _pred_traj_vis[0]
133 | for ii in range(1, len(_pred_traj_vis)):
134 | pred_traj_vis += _pred_traj_vis[ii]
135 |
136 | vis_o3d += [pred_traj_vis]
137 |
138 | if window:
139 | o3d.visualization.draw(vis_o3d)
140 |
141 | if return_vis:
142 | return vis_o3d
143 |
144 | render_dist = np.median(np.linalg.norm(points_scene, axis=1))
145 | render_img = DatasetUtils.render_offscreen(
146 | vis_o3d,
147 | config_path,
148 | # dist=render_dist,
149 | resize_factor=0.5,
150 | )
151 | render_img = torchvision.transforms.ToTensor()(render_img)
152 | results.append(render_img)
153 | results = torch.stack(results, dim=0) # [B, C, H, W]
154 | data_batch.update({"pred_vis": results})
155 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | VidBot: Learning Generalizable 3D Actions from In-the-Wild 2D Human Videos for Zero-Shot Robotic Manipulation
3 |
4 | CVPR 2025
5 |
6 |
14 |
15 |
16 |
17 |

18 |
19 |
20 | This is the official repository of [**VidBot: Learning Generalizable 3D Actions from In-the-Wild 2D Human Videos for Zero-Shot Robotic Manipulation**](https://arxiv.org/abs/2503.07135). For more details, please check our [**project website**](https://hanzhic.github.io/vidbot-project/).
21 |
22 |
23 |
24 | ## Installation
25 |
26 | To install VidBot, follow these steps:
27 |
28 | 1. **Clone the Repository**:
29 | ```bash
30 | git clone https://github.com/HanzhiC/vidbot.git
31 | cd vidbot
32 | ```
33 |
34 | 2. **Install Dependencies**:
35 | ```bash
36 | # Prepare the environment
37 | conda create -n vidbot python=3.10.9
38 | conda activate vidbot
39 |
40 | # Ensure PyTorch 1.13.1 is installed, pytorch-lightning might change the PyTorch version
41 | pip install pytorch-lightning==1.8.6
42 | pip install -r requirements.txt
43 |
44 | # Install PyTorch Scatter
45 | wget https://data.pyg.org/whl/torch-1.13.0%2Bcu117/torch_scatter-2.1.1%2Bpt113cu117-cp310-cp310-linux_x86_64.whl
46 | pip install torch_scatter-2.1.1+pt113cu117-cp310-cp310-linux_x86_64.whl
47 | rm -rf torch_scatter-2.1.1+pt113cu117-cp310-cp310-linux_x86_64.whl
48 | ```
49 |
50 | 3. **Download Pretrained Weights and Demo Dataset**:
51 | ```bash
52 | sh scripts/download_ckpt_testdata.sh
53 | ```
54 | You can now try out VidBot with the demo data we've provided!
55 |
56 | 4. **(Optional) Install Third-Party Modules**:
57 | ```bash
58 | sh scripts/prepare_third_party_modules.sh
59 | ```
60 | Follow the installation instructions from [GroundingDINO](https://github.com/IDEA-Research/GroundingDINO), [EfficientSAM](https://github.com/yformer/EfficientSAM), [GraspNet](https://github.com/graspnet/graspness_unofficial), and [GraspNetAPI](https://github.com/graspnet/graspnetAPI) to set up these third-party modules.
61 |
62 | **Note**: The `transformers` library should be version 4.26.1. Installing `GroundingDINO` might change the version. Installing `MinkowskiEngine` for GraspNet can be painful. However, our framework can still function without GraspNet. In such cases, we will employ a simplified method to obtain the grasp poses.
63 |
64 | ## Affordance Inference
65 |
66 | 1. **Quick Start with VidBot**: To quickly explore VidBot, you don't need to install any third-party modules. After downloading the weights and demo dataset, you can use our pre-saved bounding boxes to run the inference scripts with the following command:
67 | ```bash**
68 | bash scripts/test_demo.sh
69 | ```
70 |
71 | 2. **Testing VidBot with Your Own Data**: To test VidBot using your own data, just put your collected dataset under the `./datasets/` folder. Please ensure your data is organized to match the structure of our demo dataset:
72 | ```text
73 | YOUR_DATASET_NAME/
74 | ├── camera_intrinsic.json
75 | ├── color
76 | │ ├── 000000.png
77 | │ ├── 000001.png
78 | │ ├── 00000X.png
79 | ├── depth
80 | │ ├── 000000.png
81 | │ ├── 000001.png
82 | │ ├── 00000X.png
83 | ```
84 | The `camera_intrinsic.json` file should be structured as follows:
85 | ```json
86 | {
87 | "width": width,
88 | "height": height,
89 | "intrinsic_matrix": [
90 | fx,
91 | 0,
92 | 0,
93 | 0,
94 | fy,
95 | 0,
96 | cx,
97 | cy,
98 | 1
99 | ]
100 | }
101 | ```
102 | **We recommend using an image resolution of 1280x720.**
103 |
104 | 3. **Run the Inference Script**: To run tests with your own data, execute the following command, ensuring you understand the meaning of each input argument:
105 | ```bash
106 | python demos/infer_affordance.py
107 | --config ./config/test_config.yaml
108 | --dataset YOUR_DATASET_NAME
109 | --frame FRAME_ID
110 | --instruction YOUR_INSTRUCTION
111 | --object OBJECT_CLASS
112 | --visualize
113 | ```
114 | If you have installed GraspNet and wish to estimate the gripper pose, add the `--use_graspnet` option to the command.
115 |
116 |
117 | ## Citation
118 |
119 | **If you find our work useful, please cite:**
120 |
121 | ```bibtex
122 | @article{chen2025vidbot,
123 | author = {Chen, Hanzhi and Sun, Boyang and Zhang, Anran and Pollefeys, Marc and Leutenegger, Stefan},
124 | title = {{VidBot}: Learning Generalizable 3D Actions from In-the-Wild 2D Human Videos for Zero-Shot Robotic Manipulation},
125 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference},
126 | year = {2025},
127 | }
128 | ```
129 |
130 | ## Acknowledgement
131 | Our codebase is built upon [TRACE](https://github.com/nv-tlabs/trace). Partial code is borrowed from [ConvONet](https://github.com/autonomousvision/convolutional_occupancy_networks), [afford-motion](https://github.com/afford-motion/afford-motion) and [rq-vae-transformer
132 | ](https://github.com/kakaobrain/rq-vae-transformer). Thanks for their great contribution!
133 |
134 | ## License
135 |
136 | This project is licensed under the MIT License. See [LICENSE](LICENSE) for more details.
137 |
--------------------------------------------------------------------------------
/models/perceiver.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from typing import List, Optional, Tuple
5 | from collections import OrderedDict
6 | from einops import rearrange
7 | from models.layers_3d import SinusoidalPosEmb
8 | from models.attention import SelfAttentionBlock, CrossAttentionLayer
9 |
10 |
11 | class FeaturePerceiver(nn.Module):
12 |
13 | def __init__(
14 | self,
15 | transition_dim,
16 | condition_dim,
17 | time_emb_dim,
18 | encoder_q_input_channels=512,
19 | encoder_kv_input_channels=256,
20 | encoder_num_heads=8,
21 | encoder_widening_factor=1,
22 | encoder_dropout=0.1,
23 | encoder_residual_dropout=0.0,
24 | encoder_self_attn_num_layers=2,
25 | decoder_q_input_channels=256,
26 | decoder_kv_input_channels=512,
27 | decoder_num_heads=8,
28 | decoder_widening_factor=1,
29 | decoder_dropout=0.1,
30 | decoder_residual_dropout=0.0,
31 | ) -> None:
32 | super().__init__()
33 |
34 | self.encoder_q_input_channels = encoder_q_input_channels
35 | self.encoder_kv_input_channels = encoder_kv_input_channels
36 | self.encoder_num_heads = encoder_num_heads
37 | self.encoder_widening_factor = encoder_widening_factor
38 | self.encoder_dropout = encoder_dropout
39 | self.encoder_residual_dropout = encoder_residual_dropout
40 | self.encoder_self_attn_num_layers = encoder_self_attn_num_layers
41 |
42 | self.decoder_q_input_channels = decoder_q_input_channels
43 | self.decoder_kv_input_channels = decoder_kv_input_channels
44 | self.decoder_num_heads = decoder_num_heads
45 | self.decoder_widening_factor = decoder_widening_factor
46 | self.decoder_dropout = decoder_dropout
47 | self.decoder_residual_dropout = decoder_residual_dropout
48 |
49 | self.condition_adapter = nn.Linear(
50 | condition_dim, self.encoder_q_input_channels, bias=True
51 | )
52 |
53 | if time_emb_dim > 0:
54 | self.time_embedding_adapter = nn.Linear(
55 | time_emb_dim, self.encoder_q_input_channels, bias=True
56 | )
57 | else:
58 | self.time_embedding_adapter = None
59 |
60 | self.encoder_adapter = nn.Linear(
61 | transition_dim,
62 | self.encoder_kv_input_channels,
63 | bias=True,
64 | )
65 | self.decoder_adapter = nn.Linear(
66 | self.encoder_kv_input_channels, self.decoder_q_input_channels, bias=True
67 | )
68 |
69 | self.encoder_cross_attn = CrossAttentionLayer(
70 | num_heads=self.encoder_num_heads,
71 | num_q_input_channels=self.encoder_q_input_channels,
72 | num_kv_input_channels=self.encoder_kv_input_channels,
73 | widening_factor=self.encoder_widening_factor,
74 | dropout=self.encoder_dropout,
75 | residual_dropout=self.encoder_residual_dropout,
76 | )
77 |
78 | self.encoder_self_attn = SelfAttentionBlock(
79 | num_layers=self.encoder_self_attn_num_layers,
80 | num_heads=self.encoder_num_heads,
81 | num_channels=self.encoder_q_input_channels,
82 | widening_factor=self.encoder_widening_factor,
83 | dropout=self.encoder_dropout,
84 | residual_dropout=self.encoder_residual_dropout,
85 | )
86 |
87 | self.decoder_cross_attn = CrossAttentionLayer(
88 | num_heads=self.decoder_num_heads,
89 | num_q_input_channels=self.decoder_q_input_channels,
90 | num_kv_input_channels=self.decoder_kv_input_channels,
91 | widening_factor=self.decoder_widening_factor,
92 | dropout=self.decoder_dropout,
93 | residual_dropout=self.decoder_residual_dropout,
94 | )
95 | self.last_dim = self.decoder_q_input_channels
96 |
97 |
98 | def forward(
99 | self,
100 | x,
101 | condition_feat,
102 | time_embedding=None,
103 | ):
104 | """Forward pass of the ContactMLP.
105 |
106 | Args:
107 | x: input contact map, [bs, num_points, transition_dim]
108 | condition_feat: [bs, 1, condition_dim]
109 | time_embedding: [bs, 1, time_embedding_dim]
110 |
111 | Returns:
112 | Output contact map, [bs, num_points, contact_dim]
113 | """
114 |
115 | # encoder
116 | enc_kv = self.encoder_adapter(x) # [bs, num_points, enc_kv_dim]
117 | cond_feat = self.condition_adapter(condition_feat) # [bs, 1, enc_q_dim]
118 | if time_embedding is not None and self.time_embedding_adapter is not None:
119 | time_embedding = self.time_embedding_adapter(
120 | time_embedding
121 | ) # [bs, 1, enc_q_dim]
122 |
123 | enc_q = torch.cat([cond_feat, time_embedding], dim=1) # [bs, 1 + 1, enc_q_dim]
124 | else:
125 | enc_q = cond_feat
126 |
127 | enc_q = self.encoder_cross_attn(enc_q, enc_kv).last_hidden_state
128 | enc_q = self.encoder_self_attn(enc_q).last_hidden_state
129 |
130 | # decoder
131 | dec_kv = enc_q
132 | dec_q = self.decoder_adapter(enc_kv) # [bs, num_points, dec_q_dim]
133 | dec_q = self.decoder_cross_attn(
134 | dec_q, dec_kv
135 | ).last_hidden_state # [bs, num_points, dec_q_dim]
136 |
137 | return dec_q
138 |
139 |
140 | if __name__ == "__main__":
141 | print("Testing ContactPerceiver")
142 | feat_preceiver = FeaturePreceiver(
143 | transition_dim=64 + 3,
144 | condition_dim=512,
145 | time_emb_dim=64,
146 | )
147 | x = torch.randn(2, 80, 64 + 3)
148 | cond = torch.randn(2, 1, 512)
149 | time = torch.randn(2, 1, 64)
150 | feat = feat_preceiver(x, cond, time)
151 | print(feat.shape)
152 |
--------------------------------------------------------------------------------
/demos/optimize_affordance.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import open3d as o3d
4 | import torch
5 | from scipy.signal import savgol_filter
6 | from algos.traj_optimizer import TrajectoryOptimizer
7 | from diffuser_utils.dataset_utils import visualize_3d_trajectory, load_json, visualize_points, backproject, interpolate_trajectory
8 |
9 | def run(colmap_results, data, device, visualize=True):
10 | traj, vis, vis_scene = [], [], []
11 |
12 | # Prepare the tensors
13 | frame_ids = data["frame_ids"]
14 | intr = data["intr"]
15 | rgbs = data["rgbs"]
16 | depths = data["depths"]
17 | hand_masks = data["masks"]
18 | obj_masks = data["obj_masks"]
19 | hand_bboxes = data["hand_bboxes"]
20 | obj_bboxes = data["obj_bboxes"]
21 | rgb_tensors, depth_tensors, mask_tensors = [], [], []
22 | for ii, fi in enumerate(frame_ids):
23 | rgb = rgbs[ii] / 255.0
24 | depth = depths[ii]
25 | mask_hand = hand_masks[ii]
26 | mask_obj = obj_masks[ii]
27 | mask_hand = cv2.dilate(mask_hand, np.ones((5, 5), np.uint8), iterations=3)
28 | mask_obj = cv2.dilate(mask_obj, np.ones((5, 5), np.uint8), iterations=3)
29 | mask_dynamic = np.logical_or(mask_hand > 0, mask_obj > 0)
30 | mask_static = (1 - mask_dynamic).astype(np.float32)
31 | rgb_tensors.append(torch.from_numpy(rgb).permute(2, 0, 1).float())
32 | depth_tensors.append(torch.from_numpy(depth).unsqueeze(0).float())
33 | mask_tensors.append(torch.from_numpy(mask_static).unsqueeze(0).float())
34 | rgb_tensors = torch.stack(rgb_tensors).to(device)
35 | depth_tensors = torch.stack(depth_tensors).to(device)
36 | mask_tensors = torch.stack(mask_tensors).to(device)
37 | height, width = rgb_tensors.shape[-2:]
38 | intr = torch.from_numpy(intr).float().to(device)
39 |
40 | # Initialize the pose and scale optimizer
41 | traj_optimizer = TrajectoryOptimizer(
42 | resolution=(height, width),
43 | lr_scale_global=0.05,
44 | lr_scale=0.1,
45 | lr_pose=0.05,
46 | num_iters_scale=10,
47 | num_iters_pose=50,
48 | device=device,
49 | )
50 |
51 | # Optimize the global scale
52 | scale_init_tensors, scale_global_final, key_idx = traj_optimizer.optimize_global_scale(
53 | rgb_tensors,
54 | depth_tensors,
55 | mask_tensors,
56 | colmap_results,
57 | )
58 |
59 | # Optimize the pose and scale
60 | scale_init_tensors = scale_global_final * torch.ones_like(scale_init_tensors)
61 | T_kc_final, scale_final = traj_optimizer.optimize_pose(
62 | intr,
63 | rgb_tensors,
64 | depth_tensors,
65 | mask_tensors,
66 | scale_init_tensors,
67 | scale_global_final,
68 | colmap_results,
69 | key_idx=key_idx,
70 | optimize_pose=True,
71 | verbose=False,
72 | )
73 |
74 | # Acquire the optimized results
75 | T_kc_final = T_kc_final.detach().cpu().numpy()
76 | scale_final = scale_final.detach().cpu().numpy()
77 | intr_np = intr.clone().cpu().numpy()
78 |
79 | # Pose and scale of the first frame
80 | T_kc0 = T_kc_final[0]
81 | scale_m2c0 = scale_final[0]
82 | T_kc0[:3, 3] = T_kc0[:3, 3] / scale_m2c0
83 |
84 | for ii, fi in enumerate(frame_ids):
85 | # Pose and scale of the current frame
86 | T_kc = T_kc_final[ii]
87 | scale_m2c = scale_final[ii]
88 | T_kc[:3, 3] = T_kc[:3, 3] / scale_m2c
89 |
90 | # Transformation from current frame to the first frame
91 | T_c0c = np.linalg.inv(T_kc0) @ T_kc
92 |
93 | # Get the depth and hand bbox
94 | depth = depths[ii]
95 | hand_bbox = hand_bboxes[ii]
96 | hand_bbox_mask = np.zeros_like(depth)
97 | hand_bbox_mask[hand_bbox[1] : hand_bbox[3], hand_bbox[0] : hand_bbox[2]] = 1
98 | points_hand, scene_ids = backproject(depth, intr_np, hand_bbox_mask > 0, False)
99 | points_hand = points_hand @ T_c0c[:3, :3].T + T_c0c[:3, 3]
100 | wp = np.median(points_hand, axis=0)
101 | traj.append(wp)
102 |
103 | # Acquire the hand points in the scene
104 | if visualize:
105 | hand_seg_mask = hand_masks[ii]
106 | hand_seg_mask = cv2.erode(hand_seg_mask, np.ones((5, 5), np.uint8), iterations=2)
107 | hand_seg_mask = hand_seg_mask * (hand_bbox_mask > 0)
108 | points_hand_scene, scene_ids = backproject(depth, intr_np, hand_seg_mask > 0, False)
109 | point_colors_scene = rgbs[ii][scene_ids[0], scene_ids[1]]
110 | point_colors_scene = point_colors_scene / 255.0
111 | points_hand_scene = points_hand_scene @ T_c0c[:3, :3].T + T_c0c[:3, 3]
112 | pcd_scene = visualize_points(points_hand_scene, point_colors_scene)
113 | vis_scene.append(pcd_scene)
114 |
115 | # Acquire the background points
116 | if ii == 0:
117 | rgb_orig = rgbs[ii] / 255.0
118 | hand_seg_mask = hand_masks[ii]
119 | points_orig, scene_ids = backproject(depth, intr_np, hand_seg_mask == 0, False)
120 | point_colors_orig = rgb_orig[scene_ids[0], scene_ids[1]]
121 | pcd_orig = visualize_points(points_orig, point_colors_orig)
122 |
123 | # Build trajectory and visualize
124 | traj = np.array(traj)
125 | traj = savgol_filter(traj, len(traj) - 1, (len(traj) + 1) // 2, axis=0)
126 | fill_indices = frame_ids - frame_ids[0]
127 | traj_smooth, _ = interpolate_trajectory(fill_indices, traj)
128 | filter_window = min(5, len(traj_smooth) - 1)
129 | traj_smooth = savgol_filter(
130 | traj_smooth, filter_window, (filter_window + 1) // 2, axis=0
131 | )
132 | if visualize:
133 | _traj_vis = visualize_3d_trajectory(traj_smooth, size=0.05, cmap_name="viridis")
134 | traj_vis = _traj_vis[0]
135 | for v in _traj_vis[1:]:
136 | traj_vis += v
137 |
138 | hand_vis = vis_scene[0]
139 | for v in vis_scene[1:]:
140 | hand_vis += v
141 | vis = [pcd_orig, traj_vis, hand_vis]
142 | return traj_smooth, vis
143 |
144 | if __name__ == "__main__":
145 | # Load the data and prepare the tensors
146 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147 | colmap_results_fpath = "datasets/epickitchens_traj_demo/colmap.json"
148 | data_fpath = "datasets/epickitchens_traj_demo/observation.npz"
149 | colmap_results = load_json(colmap_results_fpath)
150 | data = np.load(data_fpath)
151 |
152 | # Optimize the trajectory and visualize the results
153 | traj_smooth, vis = run(colmap_results, data, device, visualize=True)
154 | o3d.visualization.draw(vis)
--------------------------------------------------------------------------------
/models/temporal.py:
--------------------------------------------------------------------------------
1 | #
2 | # Based on Diffuser: https://github.com/jannerm/diffuser/blob/main/diffuser/models/temporal.py
3 | #
4 |
5 | import torch
6 | import torch.nn as nn
7 | import einops
8 | from einops.layers.torch import Rearrange
9 |
10 | from models.layers_2d import (
11 | Downsample1d,
12 | Upsample1d,
13 | Conv1dBlock,
14 | )
15 |
16 | from models.layers_3d import SinusoidalPosEmb
17 | from models.perceiver import FeaturePerceiver
18 |
19 |
20 | class ResidualTemporalMapBlockConcat(nn.Module):
21 |
22 | def __init__(
23 | self, inp_channels, out_channels, time_embed_dim, horizon, kernel_size=5
24 | ):
25 | super().__init__()
26 |
27 | self.time_mlp = nn.Sequential(
28 | nn.Mish(),
29 | nn.Linear(time_embed_dim, out_channels),
30 | Rearrange("batch t -> batch t 1"),
31 | )
32 |
33 | self.blocks = nn.ModuleList(
34 | [
35 | Conv1dBlock(inp_channels, out_channels, kernel_size),
36 | Conv1dBlock(out_channels, out_channels, kernel_size),
37 | ]
38 | )
39 |
40 | self.residual_conv = (
41 | nn.Conv1d(inp_channels, out_channels, 1)
42 | if inp_channels != out_channels
43 | else nn.Identity()
44 | )
45 |
46 | def forward(self, x, t):
47 | """
48 | x : [ batch_size x inp_channels x horizon ]
49 | t : [ batch_size x embed_dim ]
50 | returns:
51 | out : [ batch_size x out_channels x horizon ]
52 | """
53 | out = self.blocks[0](x) + self.time_mlp(t)
54 | out = self.blocks[1](out)
55 | return out + self.residual_conv(x)
56 |
57 |
58 | class TemporalMapUnet(nn.Module):
59 |
60 | def __init__(
61 | self,
62 | horizon,
63 | transition_dim,
64 | cond_dim, # additional dimension concatenated with the time dimension
65 | output_dim,
66 | dim=32, # time_dimesion
67 | dim_mults=(1, 2, 4, 8),
68 | use_preceiver=False,
69 | ):
70 | super().__init__()
71 |
72 | ResidualTemporalMapBlock = ResidualTemporalMapBlockConcat
73 |
74 | dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
75 | in_out = list(zip(dims[:-1], dims[1:]))
76 |
77 | time_dim = dim
78 | self.time_mlp = nn.Sequential(
79 | SinusoidalPosEmb(time_dim),
80 | nn.Linear(time_dim, time_dim * 4),
81 | nn.Mish(),
82 | nn.Linear(time_dim * 4, time_dim),
83 | )
84 |
85 | cond_dim = cond_dim + time_dim
86 |
87 | self.downs = nn.ModuleList([])
88 | self.ups = nn.ModuleList([])
89 | num_resolutions = len(in_out)
90 |
91 | # Remember the property of the 1D convolution, [B, C_in, L_in] => [B, C_out, L_out]
92 | # L_out is dependent on the kernel size and stride, and L_in
93 |
94 | for ind, (dim_in, dim_out) in enumerate(in_out):
95 | is_last = ind >= (num_resolutions - 1)
96 |
97 | self.downs.append(
98 | nn.ModuleList(
99 | [
100 | ResidualTemporalMapBlock(
101 | dim_in, dim_out, time_embed_dim=cond_dim, horizon=horizon
102 | ), # Feature dimension changes, no horizon changes
103 | ResidualTemporalMapBlock(
104 | dim_out, dim_out, time_embed_dim=cond_dim, horizon=horizon
105 | ),
106 | (
107 | Downsample1d(dim_out) if not is_last else nn.Identity()
108 | ), # No feature dimension changes, but horizon changes
109 | ]
110 | )
111 | )
112 |
113 | if not is_last:
114 | horizon = horizon // 2
115 |
116 | mid_dim = dims[-1]
117 | self.mid_block1 = ResidualTemporalMapBlock(
118 | mid_dim, mid_dim, time_embed_dim=cond_dim, horizon=horizon
119 | )
120 | self.mid_block2 = ResidualTemporalMapBlock(
121 | mid_dim, mid_dim, time_embed_dim=cond_dim, horizon=horizon
122 | )
123 |
124 | final_up_dim = None
125 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
126 | is_last = ind >= (num_resolutions - 1)
127 |
128 | self.ups.append(
129 | nn.ModuleList(
130 | [
131 | ResidualTemporalMapBlock(
132 | dim_out * 2,
133 | dim_in,
134 | time_embed_dim=cond_dim,
135 | horizon=horizon,
136 | ), # Feature dimension changes, no horizon changes
137 | ResidualTemporalMapBlock(
138 | dim_in, dim_in, time_embed_dim=cond_dim, horizon=horizon
139 | ),
140 | (
141 | Upsample1d(dim_in) if not is_last else nn.Identity()
142 | ), # No feature dimension change, but horizon changes
143 | ]
144 | )
145 | )
146 | final_up_dim = dim_in
147 |
148 | if not is_last:
149 | horizon = horizon * 2
150 |
151 | self.final_conv = nn.Sequential(
152 | Conv1dBlock(final_up_dim, final_up_dim, kernel_size=5),
153 | nn.Conv1d(final_up_dim, output_dim, 1),
154 | )
155 | self.use_preceiver = use_preceiver
156 |
157 | if self.use_preceiver:
158 | self.preceiver = FeaturePerceiver(
159 | transition_dim=transition_dim,
160 | condition_dim=cond_dim - time_dim,
161 | time_emb_dim=time_dim,
162 | )
163 | self.proj = nn.Linear(self.preceiver.last_dim, transition_dim)
164 |
165 | def forward(self, x, cond, time):
166 | """
167 | x : [ batch x horizon x transition ]
168 | cond: [ batch x cond_dim ]
169 | time: [ batch ]
170 | """
171 | t = self.time_mlp(time)
172 |
173 | if self.use_preceiver:
174 | x = self.preceiver(x, cond[:, None], t[:, None])
175 | x = self.proj(x)
176 | x = einops.rearrange(x, "b h t -> b t h")
177 | t = torch.cat([t, cond], dim=-1) # [time+object+action+spatial]
178 |
179 | h = []
180 | for ii, (resnet, resnet2, downsample) in enumerate(self.downs):
181 | x = resnet(x, t)
182 | x = resnet2(x, t)
183 |
184 | h.append(x)
185 | x = downsample(
186 | x
187 | ) # Increase the feature dimension, reduce the horizon (consider the spatial resolution in image)
188 | # print("Downsample step {}, with shape {}".format(ii, x.shape)) # [B, C, H]
189 | # print(f"[ models/temporal ] Downsample step {ii}, with shape {x.shape}")
190 |
191 | x = self.mid_block1(x, t)
192 | x = self.mid_block2(x, t)
193 | for ii, (resnet, resnet2, upsample) in enumerate(self.ups):
194 | x = torch.cat((x, h.pop()), dim=1)
195 | x = resnet(x, t)
196 | x = resnet2(x, t)
197 | x = upsample(x) # Decrease the feature dimension, increase the horizon
198 | # print("Upsample step {}, with shape {}".format(ii, x.shape)) # [B, C, H]
199 |
200 | x = self.final_conv(x)
201 | x = einops.rearrange(x, "b t h -> b h t")
202 | return x
203 |
204 |
205 | if __name__ == "__main__":
206 | model = TemporalMapUnet(
207 | horizon=80, # time horizon
208 | transition_dim=67, # dimension of the input trajectory
209 | cond_dim=32, # dimension of the condition (from image, depth, text, etc.)
210 | output_dim=3, # dimension of the output trajectory
211 | dim=32, # base feature dimension
212 | dim_mults=(2, 4, 8), # number of the layers
213 | use_preceiver=True,
214 | )
215 | x = torch.randn(2, 80, 67)
216 | cond = torch.randn(2, 32)
217 | time = torch.randn(2)
218 | print("Input shape: ", x.permute(0, 2, 1).shape) # [B, input_dim, H]
219 | out = model(x, cond, time) # [B, dim', H']
220 | print("Output shape: ", out.permute(0, 2, 1).shape) # [B, outpu_dim, H]
221 |
222 | # Input shape: torch.Size([2, 3, 20])
223 | # Downsample step 0, with shape torch.Size([2, 64, 10])
224 | # Downsample step 1, with shape torch.Size([2, 128, 5])
225 | # Downsample step 2, with shape torch.Size([2, 256, 5])
226 | # Upsample step 0, with shape torch.Size([2, 128, 10])
227 | # Upsample step 1, with shape torch.Size([2, 64, 20])
228 | # Output shape: torch.Size([2, 3, 20])
229 |
--------------------------------------------------------------------------------
/models/goal.py:
--------------------------------------------------------------------------------
1 | import pytorch_lightning as pl
2 | import einops
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import tqdm
7 | import numpy as np
8 | from models.layers_2d import Decoder, Encoder
9 | from models.helpers import FocalLoss
10 | from models.clip import clip, tokenize
11 |
12 | import torchvision.models as models
13 | import torchvision.transforms as transforms
14 | from torchvision.ops import roi_align, roi_pool
15 | from copy import deepcopy
16 | from models.layers_2d import load_clip
17 | from models.perceiver import FeaturePerceiver
18 |
19 |
20 | class GoalPredictor(pl.LightningModule):
21 | def __init__(
22 | self,
23 | in_channels=4,
24 | out_channels=3,
25 | resolution=[256, 448],
26 | channel_multiplier=[1, 2, 4, 8, 16],
27 | bbox_feature_dim=64,
28 | visual_feature_dim=512,
29 | encode_action=False,
30 | encode_bbox=False,
31 | encode_object=False,
32 | num_heads_attention=4,
33 | num_layers_attention=2,
34 | object_encode_mode="roi_pool",
35 | **kwargs,
36 | ):
37 | super().__init__()
38 |
39 | # self.gpt = gpt
40 | self.in_channels = in_channels
41 | self.out_channels = out_channels
42 | self.resolution = resolution
43 |
44 | self.encode_action = encode_action
45 | self.encode_bbox = encode_bbox
46 | self.encode_object = encode_object
47 |
48 | self.visual_feature_dim = visual_feature_dim
49 | self.bbox_feature_dim = bbox_feature_dim
50 | self.object_encode_mode = object_encode_mode
51 |
52 | # ch_mult=[1, 2, 4, 8, 16]
53 | self.channel_multiplier = channel_multiplier
54 |
55 | self.downscale_factor = 2 ** (len(self.channel_multiplier) - 1)
56 | attn_resolutions = (
57 | resolution[0] // self.downscale_factor,
58 | resolution[1] // self.downscale_factor,
59 | )
60 |
61 | # Decide the model architecture
62 | self.visual = Encoder(
63 | ch=64,
64 | ch_mult=self.channel_multiplier,
65 | num_res_blocks=2,
66 | attn_resolutions=attn_resolutions,
67 | in_channels=in_channels,
68 | out_ch=out_channels,
69 | resolution=resolution,
70 | double_z=False,
71 | z_channels=self.visual_feature_dim,
72 | )
73 |
74 | self.decoder = Decoder(
75 | ch=64,
76 | ch_mult=self.channel_multiplier,
77 | num_res_blocks=2,
78 | attn_resolutions=attn_resolutions,
79 | in_channels=in_channels,
80 | out_ch=out_channels,
81 | resolution=resolution,
82 | double_z=False,
83 | z_channels=self.visual_feature_dim,
84 | )
85 |
86 | self.transform = transforms.Compose(
87 | [
88 | transforms.Normalize([0.485, 0.456, 0.406], [
89 | 0.229, 0.224, 0.225]),
90 | ]
91 | )
92 |
93 | if self.encode_object:
94 | if self.object_encode_mode == "vlm":
95 | obj_dim = self.vlm_dim
96 | else:
97 | obj_dim = self.visual_feature_dim
98 |
99 | self.object_encode_module = nn.Linear(
100 | obj_dim, self.visual_feature_dim)
101 |
102 | if self.encode_action:
103 |
104 | self.action_encode_module = nn.Linear(
105 | self.visual_feature_dim, self.visual_feature_dim
106 | )
107 |
108 | if self.encode_bbox:
109 |
110 | self.bbox_encode_module = nn.Linear(4, bbox_feature_dim)
111 |
112 | fuser_dim = 0
113 | if self.encode_action:
114 | fuser_dim += self.visual_feature_dim
115 | if self.encode_object:
116 | fuser_dim += self.visual_feature_dim
117 | if self.encode_bbox:
118 | fuser_dim += self.bbox_feature_dim
119 |
120 | if self.encode_action or self.encode_object or self.encode_bbox:
121 |
122 | self.fuser = FeaturePerceiver(
123 | transition_dim=self.visual_feature_dim,
124 | condition_dim=fuser_dim,
125 | time_emb_dim=0,
126 | )
127 | self.proj = nn.Linear(self.fuser.last_dim, self.visual_feature_dim)
128 |
129 | else:
130 | self.fuser = None
131 |
132 | self.depth_fuser = FeaturePerceiver(
133 | transition_dim=self.visual_feature_dim,
134 | condition_dim=self.visual_feature_dim,
135 | time_emb_dim=0, # No time embedding
136 | )
137 | self.depth_proj = nn.Sequential(
138 | nn.Linear(self.depth_fuser.last_dim, self.visual_feature_dim),
139 | nn.TransformerEncoder(
140 | nn.TransformerEncoderLayer(
141 | d_model=self.visual_feature_dim,
142 | dim_feedforward=512,
143 | nhead=4,
144 | batch_first=True,
145 | ),
146 | num_layers=3,
147 | ),
148 | nn.Linear(self.visual_feature_dim, 1),
149 | )
150 |
151 | def forward(self, data_batch, training=False):
152 | target_key = "vfd"
153 | color_key = "color"
154 | object_color_key = "object_color"
155 |
156 | if training:
157 | color_key += "_aug"
158 | object_color_key += "_aug"
159 |
160 | inputs = self.transform(data_batch[color_key])
161 | depth = data_batch["depth"][:, None] # [B, 1, H, W]
162 |
163 | if self.in_channels == 4:
164 | start_pos_depth = data_batch["start_pos_depth"][:, None]
165 | start_pos_depth_res = depth - start_pos_depth
166 | inputs = torch.cat([inputs, start_pos_depth_res], dim=1)
167 |
168 | # Shape information
169 | batch_size = inputs.shape[0]
170 | h_in, w_in = inputs.shape[-2:]
171 | h_out, w_out = h_in // self.downscale_factor, w_in // self.downscale_factor
172 |
173 | # Box information
174 | bbox = data_batch["bbox"] # [B, 4]
175 | bbox_batch_id = torch.arange(
176 | batch_size, device=inputs.device
177 | ) # Only one box per sample
178 | bbox = torch.cat([bbox_batch_id[:, None], bbox], dim=1) # [B, 5]
179 |
180 | # Extract visual features
181 | context_feature = self.visual(inputs)
182 | context_feature = einops.rearrange(
183 | context_feature.clone(), "b c h w -> b (h w) c", h=h_out, w=w_out
184 | )
185 | feature = context_feature
186 |
187 | # Acquire bbox features
188 | condition_feature = []
189 | if self.encode_object:
190 | if self.object_encode_mode == "vlm":
191 |
192 | print("... Try another way to encode object")
193 | elif self.object_encode_mode in ["roi_pool", "roi_align"]:
194 | roi_res = 6
195 |
196 | roi_method = eval(self.object_encode_mode)
197 | context_feature_obj = einops.rearrange(
198 | context_feature, "b (h w) c -> b c h w", h=h_out, w=w_out
199 | )
200 | spatial_scale = context_feature_obj.shape[-1] / \
201 | inputs.shape[-1]
202 | assert spatial_scale == context_feature_obj.shape[-2] / \
203 | inputs.shape[-2]
204 |
205 | context_feature_obj = roi_method(
206 | context_feature_obj,
207 | bbox,
208 | spatial_scale=spatial_scale,
209 | output_size=(roi_res, roi_res),
210 | ) # [B, c, roi_res, roi_res]
211 | object_feature = einops.rearrange(
212 | context_feature_obj, "b c h w -> b (h w) c"
213 | )
214 | object_feature = object_feature.mean(
215 | dim=1)[:, None] # [B, 1, c]
216 | else:
217 | raise NotImplementedError(
218 | "Object encode mode {} not implemented".format(
219 | self.object_encode_mode
220 | )
221 | )
222 | object_feature = self.object_encode_module(object_feature)
223 |
224 | condition_feature.append(object_feature)
225 |
226 | if self.encode_action:
227 |
228 | action_feature = data_batch["action_feature"][:, None]
229 | action_feature = self.action_encode_module(
230 | action_feature) # [B, 1, c]
231 |
232 | condition_feature.append(action_feature)
233 |
234 | if self.encode_bbox:
235 | bbox_norm = bbox[:, 1:].clone()
236 | bbox_norm[:, [0, 2]] = bbox_norm[:, [0, 2]] / inputs.shape[-1]
237 | bbox_norm[:, [1, 3]] = bbox_norm[:, [1, 3]] / inputs.shape[-2]
238 | bbox_feature = self.bbox_encode_module(bbox_norm) # [B, 64]
239 |
240 | bbox_feature = bbox_feature[:, None] # [B, 1, 64]
241 | condition_feature.append(bbox_feature)
242 |
243 | condition_feature = torch.cat(
244 | condition_feature, dim=-1) # [B, 1, 2c+64)
245 |
246 | if self.fuser is not None and len(condition_feature) > 0:
247 | feature = self.fuser(feature, condition_feature)
248 | feature = self.proj(feature)
249 |
250 | depth_feature = feature
251 | verb_feature = data_batch["verb_feature"][:, None]
252 |
253 | depth_feature = self.depth_fuser(depth_feature, verb_feature)
254 | pred_depth = self.depth_proj(depth_feature)
255 |
256 | feature = einops.rearrange(
257 | feature, "b (h w) c -> b c h w", h=h_out, w=w_out)
258 | pred_depth = einops.rearrange(
259 | pred_depth, "b (h w) c -> b c h w", h=h_out, w=w_out
260 | )
261 |
262 | pred = self.decoder(feature)
263 | pred_depth = F.interpolate(
264 | pred_depth, size=(
265 | self.resolution[0], self.resolution[1]), mode="bilinear"
266 | )
267 | # Postprocess the output
268 | if pred is not None:
269 | pred_vf, pred_dres, pred_heatmap = pred[:,
270 | :2], pred_depth, pred[:, -1:]
271 | pred_vf = F.normalize(pred_vf, p=2, dim=1) # [-1, 1]
272 | pred_vf = pred_vf.clamp(-1, 1)
273 | if "d_res_scale" in data_batch:
274 | print(
275 | "... Rescale the depth with d_res_scale: ",
276 | data_batch["d_res_scale"],
277 | )
278 | pred_dres = pred_dres * data_batch["d_res_scale"]
279 | pred_d_final = start_pos_depth - pred_dres # [B, 1, H, W]
280 | pred = torch.cat(
281 | [pred_vf, pred_d_final, pred_heatmap], dim=1
282 | ) # [B, 4, H, W]
283 |
284 | outputs = {"pred": pred}
285 |
286 | return outputs
287 |
--------------------------------------------------------------------------------
/models/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from models.clip.model import build_model
14 | from models.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 |
19 | BICUBIC = InterpolationMode.BICUBIC
20 | except ImportError:
21 | BICUBIC = Image.BICUBIC
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 | __all__ = ["available_models", "load", "tokenize"]
27 | _tokenizer = _Tokenizer()
28 |
29 | _MODELS = {
30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", # 243MB
31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", # 277MB
32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", #401MB
33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", # 629MB
34 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", # 1.2GB
35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", # 334MB
36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", # 334MB
37 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
38 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
39 | }
40 |
41 |
42 | def _download(url: str, root: str):
43 | os.makedirs(root, exist_ok=True)
44 | filename = os.path.basename(url)
45 |
46 | expected_sha256 = url.split("/")[-2]
47 | download_target = os.path.join(root, filename)
48 |
49 | if os.path.exists(download_target) and not os.path.isfile(download_target):
50 | raise RuntimeError(f"{download_target} exists and is not a regular file")
51 |
52 | if os.path.isfile(download_target):
53 | if (
54 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
55 | == expected_sha256
56 | ):
57 | return download_target
58 | else:
59 | warnings.warn(
60 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
61 | )
62 |
63 | print(f"Downloading CLIP model from {url}")
64 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
65 | with tqdm(
66 | total=int(source.info().get("Content-Length")),
67 | ncols=80,
68 | unit="iB",
69 | unit_scale=True,
70 | unit_divisor=1024,
71 | ) as loop:
72 | while True:
73 | buffer = source.read(8192)
74 | if not buffer:
75 | break
76 |
77 | output.write(buffer)
78 | loop.update(len(buffer))
79 |
80 | if (
81 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
82 | != expected_sha256
83 | ):
84 | raise RuntimeError(
85 | "Model has been downloaded but the SHA256 checksum does not not match"
86 | )
87 |
88 | return download_target
89 |
90 |
91 | def _convert_image_to_rgb(image):
92 | return image.convert("RGB")
93 |
94 |
95 | def _transform(n_px):
96 | return Compose(
97 | [
98 | # Resize(n_px, interpolation=BICUBIC),
99 | # CenterCrop(n_px),
100 | # _convert_image_to_rgb,
101 | # ToTensor(),
102 | Normalize(
103 | (0.48145466, 0.4578275, 0.40821073),
104 | (0.26862954, 0.26130258, 0.27577711),
105 | ),
106 | ]
107 | )
108 |
109 |
110 | def available_models() -> List[str]:
111 | """Returns the names of available CLIP models"""
112 | return list(_MODELS.keys())
113 |
114 |
115 | TORCH_HUB_ROOT = os.path.expandvars(os.getenv("$TORCH_HUB_ROOT", "$HOME/.torch_hub"))
116 |
117 |
118 | def load(
119 | name: str,
120 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
121 | jit: bool = False,
122 | download_root: str = None,
123 | ):
124 | """Load a CLIP model
125 |
126 | Parameters
127 | ----------
128 | name : str
129 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
130 |
131 | device : Union[str, torch.device]
132 | The device to put the loaded model
133 |
134 | jit : bool
135 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
136 |
137 | download_root: str
138 | path to download the model files; by default, it uses "~/.torch_hub/clip"
139 |
140 | Returns
141 | -------
142 | model : torch.nn.Module
143 | The CLIP model
144 |
145 | preprocess : Callable[[PIL.Image], torch.Tensor]
146 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
147 | """
148 | if name in _MODELS:
149 | model_path = _download(_MODELS[name], download_root or TORCH_HUB_ROOT)
150 | elif os.path.isfile(name):
151 | model_path = name
152 | else:
153 | raise RuntimeError(
154 | f"Model {name} not found; available models = {available_models()}"
155 | )
156 |
157 | with open(model_path, "rb") as opened_file:
158 | try:
159 | # loading JIT archive
160 | model = torch.jit.load(
161 | opened_file, map_location=device if jit else "cpu"
162 | ).eval()
163 | state_dict = None
164 | except RuntimeError:
165 | # loading saved state dict
166 | if jit:
167 | warnings.warn(
168 | f"File {model_path} is not a JIT archive. Loading as a state dict instead"
169 | )
170 | jit = False
171 | state_dict = torch.load(opened_file, map_location="cpu")
172 |
173 | if not jit:
174 | model = build_model(state_dict or model.state_dict()).to(device)
175 | if str(device) == "cpu":
176 | model.float()
177 | return model, _transform(model.visual.input_resolution)
178 |
179 | # patch the device names
180 | device_holder = torch.jit.trace(
181 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
182 | )
183 | device_node = [
184 | n
185 | for n in device_holder.graph.findAllNodes("prim::Constant")
186 | if "Device" in repr(n)
187 | ][-1]
188 |
189 | def patch_device(module):
190 | try:
191 | graphs = [module.graph] if hasattr(module, "graph") else []
192 | except RuntimeError:
193 | graphs = []
194 |
195 | if hasattr(module, "forward1"):
196 | graphs.append(module.forward1.graph)
197 |
198 | for graph in graphs:
199 | for node in graph.findAllNodes("prim::Constant"):
200 | if "value" in node.attributeNames() and str(node["value"]).startswith(
201 | "cuda"
202 | ):
203 | node.copyAttributes(device_node)
204 |
205 | model.apply(patch_device)
206 | patch_device(model.encode_image)
207 | patch_device(model.encode_text)
208 |
209 | # patch dtype to float32 on CPU
210 | if str(device) == "cpu":
211 | float_holder = torch.jit.trace(
212 | lambda: torch.ones([]).float(), example_inputs=[]
213 | )
214 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
215 | float_node = float_input.node()
216 |
217 | def patch_float(module):
218 | try:
219 | graphs = [module.graph] if hasattr(module, "graph") else []
220 | except RuntimeError:
221 | graphs = []
222 |
223 | if hasattr(module, "forward1"):
224 | graphs.append(module.forward1.graph)
225 |
226 | for graph in graphs:
227 | for node in graph.findAllNodes("aten::to"):
228 | inputs = list(node.inputs())
229 | for i in [
230 | 1,
231 | 2,
232 | ]: # dtype can be the second or third argument to aten::to()
233 | if inputs[i].node()["value"] == 5:
234 | inputs[i].node().copyAttributes(float_node)
235 |
236 | model.apply(patch_float)
237 | patch_float(model.encode_image)
238 | patch_float(model.encode_text)
239 |
240 | model.float()
241 |
242 | return model, _transform(model.input_resolution.item())
243 |
244 |
245 | def tokenize(
246 | texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
247 | ) -> Union[torch.IntTensor, torch.LongTensor]:
248 | """
249 | Returns the tokenized representation of given input string(s)
250 |
251 | Parameters
252 | ----------
253 | texts : Union[str, List[str]]
254 | An input string or a list of input strings to tokenize
255 |
256 | context_length : int
257 | The context length to use; all CLIP models use 77 as the context length
258 |
259 | truncate: bool
260 | Whether to truncate the text in case its encoding is longer than the context length
261 |
262 | Returns
263 | -------
264 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
265 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
266 | """
267 | if isinstance(texts, str):
268 | texts = [texts]
269 |
270 | sot_token = _tokenizer.encoder["<|startoftext|>"]
271 | eot_token = _tokenizer.encoder["<|endoftext|>"]
272 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
273 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
274 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
275 | else:
276 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
277 |
278 | for i, tokens in enumerate(all_tokens):
279 | if len(tokens) > context_length:
280 | if truncate:
281 | tokens = tokens[:context_length]
282 | tokens[-1] = eot_token
283 | else:
284 | raise RuntimeError(
285 | f"Input {texts[i]} is too long for context length {context_length}"
286 | )
287 | result[i, : len(tokens)] = torch.tensor(tokens)
288 |
289 | return result
290 |
--------------------------------------------------------------------------------
/demos/infer_affordance.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import sys
3 | import os
4 | import pytorch_lightning as pl
5 | from omegaconf import OmegaConf
6 | import torch
7 | import numpy as np
8 | import cv2
9 | from diffuser_utils.guidance_loss import DiffuserGuidance
10 | from diffuser_utils.guidance_params import GUIDANCE_PARAMS_DICT
11 | import diffuser_utils.dataset_utils as DatasetUtils
12 | from models.clip import clip
13 | import open3d as o3d
14 | from algos.afford_algos import AffordanceInferenceEngine
15 | from easydict import EasyDict as edict
16 | from copy import deepcopy
17 | from transformations import rotation_matrix
18 |
19 | # CLIP-based action text encoder
20 | VLM, VLM_TRANSFORM = clip.load("ViT-B/16", jit=False)
21 | VLM.float()
22 | VLM.eval()
23 | VLM.cuda()
24 | for p in VLM.parameters():
25 | p.requires_grad = False
26 |
27 |
28 | def main(args):
29 | # Parse the instruction
30 | config = edict(OmegaConf.to_container(OmegaConf.load(args.config)))
31 | gripper_mesh = o3d.io.read_triangle_mesh(config.gripper_mesh_file)
32 | gripper_mesh.compute_vertex_normals()
33 | gripper_mesh.paint_uniform_color([255 / 255.0, 192 / 255.0, 203 / 255.0])
34 | gripper_mesh.rotate(rotation_matrix(
35 | np.pi/2, [0, 0, 1])[:3, :3], center=[0, 0, 0])
36 |
37 | # Parse the instruction
38 | frame_id = args.frame
39 | action = args.instruction.split(" ")[0]
40 | object_name = args.object
41 | if args.object == "" and not args.load_results:
42 | print("No object name provided, please provide an object name for the detector to work")
43 | return
44 |
45 | dataset_path = os.path.join(config.dataset_dir, args.dataset)
46 | camera_info = DatasetUtils.load_json(
47 | os.path.join(dataset_path, "camera_intrinsic.json")
48 | )
49 | intr = np.array(camera_info["intrinsic_matrix"]
50 | ).reshape(3, 3).astype(np.float32).T
51 |
52 | if args.scale > 0:
53 | calib_scale = args.scale
54 | else:
55 | calib_scale = 1.5 * 640 / intr[0, 0]
56 |
57 | # Set the parameters for guidance
58 | if action not in GUIDANCE_PARAMS_DICT:
59 | print("WARNING: Action [{}] not found in guidance params, quality not guaranteed".format(
60 | action))
61 | guidance_params = GUIDANCE_PARAMS_DICT["other"]
62 | else:
63 | guidance_params = GUIDANCE_PARAMS_DICT[action]
64 |
65 | goal_weight = guidance_params["goal_weight"]
66 | noncollide_weight = guidance_params["noncollide_weight"]
67 | normal_weight = guidance_params["normal_weight"]
68 | contact_weight = guidance_params["contact_weight"]
69 | fine_voxel_resolution = guidance_params["fine_voxel_resolution"]
70 | exclude_object_points = guidance_params["exclude_object_points"]
71 | print("Using guidance params: ", guidance_params)
72 |
73 | # Read the data
74 | depth_file_path = os.path.join(
75 | dataset_path, "depth_m3d", "{:06d}.png".format(frame_id))
76 | if not os.path.exists(depth_file_path):
77 | depth_file_path = os.path.join(
78 | dataset_path, "depth", "{:06d}.png".format(frame_id))
79 |
80 | color_file_path = os.path.join(
81 | dataset_path, "color", "{:06d}.png".format(frame_id))
82 | if not os.path.exists(color_file_path):
83 | color_file_path = color_file_path.replace(".png", ".jpg")
84 | depth = cv2.imread(depth_file_path, -1)
85 | depth = depth / 1000.0
86 | color = cv2.imread(color_file_path, -1)[..., [2, 1, 0]].copy()
87 | depth[depth > 2] = 0
88 | data_batch = DatasetUtils.get_context_data_from_rgbd(
89 | color,
90 | depth,
91 | intr,
92 | voxel_resolution=32,
93 | fine_voxel_resolution=fine_voxel_resolution,
94 | )
95 |
96 | # Initialize the inference engine
97 | config_traj_path = config.config_traj
98 | config_goal_path = config.config_goal
99 | config_contact_path = config.config_contact
100 | cfg_traj = edict(OmegaConf.to_container(OmegaConf.load(config_traj_path)))
101 | cfg_goal = edict(OmegaConf.to_container(OmegaConf.load(config_goal_path)))
102 | cfg_contact = edict(OmegaConf.to_container(
103 | OmegaConf.load(config_contact_path)))
104 | cfg_traj.TEST.ckpt_path = config.traj_ckpt
105 | cfg_goal.TEST.ckpt_path = config.goal_ckpt
106 | cfg_contact.TEST.ckpt_path = config.contact_ckpt
107 | cfg_traj.TEST.num_samples = 40
108 | cfg_traj.TEST.class_free_guide_weight = -0.7
109 |
110 | use_detector, use_esam = True, True
111 |
112 | if args.skip_coarse_stage:
113 | assert (
114 | args.load_results
115 | ), "Must load results if skipping coarse affordance prediction"
116 | cfg_goal, cfg_contact = None, None
117 |
118 | if args.load_results:
119 | use_detector, use_esam = False, False
120 |
121 | diffuser_guidance = DiffuserGuidance(
122 | goal_weight=goal_weight, # 100.0
123 | noncollide_weight=noncollide_weight,
124 | contact_weight=contact_weight,
125 | normal_weight=normal_weight,
126 | scale=calib_scale,
127 | exclude_object_points=exclude_object_points,
128 | valid_horizon=-1,
129 | )
130 |
131 | inference_engine = AffordanceInferenceEngine(
132 | traj_config=cfg_traj,
133 | goal_config=cfg_goal,
134 | contact_config=cfg_contact,
135 | traj_guidance=diffuser_guidance,
136 | use_detector=use_detector,
137 | use_esam=use_esam,
138 | use_graspnet=args.use_graspnet,
139 | detector_config=config.config_detector,
140 | detector_ckpt=config.detector_ckpt,
141 | esam_ckpt=config.esam_ckpt,
142 | graspnet_ckpt=config.graspnet_ckpt,
143 | )
144 |
145 | for k, v in data_batch.items():
146 | if isinstance(v, torch.Tensor):
147 | data_batch[k] = v.unsqueeze(0).cuda()
148 | _, null_embeddings = DatasetUtils.encode_text_clip(
149 | VLM, [""], max_length=None, device="cuda"
150 | )
151 | data_batch["action_feature_null"] = null_embeddings.cuda()
152 | data_batch["action_text"] = args.instruction
153 | data_batch["gripper_points"] = torch.from_numpy(
154 | np.asarray(gripper_mesh.sample_points_uniformly(
155 | number_of_points=2048).points).astype(np.float32)
156 | ).cuda()
157 |
158 | # Run the open-vocabulary object detection
159 | meta_name = "{:06d}_{}.npz".format(
160 | frame_id, args.instruction.replace(" ", "-"))
161 | meta_save_path = os.path.join(dataset_path, "scene_meta", meta_name)
162 | os.makedirs(os.path.dirname(meta_save_path), exist_ok=True)
163 | if args.load_results:
164 | assert os.path.exists(meta_save_path), "Results {} not found".format(
165 | meta_save_path
166 | )
167 | AffordanceInferenceEngine.load_results(meta_save_path, data_batch)
168 | else:
169 | inference_engine.forward_detect(data_batch, text=object_name)
170 | AffordanceInferenceEngine.export_results(meta_save_path, data_batch)
171 |
172 | num_objects = len(data_batch["bbox_all"][0])
173 | if num_objects == 0:
174 | print("No objects detected")
175 | return
176 | for n in range(num_objects):
177 | outputs = {}
178 | results_name = "{:06d}_{:06d}_{}.npz".format(
179 | frame_id, n, args.instruction.replace(" ", "-")
180 | )
181 | results_save_path = os.path.join(
182 | dataset_path, "prediction", results_name)
183 | if args.load_results:
184 | assert os.path.exists(results_save_path), "Results {} not found".format(
185 | results_save_path
186 | )
187 | AffordanceInferenceEngine.load_results(
188 | results_save_path, data_batch)
189 | else:
190 | # Set the data for the object
191 | label_text = data_batch["label_text_all"][n]
192 | data_batch["bbox"] = data_batch["bbox_all"][:, n]
193 | data_batch["cropped_intr"] = data_batch["cropped_intr_all"][:, n]
194 | data_batch["object_mask"] = data_batch["object_mask_all"][:, n]
195 | data_batch["object_bbox_mask"] = data_batch["object_bbox_mask_all"][:, n]
196 |
197 | data_batch["object_color"] = data_batch["object_color_all"][:, n]
198 | data_batch["object_depth"] = data_batch["object_depth_all"][:, n]
199 | data_batch["object_points"] = data_batch["object_points_all"][:, n]
200 | data_batch["resize_ratio"] = data_batch["resize_ratio_all"][:, n]
201 | data_batch["label_text"] = label_text
202 |
203 | # Inference
204 | inference_engine.encode_action(
205 | data_batch, clip_model=VLM, max_length=20)
206 | if not args.skip_coarse_stage:
207 | inference_engine.forward_contact(
208 | data_batch,
209 | outputs,
210 | solve_vf=False,
211 | update_data_batch=True,
212 | sample_num=100,
213 | )
214 | inference_engine.forward_goal(
215 | data_batch,
216 | outputs,
217 | solve_vf=False,
218 | update_data_batch=True,
219 | sample_num=100,
220 | )
221 |
222 | inference_engine.compute_object_contact_normal(
223 | data_batch
224 | )
225 | inference_engine.compute_object_grasp_pose(
226 | data_batch, collision_thresh=0.25)
227 |
228 | if not args.skip_fine_stage:
229 | inference_engine.forward_traj(
230 | data_batch,
231 | outputs,
232 | radii=0.65,
233 | scale=calib_scale,
234 | use_guidance=True,
235 | update_data_batch=True,
236 | )
237 | data_batch["pred_trajectories"] = data_batch["pred_trajectories"][:, :, :60]
238 | inference_engine.smooth_traj(data_batch)
239 |
240 | # Save trajectories and guidance losses
241 | if not args.no_save:
242 | os.makedirs(os.path.dirname(results_save_path), exist_ok=True)
243 | AffordanceInferenceEngine.export_results(
244 | results_save_path, data_batch)
245 |
246 | if args.visualize:
247 | assert not args.skip_fine_stage, "Cannot visualize fine stage if it is skipped"
248 | # Update the predicted trajectories
249 | pred_trajs = data_batch["pred_trajectories"] # [B, N, H, 3]
250 | vis_o3d = inference_engine.nets["traj"].visualize_trajectory_by_rendering(
251 | data_batch, "configs/_render.json", window=False, return_vis=True
252 | )
253 |
254 | if "guide_losses" in data_batch:
255 | pred_trajs_loss = data_batch["guide_losses"]["total_loss"].detach(
256 | )
257 | traj_loss_colors = DatasetUtils.get_heatmap(
258 | pred_trajs_loss.cpu().numpy(), cmap_name="turbo"
259 | )
260 | best_traj_idx = np.argmin(pred_trajs_loss.cpu().numpy())
261 | best_traj = pred_trajs[0,
262 | best_traj_idx].cpu().numpy().squeeze()
263 | for i in range(1, len(vis_o3d)):
264 | vis_o3d[i].paint_uniform_color(
265 | traj_loss_colors[0, i - 1, :])
266 | else:
267 | traj_loss_colors = None
268 | vis_o3d_trajs = vis_o3d[1]
269 | for _vis_traj in vis_o3d[2:]:
270 | vis_o3d_trajs += _vis_traj
271 |
272 | if "grasp_pose" in data_batch:
273 | gripper_colors = DatasetUtils.get_heatmap(
274 | np.arange(len(best_traj))[None], cmap_name="plasma"
275 | )[0]
276 | grasp_pose = data_batch["grasp_pose"].cpu().numpy().squeeze()
277 | gripper_mesh_init = deepcopy(gripper_mesh)
278 | gripper_mesh_init.transform(grasp_pose)
279 | gripper_mesh_init.paint_uniform_color(gripper_colors[10])
280 | vis_o3d_best = gripper_mesh_init
281 |
282 | for hi in range(15, len(best_traj)):
283 | if hi % 20 == 0:
284 | gripper_mesh_hi = deepcopy(gripper_mesh)
285 | gripper_mesh_hi.transform(grasp_pose)
286 | gripper_mesh_hi.translate(best_traj[hi] - best_traj[0])
287 | gripper_mesh_hi.paint_uniform_color(gripper_colors[hi])
288 | vis_o3d_best += gripper_mesh_hi
289 | vis_o3d_final = [vis_o3d[0], vis_o3d_trajs, vis_o3d_best]
290 |
291 | o3d.visualization.draw(
292 | vis_o3d_final, title="Affordance Visualization for Instruction: {} (Object {})".format(
293 | args.instruction, n)
294 | )
295 |
296 |
297 | if __name__ == "__main__":
298 | parser = argparse.ArgumentParser()
299 | parser.add_argument(
300 | "--config",
301 | type=str,
302 | default="./config/test_config.yaml",
303 | help="Specify the checkpoint path, config file path, and the dataset directory"
304 | )
305 | parser.add_argument("-d", "--dataset", type=str,
306 | default="vidbot_data_demo", help="Dataset name")
307 | parser.add_argument("-f", "--frame", type=int,
308 | default=0, help="Frame index")
309 | parser.add_argument("-i", "--instruction", type=str,
310 | default="open drawer",
311 | help="Instruction fed to the affordance model, should be in the format of Verb + Object, no space between the verb")
312 | parser.add_argument("-o", "--object", type=str, default="",
313 | help="Object class name fed to the detector")
314 | parser.add_argument("-v", "--visualize",
315 | action="store_true", help="Visualize the results")
316 | parser.add_argument("-s", "--scale", type=float,
317 | default=-1, help="Scale of the trajectory, set this to -1 if you want to use the default scale")
318 | parser.add_argument("--no_save", action="store_true",
319 | help="Do not save the results")
320 | parser.add_argument("--load_results", action="store_true",
321 | help="Load the results from the file, set this to True if you don't want to install the detector")
322 | parser.add_argument("--skip_coarse_stage",
323 | action="store_true", help="Skip the coarse stage, make sure to set --load_results to True")
324 | parser.add_argument("--skip_fine_stage",
325 | action="store_true", help="Skip the fine stage, make sure to set --load_results to True")
326 | parser.add_argument("--use_graspnet", action="store_true",
327 | help="Use the graspnet, use this if GraspNet is successfully installed, othewise we use a heuristic approach to acuqire the grasp pose")
328 | args = parser.parse_args()
329 | pl.seed_everything(42)
330 |
331 | main(args)
332 |
--------------------------------------------------------------------------------
/algos/traj_optimizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from scipy.signal import savgol_filter
4 | from models.layers_2d import BackprojectDepth, Project3D
5 | from pytorch3d.transforms import rotation_6d_to_matrix, matrix_to_rotation_6d
6 | import numpy as np
7 | import cv2
8 |
9 | class TrajectoryOptimizer:
10 | def __init__(
11 | self,
12 | resolution=(256, 456),
13 | lr_scale_global=0.05,
14 | lr_scale=0.1,
15 | lr_pose=0.05,
16 | num_iters_scale=10,
17 | num_iters_pose=100,
18 | warp_mode="points",
19 | device="cuda",
20 | ):
21 | self.height, self.width = resolution[0], resolution[1]
22 | self.device = device
23 | self.lr_scale_global = lr_scale_global
24 | self.lr_scale = lr_scale
25 | self.lr_pose = lr_pose
26 | self.num_iters_scale = num_iters_scale
27 | self.num_iters_pose = num_iters_pose
28 | self.backproject_depth = BackprojectDepth(self.height, self.width).to(device)
29 | self.project_3d = Project3D().to(device)
30 | self.warp_mode = warp_mode
31 |
32 | def compute_warped_results(
33 | self,
34 | intrinsics,
35 | rgb_tensors,
36 | depth_tensors,
37 | mask_tensors,
38 | scale_tensors,
39 | rgb_key,
40 | depth_key,
41 | mask_key,
42 | scake_key,
43 | T_kc_tensors,
44 | mode,
45 | return_color=False,
46 | verbose=False,
47 | ):
48 | N, _, height, width = depth_tensors.shape
49 | depth_key = depth_key * scake_key[:, None, None, None] # To Colmap space
50 | depth_key = depth_key.repeat(N, 1, 1, 1)
51 | mask_key = mask_key.repeat(N, 1, 1, 1)
52 | rgb_key = rgb_key.repeat(N, 1, 1, 1)
53 |
54 | # Prepare the depth
55 | depth_tensors_tmp = (
56 | depth_tensors * scale_tensors[:, None, None, None]
57 | ) # To Colmap space
58 |
59 | # Compute the warping flow from i to k,
60 | points = self.backproject_depth(depth_tensors_tmp, K=intrinsics) # [N, 3, H*W]
61 | points = points.permute(0, 2, 1) # [N, H*W, 3]
62 | pix_coords = self.project_3d(
63 | points, K=intrinsics, T=T_kc_tensors
64 | ) # [N, 2, H*W]
65 |
66 | # Acquire the backward pixel flow
67 | pix_coords[:, 0] = (pix_coords[:, 0] / (width - 1)) * 2 - 1
68 | pix_coords[:, 1] = (pix_coords[:, 1] / (height - 1)) * 2 - 1
69 | pix_coords = pix_coords.view(-1, 2, height, width) # [N, 2, H, W]
70 | pix_coords = pix_coords.permute(0, 2, 3, 1) # [N, H, W, 2]
71 |
72 | warped_depths_key = F.grid_sample(
73 | depth_key,
74 | pix_coords,
75 | mode="bilinear",
76 | padding_mode="border",
77 | align_corners=True,
78 | ) # In frame c
79 |
80 | warped_masks_key = F.grid_sample(
81 | mask_key,
82 | pix_coords,
83 | mode="nearest",
84 | padding_mode="border",
85 | align_corners=True,
86 | ) # In frame c
87 |
88 | if mode == "points":
89 | # Warp the points
90 | points_c = points # In frame c
91 | points_c = points_c @ T_kc_tensors[:, :3, :3].transpose(
92 | -1, -2
93 | ) + T_kc_tensors[:, :3, 3].unsqueeze(-2) # In frame k
94 | points_k = self.backproject_depth(depth_key, K=intrinsics) # In frame k
95 | points_k = points_k.view(-1, 3, height, width) # [N, 3, H, W]
96 | points_k_to_c = F.grid_sample(
97 | points_k,
98 | pix_coords,
99 | mode="nearest",
100 | padding_mode="border",
101 | align_corners=True,
102 | )
103 | points_k_to_c = points_k_to_c.view(-1, 3, height, width) # [N, 3, H, W]
104 | points_k_to_c = points_k_to_c.permute(0, 2, 3, 1) # [N, H, W, 3]
105 | points_k_to_c = points_k_to_c.view(N, -1, 3) # [N, H*W, 3]
106 | points_k_to_c = points_k_to_c.clone().detach()
107 |
108 | # Warp the depth
109 | elif mode == "depth":
110 | points_k_to_c = self.backproject_depth(
111 | warped_depths_key, K=intrinsics
112 | ) # In frame c
113 | points_k_to_c = points_k_to_c.permute(0, 2, 1) # In frame c
114 | points_c = points # In frame c
115 | points_c = points_c @ T_kc_tensors[:, :3, :3].transpose(
116 | -1, -2
117 | ) + T_kc_tensors[:, :3, 3].unsqueeze(-2) # In frame k
118 | points_k_to_c = points_k_to_c @ T_kc_tensors[:, :3, :3].transpose(
119 | -1, -2
120 | ) + T_kc_tensors[:, :3, 3].unsqueeze(-2) # In frame k
121 |
122 | else:
123 | raise ValueError("Invalid mode: {}".format(mode))
124 |
125 | if return_color:
126 | warped_rgbs_key = F.grid_sample(
127 | rgb_key,
128 | pix_coords,
129 | mode="bilinear",
130 | padding_mode="border",
131 | align_corners=True,
132 | )
133 | if verbose:
134 | for bi in range(len(rgb_key)):
135 | warped_rgb_key_for_i = warped_rgbs_key[bi].detach().cpu()
136 | rgb_i = rgb_tensors[bi].detach().cpu()
137 | rgb_vis_i = torch.cat([rgb_i, warped_rgb_key_for_i], dim=-1)
138 | rgb_vis_i = rgb_vis_i.permute(1, 2, 0).cpu().numpy()
139 | rgb_vis_i = (rgb_vis_i * 255).astype(np.uint8)
140 | rgb_vis_i = rgb_vis_i[..., ::-1]
141 | cv2.imshow("rgb_{}".format(bi), rgb_vis_i)
142 | cv2.waitKey(0)
143 |
144 | return (
145 | warped_rgbs_key,
146 | points_c,
147 | points_k_to_c,
148 | mask_tensors,
149 | warped_masks_key,
150 | )
151 |
152 | return (
153 | None,
154 | points_c,
155 | points_k_to_c,
156 | mask_tensors,
157 | warped_masks_key,
158 | )
159 |
160 | def optimize_pose(
161 | self,
162 | intr,
163 | rgb_tensors,
164 | depth_tensors,
165 | mask_tensors,
166 | scale_init_tensors,
167 | scale_global,
168 | colmap_results,
169 | key_idx=0,
170 | depth_filter_thresh=1.25,
171 | optimize_pose=True,
172 | verbose=False,
173 | ):
174 | # Prepare the tensors
175 | T_wc_tensors = []
176 | frame_ids = list(colmap_results.keys())
177 |
178 | for ii, fi in enumerate(frame_ids):
179 | T_world_ci = np.array(colmap_results[str(fi)]["T_wc"]).reshape(4, 4)
180 | T_wc_tensors.append(torch.from_numpy(T_world_ci).float())
181 | T_wc_tensors = torch.stack(T_wc_tensors).to(self.device) # [N, 4, 4]
182 | key_rgb = rgb_tensors[key_idx][None] # [1, 3, H, W]
183 | key_depth = depth_tensors[key_idx][None] # [1, 1, H, W]
184 | key_mask = mask_tensors[key_idx][None] # [1, 1, H, W]
185 | key_scale = scale_global.clone()[None] # [1]
186 | T_wk = T_wc_tensors[key_idx][None] # [1, 4, 4]
187 | T_kc_tensors = torch.matmul(torch.inverse(T_wk), T_wc_tensors) # [N, 4, 4]
188 |
189 | # Prepare the optimization
190 | scale_tensors_global = torch.ones_like(scale_init_tensors) * key_scale
191 | delta_scale = scale_init_tensors / scale_tensors_global
192 | delta_translation = torch.zeros_like(T_kc_tensors[:, :3, 3]).float()
193 | delta_r6d = matrix_to_rotation_6d(
194 | torch.eye(3, device=self.device).float()
195 | )[None].repeat(len(T_kc_tensors), 1)
196 |
197 | delta_scale.requires_grad = True
198 | delta_translation.requires_grad = optimize_pose
199 | delta_r6d.requires_grad = optimize_pose
200 |
201 | optimizer = torch.optim.Adam([delta_scale], self.lr_scale, betas=(0.9, 0.9))
202 | if optimize_pose:
203 | optimizer.add_param_group(
204 | {"params": delta_translation, "lr": self.lr_pose, "betas": (0.9, 0.9)}
205 | )
206 | optimizer.add_param_group(
207 | {"params": delta_r6d, "lr": self.lr_pose, "betas": (0.9, 0.9)}
208 | )
209 |
210 | for it in range(self.num_iters_pose):
211 | optimizer.zero_grad()
212 | height, width = rgb_tensors.shape[-2:]
213 | scale_curr = scale_tensors_global * delta_scale
214 | T_kc_tensors_curr = T_kc_tensors.clone()
215 | delta_rot = rotation_6d_to_matrix(delta_r6d).transpose(-1, -2)
216 | delta_T = (
217 | torch.eye(4, device=self.device)
218 | .float()[None]
219 | .repeat(len(T_kc_tensors), 1, 1)
220 | )
221 | delta_T[..., :3, :3] = delta_rot
222 | delta_T[..., :3, 3] = delta_translation
223 | T_kc_tensors_curr = torch.matmul(
224 | delta_T,
225 | T_kc_tensors_curr,
226 | )
227 |
228 | _, points_c, points_k_to_c, masks_c, warped_masks_key = (
229 | self.compute_warped_results(
230 | intr,
231 | rgb_tensors,
232 | depth_tensors,
233 | mask_tensors,
234 | scale_curr,
235 | key_rgb,
236 | key_depth,
237 | key_mask,
238 | key_scale,
239 | T_kc_tensors_curr,
240 | mode=self.warp_mode,
241 | verbose=verbose,
242 | return_color=verbose,
243 | )
244 | )
245 | points_c = points_c.view(-1, height, width, 3).permute(0, 3, 1, 2)
246 | points_k_to_c = points_k_to_c.view(-1, height, width, 3).permute(
247 | 0, 3, 1, 2
248 | )
249 | points_loss = F.mse_loss(
250 | points_c, points_k_to_c, reduction="none"
251 | ) # [N, 3, H, W]
252 |
253 | masks_static = masks_c * warped_masks_key
254 |
255 | points_loss = torch.cat(
256 | [points_loss[:key_idx], points_loss[key_idx + 1 :]], dim=0
257 | )
258 | masks_static = torch.cat(
259 | [masks_static[:key_idx], masks_static[key_idx + 1 :]], dim=0
260 | )
261 | depth_filter = torch.cat(
262 | [depth_tensors[:key_idx], depth_tensors[key_idx + 1 :]], dim=0
263 | ).repeat(1, 3, 1, 1)
264 | depth_filter = depth_filter < depth_filter_thresh
265 | points_loss = points_loss * masks_static * depth_filter # [N, 3, H, W]
266 |
267 | loss_geo = points_loss.mean()
268 | loss_scale_reg = F.l1_loss(delta_scale, torch.ones_like(delta_scale)) * 10
269 | loss_translation_reg = F.l1_loss(
270 | delta_translation, torch.zeros_like(delta_translation)
271 | )
272 | loss_rot_reg = F.l1_loss(
273 | delta_r6d,
274 | matrix_to_rotation_6d(torch.eye(3, device=self.device).float())[None].repeat(len(T_kc_tensors), 1)
275 | )
276 | loss_reg = loss_scale_reg + loss_translation_reg + loss_rot_reg
277 | loss = loss_geo + loss_reg
278 |
279 | # Compute the loss and backprop
280 | loss.backward()
281 | optimizer.step()
282 |
283 | T_kc_final = T_kc_tensors.clone()
284 | delta_rot = rotation_6d_to_matrix(delta_r6d).transpose(-1, -2)
285 | delta_T = (
286 | torch.eye(4, device=self.device)
287 | .float()[None]
288 | .repeat(len(T_kc_tensors), 1, 1)
289 | )
290 | delta_T[..., :3, :3] = delta_rot
291 | delta_T[..., :3, 3] = delta_translation
292 | T_kc_final = torch.matmul(
293 | delta_T,
294 | T_kc_tensors_curr,
295 | )
296 | scale_final = scale_tensors_global * delta_scale
297 | return T_kc_final, scale_final
298 |
299 | def optimize_global_scale(
300 | self,
301 | rgb_tensors,
302 | depth_tensors,
303 | mask_tensors,
304 | colmap_results,
305 | ):
306 | # Prepare the results from colmap
307 | scale_init_tensors = []
308 | metric_d_tensors, colmap_d_tensors, valid_d_tensors = [], [], []
309 | frame_ids = list(colmap_results.keys())
310 | frame_id_start = int(frame_ids[0])
311 | for ii, fi in enumerate(frame_ids):
312 | depth = depth_tensors[ii, 0].cpu().numpy()
313 | mask = mask_tensors[ii, 0].cpu().numpy()
314 | uv = np.array(colmap_results[str(fi)]["uv"]).reshape(-1, 2)
315 | colmap_d = np.array(colmap_results[str(fi)]["d"])
316 | uv_mask = np.logical_and(
317 | np.logical_and(uv[:, 0] >= 0, uv[:, 0] < depth.shape[1]),
318 | np.logical_and(uv[:, 1] >= 0, uv[:, 1] < depth.shape[0]),
319 | )
320 | uv = uv[uv_mask]
321 | colmap_d = colmap_d[uv_mask]
322 | metric_d = depth[uv[:, 1], uv[:, 0]] # [S]
323 | valid_d = mask[uv[:, 1], uv[:, 0]]
324 | scale_init = np.median(colmap_d) / np.median(metric_d)
325 | scale_init_tensors.append(scale_init)
326 | metric_d_tensors.append(torch.from_numpy(metric_d).float())
327 | colmap_d_tensors.append(torch.from_numpy(colmap_d).float())
328 | valid_d_tensors.append(torch.from_numpy(valid_d).float())
329 |
330 | # Prepare the tensors
331 | metric_d_tensors = torch.cat(metric_d_tensors).to(self.device) # [S]
332 | colmap_d_tensors = torch.cat(colmap_d_tensors).to(self.device) # [S]
333 | valid_d_tensors = torch.cat(valid_d_tensors).to(self.device) # [S]
334 | scale_init_tensors = (
335 | torch.tensor(scale_init_tensors).float().to(self.device)
336 | ) # [N]
337 |
338 | # Start the optimization
339 | scale_global = torch.median(scale_init_tensors)
340 | delta_scale_global = torch.ones_like(scale_global)
341 | delta_scale_global.requires_grad = True
342 | optimizer_scale = torch.optim.Adam([delta_scale_global], self.lr_scale_global)
343 |
344 | # Do the optimization
345 | for it in range(self.num_iters_scale):
346 | optimizer_scale.zero_grad()
347 | scale_global_curr = scale_global * delta_scale_global
348 | loss_d = F.mse_loss(
349 | metric_d_tensors * scale_global_curr,
350 | colmap_d_tensors,
351 | reduction="none",
352 | ) # [S]
353 |
354 | loss_d = loss_d * valid_d_tensors
355 | loss_d = loss_d.sum() / valid_d_tensors.sum()
356 | loss_d.backward()
357 | optimizer_scale.step()
358 | scale_global_final = (scale_global * delta_scale_global).detach().clone()
359 |
360 | # Compute the amount of valid landmarks
361 | scale_global_np = scale_global_final.detach().cpu().numpy()
362 | key_idx, key_valid_diff_d = 0, -np.inf
363 | for ii, fi in enumerate(frame_ids):
364 | depth = depth_tensors[ii, 0].cpu().numpy()
365 | mask = mask_tensors[ii, 0].cpu().numpy()
366 | uv = np.array(colmap_results[str(fi)]["uv"]).reshape(-1, 2)
367 | colmap_d = np.array(colmap_results[str(fi)]["d"])
368 | uv_mask = np.logical_and(
369 | np.logical_and(uv[:, 0] >= 0, uv[:, 0] < depth.shape[1]),
370 | np.logical_and(uv[:, 1] >= 0, uv[:, 1] < depth.shape[0]),
371 | )
372 | uv = uv[uv_mask]
373 | colmap_d = colmap_d[uv_mask]
374 | metric_d = depth[uv[:, 1], uv[:, 0]] # [S]
375 | valid_d = mask[uv[:, 1], uv[:, 0]] # [S]
376 | diff_d = np.abs(colmap_d / scale_global_np - metric_d) # [S]
377 | diff_d[valid_d == 0] = np.inf
378 |
379 | # Compute the amount of valid landmarks
380 | valid_diff_d = (diff_d < 0.07).sum()
381 | if valid_diff_d > key_valid_diff_d:
382 | key_idx = ii
383 | key_valid_diff_d = valid_diff_d
384 | return scale_init_tensors, scale_global_final, key_idx
385 |
386 |
--------------------------------------------------------------------------------
/models/attention.py:
--------------------------------------------------------------------------------
1 | # Code from https://github.com/krasserm/perceiver-io/blob/main/perceiver/model/core/modules.py
2 | import torch
3 | import torch.nn as nn
4 | import numpy as np
5 | from typing import List, Optional, Tuple
6 | from collections import OrderedDict
7 | from einops import rearrange
8 |
9 | KVCache = Tuple[torch.Tensor, torch.Tensor]
10 |
11 | class RotaryPositionEmbedding:
12 | # Specified in https://arxiv.org/abs/2104.09864
13 | # Modified from https://github.com/lucidrains/rotary-embedding-torch
14 | def __init__(self, frq_pos_enc: torch.Tensor, right_align: bool = False):
15 | # frq_pos_enc shape is (b, n, c).
16 | # frq_pos_enc is broadcast to (b, h, n, c).
17 | self.frq_pos_enc = rearrange(frq_pos_enc, "b n c -> b 1 n c")
18 | self.rotate_dim = frq_pos_enc.shape[-1]
19 | self.right_align = right_align
20 |
21 | def rotate(self, t):
22 | seq_len = t.shape[-2]
23 | if self.right_align:
24 | # q and k are right-aligned in Perceiver AR
25 | pos_enc = self.frq_pos_enc[..., -seq_len:, :]
26 | else:
27 | # q and k are left-aligned
28 | pos_enc = self.frq_pos_enc[..., :seq_len, :]
29 |
30 | t_rot, t_pass = t[..., : self.rotate_dim], t[..., self.rotate_dim :]
31 | t_rot = (t_rot * pos_enc.cos()) + (self._rotate_half(t_rot) * pos_enc.sin())
32 |
33 | return torch.cat((t_rot, t_pass), dim=-1)
34 |
35 | @staticmethod
36 | def _rotate_half(x):
37 | # Rearranges channel dimension [x1, x2, x3, x4, ...] -> [-x2, x1, -x4, x3, ...]
38 | x = rearrange(x, "... (c r) -> ... c r", r=2)
39 | x1, x2 = x.unbind(dim=-1)
40 | x = torch.stack((-x2, x1), dim=-1)
41 | return rearrange(x, "... c r -> ... (c r)")
42 |
43 |
44 | class ModuleOutput(OrderedDict):
45 | def __getattr__(self, name):
46 | if name in self:
47 | return self[name]
48 | else:
49 | raise AttributeError("No such attribute: " + name)
50 |
51 | def __setattr__(self, name, value):
52 | self[name] = value
53 |
54 | def __delattr__(self, name):
55 | if name in self:
56 | del self[name]
57 | else:
58 | raise AttributeError("No such attribute: " + name)
59 |
60 |
61 | class Residual(nn.Module):
62 | def __init__(self, module: nn.Module, dropout: float = 0.0):
63 | super().__init__()
64 | self.module = module
65 | self.dropout = nn.Dropout(dropout)
66 |
67 | def forward(self, *args, **kwargs):
68 | output = self.module(*args, **kwargs)
69 | output.last_hidden_state = self.dropout(output.last_hidden_state) + args[0]
70 | return output
71 |
72 |
73 | class MultiHeadAttention(nn.Module):
74 | def __init__(
75 | self,
76 | num_heads: int,
77 | num_q_input_channels: int,
78 | num_kv_input_channels: int,
79 | num_qk_channels: Optional[int] = None,
80 | num_v_channels: Optional[int] = None,
81 | num_output_channels: Optional[int] = None,
82 | max_heads_parallel: Optional[int] = None,
83 | causal_attention: bool = False,
84 | dropout: float = 0.0,
85 | qkv_bias: bool = True,
86 | out_bias: bool = True,
87 | ):
88 | """Multi-head attention as specified in https://arxiv.org/abs/2107.14795 Appendix E plus support for rotary
89 | position embeddings (https://arxiv.org/abs/2104.09864) and causal attention. Causal attention requires
90 | queries and keys to be right-aligned, if they have different length.
91 |
92 | :param num_heads: Number of attention heads.
93 | :param num_q_input_channels: Number of query input channels.
94 | :param num_kv_input_channels: Number of key/value input channels.
95 | :param num_qk_channels: Number of query and key channels. Default is number `num_q_input_channels`
96 | :param num_v_channels: Number of value channels. Default is `num_qk_channels`.
97 | :param num_output_channels: Number of output channels. Default is `num_q_input_channels`
98 | :param max_heads_parallel: Maximum number of heads to be processed in parallel. Default is `num_heads`.
99 | :param causal_attention: Whether to apply a causal attention mask. Default is `False`.
100 | :param dropout: Dropout probability for attention matrix values. Default is `0.0`
101 | :param qkv_bias: Whether to use a bias term for query, key and value projections. Default is `True`.
102 | :param qkv_bias: Whether to use a bias term for output projection. Default is `True`.
103 | """
104 | super().__init__()
105 |
106 | if num_qk_channels is None:
107 | num_qk_channels = num_q_input_channels
108 |
109 | if num_v_channels is None:
110 | num_v_channels = num_qk_channels
111 |
112 | if num_output_channels is None:
113 | num_output_channels = num_q_input_channels
114 |
115 | if num_qk_channels % num_heads != 0:
116 | raise ValueError("num_qk_channels must be divisible by num_heads")
117 |
118 | if num_v_channels % num_heads != 0:
119 | raise ValueError("num_v_channels must be divisible by num_heads")
120 |
121 | num_qk_channels_per_head = num_qk_channels // num_heads
122 |
123 | self.dp_scale = num_qk_channels_per_head**-0.5
124 | self.num_heads = num_heads
125 | self.num_qk_channels = num_qk_channels
126 | self.num_v_channels = num_v_channels
127 | self.causal_attention = causal_attention
128 |
129 | if max_heads_parallel is None:
130 | self.max_heads_parallel = num_heads
131 | else:
132 | self.max_heads_parallel = max_heads_parallel
133 |
134 | self.q_proj = nn.Linear(num_q_input_channels, num_qk_channels, bias=qkv_bias)
135 | self.k_proj = nn.Linear(num_kv_input_channels, num_qk_channels, bias=qkv_bias)
136 | self.v_proj = nn.Linear(num_kv_input_channels, num_v_channels, bias=qkv_bias)
137 | self.o_proj = nn.Linear(num_v_channels, num_output_channels, bias=out_bias)
138 | self.dropout = nn.Dropout(dropout)
139 |
140 | def forward(
141 | self,
142 | x_q: torch.Tensor,
143 | x_kv: torch.Tensor,
144 | pad_mask: Optional[torch.Tensor] = None,
145 | rot_pos_emb_q: Optional[RotaryPositionEmbedding] = None,
146 | rot_pos_emb_k: Optional[RotaryPositionEmbedding] = None,
147 | kv_cache: Optional[KVCache] = None,
148 | ):
149 | """...
150 |
151 | :param x_q: Query input of shape (B, N, D) where B is the batch size, N the query sequence length and D the
152 | number of query input channels (= `num_q_input_channels`)
153 | :param x_kv: Key/value input of shape (B, L, C) where B is the batch size, L the key/value sequence length and C
154 | are the number of key/value input channels (= `num_kv_input_channels`)
155 | :param pad_mask: Boolean key padding mask. `True` values indicate padding tokens.
156 | :param rot_pos_emb_q: Applies a rotary position embedding to query i.e. if defined, rotates the query.
157 | :param rot_pos_emb_k: Applies a rotary position embedding to key i.e. if defined, rotates the key.
158 | :param kv_cache: cache with past keys and values.
159 | :return: attention result of shape (B, N, F) where B is the batch size, N the query sequence length and F the
160 | number of output channels (= `num_output_channels`)
161 | """
162 |
163 | q = self.q_proj(x_q)
164 | k = self.k_proj(x_kv)
165 | v = self.v_proj(x_kv)
166 |
167 | if kv_cache is not None:
168 | k_cache, v_cache = kv_cache
169 | k = torch.cat([k_cache, k], dim=1)
170 | v = torch.cat([v_cache, v], dim=1)
171 | kv_cache = (k, v)
172 |
173 | q, k, v = (rearrange(x, "b n (h c) -> b h n c", h=self.num_heads) for x in [q, k, v])
174 | q = q * self.dp_scale
175 |
176 | if rot_pos_emb_q is not None:
177 | q = rot_pos_emb_q.rotate(q)
178 |
179 | if rot_pos_emb_k is not None:
180 | k = rot_pos_emb_k.rotate(k)
181 |
182 | if pad_mask is not None:
183 | pad_mask = rearrange(pad_mask, "b j -> b 1 1 j")
184 |
185 | if self.causal_attention:
186 | i = q.shape[2]
187 | j = k.shape[2]
188 |
189 | # If q and k have different length, causal masking only works if they are right-aligned.
190 | causal_mask = torch.ones((i, j), device=x_q.device, dtype=torch.bool).triu(j - i + 1)
191 |
192 | o_chunks = []
193 |
194 | # Only process a given maximum number of heads in
195 | # parallel, using several iterations, if necessary.
196 | for q_chunk, k_chunk, v_chunk in zip(
197 | q.split(self.max_heads_parallel, dim=1),
198 | k.split(self.max_heads_parallel, dim=1),
199 | v.split(self.max_heads_parallel, dim=1),
200 | ):
201 | attn = torch.einsum("b h i c, b h j c -> b h i j", q_chunk, k_chunk)
202 | attn_max_neg = -torch.finfo(attn.dtype).max
203 |
204 | if pad_mask is not None:
205 | attn.masked_fill_(pad_mask, attn_max_neg)
206 |
207 | if self.causal_attention:
208 | attn.masked_fill_(causal_mask, attn_max_neg)
209 |
210 | attn = attn.softmax(dim=-1)
211 | attn = self.dropout(attn)
212 |
213 | o_chunk = torch.einsum("b h i j, b h j c -> b h i c", attn, v_chunk)
214 | o_chunks.append(o_chunk)
215 |
216 | o = torch.cat(o_chunks, dim=1)
217 | o = rearrange(o, "b h n c -> b n (h c)", h=self.num_heads)
218 | o = self.o_proj(o)
219 |
220 | return ModuleOutput(last_hidden_state=o, kv_cache=kv_cache)
221 |
222 |
223 | class CrossAttention(nn.Module):
224 | def __init__(
225 | self,
226 | num_heads: int,
227 | num_q_input_channels: int,
228 | num_kv_input_channels: int,
229 | num_qk_channels: Optional[int] = None,
230 | num_v_channels: Optional[int] = None,
231 | max_heads_parallel: Optional[int] = None,
232 | causal_attention: bool = False,
233 | dropout: float = 0.0,
234 | qkv_bias: bool = True,
235 | out_bias: bool = True,
236 | ):
237 | """Pre-layer-norm cross-attention (see `MultiHeadAttention` for attention details)."""
238 | super().__init__()
239 | self.q_norm = nn.LayerNorm(num_q_input_channels)
240 | self.kv_norm = nn.LayerNorm(num_kv_input_channels)
241 | self.attention = MultiHeadAttention(
242 | num_heads=num_heads,
243 | num_q_input_channels=num_q_input_channels,
244 | num_kv_input_channels=num_kv_input_channels,
245 | num_qk_channels=num_qk_channels,
246 | num_v_channels=num_v_channels,
247 | max_heads_parallel=max_heads_parallel,
248 | causal_attention=causal_attention,
249 | dropout=dropout,
250 | qkv_bias=qkv_bias,
251 | out_bias=out_bias,
252 | )
253 |
254 | def forward(
255 | self,
256 | x_q: torch.Tensor,
257 | x_kv: Optional[torch.Tensor] = None,
258 | x_kv_prefix: Optional[torch.Tensor] = None,
259 | pad_mask: Optional[torch.Tensor] = None,
260 | rot_pos_emb_q: Optional[RotaryPositionEmbedding] = None,
261 | rot_pos_emb_k: Optional[RotaryPositionEmbedding] = None,
262 | kv_cache: Optional[KVCache] = None,
263 | ):
264 | """Pre-layer-norm cross-attention of query input `x_q` to key/value input (`x_kv` or `x_kv_prefix`).
265 |
266 | If `x_kv_prefix` is defined, the entire key/value input is a concatenation of `x_kv_prefix` and `x_q` along
267 | the sequence dimension. In this case, the query attends to itself at the end of the key/value sequence (use
268 | case: Perceiver AR). If `x_kv_prefix` is not defined, `x_kv` is the entire key/value input.
269 | """
270 | x_q = self.q_norm(x_q)
271 |
272 | if x_kv is None:
273 | x_kv_prefix = self.kv_norm(x_kv_prefix)
274 | x_kv = torch.cat([x_kv_prefix, x_q], dim=1)
275 | else:
276 | x_kv = self.kv_norm(x_kv)
277 |
278 | return self.attention(
279 | x_q, x_kv, pad_mask=pad_mask, rot_pos_emb_q=rot_pos_emb_q, rot_pos_emb_k=rot_pos_emb_k, kv_cache=kv_cache
280 | )
281 |
282 |
283 | class SelfAttention(nn.Module):
284 | def __init__(
285 | self,
286 | num_heads: int,
287 | num_channels: int,
288 | num_qk_channels: Optional[int] = None,
289 | num_v_channels: Optional[int] = None,
290 | max_heads_parallel: Optional[int] = None,
291 | causal_attention: bool = False,
292 | dropout: float = 0.0,
293 | qkv_bias: bool = True,
294 | out_bias: bool = True,
295 | ):
296 | """Pre-layer norm self-attention (see `MultiHeadAttention` and for attention details)."""
297 | super().__init__()
298 | self.norm = nn.LayerNorm(num_channels)
299 | self.attention = MultiHeadAttention(
300 | num_heads=num_heads,
301 | num_q_input_channels=num_channels,
302 | num_kv_input_channels=num_channels,
303 | num_qk_channels=num_qk_channels,
304 | num_v_channels=num_v_channels,
305 | max_heads_parallel=max_heads_parallel,
306 | causal_attention=causal_attention,
307 | dropout=dropout,
308 | qkv_bias=qkv_bias,
309 | out_bias=out_bias,
310 | )
311 |
312 | def forward(
313 | self,
314 | x: torch.Tensor,
315 | pad_mask: Optional[torch.Tensor] = None,
316 | rot_pos_emb: Optional[RotaryPositionEmbedding] = None,
317 | kv_cache: Optional[KVCache] = None,
318 | ):
319 | """Pre-layer-norm self-attention of input `x`."""
320 | x = self.norm(x)
321 | return self.attention(
322 | x,
323 | x,
324 | pad_mask=pad_mask,
325 | rot_pos_emb_q=rot_pos_emb,
326 | rot_pos_emb_k=rot_pos_emb,
327 | kv_cache=kv_cache,
328 | )
329 |
330 |
331 | class AbstractAttentionLayer(nn.Sequential):
332 | def empty_kv_cache(self, x) -> KVCache:
333 | k_cache = torch.empty(x.shape[0], 0, self.num_qk_channels, dtype=x.dtype, device=x.device)
334 | v_cache = torch.empty(x.shape[0], 0, self.num_v_channels, dtype=x.dtype, device=x.device)
335 | return k_cache, v_cache
336 |
337 | def forward(self, *args, kv_cache: Optional[KVCache] = None, **kwargs):
338 | attn_output = self[0](*args, kv_cache=kv_cache, **kwargs)
339 | mlp_output = self[1](attn_output.last_hidden_state)
340 | return ModuleOutput(last_hidden_state=mlp_output.last_hidden_state, kv_cache=attn_output.kv_cache)
341 |
342 |
343 | class CrossAttentionLayer(AbstractAttentionLayer):
344 | def __init__(
345 | self,
346 | num_heads: int,
347 | num_q_input_channels: int,
348 | num_kv_input_channels: int,
349 | num_qk_channels: Optional[int] = None,
350 | num_v_channels: Optional[int] = None,
351 | max_heads_parallel: Optional[int] = None,
352 | causal_attention: bool = False,
353 | widening_factor: int = 1,
354 | dropout: float = 0.0,
355 | residual_dropout: float = 0.0,
356 | attention_residual: bool = True,
357 | qkv_bias: bool = True,
358 | out_bias: bool = True,
359 | mlp_bias: bool = True,
360 | ):
361 | cross_attn = CrossAttention(
362 | num_heads=num_heads,
363 | num_q_input_channels=num_q_input_channels,
364 | num_kv_input_channels=num_kv_input_channels,
365 | num_qk_channels=num_qk_channels,
366 | num_v_channels=num_v_channels,
367 | max_heads_parallel=max_heads_parallel,
368 | causal_attention=causal_attention,
369 | dropout=dropout,
370 | qkv_bias=qkv_bias,
371 | out_bias=out_bias,
372 | )
373 |
374 | self.num_qk_channels = cross_attn.attention.num_qk_channels
375 | self.num_v_channels = cross_attn.attention.num_v_channels
376 |
377 | super().__init__(
378 | Residual(cross_attn, residual_dropout) if attention_residual else cross_attn,
379 | Residual(MLP(num_q_input_channels, widening_factor, bias=mlp_bias), residual_dropout),
380 | )
381 |
382 |
383 | class SelfAttentionLayer(AbstractAttentionLayer):
384 | def __init__(
385 | self,
386 | num_heads: int,
387 | num_channels: int,
388 | num_qk_channels: Optional[int] = None,
389 | num_v_channels: Optional[int] = None,
390 | max_heads_parallel: Optional[int] = None,
391 | causal_attention: bool = False,
392 | widening_factor: int = 1,
393 | dropout: float = 0.0,
394 | residual_dropout: float = 0.0,
395 | qkv_bias: bool = True,
396 | out_bias: bool = True,
397 | mlp_bias: bool = True,
398 | ):
399 | self_attn = SelfAttention(
400 | num_heads=num_heads,
401 | num_channels=num_channels,
402 | num_qk_channels=num_qk_channels,
403 | num_v_channels=num_v_channels,
404 | max_heads_parallel=max_heads_parallel,
405 | causal_attention=causal_attention,
406 | dropout=dropout,
407 | qkv_bias=qkv_bias,
408 | out_bias=out_bias,
409 | )
410 |
411 | self.num_qk_channels = self_attn.attention.num_qk_channels
412 | self.num_v_channels = self_attn.attention.num_v_channels
413 |
414 | super().__init__(
415 | Residual(self_attn, residual_dropout),
416 | Residual(MLP(num_channels, widening_factor, bias=mlp_bias), residual_dropout),
417 | )
418 |
419 |
420 | class SelfAttentionBlock(nn.Sequential):
421 | def __init__(
422 | self,
423 | num_layers: int,
424 | num_heads: int,
425 | num_channels: int,
426 | num_qk_channels: Optional[int] = None,
427 | num_v_channels: Optional[int] = None,
428 | num_rotary_layers: int = 1,
429 | max_heads_parallel: Optional[int] = None,
430 | causal_attention: bool = False,
431 | widening_factor: int = 1,
432 | dropout: float = 0.0,
433 | residual_dropout: float = 0.0,
434 | qkv_bias: bool = True,
435 | out_bias: bool = True,
436 | mlp_bias: bool = True,
437 | ):
438 | layers = [
439 | SelfAttentionLayer(
440 | num_heads=num_heads,
441 | num_channels=num_channels,
442 | num_qk_channels=num_qk_channels,
443 | num_v_channels=num_v_channels,
444 | max_heads_parallel=max_heads_parallel,
445 | causal_attention=causal_attention,
446 | widening_factor=widening_factor,
447 | dropout=dropout,
448 | residual_dropout=residual_dropout,
449 | qkv_bias=qkv_bias,
450 | out_bias=out_bias,
451 | mlp_bias=mlp_bias,
452 | )
453 | for _ in range(num_layers)
454 | ]
455 |
456 |
457 | self.num_rotary_layers = num_rotary_layers
458 | super().__init__(*layers)
459 |
460 | def forward(
461 | self,
462 | x: torch.Tensor,
463 | pad_mask: Optional[torch.Tensor] = None,
464 | rot_pos_emb: Optional[RotaryPositionEmbedding] = None,
465 | kv_cache: Optional[List[KVCache]] = None,
466 | ):
467 | if kv_cache is None:
468 | kv_cache_updated = None
469 | else:
470 | if len(kv_cache) == 0:
471 | # initialize kv_cache for each self-attention layer
472 | kv_cache = [layer.empty_kv_cache(x) for layer in self]
473 | kv_cache_updated = []
474 |
475 | for i, layer in enumerate(self):
476 | rot_pos_emb_use = i < self.num_rotary_layers or self.num_rotary_layers == -1
477 | rot_pos_emb_i = rot_pos_emb if rot_pos_emb_use else None
478 |
479 | kv_cache_i = None if kv_cache is None else kv_cache[i]
480 | output = layer(x, pad_mask=pad_mask, rot_pos_emb=rot_pos_emb_i, kv_cache=kv_cache_i)
481 |
482 | x = output.last_hidden_state
483 |
484 | if kv_cache_updated is not None:
485 | kv_cache_updated.append(output.kv_cache)
486 |
487 | return ModuleOutput(last_hidden_state=x, kv_cache=kv_cache_updated)
488 |
489 | class MLP(nn.Sequential):
490 | def __init__(self, num_channels: int, widening_factor: int, bias: bool = True):
491 | super().__init__(
492 | nn.LayerNorm(num_channels),
493 | nn.Linear(num_channels, widening_factor * num_channels, bias=bias),
494 | nn.GELU(),
495 | nn.Linear(widening_factor * num_channels, num_channels, bias=bias),
496 | )
497 |
498 | def forward(self, x):
499 | return ModuleOutput(last_hidden_state=super().forward(x))
--------------------------------------------------------------------------------
/models/feature_extractors.py:
--------------------------------------------------------------------------------
1 | import einops
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from models.layers_2d import (
8 | BackprojectDepth,
9 | load_clip,
10 | Project3D,
11 | )
12 | from torchvision.ops import FeaturePyramidNetwork
13 | import torchvision
14 | from models.helpers import TSDFVolume, get_view_frustum
15 | from models.layers_3d import VoxelGridEncoder
16 |
17 | from typing import Union, List, Tuple
18 |
19 | from models.perceiver import FeaturePerceiver
20 |
21 |
22 | class MultiScaleImageFeatureExtractor(nn.Module):
23 | _RESNET_MEAN = [0.485, 0.456, 0.406]
24 | _RESNET_STD = [0.229, 0.224, 0.225]
25 |
26 | def __init__(
27 | self,
28 | modelname: str = "dino_vits16",
29 | freeze: bool = False,
30 | scale_factors: list = [1, 1 / 2, 1 / 3],
31 | embedding_dim: int = None,
32 | ):
33 | super().__init__()
34 | self.freeze = freeze
35 | self.scale_factors = scale_factors
36 | self.embedding_dim = embedding_dim
37 |
38 | if "res" in modelname:
39 | self._net = getattr(torchvision.models, modelname)(pretrained=True)
40 | self._output_dim = self._net.fc.weight.shape[1]
41 | self._net.fc = nn.Identity()
42 | elif "dinov2" in modelname:
43 | self._net = torch.hub.load("facebookresearch/dinov2", modelname)
44 | self._output_dim = self._net.norm.weight.shape[0]
45 | elif "dino" in modelname:
46 | self._net = torch.hub.load("facebookresearch/dino:main", modelname)
47 | self._output_dim = self._net.norm.weight.shape[0]
48 | else:
49 | raise ValueError(f"Unknown model name {modelname}")
50 |
51 | for name, value in (
52 | ("_resnet_mean", self._RESNET_MEAN),
53 | ("_resnet_std", self._RESNET_STD),
54 | ):
55 | self.register_buffer(
56 | name, torch.FloatTensor(value).view(1, 3, 1, 1), persistent=False
57 | )
58 |
59 | if self.freeze:
60 | for param in self.parameters():
61 | param.requires_grad = False
62 |
63 | if self.embedding_dim is not None:
64 | self._last_layer = nn.Linear(self._output_dim, self.embedding_dim)
65 | self._output_dim = self.embedding_dim
66 | else:
67 | self._last_layer = nn.Identity()
68 |
69 | def get_output_dim(self):
70 | return self._output_dim
71 |
72 | def forward(self, image_rgb: torch.Tensor):
73 | img_normed = self._resnet_normalize_image(image_rgb)
74 | features = self._compute_multiscale_features(img_normed)
75 | return features
76 |
77 | def _resnet_normalize_image(self, img: torch.Tensor):
78 | return (img - self._resnet_mean) / self._resnet_std
79 |
80 | def _compute_multiscale_features(self, img_normed: torch.Tensor):
81 | multiscale_features = None
82 |
83 | if len(self.scale_factors) <= 0:
84 | raise ValueError(
85 | f"Wrong format of self.scale_factors: {self.scale_factors}"
86 | )
87 |
88 | for scale_factor in self.scale_factors:
89 | if scale_factor == 1:
90 | inp = img_normed
91 | else:
92 | inp = self._resize_image(img_normed, scale_factor)
93 |
94 | if multiscale_features is None:
95 | multiscale_features = self._net(inp)
96 | else:
97 | multiscale_features += self._net(inp)
98 |
99 | averaged_features = multiscale_features / len(self.scale_factors)
100 | averaged_features = self._last_layer(averaged_features)
101 | return averaged_features
102 |
103 | @staticmethod
104 | def _resize_image(image: torch.Tensor, scale_factor: float):
105 | return nn.functional.interpolate(
106 | image, scale_factor=scale_factor, mode="bilinear", align_corners=False
107 | )
108 |
109 |
110 | class TSDFMapFeatureExtractor(nn.Module):
111 | def __init__(
112 | self,
113 | input_image_shape,
114 | voxel_resolution=32,
115 | voxel_feature_dim=64,
116 | vlm_feature_attn_dim=256,
117 | # use_feature_decoder=True,
118 | ):
119 | super().__init__()
120 | self.input_image_shape = input_image_shape
121 | self.voxel_resolution = voxel_resolution
122 | self.embedding_dim = voxel_feature_dim
123 | self.backproject = BackprojectDepth(input_image_shape[0], input_image_shape[1])
124 | self.project_3d = Project3D()
125 |
126 | # Load pretrained backbone
127 | # self.vlm, self.vlm_transform = clip.load("ViT-B/16", jit=False)
128 | self.vlm, self.vlm_transform = load_clip()
129 |
130 | self.vlm.float()
131 | for p in self.vlm.parameters():
132 | p.requires_grad = False
133 |
134 | # Load 3D Unet
135 | self.tsdf_net = VoxelGridEncoder(
136 | self.voxel_resolution, c_dim=self.embedding_dim
137 | )
138 |
139 | self.feature_pyramid = FeaturePyramidNetwork(
140 | [64, 256, 512, 1024, 2048], voxel_feature_dim
141 | )
142 |
143 | self.feature_map_pyramid_keys = ["res1", "res2", "res3"]
144 |
145 | # Cross Attention Layer
146 | # self.action_proj = nn.Linear(vlm_feature_attn_dim, self.embedding_dim, bias=True)
147 | self.vlm_preceiver_pyramid = nn.ModuleList()
148 | self.vlm_proj_pyramid = nn.ModuleList()
149 | vlm_preceiver = FeaturePerceiver(
150 | transition_dim=self.embedding_dim,
151 | condition_dim=vlm_feature_attn_dim,
152 | time_emb_dim=0,
153 | )
154 | vlm_proj = nn.Linear(vlm_preceiver.last_dim, self.embedding_dim, bias=True)
155 | for _ in range(len(self.feature_map_pyramid_keys)):
156 |
157 | self.vlm_preceiver_pyramid.append(vlm_preceiver)
158 | self.vlm_proj_pyramid.append(vlm_proj)
159 | # Feature projection layer
160 | proj_dim_in = voxel_feature_dim * (1 + len(self.feature_map_pyramid_keys))
161 | self.proj = nn.Linear(proj_dim_in, self.embedding_dim, bias=True)
162 |
163 | def compute_tsdf_volume(self, color, depth, intrinsics, verbose=False):
164 | cam_pose = np.eye(4)
165 | tsdf_grid_batch = []
166 | tsdf_bounds_batch = []
167 | tsdf_color_batch = []
168 | mesh_batch = []
169 | for i in range(len(depth)):
170 | d_np = depth[i].cpu().numpy()[0]
171 | c_np = (
172 | color[i].cpu().numpy().transpose(1, 2, 0)
173 | ) # [H, W, 3], requested by TSDFVolume
174 | K_np = intrinsics[i].cpu().numpy()
175 | view_frust_pts = get_view_frustum(d_np, K_np, cam_pose)
176 | vol_bnds = np.zeros((3, 2))
177 | vol_bnds[:, 0] = np.minimum(
178 | vol_bnds[:, 0], np.amin(view_frust_pts, axis=1)
179 | ).min()
180 | vol_bnds[:, 1] = np.maximum(
181 | vol_bnds[:, 1], np.amax(view_frust_pts, axis=1)
182 | ).max()
183 | tsdf = TSDFVolume(vol_bnds, voxel_dim=self.voxel_resolution)
184 | tsdf.integrate(c_np * 255, d_np, K_np, cam_pose)
185 | tsdf_grid = torch.from_numpy(tsdf.get_tsdf_volume())
186 | tsdf_grid_batch.append(tsdf_grid)
187 | tsdf_bounds_batch.append(torch.from_numpy(vol_bnds[0]))
188 | if verbose:
189 | mesh = tsdf.get_mesh()
190 | color_grid = torch.from_numpy(tsdf.get_color_volume()) / 255.0
191 | mesh_batch.append(mesh)
192 | tsdf_color_batch.append(color_grid)
193 | tsdf_bounds_batch = (
194 | torch.stack(tsdf_bounds_batch, dim=0).to(depth.device).float()
195 | )
196 | tsdf_grid_batch = torch.stack(tsdf_grid_batch, dim=0).to(depth.device).float()
197 | if verbose:
198 | tsdf_color_batch = (
199 | torch.stack(tsdf_color_batch, dim=0).to(depth.device).float()
200 | )
201 | return tsdf_grid_batch, tsdf_color_batch, tsdf_bounds_batch, mesh_batch
202 | return tsdf_grid_batch
203 |
204 | def compute_context_features(
205 | self, color, depth, intrinsics, tsdf=None, action_features=None
206 | ):
207 | if tsdf is None:
208 | tsdf = self.compute_tsdf_volume(color, depth, intrinsics)
209 |
210 | h_in, w_in = color.shape[-2:]
211 |
212 | color = self.vlm_transform(color)
213 | color_features = self.vlm(color) # [B, N, C]
214 | color_features = self.feature_pyramid(color_features) # [B, N, C]
215 |
216 | # Action grounding
217 | if action_features is not None:
218 | for i, k in enumerate(self.feature_map_pyramid_keys):
219 | color_feature_i = color_features[k] #
220 | h, w = color_feature_i.shape[-2:]
221 | color_feature_i = einops.rearrange(color_feature_i, "B C H W-> B (H W) C")
222 | color_feature_i = self.vlm_preceiver_pyramid[i](color_feature_i, action_features[:, None])
223 | color_feature_i = self.vlm_proj_pyramid[i](color_feature_i)
224 | color_feature_i = einops.rearrange(color_feature_i, "B (H W) C -> B C H W", H=h, W=w)
225 | color_features[k] = color_feature_i
226 |
227 | color_features_pyramid = []
228 | for i, k in enumerate(self.feature_map_pyramid_keys):
229 | color_feature_i = color_features[k]
230 | color_feature_i = F.interpolate(
231 | color_feature_i, size=(h_in, w_in), mode="bilinear"
232 | )
233 | color_features_pyramid.append(color_feature_i)
234 | points_map_pyramid = [tsdf] * len(color_features_pyramid) # [B, D, H, W]
235 | points_pe_pyramid = [self.tsdf_net(tsdf)] * len(color_features_pyramid) # [B, P, D, H, W]
236 |
237 | # color_features = einops.rearrange(
238 | # color_features, "B (H W) C -> B C H W", H=h_out, W=w_out
239 | # )
240 | # import pdb; pdb.set_trace()
241 | # if self.use_feature_decoder:
242 | # color_features = self.feature_decoder(color_features)
243 |
244 | # color_features = F.interpolate(
245 | # color_features, size=tuple(self.input_image_shape), mode="bilinear"
246 | # )
247 |
248 |
249 | # # Compute the point features pyramid
250 | # for i, k in enumerate(self.feature_map_pyramid_keys):
251 | # color_feature_i = color_features[k]
252 | # color_feature_i = F.interpolate(
253 | # color_feature_i, size=tuple(self.input_image_shape), mode="bilinear"
254 | # )
255 |
256 | # color_features_pyramid.append(color_feature_i)
257 | return color_features_pyramid, points_map_pyramid, points_pe_pyramid
258 |
259 | @staticmethod
260 | def interpolate_voxel_grid_features(voxel_grid, query_points, voxel_bounds):
261 | """
262 | Parameters
263 | ----------
264 | voxel_grid : torch.Tensor
265 | with shape [B, C, D, H, W]
266 | query_points : torch.Tensor
267 | _with shape [B, N, 3]
268 | voxel_bounds: torch.Tensor
269 | _with shape [B, 2]
270 | """
271 | voxel_bounds = voxel_bounds.unsqueeze(-1).repeat(1, 1, 3) # [B, 2, 3]
272 | query_points = (query_points - voxel_bounds[:, 0:1]) / (
273 | voxel_bounds[:, 1:2] - voxel_bounds[:, 0:1]
274 | )
275 | query_grids = (
276 | query_points * 2 - 1
277 | ) # Normalize the query points from [0, 1] to [-1, 1]
278 | query_grids = query_grids[
279 | ..., [2, 1, 0]
280 | ] # Convert to the voxel grid coordinate system
281 | query_grids = query_grids[:, :, None, None] # [B, N, 1, 1, 3]
282 | query_features = F.grid_sample(
283 | voxel_grid, query_grids, mode="bilinear", align_corners=True
284 | ) # [B, C, N, 1, 1]
285 | query_features = query_features.squeeze(-1).squeeze(-1) # [B, C, N]
286 | return query_features
287 |
288 | def interpolate_image_grid_features(self, image_grid, query_points, intrinsics):
289 | """
290 | Parameters
291 | ----------
292 | image_grid : torch.Tensor
293 | with shape [B, C, H, W]
294 | query_points : torch.Tensor
295 | _with shape [B, N, 3]
296 | """
297 | batch_size, _, height, width = image_grid.shape
298 | query_grids = self.project_3d(query_points, intrinsics) # [B, 2, N]
299 | query_grids[:, 0] = (query_grids[:, 0] / (width - 1)) * 2 - 1
300 | query_grids[:, 1] = (query_grids[:, 1] / (height - 1)) * 2 - 1
301 | query_grids = query_grids.permute(0, 2, 1)[:, :, None] # [B, N, 1, 2]
302 | query_featurs = F.grid_sample(
303 | image_grid, query_grids, mode="bilinear", align_corners=True
304 | ) # [B, C, N, 1]
305 | query_featurs = query_featurs.squeeze(-1)
306 | return query_featurs
307 |
308 | def forward(
309 | self,
310 | color_features_pyramid,
311 | points_map_pyramid,
312 | points_pe_pyramid,
313 | query_points,
314 | intrinsics,
315 | voxel_bounds,
316 | **kwargs,
317 | ):
318 | """_summary_
319 |
320 | Parameters
321 | ----------
322 | color_features_pyramid : list of torch.Tensor
323 | with shape [[B, C, H, W]]
324 | points_map_pyramid : list of torch.Tensor for TSDF volume
325 | [[B, D, H, W]]
326 | points_pe_pyramid : list of torch.Tensor for TSDF volume feature
327 | [[B, P, D, H, W]]
328 | query_points : query points
329 | [B, N, 3]
330 | intrinsics : torch.Tensor or np.ndarray
331 | [3, 3]
332 | voxel_bounds : _type_
333 | [B, 2]
334 |
335 | Returns
336 | -------
337 | torch.Tensor
338 | shape of [B, N, C*4]
339 | """
340 | assert len(color_features_pyramid) == len(points_map_pyramid)
341 | assert len(color_features_pyramid) == len(points_pe_pyramid)
342 | batch_size, num_query_points, _ = query_points.shape
343 | features = []
344 |
345 | for i in range(len(color_features_pyramid)):
346 | # Re-arrange to feature maps
347 | color_feature_i = color_features_pyramid[i] # [B, C, H, W]
348 | points_pe_i = points_pe_pyramid[i] # [B, P, D, H, W]
349 | points_map_i = points_map_pyramid[i][:, None] # [B, 1, D, H, W]
350 |
351 | if i == 0:
352 | # Interpolate the voxel grid features
353 | feat_occ = self.interpolate_voxel_grid_features(
354 | points_map_i, query_points, voxel_bounds
355 | )
356 |
357 | # Interpolate the voxel grid features
358 | feat_3d = self.interpolate_voxel_grid_features(
359 | points_pe_i, query_points, voxel_bounds
360 | )
361 | features.append(feat_3d) # [B, C, N]
362 |
363 | # Interpolate the 2D feature maps
364 | feat_2d = self.interpolate_image_grid_features(
365 | color_feature_i, query_points, intrinsics
366 | )
367 | features.append(feat_2d) # [B, C, N]
368 | features = torch.cat(features, dim=1).permute(0, 2, 1) # [B, N, C*3]
369 | features = self.proj(features) # [B, N, C]
370 | return features
371 |
372 | class TSDFMapGeometryExtractor(nn.Module):
373 | def __init__(
374 | self,
375 | input_image_shape,
376 | voxel_resolution=64,
377 | voxel_feature_dim=64,
378 | ):
379 | super().__init__()
380 | self.input_image_shape = input_image_shape
381 | self.voxel_resolution = voxel_resolution
382 | self.embedding_dim = voxel_feature_dim
383 | self.backproject = BackprojectDepth(input_image_shape[0], input_image_shape[1])
384 | self.project_3d = Project3D()
385 |
386 | # Load 3D Unet
387 | self.tsdf_net = VoxelGridEncoder(
388 | self.voxel_resolution, c_dim=self.embedding_dim
389 | )
390 |
391 | def compute_tsdf_volume(self, color, depth, intrinsics, verbose=False):
392 | cam_pose = np.eye(4)
393 | tsdf_grid_batch = []
394 | tsdf_bounds_batch = []
395 | tsdf_color_batch = []
396 | mesh_batch = []
397 | for i in range(len(depth)):
398 | d_np = depth[i].cpu().numpy()[0]
399 | c_np = (
400 | color[i].cpu().numpy().transpose(1, 2, 0)
401 | ) # [H, W, 3], requested by TSDFVolume
402 | K_np = intrinsics[i].cpu().numpy()
403 | view_frust_pts = get_view_frustum(d_np, K_np, cam_pose)
404 | vol_bnds = np.zeros((3, 2))
405 | vol_bnds[:, 0] = np.minimum(
406 | vol_bnds[:, 0], np.amin(view_frust_pts, axis=1)
407 | ).min()
408 | vol_bnds[:, 1] = np.maximum(
409 | vol_bnds[:, 1], np.amax(view_frust_pts, axis=1)
410 | ).max()
411 | tsdf = TSDFVolume(vol_bnds, voxel_dim=self.voxel_resolution)
412 | tsdf.integrate(c_np * 255, d_np, K_np, cam_pose)
413 | tsdf_grid = torch.from_numpy(tsdf.get_tsdf_volume())
414 | tsdf_grid_batch.append(tsdf_grid)
415 | tsdf_bounds_batch.append(torch.from_numpy(vol_bnds[0]))
416 | if verbose:
417 | mesh = tsdf.get_mesh()
418 | color_grid = torch.from_numpy(tsdf.get_color_volume()) / 255.0
419 | mesh_batch.append(mesh)
420 | tsdf_color_batch.append(color_grid)
421 | tsdf_bounds_batch = (
422 | torch.stack(tsdf_bounds_batch, dim=0).to(depth.device).float()
423 | )
424 | tsdf_grid_batch = torch.stack(tsdf_grid_batch, dim=0).to(depth.device).float()
425 | if verbose:
426 | tsdf_color_batch = (
427 | torch.stack(tsdf_color_batch, dim=0).to(depth.device).float()
428 | )
429 | return tsdf_grid_batch, tsdf_color_batch, tsdf_bounds_batch, mesh_batch
430 | return tsdf_grid_batch
431 |
432 | def compute_context_features(
433 | self, color, depth, intrinsics, tsdf=None, action_featurs=None
434 | ):
435 | if tsdf is None:
436 | tsdf = self.compute_tsdf_volume(color, depth, intrinsics)
437 |
438 |
439 | color_features_pyramid = [None] # [B, C, H, W]
440 | points_map_pyramid = [tsdf] # [B, D, H, W]
441 | points_pe_pyramid = [self.tsdf_net(tsdf)] # [B, P, D, H, W]
442 |
443 | return color_features_pyramid, points_map_pyramid, points_pe_pyramid
444 |
445 | @staticmethod
446 | def interpolate_voxel_grid_features(voxel_grid, query_points, voxel_bounds):
447 | """
448 | Parameters
449 | ----------
450 | voxel_grid : torch.Tensor
451 | with shape [B, C, D, H, W]
452 | query_points : torch.Tensor
453 | _with shape [B, N, 3]
454 | voxel_bounds: torch.Tensor
455 | _with shape [B, 2]
456 | """
457 | voxel_bounds = voxel_bounds.unsqueeze(-1).repeat(1, 1, 3) # [B, 2, 3]
458 | query_points = (query_points - voxel_bounds[:, 0:1]) / (
459 | voxel_bounds[:, 1:2] - voxel_bounds[:, 0:1]
460 | )
461 | query_grids = (
462 | query_points * 2 - 1
463 | ) # Normalize the query points from [0, 1] to [-1, 1]
464 | query_grids = query_grids[
465 | ..., [2, 1, 0]
466 | ] # Convert to the voxel grid coordinate system
467 | query_grids = query_grids[:, :, None, None] # [B, N, 1, 1, 3]
468 | query_features = F.grid_sample(
469 | voxel_grid, query_grids, mode="bilinear", align_corners=True
470 | ) # [B, C, N, 1, 1]
471 | query_features = query_features.squeeze(-1).squeeze(-1) # [B, C, N]
472 | return query_features
473 |
474 | def interpolate_image_grid_features(self, image_grid, query_points, intrinsics):
475 | """
476 | Parameters
477 | ----------
478 | image_grid : torch.Tensor
479 | with shape [B, C, H, W]
480 | query_points : torch.Tensor
481 | _with shape [B, N, 3]
482 | """
483 | batch_size, _, height, width = image_grid.shape
484 | query_grids = self.project_3d(query_points, intrinsics) # [B, 2, N]
485 | query_grids[:, 0] = (query_grids[:, 0] / (width - 1)) * 2 - 1
486 | query_grids[:, 1] = (query_grids[:, 1] / (height - 1)) * 2 - 1
487 | query_grids = query_grids.permute(0, 2, 1)[:, :, None] # [B, N, 1, 2]
488 | query_featurs = F.grid_sample(
489 | image_grid, query_grids, mode="bilinear", align_corners=True
490 | ) # [B, C, N, 1]
491 | query_featurs = query_featurs.squeeze(-1)
492 | return query_featurs
493 |
494 | def forward(
495 | self,
496 | color_features_pyramid,
497 | points_map_pyramid,
498 | points_pe_pyramid,
499 | query_points,
500 | intrinsics,
501 | voxel_bounds,
502 | **kwargs,
503 | ):
504 | """_summary_
505 |
506 | Parameters
507 | ----------
508 | color_features_pyramid : list of torch.Tensor
509 | with shape [[B, C, H, W]]
510 | points_map_pyramid : list of torch.Tensor for TSDF volume
511 | [[B, D, H, W]]
512 | points_pe_pyramid : list of torch.Tensor for TSDF volume feature
513 | [[B, P, D, H, W]]
514 | query_points : query points
515 | [B, N, 3]
516 | intrinsics : torch.Tensor or np.ndarray
517 | [3, 3]
518 | voxel_bounds : _type_
519 | [B, 2]
520 |
521 | Returns
522 | -------
523 | torch.Tensor
524 | shape of [B, N, C*4]
525 | """
526 | assert len(color_features_pyramid) == len(points_map_pyramid)
527 | assert len(color_features_pyramid) == len(points_pe_pyramid)
528 | batch_size, num_query_points, _ = query_points.shape
529 | features = []
530 |
531 | for i in range(len(color_features_pyramid)):
532 | # Re-arrange to feature maps
533 | color_feature_i = color_features_pyramid[i] # [B, C, H, W]
534 | points_pe_i = points_pe_pyramid[i] # [B, P, D, H, W]
535 | points_map_i = points_map_pyramid[i][:, None] # [B, 1, D, H, W]
536 |
537 | if i == 0:
538 | # Interpolate the voxel grid features
539 | feat_occ = self.interpolate_voxel_grid_features(
540 | points_map_i, query_points, voxel_bounds
541 | )
542 |
543 | # Interpolate the voxel grid features
544 | feat_3d = self.interpolate_voxel_grid_features(
545 | points_pe_i, query_points, voxel_bounds
546 | )
547 | features.append(feat_3d) # [B, C, N]
548 |
549 | # # Interpolate the 2D feature maps
550 | # feat_2d = self.interpolate_image_grid_features(
551 | # color_feature_i, query_points, intrinsics
552 | # )
553 | # features.append(feat_2d) # [B, C, N]
554 | features = torch.cat(features, dim=1).permute(0, 2, 1) # [B, N, C*3]
555 | return features
556 |
--------------------------------------------------------------------------------
/models/helpers.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | from numba import njit, prange
7 | from skimage import measure
8 | import open3d as o3d
9 | import time
10 | import torch
11 | from math import ceil
12 | from models.clip import clip, tokenize
13 |
14 |
15 | def compute_null_text_embeddings(vlm, batch_size=1, device="cuda"):
16 | action_tokens_null = tokenize("")
17 | action_tokens_null = action_tokens_null.repeat(batch_size, 1)
18 | action_tokens_null = action_tokens_null.to(device)
19 | action_feature_null = vlm.encode_text(action_tokens_null).float()
20 | return action_feature_null
21 |
22 |
23 | def fourier_positional_encoding(input, L): # [B,...,C]
24 | shape = input.shape
25 | freq = 2 ** torch.arange(L, dtype=torch.float32, device=input.device) * np.pi # [L]
26 | spectrum = input[..., None] * freq # [B,...,C,L]
27 | sin, cos = spectrum.sin(), spectrum.cos() # [B,...,C,L]
28 | input_enc = torch.stack([sin, cos], dim=-2) # [B,...,C,2,L]
29 | input_enc = input_enc.view(*shape[:-1], -1) # [B,...,2CL]
30 | return input_enc
31 |
32 |
33 | def exists(x):
34 | return x is not None
35 |
36 |
37 | def default(val, d):
38 | if exists(val):
39 | return val
40 | return d() if callable(d) else d
41 |
42 |
43 | def round_up_multiple(num, mult):
44 | return ceil(num / mult) * mult
45 |
46 |
47 | def extract(a, t, x_shape):
48 | b, *_ = t.shape
49 | out = a.gather(-1, t)
50 | return out.reshape(b, *((1,) * (len(x_shape) - 1)))
51 |
52 |
53 | def cosine_beta_schedule(timesteps, s=0.008, dtype=torch.float32):
54 | """
55 | cosine schedule
56 | as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
57 | """
58 | steps = timesteps + 1
59 | x = np.linspace(0, steps, steps)
60 | alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
61 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
62 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
63 | betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
64 | return torch.tensor(betas_clipped, dtype=dtype)
65 |
66 |
67 | # -----------------------------------------------------------------------------#
68 | # ---------------------------------- losses -----------------------------------#
69 | # -----------------------------------------------------------------------------#
70 |
71 |
72 | class FocalLoss(nn.Module):
73 | def __init__(self, gamma: float = 0, size_average: bool = True):
74 | super(FocalLoss, self).__init__()
75 | self.gamma = gamma
76 | self.size_average = size_average
77 |
78 | def forward(self, input, target):
79 | logpt = F.log_softmax(input, dim=-1)
80 | logpt = logpt.gather(1, target.view(-1, 1)).view(-1)
81 | pt = logpt.exp()
82 |
83 | loss = -1 * (1 - pt) ** self.gamma * logpt
84 | if self.size_average:
85 | return loss.mean()
86 | else:
87 | return loss.sum()
88 |
89 |
90 | class WeightedLoss(nn.Module):
91 |
92 | def __init__(self, weights):
93 | super().__init__()
94 | self.register_buffer("weights", weights)
95 |
96 | def forward(self, pred, targ):
97 | """
98 | pred, targ : tensor
99 | [ batch_size x horizon x transition_dim ]
100 | """
101 | loss = self._loss(pred, targ)
102 | weighted_loss = (loss * self.weights).mean()
103 | return weighted_loss
104 |
105 |
106 | class WeightedL1(WeightedLoss):
107 |
108 | def _loss(self, pred, targ):
109 | return torch.abs(pred - targ)
110 |
111 |
112 | class WeightedL2(WeightedLoss):
113 |
114 | def _loss(self, pred, targ):
115 | return F.mse_loss(pred, targ, reduction="none")
116 |
117 |
118 | Losses = {
119 | "l1": WeightedL1,
120 | "l2": WeightedL2,
121 | }
122 |
123 |
124 | class EMA:
125 | """
126 | empirical moving average
127 | """
128 |
129 | def __init__(self, beta):
130 | super().__init__()
131 | self.beta = beta
132 |
133 | def update_model_average(self, ma_model, current_model):
134 | with torch.no_grad():
135 | ema_state_dict = ma_model.state_dict()
136 | for key, value in current_model.state_dict().items():
137 | ema_value = ema_state_dict[key]
138 | ema_value.copy_(self.beta * ema_value + (1.0 - self.beta) * value)
139 |
140 |
141 | # -----------------------------------------------------------------------------#
142 | # ---------------------------------- TSDF -----------------------------------#
143 | # -----------------------------------------------------------------------------#
144 |
145 |
146 | def get_view_frustum(depth_im, cam_intr, cam_pose):
147 | """Get corners of 3D camera view frustum of depth image"""
148 | im_h = depth_im.shape[0]
149 | im_w = depth_im.shape[1]
150 | max_depth = np.max(depth_im)
151 | view_frust_pts = np.array(
152 | [
153 | (np.array([0, 0, 0, im_w, im_w]) - cam_intr[0, 2])
154 | * np.array([0, max_depth, max_depth, max_depth, max_depth])
155 | / cam_intr[0, 0],
156 | (np.array([0, 0, im_h, 0, im_h]) - cam_intr[1, 2])
157 | * np.array([0, max_depth, max_depth, max_depth, max_depth])
158 | / cam_intr[1, 1],
159 | np.array([0, max_depth, max_depth, max_depth, max_depth]),
160 | ]
161 | )
162 | view_frust_pts = view_frust_pts.T @ cam_pose[:3, :3].T + cam_pose[:3, 3]
163 | return view_frust_pts.T
164 |
165 |
166 | # try:
167 | # import pycuda.driver as cuda
168 | # import pycuda.autoinit
169 | # from pycuda.compiler import SourceModule
170 |
171 | # FUSION_GPU_MODE = 1
172 | # except Exception as err:
173 | # print("Warning: {}".format(err))
174 | # print("Failed to import PyCUDA. Running fusion in CPU mode.")
175 | # FUSION_GPU_MODE = 0
176 |
177 |
178 | class TSDFVolume:
179 | """
180 | Volumetric with TSDF representation
181 | """
182 |
183 | def __init__(
184 | self,
185 | vol_bounds: np.ndarray,
186 | voxel_dim: float,
187 | use_gpu: bool = False,
188 | verbose: bool = False,
189 | num_margin: float = 5.0,
190 | enable_color=True,
191 | unknown_free=True,
192 | ):
193 | """
194 | Constructor
195 |
196 | :param vol_bounds: An ndarray is shape (3,2), define the min & max bounds of voxels.
197 | :param voxel_size: Voxel size in meters.
198 | :param use_gpu: Use GPU for voxel update.
199 | :param verbose: Print verbose message or not.
200 | """
201 |
202 | vol_bounds = np.asarray(vol_bounds)
203 | assert vol_bounds.shape == (3, 2), "vol_bounds should be of shape (3,2)"
204 |
205 | self._verbose = verbose
206 | self._use_gpu = use_gpu
207 | self._vol_bounds = vol_bounds
208 |
209 | self._vol_dim = [voxel_dim, voxel_dim, voxel_dim]
210 | self._vox_size = float(
211 | (self._vol_bounds[0, 1] - self._vol_bounds[0, 0]) / voxel_dim
212 | )
213 | self._trunc_margin = num_margin * self._vox_size # truncation on SDF
214 |
215 | # Check GPU
216 | if self._use_gpu:
217 | if torch.cuda.is_available():
218 | if self._verbose:
219 | print("# Using GPU mode")
220 | self._device = torch.device("cuda:0")
221 | else:
222 | if self._verbose:
223 | print("# Not available CUDA device, using CPU mode")
224 | self._device = torch.device("cpu")
225 | else:
226 | if self._verbose:
227 | print("# Using CPU mode")
228 | self._device = torch.device("cpu")
229 |
230 | # Coordinate origin of the volume, set as the min value of volume bounds
231 | self._vol_origin = torch.tensor(
232 | self._vol_bounds[:, 0].copy(order="C"), device=self._device
233 | ).float()
234 |
235 | # Grid coordinates of voxels
236 | xx, yy, zz = torch.meshgrid(
237 | torch.arange(self._vol_dim[0]),
238 | torch.arange(self._vol_dim[1]),
239 | torch.arange(self._vol_dim[2]),
240 | indexing="ij",
241 | )
242 | self._vox_coords = (
243 | torch.cat([xx.reshape(1, -1), yy.reshape(1, -1), zz.reshape(1, -1)], dim=0)
244 | .int()
245 | .T
246 | )
247 | if self._use_gpu:
248 | self._vox_coords = self._vox_coords.cuda()
249 |
250 | # World coordinates of voxel centers
251 | self._world_coords = self.vox2world(
252 | self._vol_origin, self._vox_coords, self._vox_size
253 | )
254 | self.enable_color = enable_color
255 |
256 | # TSDF & weights
257 | self._tsdf_vol = torch.ones(
258 | size=self._vol_dim, device=self._device, dtype=torch.float32
259 | )
260 | self._weight_vol = torch.zeros(
261 | size=self._vol_dim, device=self._device, dtype=torch.float32
262 | )
263 | if self.enable_color:
264 | self._color_vol = torch.zeros(
265 | size=[*self._vol_dim, 3], device=self._device, dtype=torch.float32
266 | )
267 |
268 | # Mesh paramters
269 | self._mesh = o3d.geometry.TriangleMesh()
270 | self.unknown_free = unknown_free
271 | @staticmethod
272 | def vox2world(vol_origin: torch.Tensor, vox_coords: torch.Tensor, vox_size):
273 | """
274 | Converts voxel grid coordinates to world coordinates
275 |
276 | :param vol_origin: Origin of the volume in world coordinates, (3,1).
277 | :parma vol_coords: List of all grid coordinates in the volume, (N,3).
278 | :param vol_size: Size of volume.
279 | :retrun: Grid points under world coordinates. Tensor with shape (N, 3)
280 | """
281 |
282 | cam_pts = torch.empty_like(vox_coords, dtype=torch.float32)
283 | cam_pts = vol_origin + (vox_size * vox_coords)
284 |
285 | return cam_pts
286 |
287 | @staticmethod
288 | def cam2pix(cam_pts: torch.Tensor, intrinsics: torch.Tensor):
289 | """
290 | Convert points in camera coordinate to pixel coordinates
291 |
292 | :param cam_pts: Points in camera coordinates, (N,3).
293 | :param intrinsics: Vamera intrinsics, (3,3).
294 | :return: Pixel coordinate (u,v) cooresponding to input points. Tensor with shape (N, 2).
295 | """
296 |
297 | cam_pts_z = cam_pts[:, 2].repeat(3, 1).T
298 | pix = torch.round((cam_pts @ intrinsics.T) / cam_pts_z)
299 |
300 | return pix
301 |
302 | @staticmethod
303 | def ridgid_transform(points: torch.Tensor, transform: torch.Tensor):
304 | """
305 | Apply rigid transform (4,4) on points
306 |
307 | :param points: Points, shape (N,3).
308 | :param transform: Tranform matrix, shape (4,4).
309 | :return: Points after transform.
310 | """
311 |
312 | points_h = torch.cat(
313 | [points, torch.ones((points.shape[0], 1), device=points.device)], 1
314 | )
315 | points_h = (transform @ points_h.T).T
316 |
317 | return points_h[:, :3]
318 |
319 | def get_tsdf_volume(self):
320 | return self._tsdf_vol.cpu().numpy()
321 |
322 | def get_color_volume(self):
323 | return self._color_vol.permute(3, 0, 1, 2).cpu().numpy()
324 |
325 | def get_mesh(self):
326 | """
327 | Get mesh.
328 | """
329 | return self._mesh
330 |
331 | def integrate(self, color_img, depth_img, intrinsic, cam_pose, weight: float = 1.0):
332 | """
333 | Integrate an depth image to the TSDF volume
334 |
335 | :param depth_img: depth image with depth value in meter.
336 | :param intrinsics: camera intrinsics of shape (3,3).
337 | :param cam_pose: camera pose, transform matrix of shape (4,4)
338 | :param weight: weight assign for current frame, higher value indicate higher confidence
339 | """
340 |
341 | time_begin = time.time()
342 | img_h, img_w = depth_img.shape
343 | depth_img = torch.from_numpy(depth_img).float().to(self._device)
344 | color_img = torch.from_numpy(color_img).float().to(self._device) # [H, W, 3]
345 | cam_pose = torch.from_numpy(cam_pose).float().to(self._device)
346 | intrinsic = torch.from_numpy(intrinsic).float().to(self._device)
347 |
348 | # TODO:
349 | # Better way to select valid voxels.
350 | # - Current:
351 | # -> Back project all voxels to frame pixels according to current camera pose.
352 | # -> Select valid pixels within frame size.
353 | # - Possible:
354 | # -> Project pixel to voxel coordinates
355 | # -> hash voxel coordinates
356 | # -> dynamically allocate voxel chunks
357 |
358 | # Get the world coordinates of all voxels
359 | # world_points = geometry.vox2world(self._vol_origin, self._vox_coords, self._vox_size)
360 |
361 | # Get voxel centers under camera coordinates
362 | world_points = self.ridgid_transform(
363 | self._world_coords, cam_pose.inverse()
364 | ) # [N^3, 3]
365 |
366 | # Get the pixel coordinates (u,v) of all voxels under current camere pose
367 | # Multiple voxels can be projected to a same (u,v)
368 | voxel_uv = self.cam2pix(world_points, intrinsic) # [N^3, 3]
369 | voxel_u, voxel_v = voxel_uv[:, 0], voxel_uv[:, 1] # [N^3], [N^3]
370 | voxel_z = world_points[:, 2]
371 |
372 | # Filter out voxels points that visible in current frame
373 | pixel_mask = torch.logical_and(
374 | voxel_u >= 0,
375 | torch.logical_and(
376 | voxel_u < img_w,
377 | torch.logical_and(
378 | voxel_v >= 0, torch.logical_and(voxel_v < img_h, voxel_z > 0)
379 | ),
380 | ),
381 | )
382 |
383 | # Get depth value
384 | depth_value = torch.zeros(voxel_u.shape, device=self._device)
385 | depth_value[pixel_mask] = depth_img[
386 | voxel_v[pixel_mask].long(), voxel_u[pixel_mask].long()
387 | ]
388 |
389 | # Compute and Integrate TSDF
390 | sdf_value = depth_value - voxel_z # Compute SDF
391 | if self.unknown_free:
392 | voxel_mask = torch.logical_and(
393 | depth_value > 0, sdf_value >= -self._trunc_margin
394 | ) # Truncate SDF
395 | else:
396 | voxel_mask = depth_value > 0 # Truncate SDF
397 |
398 | tsdf_value = torch.minimum(
399 | torch.ones_like(sdf_value, device=self._device),
400 | sdf_value / self._trunc_margin,
401 | )
402 | tsdf_value = tsdf_value[voxel_mask]
403 | # Get coordinates of valid voxels with valid TSDF value
404 | valid_vox_x = self._vox_coords[voxel_mask, 0].long()
405 | valid_vox_y = self._vox_coords[voxel_mask, 1].long()
406 | valid_vox_z = self._vox_coords[voxel_mask, 2].long()
407 |
408 | # Update TSDF of cooresponding voxels
409 | weight_old = self._weight_vol[valid_vox_x, valid_vox_y, valid_vox_z]
410 | tsdf_old = self._tsdf_vol[valid_vox_x, valid_vox_y, valid_vox_z]
411 |
412 | if self.enable_color:
413 | color_value = torch.zeros([voxel_u.shape[0], 3], device=self._device)
414 | color_value[pixel_mask] = color_img[
415 | voxel_v[pixel_mask].long(), voxel_u[pixel_mask].long(), :
416 | ]
417 | color_value = color_value[voxel_mask]
418 | color_old = self._color_vol[valid_vox_x, valid_vox_y, valid_vox_z]
419 |
420 | else:
421 | color_value = None
422 | color_old = None
423 | tsdf_new, color_new, weight_new = self.update_tsdf(
424 | tsdf_old, tsdf_value, color_old, color_value, weight_old, weight
425 | )
426 |
427 | self._tsdf_vol[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_new
428 | self._weight_vol[valid_vox_x, valid_vox_y, valid_vox_z] = weight_new
429 |
430 | if self.enable_color:
431 | self._color_vol[valid_vox_x, valid_vox_y, valid_vox_z] = color_new
432 |
433 | if self._verbose:
434 | print("# Update {} voxels.".format(len(tsdf_new)))
435 | print(
436 | "# Integration Timing: {:.5f} (second).".format(
437 | time.time() - time_begin
438 | )
439 | )
440 |
441 | def get_mesh(self):
442 | """
443 | Extract mesh from current TSDF volume.
444 | """
445 |
446 | time_begin = time.time()
447 |
448 | if self._use_gpu:
449 | tsdf_vol = self._tsdf_vol.cpu().numpy()
450 | vol_origin = self._vol_origin.cpu().numpy()
451 | if self.enable_color:
452 | color_vol = self._color_vol.cpu().numpy() / 255
453 |
454 | else:
455 | tsdf_vol = self._tsdf_vol.numpy()
456 | vol_origin = self._vol_origin.numpy()
457 | if self.enable_color:
458 | color_vol = self._color_vol.numpy() / 255
459 |
460 | _vertices, triangles, _, _ = measure.marching_cubes(-tsdf_vol, 0)
461 | vertices_sample = (_vertices / self._vol_dim[0] - 0.5) * 2
462 |
463 | # interpolation to get colors
464 | vertices_pt = torch.from_numpy(vertices_sample).float()[
465 | None, None, None, :, [2, 1, 0]
466 | ] # [1, 1, 1, N, 3]
467 |
468 | # mesh_vertices = _vertices / self._vol_dim[0] + self._vol_origin.cpu().numpy()
469 | mesh_vertices = vertices_sample
470 | self._mesh.vertices = o3d.utility.Vector3dVector(mesh_vertices.astype(float))
471 | self._mesh.triangles = o3d.utility.Vector3iVector(triangles.astype(np.int32))
472 | if self.enable_color:
473 | color_vol_pt = (
474 | torch.from_numpy(color_vol).float().permute(3, 0, 1, 2)[None]
475 | ) # [1, 3, H, W, D]
476 | vert_colors = torch.nn.functional.grid_sample(
477 | color_vol_pt, vertices_pt, align_corners=True
478 | ) # [1, 3, 1, 1, N]
479 | vert_colors = vert_colors.squeeze().cpu().numpy().T
480 | self._mesh.vertex_colors = o3d.utility.Vector3dVector(
481 | vert_colors.astype(float)
482 | )
483 |
484 | self._mesh.compute_vertex_normals()
485 | if self._verbose:
486 | print("# Extracting Mesh: {} Vertices".format(mesh_vertices.shape[0]))
487 | print("# Meshing Timing: {:.5f} (second).".format(time.time() - time_begin))
488 | return self._mesh
489 |
490 | def update_tsdf(
491 | self, tsdf_old, tsdf_new, color_old, color_new, weight_old, obs_weight
492 | ):
493 | """
494 | Update the TSDF value of given voxel
495 | V = (wv + WV) / w + W
496 |
497 | :param tsdf_old: Old TSDF values.
498 | :param tsdf_new: New TSDF values.
499 | :param weight_old: Voxels weights.
500 | :param obs_weight: Weight of current update.
501 | :return: Updated TSDF values & Updated weights.
502 | """
503 |
504 | tsdf_vol_int = torch.empty_like(
505 | tsdf_old, dtype=torch.float32, device=self._device
506 | )
507 | weight_new = torch.empty_like(
508 | weight_old, dtype=torch.float32, device=self._device
509 | )
510 |
511 | weight_new = weight_old + obs_weight
512 | tsdf_vol_int = (weight_old * tsdf_old + obs_weight * tsdf_new) / weight_new
513 | if color_old is not None:
514 | color_vol_int = (
515 | weight_old[:, None] * color_old + obs_weight * color_new
516 | ) / weight_new[:, None]
517 | return tsdf_vol_int, color_vol_int, weight_new
518 | else:
519 | color_vol_int = None
520 |
521 | return tsdf_vol_int, color_vol_int, weight_new
522 |
523 |
524 | class TSDFVolume2(TSDFVolume):
525 | """
526 | Volumetric with TSDF representation
527 | """
528 |
529 | def __init__(
530 | self,
531 | vol_bounds: np.ndarray,
532 | voxel_size: float,
533 | use_gpu: bool = False,
534 | verbose: bool = False,
535 | num_margin: float = 5.0,
536 | enable_color=True,
537 | ):
538 | """
539 | Constructor
540 |
541 | :param vol_bounds: An ndarray is shape (3,2), define the min & max bounds of voxels.
542 | :param voxel_size: Voxel size in meters.
543 | :param use_gpu: Use GPU for voxel update.
544 | :param verbose: Print verbose message or not.
545 | """
546 |
547 | vol_bounds = np.asarray(vol_bounds)
548 | assert vol_bounds.shape == (3, 2), "vol_bounds should be of shape (3,2)"
549 |
550 | self._verbose = verbose
551 | self._use_gpu = use_gpu
552 | self._vol_bounds = vol_bounds
553 | if self._use_gpu:
554 | if torch.cuda.is_available():
555 | if self._verbose:
556 | print("# Using GPU mode")
557 | self._device = torch.device("cuda:0")
558 | else:
559 | if self._verbose:
560 | print("# Not available CUDA device, using CPU mode")
561 | self._device = torch.device("cpu")
562 | else:
563 | if self._verbose:
564 | print("# Using CPU mode")
565 | self._device = torch.device("cpu")
566 | self._voxel_size = float(voxel_size)
567 | self._trunc_margin = 5 * self._voxel_size # truncation on SDF
568 | self._vox_size = float(voxel_size)
569 | self._trunc_margin = num_margin * self._vox_size # truncation on SDF
570 | # Adjust volume bounds and ensure C-order contiguous
571 | self._vol_dim = (
572 | np.ceil(
573 | (self._vol_bounds[:, 1] - self._vol_bounds[:, 0]) / self._voxel_size
574 | )
575 | .copy(order="C")
576 | .astype(int)
577 | )
578 |
579 | self._vol_bounds[:, 1] = (
580 | self._vol_bounds[:, 0] + self._vol_dim * self._voxel_size
581 | )
582 | self._vol_dim = self._vol_dim.tolist()
583 | # self._vol_origin = self._vol_bounds[:, 0].copy(order="C").astype(np.float32)
584 | # self._vol_dim = torch.tensor(self._vol_dim, device=self._device).int()
585 | self._vol_origin = torch.tensor(
586 | self._vol_bounds[:, 0].copy(order="C"), device=self._device
587 | ).float()
588 | # Grid coordinates of voxels
589 | xx, yy, zz = torch.meshgrid(
590 | torch.arange(self._vol_dim[0]),
591 | torch.arange(self._vol_dim[1]),
592 | torch.arange(self._vol_dim[2]),
593 | indexing="ij",
594 | )
595 | self._vox_coords = (
596 | torch.cat([xx.reshape(1, -1), yy.reshape(1, -1), zz.reshape(1, -1)], dim=0)
597 | .int()
598 | .T
599 | )
600 | if self._use_gpu:
601 | self._vox_coords = self._vox_coords.cuda()
602 |
603 | # World coordinates of voxel centers
604 | self._world_coords = self.vox2world(
605 | self._vol_origin, self._vox_coords, self._vox_size
606 | )
607 | self.enable_color = enable_color
608 |
609 | # TSDF & weights
610 | self._tsdf_vol = torch.ones(
611 | size=self._vol_dim, device=self._device, dtype=torch.float32
612 | )
613 | self._weight_vol = torch.zeros(
614 | size=self._vol_dim, device=self._device, dtype=torch.float32
615 | )
616 | if self.enable_color:
617 | self._color_vol = torch.zeros(
618 | size=[*self._vol_dim, 3], device=self._device, dtype=torch.float32
619 | )
620 |
621 | # Mesh paramters
622 | self._mesh = o3d.geometry.TriangleMesh()
623 |
--------------------------------------------------------------------------------
/models/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 | from models.clip.interpolate import interpolate_positional_embedding
10 |
11 |
12 | class Bottleneck(nn.Module):
13 | expansion = 4
14 |
15 | def __init__(self, inplanes, planes, stride=1):
16 | super().__init__()
17 |
18 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
19 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
20 | self.bn1 = nn.BatchNorm2d(planes)
21 | self.relu1 = nn.ReLU(inplace=True)
22 |
23 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
24 | self.bn2 = nn.BatchNorm2d(planes)
25 | self.relu2 = nn.ReLU(inplace=True)
26 |
27 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
28 |
29 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
30 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
31 | self.relu3 = nn.ReLU(inplace=True)
32 |
33 | self.downsample = None
34 | self.stride = stride
35 |
36 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
37 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
38 | self.downsample = nn.Sequential(
39 | OrderedDict(
40 | [
41 | ("-1", nn.AvgPool2d(stride)),
42 | (
43 | "0",
44 | nn.Conv2d(
45 | inplanes,
46 | planes * self.expansion,
47 | 1,
48 | stride=1,
49 | bias=False,
50 | ),
51 | ),
52 | ("1", nn.BatchNorm2d(planes * self.expansion)),
53 | ]
54 | )
55 | )
56 |
57 | def forward(self, x: torch.Tensor):
58 | identity = x
59 |
60 | out = self.relu1(self.bn1(self.conv1(x)))
61 | out = self.relu2(self.bn2(self.conv2(out)))
62 | out = self.avgpool(out)
63 | out = self.bn3(self.conv3(out))
64 |
65 | if self.downsample is not None:
66 | identity = self.downsample(x)
67 |
68 | out += identity
69 | out = self.relu3(out)
70 | return out
71 |
72 |
73 | class AttentionPool2d(nn.Module):
74 | def __init__(
75 | self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
76 | ):
77 | super().__init__()
78 | self.positional_embedding = nn.Parameter(
79 | torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
80 | )
81 | self.k_proj = nn.Linear(embed_dim, embed_dim)
82 | self.q_proj = nn.Linear(embed_dim, embed_dim)
83 | self.v_proj = nn.Linear(embed_dim, embed_dim)
84 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
85 | self.num_heads = num_heads
86 | self.spacial_dim = spacial_dim
87 |
88 | def forward(self, x):
89 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
90 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
91 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
92 | x, _ = F.multi_head_attention_forward(
93 | query=x[:1],
94 | key=x,
95 | value=x,
96 | embed_dim_to_check=x.shape[-1],
97 | num_heads=self.num_heads,
98 | q_proj_weight=self.q_proj.weight,
99 | k_proj_weight=self.k_proj.weight,
100 | v_proj_weight=self.v_proj.weight,
101 | in_proj_weight=None,
102 | in_proj_bias=torch.cat(
103 | [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
104 | ),
105 | bias_k=None,
106 | bias_v=None,
107 | add_zero_attn=False,
108 | dropout_p=0,
109 | out_proj_weight=self.c_proj.weight,
110 | out_proj_bias=self.c_proj.bias,
111 | use_separate_proj_weight=True,
112 | training=self.training,
113 | need_weights=False,
114 | )
115 | return x.squeeze(0)
116 |
117 | def forward_v(self, x: torch.Tensor):
118 | """
119 | Forward function for computing the value features for dense prediction (i.e., features for every image patch).
120 | """
121 | _, _, w, h = x.shape
122 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
123 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
124 |
125 | # Interpolate positional embedding to match the size of the input
126 | interpolated_pe = interpolate_positional_embedding(
127 | self.positional_embedding, x.permute(1, 0, 2), patch_size=1, w=w, h=h
128 | )
129 | x = x + interpolated_pe[:, None, :] # (HW+1)NC
130 |
131 | v_in = F.linear(x, self.v_proj.weight, self.v_proj.bias)
132 | v_out = F.linear(v_in, self.c_proj.weight, self.c_proj.bias)
133 | v_out = v_out.permute(1, 0, 2) # (HW+1)NC -> N(HW+1)C
134 | return v_out
135 |
136 |
137 | class ModifiedResNet(nn.Module):
138 | """
139 | A ResNet class that is similar to torchvision's but contains the following changes:
140 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
141 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
142 | - The final pooling layer is a QKV attention instead of an average pool
143 | """
144 |
145 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
146 | super().__init__()
147 | self.output_dim = output_dim
148 | self.input_resolution = input_resolution
149 |
150 | # the 3-layer stem
151 | self.conv1 = nn.Conv2d(
152 | 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
153 | )
154 | self.bn1 = nn.BatchNorm2d(width // 2)
155 | self.relu1 = nn.ReLU(inplace=True)
156 | self.conv2 = nn.Conv2d(
157 | width // 2, width // 2, kernel_size=3, padding=1, bias=False
158 | )
159 | self.bn2 = nn.BatchNorm2d(width // 2)
160 | self.relu2 = nn.ReLU(inplace=True)
161 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
162 | self.bn3 = nn.BatchNorm2d(width)
163 | self.relu3 = nn.ReLU(inplace=True)
164 | self.avgpool = nn.AvgPool2d(2)
165 |
166 | # residual layers
167 | self._inplanes = width # this is a *mutable* variable used during construction
168 | self.layer1 = self._make_layer(width, layers[0])
169 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
170 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
171 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
172 |
173 | embed_dim = width * 32 # the ResNet feature dimension
174 | self.attnpool = AttentionPool2d(
175 | input_resolution // 32, embed_dim, heads, output_dim
176 | )
177 |
178 | def _make_layer(self, planes, blocks, stride=1):
179 | layers = [Bottleneck(self._inplanes, planes, stride)]
180 |
181 | self._inplanes = planes * Bottleneck.expansion
182 | for _ in range(1, blocks):
183 | layers.append(Bottleneck(self._inplanes, planes))
184 |
185 | return nn.Sequential(*layers)
186 |
187 | def forward(self, x, patch_output: bool = False):
188 | def stem(x):
189 | x = self.relu1(self.bn1(self.conv1(x)))
190 | x = self.relu2(self.bn2(self.conv2(x)))
191 | x = self.relu3(self.bn3(self.conv3(x)))
192 | x = self.avgpool(x)
193 | return x
194 |
195 | x = x.type(self.conv1.weight.dtype)
196 | x = stem(x)
197 | x = self.layer1(x)
198 | x = self.layer2(x)
199 | x = self.layer3(x)
200 | x = self.layer4(x)
201 |
202 | if patch_output:
203 | x = self.attnpool.forward_v(x)
204 | x = x[:, 1:, :] # remove the cls token
205 | else:
206 | x = self.attnpool(x)
207 |
208 | return x
209 |
210 |
211 | class LayerNorm(nn.LayerNorm):
212 | """Subclass torch's LayerNorm to handle fp16."""
213 |
214 | def forward(self, x: torch.Tensor):
215 | orig_type = x.dtype
216 | ret = super().forward(x.type(torch.float32))
217 | return ret.type(orig_type)
218 |
219 |
220 | class QuickGELU(nn.Module):
221 | def forward(self, x: torch.Tensor):
222 | return x * torch.sigmoid(1.702 * x)
223 |
224 |
225 | class ResidualAttentionBlock(nn.Module):
226 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
227 | super().__init__()
228 |
229 | self.attn = nn.MultiheadAttention(d_model, n_head)
230 | self.ln_1 = LayerNorm(d_model)
231 | self.mlp = nn.Sequential(
232 | OrderedDict(
233 | [
234 | ("c_fc", nn.Linear(d_model, d_model * 4)),
235 | ("gelu", QuickGELU()),
236 | ("c_proj", nn.Linear(d_model * 4, d_model)),
237 | ]
238 | )
239 | )
240 | self.ln_2 = LayerNorm(d_model)
241 | self.attn_mask = attn_mask
242 |
243 | def attention(self, x: torch.Tensor):
244 | self.attn_mask = (
245 | self.attn_mask.to(dtype=x.dtype, device=x.device)
246 | if self.attn_mask is not None
247 | else None
248 | )
249 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
250 |
251 | def forward_v(self, x: torch.Tensor):
252 | """
253 | Forward function for computing the value features for dense prediction (i.e., features for every image patch).
254 | """
255 | # Get the weights and biases for the value projection, multihead attention uses 3 * embed_dim for the input projection
256 | v_in_proj_weight = self.attn.in_proj_weight[-self.attn.embed_dim :]
257 | v_in_proj_bias = self.attn.in_proj_bias[-self.attn.embed_dim :]
258 |
259 | v_in = F.linear(self.ln_1(x), v_in_proj_weight, v_in_proj_bias)
260 | v_out = F.linear(v_in, self.attn.out_proj.weight, self.attn.out_proj.bias)
261 |
262 | # Using the value features works the best. Adding this to 'x' or feeding 'v' to the LayerNorm then MLP degrades the performance
263 | return v_out
264 |
265 | def forward(self, x: torch.Tensor):
266 | x = x + self.attention(self.ln_1(x))
267 | x = x + self.mlp(self.ln_2(x))
268 | return x
269 |
270 |
271 | class Transformer(nn.Module):
272 | def __init__(
273 | self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None
274 | ):
275 | super().__init__()
276 | self.width = width
277 | self.layers = layers
278 | self.resblocks = nn.Sequential(
279 | *[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]
280 | )
281 |
282 | def forward(self, x: torch.Tensor):
283 | return self.resblocks(x)
284 |
285 |
286 | class VisionTransformer(nn.Module):
287 | def __init__(
288 | self,
289 | input_resolution: int,
290 | patch_size: int,
291 | width: int,
292 | layers: int,
293 | heads: int,
294 | output_dim: int,
295 | output_indices: list = [6, 7, 8, 9, 10],
296 | ):
297 | super().__init__()
298 | self.input_resolution = input_resolution
299 | self.output_dim = output_dim
300 | self.conv1 = nn.Conv2d(
301 | in_channels=3,
302 | out_channels=width,
303 | kernel_size=patch_size,
304 | stride=patch_size,
305 | bias=False,
306 | )
307 |
308 | scale = width**-0.5
309 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
310 | self.positional_embedding = nn.Parameter(
311 | scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)
312 | )
313 | self.ln_pre = LayerNorm(width)
314 |
315 | self.transformer = Transformer(width, layers, heads)
316 |
317 | self.ln_post = LayerNorm(width)
318 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
319 |
320 | self.patch_size = patch_size
321 | self.output_indices = output_indices
322 |
323 | def forward(self, x: torch.Tensor, patch_output: bool = False, return_dict=False):
324 | output_dict = {}
325 |
326 | _, _, w, h = x.shape
327 |
328 | x = self.conv1(x) # shape = [*, width, grid, grid]
329 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
330 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
331 | x = torch.cat(
332 | [
333 | self.class_embedding.to(x.dtype)
334 | + torch.zeros(
335 | x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
336 | ),
337 | x,
338 | ],
339 | dim=1,
340 | ) # shape = [*, grid ** 2 + 1, width]
341 | x = x + interpolate_positional_embedding(
342 | self.positional_embedding, x, patch_size=self.patch_size, w=w, h=h
343 | )
344 | x = self.ln_pre(x)
345 |
346 | x = x.permute(1, 0, 2) # NLD -> LND
347 |
348 | if patch_output:
349 | *layers, last_resblock = self.transformer.resblocks
350 | penultimate = nn.Sequential(*layers)
351 | for ii, layer in enumerate(layers):
352 | _x = layer(x) # Results after layer normalization
353 | # TODO: add f"res{ii}" output dict here!!!
354 | if return_dict and ii in self.output_indices:
355 | res_i = ii - 5
356 | output_dict[f"res{res_i}"] = _x.permute(1, 0, 2)
357 |
358 | x = penultimate(x)
359 | x = last_resblock.forward_v(x)
360 | x = x.permute(1, 0, 2) # LND -> NLD
361 |
362 | # Extract the patch tokens, not the class token
363 | x = x[:, 1:, :]
364 | x = self.ln_post(x)
365 | if self.proj is not None:
366 | # This is equivalent to conv1d
367 | x = x @ self.proj
368 |
369 | if return_dict:
370 | output_dict["final"] = x
371 | for k, v in output_dict.items():
372 | print(k, v.shape)
373 | return output_dict
374 | else:
375 | return x
376 |
377 | x = self.transformer(x)
378 | x = x.permute(1, 0, 2) # LND -> NLD
379 |
380 | x = self.ln_post(x[:, 0, :])
381 |
382 | if self.proj is not None:
383 | x = x @ self.proj
384 |
385 | return x
386 |
387 |
388 | class CLIP(nn.Module):
389 | def __init__(
390 | self,
391 | embed_dim: int,
392 | # vision
393 | image_resolution: int,
394 | vision_layers: Union[Tuple[int, int, int, int], int],
395 | vision_width: int,
396 | vision_patch_size: int,
397 | # text
398 | context_length: int,
399 | vocab_size: int,
400 | transformer_width: int,
401 | transformer_heads: int,
402 | transformer_layers: int,
403 | ):
404 | super().__init__()
405 |
406 | self.context_length = context_length
407 |
408 | if isinstance(vision_layers, (tuple, list)):
409 | vision_heads = vision_width * 32 // 64
410 | self.visual = ModifiedResNet(
411 | layers=vision_layers,
412 | output_dim=embed_dim,
413 | heads=vision_heads,
414 | input_resolution=image_resolution,
415 | width=vision_width,
416 | )
417 | else:
418 | vision_heads = vision_width // 64
419 | self.visual = VisionTransformer(
420 | input_resolution=image_resolution,
421 | patch_size=vision_patch_size,
422 | width=vision_width,
423 | layers=vision_layers,
424 | heads=vision_heads,
425 | output_dim=embed_dim,
426 | )
427 |
428 | self.transformer = Transformer(
429 | width=transformer_width,
430 | layers=transformer_layers,
431 | heads=transformer_heads,
432 | attn_mask=self.build_attention_mask(),
433 | )
434 |
435 | self.vocab_size = vocab_size
436 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
437 | self.positional_embedding = nn.Parameter(
438 | torch.empty(self.context_length, transformer_width)
439 | )
440 | self.ln_final = LayerNorm(transformer_width)
441 |
442 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
443 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
444 |
445 | self.initialize_parameters()
446 |
447 | def initialize_parameters(self):
448 | nn.init.normal_(self.token_embedding.weight, std=0.02)
449 | nn.init.normal_(self.positional_embedding, std=0.01)
450 |
451 | if isinstance(self.visual, ModifiedResNet):
452 | if self.visual.attnpool is not None:
453 | std = self.visual.attnpool.c_proj.in_features**-0.5
454 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
455 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
456 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
457 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
458 |
459 | for resnet_block in [
460 | self.visual.layer1,
461 | self.visual.layer2,
462 | self.visual.layer3,
463 | self.visual.layer4,
464 | ]:
465 | for name, param in resnet_block.named_parameters():
466 | if name.endswith("bn3.weight"):
467 | nn.init.zeros_(param)
468 |
469 | proj_std = (self.transformer.width**-0.5) * (
470 | (2 * self.transformer.layers) ** -0.5
471 | )
472 | attn_std = self.transformer.width**-0.5
473 | fc_std = (2 * self.transformer.width) ** -0.5
474 | for block in self.transformer.resblocks:
475 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
476 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
477 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
478 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
479 |
480 | if self.text_projection is not None:
481 | nn.init.normal_(self.text_projection, std=self.transformer.width**-0.5)
482 |
483 | def build_attention_mask(self):
484 | # lazily create causal attention mask, with full attention between the vision tokens
485 | # pytorch uses additive attention mask; fill with -inf
486 | mask = torch.empty(self.context_length, self.context_length)
487 | mask.fill_(float("-inf"))
488 | mask.triu_(1) # zero out the lower diagonal
489 | return mask
490 |
491 | @property
492 | def dtype(self):
493 | return self.visual.conv1.weight.dtype
494 |
495 | def encode_image(self, image):
496 | return self.visual(image.type(self.dtype))
497 |
498 | def get_patch_encodings(self, image) -> torch.Tensor:
499 | """Get the encodings for each patch in the image"""
500 | return self.visual(image.type(self.dtype), patch_output=True)
501 |
502 | def get_image_encoder_projection(self) -> nn.Parameter:
503 | """Get vision transformer projection matrix."""
504 | assert isinstance(self.visual, VisionTransformer)
505 | return self.visual.proj
506 |
507 | def encode_text(self, text):
508 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
509 |
510 | x = x + self.positional_embedding.type(self.dtype)
511 | x = x.permute(1, 0, 2) # NLD -> LND
512 | x = self.transformer(x)
513 | x = x.permute(1, 0, 2) # LND -> NLD
514 | x = self.ln_final(x).type(self.dtype)
515 |
516 | # x.shape = [batch_size, n_ctx, transformer.width]
517 | # take features from the eot embedding (eot_token is the highest number in each sequence)
518 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
519 |
520 | return x
521 |
522 | def forward(self, image, text):
523 | image_features = self.encode_image(image)
524 | text_features = self.encode_text(text)
525 |
526 | # normalized features
527 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
528 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
529 |
530 | # cosine similarity as logits
531 | logit_scale = self.logit_scale.exp()
532 | logits_per_image = logit_scale * image_features @ text_features.t()
533 | logits_per_text = logits_per_image.t()
534 |
535 | # shape = [global_batch_size, global_batch_size]
536 | return logits_per_image, logits_per_text
537 |
538 |
539 | def convert_weights(model: nn.Module):
540 | """Convert applicable model parameters to fp16"""
541 |
542 | def _convert_weights_to_fp16(l):
543 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
544 | l.weight.data = l.weight.data.half()
545 | if l.bias is not None:
546 | l.bias.data = l.bias.data.half()
547 |
548 | if isinstance(l, nn.MultiheadAttention):
549 | for attr in [
550 | *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
551 | "in_proj_bias",
552 | "bias_k",
553 | "bias_v",
554 | ]:
555 | tensor = getattr(l, attr)
556 | if tensor is not None:
557 | tensor.data = tensor.data.half()
558 |
559 | for name in ["text_projection", "proj"]:
560 | if hasattr(l, name):
561 | attr = getattr(l, name)
562 | if attr is not None:
563 | attr.data = attr.data.half()
564 |
565 | model.apply(_convert_weights_to_fp16)
566 |
567 |
568 | def build_model(state_dict: dict):
569 | vit = "visual.proj" in state_dict
570 |
571 | if vit:
572 | vision_width = state_dict["visual.conv1.weight"].shape[0]
573 | vision_layers = len(
574 | [
575 | k
576 | for k in state_dict.keys()
577 | if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")
578 | ]
579 | )
580 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
581 | grid_size = round(
582 | (state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5
583 | )
584 | image_resolution = vision_patch_size * grid_size
585 | else:
586 | counts: list = [
587 | len(
588 | set(
589 | k.split(".")[2]
590 | for k in state_dict
591 | if k.startswith(f"visual.layer{b}")
592 | )
593 | )
594 | for b in [1, 2, 3, 4]
595 | ]
596 | vision_layers = tuple(counts)
597 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
598 | output_width = round(
599 | (state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5
600 | )
601 | vision_patch_size = None
602 | assert (
603 | output_width**2 + 1
604 | == state_dict["visual.attnpool.positional_embedding"].shape[0]
605 | )
606 | image_resolution = output_width * 32
607 |
608 | embed_dim = state_dict["text_projection"].shape[1]
609 | context_length = state_dict["positional_embedding"].shape[0]
610 | vocab_size = state_dict["token_embedding.weight"].shape[0]
611 | transformer_width = state_dict["ln_final.weight"].shape[0]
612 | transformer_heads = transformer_width // 64
613 | transformer_layers = len(
614 | set(
615 | k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")
616 | )
617 | )
618 |
619 | model = CLIP(
620 | embed_dim,
621 | image_resolution,
622 | vision_layers,
623 | vision_width,
624 | vision_patch_size,
625 | context_length,
626 | vocab_size,
627 | transformer_width,
628 | transformer_heads,
629 | transformer_layers,
630 | )
631 |
632 | for key in ["input_resolution", "context_length", "vocab_size"]:
633 | if key in state_dict:
634 | del state_dict[key]
635 |
636 | convert_weights(model)
637 | model.load_state_dict(state_dict)
638 | return model.eval()
639 |
--------------------------------------------------------------------------------