├── .gitignore
├── CLIP
├── .github
│ └── workflows
│ │ └── test.yml
├── .gitignore
├── CLIP.png
├── LICENSE
├── MANIFEST.in
├── README.md
├── clip
│ ├── __init__.py
│ ├── auxiliary.py
│ ├── bpe_simple_vocab_16e6.txt
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── clip.py
│ ├── clip_explainability.py
│ ├── clip_gradcam.py
│ ├── model.py
│ ├── model_explainability.py
│ └── simple_tokenizer.py
├── model-card.md
├── requirements.txt
├── setup.py
└── tests
│ └── test_consistency.py
├── LICENSE
├── README.md
├── arm
├── LICENSE
├── __init__.py
├── network_utils.py
├── optim
│ ├── LICENSE
│ ├── __init__.py
│ └── lamb.py
└── utils.py
├── assets
├── hair_dryer_behind_table.gif
├── hair_dryer_completion.gif
├── hair_dryer_completion_blender.gif
├── hair_dryer_completion_blender_legend.png
├── hair_dryer_scene.png
├── matterport.png
├── teaser.gif
└── vn_poster_relevancies.png
├── datagen
├── README.md
└── assets
│ └── unity-build.png
├── dataset.py
├── eval.py
├── fusion.py
├── generate_relevancy.py
├── generate_thor_data.py
├── grads.png
├── matterport.png
├── net.py
├── plot_utils.py
├── point_cloud.py
├── scene_files
├── arkit_kitchen.pkl
├── arkit_vn_poster.pkl
├── matterport_hallway.pkl
├── matterport_kitchen.pkl
├── matterport_living_room.pkl
└── walle.pkl
├── semabs.yml
├── summarize.py
├── test_semantic_classes.txt
├── test_thor_rooms.txt
├── train_ovssc.py
├── train_vool.py
├── unet3d.py
├── utils.py
└── visualize.py
/.gitignore:
--------------------------------------------------------------------------------
1 | crv04
2 | crv05
3 | *.pth
4 | *.ply
5 | *.obj
6 | *.jpg
7 | *.blend
8 | *.blend1
9 | tags
10 | log*
11 | temp*
12 | *.hdf5
13 | *.pkl
14 | *.lock
15 | .vscode/
16 | logs/
17 | *cache*
18 |
19 | # Byte-compiled / optimized / DLL files
20 | __pycache__/
21 | *.py[cod]
22 | *$py.class
23 |
24 | # C extensions
25 | *.so
26 |
27 | # Distribution / packaging
28 | .Python
29 | build/
30 | develop-eggs/
31 | dist/
32 | downloads/
33 | eggs/
34 | .eggs/
35 | lib/
36 | lib64/
37 | parts/
38 | sdist/
39 | var/
40 | wheels/
41 | pip-wheel-metadata/
42 | share/python-wheels/
43 | *.egg-info/
44 | .installed.cfg
45 | *.egg
46 | MANIFEST
47 |
48 | # PyInstaller
49 | # Usually these files are written by a python script from a template
50 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
51 | *.manifest
52 | *.spec
53 |
54 | # Installer logs
55 | pip-log.txt
56 | pip-delete-this-directory.txt
57 |
58 | # Unit test / coverage reports
59 | htmlcov/
60 | .tox/
61 | .nox/
62 | .coverage
63 | .coverage.*
64 | .cache
65 | nosetests.xml
66 | coverage.xml
67 | *.cover
68 | *.py,cover
69 | .hypothesis/
70 | .pytest_cache/
71 |
72 | # Translations
73 | *.mo
74 | *.pot
75 |
76 | # Django stuff:
77 | *.log
78 | local_settings.py
79 | db.sqlite3
80 | db.sqlite3-journal
81 |
82 | # Flask stuff:
83 | instance/
84 | .webassets-cache
85 |
86 | # Scrapy stuff:
87 | .scrapy
88 |
89 | # Sphinx documentation
90 | docs/_build/
91 |
92 | # PyBuilder
93 | target/
94 |
95 | # Jupyter Notebook
96 | .ipynb_checkpoints
97 |
98 | # IPython
99 | profile_default/
100 | ipython_config.py
101 |
102 | # pyenv
103 | .python-version
104 |
105 | # pipenv
106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
109 | # install all needed dependencies.
110 | #Pipfile.lock
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
--------------------------------------------------------------------------------
/CLIP/.github/workflows/test.yml:
--------------------------------------------------------------------------------
1 | name: test
2 | on:
3 | push:
4 | branches:
5 | - main
6 | pull_request:
7 | branches:
8 | - main
9 | jobs:
10 | CLIP-test:
11 | runs-on: ubuntu-latest
12 | strategy:
13 | matrix:
14 | python-version: [3.7, 3.8]
15 | pytorch-version: [1.7.1, 1.9.0]
16 | steps:
17 | - uses: conda-incubator/setup-miniconda@v2
18 | - run: conda install -n test python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} torchvision cpuonly -c pytorch
19 | - uses: actions/checkout@v2
20 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH
21 | - run: pip install pytest
22 | - run: pip install .
23 | - run: pytest
24 |
--------------------------------------------------------------------------------
/CLIP/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | *.py[cod]
3 | *$py.class
4 | *.egg-info
5 | .pytest_cache
6 | .ipynb_checkpoints
7 |
8 | thumbs.db
9 | .DS_Store
10 | .idea
11 |
--------------------------------------------------------------------------------
/CLIP/CLIP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/CLIP/CLIP.png
--------------------------------------------------------------------------------
/CLIP/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 OpenAI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
23 |
--------------------------------------------------------------------------------
/CLIP/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include clip/bpe_simple_vocab_16e6.txt.gz
2 |
--------------------------------------------------------------------------------
/CLIP/README.md:
--------------------------------------------------------------------------------
1 | # CLIP
2 |
3 | [[Blog]](https://openai.com/blog/clip/) [[Paper]](https://arxiv.org/abs/2103.00020) [[Model Card]](model-card.md) [[Colab]](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb)
4 |
5 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision.
6 |
7 |
8 |
9 | ## Approach
10 |
11 | 
12 |
13 |
14 |
15 | ## Usage
16 |
17 | First, [install PyTorch 1.7.1](https://pytorch.org/get-started/locally/) and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick:
18 |
19 | ```bash
20 | $ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0
21 | $ pip install ftfy regex tqdm
22 | $ pip install git+https://github.com/openai/CLIP.git
23 | ```
24 |
25 | Replace `cudatoolkit=11.0` above with the appropriate CUDA version on your machine or `cpuonly` when installing on a machine without a GPU.
26 |
27 | ```python
28 | import torch
29 | import clip
30 | from PIL import Image
31 |
32 | device = "cuda" if torch.cuda.is_available() else "cpu"
33 | model, preprocess = clip.load("ViT-B/32", device=device)
34 |
35 | image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
36 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
37 |
38 | with torch.no_grad():
39 | image_features = model.encode_image(image)
40 | text_features = model.encode_text(text)
41 |
42 | logits_per_image, logits_per_text = model(image, text)
43 | probs = logits_per_image.softmax(dim=-1).cpu().numpy()
44 |
45 | print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]]
46 | ```
47 |
48 |
49 | ## API
50 |
51 | The CLIP module `clip` provides the following methods:
52 |
53 | #### `clip.available_models()`
54 |
55 | Returns the names of the available CLIP models.
56 |
57 | #### `clip.load(name, device=..., jit=False)`
58 |
59 | Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint.
60 |
61 | The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded.
62 |
63 | #### `clip.tokenize(text: Union[str, List[str]], context_length=77)`
64 |
65 | Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model
66 |
67 | ---
68 |
69 | The model returned by `clip.load()` supports the following methods:
70 |
71 | #### `model.encode_image(image: Tensor)`
72 |
73 | Given a batch of images, returns the image features encoded by the vision portion of the CLIP model.
74 |
75 | #### `model.encode_text(text: Tensor)`
76 |
77 | Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model.
78 |
79 | #### `model(image: Tensor, text: Tensor)`
80 |
81 | Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100.
82 |
83 |
84 |
85 | ## More Examples
86 |
87 | ### Zero-Shot Prediction
88 |
89 | The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset.
90 |
91 | ```python
92 | import os
93 | import clip
94 | import torch
95 | from torchvision.datasets import CIFAR100
96 |
97 | # Load the model
98 | device = "cuda" if torch.cuda.is_available() else "cpu"
99 | model, preprocess = clip.load('ViT-B/32', device)
100 |
101 | # Download the dataset
102 | cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
103 |
104 | # Prepare the inputs
105 | image, class_id = cifar100[3637]
106 | image_input = preprocess(image).unsqueeze(0).to(device)
107 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
108 |
109 | # Calculate features
110 | with torch.no_grad():
111 | image_features = model.encode_image(image_input)
112 | text_features = model.encode_text(text_inputs)
113 |
114 | # Pick the top 5 most similar labels for the image
115 | image_features /= image_features.norm(dim=-1, keepdim=True)
116 | text_features /= text_features.norm(dim=-1, keepdim=True)
117 | similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
118 | values, indices = similarity[0].topk(5)
119 |
120 | # Print the result
121 | print("\nTop predictions:\n")
122 | for value, index in zip(values, indices):
123 | print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")
124 | ```
125 |
126 | The output will look like the following (the exact numbers may be slightly different depending on the compute device):
127 |
128 | ```
129 | Top predictions:
130 |
131 | snake: 65.31%
132 | turtle: 12.29%
133 | sweet_pepper: 3.83%
134 | lizard: 1.88%
135 | crocodile: 1.75%
136 | ```
137 |
138 | Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs.
139 |
140 |
141 | ### Linear-probe evaluation
142 |
143 | The example below uses [scikit-learn](https://scikit-learn.org/) to perform logistic regression on image features.
144 |
145 | ```python
146 | import os
147 | import clip
148 | import torch
149 |
150 | import numpy as np
151 | from sklearn.linear_model import LogisticRegression
152 | from torch.utils.data import DataLoader
153 | from torchvision.datasets import CIFAR100
154 | from tqdm import tqdm
155 |
156 | # Load the model
157 | device = "cuda" if torch.cuda.is_available() else "cpu"
158 | model, preprocess = clip.load('ViT-B/32', device)
159 |
160 | # Load the dataset
161 | root = os.path.expanduser("~/.cache")
162 | train = CIFAR100(root, download=True, train=True, transform=preprocess)
163 | test = CIFAR100(root, download=True, train=False, transform=preprocess)
164 |
165 |
166 | def get_features(dataset):
167 | all_features = []
168 | all_labels = []
169 |
170 | with torch.no_grad():
171 | for images, labels in tqdm(DataLoader(dataset, batch_size=100)):
172 | features = model.encode_image(images.to(device))
173 |
174 | all_features.append(features)
175 | all_labels.append(labels)
176 |
177 | return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy()
178 |
179 | # Calculate the image features
180 | train_features, train_labels = get_features(train)
181 | test_features, test_labels = get_features(test)
182 |
183 | # Perform logistic regression
184 | classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1)
185 | classifier.fit(train_features, train_labels)
186 |
187 | # Evaluate using the logistic regression classifier
188 | predictions = classifier.predict(test_features)
189 | accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100.
190 | print(f"Accuracy = {accuracy:.3f}")
191 | ```
192 |
193 | Note that the `C` value should be determined via a hyperparameter sweep using a validation split.
194 |
--------------------------------------------------------------------------------
/CLIP/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 | from .clip_gradcam import ClipGradcam
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import torchvision
7 | from functools import reduce
8 |
9 |
10 | def factors(n):
11 | return set(
12 | reduce(
13 | list.__add__,
14 | ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0),
15 | )
16 | )
17 |
18 |
19 | saliency_configs = {
20 | "ours": lambda img_dim: {
21 | "distractor_labels": {},
22 | "horizontal_flipping": True,
23 | "augmentations": 5,
24 | "imagenet_prompt_ensemble": False,
25 | "positive_attn_only": True,
26 | "cropping_augmentations": [
27 | {"tile_size": img_dim, "stride": img_dim // 4},
28 | {"tile_size": int(img_dim * 2 / 3), "stride": int(img_dim * 2 / 3) // 4},
29 | {"tile_size": img_dim // 2, "stride": (img_dim // 2) // 4},
30 | {"tile_size": img_dim // 4, "stride": (img_dim // 4) // 4},
31 | ],
32 | },
33 | "chefer_et_al": lambda img_dim: {
34 | "distractor_labels": {},
35 | "horizontal_flipping": False,
36 | "augmentations": 0,
37 | "imagenet_prompt_ensemble": False,
38 | "positive_attn_only": True,
39 | "cropping_augmentations": [{"tile_size": img_dim, "stride": img_dim // 4}],
40 | },
41 | }
42 |
43 |
44 | class ClipWrapper:
45 | # SINGLETON WRAPPER
46 | clip_model = None
47 | clip_preprocess = None
48 | clip_gradcam = None
49 | lavt = None
50 | device = None
51 | jittering_transforms = None
52 |
53 | def __init__(self, clip_model_type, device, **kwargs):
54 | ClipWrapper.device = device
55 | ClipWrapper.jittering_transforms = torchvision.transforms.ColorJitter(
56 | brightness=0.6, contrast=0.6, saturation=0.6, hue=0.1
57 | )
58 | ClipWrapper.clip_model, ClipWrapper.clip_preprocess = load(
59 | clip_model_type, ClipWrapper.device, **kwargs
60 | )
61 | ClipWrapper.clip_gradcam = ClipGradcam(
62 | clip_model_name=clip_model_type,
63 | classes=[""],
64 | templates=["{}"],
65 | device=ClipWrapper.device,
66 | **kwargs
67 | )
68 |
69 | @classmethod
70 | def check_initialized(cls, clip_model_type="ViT-B/32", **kwargs):
71 | if cls.clip_gradcam is None:
72 | ClipWrapper(
73 | clip_model_type=clip_model_type,
74 | device="cuda" if torch.cuda.is_available() else "cpu",
75 | **kwargs
76 | )
77 |
78 | @classmethod
79 | def get_clip_text_feature(cls, string):
80 | ClipWrapper.check_initialized()
81 | with torch.no_grad():
82 | return (
83 | cls.clip_model.encode_text(
84 | tokenize(string, context_length=77).to(cls.device)
85 | )
86 | .squeeze()
87 | .cpu()
88 | .numpy()
89 | )
90 |
91 | @classmethod
92 | def get_visual_feature(cls, rgb, tile_attn_mask, device=None):
93 | if device is None:
94 | device = ClipWrapper.device
95 | ClipWrapper.check_initialized()
96 | rgb = ClipWrapper.clip_preprocess(Image.fromarray(rgb)).unsqueeze(0)
97 | with torch.no_grad():
98 | clip_feature = ClipWrapper.clip_model.encode_image(
99 | rgb.to(ClipWrapper.device), tile_attn_mask=tile_attn_mask
100 | ).squeeze()
101 | return clip_feature.to(device)
102 |
103 | @classmethod
104 | def get_clip_saliency(
105 | cls,
106 | img,
107 | text_labels,
108 | prompts,
109 | distractor_labels=set(),
110 | use_lavt=False,
111 | **kwargs
112 | ):
113 | cls.check_initialized()
114 | if use_lavt:
115 | return cls.lavt.localize(img=img, prompts=text_labels)
116 | cls.clip_gradcam.templates = prompts
117 | cls.clip_gradcam.set_classes(text_labels)
118 | text_label_features = torch.stack(
119 | list(cls.clip_gradcam.class_to_language_feature.values()), dim=0
120 | )
121 | text_label_features = text_label_features.squeeze(dim=-1).cpu()
122 | text_maps = cls.get_clip_saliency_convolve(
123 | img=img, text_labels=text_labels, **kwargs
124 | )
125 | if len(distractor_labels) > 0:
126 | distractor_labels = set(distractor_labels) - set(text_labels)
127 | cls.clip_gradcam.set_classes(list(distractor_labels))
128 | distractor_maps = cls.get_clip_saliency_convolve(
129 | img=img, text_labels=list(distractor_labels), **kwargs
130 | )
131 | text_maps -= distractor_maps.mean(dim=0)
132 | text_maps = text_maps.cpu()
133 | return text_maps, text_label_features.squeeze(dim=-1)
134 |
135 | @classmethod
136 | def get_clip_saliency_convolve(
137 | cls,
138 | text_labels,
139 | horizontal_flipping=False,
140 | positive_attn_only: bool = False,
141 | tile_batch_size=32,
142 | prompt_batch_size=32,
143 | tile_interpolate_batch_size=32,
144 | **kwargs
145 | ):
146 | cls.clip_gradcam.positive_attn_only = positive_attn_only
147 | tiles, tile_imgs, counts, tile_sizes = cls.create_tiles(**kwargs)
148 | outputs = {
149 | k: torch.zeros(
150 | [len(text_labels)] + list(count.shape), device=cls.device
151 | ).half()
152 | for k, count in counts.items()
153 | }
154 | tile_gradcams = torch.cat(
155 | [
156 | torch.cat(
157 | [
158 | cls.clip_gradcam(
159 | x=tile_imgs[tile_idx : tile_idx + tile_batch_size],
160 | o=text_labels[prompt_idx : prompt_idx + prompt_batch_size],
161 | )
162 | for tile_idx in np.arange(0, len(tile_imgs), tile_batch_size)
163 | ],
164 | dim=1,
165 | )
166 | for prompt_idx in np.arange(0, len(text_labels), prompt_batch_size)
167 | ],
168 | dim=0,
169 | )
170 | if horizontal_flipping:
171 | flipped_tile_imgs = tile_imgs[
172 | ..., torch.flip(torch.arange(0, tile_imgs.shape[-1]), dims=[0])
173 | ]
174 | flipped_tile_gradcams = torch.cat(
175 | [
176 | torch.cat(
177 | [
178 | cls.clip_gradcam(
179 | x=flipped_tile_imgs[
180 | tile_idx : tile_idx + tile_batch_size
181 | ],
182 | o=text_labels[
183 | prompt_idx : prompt_idx + prompt_batch_size
184 | ],
185 | )
186 | for tile_idx in np.arange(
187 | 0, len(tile_imgs), tile_batch_size
188 | )
189 | ],
190 | dim=1,
191 | )
192 | for prompt_idx in np.arange(0, len(text_labels), prompt_batch_size)
193 | ],
194 | dim=0,
195 | )
196 | with torch.no_grad():
197 | flipped_tile_gradcams = flipped_tile_gradcams[
198 | ...,
199 | torch.flip(
200 | torch.arange(0, flipped_tile_gradcams.shape[-1]), dims=[0]
201 | ),
202 | ]
203 | tile_gradcams = (tile_gradcams + flipped_tile_gradcams) / 2
204 | del flipped_tile_gradcams
205 | with torch.no_grad():
206 | torch.cuda.empty_cache()
207 | for tile_size in np.unique(tile_sizes):
208 | tile_size_mask = tile_sizes == tile_size
209 | curr_size_grads = tile_gradcams[:, tile_size_mask]
210 | curr_size_tiles = tiles[tile_size_mask]
211 | for tile_idx in np.arange(
212 | 0, curr_size_grads.shape[1], tile_interpolate_batch_size
213 | ):
214 | resized_tiles = torch.nn.functional.interpolate(
215 | curr_size_grads[
216 | :, tile_idx : tile_idx + tile_interpolate_batch_size
217 | ],
218 | size=tile_size,
219 | mode="bilinear",
220 | align_corners=False,
221 | )
222 | for tile_idx, tile_slice in enumerate(
223 | curr_size_tiles[
224 | tile_idx : tile_idx + tile_interpolate_batch_size
225 | ]
226 | ):
227 | outputs[tile_size][tile_slice] += resized_tiles[
228 | :, tile_idx, ...
229 | ]
230 | output = sum(
231 | output.float() / count
232 | for output, count in zip(outputs.values(), counts.values())
233 | ) / len(counts)
234 | del outputs, counts, tile_gradcams
235 | output = output.cpu()
236 | return output
237 |
238 | @classmethod
239 | def create_tiles(cls, img, augmentations, cropping_augmentations, **kwargs):
240 | assert type(img) == np.ndarray
241 | images = []
242 | cls.check_initialized()
243 | # compute image crops
244 | img_pil = Image.fromarray(img)
245 | images.append(np.array(img_pil))
246 | for _ in range(augmentations):
247 | images.append(np.array(cls.jittering_transforms(img_pil)))
248 | # for taking average
249 | counts = {
250 | crop_aug["tile_size"]: torch.zeros(img.shape[:2], device=cls.device).float()
251 | + 1e-5
252 | for crop_aug in cropping_augmentations
253 | }
254 | tiles = []
255 | tile_imgs = []
256 | tile_sizes = []
257 | for img in images:
258 | for crop_aug in cropping_augmentations:
259 | tile_size = crop_aug["tile_size"]
260 | stride = crop_aug["stride"]
261 | for y in np.arange(0, img.shape[1] - tile_size + 1, stride):
262 | if y >= img.shape[0]:
263 | continue
264 | for x in np.arange(0, img.shape[0] - tile_size + 1, stride):
265 | if x >= img.shape[1]:
266 | continue
267 | tile = (
268 | slice(None, None),
269 | slice(x, x + tile_size),
270 | slice(y, y + tile_size),
271 | )
272 | tiles.append(tile)
273 | counts[tile_size][tile[1:]] += 1
274 | tile_sizes.append(tile_size)
275 | # this is currently biggest bottle neck
276 | tile_imgs.append(
277 | cls.clip_gradcam.preprocess(
278 | Image.fromarray(img[tiles[-1][1:]])
279 | )
280 | )
281 | tile_imgs = torch.stack(tile_imgs).to(cls.device)
282 | return np.array(tiles), tile_imgs, counts, np.array(tile_sizes)
283 |
284 |
285 | imagenet_templates = [
286 | "a bad photo of a {}.",
287 | "a photo of many {}.",
288 | "a sculpture of a {}.",
289 | "a photo of the hard to see {}.",
290 | "a low resolution photo of the {}.",
291 | "a rendering of a {}.",
292 | "graffiti of a {}.",
293 | "a bad photo of the {}.",
294 | "a cropped photo of the {}.",
295 | "a tattoo of a {}.",
296 | "the embroidered {}.",
297 | "a photo of a hard to see {}.",
298 | "a bright photo of a {}.",
299 | "a photo of a clean {}.",
300 | "a photo of a dirty {}.",
301 | "a dark photo of the {}.",
302 | "a drawing of a {}.",
303 | "a photo of my {}.",
304 | "the plastic {}.",
305 | "a photo of the cool {}.",
306 | "a close-up photo of a {}.",
307 | "a black and white photo of the {}.",
308 | "a painting of the {}.",
309 | "a painting of a {}.",
310 | "a pixelated photo of the {}.",
311 | "a sculpture of the {}.",
312 | "a bright photo of the {}.",
313 | "a cropped photo of a {}.",
314 | "a plastic {}.",
315 | "a photo of the dirty {}.",
316 | "a jpeg corrupted photo of a {}.",
317 | "a blurry photo of the {}.",
318 | "a photo of the {}.",
319 | "a good photo of the {}.",
320 | "a rendering of the {}.",
321 | "a {} in a video game.",
322 | "a photo of one {}.",
323 | "a doodle of a {}.",
324 | "a close-up photo of the {}.",
325 | "a photo of a {}.",
326 | "the origami {}.",
327 | "the {} in a video game.",
328 | "a sketch of a {}.",
329 | "a doodle of the {}.",
330 | "a origami {}.",
331 | "a low resolution photo of a {}.",
332 | "the toy {}.",
333 | "a rendition of the {}.",
334 | "a photo of the clean {}.",
335 | "a photo of a large {}.",
336 | "a rendition of a {}.",
337 | "a photo of a nice {}.",
338 | "a photo of a weird {}.",
339 | "a blurry photo of a {}.",
340 | "a cartoon {}.",
341 | "art of a {}.",
342 | "a sketch of the {}.",
343 | "a embroidered {}.",
344 | "a pixelated photo of a {}.",
345 | "itap of the {}.",
346 | "a jpeg corrupted photo of the {}.",
347 | "a good photo of a {}.",
348 | "a plushie {}.",
349 | "a photo of the nice {}.",
350 | "a photo of the small {}.",
351 | "a photo of the weird {}.",
352 | "the cartoon {}.",
353 | "art of the {}.",
354 | "a drawing of the {}.",
355 | "a photo of the large {}.",
356 | "a black and white photo of a {}.",
357 | "the plushie {}.",
358 | "a dark photo of a {}.",
359 | "itap of a {}.",
360 | "graffiti of the {}.",
361 | "a toy {}.",
362 | "itap of my {}.",
363 | "a photo of a cool {}.",
364 | "a photo of a small {}.",
365 | "a tattoo of the {}.",
366 | ]
367 |
368 | __all__ = ["ClipWrapper", "imagenet_templates"]
369 |
--------------------------------------------------------------------------------
/CLIP/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/CLIP/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/CLIP/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 |
7 | import torch
8 | from PIL import Image
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10 | from tqdm import tqdm
11 |
12 | from .model import build_model
13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14 |
15 | try:
16 | from torchvision.transforms import InterpolationMode
17 |
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | # if [int(i) for i in torch.__version__.split(".")] < [1, 7, 1]:
24 | # warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 | __all__ = ["available_models", "load", "tokenize", "tokenizer"]
27 | tokenizer = _Tokenizer()
28 |
29 | _MODELS = {
30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
37 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
38 | }
39 |
40 |
41 | def _download(url: str, root: str):
42 | os.makedirs(root, exist_ok=True)
43 | filename = os.path.basename(url)
44 |
45 | expected_sha256 = url.split("/")[-2]
46 | download_target = os.path.join(root, filename)
47 |
48 | if os.path.exists(download_target) and not os.path.isfile(download_target):
49 | raise RuntimeError(f"{download_target} exists and is not a regular file")
50 |
51 | if os.path.isfile(download_target):
52 | if (
53 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
54 | == expected_sha256
55 | ):
56 | return download_target
57 | else:
58 | warnings.warn(
59 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
60 | )
61 |
62 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
63 | with tqdm(
64 | total=int(source.info().get("Content-Length")),
65 | ncols=80,
66 | unit="iB",
67 | unit_scale=True,
68 | unit_divisor=1024,
69 | ) as loop:
70 | while True:
71 | buffer = source.read(8192)
72 | if not buffer:
73 | break
74 |
75 | output.write(buffer)
76 | loop.update(len(buffer))
77 |
78 | if (
79 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
80 | != expected_sha256
81 | ):
82 | raise RuntimeError(
83 | f"Model has been downloaded but the SHA256 checksum does not not match"
84 | )
85 |
86 | return download_target
87 |
88 |
89 | def _convert_image_to_rgb(image):
90 | return image.convert("RGB")
91 |
92 |
93 | def _transform(n_px, overload_resolution=False):
94 | transforms = [
95 | _convert_image_to_rgb,
96 | ToTensor(),
97 | Normalize(
98 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
99 | ),
100 | ]
101 | if not overload_resolution:
102 | transforms = [Resize(224, interpolation=BICUBIC), CenterCrop(n_px)] + transforms
103 | return Compose(transforms)
104 |
105 |
106 | def available_models() -> List[str]:
107 | """Returns the names of available CLIP models"""
108 | return list(_MODELS.keys())
109 |
110 |
111 | def load(
112 | name: str,
113 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
114 | jit: bool = False,
115 | download_root: str = None,
116 | overload_resolution=False,
117 | ):
118 | """Load a CLIP model
119 | Parameters
120 | ----------
121 | name : str
122 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
123 | device : Union[str, torch.device]
124 | The device to put the loaded model
125 | jit : bool
126 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
127 | download_root: str
128 | path to download the model files; by default, it uses "~/.cache/clip"
129 | Returns
130 | -------
131 | model : torch.nn.Module
132 | The CLIP model
133 | preprocess : Callable[[PIL.Image], torch.Tensor]
134 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
135 | """
136 | if name in _MODELS:
137 | model_path = _download(
138 | _MODELS[name], download_root or os.path.expanduser("~/.cache/clip")
139 | )
140 | elif os.path.isfile(name):
141 | model_path = name
142 | else:
143 | raise RuntimeError(
144 | f"Model {name} not found; available models = {available_models()}"
145 | )
146 |
147 | try:
148 | # loading JIT archive
149 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
150 | state_dict = None
151 | except RuntimeError:
152 | # loading saved state dict
153 | if jit:
154 | warnings.warn(
155 | f"File {model_path} is not a JIT archive. Loading as a state dict instead"
156 | )
157 | jit = False
158 | state_dict = torch.load(model_path, map_location="cpu")
159 |
160 | if not jit:
161 | model = build_model(state_dict or model.state_dict()).to(device)
162 | if str(device) == "cpu":
163 | model.float()
164 | return model, _transform(model.visual.input_resolution, overload_resolution)
165 |
166 | # patch the device names
167 | device_holder = torch.jit.trace(
168 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
169 | )
170 | device_node = [
171 | n
172 | for n in device_holder.graph.findAllNodes("prim::Constant")
173 | if "Device" in repr(n)
174 | ][-1]
175 |
176 | def patch_device(module):
177 | try:
178 | graphs = [module.graph] if hasattr(module, "graph") else []
179 | except RuntimeError:
180 | graphs = []
181 |
182 | if hasattr(module, "forward1"):
183 | graphs.append(module.forward1.graph)
184 |
185 | for graph in graphs:
186 | for node in graph.findAllNodes("prim::Constant"):
187 | if "value" in node.attributeNames() and str(node["value"]).startswith(
188 | "cuda"
189 | ):
190 | node.copyAttributes(device_node)
191 |
192 | model.apply(patch_device)
193 | patch_device(model.encode_image)
194 | patch_device(model.encode_text)
195 |
196 | # patch dtype to float32 on CPU
197 | if str(device) == "cpu":
198 | float_holder = torch.jit.trace(
199 | lambda: torch.ones([]).float(), example_inputs=[]
200 | )
201 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
202 | float_node = float_input.node()
203 |
204 | def patch_float(module):
205 | try:
206 | graphs = [module.graph] if hasattr(module, "graph") else []
207 | except RuntimeError:
208 | graphs = []
209 |
210 | if hasattr(module, "forward1"):
211 | graphs.append(module.forward1.graph)
212 |
213 | for graph in graphs:
214 | for node in graph.findAllNodes("aten::to"):
215 | inputs = list(node.inputs())
216 | for i in [
217 | 1,
218 | 2,
219 | ]: # dtype can be the second or third argument to aten::to()
220 | if inputs[i].node()["value"] == 5:
221 | inputs[i].node().copyAttributes(float_node)
222 |
223 | model.apply(patch_float)
224 | patch_float(model.encode_image)
225 | patch_float(model.encode_text)
226 |
227 | model.float()
228 |
229 | return model, _transform(model.input_resolution.item(), overload_resolution)
230 |
231 |
232 | def tokenize(
233 | texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
234 | ) -> torch.LongTensor:
235 | """
236 | Returns the tokenized representation of given input string(s)
237 |
238 | Parameters
239 | ----------
240 | texts : Union[str, List[str]]
241 | An input string or a list of input strings to tokenize
242 |
243 | context_length : int
244 | The context length to use; all CLIP models use 77 as the context length
245 |
246 | truncate: bool
247 | Whether to truncate the text in case its encoding is longer than the context length
248 |
249 | Returns
250 | -------
251 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
252 | """
253 | if isinstance(texts, str):
254 | texts = [texts]
255 |
256 | sot_token = tokenizer.encoder["<|startoftext|>"]
257 | eot_token = tokenizer.encoder["<|endoftext|>"]
258 | all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] for text in texts]
259 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
260 |
261 | for i, tokens in enumerate(all_tokens):
262 | if len(tokens) > context_length:
263 | if truncate:
264 | tokens = tokens[:context_length]
265 | tokens[-1] = eot_token
266 | else:
267 | raise RuntimeError(
268 | f"Input {texts[i]} is too long for context length {context_length}"
269 | )
270 | result[i, : len(tokens)] = torch.tensor(tokens)
271 |
272 | return result
273 |
--------------------------------------------------------------------------------
/CLIP/clip/clip_explainability.py:
--------------------------------------------------------------------------------
1 | # modified from: https://github.com/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP/clip/clip.py
2 |
3 | import hashlib
4 | import os
5 | import urllib
6 | import warnings
7 | from typing import Any, Union, List
8 | from pkg_resources import packaging
9 |
10 | import torch
11 | from PIL import Image
12 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
13 | from tqdm import tqdm
14 |
15 | from .model_explainability import build_model
16 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
17 |
18 | try:
19 | from torchvision.transforms import InterpolationMode
20 |
21 | BICUBIC = InterpolationMode.BICUBIC
22 | except ImportError:
23 | BICUBIC = Image.BICUBIC
24 |
25 |
26 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
27 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
28 |
29 |
30 | __all__ = ["available_models", "load", "tokenize"]
31 | _tokenizer = _Tokenizer()
32 |
33 |
34 | _MODELS = {
35 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
36 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
37 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
38 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
39 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
40 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
41 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
42 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
43 | }
44 |
45 |
46 | def _download(url: str, root: str):
47 | os.makedirs(root, exist_ok=True)
48 | filename = os.path.basename(url)
49 |
50 | expected_sha256 = url.split("/")[-2]
51 | download_target = os.path.join(root, filename)
52 |
53 | if os.path.exists(download_target) and not os.path.isfile(download_target):
54 | raise RuntimeError(f"{download_target} exists and is not a regular file")
55 |
56 | if os.path.isfile(download_target):
57 | if (
58 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
59 | == expected_sha256
60 | ):
61 | return download_target
62 | else:
63 | warnings.warn(
64 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
65 | )
66 |
67 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
68 | with tqdm(
69 | total=int(source.info().get("Content-Length")),
70 | ncols=80,
71 | unit="iB",
72 | unit_scale=True,
73 | unit_divisor=1024,
74 | ) as loop:
75 | while True:
76 | buffer = source.read(8192)
77 | if not buffer:
78 | break
79 |
80 | output.write(buffer)
81 | loop.update(len(buffer))
82 |
83 | if (
84 | hashlib.sha256(open(download_target, "rb").read()).hexdigest()
85 | != expected_sha256
86 | ):
87 | raise RuntimeError(
88 | f"Model has been downloaded but the SHA256 checksum does not not match"
89 | )
90 |
91 | return download_target
92 |
93 |
94 | def _convert_image_to_rgb(image):
95 | return image.convert("RGB")
96 |
97 |
98 | def _transform(n_px, overload_resolution=False):
99 | transforms = [
100 | _convert_image_to_rgb,
101 | ToTensor(),
102 | Normalize(
103 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
104 | ),
105 | ]
106 | if not overload_resolution:
107 | transforms = [Resize(224, interpolation=BICUBIC), CenterCrop(n_px)] + transforms
108 | return Compose(transforms)
109 |
110 |
111 | def available_models() -> List[str]:
112 | """Returns the names of available CLIP models"""
113 | return list(_MODELS.keys())
114 |
115 |
116 | def load(
117 | name: str,
118 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
119 | jit: bool = False,
120 | download_root: str = None,
121 | overload_resolution=False,
122 | ):
123 | """Load a CLIP model
124 | Parameters
125 | ----------
126 | name : str
127 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
128 | device : Union[str, torch.device]
129 | The device to put the loaded model
130 | jit : bool
131 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
132 | download_root: str
133 | path to download the model files; by default, it uses "~/.cache/clip"
134 | Returns
135 | -------
136 | model : torch.nn.Module
137 | The CLIP model
138 | preprocess : Callable[[PIL.Image], torch.Tensor]
139 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
140 | """
141 | if name in _MODELS:
142 | model_path = _download(
143 | _MODELS[name], download_root or os.path.expanduser("~/.cache/clip")
144 | )
145 | elif os.path.isfile(name):
146 | model_path = name
147 | else:
148 | raise RuntimeError(
149 | f"Model {name} not found; available models = {available_models()}"
150 | )
151 |
152 | try:
153 | # loading JIT archive
154 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
155 | state_dict = None
156 | except RuntimeError:
157 | # loading saved state dict
158 | if jit:
159 | warnings.warn(
160 | f"File {model_path} is not a JIT archive. Loading as a state dict instead"
161 | )
162 | jit = False
163 | state_dict = torch.load(model_path, map_location="cpu")
164 |
165 | if not jit:
166 | model = build_model(state_dict or model.state_dict()).to(device)
167 | if str(device) == "cpu":
168 | model.float()
169 | return model, _transform(model.visual.input_resolution, overload_resolution)
170 |
171 | # patch the device names
172 | device_holder = torch.jit.trace(
173 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
174 | )
175 | device_node = [
176 | n
177 | for n in device_holder.graph.findAllNodes("prim::Constant")
178 | if "Device" in repr(n)
179 | ][-1]
180 |
181 | def patch_device(module):
182 | try:
183 | graphs = [module.graph] if hasattr(module, "graph") else []
184 | except RuntimeError:
185 | graphs = []
186 |
187 | if hasattr(module, "forward1"):
188 | graphs.append(module.forward1.graph)
189 |
190 | for graph in graphs:
191 | for node in graph.findAllNodes("prim::Constant"):
192 | if "value" in node.attributeNames() and str(node["value"]).startswith(
193 | "cuda"
194 | ):
195 | node.copyAttributes(device_node)
196 |
197 | model.apply(patch_device)
198 | patch_device(model.encode_image)
199 | patch_device(model.encode_text)
200 |
201 | # patch dtype to float32 on CPU
202 | if str(device) == "cpu":
203 | float_holder = torch.jit.trace(
204 | lambda: torch.ones([]).float(), example_inputs=[]
205 | )
206 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
207 | float_node = float_input.node()
208 |
209 | def patch_float(module):
210 | try:
211 | graphs = [module.graph] if hasattr(module, "graph") else []
212 | except RuntimeError:
213 | graphs = []
214 |
215 | if hasattr(module, "forward1"):
216 | graphs.append(module.forward1.graph)
217 |
218 | for graph in graphs:
219 | for node in graph.findAllNodes("aten::to"):
220 | inputs = list(node.inputs())
221 | for i in [
222 | 1,
223 | 2,
224 | ]: # dtype can be the second or third argument to aten::to()
225 | if inputs[i].node()["value"] == 5:
226 | inputs[i].node().copyAttributes(float_node)
227 |
228 | model.apply(patch_float)
229 | patch_float(model.encode_image)
230 | patch_float(model.encode_text)
231 |
232 | model.float()
233 |
234 | return model, _transform(model.input_resolution.item(), overload_resolution)
235 |
236 |
237 | def tokenize(
238 | texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
239 | ) -> torch.LongTensor:
240 | """
241 | Returns the tokenized representation of given input string(s)
242 | Parameters
243 | ----------
244 | texts : Union[str, List[str]]
245 | An input string or a list of input strings to tokenize
246 | context_length : int
247 | The context length to use; all CLIP models use 77 as the context length
248 | truncate: bool
249 | Whether to truncate the text in case its encoding is longer than the context length
250 | Returns
251 | -------
252 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
253 | """
254 | if isinstance(texts, str):
255 | texts = [texts]
256 |
257 | sot_token = _tokenizer.encoder["<|startoftext|>"]
258 | eot_token = _tokenizer.encoder["<|endoftext|>"]
259 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
260 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
261 |
262 | for i, tokens in enumerate(all_tokens):
263 | if len(tokens) > context_length:
264 | if truncate:
265 | tokens = tokens[:context_length]
266 | tokens[-1] = eot_token
267 | else:
268 | raise RuntimeError(
269 | f"Input {texts[i]} is too long for context length {context_length}"
270 | )
271 | result[i, : len(tokens)] = torch.tensor(tokens)
272 |
273 | return result
274 |
--------------------------------------------------------------------------------
/CLIP/clip/clip_gradcam.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import torch
3 | import torch.nn as nn
4 | from .clip_explainability import load
5 | from .clip import tokenize
6 | from torch import device
7 | import numpy as np
8 | import torch.nn.functional as nnf
9 | import itertools
10 |
11 |
12 | def zeroshot_classifier(clip_model, classnames, templates, device):
13 | with torch.no_grad():
14 | texts = list(
15 | itertools.chain(
16 | *[
17 | [template.format(classname) for template in templates]
18 | for classname in classnames
19 | ]
20 | )
21 | ) # format with class
22 | texts = tokenize(texts).to(device) # tokenize
23 | class_embeddings = clip_model.encode_text(texts)
24 | class_embeddings = class_embeddings.view(len(classnames), len(templates), -1)
25 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
26 | zeroshot_weights = class_embeddings.mean(dim=1)
27 | return zeroshot_weights.T # shape: [dim, n classes]
28 |
29 |
30 | class ClipGradcam(nn.Module):
31 | def __init__(
32 | self,
33 | clip_model_name: str,
34 | classes: List[str],
35 | templates: List[str],
36 | device: device,
37 | num_layers=10,
38 | positive_attn_only=False,
39 | **kwargs
40 | ):
41 |
42 | super(ClipGradcam, self).__init__()
43 | self.clip_model_name = clip_model_name
44 | self.model, self.preprocess = load(clip_model_name, device=device, **kwargs)
45 | self.templates = templates
46 | self.device = device
47 | self.target_classes = None
48 | self.set_classes(classes)
49 | self.num_layers = num_layers
50 | self.positive_attn_only = positive_attn_only
51 | self.num_res_attn_blocks = {
52 | "ViT-B/32": 12,
53 | "ViT-B/16": 12,
54 | "ViT-L/14": 16,
55 | "ViT-L/14@336px": 16,
56 | }[clip_model_name]
57 |
58 | def forward(self, x: torch.Tensor, o: List[str]):
59 | """
60 | non-standard hack around an nn, really should be more principled here
61 | """
62 | image_features = self.model.encode_image(x.to(self.device))
63 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
64 | zeroshot_weights = torch.cat(
65 | [self.class_to_language_feature[prompt] for prompt in o], dim=1
66 | )
67 | logits_per_image = 100.0 * image_features @ zeroshot_weights
68 | return self.interpret(logits_per_image, self.model, self.device)
69 |
70 | def interpret(self, logits_per_image, model, device):
71 | # modified from: https://colab.research.google.com/github/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb#scrollTo=fWKGyu2YAeSV
72 | batch_size = logits_per_image.shape[0]
73 | num_prompts = logits_per_image.shape[1]
74 | one_hot = [logit for logit in logits_per_image.sum(dim=0)]
75 | model.zero_grad()
76 |
77 | image_attn_blocks = list(
78 | dict(model.visual.transformer.resblocks.named_children()).values()
79 | )
80 | num_tokens = image_attn_blocks[0].attn_probs.shape[-1]
81 | R = torch.eye(
82 | num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype
83 | ).to(device)
84 | R = R[None, None, :, :].repeat(num_prompts, batch_size, 1, 1)
85 | for i, block in enumerate(image_attn_blocks):
86 | if i <= self.num_layers:
87 | continue
88 | # TODO try scaling block.attn_probs by value magnitude
89 | # TODO actual parallelized prompt gradients
90 | grad = torch.stack(
91 | [
92 | torch.autograd.grad(logit, [block.attn_probs], retain_graph=True)[
93 | 0
94 | ].detach()
95 | for logit in one_hot
96 | ]
97 | )
98 | grad = grad.view(
99 | num_prompts,
100 | batch_size,
101 | self.num_res_attn_blocks,
102 | num_tokens,
103 | num_tokens,
104 | )
105 | cam = (
106 | block.attn_probs.view(
107 | 1, batch_size, self.num_res_attn_blocks, num_tokens, num_tokens
108 | )
109 | .detach()
110 | .repeat(num_prompts, 1, 1, 1, 1)
111 | )
112 | cam = cam.reshape(num_prompts, batch_size, -1, cam.shape[-1], cam.shape[-1])
113 | grad = grad.reshape(
114 | num_prompts, batch_size, -1, grad.shape[-1], grad.shape[-1]
115 | )
116 | cam = grad * cam
117 | cam = cam.reshape(
118 | num_prompts * batch_size, -1, cam.shape[-1], cam.shape[-1]
119 | )
120 | if self.positive_attn_only:
121 | cam = cam.clamp(min=0)
122 | # average of all heads
123 | cam = cam.mean(dim=-3)
124 | R = R + torch.bmm(
125 | cam, R.view(num_prompts * batch_size, num_tokens, num_tokens)
126 | ).view(num_prompts, batch_size, num_tokens, num_tokens)
127 | image_relevance = R[:, :, 0, 1:]
128 | img_dim = int(np.sqrt(num_tokens - 1))
129 | image_relevance = image_relevance.reshape(
130 | num_prompts, batch_size, img_dim, img_dim
131 | )
132 | return image_relevance
133 |
134 | def set_classes(self, classes):
135 | self.target_classes = classes
136 | language_features = zeroshot_classifier(
137 | self.model, self.target_classes, self.templates, self.device
138 | )
139 |
140 | self.class_to_language_feature = {}
141 | for i, c in enumerate(self.target_classes):
142 | self.class_to_language_feature[c] = language_features[:, [i]]
143 |
--------------------------------------------------------------------------------
/CLIP/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(
13 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
14 | )
15 |
16 |
17 | @lru_cache()
18 | def bytes_to_unicode():
19 | """
20 | Returns list of utf-8 byte and a corresponding list of unicode strings.
21 | The reversible bpe codes work on unicode strings.
22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
24 | This is a signficant percentage of your normal, say, 32K bpe vocab.
25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
26 | And avoids mapping to whitespace/control characters the bpe code barfs on.
27 | """
28 | bs = (
29 | list(range(ord("!"), ord("~") + 1))
30 | + list(range(ord("¡"), ord("¬") + 1))
31 | + list(range(ord("®"), ord("ÿ") + 1))
32 | )
33 | cs = bs[:]
34 | n = 0
35 | for b in range(2**8):
36 | if b not in bs:
37 | bs.append(b)
38 | cs.append(2**8 + n)
39 | n += 1
40 | cs = [chr(n) for n in cs]
41 | return dict(zip(bs, cs))
42 |
43 |
44 | def get_pairs(word):
45 | """Return set of symbol pairs in a word.
46 | Word is represented as tuple of symbols (symbols being variable-length strings).
47 | """
48 | pairs = set()
49 | prev_char = word[0]
50 | for char in word[1:]:
51 | pairs.add((prev_char, char))
52 | prev_char = char
53 | return pairs
54 |
55 |
56 | def basic_clean(text):
57 | text = ftfy.fix_text(text)
58 | text = html.unescape(html.unescape(text))
59 | return text.strip()
60 |
61 |
62 | def whitespace_clean(text):
63 | text = re.sub(r"\s+", " ", text)
64 | text = text.strip()
65 | return text
66 |
67 |
68 | class SimpleTokenizer(object):
69 | def __init__(self, bpe_path: str = default_bpe()):
70 | self.byte_encoder = bytes_to_unicode()
71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
72 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
73 | merges = merges[1 : 49152 - 256 - 2 + 1]
74 | merges = [tuple(merge.split()) for merge in merges]
75 | vocab = list(bytes_to_unicode().values())
76 | vocab = vocab + [v + "" for v in vocab]
77 | for merge in merges:
78 | vocab.append("".join(merge))
79 | vocab.extend(["<|startoftext|>", "<|endoftext|>"])
80 | self.encoder = dict(zip(vocab, range(len(vocab))))
81 | self.decoder = {v: k for k, v in self.encoder.items()}
82 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
83 | self.cache = {
84 | "<|startoftext|>": "<|startoftext|>",
85 | "<|endoftext|>": "<|endoftext|>",
86 | }
87 | self.pat = re.compile(
88 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
89 | re.IGNORECASE,
90 | )
91 |
92 | def bpe(self, token):
93 | if token in self.cache:
94 | return self.cache[token]
95 | word = tuple(token[:-1]) + (token[-1] + "",)
96 | pairs = get_pairs(word)
97 |
98 | if not pairs:
99 | return token + ""
100 |
101 | while True:
102 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
103 | if bigram not in self.bpe_ranks:
104 | break
105 | first, second = bigram
106 | new_word = []
107 | i = 0
108 | while i < len(word):
109 | try:
110 | j = word.index(first, i)
111 | new_word.extend(word[i:j])
112 | i = j
113 | except:
114 | new_word.extend(word[i:])
115 | break
116 |
117 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
118 | new_word.append(first + second)
119 | i += 2
120 | else:
121 | new_word.append(word[i])
122 | i += 1
123 | new_word = tuple(new_word)
124 | word = new_word
125 | if len(word) == 1:
126 | break
127 | else:
128 | pairs = get_pairs(word)
129 | word = " ".join(word)
130 | self.cache[token] = word
131 | return word
132 |
133 | def encode(self, text):
134 | bpe_tokens = []
135 | text = whitespace_clean(basic_clean(text)).lower()
136 | for token in re.findall(self.pat, text):
137 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
138 | bpe_tokens.extend(
139 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
140 | )
141 | return bpe_tokens
142 |
143 | def decode(self, tokens):
144 | text = "".join([self.decoder[token] for token in tokens])
145 | text = (
146 | bytearray([self.byte_decoder[c] for c in text])
147 | .decode("utf-8", errors="replace")
148 | .replace("", " ")
149 | )
150 | return text
151 |
--------------------------------------------------------------------------------
/CLIP/model-card.md:
--------------------------------------------------------------------------------
1 | # Model Card: CLIP
2 |
3 | Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we’re providing some accompanying information about the multimodal model.
4 |
5 | ## Model Details
6 |
7 | The CLIP model was developed by researchers at OpenAI to learn about what contributes to robustness in computer vision tasks. The model was also developed to test the ability of models to generalize to arbitrary image classification tasks in a zero-shot manner. It was not developed for general model deployment - to deploy models like CLIP, researchers will first need to carefully study their capabilities in relation to the specific context they’re being deployed within.
8 |
9 | ### Model Date
10 |
11 | January 2021
12 |
13 | ### Model Type
14 |
15 | The base model uses a ResNet50 with several modifications as an image encoder and uses a masked self-attention Transformer as a text encoder. These encoders are trained to maximize the similarity of (image, text) pairs via a contrastive loss. There is also a variant of the model where the ResNet image encoder is replaced with a Vision Transformer.
16 |
17 | ### Model Versions
18 |
19 | Initially, we’ve released one CLIP model based on the Vision Transformer architecture equivalent to ViT-B/32, along with the RN50 model, using the architecture equivalent to ResNet-50.
20 |
21 | As part of the staged release process, we have also released the RN101 model, as well as RN50x4, a RN50 scaled up 4x according to the [EfficientNet](https://arxiv.org/abs/1905.11946) scaling rule. In July 2021, we additionally released the RN50x16 and ViT-B/16 models.
22 |
23 | Please see the paper linked below for further details about their specification.
24 |
25 | ### Documents
26 |
27 | - [Blog Post](https://openai.com/blog/clip/)
28 | - [CLIP Paper](https://arxiv.org/abs/2103.00020)
29 |
30 |
31 |
32 | ## Model Use
33 |
34 | ### Intended Use
35 |
36 | The model is intended as a research output for research communities. We hope that this model will enable researchers to better understand and explore zero-shot, arbitrary image classification. We also hope it can be used for interdisciplinary studies of the potential impact of such models - the CLIP paper includes a discussion of potential downstream impacts to provide an example for this sort of analysis.
37 |
38 | #### Primary intended uses
39 |
40 | The primary intended users of these models are AI researchers.
41 |
42 | We primarily imagine the model will be used by researchers to better understand robustness, generalization, and other capabilities, biases, and constraints of computer vision models.
43 |
44 | ### Out-of-Scope Use Cases
45 |
46 | **Any** deployed use case of the model - whether commercial or not - is currently out of scope. Non-deployed use cases such as image search in a constrained environment, are also not recommended unless there is thorough in-domain testing of the model with a specific, fixed class taxonomy. This is because our safety assessment demonstrated a high need for task specific testing especially given the variability of CLIP’s performance with different class taxonomies. This makes untested and unconstrained deployment of the model in any use case currently potentially harmful.
47 |
48 | Certain use cases which would fall under the domain of surveillance and facial recognition are always out-of-scope regardless of performance of the model. This is because the use of artificial intelligence for tasks such as these can be premature currently given the lack of testing norms and checks to ensure its fair use.
49 |
50 | Since the model has not been purposefully trained in or evaluated on any languages other than English, its use should be limited to English language use cases.
51 |
52 |
53 |
54 | ## Data
55 |
56 | The model was trained on publicly available image-caption data. This was done through a combination of crawling a handful of websites and using commonly-used pre-existing image datasets such as [YFCC100M](http://projects.dfki.uni-kl.de/yfcc100m/). A large portion of the data comes from our crawling of the internet. This means that the data is more representative of people and societies most connected to the internet which tend to skew towards more developed nations, and younger, male users.
57 |
58 | ### Data Mission Statement
59 |
60 | Our goal with building this dataset was to test out robustness and generalizability in computer vision tasks. As a result, the focus was on gathering large quantities of data from different publicly-available internet data sources. The data was gathered in a mostly non-interventionist manner. However, we only crawled websites that had policies against excessively violent and adult images and allowed us to filter out such content. We do not intend for this dataset to be used as the basis for any commercial or deployed model and will not be releasing the dataset.
61 |
62 |
63 |
64 | ## Performance and Limitations
65 |
66 | ### Performance
67 |
68 | We have evaluated the performance of CLIP on a wide range of benchmarks across a variety of computer vision datasets such as OCR to texture recognition to fine-grained classification. The paper describes model performance on the following datasets:
69 |
70 | - Food101
71 | - CIFAR10
72 | - CIFAR100
73 | - Birdsnap
74 | - SUN397
75 | - Stanford Cars
76 | - FGVC Aircraft
77 | - VOC2007
78 | - DTD
79 | - Oxford-IIIT Pet dataset
80 | - Caltech101
81 | - Flowers102
82 | - MNIST
83 | - SVHN
84 | - IIIT5K
85 | - Hateful Memes
86 | - SST-2
87 | - UCF101
88 | - Kinetics700
89 | - Country211
90 | - CLEVR Counting
91 | - KITTI Distance
92 | - STL-10
93 | - RareAct
94 | - Flickr30
95 | - MSCOCO
96 | - ImageNet
97 | - ImageNet-A
98 | - ImageNet-R
99 | - ImageNet Sketch
100 | - ObjectNet (ImageNet Overlap)
101 | - Youtube-BB
102 | - ImageNet-Vid
103 |
104 | ## Limitations
105 |
106 | CLIP and our analysis of it have a number of limitations. CLIP currently struggles with respect to certain tasks such as fine grained classification and counting objects. CLIP also poses issues with regards to fairness and bias which we discuss in the paper and briefly in the next section. Additionally, our approach to testing CLIP also has an important limitation- in many cases we have used linear probes to evaluate the performance of CLIP and there is evidence suggesting that linear probes can underestimate model performance.
107 |
108 | ### Bias and Fairness
109 |
110 | We find that the performance of CLIP - and the specific biases it exhibits - can depend significantly on class design and the choices one makes for categories to include and exclude. We tested the risk of certain kinds of denigration with CLIP by classifying images of people from [Fairface](https://arxiv.org/abs/1908.04913) into crime-related and non-human animal categories. We found significant disparities with respect to race and gender. Additionally, we found that these disparities could shift based on how the classes were constructed. (Details captured in the Broader Impacts Section in the paper).
111 |
112 | We also tested the performance of CLIP on gender, race and age classification using the Fairface dataset (We default to using race categories as they are constructed in the Fairface dataset.) in order to assess quality of performance across different demographics. We found accuracy >96% across all races for gender classification with ‘Middle Eastern’ having the highest accuracy (98.4%) and ‘White’ having the lowest (96.5%). Additionally, CLIP averaged ~93% for racial classification and ~63% for age classification. Our use of evaluations to test for gender, race and age classification as well as denigration harms is simply to evaluate performance of the model across people and surface potential risks and not to demonstrate an endorsement/enthusiasm for such tasks.
113 |
114 |
115 |
116 | ## Feedback
117 |
118 | ### Where to send questions or comments about the model
119 |
120 | Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9)
121 |
--------------------------------------------------------------------------------
/CLIP/requirements.txt:
--------------------------------------------------------------------------------
1 | ftfy
2 | regex
3 | tqdm
4 | torch
5 | torchvision
6 |
--------------------------------------------------------------------------------
/CLIP/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import pkg_resources
4 | from setuptools import setup, find_packages
5 |
6 | setup(
7 | name="clip",
8 | py_modules=["clip"],
9 | version="1.0",
10 | description="",
11 | author="OpenAI",
12 | packages=find_packages(exclude=["tests*"]),
13 | install_requires=[
14 | str(r)
15 | for r in pkg_resources.parse_requirements(
16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
17 | )
18 | ],
19 | include_package_data=True,
20 | extras_require={"dev": ["pytest"]},
21 | )
22 |
--------------------------------------------------------------------------------
/CLIP/tests/test_consistency.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 | from PIL import Image
5 |
6 | import clip
7 |
8 |
9 | @pytest.mark.parametrize("model_name", clip.available_models())
10 | def test_consistency(model_name):
11 | device = "cpu"
12 | jit_model, transform = clip.load(model_name, device=device, jit=True)
13 | py_model, _ = clip.load(model_name, device=device, jit=False)
14 |
15 | image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
16 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)
17 |
18 | with torch.no_grad():
19 | logits_per_image, _ = jit_model(image, text)
20 | jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
21 |
22 | logits_per_image, _ = py_model(image, text)
23 | py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()
24 |
25 | assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)
26 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Columbia Artificial Intelligence and Robotics Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
Semantic Abstraction: Open-World 3D Scene Understanding from 2D Vision-Language Models
2 |
3 |
4 | [Huy Ha](https://www.cs.columbia.edu/~huy/), [Shuran Song](https://www.cs.columbia.edu/~shurans/)
5 |
6 | Columbia University, New York, NY, United States
7 |
8 | [Conference on Robot Learning 2022](https://corl2022.org/)
9 |
10 | [Project Page](https://semantic-abstraction.cs.columbia.edu/) | [Arxiv](https://arxiv.org/abs/2207.11514)
11 |
12 | [](https://huggingface.co/spaces/huy-ha/semabs-relevancy)
13 |
14 |
15 |

16 |
17 | Our approach, Semantic Abstraction, unlocks 2D VLM's capabilities to 3D scene understanding. Trained with a limited synthetic dataset, our model generalizes to unseen classes in a novel domain (i.e., real world), even for small objects like “rubiks cube”, long-tail concepts like “harry potter”, and hidden objects like the “used N95s in the garbage bin”. Unseen classes are bolded.
18 |
19 |
20 |
21 |
22 |
23 |
24 | This repository contains code for generating relevancies, training, and evaluating [Semantic Abstraction](https://semantic-abstraction.cs.columbia.edu/).
25 | It has been tested on Ubuntu 18.04 and 20.04, NVIDIA GTX 1080, NVIDIA RTX A6000, NVIDIA GeForce RTX 3080, and NVIDIA GeForce RTX 3090.
26 |
27 | If you find this codebase useful, consider citing:
28 |
29 | ```bibtex
30 | @inproceedings{ha2022semabs,
31 | title={Semantic Abstraction: Open-World 3{D} Scene Understanding from 2{D} Vision-Language Models},
32 | author = {Ha, Huy and Song, Shuran},
33 | booktitle={Proceedings of the 2022 Conference on Robot Learning},
34 | year={2022}
35 | }
36 | ```
37 |
38 | If you have any questions, please contact [me](https://www.cs.columbia.edu/~huy/) at `huy [at] cs [dot] columbia [dot] edu`.
39 |
40 | **Table of Contents**
41 |
42 | - [Setup](#setup)
43 | - [Environment](#environment)
44 | - [Models](#models)
45 | - [Dataset (Optional)](#dataset-optional)
46 | - [Multi-scale Relevancy Extractor](#multi-scale-relevancy-extractor)
47 | - [Evaluation](#evaluation)
48 | - [Summarize](#summarize)
49 | - [Run inference](#run-inference)
50 | - [Visualization](#visualization)
51 | - [Training](#training)
52 | - [OVSSC](#ovssc)
53 | - [VOOL](#vool)
54 | - [Codebase Walkthrough](#codebase-walkthrough)
55 | - [Acknowledgements](#acknowledgements)
56 |
57 | # Setup
58 |
59 | ## Environment
60 |
61 | Create the conda environment
62 |
63 | ```sh
64 | conda env create -f semabs.yml
65 | ```
66 |
67 | ## Models
68 |
69 | Download the model checkpoints (~3.5GB) by running this command at the root of the repo
70 | ```sh
71 | wget https://semantic-abstraction.cs.columbia.edu/downloads/models.tar.lz4 -O - | tar --use-compress-program=lz4 -xf - -C ./
72 | ```
73 |
74 | You should have the following directory structure
75 |
76 | ```console
77 | ❯ tree /path/to/semantic-abstraction/models
78 | /path/to/semantic-abstraction/models
79 | ├── chefer_et_al
80 | │ ├── ovssc
81 | │ │ ├── args.pkl
82 | │ │ ├── ovssc.pth
83 | │ │ └── ovssc_eval_stats.pkl
84 | │ └── vool
85 | │ ├── args.pkl
86 | │ ├── vool.pth
87 | │ └── vool_eval_stats.pkl
88 | ├── clipspatial
89 | │ └── vool
90 | │ ├── args.pkl
91 | │ ├── vool.pth
92 | │ └── vool_eval_stats.pkl
93 | ├── ours
94 | │ ├── ovssc
95 | │ │ ├── args.pkl
96 | │ │ ├── ovssc.pth
97 | │ │ └── ovssc_eval_stats.pkl
98 | │ └── vool
99 | │ ├── args.pkl
100 | │ ├── vool.pth
101 | │ └── vool_eval_stats.pkl
102 | └── semaware
103 | ├── ovssc
104 | │ ├── args.pkl
105 | │ ├── ovssc.pth
106 | │ └── ovssc_eval_stats.pkl
107 | └── vool
108 | ├── args.pkl
109 | ├── vool.pth
110 | └── vool_eval_stats.pkl
111 |
112 | 11 directories, 21 files
113 | ```
114 |
115 |
116 | ## Dataset (Optional)
117 |
118 | To run the evaluation inference or training, you will need the dataset.
119 |
120 | Download the dataset (~269GB) by running the following at the root of the repo
121 | ```sh
122 | wget https://semantic-abstraction.cs.columbia.edu/downloads/dataset.tar.lz4 -O - | tar --use-compress-program=lz4 -xf - -C ./
123 | ```
124 |
125 |
126 | We also preprocessed the (~53GB) NYU dataset for training and evaluation
127 | ```sh
128 | wget https://semantic-abstraction.cs.columbia.edu/downloads/nyu_ovssc.tar.lz4 -O - | tar --use-compress-program=lz4 -xf - -C ./
129 | ```
130 |
131 |
132 | # Multi-scale Relevancy Extractor
133 |
134 | Play around with the multi-scale relevancy extractor on [](https://huggingface.co/spaces/huy-ha/semabs-relevancy).
135 |
136 | To run the multi-scale relevancy on GPU locally with the provided image from matterport
137 |
138 | ```sh
139 | python generate_relevancy.py image
140 | ```
141 |
142 | which will output
143 |
144 | 
145 |
146 | Try passing in your own image!
147 |
148 | ```console
149 | ❯ python generate_relevancy.py image --help
150 |
151 | Usage: generate_relevancy.py image [OPTIONS] [FILE_PATH]
152 |
153 | Generates a multi-scale relevancy for image at `file_path`.
154 |
155 | ╭─ Arguments ───────────────────────────────────────────────────────────────────────────────────╮
156 | │ file_path [FILE_PATH] path of image file [default: matterport.png] │
157 | ╰───────────────────────────────────────────────────────────────────────────────────────────────╯
158 | ╭─ Options ─────────────────────────────────────────────────────────────────────────────────────╮
159 | │ --labels TEXT list of object categories (e.g.: "nintendo switch") │
160 | │ [default: basketball jersey, nintendo switch, television, ping pong │
161 | │ table, vase, fireplace, abstract painting of a vespa, carpet, wall] │
162 | │ --prompts TEXT prompt template to use with CLIP. │
163 | │ [default: a photograph of a {} in a home.] │
164 | │ --help Show this message and exit. │
165 | ╰───────────────────────────────────────────────────────────────────────────────────────────────╯
166 | ```
167 |
168 | # Evaluation
169 |
170 | ## Summarize
171 |
172 | The evaluation result dataframes (`*_eval_stats*.pkl`) are provided along with the model checkpoints.
173 | To summarize them in a table
174 |
175 | ```console
176 | ❯ python summarize.py
177 | OVSSC THOR
178 | ╷ ╷ ╷ ╷
179 | Approach │ Novel Room │ Novel Visual │ Novel Vocab │ Novel Class
180 | ═════════════════════════╪════════════╪══════════════╪═════════════╪═════════════
181 | Semantic Aware │ 32.2 │ 31.9 │ 20.2 │ 0.0
182 | SemAbs + [Chefer et al] │ 26.6 │ 24.3 │ 17.8 │ 12.2
183 | ─────────────────────────┼────────────┼──────────────┼─────────────┼─────────────
184 | Ours │ 40.1 │ 36.4 │ 33.4 │ 37.9
185 | ╵ ╵ ╵ ╵
186 | FULL VOOL THOR
187 | ╷ ╷ ╷ ╷ ╷
188 | Approach │ Spatial Relation │ Novel Room │ Novel Visual │ Novel Vocab │ Novel Class
189 | ═════════════════════════╪══════════════════╪════════════╪══════════════╪═════════════╪═════════════
190 | Semantic Aware │ in │ 15.0 │ 14.7 │ 7.6 │ 1.8
191 | │ on │ 9.0 │ 8.9 │ 11.4 │ 4.5
192 | │ on the left of │ 11.2 │ 11.1 │ 14.4 │ 4.0
193 | │ behind │ 12.8 │ 12.6 │ 14.1 │ 2.2
194 | │ on the right of │ 13.1 │ 13.0 │ 11.5 │ 3.4
195 | │ in front of │ 11.2 │ 11.1 │ 9.3 │ 2.2
196 | │ mean │ 12.1 │ 11.9 │ 11.4 │ 3.0
197 | ─────────────────────────┼──────────────────┼────────────┼──────────────┼─────────────┼─────────────
198 | ClipSpatial │ in │ 9.6 │ 8.6 │ 7.1 │ 3.3
199 | │ on │ 14.1 │ 12.1 │ 18.5 │ 20.0
200 | │ on the left of │ 11.0 │ 9.4 │ 14.2 │ 13.2
201 | │ behind │ 11.3 │ 9.9 │ 14.1 │ 8.9
202 | │ on the right of │ 12.1 │ 10.6 │ 16.2 │ 11.5
203 | │ in front of │ 12.3 │ 10.3 │ 15.7 │ 9.9
204 | │ mean │ 11.7 │ 10.1 │ 14.3 │ 11.2
205 | ─────────────────────────┼──────────────────┼────────────┼──────────────┼─────────────┼─────────────
206 | SemAbs + [Chefer et al] │ in │ 11.8 │ 11.1 │ 5.7 │ 2.1
207 | │ on │ 7.0 │ 6.7 │ 11.3 │ 7.1
208 | │ on the left of │ 9.5 │ 9.3 │ 13.7 │ 4.9
209 | │ behind │ 7.6 │ 7.6 │ 10.6 │ 2.5
210 | │ on the right of │ 9.2 │ 9.2 │ 11.0 │ 3.9
211 | │ in front of │ 9.4 │ 9.0 │ 12.0 │ 3.3
212 | │ mean │ 9.1 │ 8.8 │ 10.7 │ 4.0
213 | ─────────────────────────┼──────────────────┼────────────┼──────────────┼─────────────┼─────────────
214 | Ours │ in │ 17.8 │ 17.5 │ 8.5 │ 7.3
215 | │ on │ 21.0 │ 18.0 │ 27.2 │ 28.1
216 | │ on the left of │ 22.0 │ 20.3 │ 27.7 │ 25.1
217 | │ behind │ 19.9 │ 18.0 │ 22.8 │ 16.7
218 | │ on the right of │ 23.2 │ 21.7 │ 28.1 │ 22.1
219 | │ in front of │ 21.5 │ 19.4 │ 25.8 │ 19.1
220 | │ mean │ 20.9 │ 19.2 │ 23.4 │ 19.7
221 |
222 | OVSSC NYU
223 | ╷ ╷ ╷ ╷ ╷ ╷ ╷ ╷ ╷ ╷ ╷ ╷
224 | Approach │ Ceiling │ Floor │ Wall │ Window │ Chair │ Bed │ Sofa │ Table │ Tvs │ Furn │ Objs │ Mean
225 | ═══════════════════╪═════════╪═══════╪══════╪════════╪═══════╪══════╪══════╪═══════╪══════╪══════╪══════╪══════
226 | Ours (Supervised) │ 22.6 │ 46.1 │ 33.9 │ 35.9 │ 23.9 │ 55.9 │ 37.9 │ 19.7 │ 30.8 │ 39.8 │ 27.7 │ 34.0
227 | ───────────────────┼─────────┼───────┼──────┼────────┼───────┼──────┼──────┼───────┼──────┼──────┼──────┼──────
228 | Ours (Zeroshot) │ 13.7 │ 17.3 │ 13.5 │ 25.2 │ 15.2 │ 33.3 │ 31.5 │ 12.0 │ 23.7 │ 25.6 │ 19.9 │ 21.0
229 | ╵ ╵ ╵ ╵ ╵ ╵ ╵ ╵ ╵ ╵ ╵ ╵
230 | ```
231 |
232 | ## Run inference
233 |
234 | To run inference, make sure the dataset is [downloaded](#dataset-optional).
235 |
236 | Then, regenerate the evaluation result dataframes by running the evaluation script.
237 |
238 | For OVSSC
239 |
240 | ```sh
241 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=1 eval.py --task ovssc --file_path dataset/ --gpus 0 --load models/ours/ovssc/ovssc.pth
242 | ```
243 |
244 | Inference can be sped up by using more than one GPU.
245 | For instance, to use GPU 0, 1, 2, and 3, use `--nproc_per_node=4` and `--gpus 0 1 2 3`.
246 |
247 | For OVSSC on NYU
248 | ```sh
249 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=1 eval.py --task ovssc --file_path nyu_ovssc/ --gpus 0 --load models/ours/ovssc/ovssc.pth
250 | ```
251 |
252 | Similarly, for VOOL
253 |
254 | ```sh
255 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=1 eval.py --task vool --file_path dataset/ --gpus 0 --load models/ours/vool/vool.pth
256 | ```
257 |
258 | ## Visualization
259 |
260 | The `visualize.py` script takes as input the scene pickle file and the network checkpoint.
261 |
262 | The pickle file should be a dictionary with the following keys and types
263 | ```python
264 | rgb: np.ndarray # shape h x w x 3
265 | depth: np.ndarray # shape h x w
266 | img_shape: Tuple[int, int]
267 | cam_intr: np.ndarray # shape 4 x 4
268 | cam_extr: np.ndarray # shape 4 x 4
269 | ovssc_obj_classes: List[str]
270 | descriptions: List[List[str]]
271 | ```
272 | After being loaded, `rgb` and `depth` will be resized to `img_shape`, which matches the image dimensions in `cam_intr`.
273 | Each element in descriptions is a list containing the target object name, spatial preposition and reference object name respectively.
274 | We provide some example scene pickle files from [Habitat Matterport 3D](https://aihabitat.org/datasets/hm3d/) and [ARKitScenes](https://github.com/apple/ARKitScenes) in `scene_files/`.
275 |
276 | Visualizing OVSSC generates a `.mp4` video of the completion, along with `.obj` meshes, while visualizing VOOL generates `.mp4` videos along with `.ply` pointclouds for each description.
277 |
278 | For instance, running
279 | ```sh
280 | # OVSSC
281 | python visualize.py ovssc-inference scene_files/arkit_vn_poster.pkl models/ours/ovssc/ovssc.pth
282 | python visualize.py ovssc-visualize visualization/arkit_vn_poster
283 | # VOOL
284 | python visualize.py vool-inference scene_files/arkit_vn_poster.pkl models/ours/vool/vool.pth
285 | python visualize.py vool-visualize visualization/arkit_vn_poster
286 | ```
287 | Will output to `visualization/arkit_vn_poster`, including the following relevancies
288 |
289 | 
290 |
291 | and the following VOOL localization for `the hair dryer with its wires all tangled up behind the table legs` and OVSSC completion:
292 |
293 | | RGB | Localization | Completion |
294 | | -------------------------------- | --------------------------------------- | ------------------------------------- |
295 | |  |  |  |
296 |
297 |
298 | While these visualizations are sufficient for debugging, I recommend using the `ply` and `obj` files to render in Blender.
299 |
300 |
301 | | Legend | Localization |
302 | | ---------------------------------------------------- | --------------------------------------------- |
303 | |  |  |
304 |
305 |
306 | # Training
307 |
308 | To train the models, make sure the dataset is [downloaded](#dataset-optional).
309 |
310 | ## OVSSC
311 |
312 | To retrain our model
313 |
314 | ```sh
315 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_ovssc.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --saliency_config ours
316 | ```
317 |
318 | To retrain the semantic aware model
319 |
320 | ```sh
321 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_ovssc.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --approach semantic_aware
322 | ```
323 |
324 | To retrain the semantic abstraction + [Chefer et. al](https://github.com/hila-chefer/Transformer-MM-Explainability) model
325 |
326 | ```sh
327 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_ovssc.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --approach semantic_aware
328 | ```
329 |
330 | ## VOOL
331 |
332 | To retrain our model
333 |
334 | ```sh
335 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_vool.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --saliency_config ours
336 | ```
337 |
338 | To retrain the semantic aware model
339 |
340 | ```sh
341 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_vool.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --approach semantic_aware
342 | ```
343 |
344 | To retrain the CLIP-Spatial model
345 |
346 | ```
347 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_vool.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --approach clip_spatial
348 | ```
349 |
350 | To retrain the semantic abstraction + [Chefer et. al](https://github.com/hila-chefer/Transformer-MM-Explainability) model
351 |
352 | ```sh
353 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_vool.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --approach semantic_aware
354 | ```
355 |
356 |
357 | # Codebase Walkthrough
358 |
359 | Below, we've provided a summary of the networks provide, along with how and where our method as described in the paper is implemented. The links in the bullet points below links to specific lines in this codebase. We hope this code annotation helps clarify the network architecture and training procedures.
360 |
361 | - [`SemAbs3D`](net.py#L319): This class implements the SemAbs module. It contains two networks, a [`ResidualUNet3D`](unet3d.py) (i.e, $f_\mathrm{encode}$) and an [`ImplicitVolumetricDecoder`](net.py#L204) (i.e, $f_\mathrm{decoder}$).
362 | - [`SemanticAwareOVSSC`](net.py#L441): This class implements the SemAware baseline in the OVSSC task. It inherits directly from `SemAbs3D`, with two crucial differences: 1) it [takes RGB pointclouds as input](train_ovssc.py#L237) instead of [saliency pointclouds](utils.py#L94), and ) it uses its [sampled feature pointclouds to point](net.py#L455) to text features of semantic classes (i.e., encoded using CLIP's text encoder). These two differences together mean it has to learn to recognize semantic classes from RGB inputs by itself, which leads to overfitting of training semantic classes.
363 | - [`SemAbsVOOL`](net.py#L468): This class implements the SemAbs module for the VOOL task. In addition to the learnable parameters of `SemAbs3D`, it includes a [set of relational embeddings](net.py#L489), one for each of the close-vocabulary of spatial relations.
364 | - [`SemanticAwareVOOL`](net.py#L581): This class implements the SemAware baseline for the VOOl task. Similar to `SemanticAwareOVSSC`, it [takes as input RGB pointclouds](train_vool.py#L222). However, specifically for the VOOl task, it uses the entire localization description ([encoded using CLIP's text encoders and learned spatial relations](net.py#L590)) to [point to regions within the scene](net.py#L606).
365 | - [`ClipSpatialVOOL`](net.py#L638): This class implements the CLIPSpatial baseline for the VOOL task. In contrast to other VOOL networks, it does not attempt to learn spatial relations or semantics. Instead, it completely relies on relevancy inputs from CLIP.
366 |
367 |
368 | A few tips for training your semantic abstraction module:
369 | - We have observed that performing [small random transformations](dataset.py#L534) on input and output point clouds help generalization significantly.
370 | - To account for positive/negative class balance when using the binary cross entropy loss (in both [OVSSC](train_ovssc.py#L147) and [VOOL](train_vool.py#L171)), we found that using [`--balance_spatial_sampling`](utils.py#L71) helps tremendously. This [biases the subsampling of query points](dataset.py#L607) such that as many positive points are sample as possible without replacement to achieve a balanced batch.
371 | - Remember to [rescale your relevancy values in a reasonable range](dataset.py#L1052)!
372 |
373 |
374 | # Acknowledgements
375 |
376 | We would like to thank Samir Yitzhak Gadre, Cheng Chi and Zhenjia Xu for their helpful feedback and fruitful discussions.
377 | This work was supported in part by NSF Award #2143601, #2132519, JP Morgan Faculty Research Award, and Google Research Award.
378 | The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies, either expressed or implied, of the sponsors.
379 |
380 | Code:
381 | - The relevancy extraction code was modified from [Chefer et al.'s codebase](https://github.com/hila-chefer/Transformer-MM-Explainability) and [CLIP on Wheels' codebase](https://cow.cs.columbia.edu/).
382 | - The [3D U-Net](unet3d.py) definition was taken from [Adrian Wolny](https://github.com/wolny/pytorch-3dunet/).
383 | - Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio).
384 | - The [LAMB](https://arxiv.org/pdf/1904.00962.pdf) PyTorch implementation is from the [Attention-driven Robotic Manipulation (ARM)](https://github.com/stepjam/ARM#running-experiments).
385 |
386 |
387 | 
--------------------------------------------------------------------------------
/arm/LICENSE:
--------------------------------------------------------------------------------
1 | Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation
2 |
3 | LICENCE AGREEMENT
4 |
5 | WE (Imperial College of Science, Technology and Medicine, (“Imperial College London”))
6 | ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY ON THE
7 | CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE
8 | FOLLOWING AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE
9 | DOWNLOADING THE SOFTWARE. BY EXERCISING THE OPTION TO DOWNLOAD
10 | THE SOFTWARE YOU AGREE TO BE BOUND BY THE TERMS OF THE AGREEMENT.
11 | SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS)
12 |
13 | 1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully paid-up, royalty
14 | free, non-transferable, non-sub- licensable licence (the “Licence”) to use the Q-attention
15 | source code, including any modification, part or derivative (the “Software”).
16 | Ownership and Licence. Your rights to use and download the Software onto your computer,
17 | and all other copies that You are authorised to make, are specified in this Agreement.
18 | However, we (or our licensors) retain all rights, including but not limited to all copyright and
19 | other intellectual property rights anywhere in the world, in the Software not expressly
20 | granted to You in this Agreement.
21 |
22 | 2. Permitted use of the Licence:
23 |
24 | (a) You may download and install the Software onto one computer or server for use in
25 | accordance with Clause 2(b) of this Agreement provided that You ensure that the Software is
26 | not accessible by other users unless they have themselves accepted the terms of this licence
27 | agreement.
28 |
29 | (b) You may use the Software solely for non-commercial, internal or academic research
30 | purposes and only in accordance with the terms of this Agreement. You may not use the
31 | Software for commercial purposes, including but not limited to (1) integration of all or part of
32 | the source code or the Software into a product for sale or licence by or on behalf of You to
33 | third parties or (2) use of the Software or any derivative of it for research to develop software
34 | products for sale or licence to a third party or (3) use of the Software or any derivative of it
35 | for research to develop non-software products for sale or licence to a third party, or (4) use of
36 | the Software to provide any service to an external organisation for which payment is
37 | received.
38 |
39 | Should You wish to use the Software for commercial purposes, You shall
40 | email researchcontracts.engineering@imperial.ac.uk .
41 |
42 | (c) Right to Copy. You may copy the Software for back-up and archival purposes, provided
43 | that each copy is kept in your possession and provided You reproduce our copyright notice
44 | (set out in Schedule 1) on each copy.
45 |
46 | (d) Transfer and sub-licensing. You may not rent, lend, or lease the Software and You may
47 | not transmit, transfer or sub-license this licence to use the Software or any of your rights or
48 | obligations under this Agreement to another party.
49 |
50 | (e) Identity of Licensee. The licence granted herein is personal to You. You shall not permit
51 | any third party to access, modify or otherwise use the Software nor shall You access modify
52 | or otherwise use the Software on behalf of any third party. If You wish to obtain a licence for
53 | mutiple users or a site licence for the Software please contact us
54 | at researchcontracts.engineering@imperial.ac.uk .
55 |
56 | (f) Publications and presentations. You may make public, results or data obtained from,
57 | dependent on or arising from research carried out using the Software, provided that any such
58 | presentation or publication identifies the Software as the source of the results or the data,
59 | including the Copyright Notice given in each element of the Software, and stating that the
60 | Software has been made available for use by You under licence from Imperial College London
61 | and You provide a copy of any such publication to Imperial College London.
62 |
63 | 3. Prohibited Uses. You may not, without written permission from us
64 | at researchcontracts.engineering@imperial.ac.uk :
65 |
66 | (a) Use, copy, modify, merge, or transfer copies of the Software or any documentation
67 | provided by us which relates to the Software except as provided in this Agreement;
68 |
69 | (b) Use any back-up or archival copies of the Software (or allow anyone else to use such
70 | copies) for any purpose other than to replace the original copy in the event it is destroyed or
71 | becomes defective; or
72 |
73 | (c) Disassemble, decompile or "unlock", reverse translate, or in any manner decode the
74 | Software for any reason.
75 |
76 | 4. Warranty Disclaimer
77 |
78 | (a) Disclaimer. The Software has been developed for research purposes only. You
79 | acknowledge that we are providing the Software to You under this licence agreement free of
80 | charge and on condition that the disclaimer set out below shall apply. We do not represent or
81 | warrant that the Software as to: (i) the quality, accuracy or reliability of the Software; (ii) the
82 | suitability of the Software for any particular use or for use under any specific conditions; and
83 | (iii) whether use of the Software will infringe third-party rights.
84 | You acknowledge that You have reviewed and evaluated the Software to determine that it
85 | meets your needs and that You assume all responsibility and liability for determining the
86 | suitability of the Software as fit for your particular purposes and requirements. Subject to
87 | Clause 4(b), we exclude and expressly disclaim all express and implied representations,
88 | warranties, conditions and terms not stated herein (including the implied conditions or
89 | warranties of satisfactory quality, merchantable quality, merchantability and fitness for
90 | purpose).
91 |
92 | (b) Savings. Some jurisdictions may imply warranties, conditions or terms or impose
93 | obligations upon us which cannot, in whole or in part, be excluded, restricted or modified or
94 | otherwise do not allow the exclusion of implied warranties, conditions or terms, in which
95 | case the above warranty disclaimer and exclusion will only apply to You to the extent
96 | permitted in the relevant jurisdiction and does not in any event exclude any implied
97 | warranties, conditions or terms which may not under applicable law be excluded.
98 |
99 | (c) Imperial College London disclaims all responsibility for the use which is made of the
100 | Software and any liability for the outcomes arising from using the Software.
101 |
102 | 5. Limitation of Liability
103 |
104 | (a) You acknowledge that we are providing the Software to You under this licence agreement
105 | free of charge and on condition that the limitation of liability set out below shall apply.
106 | Accordingly, subject to Clause 5(b), we exclude all liability whether in contract, tort,
107 | negligence or otherwise, in respect of the Software and/or any related documentation
108 | provided to You by us including, but not limited to, liability for loss or corruption of data,
109 | loss of contracts, loss of income, loss of profits, loss of cover and any consequential or indirect
110 | loss or damage of any kind arising out of or in connection with this licence agreement,
111 | however caused. This exclusion shall apply even if we have been advised of the possibility of
112 | such loss or damage.
113 |
114 | (b) You agree to indemnify Imperial College London and hold it harmless from and against
115 | any and all claims, damages and liabilities asserted by third parties (including claims for
116 | negligence) which arise directly or indirectly from the use of the Software or any derivative
117 | of it or the sale of any products based on the Software. You undertake to make no liability
118 | claim against any employee, student, agent or appointee of Imperial College London, in
119 | connection with this Licence or the Software.
120 |
121 | (c) Nothing in this Agreement shall have the effect of excluding or limiting our statutory
122 | liability.
123 |
124 | (d) Some jurisdictions do not allow these limitations or exclusions either wholly or in part,
125 | and, to that extent, they may not apply to you. Nothing in this licence agreement will affect
126 | your statutory rights or other relevant statutory provisions which cannot be excluded,
127 | restricted or modified, and its terms and conditions must be read and construed subject to any
128 | such statutory rights and/or provisions.
129 |
130 | 6. Confidentiality. You agree not to disclose any confidential information provided to You by
131 | us pursuant to this Agreement to any third party without our prior written consent. The
132 | obligations in this Clause 6 shall survive the termination of this Agreement for any reason.
133 |
134 | 7. Termination.
135 |
136 | (a) We may terminate this licence agreement and your right to use the Software at any time
137 | with immediate effect upon written notice to You.
138 |
139 | (b) This licence agreement and your right to use the Software automatically terminate if You:
140 | (i) fail to comply with any provisions of this Agreement; or
141 | (ii) destroy the copies of the Software in your possession, or voluntarily return the Software
142 | to us.
143 |
144 | (c) Upon termination You will destroy all copies of the Software.
145 |
146 | (d) Otherwise, the restrictions on your rights to use the Software will expire 10 (ten) years
147 | after first use of the Software under this licence agreement.
148 |
149 | 8. Miscellaneous Provisions.
150 |
151 | (a) This Agreement will be governed by and construed in accordance with the substantive
152 | laws of England and Wales whose courts shall have exclusive jurisdiction over all disputes
153 | which may arise between us.
154 |
155 | (b) This is the entire agreement between us relating to the Software, and supersedes any prior
156 | purchase order, communications, advertising or representations concerning the Software.
157 |
158 | (c) No change or modification of this Agreement will be valid unless it is in writing, and is
159 | signed by us.
160 |
161 | (d) The unenforceability or invalidity of any part of this Agreement will not affect the
162 | enforceability or validity of the remaining parts.
163 |
164 | BSD Elements of the Software
165 |
166 | For BSD elements of the Software, the following terms shall apply:
167 |
168 | Copyright as indicated in the header of the individual element of the Software.
169 |
170 | All rights reserved.
171 |
172 | Redistribution and use in source and binary forms, with or without modification, are
173 | permitted provided that the following conditions are met:
174 |
175 | 1. Redistributions of source code must retain the above copyright notice, this list of
176 | conditions and the following disclaimer.
177 |
178 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of
179 | conditions and the following disclaimer in the documentation and/or other materials
180 | provided with the distribution.
181 |
182 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to
183 | endorse or promote products derived from this software without specific prior written
184 | permission.
185 |
186 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
187 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
188 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
189 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
190 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
191 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
192 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
193 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
194 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
195 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
196 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
--------------------------------------------------------------------------------
/arm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/arm/__init__.py
--------------------------------------------------------------------------------
/arm/optim/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 cybertronai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/arm/optim/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/arm/optim/__init__.py
--------------------------------------------------------------------------------
/arm/optim/lamb.py:
--------------------------------------------------------------------------------
1 | # From https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py
2 |
3 | """Lamb optimizer."""
4 |
5 | import collections
6 | import math
7 |
8 | import torch
9 | from torch.optim import Optimizer
10 |
11 |
12 | # def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
13 | # """Log a histogram of trust ratio scalars in across layers."""
14 | # results = collections.defaultdict(list)
15 | # for group in optimizer.param_groups:
16 | # for p in group['params']:
17 | # state = optimizer.state[p]
18 | # for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
19 | # if i in state:
20 | # results[i].append(state[i])
21 | #
22 | # for k, v in results.items():
23 | # event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
24 |
25 |
26 | class Lamb(Optimizer):
27 | r"""Implements Lamb algorithm.
28 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
29 | Arguments:
30 | params (iterable): iterable of parameters to optimize or dicts defining
31 | parameter groups
32 | lr (float, optional): learning rate (default: 1e-3)
33 | betas (Tuple[float, float], optional): coefficients used for computing
34 | running averages of gradient and its square (default: (0.9, 0.999))
35 | eps (float, optional): term added to the denominator to improve
36 | numerical stability (default: 1e-8)
37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
38 | adam (bool, optional): always use trust ratio = 1, which turns this into
39 | Adam. Useful for comparison purposes.
40 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
41 | https://arxiv.org/abs/1904.00962
42 | """
43 |
44 | def __init__(
45 | self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False
46 | ):
47 | if not 0.0 <= lr:
48 | raise ValueError("Invalid learning rate: {}".format(lr))
49 | if not 0.0 <= eps:
50 | raise ValueError("Invalid epsilon value: {}".format(eps))
51 | if not 0.0 <= betas[0] < 1.0:
52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
53 | if not 0.0 <= betas[1] < 1.0:
54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
55 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
56 | self.adam = adam
57 | super(Lamb, self).__init__(params, defaults)
58 |
59 | def step(self, closure=None):
60 | """Performs a single optimization step.
61 | Arguments:
62 | closure (callable, optional): A closure that reevaluates the model
63 | and returns the loss.
64 | """
65 | loss = None
66 | if closure is not None:
67 | loss = closure()
68 |
69 | for group in self.param_groups:
70 | for p in group["params"]:
71 | if p.grad is None:
72 | continue
73 | grad = p.grad.data
74 | if grad.is_sparse:
75 | raise RuntimeError(
76 | "Lamb does not support sparse gradients, consider SparseAdam instad."
77 | )
78 |
79 | state = self.state[p]
80 |
81 | # State initialization
82 | if len(state) == 0:
83 | state["step"] = 0
84 | # Exponential moving average of gradient values
85 | state["exp_avg"] = torch.zeros_like(p.data)
86 | # Exponential moving average of squared gradient values
87 | state["exp_avg_sq"] = torch.zeros_like(p.data)
88 |
89 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
90 | beta1, beta2 = group["betas"]
91 |
92 | state["step"] += 1
93 |
94 | # Decay the first and second moment running average coefficient
95 | # m_t
96 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
97 | # v_t
98 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
99 |
100 | # Paper v3 does not use debiasing.
101 | # bias_correction1 = 1 - beta1 ** state['step']
102 | # bias_correction2 = 1 - beta2 ** state['step']
103 | # Apply bias to lr to avoid broadcast.
104 | step_size = group[
105 | "lr"
106 | ] # * math.sqrt(bias_correction2) / bias_correction1
107 |
108 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
109 |
110 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"])
111 | if group["weight_decay"] != 0:
112 | adam_step.add_(p.data, alpha=group["weight_decay"])
113 |
114 | adam_norm = adam_step.pow(2).sum().sqrt()
115 | if weight_norm == 0 or adam_norm == 0:
116 | trust_ratio = 1
117 | else:
118 | trust_ratio = weight_norm / adam_norm
119 | state["weight_norm"] = weight_norm
120 | state["adam_norm"] = adam_norm
121 | state["trust_ratio"] = trust_ratio
122 | if self.adam:
123 | trust_ratio = 1
124 |
125 | p.data.add_(adam_step, alpha=-step_size * trust_ratio)
126 |
127 | return loss
128 |
--------------------------------------------------------------------------------
/arm/utils.py:
--------------------------------------------------------------------------------
1 | # Adapted from: https://github.com/stepjam/ARM/blob/main/arm/utils.py
2 |
3 | import torch
4 | import numpy as np
5 | from scipy.spatial.transform import Rotation
6 |
7 | import pyrender
8 | import trimesh
9 | from pyrender.trackball import Trackball
10 |
11 |
12 | def normalize_quaternion(quat):
13 | return np.array(quat) / np.linalg.norm(quat, axis=-1, keepdims=True)
14 |
15 |
16 | def quaternion_to_discrete_euler(quaternion, resolution):
17 | euler = Rotation.from_quat(quaternion).as_euler("xyz", degrees=True) + 180
18 | assert np.min(euler) >= 0 and np.max(euler) <= 360
19 | disc = np.around((euler / resolution)).astype(int)
20 | disc[disc == int(360 / resolution)] = 0
21 | return disc
22 |
23 |
24 | def discrete_euler_to_quaternion(discrete_euler, resolution):
25 | euluer = (discrete_euler * resolution) - 180
26 | return Rotation.from_euler("xyz", euluer, degrees=True).as_quat()
27 |
28 |
29 | def point_to_voxel_index(
30 | point: np.ndarray, voxel_size: np.ndarray, coord_bounds: np.ndarray
31 | ):
32 | bb_mins = np.array(coord_bounds[0:3])
33 | bb_maxs = np.array(coord_bounds[3:])
34 | dims_m_one = np.array([voxel_size] * 3) - 1
35 | bb_ranges = bb_maxs - bb_mins
36 | res = bb_ranges / (np.array([voxel_size] * 3) + 1e-12)
37 | voxel_indicy = np.minimum(
38 | np.floor((point - bb_mins) / (res + 1e-12)).astype(np.int32), dims_m_one
39 | )
40 | return voxel_indicy
41 |
42 |
43 | def stack_on_channel(x):
44 | # expect (B, T, C, ...)
45 | return torch.cat(torch.split(x, 1, dim=1), dim=2).squeeze(1)
46 |
47 |
48 | def _compute_initial_camera_pose(scene):
49 | # Adapted from:
50 | # https://github.com/mmatl/pyrender/blob/master/pyrender/viewer.py#L1032
51 | centroid = scene.centroid
52 | scale = scene.scale
53 | # if scale == 0.0:
54 | # scale = DEFAULT_SCENE_SCALE
55 | scale = 4.0
56 | s2 = 1.0 / np.sqrt(2.0)
57 | cp = np.eye(4)
58 | cp[:3, :3] = np.array([[0.0, -s2, s2], [1.0, 0.0, 0.0], [0.0, s2, s2]])
59 | hfov = np.pi / 6.0
60 | dist = scale / (2.0 * np.tan(hfov))
61 | cp[:3, 3] = dist * np.array([1.0, 0.0, 1.0]) + centroid
62 | return cp
63 |
64 |
65 | def _from_trimesh_scene(trimesh_scene, bg_color=None, ambient_light=None):
66 | # convert trimesh geometries to pyrender geometries
67 | geometries = {
68 | name: pyrender.Mesh.from_trimesh(geom, smooth=False)
69 | for name, geom in trimesh_scene.geometry.items()
70 | }
71 | # create the pyrender scene object
72 | scene_pr = pyrender.Scene(bg_color=bg_color, ambient_light=ambient_light)
73 | # add every node with geometry to the pyrender scene
74 | for node in trimesh_scene.graph.nodes_geometry:
75 | pose, geom_name = trimesh_scene.graph[node]
76 | scene_pr.add(geometries[geom_name], pose=pose)
77 | return scene_pr
78 |
79 |
80 | def create_voxel_scene(
81 | voxel_grid: np.ndarray,
82 | q_attention: np.ndarray = None,
83 | highlight_coordinate: np.ndarray = None,
84 | highlight_gt_coordinate: np.ndarray = None,
85 | highlight_alpha: float = 1.0,
86 | voxel_size: float = 0.1,
87 | show_bb: bool = False,
88 | alpha: float = 0.5,
89 | ):
90 | _, d, h, w = voxel_grid.shape
91 | v = voxel_grid.transpose((1, 2, 3, 0))
92 | occupancy = v[:, :, :, -1] != 0
93 | alpha = np.expand_dims(np.full_like(occupancy, alpha, dtype=np.float32), -1)
94 | rgb = np.concatenate([(v[:, :, :, 3:6] + 1) / 2.0, alpha], axis=-1)
95 |
96 | if q_attention is not None:
97 | q = np.max(q_attention, 0)
98 | q = q / np.max(q)
99 | show_q = q > 0.75
100 | occupancy = (show_q + occupancy).astype(bool)
101 | q = np.expand_dims(q - 0.5, -1) # Max q can be is 0.9
102 | q_rgb = np.concatenate(
103 | [q, np.zeros_like(q), np.zeros_like(q), np.clip(q, 0, 1)], axis=-1
104 | )
105 | rgb = np.where(np.expand_dims(show_q, -1), q_rgb, rgb)
106 |
107 | if highlight_coordinate is not None:
108 | x, y, z = highlight_coordinate
109 | occupancy[x, y, z] = True
110 | rgb[x, y, z] = [1.0, 0.0, 0.0, highlight_alpha]
111 |
112 | if highlight_gt_coordinate is not None:
113 | x, y, z = highlight_gt_coordinate
114 | occupancy[x, y, z] = True
115 | rgb[x, y, z] = [0.0, 0.0, 1.0, highlight_alpha]
116 |
117 | transform = trimesh.transformations.scale_and_translate(
118 | scale=voxel_size, translate=(0.0, 0.0, 0.0)
119 | )
120 | trimesh_voxel_grid = trimesh.voxel.VoxelGrid(
121 | encoding=occupancy, transform=transform
122 | )
123 | geometry = trimesh_voxel_grid.as_boxes(colors=rgb)
124 | scene = trimesh.Scene()
125 | scene.add_geometry(geometry)
126 | if show_bb:
127 | assert d == h == w
128 | _create_bounding_box(scene, voxel_size, d)
129 | return scene
130 |
131 |
132 | def visualise_voxel(
133 | voxel_grid: np.ndarray,
134 | q_attention: np.ndarray = None,
135 | highlight_coordinate: np.ndarray = None,
136 | highlight_gt_coordinate: np.ndarray = None,
137 | highlight_alpha: float = 1.0,
138 | rotation_amount: float = 0.0,
139 | show: bool = False,
140 | voxel_size: float = 0.1,
141 | offscreen_renderer: pyrender.OffscreenRenderer = None,
142 | show_bb: bool = False,
143 | alpha: float = 0.5,
144 | render_gripper=False,
145 | gripper_pose=None,
146 | gripper_mesh_scale=1.0,
147 | ):
148 | scene = create_voxel_scene(
149 | voxel_grid,
150 | q_attention,
151 | highlight_coordinate,
152 | highlight_gt_coordinate,
153 | highlight_alpha,
154 | voxel_size,
155 | show_bb,
156 | alpha,
157 | )
158 | if show:
159 | scene.show()
160 | else:
161 | r = offscreen_renderer or pyrender.OffscreenRenderer(
162 | viewport_width=1920, viewport_height=1080, point_size=1.0
163 | )
164 | s = _from_trimesh_scene(
165 | scene, ambient_light=[0.8, 0.8, 0.8], bg_color=[1.0, 1.0, 1.0]
166 | )
167 | cam = pyrender.PerspectiveCamera(
168 | yfov=np.pi / 4.0, aspectRatio=r.viewport_width / r.viewport_height
169 | )
170 | p = _compute_initial_camera_pose(s)
171 | t = Trackball(p, (r.viewport_width, r.viewport_height), s.scale, s.centroid)
172 | t.rotate(rotation_amount, np.array([0.0, 0.0, 1.0]))
173 | s.add(cam, pose=t.pose)
174 |
175 | if render_gripper:
176 | gripper_trimesh = trimesh.load("peract_colab/meshes/hand.dae", force="mesh")
177 | gripper_trimesh.vertices *= gripper_mesh_scale
178 | radii = np.linalg.norm(
179 | gripper_trimesh.vertices - gripper_trimesh.center_mass, axis=1
180 | )
181 | gripper_trimesh.visual.vertex_colors = trimesh.visual.interpolate(
182 | radii * gripper_mesh_scale, color_map="winter"
183 | )
184 | gripper_mesh = pyrender.Mesh.from_trimesh(
185 | gripper_trimesh, poses=np.array([gripper_pose]), smooth=False
186 | )
187 | s.add(gripper_mesh)
188 | color, depth = r.render(s)
189 | return color.copy()
190 |
191 |
192 | def get_gripper_render_pose(
193 | voxel_scale, scene_bound_origin, continuous_trans, continuous_quat
194 | ):
195 | # finger tip to gripper offset
196 | offset = np.array(
197 | [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.1 * voxel_scale], [0, 0, 0, 1]]
198 | )
199 |
200 | # scale and translate by origin
201 | translation = (continuous_trans - (np.array(scene_bound_origin[:3]))) * voxel_scale
202 | mat = np.eye(4, 4)
203 | mat[:3, :3] = Rotation.from_quat(
204 | [continuous_quat[0], continuous_quat[1], continuous_quat[2], continuous_quat[3]]
205 | ).as_matrix()
206 | offset_mat = np.matmul(mat, offset)
207 | mat[:3, 3] = translation - offset_mat[:3, 3]
208 | return mat
209 |
--------------------------------------------------------------------------------
/assets/hair_dryer_behind_table.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/hair_dryer_behind_table.gif
--------------------------------------------------------------------------------
/assets/hair_dryer_completion.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/hair_dryer_completion.gif
--------------------------------------------------------------------------------
/assets/hair_dryer_completion_blender.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/hair_dryer_completion_blender.gif
--------------------------------------------------------------------------------
/assets/hair_dryer_completion_blender_legend.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/hair_dryer_completion_blender_legend.png
--------------------------------------------------------------------------------
/assets/hair_dryer_scene.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/hair_dryer_scene.png
--------------------------------------------------------------------------------
/assets/matterport.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/matterport.png
--------------------------------------------------------------------------------
/assets/teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/teaser.gif
--------------------------------------------------------------------------------
/assets/vn_poster_relevancies.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/vn_poster_relevancies.png
--------------------------------------------------------------------------------
/datagen/README.md:
--------------------------------------------------------------------------------
1 | # THOR Data Generation Instructions
2 |
3 | This data generation code has been tested on Ubuntu 18.04 with a NVIDIA GTX 1080 and Unity version 2019.4.38f1.
4 |
5 | ## Compiling a Custom Version of THOR
6 |
7 | In [Unity](https://unity.com/) version 2019.4.38f1, open the Unity project from our [custom THOR repo](https://github.com/huy-ha/ai2thor).
8 | Our custom THOR extends it with a [`SceneVolumeExporter`](https://github.com/huy-ha/ai2thor/blob/main/unity/Assets/SceneVolumeExporter.cs) for exporting 3D scenes from THOR.
9 | You can modify the output directory by changing [`rootPath`](https://github.com/huy-ha/ai2thor/blob/main/unity/Assets/SceneVolumeExporter.cs#L8) to somewhere convenient.
10 | This will be the path used for `--path_to_exported_scenes` in `generate_thor_data.py`.
11 | You can also export at a higher [resolution](https://github.com/huy-ha/ai2thor/blob/main/unity/Assets/SceneVolumeExporter.cs#L355), which will take longer to export.
12 |
13 | When building the project, make sure to select all scenes.
14 |
15 | 
16 |
17 | This will be used for `--path_to_custom_unity` in `generate_thor_data.py`.
18 |
19 | ## Generating data
20 |
21 | ### Export/download all 3D scenes
22 |
23 | `SceneVolumeExporter` works by [naively grid sampling THOR scenes with a collision probe](https://github.com/huy-ha/ai2thor/blob/main/unity/Assets/SceneVolumeExporter.cs#L378-L420), then dumps the collision results to `.txt` files.
24 | It also [queries all receptacles](https://github.com/huy-ha/ai2thor/blob/main/unity/Assets/SceneVolumeExporter.cs#L131-L172) and dumps this information to a `.txt` file.
25 | The result is the un-transformed 3D semantic occupancy of all 3D scenes at the `rootPath`.
26 | While this is costly, it only needs to be done once.
27 | Alternatively, download our pre-processed exported 3D scenes.
28 |
29 | ```sh
30 | wget https://semantic-abstraction.cs.columbia.edu/downloads/groundtruth3dpoints.tar.lz4 -O - | tar --use-compress-program=lz4 -xf - -C ./
31 | ```
32 |
33 |
34 | When `generate_thor_data.py` is ran, it will first check that all scenes' ground truth 3D semantic occupancies have been exported to `path_to_exported_scenes`.
35 |
36 | Note: if you get "Reading from AI2-THOR backend timed out" errors, you can remove the timeout limit in [AI2THOR's python server](https://github.com/allenai/ai2thor/blob/main/ai2thor/fifo_server.py#L130-L131).
37 |
38 |
39 | ### Generating SemAbs data points
40 |
41 | After installing `ai2thor`
42 |
43 | ```sh
44 | pip install ai2thor==4.2.0
45 | ```
46 | you're ready to run the data generation script. It works as follows:
47 | 1. choose a random Thor room.
48 | 2. choose a random view point in that room.
49 | 3. transform and filters the ground truth 3D semantic occupancy at `path_to_exported_scenes` to match the view point.
50 | 4. preprocess return value from Thor (segmentation maps, object class names, generate TSDFs, etc.)
51 | 5. detect all spatial relations.
52 |
53 | An example command might look like this
54 | ```sh
55 | python generate_thor_data.py --dataset_dir_path /path/to/dataset --num_pts 50000 --path_to_custom_unity /path/to/unity/exec --path_to_exported_scenes /path/to/3dscenes
56 | ```
57 | If you have issues launching THOR, try disabling multi-processing by using the `--local` flag.
58 |
59 | ### Generating relevancies
60 |
61 | The last step is to generate the relevancies for the dumped data points,
62 | ```
63 | python generate_relevancy.py dataset /path/to/dataset num_processes
64 | ```
--------------------------------------------------------------------------------
/datagen/assets/unity-build.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/datagen/assets/unity-build.png
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | from tqdm import tqdm
4 | import torch
5 | import os
6 | import pickle
7 | from dataset import ObjectLocalizationDataset, SceneCompletionDataset
8 | from train_vool import get_losses as vool_get_losses, approach as vool_approaches
9 | from train_ovssc import get_losses as ovssc_get_losses, approach as ovssc_approaches
10 | import utils
11 | from torch.utils.data import DataLoader
12 | import pandas as pd
13 | import torch.distributed as dist
14 | from torch.utils.data.distributed import DistributedSampler
15 |
16 | if __name__ == "__main__":
17 | parser = utils.config_parser()
18 | parser.add_argument("--task", choices=["ovssc", "vool"], required=True)
19 | args = parser.parse_args()
20 | with open(os.path.dirname(args.load) + "/args.pkl", "rb") as file:
21 | exp_args = pickle.load(file)
22 | for arg in vars(exp_args):
23 | if any(arg == s for s in ["device", "file_path", "load", "gpus", "task"]):
24 | continue
25 | setattr(args, arg, getattr(exp_args, arg))
26 | args.domain_randomization = False
27 | args.scene_bounds = torch.tensor(args.scene_bounds)
28 | args.batch_size = 1
29 | args.num_workers = 8
30 | args.balance_spatial_sampling = False
31 | args.detailed_analysis = True
32 | ddp = len(args.gpus) > 1
33 | approaches = ovssc_approaches if args.task == "ovssc" else vool_approaches
34 | dataset_class = (
35 | SceneCompletionDataset if args.task == "ovssc" else ObjectLocalizationDataset
36 | )
37 | exp_dict = utils.setup_experiment(
38 | args=args,
39 | net_class=approaches[args.approach],
40 | dataset_class=dataset_class,
41 | split_file_path=args.file_path
42 | + ("/vool_split.pkl" if args.task == "vool" else "/ssc_split.pkl"),
43 | return_vis=True,
44 | ddp=ddp,
45 | )
46 | net = exp_dict["net"]
47 | net.eval()
48 | net.requires_grad = False
49 | epoch = exp_dict["start_epoch"]
50 | eval_detailed_stats = pd.DataFrame()
51 | with torch.no_grad():
52 | for split, dataset in exp_dict["datasets"].items():
53 | if split == "train":
54 | continue
55 | sampler = None
56 | if ddp:
57 | sampler = DistributedSampler(
58 | dataset=dataset, shuffle=False, drop_last=False
59 | )
60 | sampler.set_epoch(0)
61 | loader = DataLoader(
62 | dataset=dataset,
63 | num_workers=args.num_workers,
64 | batch_size=1,
65 | sampler=sampler,
66 | )
67 | detailed_stats = utils.loop(
68 | net=net,
69 | loader=loader,
70 | get_losses_fn=ovssc_get_losses
71 | if args.task == "ovssc"
72 | else vool_get_losses,
73 | **{
74 | **vars(args),
75 | "optimizer": None,
76 | "lr_scheduler": None,
77 | "cutoffs": np.arange(-2.5, -0.0, 0.1),
78 | "pbar": tqdm(
79 | total=len(loader),
80 | dynamic_ncols=True,
81 | unit="batch",
82 | postfix=f"| {split.upper()} ",
83 | ),
84 | "detailed_analysis": True,
85 | },
86 | )
87 | detailed_stats["epoch"] = [epoch] * len(detailed_stats)
88 | detailed_stats["split"] = [split] * len(detailed_stats)
89 | eval_detailed_stats = pd.concat([eval_detailed_stats, detailed_stats])
90 | if (ddp and dist.get_rank() == 0) or not ddp:
91 | stats_path = os.path.splitext(args.load)[0] + f"_eval_stats.pkl"
92 | eval_detailed_stats.to_pickle(stats_path)
93 | print("dumped stats to ", stats_path)
94 |
--------------------------------------------------------------------------------
/fusion.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2018 Andy Zeng
2 | # Source: https://github.com/andyzeng/tsdf-fusion-python/blob/master/fusion.py
3 | # BSD 2-Clause License
4 |
5 | # Copyright (c) 2019, Princeton University
6 | # All rights reserved.
7 |
8 | # Redistribution and use in source and binary forms, with or without
9 | # modification, are permitted provided that the following conditions are met:
10 |
11 | # 1. Redistributions of source code must retain the above copyright notice, this
12 | # list of conditions and the following disclaimer.
13 |
14 | # 2. Redistributions in binary form must reproduce the above copyright notice,
15 | # this list of conditions and the following disclaimer in the documentation
16 | # and/or other materials provided with the distribution.
17 |
18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
28 |
29 | import numpy as np
30 | from numba import njit, prange
31 | from skimage import measure
32 |
33 |
34 | class TSDFVolume:
35 | """Volumetric TSDF Fusion of RGB-D Images."""
36 |
37 | def __init__(self, vol_bnds, voxel_size):
38 | """Constructor.
39 | Args:
40 | vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the
41 | xyz bounds (min/max) in meters.
42 | voxel_size (float): The volume discretization in meters.
43 | """
44 | vol_bnds = np.asarray(vol_bnds)
45 | assert vol_bnds.shape == (3, 2), "[!] `vol_bnds` should be of shape (3, 2)."
46 | assert (vol_bnds[:, 0] < vol_bnds[:, 1]).all()
47 |
48 | # Define voxel volume parameters
49 | self._vol_bnds = vol_bnds
50 | self._voxel_size = float(voxel_size)
51 | self._trunc_margin = 5 * self._voxel_size # truncation on SDF
52 | self._color_const = 256 * 256
53 |
54 | # Adjust volume bounds and ensure C-order contiguous
55 | self._vol_dim = (
56 | np.ceil((self._vol_bnds[:, 1] - self._vol_bnds[:, 0]) / self._voxel_size)
57 | .copy(order="C")
58 | .astype(int)
59 | )
60 | self._vol_bnds[:, 1] = self._vol_bnds[:, 0] + self._vol_dim * self._voxel_size
61 | self._vol_origin = self._vol_bnds[:, 0].copy(order="C").astype(np.float32)
62 |
63 | # Initialize pointers to voxel volume in CPU memory
64 | # Assume all unobserved regions are occupied
65 | self._tsdf_vol_cpu = -np.ones(self._vol_dim).astype(np.float32)
66 | # for computing the cumulative moving average of observations per voxel
67 | self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
68 | self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32)
69 |
70 | # Get voxel grid coordinates
71 | xv, yv, zv = np.meshgrid(
72 | range(self._vol_dim[0]),
73 | range(self._vol_dim[1]),
74 | range(self._vol_dim[2]),
75 | indexing="ij",
76 | )
77 | self.vox_coords = (
78 | np.concatenate(
79 | [xv.reshape(1, -1), yv.reshape(1, -1), zv.reshape(1, -1)], axis=0
80 | )
81 | .astype(int)
82 | .T
83 | )
84 |
85 | @staticmethod
86 | @njit(parallel=True)
87 | def vox2world(vol_origin, vox_coords, vox_size):
88 | """Convert voxel grid coordinates to world coordinates."""
89 | vol_origin = vol_origin.astype(np.float32)
90 | vox_coords = vox_coords.astype(np.float32)
91 | cam_pts = np.empty_like(vox_coords, dtype=np.float32)
92 | for i in prange(vox_coords.shape[0]):
93 | for j in range(3):
94 | cam_pts[i, j] = vol_origin[j] + (vox_size * vox_coords[i, j])
95 | return cam_pts
96 |
97 | @staticmethod
98 | @njit(parallel=True)
99 | def cam2pix(cam_pts, intr):
100 | """Convert camera coordinates to pixel coordinates."""
101 | intr = intr.astype(np.float32)
102 | fx, fy = intr[0, 0], intr[1, 1]
103 | cx, cy = intr[0, 2], intr[1, 2]
104 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64)
105 | for i in prange(cam_pts.shape[0]):
106 | pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx))
107 | pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy))
108 | return pix
109 |
110 | @staticmethod
111 | @njit(parallel=True)
112 | def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight):
113 | """Integrate the TSDF volume."""
114 | tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32)
115 | w_new = np.empty_like(w_old, dtype=np.float32)
116 | for i in prange(len(tsdf_vol)):
117 | w_new[i] = w_old[i] + obs_weight
118 | tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] + obs_weight * dist[i]) / w_new[i]
119 | return tsdf_vol_int, w_new
120 |
121 | def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.0):
122 | """Integrate an RGB-D frame into the TSDF volume.
123 | Args:
124 | color_im (ndarray): An RGB image of shape (H, W, 3).
125 | depth_im (ndarray): A depth image of shape (H, W).
126 | cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3).
127 | cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4).
128 | obs_weight (float): The weight to assign for the current observation. A higher
129 | value
130 | """
131 | im_h, im_w = depth_im.shape
132 |
133 | # Fold RGB color image into a single channel image
134 | color_im = color_im.astype(np.float32)
135 | color_im = np.floor(
136 | color_im[..., 2] * self._color_const
137 | + color_im[..., 1] * 256
138 | + color_im[..., 0]
139 | )
140 |
141 | # Convert voxel grid coordinates to pixel coordinates
142 | cam_pts = self.vox2world(self._vol_origin, self.vox_coords, self._voxel_size)
143 | cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose))
144 | pix_z = cam_pts[:, 2]
145 | pix = self.cam2pix(cam_pts, cam_intr)
146 | pix_x, pix_y = pix[:, 0], pix[:, 1]
147 |
148 | # Eliminate pixels outside view frustum
149 | valid_pix = np.logical_and(
150 | pix_x >= 0,
151 | np.logical_and(
152 | pix_x < im_w,
153 | np.logical_and(pix_y >= 0, np.logical_and(pix_y < im_h, pix_z > 0)),
154 | ),
155 | )
156 | depth_val = np.zeros(pix_x.shape)
157 | depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]]
158 |
159 | # Integrate TSDF
160 | depth_diff = depth_val - pix_z
161 | valid_pts = np.logical_and(depth_val > 0, depth_diff >= -self._trunc_margin)
162 | dist = np.maximum(-1, np.minimum(1, depth_diff / self._trunc_margin))
163 | valid_vox_x = self.vox_coords[valid_pts, 0]
164 | valid_vox_y = self.vox_coords[valid_pts, 1]
165 | valid_vox_z = self.vox_coords[valid_pts, 2]
166 | w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
167 | tsdf_vals = self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
168 | valid_dist = dist[valid_pts]
169 | tsdf_vol_new, w_new = self.integrate_tsdf(
170 | tsdf_vals, valid_dist, w_old, obs_weight
171 | )
172 | self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new
173 | self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_vol_new
174 |
175 | # Integrate color
176 | old_color = self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z]
177 | old_b = np.floor(old_color / self._color_const)
178 | old_g = np.floor((old_color - old_b * self._color_const) / 256)
179 | old_r = old_color - old_b * self._color_const - old_g * 256
180 | new_color = color_im[pix_y[valid_pts], pix_x[valid_pts]]
181 | new_b = np.floor(new_color / self._color_const)
182 | new_g = np.floor((new_color - new_b * self._color_const) / 256)
183 | new_r = new_color - new_b * self._color_const - new_g * 256
184 | new_b = np.minimum(
185 | 255.0, np.round((w_old * old_b + obs_weight * new_b) / w_new)
186 | )
187 | new_g = np.minimum(
188 | 255.0, np.round((w_old * old_g + obs_weight * new_g) / w_new)
189 | )
190 | new_r = np.minimum(
191 | 255.0, np.round((w_old * old_r + obs_weight * new_r) / w_new)
192 | )
193 | self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = (
194 | new_b * self._color_const + new_g * 256 + new_r
195 | )
196 |
197 | def get_volume(self):
198 | # Fold RGB color image into a single channel image
199 | color_vol = np.zeros([3] + list(self._color_vol_cpu.shape)).astype(np.uint8)
200 | color_vol[2, ...] = np.floor(self._color_vol_cpu / self._color_const)
201 | color_vol[1, ...] = np.floor(
202 | (self._color_vol_cpu - color_vol[2, ...] * self._color_const) / 256
203 | )
204 | color_vol[0, ...] = (
205 | self._color_vol_cpu
206 | - color_vol[2, ...] * self._color_const
207 | - color_vol[1, ...] * 256
208 | )
209 | return self._tsdf_vol_cpu, color_vol
210 |
211 | def get_point_cloud(self):
212 | """Extract a point cloud from the voxel volume."""
213 | tsdf_vol, color_vol = self.get_volume()
214 |
215 | # Marching cubes
216 | verts = measure.marching_cubes_lewiner(tsdf_vol, level=0)[0]
217 | verts_ind = np.round(verts).astype(int)
218 | verts = verts * self._voxel_size + self._vol_origin
219 |
220 | # Get vertex colors
221 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]]
222 | colors_b = np.floor(rgb_vals / self._color_const)
223 | colors_g = np.floor((rgb_vals - colors_b * self._color_const) / 256)
224 | colors_r = rgb_vals - colors_b * self._color_const - colors_g * 256
225 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T
226 | colors = colors.astype(np.uint8)
227 |
228 | pc = np.hstack([verts, colors])
229 | return pc
230 |
231 | def get_mesh(self):
232 | """Compute a mesh from the voxel volume using marching cubes."""
233 | tsdf_vol, color_vol = self.get_volume()
234 |
235 | # Marching cubes
236 | verts, faces, norms, vals = measure.marching_cubes_lewiner(tsdf_vol, level=0)
237 | verts_ind = np.round(verts).astype(int)
238 | # voxel grid coordinates to world coordinates
239 | verts = verts * self._voxel_size + self._vol_origin
240 |
241 | # Get vertex colors
242 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]]
243 | colors_b = np.floor(rgb_vals / self._color_const)
244 | colors_g = np.floor((rgb_vals - colors_b * self._color_const) / 256)
245 | colors_r = rgb_vals - colors_b * self._color_const - colors_g * 256
246 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T
247 | colors = colors.astype(np.uint8)
248 | return verts, faces, norms, colors
249 |
250 |
251 | def rigid_transform(xyz, transform):
252 | """Applies a rigid transform to an (N, 3) pointcloud."""
253 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)])
254 | xyz_t_h = np.dot(transform, xyz_h.T).T
255 | return xyz_t_h[:, :3]
256 |
257 |
258 | def get_view_frustum(depth_im, cam_intr, cam_pose):
259 | """Get corners of 3D camera view frustum of depth image"""
260 | im_h = depth_im.shape[0]
261 | im_w = depth_im.shape[1]
262 | max_depth = np.max(depth_im)
263 | view_frust_pts = np.array(
264 | [
265 | (np.array([0, 0, 0, im_w, im_w]) - cam_intr[0, 2])
266 | * np.array([0, max_depth, max_depth, max_depth, max_depth])
267 | / cam_intr[0, 0],
268 | (np.array([0, 0, im_h, 0, im_h]) - cam_intr[1, 2])
269 | * np.array([0, max_depth, max_depth, max_depth, max_depth])
270 | / cam_intr[1, 1],
271 | np.array([0, max_depth, max_depth, max_depth, max_depth]),
272 | ]
273 | )
274 | view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T
275 | return view_frust_pts
276 |
277 |
278 | def meshwrite(filename, verts, faces, norms, colors):
279 | """Save a 3D mesh to a polygon .ply file."""
280 | # Write header
281 | ply_file = open(filename, "w")
282 | ply_file.write("ply\n")
283 | ply_file.write("format ascii 1.0\n")
284 | ply_file.write("element vertex %d\n" % (verts.shape[0]))
285 | ply_file.write("property float x\n")
286 | ply_file.write("property float y\n")
287 | ply_file.write("property float z\n")
288 | ply_file.write("property float nx\n")
289 | ply_file.write("property float ny\n")
290 | ply_file.write("property float nz\n")
291 | ply_file.write("property uchar red\n")
292 | ply_file.write("property uchar green\n")
293 | ply_file.write("property uchar blue\n")
294 | ply_file.write("element face %d\n" % (faces.shape[0]))
295 | ply_file.write("property list uchar int vertex_index\n")
296 | ply_file.write("end_header\n")
297 |
298 | # Write vertex list
299 | for i in range(verts.shape[0]):
300 | ply_file.write(
301 | "%f %f %f %f %f %f %d %d %d\n"
302 | % (
303 | verts[i, 0],
304 | verts[i, 1],
305 | verts[i, 2],
306 | norms[i, 0],
307 | norms[i, 1],
308 | norms[i, 2],
309 | colors[i, 0],
310 | colors[i, 1],
311 | colors[i, 2],
312 | )
313 | )
314 |
315 | # Write face list
316 | for i in range(faces.shape[0]):
317 | ply_file.write("3 %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2]))
318 |
319 | ply_file.close()
320 |
321 |
322 | def pcwrite(filename, xyzrgb):
323 | """Save a point cloud to a polygon .ply file."""
324 | xyz = xyzrgb[:, :3]
325 | rgb = xyzrgb[:, 3:].astype(np.uint8)
326 |
327 | # Write header
328 | ply_file = open(filename, "w")
329 | ply_file.write("ply\n")
330 | ply_file.write("format ascii 1.0\n")
331 | ply_file.write("element vertex %d\n" % (xyz.shape[0]))
332 | ply_file.write("property float x\n")
333 | ply_file.write("property float y\n")
334 | ply_file.write("property float z\n")
335 | ply_file.write("property uchar red\n")
336 | ply_file.write("property uchar green\n")
337 | ply_file.write("property uchar blue\n")
338 | ply_file.write("end_header\n")
339 |
340 | # Write vertex list
341 | for i in range(xyz.shape[0]):
342 | ply_file.write(
343 | "%f %f %f %d %d %d\n"
344 | % (
345 | xyz[i, 0],
346 | xyz[i, 1],
347 | xyz[i, 2],
348 | rgb[i, 0],
349 | rgb[i, 1],
350 | rgb[i, 2],
351 | )
352 | )
353 |
--------------------------------------------------------------------------------
/generate_relevancy.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from pathlib import Path
3 | import h5py
4 | import torch
5 | from tqdm import tqdm
6 | import ray
7 | from utils import write_to_hdf5
8 | from filelock import FileLock
9 | import numpy as np
10 | from CLIP.clip import ClipWrapper, saliency_configs, imagenet_templates
11 | from dataset import synonyms, deref_h5py
12 | import typer
13 | import imageio
14 | from matplotlib import pyplot as plt
15 | import cv2
16 | from time import time
17 |
18 | app = typer.Typer()
19 |
20 |
21 | def resize_and_add_data(dataset, data):
22 | data_shape = np.array(data.shape)
23 | dataset_shape = np.array(dataset.shape)
24 | assert (dataset_shape[1:] == data_shape[1:]).all()
25 | dataset.resize(dataset_shape[0] + data_shape[0], axis=0)
26 | dataset[-data_shape[0] :, ...] = data
27 | return [
28 | dataset.regionref[dataset_shape[0] + i, ...]
29 | for i in np.arange(0, data_shape[0])
30 | ]
31 |
32 |
33 | def get_datastructure(image_shape, relevancy_shape, tsdf_dim, num_output_pts, **kwargs):
34 | image_shape = list(image_shape)
35 | relevancy_shape = list(relevancy_shape)
36 | return {
37 | "rgb": {"dtype": "uint8", "item_shape": image_shape + [3]},
38 | "depth": {"dtype": "f", "item_shape": image_shape},
39 | "seg": {"dtype": "i", "item_shape": image_shape},
40 | "saliencies": {"dtype": "f", "item_shape": relevancy_shape},
41 | "tsdf_value_pts": {"dtype": "f", "item_shape": [np.prod(tsdf_dim)]},
42 | "tsdf_xyz_pts": {"dtype": "f", "item_shape": [np.prod(tsdf_dim), 3]},
43 | "full_xyz_pts": {"dtype": "f", "item_shape": [num_output_pts, 3]},
44 | "full_objid_pts": {"dtype": "i", "item_shape": [num_output_pts]},
45 | }
46 |
47 |
48 | def init_dataset(file_path, data_structure):
49 | with h5py.File(file_path, mode="w") as file:
50 | # setup
51 | for key, data_info in data_structure.items():
52 | file.create_dataset(
53 | name=key,
54 | shape=tuple([0] + data_info["item_shape"]),
55 | dtype=data_info["dtype"],
56 | chunks=tuple([1] + data_info["item_shape"]),
57 | compression="gzip",
58 | compression_opts=9,
59 | maxshape=tuple([None] + data_info["item_shape"]),
60 | )
61 |
62 |
63 | @ray.remote
64 | def generate_saliency_helper(
65 | clip_wrapper, rgb_inputs, prompts, text_labels, scene_path, replace
66 | ):
67 | saliencies = {
68 | rgb_name: {
69 | saliency_config_name: ray.get(
70 | clip_wrapper.get_clip_saliency.remote(
71 | img=rgb,
72 | text_labels=text_labels,
73 | prompts=prompts
74 | if "imagenet_prompt_ensemble"
75 | not in saliency_config(img_dim=min(rgb.shape[:2]))
76 | or not saliency_config(img_dim=min(rgb.shape[:2]))[
77 | "imagenet_prompt_ensemble"
78 | ]
79 | else imagenet_templates,
80 | **saliency_config(img_dim=min(rgb.shape[:2])),
81 | )
82 | )
83 | for saliency_config_name, saliency_config in saliency_configs.items()
84 | }
85 | for rgb_name, rgb in rgb_inputs.items()
86 | }
87 | with FileLock(scene_path + ".lock"):
88 | with h5py.File(scene_path, mode="a") as f:
89 | saliency_group = f["data"].create_group("saliencies")
90 | for rgb_name, rgb_saliencies in saliencies.items():
91 | for (
92 | saliency_config_name,
93 | (config_saliency, text_label_features),
94 | ) in rgb_saliencies.items():
95 | storage_dims = np.array(f["saliencies"].shape)[1:]
96 | config_saliency = torch.nn.functional.interpolate(
97 | config_saliency[:, None, :, :],
98 | size=tuple(storage_dims),
99 | mode="nearest-exact"
100 | # mode='bilinear',
101 | # align_corners=False
102 | )[:, 0]
103 |
104 | config_saliency = torch.cat(
105 | [config_saliency, config_saliency.mean(dim=0, keepdim=True)],
106 | dim=0,
107 | )
108 | text_label_features = torch.cat(
109 | [
110 | text_label_features,
111 | text_label_features.mean(dim=0, keepdim=True),
112 | ],
113 | dim=0,
114 | )
115 | text_label_features /= text_label_features.norm(
116 | dim=-1, keepdim=True
117 | )
118 | write_to_hdf5(
119 | saliency_group,
120 | key=rgb_name
121 | + "|"
122 | + saliency_config_name
123 | + "|saliency_text_labels",
124 | value=np.array(text_labels + ["mean"]).astype("S"),
125 | replace=replace,
126 | )
127 | write_to_hdf5(
128 | saliency_group,
129 | key=rgb_name
130 | + "|"
131 | + saliency_config_name
132 | + "|saliency_text_label_features",
133 | value=text_label_features,
134 | replace=replace,
135 | )
136 | region_references = resize_and_add_data(
137 | dataset=f["saliencies"], data=config_saliency
138 | )
139 | write_to_hdf5(
140 | saliency_group,
141 | key=rgb_name + "|" + saliency_config_name,
142 | dtype=h5py.regionref_dtype,
143 | value=region_references,
144 | replace=replace,
145 | )
146 | return clip_wrapper
147 |
148 |
149 | @app.command()
150 | def dataset(
151 | file_path: str,
152 | num_processes: int,
153 | local: bool,
154 | prompts: List[str] = ["a render of a {} in a game engine."],
155 | replace=False,
156 | ):
157 | if "matterport" in file_path or "nyu" in file_path:
158 | prompts = ["a photograph of a {} in a home."]
159 | print(prompts)
160 | tasks = []
161 | ray.init(log_to_driver=True, local_mode=local)
162 | num_cuda_devices = torch.cuda.device_count()
163 | assert num_cuda_devices > 0
164 | print(f"[INFO] FOUND {num_cuda_devices} CUDA DEVICE")
165 | wrapper_actor_cls = ray.remote(ClipWrapper)
166 | available_clip_wrappers = [
167 | wrapper_actor_cls.options(num_gpus=num_cuda_devices / num_processes).remote(
168 | clip_model_type="ViT-B/32", device="cuda"
169 | )
170 | for _ in range(num_processes)
171 | ]
172 |
173 | scene_paths = list(reversed(sorted(map(str, Path(file_path).rglob("*.hdf5")))))
174 | if replace:
175 | if input("Replace = True. Delete existing relevancies? [y/n]") != "y":
176 | exit()
177 | for scene_path in tqdm(
178 | scene_paths, dynamic_ncols=True, desc="deleting existing relevancies"
179 | ):
180 | try:
181 | with h5py.File(scene_path, mode="a") as f:
182 | for k in f["data"]:
183 | if "salienc" in k:
184 | del f[f"data/{k}"]
185 | if "saliencies" in f:
186 | data_shape = list(f["saliencies"].shape[1:])
187 | del f["saliencies"]
188 | f.create_dataset(
189 | name="saliencies",
190 | shape=tuple([0] + data_shape),
191 | dtype="f",
192 | chunks=tuple([1] + data_shape),
193 | compression="gzip",
194 | compression_opts=9,
195 | maxshape=tuple([None] + data_shape),
196 | )
197 | except Exception as e:
198 | print(e, scene_path)
199 | exit()
200 | for scene_path in tqdm(
201 | scene_paths, dynamic_ncols=True, desc="generating relevancies", smoothing=0.001
202 | ):
203 | assert len(available_clip_wrappers) > 0
204 | try:
205 | with h5py.File(scene_path, mode="a") as f:
206 | scene_already_done = "saliencies" in f["data"]
207 | if not scene_already_done or replace:
208 | if scene_already_done:
209 | for k in f["data"]:
210 | if "salienc" in k:
211 | del f[f"data/{k}"]
212 | data_shape = f["saliencies"].shape[1:]
213 | if "saliencies" in f:
214 | del f["saliencies"]
215 | f.create_dataset(
216 | name="saliencies",
217 | shape=tuple([0] + data_shape),
218 | dtype="f",
219 | chunks=tuple([1] + data_shape),
220 | compression="gzip",
221 | compression_opts=9,
222 | maxshape=tuple([None] + data_shape),
223 | )
224 |
225 | if "data/visible_scene_obj_labels" in f:
226 | del f["data/visible_scene_obj_labels"]
227 | objid_to_class = np.array(f[f"data/objid_to_class"]).astype(str)
228 | text_labels = objid_to_class.copy()
229 | scene_has_groundtruth = (
230 | "seg" in f["data"] and "full_objid_pts" in f["data"]
231 | )
232 | visible_scene_obj_labels = text_labels.copy()
233 | if scene_has_groundtruth:
234 | objids_in_scene = list(
235 | set(
236 | deref_h5py(
237 | dataset=f["full_objid_pts"],
238 | refs=f["data/full_objid_pts"],
239 | )
240 | .astype(int)
241 | .reshape(-1)
242 | )
243 | - {-1}
244 | ) # remove empty
245 | scene_object_labels = text_labels.copy()[objids_in_scene]
246 |
247 | # remove objects which are not in view
248 | gt_seg = deref_h5py(dataset=f["seg"], refs=f["data"]["seg"])[0]
249 | visible_obj_ids = list(map(int, set(np.unique(gt_seg)) - {-1}))
250 | visible_obj_labels = text_labels[visible_obj_ids]
251 | visible_scene_obj_labels = list(
252 | set(visible_obj_labels).intersection(
253 | set(scene_object_labels)
254 | )
255 | )
256 | visible_scene_obj_labels = list(
257 | sorted(
258 | set(
259 | map(
260 | lambda c: c.split("[")[0].lstrip().rstrip(),
261 | visible_scene_obj_labels,
262 | )
263 | )
264 | )
265 | )
266 | # visible_scene_obj_labels used to filter
267 | # objects both visible and in scene
268 | text_labels = visible_obj_labels.copy()
269 | text_labels = set(text_labels)
270 |
271 | # create saliency maps necessary for descriptions
272 | if (
273 | "descriptions" in f["data"]
274 | and len(np.array(f["data/descriptions/spatial_relation_name"]))
275 | > 0
276 | ):
277 | target_obj_names = np.array(
278 | f["data/descriptions/target_obj_name"]
279 | ).astype(str)
280 | reference_obj_names = np.array(
281 | f["data/descriptions/reference_obj_name"]
282 | ).astype(str)
283 | spatial_relation_names = np.array(
284 | f["data/descriptions/spatial_relation_name"]
285 | ).astype(str)
286 | text_labels = text_labels.union(
287 | target_obj_names.tolist() + reference_obj_names.tolist()
288 | )
289 |
290 | # gradcam for clip spatial
291 | descriptions = ""
292 | for desc_part in [
293 | target_obj_names,
294 | " ",
295 | spatial_relation_names,
296 | " a ",
297 | reference_obj_names,
298 | ]:
299 | descriptions = np.char.add(descriptions, desc_part)
300 | text_labels = text_labels.union(descriptions)
301 | # descriptions with synonyms
302 | descriptions = ""
303 | for desc_part in [
304 | np.array(
305 | list(
306 | map(
307 | lambda x: x
308 | if x not in synonyms.keys()
309 | else synonyms[x],
310 | target_obj_names,
311 | )
312 | )
313 | ),
314 | " ",
315 | spatial_relation_names,
316 | " a ",
317 | np.array(
318 | list(
319 | map(
320 | lambda x: x
321 | if x not in synonyms.keys()
322 | else synonyms[x],
323 | reference_obj_names,
324 | )
325 | )
326 | ),
327 | ]:
328 | descriptions = np.char.add(descriptions, desc_part)
329 | text_labels = text_labels.union(descriptions)
330 | text_labels = set(
331 | map(lambda c: c.split("[")[0].lstrip().rstrip(), text_labels)
332 | )
333 |
334 | # do synonyms
335 | text_labels = text_labels.union(
336 | map(
337 | lambda text_label: synonyms[text_label],
338 | filter(
339 | lambda text_label: text_label in synonyms, text_labels
340 | ),
341 | )
342 | )
343 | for remove_label in {"unlabelled", "empty", "out of bounds"}:
344 | if remove_label in text_labels:
345 | text_labels.remove(remove_label)
346 | text_labels = list(sorted(text_labels))
347 |
348 | rgb_inputs = {"rgb": np.array(f["rgb"][f["data"]["rgb"][0]][0])}
349 | if (
350 | "domain_randomized_rgb" in f["data"]
351 | and len(np.array(f["data/domain_randomized_rgb"])[0].shape) > 1
352 | ):
353 | rgb_inputs["domain_randomized_rgb"] = np.array(
354 | f["data/domain_randomized_rgb"]
355 | )[0]
356 | write_to_hdf5(
357 | f["data"],
358 | key="visible_scene_obj_labels",
359 | value=np.array(visible_scene_obj_labels).astype("S"),
360 | replace=replace,
361 | )
362 | clip_wrapper = available_clip_wrappers.pop()
363 | tasks.append(
364 | generate_saliency_helper.remote(
365 | clip_wrapper=clip_wrapper,
366 | scene_path=scene_path,
367 | rgb_inputs=rgb_inputs,
368 | text_labels=text_labels,
369 | prompts=prompts,
370 | replace=replace,
371 | )
372 | )
373 | except Exception as e:
374 | print(e)
375 | print(scene_path, "invalid hdf5 file")
376 | if len(available_clip_wrappers) == 0:
377 | readies, tasks = ray.wait(tasks, num_returns=1)
378 | num_readies = len(readies)
379 | try:
380 | available_clip_wrappers.extend(ray.get(readies))
381 | except Exception as e:
382 | print(e)
383 | available_clip_wrappers.extend(
384 | [
385 | wrapper_actor_cls.options(
386 | num_gpus=num_cuda_devices / num_processes
387 | ).remote(clip_model_type="ViT-B/32", device="cuda")
388 | for _ in range(num_readies)
389 | ]
390 | )
391 | ray.get(tasks)
392 |
393 |
394 | @app.command()
395 | def image(
396 | file_path: str = typer.Argument(
397 | default="matterport.png", help="path of image file"
398 | ),
399 | labels: List[str] = typer.Option(
400 | default=[
401 | "basketball jersey",
402 | "nintendo switch",
403 | "television",
404 | "ping pong table",
405 | "vase",
406 | "fireplace",
407 | "abstract painting of a vespa",
408 | "carpet",
409 | "wall",
410 | ],
411 | help='list of object categories (e.g.: "nintendo switch")',
412 | ),
413 | prompts: List[str] = typer.Option(
414 | default=["a photograph of a {} in a home."],
415 | help="prompt template to use with CLIP.",
416 | ),
417 | ):
418 | """
419 | Generates a multi-scale relevancy for image at `file_path`.
420 | """
421 | img = np.array(imageio.imread(file_path))
422 | assert img.dtype == np.uint8
423 | h, w, c = img.shape
424 | start = time()
425 | grads = ClipWrapper.get_clip_saliency(
426 | img=img,
427 | text_labels=np.array(labels),
428 | prompts=prompts,
429 | **saliency_configs["ours"](h),
430 | )[0]
431 | print(f"get gradcam took {float(time() - start)} seconds", grads.shape)
432 | grads -= grads.mean(axis=0)
433 | grads = grads.cpu().numpy()
434 | fig, axes = plt.subplots(3, 3)
435 | axes = axes.flatten()
436 | vmin = 0.002
437 | cmap = plt.get_cmap("jet")
438 | vmax = 0.008
439 | for ax, label_grad, label in zip(axes, grads, labels):
440 | ax.axis("off")
441 | ax.imshow(img)
442 | ax.set_title(label, fontsize=12)
443 | grad = np.clip((label_grad - vmin) / (vmax - vmin), a_min=0.0, a_max=1.0)
444 | colored_grad = cmap(grad)
445 | grad = 1 - grad
446 | colored_grad[..., -1] = grad * 0.7
447 | ax.imshow(colored_grad)
448 | plt.tight_layout(pad=0)
449 | plt.savefig("grads.png")
450 | print("dumped relevancy to grads.png")
451 | plt.show()
452 |
453 |
454 | if __name__ == "__main__":
455 | app()
456 |
--------------------------------------------------------------------------------
/grads.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/grads.png
--------------------------------------------------------------------------------
/matterport.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/matterport.png
--------------------------------------------------------------------------------
/plot_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from matplotlib.patches import Patch
3 | import matplotlib.pyplot as plt
4 | import io
5 | from PIL import Image
6 | import open3d as o3d
7 | from skimage.measure import block_reduce
8 | import matplotlib.cm as cm
9 | import matplotlib as mpl
10 |
11 |
12 | def plot_to_png(fig):
13 | buf = io.BytesIO()
14 | plt.savefig(buf, format="png")
15 | buf.seek(0)
16 | img = np.array(Image.open(buf)).astype(np.uint8)
17 | return img
18 |
19 |
20 | def set_view_and_save_img(fig, ax, views):
21 | for elev, azim in views:
22 | ax.view_init(elev=elev, azim=azim)
23 | yield plot_to_png(fig)
24 |
25 |
26 | def plot_pointcloud(
27 | xyz,
28 | features,
29 | object_labels=None,
30 | background_color=(0.1, 0.1, 0.1, 0.99),
31 | num_points=50000,
32 | views=[(45, 135)],
33 | pts_size=3,
34 | alpha=0.5,
35 | plot_empty=False,
36 | visualize_ghost_points=False,
37 | object_colors=None,
38 | delete_fig=True,
39 | show_plot=False,
40 | bounds=[[-1, -1, -0.1], [1, 1, 1.9]],
41 | ):
42 | is_semantic = len(features.shape) == 1
43 | if type(alpha) is float:
44 | alpha = np.ones(xyz.shape[0]).astype(np.float32) * alpha
45 | if not plot_empty and is_semantic and object_labels is not None:
46 | mask = np.ones_like(alpha).astype(bool)
47 | for remove_label in ["empty", "unlabelled", "out of bounds"]:
48 | if remove_label in object_labels.tolist():
49 | remove_idx = object_labels.tolist().index(remove_label)
50 | mask = np.logical_and(mask, features != remove_idx)
51 | xyz = xyz[mask, :]
52 | features = features[mask, ...]
53 | alpha = alpha[mask]
54 | if type(pts_size) != int and type(pts_size) != float:
55 | pts_size = pts_size[mask]
56 | # subsample
57 | if xyz.shape[0] > num_points:
58 | indices = np.random.choice(xyz.shape[0], size=num_points, replace=False)
59 | xyz = xyz[indices, :]
60 | features = features[indices, ...]
61 | alpha = alpha[indices]
62 | if type(pts_size) != int and type(pts_size) != float:
63 | pts_size = pts_size[indices]
64 |
65 | fig = plt.figure(figsize=(6, 6), dpi=160)
66 | ax = fig.add_subplot(111, projection="3d")
67 | x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
68 | ax.set_facecolor(background_color)
69 | ax.w_xaxis.set_pane_color(background_color)
70 | ax.w_yaxis.set_pane_color(background_color)
71 | ax.w_zaxis.set_pane_color(background_color)
72 | # ax._axis3don = False
73 |
74 | if is_semantic and object_labels is not None:
75 | object_ids = list(np.unique(features))
76 | object_labels = object_labels[object_ids].tolist()
77 | if object_colors is not None:
78 | object_colors = object_colors[object_ids]
79 | features = features.astype(np.int)
80 | # repack object ids
81 | repacked_obj_ids = np.zeros(features.shape).astype(np.uint32)
82 | for i, j in enumerate(object_ids):
83 | repacked_obj_ids[features == j] = i
84 | features = repacked_obj_ids
85 |
86 | object_ids = list(np.unique(features))
87 | colors = np.zeros((len(features), 4)).astype(np.uint8)
88 | if object_colors is None:
89 | cmap = plt.get_cmap("tab20")
90 | object_colors = (255 * cmap(np.array(object_ids) % 20)).astype(np.uint8)
91 | for obj_id in np.unique(features):
92 | colors[features == obj_id, :] = object_colors[obj_id]
93 | colors = colors.astype(float) / 255.0
94 | object_colors = object_colors.astype(float) / 255
95 | handles = [
96 | Patch(facecolor=c, edgecolor="grey", label=label)
97 | for label, c in zip(object_labels, object_colors)
98 | ]
99 |
100 | l = ax.legend(
101 | handles=handles,
102 | labels=object_labels,
103 | loc="lower center",
104 | bbox_to_anchor=(0.5, 0),
105 | ncol=4,
106 | facecolor=(0, 0, 0, 0.1),
107 | fontsize=8,
108 | framealpha=0,
109 | )
110 | plt.setp(l.get_texts(), color=(0.8, 0.8, 0.8))
111 | else:
112 | colors = features.astype(float)
113 | if colors.max() > 1.0:
114 | colors /= 255.0
115 | assert colors.max() <= 1.0
116 | # ensure alpha has same dims as colors
117 | if colors.shape[-1] == 4:
118 | colors[:, -1] = alpha
119 | ax.scatter(x, y, z, c=colors, s=pts_size)
120 | if visualize_ghost_points:
121 | x, y, z = np.array(np.unique(xyz, axis=0)).T
122 | ax.scatter(x, y, z, color=[1.0, 1.0, 1.0, 0.02], s=pts_size)
123 |
124 | # Hide axes ticks
125 | ax.set_xticks([])
126 | ax.set_yticks([])
127 | ax.set_zticks([])
128 | ax.axes.set_xlim3d(left=bounds[0][0], right=bounds[1][0])
129 | ax.axes.set_ylim3d(bottom=bounds[0][1], top=bounds[1][1])
130 | ax.axes.set_zlim3d(bottom=bounds[0][2], top=bounds[1][2])
131 | plt.tight_layout(pad=0)
132 | imgs = list(set_view_and_save_img(fig, ax, views))
133 | if show_plot:
134 | plt.show()
135 | if delete_fig:
136 | plt.close(fig)
137 | return imgs
138 |
139 |
140 | # meshes = []
141 | # for class_id in np.unique(features):
142 | # mask = features == class_id
143 | # pcd = o3d.geometry.PointCloud()
144 | # pcd.points = o3d.utility.Vector3dVector(xyz[mask, :])
145 | # pcd.estimate_normals(
146 | # search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))
147 | # radii = [0.005, 0.01, 0.02, 0.04]
148 | # rec_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
149 | # pcd, o3d.utility.DoubleVector(radii))
150 | # rec_mesh.paint_uniform_color(object_colors[class_id][:3])
151 | # meshes.append(rec_mesh)
152 | # o3d.visualization.draw_geometries(meshes)
153 |
154 |
155 | def view_tsdf(tsdf, simplify=True):
156 | main_color = "#00000055"
157 | mpl.rcParams["text.color"] = main_color
158 | mpl.rcParams["axes.labelcolor"] = main_color
159 | mpl.rcParams["xtick.color"] = main_color
160 | mpl.rcParams["ytick.color"] = main_color
161 | mpl.rc("axes", edgecolor=main_color)
162 | mpl.rcParams["grid.color"] = "#00000033"
163 |
164 | if simplify:
165 | tsdf = block_reduce(tsdf, block_size=(8, 8, 8), func=np.mean)
166 | print("block reduced", tsdf.shape)
167 |
168 | x = np.arange(tsdf.shape[0])[:, None, None]
169 | y = np.arange(tsdf.shape[1])[None, :, None]
170 | z = np.arange(tsdf.shape[2])[None, None, :]
171 | x, y, z = np.broadcast_arrays(x, y, z)
172 |
173 | c = cm.plasma((tsdf.ravel() + 1))
174 | alphas = (tsdf.ravel() < 0).astype(float)
175 | c[..., -1] = alphas
176 |
177 | fig = plt.figure()
178 | ax = fig.gca(projection="3d")
179 | ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=c, s=1)
180 | ax.w_xaxis.set_pane_color((0.0, 0.0, 0.0, 0.0))
181 | ax.w_yaxis.set_pane_color((0.0, 0.0, 0.0, 0.0))
182 | ax.w_zaxis.set_pane_color((0.0, 0.0, 0.0, 0.0))
183 |
184 | # Hide axes ticks
185 | ax.tick_params(axis="x", colors=(0.0, 0.0, 0.0, 0.0))
186 | ax.tick_params(axis="y", colors=(0.0, 0.0, 0.0, 0.0))
187 | ax.tick_params(axis="z", colors=(0.0, 0.0, 0.0, 0.0))
188 | ax.view_init(20, -110)
189 |
190 | plt.show()
191 |
--------------------------------------------------------------------------------
/point_cloud.py:
--------------------------------------------------------------------------------
1 | import pybullet_data
2 | import numpy as np
3 | from numba import njit, prange
4 | import pybullet as p
5 | import matplotlib.pyplot as plt
6 |
7 |
8 | def transform_pointcloud(xyz_pts, rigid_transform):
9 | """Apply rigid transformation to 3D pointcloud.
10 | Args:
11 | xyz_pts: Nx3 float array of 3D points
12 | rigid_transform: 3x4 or 4x4 float array defining a rigid transformation (rotation and translation)
13 | Returns:
14 | xyz_pts: Nx3 float array of transformed 3D points
15 | """
16 | xyz_pts = np.dot(rigid_transform[:3, :3], xyz_pts.T) # apply rotation
17 | # apply translation
18 | xyz_pts = xyz_pts + np.tile(
19 | rigid_transform[:3, 3].reshape(3, 1), (1, xyz_pts.shape[1])
20 | )
21 | return xyz_pts.T
22 |
23 |
24 | def filter_pts_bounds(xyz, bounds):
25 | mask = xyz[:, 0] >= bounds[0, 0]
26 | mask = np.logical_and(mask, xyz[:, 0] <= bounds[1, 0])
27 | mask = np.logical_and(mask, xyz[:, 1] >= bounds[0, 1])
28 | mask = np.logical_and(mask, xyz[:, 1] <= bounds[1, 1])
29 | mask = np.logical_and(mask, xyz[:, 2] >= bounds[0, 2])
30 | mask = np.logical_and(mask, xyz[:, 2] <= bounds[1, 2])
31 | return mask
32 |
33 |
34 | def get_pointcloud(depth_img, color_img, cam_intr, cam_pose=None):
35 | """Get 3D pointcloud from depth image.
36 |
37 | Args:
38 | depth_img: HxW float array of depth values in meters aligned with color_img
39 | color_img: HxWx3 uint8 array of color image
40 | cam_intr: 3x3 float array of camera intrinsic parameters
41 | cam_pose: (optional) 3x4 float array of camera pose matrix
42 |
43 | Returns:
44 | cam_pts: Nx3 float array of 3D points in camera/world coordinates
45 | color_pts: Nx3 uint8 array of color points
46 | """
47 |
48 | img_h = depth_img.shape[0]
49 | img_w = depth_img.shape[1]
50 |
51 | # Project depth into 3D pointcloud in camera coordinates
52 | pixel_x, pixel_y = np.meshgrid(
53 | np.linspace(0, img_w - 1, img_w), np.linspace(0, img_h - 1, img_h)
54 | )
55 | cam_pts_x = np.multiply(pixel_x - cam_intr[0, 2], depth_img / cam_intr[0, 0])
56 | cam_pts_y = np.multiply(pixel_y - cam_intr[1, 2], depth_img / cam_intr[1, 1])
57 | cam_pts_z = depth_img
58 | cam_pts = (
59 | np.array([cam_pts_x, cam_pts_y, cam_pts_z]).transpose(1, 2, 0).reshape(-1, 3)
60 | )
61 |
62 | if cam_pose is not None:
63 | cam_pts = transform_pointcloud(cam_pts, cam_pose)
64 | color_pts = None if color_img is None else color_img.reshape(-1, 3)
65 | # TODO check memory leak here
66 | return cam_pts, color_pts
67 |
68 |
69 | def project_pts_to_2d(pts, camera_view_matrix, camera_intrisic):
70 | """Project points to 2D.
71 | Args:
72 | pts: Nx3 float array of 3D points in world coordinates.
73 | camera_view_matrix: 4x4 float array. A wrd2cam transformation defining camera's totation and translation.
74 | camera_intrisic: 3x3 float array. [ [f,0,0],[0,f,0],[0,0,1] ]. f is focal length.
75 | Returns:
76 | coord_2d: Nx3 float array of 2D pixel. (w, h, d) the last one is depth
77 | """
78 | pts_c = transform_pointcloud(pts, camera_view_matrix[0:3, :])
79 | rot_algix = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0]])
80 | pts_c = transform_pointcloud(pts_c, rot_algix) # Nx3
81 | coord_2d = np.dot(camera_intrisic, pts_c.T) # 3xN
82 | coord_2d[0:2, :] = coord_2d[0:2, :] / np.tile(coord_2d[2, :], (2, 1))
83 | coord_2d[2, :] = pts_c[:, 2]
84 | coord_2d = np.array([coord_2d[1], coord_2d[0], coord_2d[2]])
85 | return coord_2d.T
86 |
87 |
88 | def check_pts_in_frustum(xyz_pts, depth, cam_pose, cam_intr):
89 | # xyz_pts (N,3)
90 | cam_pts = transform_pointcloud(
91 | xyz_pts=xyz_pts, rigid_transform=np.linalg.inv(cam_pose)
92 | )
93 | cam_pts_x = cam_pts[..., 0]
94 | cam_pts_y = cam_pts[..., 1]
95 | pix_z = cam_pts[..., 2]
96 |
97 | pix_x = (cam_intr[0, 0] / pix_z) * cam_pts_x + cam_intr[0, 2]
98 | pix_y = (cam_intr[1, 1] / pix_z) * cam_pts_y + cam_intr[1, 2]
99 |
100 | # camera to pixel space
101 | h, w = depth.shape
102 |
103 | valid_pix = np.logical_and(
104 | pix_x >= 0,
105 | np.logical_and(
106 | pix_x < w, np.logical_and(pix_y >= 0, np.logical_and(pix_y < h, pix_z > 0))
107 | ),
108 | )
109 | in_frustum_mask = valid_pix.reshape(-1)
110 | return in_frustum_mask
111 |
112 |
113 | def meshwrite(filename, verts, colors, faces=None):
114 | """Save 3D mesh to a polygon .ply file.
115 | Args:
116 | filename: string; path to mesh file. (suffix should be .ply)
117 | verts: [N, 3]. Coordinates of each vertex
118 | colors: [N, 3]. RGB or each vertex. (type: uint8)
119 | faces: (optional) [M, 4]
120 | """
121 | # Write header
122 | ply_file = open(filename, "w")
123 | ply_file.write("ply\n")
124 | ply_file.write("format ascii 1.0\n")
125 | ply_file.write("element vertex %d\n" % (verts.shape[0]))
126 | ply_file.write("property float x\n")
127 | ply_file.write("property float y\n")
128 | ply_file.write("property float z\n")
129 | ply_file.write("property uchar red\n")
130 | ply_file.write("property uchar green\n")
131 | ply_file.write("property uchar blue\n")
132 | if faces is not None:
133 | ply_file.write("element face %d\n" % (faces.shape[0]))
134 | ply_file.write("end_header\n")
135 |
136 | # Write vertex list
137 | for i in range(verts.shape[0]):
138 | ply_file.write(
139 | "%f %f %f %d %d %d\n"
140 | % (
141 | verts[i, 0],
142 | verts[i, 1],
143 | verts[i, 2],
144 | colors[i, 0],
145 | colors[i, 1],
146 | colors[i, 2],
147 | )
148 | )
149 |
150 | # Write face list
151 | if faces is not None:
152 | for i in range(faces.shape[0]):
153 | ply_file.write(
154 | "4 %d %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2], faces[i, 3])
155 | )
156 |
157 | ply_file.close()
158 |
159 |
160 | @njit(parallel=True)
161 | def cam2pix(cam_pts, intr):
162 | """Convert camera coordinates to pixel coordinates."""
163 | intr = intr.astype(np.float32)
164 | fx, fy = intr[0, 0], intr[1, 1]
165 | cx, cy = intr[0, 2], intr[1, 2]
166 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64)
167 | for i in prange(cam_pts.shape[0]):
168 | pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx))
169 | pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy))
170 | return pix
171 |
172 |
173 | def compute_empty_mask(
174 | scene_bounds, depth_img, intrinsic_matrix, extrinsic_matrix, voxel_resolution=20
175 | ):
176 | # parts taken from
177 | # https://github.com/andyzeng/tsdf-fusion-python/blob/3f22a940d90f684145b1f29b1feaa92e09eb1db6/fusion.py#L170
178 |
179 | # start off all empty
180 | grid_shape = [voxel_resolution] * 3
181 | mask = np.ones(grid_shape).astype(int)
182 | # get volume points
183 | lc = scene_bounds[0]
184 | uc = scene_bounds[1]
185 |
186 | # get voxel indices
187 | grid_idxs = np.stack(
188 | np.meshgrid(*[np.arange(0, dim) for dim in grid_shape]), axis=-1
189 | )
190 |
191 | # voxel indices to world pts
192 | idx_scale = np.array(grid_shape) - 1
193 | scales = (uc - lc) / idx_scale
194 | offsets = lc
195 | grid_points = grid_idxs.astype(float) * scales + offsets
196 |
197 | flattened_grid_points = grid_points.reshape(-1, 3)
198 | print(flattened_grid_points.min(axis=0), flattened_grid_points.max(axis=0))
199 |
200 | # world pts to camera centric frame pts
201 | xyz_h = np.hstack(
202 | [
203 | flattened_grid_points,
204 | np.ones((len(flattened_grid_points), 1), dtype=np.float32),
205 | ]
206 | )
207 | xyz_t_h = np.dot(np.linalg.inv(extrinsic_matrix), xyz_h.T).T
208 | cam_pts = xyz_t_h[:, :3]
209 | pix_z = cam_pts[:, 2]
210 | pix = cam2pix(cam_pts, intrinsic_matrix)
211 | pix_x, pix_y = pix[:, 0], pix[:, 1]
212 | im_w, im_h = depth_img.shape
213 |
214 | valid_pix = np.logical_and(
215 | pix_x >= 0,
216 | np.logical_and(
217 | pix_x < im_w,
218 | np.logical_and(pix_y >= 0, np.logical_and(pix_y < im_h, pix_z > 0)),
219 | ),
220 | )
221 | inframe_indices = grid_idxs.reshape(-1, 3)[valid_pix, :]
222 |
223 | # depth_val = np.zeros(pix_x.shape)
224 | # depth_val[valid_pix] = depth_img[pix_y[valid_pix], pix_x[valid_pix]]
225 | observed_indices = inframe_indices[
226 | (depth_img[pix_y[valid_pix], pix_x[valid_pix]] > pix_z[valid_pix])
227 | ]
228 |
229 | print("before:", mask.mean(), mask.shape, observed_indices.shape)
230 | for idx in observed_indices:
231 | mask[tuple(idx)] = 0
232 | print(mask.mean())
233 | print(observed_indices.shape, mask.shape)
234 | # mask[observed_indices] = 0
235 | print("after:", mask.mean())
236 |
237 | ax = plt.figure().add_subplot(projection="3d")
238 | ax.voxels(mask)
239 | # pts = grid_points[mask, :]
240 | # ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2])
241 | plt.show()
242 | return mask.astype(bool)
243 |
244 |
245 | def subsample(seg_pts, num_pts, random_state, balanced=True):
246 | probabilities = np.ones(seg_pts.shape).astype(np.float64)
247 | if balanced:
248 | unique_semantic_ids = np.unique(seg_pts)
249 | num_semantic_ids = len(unique_semantic_ids)
250 | for semantic_id in unique_semantic_ids:
251 | mask = seg_pts == semantic_id
252 | probabilities[mask] = 1.0 / (int((mask).sum().item()) * num_semantic_ids)
253 | else:
254 | probabilities /= probabilities.sum()
255 | indices = random_state.choice(
256 | seg_pts.shape[0], size=num_pts, replace=False, p=probabilities
257 | )
258 | return indices
259 |
260 |
261 | if __name__ == "__main__":
262 | # TODO change this to filter input sampled points out based on
263 | # view point
264 | from datagen.simulation.asset import make_object, occluder_objects, partnet_objs
265 | from datagen.simulation import Camera
266 |
267 | object_keys = [k for k in occluder_objects]
268 | object_def = occluder_objects[object_keys[10]]
269 | p.connect(p.GUI)
270 | p.resetDebugVisualizerCamera(
271 | cameraDistance=4.0,
272 | cameraYaw=270,
273 | cameraPitch=-20,
274 | cameraTargetPosition=(0, 0, 0.4),
275 | )
276 | p.setAdditionalSearchPath(pybullet_data.getDataPath())
277 | p.setRealTimeSimulation(False)
278 | p.resetSimulation()
279 | p.setGravity(0, 0, -9.8)
280 | planeUid = p.loadURDF(fileName="plane.urdf", useFixedBase=True)
281 | occluder_obj = make_object(**object_def)
282 | camera = Camera(position=[-1, 1, 1], lookat=[0, 0, 0.5])
283 | view = camera.get_image(return_pose=True, segmentation_mask=True)
284 | mask = compute_empty_mask(
285 | scene_bounds=np.array([[-1.0, -1.0, -0.1], [1.0, 1.0, 1.9]]),
286 | depth_img=view[1],
287 | intrinsic_matrix=view[-2],
288 | extrinsic_matrix=view[-1],
289 | )
290 |
--------------------------------------------------------------------------------
/scene_files/arkit_kitchen.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/arkit_kitchen.pkl
--------------------------------------------------------------------------------
/scene_files/arkit_vn_poster.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/arkit_vn_poster.pkl
--------------------------------------------------------------------------------
/scene_files/matterport_hallway.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/matterport_hallway.pkl
--------------------------------------------------------------------------------
/scene_files/matterport_kitchen.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/matterport_kitchen.pkl
--------------------------------------------------------------------------------
/scene_files/matterport_living_room.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/matterport_living_room.pkl
--------------------------------------------------------------------------------
/scene_files/walle.pkl:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/walle.pkl
--------------------------------------------------------------------------------
/semabs.yml:
--------------------------------------------------------------------------------
1 | name: semabs
2 | channels:
3 | - pyg
4 | - pytorch
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _libgcc_mutex=0.1=main
9 | - _openmp_mutex=5.1=1_gnu
10 | - blas=1.0=mkl
11 | - blosc=1.21.0=h4ff587b_1
12 | - bottleneck=1.3.5=py38h7deecbd_0
13 | - brotli=1.0.9=h5eee18b_7
14 | - brotli-bin=1.0.9=h5eee18b_7
15 | - brotlipy=0.7.0=py38h27cfd23_1003
16 | - brunsli=0.1=h2531618_0
17 | - bzip2=1.0.8=h7f98852_4
18 | - c-ares=1.18.1=h7f8727e_0
19 | - ca-certificates=2022.10.11=h06a4308_0
20 | - certifi=2022.9.24=py38h06a4308_0
21 | - cffi=1.15.0=py38h7f8727e_0
22 | - cfitsio=3.470=h5893167_7
23 | - charls=2.2.0=h2531618_0
24 | - charset-normalizer=2.1.1=pyhd8ed1ab_0
25 | - cloudpickle=2.0.0=pyhd3eb1b0_0
26 | - cryptography=38.0.1=py38h9ce1e76_0
27 | - cudatoolkit=11.6.0=hecad31d_10
28 | - cycler=0.11.0=pyhd3eb1b0_0
29 | - cytoolz=0.12.0=py38h5eee18b_0
30 | - dask-core=2022.7.0=py38h06a4308_0
31 | - ffmpeg=4.3.2=hca11adc_0
32 | - fftw=3.3.9=h27cfd23_1
33 | - fonttools=4.25.0=pyhd3eb1b0_0
34 | - freetype=2.12.1=h4a9f257_0
35 | - fsspec=2022.10.0=py38h06a4308_0
36 | - gettext=0.21.0=hf68c758_0
37 | - giflib=5.2.1=h516909a_2
38 | - gmp=6.2.1=h58526e2_0
39 | - gnutls=3.6.15=he1e5248_0
40 | - icu=58.2=he6710b0_3
41 | - idna=3.4=pyhd8ed1ab_0
42 | - imagecodecs=2021.8.26=py38hf0132c2_1
43 | - importlib-metadata=4.11.3=py38h06a4308_0
44 | - importlib_metadata=4.11.3=hd3eb1b0_0
45 | - intel-openmp=2021.4.0=h06a4308_3561
46 | - jpeg=9e=h7f8727e_0
47 | - jxrlib=1.1=h7b6447c_2
48 | - kiwisolver=1.4.2=py38h295c915_0
49 | - krb5=1.19.2=hac12032_0
50 | - lame=3.100=h7b6447c_0
51 | - lcms2=2.12=h3be6417_0
52 | - lerc=3.0=h9c3ff4c_0
53 | - libaec=1.0.4=he6710b0_1
54 | - libbrotlicommon=1.0.9=h5eee18b_7
55 | - libbrotlidec=1.0.9=h5eee18b_7
56 | - libbrotlienc=1.0.9=h5eee18b_7
57 | - libcurl=7.85.0=h91b91d3_0
58 | - libdeflate=1.8=h7f98852_0
59 | - libedit=3.1.20210910=h7f8727e_0
60 | - libev=4.33=h7f8727e_1
61 | - libffi=3.2.1=hf484d3e_1007
62 | - libgcc-ng=11.2.0=h1234567_1
63 | - libgfortran-ng=11.2.0=h00389a5_1
64 | - libgfortran5=11.2.0=h1234567_1
65 | - libgomp=11.2.0=h1234567_1
66 | - libidn2=2.3.2=h7f8727e_0
67 | - libllvm11=11.1.0=h9e868ea_6
68 | - libnghttp2=1.46.0=hce63b2e_0
69 | - libpng=1.6.37=hed695b0_2
70 | - libssh2=1.10.0=h8f2d780_0
71 | - libstdcxx-ng=11.2.0=h1234567_1
72 | - libtasn1=4.16.0=h27cfd23_0
73 | - libtiff=4.4.0=hecacb30_1
74 | - libunistring=0.9.10=h14c3975_0
75 | - libwebp=1.2.4=h11a3e52_0
76 | - libwebp-base=1.2.4=h5eee18b_0
77 | - libxml2=2.9.14=h74e7548_0
78 | - libzopfli=1.0.3=he6710b0_0
79 | - llvmlite=0.39.1=py38he621ea3_0
80 | - locket=1.0.0=py38h06a4308_0
81 | - lz4-c=1.9.3=h9c3ff4c_1
82 | - matplotlib-base=3.5.3=py38hf590b9c_0
83 | - mkl=2021.4.0=h06a4308_640
84 | - mkl-service=2.4.0=py38h497a2fe_0
85 | - mkl_fft=1.3.1=py38h8666266_1
86 | - mkl_random=1.2.2=py38h1abd341_0
87 | - munkres=1.1.4=py_0
88 | - ncurses=6.3=h5eee18b_3
89 | - nettle=3.7.3=hbbd107a_1
90 | - networkx=2.8.4=py38h06a4308_0
91 | - numba=0.56.3=py38h417a72b_0
92 | - numexpr=2.8.3=py38h807cd23_0
93 | - numpy=1.22.3=py38he7a7128_0
94 | - numpy-base=1.22.3=py38hf524024_0
95 | - openh264=2.1.1=h780b84a_0
96 | - openjpeg=2.4.0=h3ad879b_0
97 | - openssl=1.1.1s=h7f8727e_0
98 | - packaging=21.3=pyhd3eb1b0_0
99 | - pandas=1.4.4=py38h6a678d5_0
100 | - partd=1.2.0=pyhd3eb1b0_1
101 | - pillow=9.2.0=py38hace64e9_1
102 | - pip=22.2.2=py38h06a4308_0
103 | - pycparser=2.21=pyhd8ed1ab_0
104 | - pyopenssl=22.1.0=pyhd8ed1ab_0
105 | - pyparsing=3.0.9=py38h06a4308_0
106 | - pysocks=1.7.1=py38h578d9bd_5
107 | - python=3.8.0=h0371630_2
108 | - python-dateutil=2.8.2=pyhd3eb1b0_0
109 | - python_abi=3.8=2_cp38
110 | - pytorch=1.12.1=py3.8_cuda11.6_cudnn8.3.2_0
111 | - pytorch-mutex=1.0=cuda
112 | - pytorch-scatter=2.0.9=py38_torch_1.12.0_cu116
113 | - pytz=2022.1=py38h06a4308_0
114 | - pywavelets=1.3.0=py38h7f8727e_0
115 | - readline=7.0=h7b6447c_5
116 | - requests=2.28.1=pyhd8ed1ab_1
117 | - scikit-image=0.19.2=py38h51133e4_0
118 | - scipy=1.9.3=py38h14f4228_0
119 | - seaborn=0.12.0=py38h06a4308_0
120 | - setuptools=65.5.0=py38h06a4308_0
121 | - six=1.16.0=pyh6c4a22f_0
122 | - snappy=1.1.9=h295c915_0
123 | - sqlite=3.33.0=h62c20be_0
124 | - tbb=2021.6.0=hdb19cb5_0
125 | - tifffile=2021.7.2=pyhd3eb1b0_2
126 | - tk=8.6.12=h1ccaba5_0
127 | - toolz=0.12.0=py38h06a4308_0
128 | - torchaudio=0.12.1=py38_cu116
129 | - torchvision=0.13.1=py38_cu116
130 | - tqdm=4.64.1=py38h06a4308_0
131 | - typing_extensions=4.4.0=pyha770c72_0
132 | - urllib3=1.26.12=py38h06a4308_0
133 | - wheel=0.37.1=pyhd3eb1b0_0
134 | - x264=1!161.3030=h7f98852_1
135 | - xz=5.2.6=h5eee18b_0
136 | - yaml=0.2.5=h7b6447c_0
137 | - zfp=0.5.5=h2531618_6
138 | - zlib=1.2.13=h5eee18b_0
139 | - zstd=1.5.2=ha4553b6_0
140 | - pip:
141 | - addict==2.4.0
142 | - aiosignal==1.2.0
143 | - asttokens==2.1.0
144 | - attrs==22.1.0
145 | - backcall==0.2.0
146 | - click==8.0.4
147 | - configargparse==1.5.3
148 | - dash==2.7.0
149 | - dash-core-components==2.0.0
150 | - dash-html-components==2.0.0
151 | - dash-table==5.0.0
152 | - debugpy==1.6.3
153 | - decorator==5.1.1
154 | - distlib==0.3.6
155 | - entrypoints==0.4
156 | - executing==1.2.0
157 | - fastjsonschema==2.16.2
158 | - filelock==3.8.0
159 | - flask==2.2.2
160 | - frozenlist==1.3.1
161 | - ftfy==6.1.1
162 | - greenlet==2.0.1
163 | - grpcio==1.43.0
164 | - h5py==3.7.0
165 | - imageio==2.22.4
166 | - imageio-ffmpeg==0.4.7
167 | - importlib-resources==5.10.0
168 | - ipykernel==6.17.0
169 | - ipython==8.6.0
170 | - ipywidgets==8.0.2
171 | - itsdangerous==2.1.2
172 | - jedi==0.18.1
173 | - jinja2==3.1.2
174 | - joblib==1.2.0
175 | - jsonschema==4.17.0
176 | - jupyter-client==7.4.4
177 | - jupyter-core==4.11.2
178 | - jupyterlab-widgets==3.0.3
179 | - markupsafe==2.1.1
180 | - matplotlib-inline==0.1.6
181 | - msgpack==1.0.4
182 | - nbformat==5.5.0
183 | - neovim==0.3.1
184 | - nest-asyncio==1.5.6
185 | - open3d==0.16.0
186 | - opencv-python==4.6.0.66
187 | - parso==0.8.3
188 | - pexpect==4.8.0
189 | - pickleshare==0.7.5
190 | - pkgutil-resolve-name==1.3.10
191 | - platformdirs==2.5.3
192 | - plotly==5.11.0
193 | - prompt-toolkit==3.0.32
194 | - protobuf==3.20.1
195 | - psutil==5.9.3
196 | - ptyprocess==0.7.0
197 | - pure-eval==0.2.2
198 | - pybullet==3.2.5
199 | - pygments==2.13.0
200 | - pynvim==0.4.3
201 | - pyquaternion==0.9.9
202 | - pyrsistent==0.19.2
203 | - pyyaml==6.0
204 | - pyzmq==24.0.1
205 | - ray==2.0.1
206 | - regex==2022.10.31
207 | - rich
208 | - sacremoses==0.0.53
209 | - scikit-learn==1.1.3
210 | - stack-data==0.6.0
211 | - tabulate==0.9.0
212 | - tenacity==8.1.0
213 | - tensorboardx==2.5.1
214 | - threadpoolctl==3.1.0
215 | - tokenizers==0.10.3
216 | - torchtyping==0.1.4
217 | - tornado==6.2
218 | - traitlets==5.5.0
219 | - transformers==4.5.1
220 | - transforms3d==0.4.1
221 | - typeguard==2.13.3
222 | - typer
223 | - virtualenv==20.16.6
224 | - wcwidth==0.2.5
225 | - werkzeug==2.2.2
226 | - widgetsnbextension==4.0.3
227 | - zipp==3.10.0
--------------------------------------------------------------------------------
/summarize.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import rich
3 | import pickle
4 | from dataset import synonyms
5 | import numpy as np
6 | from rich.console import Console
7 | from rich.table import Table
8 |
9 | test_objs = set(
10 | map(lambda l: l.rstrip().lstrip(), open("test_semantic_classes.txt", "r"))
11 | )
12 |
13 |
14 | def summarize_ovssc(metric="voxel32x32x32_iou"):
15 | ssc_approaches = {
16 | "Semantic Aware": pickle.load(
17 | open("models/semaware/ovssc/ovssc_eval_stats.pkl", "rb")
18 | ),
19 | "SemAbs + [Chefer et al]": pickle.load(
20 | open("models/chefer_et_al/ovssc/ovssc_eval_stats.pkl", "rb")
21 | ),
22 | "Ours": pickle.load(
23 | open(
24 | "models/ours/ovssc/ovssc_eval_stats.pkl",
25 | "rb",
26 | )
27 | ),
28 | }
29 |
30 | ovssc_stats = {
31 | "approach": [],
32 | "novel rooms": [],
33 | "novel visual": [],
34 | "novel vocab": [],
35 | "novel class": [],
36 | }
37 | pd.options.display.float_format = "{:,.3f}".format
38 | for approach, approach_stats in ssc_approaches.items():
39 | # approach_stats = approach_stats[approach_stats.label!='']
40 | approach_stats["room_id"] = approach_stats["scene_id"].apply(
41 | lambda s: int(s.split("_")[0].split("FloorPlan")[1])
42 | )
43 | approach_stats[metric] = approach_stats[metric] * 100
44 | cutoff_analysis = approach_stats.groupby("cutoff")[[metric]].mean()
45 | best_cutoff = cutoff_analysis[metric].idxmax()
46 | df = approach_stats[approach_stats.cutoff == best_cutoff]
47 | novel_class_mask = df.label.isin(test_objs)
48 | novel_vocab_mask = df.label.isin(synonyms.values())
49 | ovssc_stats["approach"].append(approach)
50 |
51 | novel_rooms_df = df[(df.split == "unseen_instances") & (~novel_class_mask)]
52 | mean_per_room = np.array(novel_rooms_df.groupby("room_id")[metric].mean())
53 | ovssc_stats["novel rooms"].append(mean_per_room.mean())
54 |
55 | novel_rooms_dr_df = df[
56 | (df.split == "unseen_instances_dr") & (~novel_class_mask)
57 | ]
58 | mean_per_room = np.array(novel_rooms_dr_df.groupby("room_id")[metric].mean())
59 | ovssc_stats["novel visual"].append(mean_per_room.mean())
60 |
61 | unseen_class_df = df[novel_class_mask]
62 | mean_per_label = unseen_class_df.groupby("label")[metric].mean()
63 | ovssc_stats["novel class"].append(np.array(mean_per_label).mean())
64 |
65 | unseen_vocab_df = df[
66 | (df.split == "unseen_instances_synonyms") & novel_vocab_mask
67 | ]
68 | mean_per_label = unseen_vocab_df.groupby("label")[metric].mean()
69 | ovssc_stats["novel vocab"].append(np.array(mean_per_label).mean())
70 | ovssc_stats = pd.DataFrame.from_dict(ovssc_stats)
71 | table = Table(title="OVSSC THOR", box=rich.box.MINIMAL_DOUBLE_HEAD)
72 | table.add_column("Approach", justify="left")
73 | table.add_column("Novel Room", justify="right")
74 | table.add_column("Novel Visual", justify="right")
75 | table.add_column("Novel Vocab", justify="right")
76 | table.add_column("Novel Class", justify="right")
77 | for row in ovssc_stats.to_csv().split("\n")[1:-1]:
78 | approach, novel_room, novel_visual, novel_vocab, novel_class = row.split(",")[
79 | 1:
80 | ]
81 | table.add_row(
82 | approach,
83 | f"{float(novel_room):.01f}",
84 | f"{float(novel_visual):.01f}",
85 | f"{float(novel_vocab):.01f}",
86 | f"{float(novel_class):.01f}",
87 | end_section=approach == "SemAbs + [Chefer et al]",
88 | style="green" if approach == "Ours" else "white",
89 | )
90 | console = Console()
91 | console.print(table)
92 |
93 |
94 | def summarize_vool(metric="voxel32x32x32_iou"):
95 | vool_approaches = {
96 | "Semantic Aware": pickle.load(
97 | open("models/semaware/vool/vool_eval_stats.pkl", "rb")
98 | ),
99 | "ClipSpatial": pickle.load(
100 | open("models/clipspatial/vool/vool_eval_stats.pkl", "rb")
101 | ),
102 | "SemAbs + [Chefer et al]": pickle.load(
103 | open("models/chefer_et_al/vool/vool_eval_stats.pkl", "rb")
104 | ),
105 | "Ours": pickle.load(open("models/ours/vool/vool_eval_stats.pkl", "rb")),
106 | }
107 | vool_stats = {
108 | "approach": [],
109 | "relation": [],
110 | "novel rooms": [],
111 | "novel visual": [],
112 | "novel vocab": [],
113 | "novel class": [],
114 | }
115 | relations = vool_approaches["Ours"].spatial_relation_name.unique()
116 | for approach in vool_approaches.keys():
117 | approach_stats = vool_approaches[approach]
118 | approach_stats["room_id"] = approach_stats["scene_id"].apply(
119 | lambda s: int(s.split("_")[0].split("FloorPlan")[1])
120 | )
121 | cutoff_analysis = approach_stats.groupby("cutoff")[[metric]].mean()
122 | best_cutoff = cutoff_analysis[metric].idxmax()
123 | approach_stats[metric] = approach_stats[metric] * 100
124 | for relation in relations:
125 | if relation == "[pad]":
126 | continue
127 | df = approach_stats[approach_stats.cutoff == best_cutoff]
128 | df = df[df.spatial_relation_name == relation]
129 |
130 | novel_vocab_mask = df.target_obj_name.isin(
131 | synonyms.values()
132 | ) | df.reference_obj_name.isin(synonyms.values())
133 | novel_class_mask = df.target_obj_name.isin(
134 | test_objs
135 | ) | df.reference_obj_name.isin(test_objs)
136 |
137 | vool_stats["approach"].append(approach)
138 | vool_stats["relation"].append(relation)
139 | novel_rooms_df = df[(df.split == "unseen_instances") & (~novel_class_mask)]
140 | mean_per_room = np.array(novel_rooms_df.groupby("room_id")[metric].mean())
141 | vool_stats["novel rooms"].append(np.nanmean(mean_per_room))
142 | novel_rooms_dr_df = df[
143 | (df.split == "unseen_instances_dr") & (~novel_class_mask)
144 | ]
145 | mean_per_room = np.array(
146 | novel_rooms_dr_df.groupby("room_id")[metric].mean()
147 | )
148 | vool_stats["novel visual"].append(np.nanmean(mean_per_room))
149 |
150 | unseen_class_df = df[novel_class_mask]
151 | vool_stats["novel class"].append(np.nanmean(unseen_class_df[metric]))
152 | unseen_vocab_df = df[
153 | (df.split == "unseen_instances_synonyms") & novel_vocab_mask
154 | ]
155 | vool_stats["novel vocab"].append(np.nanmean(unseen_vocab_df[metric]))
156 | vool_stats = pd.DataFrame.from_dict(vool_stats)
157 | for approach_i, approach in enumerate(vool_approaches.keys()):
158 | mean_df = pd.DataFrame.from_dict(
159 | {
160 | "approach": [approach],
161 | "relation": ["mean"],
162 | **{
163 | split: [
164 | np.array(
165 | vool_stats[(vool_stats.approach == approach)][[split]]
166 | ).mean()
167 | ]
168 | for split in [
169 | "novel rooms",
170 | "novel visual",
171 | "novel vocab",
172 | "novel class",
173 | ]
174 | },
175 | }
176 | )
177 | vool_stats = pd.concat(
178 | [
179 | vool_stats.iloc[0 : (approach_i + 1) * 6 + approach_i],
180 | mean_df,
181 | vool_stats.iloc[(approach_i + 1) * 6 + approach_i :],
182 | ]
183 | )
184 | table = Table(title="FULL VOOL THOR", box=rich.box.MINIMAL_DOUBLE_HEAD)
185 | table.add_column("Approach", justify="left")
186 | table.add_column("Spatial Relation", justify="left")
187 | table.add_column("Novel Room", justify="right")
188 | table.add_column("Novel Visual", justify="right")
189 | table.add_column("Novel Vocab", justify="right")
190 | table.add_column("Novel Class", justify="right")
191 | last_approach = ""
192 | for row in vool_stats.to_csv().split("\n")[1:-1]:
193 | (
194 | approach,
195 | spatial_relation,
196 | novel_room,
197 | novel_visual,
198 | novel_vocab,
199 | novel_class,
200 | ) = row.split(",")[1:]
201 | table.add_row(
202 | approach if approach != last_approach else "",
203 | spatial_relation,
204 | f"{float(novel_room):.01f}",
205 | f"{float(novel_visual):.01f}",
206 | f"{float(novel_vocab):.01f}",
207 | f"{float(novel_class):.01f}",
208 | end_section=spatial_relation == "mean",
209 | style=("green" if approach == "Ours" else "white"),
210 | )
211 | last_approach = approach
212 | console = Console()
213 | console.print(table)
214 |
215 |
216 | def summarize_nyuv2(metric="voxel60x60x60_iou"):
217 | ssc_approaches = {
218 | "Ours (Supervised)": pickle.load(
219 | open(
220 | "models/ours/ovssc/ovssc_eval_stats_supervised_nyu_merged.pkl",
221 | "rb",
222 | )
223 | ),
224 | "Ours (Zeroshot)": pickle.load(
225 | open(
226 | "models/ours/ovssc/ovssc_eval_stats_zs_nyu_merged.pkl",
227 | "rb",
228 | )
229 | ),
230 | }
231 | classes = [
232 | "ceiling",
233 | "floor",
234 | "wall",
235 | "window",
236 | "chair",
237 | "bed",
238 | "sofa",
239 | "table",
240 | "tvs",
241 | "furn",
242 | "objs",
243 | "mean",
244 | ]
245 | table = Table(title="OVSSC NYU", box=rich.box.MINIMAL_DOUBLE_HEAD)
246 | table.add_column("Approach", justify="left")
247 | for c in classes:
248 | table.add_column(c.title(), justify="right")
249 | for approach, approach_stats in ssc_approaches.items():
250 | approach_stats[metric] = approach_stats[metric] * 100
251 | cutoff_analysis = approach_stats.groupby("cutoff")[[metric]].mean()
252 | best_cutoff = cutoff_analysis[metric].idxmax()
253 | df = approach_stats[approach_stats.cutoff == best_cutoff]
254 | row = [approach]
255 | for c in classes:
256 | if c != "mean":
257 | row.append(f"{df[df.label == c][metric].mean():.01f}")
258 | else:
259 | row.append(
260 | f'{np.array(df.groupby("label")[metric].mean()).mean():.01f}'
261 | )
262 | table.add_row(
263 | *row,
264 | end_section=approach == "Ours (Supervised)",
265 | style="green" if approach == "Ours (Zeroshot)" else "white",
266 | )
267 | console = Console()
268 | console.print(table)
269 |
270 |
271 | if __name__ == "__main__":
272 | summarize_ovssc()
273 | summarize_vool()
274 | summarize_nyuv2()
275 |
--------------------------------------------------------------------------------
/test_semantic_classes.txt:
--------------------------------------------------------------------------------
1 | pot
2 | mug
3 | safe
4 | teddy bear
5 | basket ball
6 | wine bottle
7 |
--------------------------------------------------------------------------------
/test_thor_rooms.txt:
--------------------------------------------------------------------------------
1 | FloorPlan26_physics
2 | FloorPlan27_physics
3 | FloorPlan28_physics
4 | FloorPlan29_physics
5 | FloorPlan30_physics
6 | FloorPlan226_physics
7 | FloorPlan227_physics
8 | FloorPlan228_physics
9 | FloorPlan229_physics
10 | FloorPlan230_physics
11 | FloorPlan326_physics
12 | FloorPlan327_physics
13 | FloorPlan328_physics
14 | FloorPlan329_physics
15 | FloorPlan330_physics
16 | FloorPlan426_physics
17 | FloorPlan427_physics
18 | FloorPlan428_physics
19 | FloorPlan429_physics
20 | FloorPlan430_physics
21 |
--------------------------------------------------------------------------------
/train_ovssc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch.nn.functional import binary_cross_entropy_with_logits
4 | from net import SemAbs3D, SemanticAwareOVSSC
5 | import utils
6 | import pandas as pd
7 | from dataset import SceneCompletionDataset
8 | from typing import Dict, Tuple, Union
9 |
10 |
11 | def get_detailed_stats(
12 | prediction,
13 | gt_label,
14 | xyz_pts,
15 | patch_labels,
16 | scene_ids,
17 | scene_bounds,
18 | ignore_pts,
19 | detailed_analysis=False,
20 | eval_device="cuda",
21 | **kwargs,
22 | ):
23 | num_scenes, num_patches = patch_labels.shape
24 | retvals = {
25 | "scene_id": np.array([[scene_id] * num_patches for scene_id in scene_ids])
26 | .reshape(-1)
27 | .tolist(),
28 | "label": patch_labels.reshape(-1).tolist(),
29 | }
30 | retvals.update(
31 | {
32 | f"point_{k}": v
33 | for k, v in utils.prediction_analysis(
34 | prediction=prediction.to(eval_device),
35 | label=gt_label.to(eval_device),
36 | ignore=ignore_pts.to(eval_device),
37 | ).items()
38 | }
39 | )
40 | voxelized_pts = utils.voxelize_points(
41 | prediction=prediction,
42 | label=gt_label,
43 | xyz_pts=xyz_pts,
44 | voxel_shape=(32, 32, 32),
45 | scene_bounds=scene_bounds,
46 | ignore_pts=ignore_pts,
47 | )
48 | retvals.update(
49 | {
50 | "voxel32x32x32_" + k: v
51 | for k, v in utils.prediction_analysis(
52 | **{k: v.to(eval_device) for k, v in voxelized_pts.items()}
53 | ).items()
54 | }
55 | )
56 | if detailed_analysis:
57 | voxelized_pts = utils.voxelize_points(
58 | prediction=prediction,
59 | label=gt_label,
60 | xyz_pts=xyz_pts,
61 | voxel_shape=(64, 64, 64),
62 | scene_bounds=scene_bounds,
63 | ignore_pts=ignore_pts,
64 | )
65 | retvals.update(
66 | {
67 | "voxel64x64x64_" + k: v
68 | for k, v in utils.prediction_analysis(
69 | **{k: v.to(eval_device) for k, v in voxelized_pts.items()}
70 | ).items()
71 | }
72 | )
73 | for i, label in enumerate(patch_labels.reshape(-1).tolist()):
74 | if label == "": # skip padding classes
75 | for k in retvals.keys():
76 | if "voxel" in k or "point" in k:
77 | retvals[k][i] = np.NAN
78 | return pd.DataFrame.from_dict(retvals)
79 |
80 |
81 | def get_losses(
82 | net,
83 | batch: dict,
84 | cutoffs=[0],
85 | balance_positive_negative: bool = False,
86 | **kwargs,
87 | ) -> Tuple[Dict[str, Union[float, torch.Tensor]], pd.DataFrame]:
88 | stats = {}
89 | num_pts = batch["output_xyz_pts"].shape[2]
90 | if num_pts <= 500000:
91 | outputs = net(**batch)
92 | else:
93 | num_patches = 1
94 | # probably CUDA OOM
95 | outputs = torch.cat(
96 | [
97 | net(
98 | **{
99 | **batch,
100 | "input_feature_pts": batch["input_feature_pts"][
101 | :, patch_i * num_patches : (patch_i + 1) * num_patches, ...
102 | ]
103 | if batch["input_feature_pts"].shape[1]
104 | == batch["output_xyz_pts"].shape[1]
105 | else batch["input_feature_pts"],
106 | "output_xyz_pts": batch["output_xyz_pts"][
107 | :, patch_i * num_patches : (patch_i + 1) * num_patches, ...
108 | ],
109 | "semantic_class_features": batch["semantic_class_features"][
110 | :, patch_i * num_patches : (patch_i + 1) * num_patches, ...
111 | ],
112 | }
113 | )
114 | for patch_i in range(len(batch["patch_labels"]) // num_patches + 1)
115 | if np.prod(
116 | batch["output_xyz_pts"][
117 | :, patch_i * num_patches : (patch_i + 1) * num_patches, ...
118 | ].shape
119 | )
120 | > 0
121 | ],
122 | dim=1,
123 | )
124 |
125 | batch["patch_labels"] = np.array(batch["patch_labels"]).T
126 | padding_mask = torch.from_numpy(batch["patch_labels"] == "").bool()
127 | batch["out_of_bounds_pts"] = batch["out_of_bounds_pts"].view(outputs.shape)
128 | ignore_pts_mask = torch.zeros_like(outputs).bool()
129 | # ignore all padding labels
130 | ignore_pts_mask[padding_mask] = True
131 | # ignore all points out of bounds
132 | ignore_pts_mask = torch.logical_or(ignore_pts_mask, batch["out_of_bounds_pts"])
133 | # don't eval on points outside of frustum
134 | ignore_pts_mask = torch.logical_or(
135 | ignore_pts_mask, batch["out_of_frustum_pts_mask"]
136 | )
137 | stats["loss"] = binary_cross_entropy_with_logits(
138 | outputs[~ignore_pts_mask],
139 | batch["output_label_pts"][~ignore_pts_mask],
140 | weight=utils.get_bce_weight(
141 | output_label_pts=batch["output_label_pts"],
142 | balance_positive_negative=balance_positive_negative,
143 | )[~ignore_pts_mask],
144 | )
145 | with torch.no_grad():
146 | vision_accuracy_mask = (
147 | (outputs > 0.0).long() == batch["output_label_pts"]
148 | ).float()
149 | stats["accuracy"] = vision_accuracy_mask[~ignore_pts_mask].mean()
150 | detailed_stats = [
151 | get_detailed_stats(
152 | prediction=outputs > cutoff,
153 | gt_label=batch["output_label_pts"].bool(),
154 | xyz_pts=batch["output_xyz_pts"],
155 | ignore_pts=ignore_pts_mask,
156 | patch_labels=batch["patch_labels"],
157 | scene_ids=batch["scene_id"],
158 | eval_device=net.device,
159 | **kwargs,
160 | )
161 | for cutoff in cutoffs
162 | ]
163 | for detailed_stat, cutoff in zip(detailed_stats, cutoffs):
164 | detailed_stat["cutoff"] = [cutoff] * len(detailed_stat)
165 | detailed_stats = pd.concat(detailed_stats)
166 | for k in detailed_stats.columns:
167 | if "iou" in k:
168 | stats[k] = detailed_stats[k].mean()
169 | return stats, detailed_stats
170 |
171 |
172 | approach = {
173 | "semantic_abstraction": SemAbs3D,
174 | "semantic_aware": SemanticAwareOVSSC,
175 | }
176 |
177 |
178 | if __name__ == "__main__":
179 | parser = utils.config_parser()
180 | parser.add_argument("--log", type=str, required=True)
181 | parser.add_argument(
182 | "--approach", choices=approach.keys(), default="semantic_abstraction"
183 | )
184 | args = parser.parse_args()
185 | if args.approach == "semantic_aware":
186 | args.network_inputs = ["rgb"]
187 | utils.train(
188 | get_losses_fn=get_losses,
189 | **utils.setup_experiment(
190 | args=args,
191 | ddp=len(args.gpus) > 1,
192 | net_class=approach[args.approach],
193 | dataset_class=SceneCompletionDataset,
194 | split_file_path=args.file_path + "/ssc_split.pkl",
195 | ),
196 | **vars(args),
197 | )
198 |
--------------------------------------------------------------------------------
/train_vool.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Tuple, Union
2 | import numpy as np
3 | from dataset import ObjectLocalizationDataset
4 | from net import (
5 | SemAbsVOOL,
6 | ClipSpatialVOOL,
7 | SemanticAwareVOOL,
8 | )
9 | import utils
10 | from torch.nn.functional import binary_cross_entropy_with_logits
11 | import torch
12 | import pandas as pd
13 |
14 |
15 | def get_detailed_stats(
16 | prediction,
17 | gt_label,
18 | xyz_pts,
19 | scene_ids,
20 | target_obj_names,
21 | reference_obj_names,
22 | spatial_relation_names,
23 | scene_bounds,
24 | ignore_pts,
25 | detailed_analysis=False,
26 | eval_device="cuda",
27 | **kwargs,
28 | ):
29 | num_scenes, num_descs = gt_label.shape[:2]
30 |
31 | retvals = {
32 | "scene_id": np.array([[scene_id] * num_descs for scene_id in scene_ids])
33 | .reshape(-1)
34 | .tolist(),
35 | "target_obj_name": np.array(target_obj_names).T.reshape(-1).tolist(),
36 | "reference_obj_name": np.array(reference_obj_names).T.reshape(-1).tolist(),
37 | "spatial_relation_name": np.array(spatial_relation_names)
38 | .T.reshape(-1)
39 | .tolist(),
40 | }
41 | retvals.update(
42 | {
43 | f"point_{k}": v
44 | for k, v in utils.prediction_analysis(
45 | prediction=prediction.to(eval_device),
46 | label=gt_label.to(eval_device),
47 | ignore=ignore_pts.to(eval_device),
48 | ).items()
49 | }
50 | )
51 | num_desc_b = 10
52 | outputs = []
53 | for i in np.arange(0, num_descs + num_desc_b + 1, num_desc_b):
54 | if np.prod(prediction[:, i : i + num_desc_b].shape) == 0:
55 | continue
56 | outputs.append(
57 | utils.voxelize_points(
58 | prediction=prediction[:, i : i + num_desc_b],
59 | label=gt_label[:, i : i + num_desc_b],
60 | xyz_pts=xyz_pts[:, i : i + num_desc_b],
61 | voxel_shape=(32, 32, 32),
62 | scene_bounds=scene_bounds,
63 | ignore_pts=ignore_pts[:, i : i + num_desc_b],
64 | device=eval_device,
65 | )
66 | )
67 | voxelized_pts = {
68 | k: torch.cat([output[k] for output in outputs], dim=1)
69 | for k in outputs[0].keys()
70 | }
71 | retvals.update(
72 | {
73 | "voxel32x32x32_" + k: v
74 | for k, v in utils.prediction_analysis(
75 | **{k: v.to(eval_device) for k, v in voxelized_pts.items()}
76 | ).items()
77 | }
78 | )
79 | if detailed_analysis:
80 | outputs = []
81 | for i in np.arange(0, num_descs + num_desc_b + 1, num_desc_b):
82 | if np.prod(prediction[:, i : i + num_desc_b].shape) == 0:
83 | continue
84 | outputs.append(
85 | utils.voxelize_points(
86 | prediction=prediction[:, i : i + num_desc_b],
87 | label=gt_label[:, i : i + num_desc_b],
88 | xyz_pts=xyz_pts[:, i : i + num_desc_b],
89 | voxel_shape=(64, 64, 64),
90 | scene_bounds=scene_bounds,
91 | ignore_pts=ignore_pts[:, i : i + num_desc_b],
92 | device=eval_device,
93 | )
94 | )
95 | voxelized_pts = {
96 | k: torch.cat([output[k] for output in outputs], dim=1)
97 | for k in outputs[0].keys()
98 | }
99 | retvals.update(
100 | {
101 | "voxel64x64x64_" + k: v
102 | for k, v in utils.prediction_analysis(
103 | **{k: v.to(eval_device) for k, v in voxelized_pts.items()}
104 | ).items()
105 | }
106 | )
107 |
108 | for i, spatial_relation in enumerate(
109 | np.array(spatial_relation_names).T.reshape(-1)
110 | ):
111 | if spatial_relation == "[pad]": # skip padding classes
112 | for k in retvals.keys():
113 | if "voxel" in k or "point" in k:
114 | retvals[k][i] = np.NAN
115 | return pd.DataFrame.from_dict(retvals)
116 |
117 |
118 | def get_losses(
119 | net, batch: dict, cutoffs=[-2.0], balance_positive_negative: bool = False, **kwargs
120 | ) -> Tuple[Dict[str, Union[float, torch.Tensor]], pd.DataFrame]:
121 | stats = {}
122 | batch_size, total_num_descs, num_pts = batch["output_label_pts"].shape
123 | if num_pts <= 500000:
124 | outputs = net(**batch)
125 | else:
126 | num_descs = 1
127 | # probably CUDA OOM
128 | outputs = torch.cat(
129 | [
130 | net(
131 | **{
132 | **batch,
133 | "input_target_saliency_pts": batch["input_target_saliency_pts"][
134 | :, desc_i * num_descs : (desc_i + 1) * num_descs, ...
135 | ],
136 | "input_reference_saliency_pts": batch[
137 | "input_reference_saliency_pts"
138 | ][:, desc_i * num_descs : (desc_i + 1) * num_descs, ...],
139 | "input_description_saliency_pts": batch[
140 | "input_description_saliency_pts"
141 | ][:, desc_i * num_descs : (desc_i + 1) * num_descs, ...],
142 | "output_xyz_pts": batch["output_xyz_pts"][
143 | :, desc_i * num_descs : (desc_i + 1) * num_descs, ...
144 | ],
145 | "spatial_relation_name": (
146 | np.array(batch["spatial_relation_name"])
147 | .T[:, desc_i * num_descs : (desc_i + 1) * num_descs]
148 | .T
149 | ),
150 | }
151 | )
152 | for desc_i in range(total_num_descs // num_descs + 1)
153 | if np.prod(
154 | batch["output_xyz_pts"][
155 | :, desc_i * num_descs : (desc_i + 1) * num_descs, ...
156 | ].shape
157 | )
158 | > 0
159 | ],
160 | dim=1,
161 | )
162 |
163 | padding_mask = torch.from_numpy(
164 | np.array(batch["spatial_relation_name"]).T == "[pad]"
165 | ).bool()
166 | ignore_pts_mask = torch.zeros_like(outputs).bool()
167 | # ignore all padding labels
168 | ignore_pts_mask[padding_mask] = True
169 | # ignore all points out of bounds
170 | ignore_pts_mask = torch.logical_or(ignore_pts_mask, batch["out_of_bounds_pts"])
171 | stats["loss"] = binary_cross_entropy_with_logits(
172 | outputs,
173 | batch["output_label_pts"],
174 | weight=utils.get_bce_weight(
175 | output_label_pts=batch["output_label_pts"],
176 | balance_positive_negative=balance_positive_negative,
177 | ),
178 | )
179 |
180 | with torch.no_grad():
181 | accuracy = ((outputs > 0.0).long() == batch["output_label_pts"]).float()[
182 | ~ignore_pts_mask
183 | ]
184 | stats["accuracy"] = accuracy.mean()
185 | detailed_stats = [
186 | get_detailed_stats(
187 | prediction=outputs > cutoff,
188 | gt_label=batch["output_label_pts"].bool(),
189 | xyz_pts=batch["output_xyz_pts"],
190 | ignore_pts=ignore_pts_mask,
191 | target_obj_names=batch["target_obj_name"],
192 | reference_obj_names=batch["reference_obj_name"],
193 | spatial_relation_names=batch["spatial_relation_name"],
194 | scene_ids=batch["scene_id"],
195 | eval_device=net.device,
196 | **kwargs,
197 | )
198 | for cutoff in cutoffs
199 | ]
200 | for detailed_stat, cutoff in zip(detailed_stats, cutoffs):
201 | detailed_stat["cutoff"] = [cutoff] * len(detailed_stat)
202 | detailed_stats = pd.concat(detailed_stats)
203 | for k in detailed_stats.columns:
204 | if "iou" in k:
205 | stats[k] = detailed_stats[k].mean()
206 | return stats, detailed_stats
207 |
208 |
209 | approach = {
210 | "semantic_abstraction": SemAbsVOOL,
211 | "semantic_aware": SemanticAwareVOOL,
212 | "clip_spatial": ClipSpatialVOOL,
213 | }
214 |
215 | if __name__ == "__main__":
216 | parser = utils.config_parser()
217 | parser.add_argument("--log", type=str, required=True)
218 | parser.add_argument(
219 | "--approach", choices=approach.keys(), default="semantic_abstraction"
220 | )
221 | args = parser.parse_args()
222 | if args.approach == "semantic_aware":
223 | args.network_inputs = ["rgb"]
224 | utils.train(
225 | get_losses_fn=get_losses,
226 | **utils.setup_experiment(
227 | args=args,
228 | net_class=approach[args.approach],
229 | dataset_class=ObjectLocalizationDataset,
230 | split_file_path=args.file_path + "/vool_split.pkl",
231 | ),
232 | **vars(args),
233 | )
234 |
--------------------------------------------------------------------------------