├── .gitignore ├── CLIP ├── .github │ └── workflows │ │ └── test.yml ├── .gitignore ├── CLIP.png ├── LICENSE ├── MANIFEST.in ├── README.md ├── clip │ ├── __init__.py │ ├── auxiliary.py │ ├── bpe_simple_vocab_16e6.txt │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── clip_explainability.py │ ├── clip_gradcam.py │ ├── model.py │ ├── model_explainability.py │ └── simple_tokenizer.py ├── model-card.md ├── requirements.txt ├── setup.py └── tests │ └── test_consistency.py ├── LICENSE ├── README.md ├── arm ├── LICENSE ├── __init__.py ├── network_utils.py ├── optim │ ├── LICENSE │ ├── __init__.py │ └── lamb.py └── utils.py ├── assets ├── hair_dryer_behind_table.gif ├── hair_dryer_completion.gif ├── hair_dryer_completion_blender.gif ├── hair_dryer_completion_blender_legend.png ├── hair_dryer_scene.png ├── matterport.png ├── teaser.gif └── vn_poster_relevancies.png ├── datagen ├── README.md └── assets │ └── unity-build.png ├── dataset.py ├── eval.py ├── fusion.py ├── generate_relevancy.py ├── generate_thor_data.py ├── grads.png ├── matterport.png ├── net.py ├── plot_utils.py ├── point_cloud.py ├── scene_files ├── arkit_kitchen.pkl ├── arkit_vn_poster.pkl ├── matterport_hallway.pkl ├── matterport_kitchen.pkl ├── matterport_living_room.pkl └── walle.pkl ├── semabs.yml ├── summarize.py ├── test_semantic_classes.txt ├── test_thor_rooms.txt ├── train_ovssc.py ├── train_vool.py ├── unet3d.py ├── utils.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- 1 | crv04 2 | crv05 3 | *.pth 4 | *.ply 5 | *.obj 6 | *.jpg 7 | *.blend 8 | *.blend1 9 | tags 10 | log* 11 | temp* 12 | *.hdf5 13 | *.pkl 14 | *.lock 15 | .vscode/ 16 | logs/ 17 | *cache* 18 | 19 | # Byte-compiled / optimized / DLL files 20 | __pycache__/ 21 | *.py[cod] 22 | *$py.class 23 | 24 | # C extensions 25 | *.so 26 | 27 | # Distribution / packaging 28 | .Python 29 | build/ 30 | develop-eggs/ 31 | dist/ 32 | downloads/ 33 | eggs/ 34 | .eggs/ 35 | lib/ 36 | lib64/ 37 | parts/ 38 | sdist/ 39 | var/ 40 | wheels/ 41 | pip-wheel-metadata/ 42 | share/python-wheels/ 43 | *.egg-info/ 44 | .installed.cfg 45 | *.egg 46 | MANIFEST 47 | 48 | # PyInstaller 49 | # Usually these files are written by a python script from a template 50 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 51 | *.manifest 52 | *.spec 53 | 54 | # Installer logs 55 | pip-log.txt 56 | pip-delete-this-directory.txt 57 | 58 | # Unit test / coverage reports 59 | htmlcov/ 60 | .tox/ 61 | .nox/ 62 | .coverage 63 | .coverage.* 64 | .cache 65 | nosetests.xml 66 | coverage.xml 67 | *.cover 68 | *.py,cover 69 | .hypothesis/ 70 | .pytest_cache/ 71 | 72 | # Translations 73 | *.mo 74 | *.pot 75 | 76 | # Django stuff: 77 | *.log 78 | local_settings.py 79 | db.sqlite3 80 | db.sqlite3-journal 81 | 82 | # Flask stuff: 83 | instance/ 84 | .webassets-cache 85 | 86 | # Scrapy stuff: 87 | .scrapy 88 | 89 | # Sphinx documentation 90 | docs/_build/ 91 | 92 | # PyBuilder 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # pipenv 106 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 107 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 108 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 109 | # install all needed dependencies. 110 | #Pipfile.lock 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | -------------------------------------------------------------------------------- /CLIP/.github/workflows/test.yml: -------------------------------------------------------------------------------- 1 | name: test 2 | on: 3 | push: 4 | branches: 5 | - main 6 | pull_request: 7 | branches: 8 | - main 9 | jobs: 10 | CLIP-test: 11 | runs-on: ubuntu-latest 12 | strategy: 13 | matrix: 14 | python-version: [3.7, 3.8] 15 | pytorch-version: [1.7.1, 1.9.0] 16 | steps: 17 | - uses: conda-incubator/setup-miniconda@v2 18 | - run: conda install -n test python=${{ matrix.python-version }} pytorch=${{ matrix.pytorch-version }} torchvision cpuonly -c pytorch 19 | - uses: actions/checkout@v2 20 | - run: echo "$CONDA/envs/test/bin" >> $GITHUB_PATH 21 | - run: pip install pytest 22 | - run: pip install . 23 | - run: pytest 24 | -------------------------------------------------------------------------------- /CLIP/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | *.py[cod] 3 | *$py.class 4 | *.egg-info 5 | .pytest_cache 6 | .ipynb_checkpoints 7 | 8 | thumbs.db 9 | .DS_Store 10 | .idea 11 | -------------------------------------------------------------------------------- /CLIP/CLIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/CLIP/CLIP.png -------------------------------------------------------------------------------- /CLIP/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 OpenAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /CLIP/MANIFEST.in: -------------------------------------------------------------------------------- 1 | include clip/bpe_simple_vocab_16e6.txt.gz 2 | -------------------------------------------------------------------------------- /CLIP/README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | 3 | [[Blog]](https://openai.com/blog/clip/) [[Paper]](https://arxiv.org/abs/2103.00020) [[Model Card]](model-card.md) [[Colab]](https://colab.research.google.com/github/openai/clip/blob/master/notebooks/Interacting_with_CLIP.ipynb) 4 | 5 | CLIP (Contrastive Language-Image Pre-Training) is a neural network trained on a variety of (image, text) pairs. It can be instructed in natural language to predict the most relevant text snippet, given an image, without directly optimizing for the task, similarly to the zero-shot capabilities of GPT-2 and 3. We found CLIP matches the performance of the original ResNet50 on ImageNet “zero-shot” without using any of the original 1.28M labeled examples, overcoming several major challenges in computer vision. 6 | 7 | 8 | 9 | ## Approach 10 | 11 | ![CLIP](CLIP.png) 12 | 13 | 14 | 15 | ## Usage 16 | 17 | First, [install PyTorch 1.7.1](https://pytorch.org/get-started/locally/) and torchvision, as well as small additional dependencies, and then install this repo as a Python package. On a CUDA GPU machine, the following will do the trick: 18 | 19 | ```bash 20 | $ conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=11.0 21 | $ pip install ftfy regex tqdm 22 | $ pip install git+https://github.com/openai/CLIP.git 23 | ``` 24 | 25 | Replace `cudatoolkit=11.0` above with the appropriate CUDA version on your machine or `cpuonly` when installing on a machine without a GPU. 26 | 27 | ```python 28 | import torch 29 | import clip 30 | from PIL import Image 31 | 32 | device = "cuda" if torch.cuda.is_available() else "cpu" 33 | model, preprocess = clip.load("ViT-B/32", device=device) 34 | 35 | image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device) 36 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 37 | 38 | with torch.no_grad(): 39 | image_features = model.encode_image(image) 40 | text_features = model.encode_text(text) 41 | 42 | logits_per_image, logits_per_text = model(image, text) 43 | probs = logits_per_image.softmax(dim=-1).cpu().numpy() 44 | 45 | print("Label probs:", probs) # prints: [[0.9927937 0.00421068 0.00299572]] 46 | ``` 47 | 48 | 49 | ## API 50 | 51 | The CLIP module `clip` provides the following methods: 52 | 53 | #### `clip.available_models()` 54 | 55 | Returns the names of the available CLIP models. 56 | 57 | #### `clip.load(name, device=..., jit=False)` 58 | 59 | Returns the model and the TorchVision transform needed by the model, specified by the model name returned by `clip.available_models()`. It will download the model as necessary. The `name` argument can also be a path to a local checkpoint. 60 | 61 | The device to run the model can be optionally specified, and the default is to use the first CUDA device if there is any, otherwise the CPU. When `jit` is `False`, a non-JIT version of the model will be loaded. 62 | 63 | #### `clip.tokenize(text: Union[str, List[str]], context_length=77)` 64 | 65 | Returns a LongTensor containing tokenized sequences of given text input(s). This can be used as the input to the model 66 | 67 | --- 68 | 69 | The model returned by `clip.load()` supports the following methods: 70 | 71 | #### `model.encode_image(image: Tensor)` 72 | 73 | Given a batch of images, returns the image features encoded by the vision portion of the CLIP model. 74 | 75 | #### `model.encode_text(text: Tensor)` 76 | 77 | Given a batch of text tokens, returns the text features encoded by the language portion of the CLIP model. 78 | 79 | #### `model(image: Tensor, text: Tensor)` 80 | 81 | Given a batch of images and a batch of text tokens, returns two Tensors, containing the logit scores corresponding to each image and text input. The values are cosine similarities between the corresponding image and text features, times 100. 82 | 83 | 84 | 85 | ## More Examples 86 | 87 | ### Zero-Shot Prediction 88 | 89 | The code below performs zero-shot prediction using CLIP, as shown in Appendix B in the paper. This example takes an image from the [CIFAR-100 dataset](https://www.cs.toronto.edu/~kriz/cifar.html), and predicts the most likely labels among the 100 textual labels from the dataset. 90 | 91 | ```python 92 | import os 93 | import clip 94 | import torch 95 | from torchvision.datasets import CIFAR100 96 | 97 | # Load the model 98 | device = "cuda" if torch.cuda.is_available() else "cpu" 99 | model, preprocess = clip.load('ViT-B/32', device) 100 | 101 | # Download the dataset 102 | cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False) 103 | 104 | # Prepare the inputs 105 | image, class_id = cifar100[3637] 106 | image_input = preprocess(image).unsqueeze(0).to(device) 107 | text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device) 108 | 109 | # Calculate features 110 | with torch.no_grad(): 111 | image_features = model.encode_image(image_input) 112 | text_features = model.encode_text(text_inputs) 113 | 114 | # Pick the top 5 most similar labels for the image 115 | image_features /= image_features.norm(dim=-1, keepdim=True) 116 | text_features /= text_features.norm(dim=-1, keepdim=True) 117 | similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1) 118 | values, indices = similarity[0].topk(5) 119 | 120 | # Print the result 121 | print("\nTop predictions:\n") 122 | for value, index in zip(values, indices): 123 | print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%") 124 | ``` 125 | 126 | The output will look like the following (the exact numbers may be slightly different depending on the compute device): 127 | 128 | ``` 129 | Top predictions: 130 | 131 | snake: 65.31% 132 | turtle: 12.29% 133 | sweet_pepper: 3.83% 134 | lizard: 1.88% 135 | crocodile: 1.75% 136 | ``` 137 | 138 | Note that this example uses the `encode_image()` and `encode_text()` methods that return the encoded features of given inputs. 139 | 140 | 141 | ### Linear-probe evaluation 142 | 143 | The example below uses [scikit-learn](https://scikit-learn.org/) to perform logistic regression on image features. 144 | 145 | ```python 146 | import os 147 | import clip 148 | import torch 149 | 150 | import numpy as np 151 | from sklearn.linear_model import LogisticRegression 152 | from torch.utils.data import DataLoader 153 | from torchvision.datasets import CIFAR100 154 | from tqdm import tqdm 155 | 156 | # Load the model 157 | device = "cuda" if torch.cuda.is_available() else "cpu" 158 | model, preprocess = clip.load('ViT-B/32', device) 159 | 160 | # Load the dataset 161 | root = os.path.expanduser("~/.cache") 162 | train = CIFAR100(root, download=True, train=True, transform=preprocess) 163 | test = CIFAR100(root, download=True, train=False, transform=preprocess) 164 | 165 | 166 | def get_features(dataset): 167 | all_features = [] 168 | all_labels = [] 169 | 170 | with torch.no_grad(): 171 | for images, labels in tqdm(DataLoader(dataset, batch_size=100)): 172 | features = model.encode_image(images.to(device)) 173 | 174 | all_features.append(features) 175 | all_labels.append(labels) 176 | 177 | return torch.cat(all_features).cpu().numpy(), torch.cat(all_labels).cpu().numpy() 178 | 179 | # Calculate the image features 180 | train_features, train_labels = get_features(train) 181 | test_features, test_labels = get_features(test) 182 | 183 | # Perform logistic regression 184 | classifier = LogisticRegression(random_state=0, C=0.316, max_iter=1000, verbose=1) 185 | classifier.fit(train_features, train_labels) 186 | 187 | # Evaluate using the logistic regression classifier 188 | predictions = classifier.predict(test_features) 189 | accuracy = np.mean((test_labels == predictions).astype(np.float)) * 100. 190 | print(f"Accuracy = {accuracy:.3f}") 191 | ``` 192 | 193 | Note that the `C` value should be determined via a hyperparameter sweep using a validation split. 194 | -------------------------------------------------------------------------------- /CLIP/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | from .clip_gradcam import ClipGradcam 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import torchvision 7 | from functools import reduce 8 | 9 | 10 | def factors(n): 11 | return set( 12 | reduce( 13 | list.__add__, 14 | ([i, n // i] for i in range(1, int(n**0.5) + 1) if n % i == 0), 15 | ) 16 | ) 17 | 18 | 19 | saliency_configs = { 20 | "ours": lambda img_dim: { 21 | "distractor_labels": {}, 22 | "horizontal_flipping": True, 23 | "augmentations": 5, 24 | "imagenet_prompt_ensemble": False, 25 | "positive_attn_only": True, 26 | "cropping_augmentations": [ 27 | {"tile_size": img_dim, "stride": img_dim // 4}, 28 | {"tile_size": int(img_dim * 2 / 3), "stride": int(img_dim * 2 / 3) // 4}, 29 | {"tile_size": img_dim // 2, "stride": (img_dim // 2) // 4}, 30 | {"tile_size": img_dim // 4, "stride": (img_dim // 4) // 4}, 31 | ], 32 | }, 33 | "chefer_et_al": lambda img_dim: { 34 | "distractor_labels": {}, 35 | "horizontal_flipping": False, 36 | "augmentations": 0, 37 | "imagenet_prompt_ensemble": False, 38 | "positive_attn_only": True, 39 | "cropping_augmentations": [{"tile_size": img_dim, "stride": img_dim // 4}], 40 | }, 41 | } 42 | 43 | 44 | class ClipWrapper: 45 | # SINGLETON WRAPPER 46 | clip_model = None 47 | clip_preprocess = None 48 | clip_gradcam = None 49 | lavt = None 50 | device = None 51 | jittering_transforms = None 52 | 53 | def __init__(self, clip_model_type, device, **kwargs): 54 | ClipWrapper.device = device 55 | ClipWrapper.jittering_transforms = torchvision.transforms.ColorJitter( 56 | brightness=0.6, contrast=0.6, saturation=0.6, hue=0.1 57 | ) 58 | ClipWrapper.clip_model, ClipWrapper.clip_preprocess = load( 59 | clip_model_type, ClipWrapper.device, **kwargs 60 | ) 61 | ClipWrapper.clip_gradcam = ClipGradcam( 62 | clip_model_name=clip_model_type, 63 | classes=[""], 64 | templates=["{}"], 65 | device=ClipWrapper.device, 66 | **kwargs 67 | ) 68 | 69 | @classmethod 70 | def check_initialized(cls, clip_model_type="ViT-B/32", **kwargs): 71 | if cls.clip_gradcam is None: 72 | ClipWrapper( 73 | clip_model_type=clip_model_type, 74 | device="cuda" if torch.cuda.is_available() else "cpu", 75 | **kwargs 76 | ) 77 | 78 | @classmethod 79 | def get_clip_text_feature(cls, string): 80 | ClipWrapper.check_initialized() 81 | with torch.no_grad(): 82 | return ( 83 | cls.clip_model.encode_text( 84 | tokenize(string, context_length=77).to(cls.device) 85 | ) 86 | .squeeze() 87 | .cpu() 88 | .numpy() 89 | ) 90 | 91 | @classmethod 92 | def get_visual_feature(cls, rgb, tile_attn_mask, device=None): 93 | if device is None: 94 | device = ClipWrapper.device 95 | ClipWrapper.check_initialized() 96 | rgb = ClipWrapper.clip_preprocess(Image.fromarray(rgb)).unsqueeze(0) 97 | with torch.no_grad(): 98 | clip_feature = ClipWrapper.clip_model.encode_image( 99 | rgb.to(ClipWrapper.device), tile_attn_mask=tile_attn_mask 100 | ).squeeze() 101 | return clip_feature.to(device) 102 | 103 | @classmethod 104 | def get_clip_saliency( 105 | cls, 106 | img, 107 | text_labels, 108 | prompts, 109 | distractor_labels=set(), 110 | use_lavt=False, 111 | **kwargs 112 | ): 113 | cls.check_initialized() 114 | if use_lavt: 115 | return cls.lavt.localize(img=img, prompts=text_labels) 116 | cls.clip_gradcam.templates = prompts 117 | cls.clip_gradcam.set_classes(text_labels) 118 | text_label_features = torch.stack( 119 | list(cls.clip_gradcam.class_to_language_feature.values()), dim=0 120 | ) 121 | text_label_features = text_label_features.squeeze(dim=-1).cpu() 122 | text_maps = cls.get_clip_saliency_convolve( 123 | img=img, text_labels=text_labels, **kwargs 124 | ) 125 | if len(distractor_labels) > 0: 126 | distractor_labels = set(distractor_labels) - set(text_labels) 127 | cls.clip_gradcam.set_classes(list(distractor_labels)) 128 | distractor_maps = cls.get_clip_saliency_convolve( 129 | img=img, text_labels=list(distractor_labels), **kwargs 130 | ) 131 | text_maps -= distractor_maps.mean(dim=0) 132 | text_maps = text_maps.cpu() 133 | return text_maps, text_label_features.squeeze(dim=-1) 134 | 135 | @classmethod 136 | def get_clip_saliency_convolve( 137 | cls, 138 | text_labels, 139 | horizontal_flipping=False, 140 | positive_attn_only: bool = False, 141 | tile_batch_size=32, 142 | prompt_batch_size=32, 143 | tile_interpolate_batch_size=32, 144 | **kwargs 145 | ): 146 | cls.clip_gradcam.positive_attn_only = positive_attn_only 147 | tiles, tile_imgs, counts, tile_sizes = cls.create_tiles(**kwargs) 148 | outputs = { 149 | k: torch.zeros( 150 | [len(text_labels)] + list(count.shape), device=cls.device 151 | ).half() 152 | for k, count in counts.items() 153 | } 154 | tile_gradcams = torch.cat( 155 | [ 156 | torch.cat( 157 | [ 158 | cls.clip_gradcam( 159 | x=tile_imgs[tile_idx : tile_idx + tile_batch_size], 160 | o=text_labels[prompt_idx : prompt_idx + prompt_batch_size], 161 | ) 162 | for tile_idx in np.arange(0, len(tile_imgs), tile_batch_size) 163 | ], 164 | dim=1, 165 | ) 166 | for prompt_idx in np.arange(0, len(text_labels), prompt_batch_size) 167 | ], 168 | dim=0, 169 | ) 170 | if horizontal_flipping: 171 | flipped_tile_imgs = tile_imgs[ 172 | ..., torch.flip(torch.arange(0, tile_imgs.shape[-1]), dims=[0]) 173 | ] 174 | flipped_tile_gradcams = torch.cat( 175 | [ 176 | torch.cat( 177 | [ 178 | cls.clip_gradcam( 179 | x=flipped_tile_imgs[ 180 | tile_idx : tile_idx + tile_batch_size 181 | ], 182 | o=text_labels[ 183 | prompt_idx : prompt_idx + prompt_batch_size 184 | ], 185 | ) 186 | for tile_idx in np.arange( 187 | 0, len(tile_imgs), tile_batch_size 188 | ) 189 | ], 190 | dim=1, 191 | ) 192 | for prompt_idx in np.arange(0, len(text_labels), prompt_batch_size) 193 | ], 194 | dim=0, 195 | ) 196 | with torch.no_grad(): 197 | flipped_tile_gradcams = flipped_tile_gradcams[ 198 | ..., 199 | torch.flip( 200 | torch.arange(0, flipped_tile_gradcams.shape[-1]), dims=[0] 201 | ), 202 | ] 203 | tile_gradcams = (tile_gradcams + flipped_tile_gradcams) / 2 204 | del flipped_tile_gradcams 205 | with torch.no_grad(): 206 | torch.cuda.empty_cache() 207 | for tile_size in np.unique(tile_sizes): 208 | tile_size_mask = tile_sizes == tile_size 209 | curr_size_grads = tile_gradcams[:, tile_size_mask] 210 | curr_size_tiles = tiles[tile_size_mask] 211 | for tile_idx in np.arange( 212 | 0, curr_size_grads.shape[1], tile_interpolate_batch_size 213 | ): 214 | resized_tiles = torch.nn.functional.interpolate( 215 | curr_size_grads[ 216 | :, tile_idx : tile_idx + tile_interpolate_batch_size 217 | ], 218 | size=tile_size, 219 | mode="bilinear", 220 | align_corners=False, 221 | ) 222 | for tile_idx, tile_slice in enumerate( 223 | curr_size_tiles[ 224 | tile_idx : tile_idx + tile_interpolate_batch_size 225 | ] 226 | ): 227 | outputs[tile_size][tile_slice] += resized_tiles[ 228 | :, tile_idx, ... 229 | ] 230 | output = sum( 231 | output.float() / count 232 | for output, count in zip(outputs.values(), counts.values()) 233 | ) / len(counts) 234 | del outputs, counts, tile_gradcams 235 | output = output.cpu() 236 | return output 237 | 238 | @classmethod 239 | def create_tiles(cls, img, augmentations, cropping_augmentations, **kwargs): 240 | assert type(img) == np.ndarray 241 | images = [] 242 | cls.check_initialized() 243 | # compute image crops 244 | img_pil = Image.fromarray(img) 245 | images.append(np.array(img_pil)) 246 | for _ in range(augmentations): 247 | images.append(np.array(cls.jittering_transforms(img_pil))) 248 | # for taking average 249 | counts = { 250 | crop_aug["tile_size"]: torch.zeros(img.shape[:2], device=cls.device).float() 251 | + 1e-5 252 | for crop_aug in cropping_augmentations 253 | } 254 | tiles = [] 255 | tile_imgs = [] 256 | tile_sizes = [] 257 | for img in images: 258 | for crop_aug in cropping_augmentations: 259 | tile_size = crop_aug["tile_size"] 260 | stride = crop_aug["stride"] 261 | for y in np.arange(0, img.shape[1] - tile_size + 1, stride): 262 | if y >= img.shape[0]: 263 | continue 264 | for x in np.arange(0, img.shape[0] - tile_size + 1, stride): 265 | if x >= img.shape[1]: 266 | continue 267 | tile = ( 268 | slice(None, None), 269 | slice(x, x + tile_size), 270 | slice(y, y + tile_size), 271 | ) 272 | tiles.append(tile) 273 | counts[tile_size][tile[1:]] += 1 274 | tile_sizes.append(tile_size) 275 | # this is currently biggest bottle neck 276 | tile_imgs.append( 277 | cls.clip_gradcam.preprocess( 278 | Image.fromarray(img[tiles[-1][1:]]) 279 | ) 280 | ) 281 | tile_imgs = torch.stack(tile_imgs).to(cls.device) 282 | return np.array(tiles), tile_imgs, counts, np.array(tile_sizes) 283 | 284 | 285 | imagenet_templates = [ 286 | "a bad photo of a {}.", 287 | "a photo of many {}.", 288 | "a sculpture of a {}.", 289 | "a photo of the hard to see {}.", 290 | "a low resolution photo of the {}.", 291 | "a rendering of a {}.", 292 | "graffiti of a {}.", 293 | "a bad photo of the {}.", 294 | "a cropped photo of the {}.", 295 | "a tattoo of a {}.", 296 | "the embroidered {}.", 297 | "a photo of a hard to see {}.", 298 | "a bright photo of a {}.", 299 | "a photo of a clean {}.", 300 | "a photo of a dirty {}.", 301 | "a dark photo of the {}.", 302 | "a drawing of a {}.", 303 | "a photo of my {}.", 304 | "the plastic {}.", 305 | "a photo of the cool {}.", 306 | "a close-up photo of a {}.", 307 | "a black and white photo of the {}.", 308 | "a painting of the {}.", 309 | "a painting of a {}.", 310 | "a pixelated photo of the {}.", 311 | "a sculpture of the {}.", 312 | "a bright photo of the {}.", 313 | "a cropped photo of a {}.", 314 | "a plastic {}.", 315 | "a photo of the dirty {}.", 316 | "a jpeg corrupted photo of a {}.", 317 | "a blurry photo of the {}.", 318 | "a photo of the {}.", 319 | "a good photo of the {}.", 320 | "a rendering of the {}.", 321 | "a {} in a video game.", 322 | "a photo of one {}.", 323 | "a doodle of a {}.", 324 | "a close-up photo of the {}.", 325 | "a photo of a {}.", 326 | "the origami {}.", 327 | "the {} in a video game.", 328 | "a sketch of a {}.", 329 | "a doodle of the {}.", 330 | "a origami {}.", 331 | "a low resolution photo of a {}.", 332 | "the toy {}.", 333 | "a rendition of the {}.", 334 | "a photo of the clean {}.", 335 | "a photo of a large {}.", 336 | "a rendition of a {}.", 337 | "a photo of a nice {}.", 338 | "a photo of a weird {}.", 339 | "a blurry photo of a {}.", 340 | "a cartoon {}.", 341 | "art of a {}.", 342 | "a sketch of the {}.", 343 | "a embroidered {}.", 344 | "a pixelated photo of a {}.", 345 | "itap of the {}.", 346 | "a jpeg corrupted photo of the {}.", 347 | "a good photo of a {}.", 348 | "a plushie {}.", 349 | "a photo of the nice {}.", 350 | "a photo of the small {}.", 351 | "a photo of the weird {}.", 352 | "the cartoon {}.", 353 | "art of the {}.", 354 | "a drawing of the {}.", 355 | "a photo of the large {}.", 356 | "a black and white photo of a {}.", 357 | "the plushie {}.", 358 | "a dark photo of a {}.", 359 | "itap of a {}.", 360 | "graffiti of the {}.", 361 | "a toy {}.", 362 | "itap of my {}.", 363 | "a photo of a cool {}.", 364 | "a photo of a small {}.", 365 | "a tattoo of the {}.", 366 | ] 367 | 368 | __all__ = ["ClipWrapper", "imagenet_templates"] 369 | -------------------------------------------------------------------------------- /CLIP/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/CLIP/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /CLIP/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | # if [int(i) for i in torch.__version__.split(".")] < [1, 7, 1]: 24 | # warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | __all__ = ["available_models", "load", "tokenize", "tokenizer"] 27 | tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 37 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 38 | } 39 | 40 | 41 | def _download(url: str, root: str): 42 | os.makedirs(root, exist_ok=True) 43 | filename = os.path.basename(url) 44 | 45 | expected_sha256 = url.split("/")[-2] 46 | download_target = os.path.join(root, filename) 47 | 48 | if os.path.exists(download_target) and not os.path.isfile(download_target): 49 | raise RuntimeError(f"{download_target} exists and is not a regular file") 50 | 51 | if os.path.isfile(download_target): 52 | if ( 53 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 54 | == expected_sha256 55 | ): 56 | return download_target 57 | else: 58 | warnings.warn( 59 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 60 | ) 61 | 62 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 63 | with tqdm( 64 | total=int(source.info().get("Content-Length")), 65 | ncols=80, 66 | unit="iB", 67 | unit_scale=True, 68 | unit_divisor=1024, 69 | ) as loop: 70 | while True: 71 | buffer = source.read(8192) 72 | if not buffer: 73 | break 74 | 75 | output.write(buffer) 76 | loop.update(len(buffer)) 77 | 78 | if ( 79 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 80 | != expected_sha256 81 | ): 82 | raise RuntimeError( 83 | f"Model has been downloaded but the SHA256 checksum does not not match" 84 | ) 85 | 86 | return download_target 87 | 88 | 89 | def _convert_image_to_rgb(image): 90 | return image.convert("RGB") 91 | 92 | 93 | def _transform(n_px, overload_resolution=False): 94 | transforms = [ 95 | _convert_image_to_rgb, 96 | ToTensor(), 97 | Normalize( 98 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 99 | ), 100 | ] 101 | if not overload_resolution: 102 | transforms = [Resize(224, interpolation=BICUBIC), CenterCrop(n_px)] + transforms 103 | return Compose(transforms) 104 | 105 | 106 | def available_models() -> List[str]: 107 | """Returns the names of available CLIP models""" 108 | return list(_MODELS.keys()) 109 | 110 | 111 | def load( 112 | name: str, 113 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 114 | jit: bool = False, 115 | download_root: str = None, 116 | overload_resolution=False, 117 | ): 118 | """Load a CLIP model 119 | Parameters 120 | ---------- 121 | name : str 122 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 123 | device : Union[str, torch.device] 124 | The device to put the loaded model 125 | jit : bool 126 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 127 | download_root: str 128 | path to download the model files; by default, it uses "~/.cache/clip" 129 | Returns 130 | ------- 131 | model : torch.nn.Module 132 | The CLIP model 133 | preprocess : Callable[[PIL.Image], torch.Tensor] 134 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 135 | """ 136 | if name in _MODELS: 137 | model_path = _download( 138 | _MODELS[name], download_root or os.path.expanduser("~/.cache/clip") 139 | ) 140 | elif os.path.isfile(name): 141 | model_path = name 142 | else: 143 | raise RuntimeError( 144 | f"Model {name} not found; available models = {available_models()}" 145 | ) 146 | 147 | try: 148 | # loading JIT archive 149 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 150 | state_dict = None 151 | except RuntimeError: 152 | # loading saved state dict 153 | if jit: 154 | warnings.warn( 155 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 156 | ) 157 | jit = False 158 | state_dict = torch.load(model_path, map_location="cpu") 159 | 160 | if not jit: 161 | model = build_model(state_dict or model.state_dict()).to(device) 162 | if str(device) == "cpu": 163 | model.float() 164 | return model, _transform(model.visual.input_resolution, overload_resolution) 165 | 166 | # patch the device names 167 | device_holder = torch.jit.trace( 168 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] 169 | ) 170 | device_node = [ 171 | n 172 | for n in device_holder.graph.findAllNodes("prim::Constant") 173 | if "Device" in repr(n) 174 | ][-1] 175 | 176 | def patch_device(module): 177 | try: 178 | graphs = [module.graph] if hasattr(module, "graph") else [] 179 | except RuntimeError: 180 | graphs = [] 181 | 182 | if hasattr(module, "forward1"): 183 | graphs.append(module.forward1.graph) 184 | 185 | for graph in graphs: 186 | for node in graph.findAllNodes("prim::Constant"): 187 | if "value" in node.attributeNames() and str(node["value"]).startswith( 188 | "cuda" 189 | ): 190 | node.copyAttributes(device_node) 191 | 192 | model.apply(patch_device) 193 | patch_device(model.encode_image) 194 | patch_device(model.encode_text) 195 | 196 | # patch dtype to float32 on CPU 197 | if str(device) == "cpu": 198 | float_holder = torch.jit.trace( 199 | lambda: torch.ones([]).float(), example_inputs=[] 200 | ) 201 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 202 | float_node = float_input.node() 203 | 204 | def patch_float(module): 205 | try: 206 | graphs = [module.graph] if hasattr(module, "graph") else [] 207 | except RuntimeError: 208 | graphs = [] 209 | 210 | if hasattr(module, "forward1"): 211 | graphs.append(module.forward1.graph) 212 | 213 | for graph in graphs: 214 | for node in graph.findAllNodes("aten::to"): 215 | inputs = list(node.inputs()) 216 | for i in [ 217 | 1, 218 | 2, 219 | ]: # dtype can be the second or third argument to aten::to() 220 | if inputs[i].node()["value"] == 5: 221 | inputs[i].node().copyAttributes(float_node) 222 | 223 | model.apply(patch_float) 224 | patch_float(model.encode_image) 225 | patch_float(model.encode_text) 226 | 227 | model.float() 228 | 229 | return model, _transform(model.input_resolution.item(), overload_resolution) 230 | 231 | 232 | def tokenize( 233 | texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False 234 | ) -> torch.LongTensor: 235 | """ 236 | Returns the tokenized representation of given input string(s) 237 | 238 | Parameters 239 | ---------- 240 | texts : Union[str, List[str]] 241 | An input string or a list of input strings to tokenize 242 | 243 | context_length : int 244 | The context length to use; all CLIP models use 77 as the context length 245 | 246 | truncate: bool 247 | Whether to truncate the text in case its encoding is longer than the context length 248 | 249 | Returns 250 | ------- 251 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 252 | """ 253 | if isinstance(texts, str): 254 | texts = [texts] 255 | 256 | sot_token = tokenizer.encoder["<|startoftext|>"] 257 | eot_token = tokenizer.encoder["<|endoftext|>"] 258 | all_tokens = [[sot_token] + tokenizer.encode(text) + [eot_token] for text in texts] 259 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 260 | 261 | for i, tokens in enumerate(all_tokens): 262 | if len(tokens) > context_length: 263 | if truncate: 264 | tokens = tokens[:context_length] 265 | tokens[-1] = eot_token 266 | else: 267 | raise RuntimeError( 268 | f"Input {texts[i]} is too long for context length {context_length}" 269 | ) 270 | result[i, : len(tokens)] = torch.tensor(tokens) 271 | 272 | return result 273 | -------------------------------------------------------------------------------- /CLIP/clip/clip_explainability.py: -------------------------------------------------------------------------------- 1 | # modified from: https://github.com/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP/clip/clip.py 2 | 3 | import hashlib 4 | import os 5 | import urllib 6 | import warnings 7 | from typing import Any, Union, List 8 | from pkg_resources import packaging 9 | 10 | import torch 11 | from PIL import Image 12 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 13 | from tqdm import tqdm 14 | 15 | from .model_explainability import build_model 16 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 17 | 18 | try: 19 | from torchvision.transforms import InterpolationMode 20 | 21 | BICUBIC = InterpolationMode.BICUBIC 22 | except ImportError: 23 | BICUBIC = Image.BICUBIC 24 | 25 | 26 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 27 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 28 | 29 | 30 | __all__ = ["available_models", "load", "tokenize"] 31 | _tokenizer = _Tokenizer() 32 | 33 | 34 | _MODELS = { 35 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 36 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 37 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 38 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 39 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 40 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 41 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 42 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 43 | } 44 | 45 | 46 | def _download(url: str, root: str): 47 | os.makedirs(root, exist_ok=True) 48 | filename = os.path.basename(url) 49 | 50 | expected_sha256 = url.split("/")[-2] 51 | download_target = os.path.join(root, filename) 52 | 53 | if os.path.exists(download_target) and not os.path.isfile(download_target): 54 | raise RuntimeError(f"{download_target} exists and is not a regular file") 55 | 56 | if os.path.isfile(download_target): 57 | if ( 58 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 59 | == expected_sha256 60 | ): 61 | return download_target 62 | else: 63 | warnings.warn( 64 | f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" 65 | ) 66 | 67 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 68 | with tqdm( 69 | total=int(source.info().get("Content-Length")), 70 | ncols=80, 71 | unit="iB", 72 | unit_scale=True, 73 | unit_divisor=1024, 74 | ) as loop: 75 | while True: 76 | buffer = source.read(8192) 77 | if not buffer: 78 | break 79 | 80 | output.write(buffer) 81 | loop.update(len(buffer)) 82 | 83 | if ( 84 | hashlib.sha256(open(download_target, "rb").read()).hexdigest() 85 | != expected_sha256 86 | ): 87 | raise RuntimeError( 88 | f"Model has been downloaded but the SHA256 checksum does not not match" 89 | ) 90 | 91 | return download_target 92 | 93 | 94 | def _convert_image_to_rgb(image): 95 | return image.convert("RGB") 96 | 97 | 98 | def _transform(n_px, overload_resolution=False): 99 | transforms = [ 100 | _convert_image_to_rgb, 101 | ToTensor(), 102 | Normalize( 103 | (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711) 104 | ), 105 | ] 106 | if not overload_resolution: 107 | transforms = [Resize(224, interpolation=BICUBIC), CenterCrop(n_px)] + transforms 108 | return Compose(transforms) 109 | 110 | 111 | def available_models() -> List[str]: 112 | """Returns the names of available CLIP models""" 113 | return list(_MODELS.keys()) 114 | 115 | 116 | def load( 117 | name: str, 118 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 119 | jit: bool = False, 120 | download_root: str = None, 121 | overload_resolution=False, 122 | ): 123 | """Load a CLIP model 124 | Parameters 125 | ---------- 126 | name : str 127 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 128 | device : Union[str, torch.device] 129 | The device to put the loaded model 130 | jit : bool 131 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 132 | download_root: str 133 | path to download the model files; by default, it uses "~/.cache/clip" 134 | Returns 135 | ------- 136 | model : torch.nn.Module 137 | The CLIP model 138 | preprocess : Callable[[PIL.Image], torch.Tensor] 139 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 140 | """ 141 | if name in _MODELS: 142 | model_path = _download( 143 | _MODELS[name], download_root or os.path.expanduser("~/.cache/clip") 144 | ) 145 | elif os.path.isfile(name): 146 | model_path = name 147 | else: 148 | raise RuntimeError( 149 | f"Model {name} not found; available models = {available_models()}" 150 | ) 151 | 152 | try: 153 | # loading JIT archive 154 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 155 | state_dict = None 156 | except RuntimeError: 157 | # loading saved state dict 158 | if jit: 159 | warnings.warn( 160 | f"File {model_path} is not a JIT archive. Loading as a state dict instead" 161 | ) 162 | jit = False 163 | state_dict = torch.load(model_path, map_location="cpu") 164 | 165 | if not jit: 166 | model = build_model(state_dict or model.state_dict()).to(device) 167 | if str(device) == "cpu": 168 | model.float() 169 | return model, _transform(model.visual.input_resolution, overload_resolution) 170 | 171 | # patch the device names 172 | device_holder = torch.jit.trace( 173 | lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] 174 | ) 175 | device_node = [ 176 | n 177 | for n in device_holder.graph.findAllNodes("prim::Constant") 178 | if "Device" in repr(n) 179 | ][-1] 180 | 181 | def patch_device(module): 182 | try: 183 | graphs = [module.graph] if hasattr(module, "graph") else [] 184 | except RuntimeError: 185 | graphs = [] 186 | 187 | if hasattr(module, "forward1"): 188 | graphs.append(module.forward1.graph) 189 | 190 | for graph in graphs: 191 | for node in graph.findAllNodes("prim::Constant"): 192 | if "value" in node.attributeNames() and str(node["value"]).startswith( 193 | "cuda" 194 | ): 195 | node.copyAttributes(device_node) 196 | 197 | model.apply(patch_device) 198 | patch_device(model.encode_image) 199 | patch_device(model.encode_text) 200 | 201 | # patch dtype to float32 on CPU 202 | if str(device) == "cpu": 203 | float_holder = torch.jit.trace( 204 | lambda: torch.ones([]).float(), example_inputs=[] 205 | ) 206 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 207 | float_node = float_input.node() 208 | 209 | def patch_float(module): 210 | try: 211 | graphs = [module.graph] if hasattr(module, "graph") else [] 212 | except RuntimeError: 213 | graphs = [] 214 | 215 | if hasattr(module, "forward1"): 216 | graphs.append(module.forward1.graph) 217 | 218 | for graph in graphs: 219 | for node in graph.findAllNodes("aten::to"): 220 | inputs = list(node.inputs()) 221 | for i in [ 222 | 1, 223 | 2, 224 | ]: # dtype can be the second or third argument to aten::to() 225 | if inputs[i].node()["value"] == 5: 226 | inputs[i].node().copyAttributes(float_node) 227 | 228 | model.apply(patch_float) 229 | patch_float(model.encode_image) 230 | patch_float(model.encode_text) 231 | 232 | model.float() 233 | 234 | return model, _transform(model.input_resolution.item(), overload_resolution) 235 | 236 | 237 | def tokenize( 238 | texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False 239 | ) -> torch.LongTensor: 240 | """ 241 | Returns the tokenized representation of given input string(s) 242 | Parameters 243 | ---------- 244 | texts : Union[str, List[str]] 245 | An input string or a list of input strings to tokenize 246 | context_length : int 247 | The context length to use; all CLIP models use 77 as the context length 248 | truncate: bool 249 | Whether to truncate the text in case its encoding is longer than the context length 250 | Returns 251 | ------- 252 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 253 | """ 254 | if isinstance(texts, str): 255 | texts = [texts] 256 | 257 | sot_token = _tokenizer.encoder["<|startoftext|>"] 258 | eot_token = _tokenizer.encoder["<|endoftext|>"] 259 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 260 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 261 | 262 | for i, tokens in enumerate(all_tokens): 263 | if len(tokens) > context_length: 264 | if truncate: 265 | tokens = tokens[:context_length] 266 | tokens[-1] = eot_token 267 | else: 268 | raise RuntimeError( 269 | f"Input {texts[i]} is too long for context length {context_length}" 270 | ) 271 | result[i, : len(tokens)] = torch.tensor(tokens) 272 | 273 | return result 274 | -------------------------------------------------------------------------------- /CLIP/clip/clip_gradcam.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import torch.nn as nn 4 | from .clip_explainability import load 5 | from .clip import tokenize 6 | from torch import device 7 | import numpy as np 8 | import torch.nn.functional as nnf 9 | import itertools 10 | 11 | 12 | def zeroshot_classifier(clip_model, classnames, templates, device): 13 | with torch.no_grad(): 14 | texts = list( 15 | itertools.chain( 16 | *[ 17 | [template.format(classname) for template in templates] 18 | for classname in classnames 19 | ] 20 | ) 21 | ) # format with class 22 | texts = tokenize(texts).to(device) # tokenize 23 | class_embeddings = clip_model.encode_text(texts) 24 | class_embeddings = class_embeddings.view(len(classnames), len(templates), -1) 25 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 26 | zeroshot_weights = class_embeddings.mean(dim=1) 27 | return zeroshot_weights.T # shape: [dim, n classes] 28 | 29 | 30 | class ClipGradcam(nn.Module): 31 | def __init__( 32 | self, 33 | clip_model_name: str, 34 | classes: List[str], 35 | templates: List[str], 36 | device: device, 37 | num_layers=10, 38 | positive_attn_only=False, 39 | **kwargs 40 | ): 41 | 42 | super(ClipGradcam, self).__init__() 43 | self.clip_model_name = clip_model_name 44 | self.model, self.preprocess = load(clip_model_name, device=device, **kwargs) 45 | self.templates = templates 46 | self.device = device 47 | self.target_classes = None 48 | self.set_classes(classes) 49 | self.num_layers = num_layers 50 | self.positive_attn_only = positive_attn_only 51 | self.num_res_attn_blocks = { 52 | "ViT-B/32": 12, 53 | "ViT-B/16": 12, 54 | "ViT-L/14": 16, 55 | "ViT-L/14@336px": 16, 56 | }[clip_model_name] 57 | 58 | def forward(self, x: torch.Tensor, o: List[str]): 59 | """ 60 | non-standard hack around an nn, really should be more principled here 61 | """ 62 | image_features = self.model.encode_image(x.to(self.device)) 63 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 64 | zeroshot_weights = torch.cat( 65 | [self.class_to_language_feature[prompt] for prompt in o], dim=1 66 | ) 67 | logits_per_image = 100.0 * image_features @ zeroshot_weights 68 | return self.interpret(logits_per_image, self.model, self.device) 69 | 70 | def interpret(self, logits_per_image, model, device): 71 | # modified from: https://colab.research.google.com/github/hila-chefer/Transformer-MM-Explainability/blob/main/CLIP_explainability.ipynb#scrollTo=fWKGyu2YAeSV 72 | batch_size = logits_per_image.shape[0] 73 | num_prompts = logits_per_image.shape[1] 74 | one_hot = [logit for logit in logits_per_image.sum(dim=0)] 75 | model.zero_grad() 76 | 77 | image_attn_blocks = list( 78 | dict(model.visual.transformer.resblocks.named_children()).values() 79 | ) 80 | num_tokens = image_attn_blocks[0].attn_probs.shape[-1] 81 | R = torch.eye( 82 | num_tokens, num_tokens, dtype=image_attn_blocks[0].attn_probs.dtype 83 | ).to(device) 84 | R = R[None, None, :, :].repeat(num_prompts, batch_size, 1, 1) 85 | for i, block in enumerate(image_attn_blocks): 86 | if i <= self.num_layers: 87 | continue 88 | # TODO try scaling block.attn_probs by value magnitude 89 | # TODO actual parallelized prompt gradients 90 | grad = torch.stack( 91 | [ 92 | torch.autograd.grad(logit, [block.attn_probs], retain_graph=True)[ 93 | 0 94 | ].detach() 95 | for logit in one_hot 96 | ] 97 | ) 98 | grad = grad.view( 99 | num_prompts, 100 | batch_size, 101 | self.num_res_attn_blocks, 102 | num_tokens, 103 | num_tokens, 104 | ) 105 | cam = ( 106 | block.attn_probs.view( 107 | 1, batch_size, self.num_res_attn_blocks, num_tokens, num_tokens 108 | ) 109 | .detach() 110 | .repeat(num_prompts, 1, 1, 1, 1) 111 | ) 112 | cam = cam.reshape(num_prompts, batch_size, -1, cam.shape[-1], cam.shape[-1]) 113 | grad = grad.reshape( 114 | num_prompts, batch_size, -1, grad.shape[-1], grad.shape[-1] 115 | ) 116 | cam = grad * cam 117 | cam = cam.reshape( 118 | num_prompts * batch_size, -1, cam.shape[-1], cam.shape[-1] 119 | ) 120 | if self.positive_attn_only: 121 | cam = cam.clamp(min=0) 122 | # average of all heads 123 | cam = cam.mean(dim=-3) 124 | R = R + torch.bmm( 125 | cam, R.view(num_prompts * batch_size, num_tokens, num_tokens) 126 | ).view(num_prompts, batch_size, num_tokens, num_tokens) 127 | image_relevance = R[:, :, 0, 1:] 128 | img_dim = int(np.sqrt(num_tokens - 1)) 129 | image_relevance = image_relevance.reshape( 130 | num_prompts, batch_size, img_dim, img_dim 131 | ) 132 | return image_relevance 133 | 134 | def set_classes(self, classes): 135 | self.target_classes = classes 136 | language_features = zeroshot_classifier( 137 | self.model, self.target_classes, self.templates, self.device 138 | ) 139 | 140 | self.class_to_language_feature = {} 141 | for i, c in enumerate(self.target_classes): 142 | self.class_to_language_feature[c] = language_features[:, [i]] 143 | -------------------------------------------------------------------------------- /CLIP/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join( 13 | os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" 14 | ) 15 | 16 | 17 | @lru_cache() 18 | def bytes_to_unicode(): 19 | """ 20 | Returns list of utf-8 byte and a corresponding list of unicode strings. 21 | The reversible bpe codes work on unicode strings. 22 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 23 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 24 | This is a signficant percentage of your normal, say, 32K bpe vocab. 25 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 26 | And avoids mapping to whitespace/control characters the bpe code barfs on. 27 | """ 28 | bs = ( 29 | list(range(ord("!"), ord("~") + 1)) 30 | + list(range(ord("¡"), ord("¬") + 1)) 31 | + list(range(ord("®"), ord("ÿ") + 1)) 32 | ) 33 | cs = bs[:] 34 | n = 0 35 | for b in range(2**8): 36 | if b not in bs: 37 | bs.append(b) 38 | cs.append(2**8 + n) 39 | n += 1 40 | cs = [chr(n) for n in cs] 41 | return dict(zip(bs, cs)) 42 | 43 | 44 | def get_pairs(word): 45 | """Return set of symbol pairs in a word. 46 | Word is represented as tuple of symbols (symbols being variable-length strings). 47 | """ 48 | pairs = set() 49 | prev_char = word[0] 50 | for char in word[1:]: 51 | pairs.add((prev_char, char)) 52 | prev_char = char 53 | return pairs 54 | 55 | 56 | def basic_clean(text): 57 | text = ftfy.fix_text(text) 58 | text = html.unescape(html.unescape(text)) 59 | return text.strip() 60 | 61 | 62 | def whitespace_clean(text): 63 | text = re.sub(r"\s+", " ", text) 64 | text = text.strip() 65 | return text 66 | 67 | 68 | class SimpleTokenizer(object): 69 | def __init__(self, bpe_path: str = default_bpe()): 70 | self.byte_encoder = bytes_to_unicode() 71 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 72 | merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") 73 | merges = merges[1 : 49152 - 256 - 2 + 1] 74 | merges = [tuple(merge.split()) for merge in merges] 75 | vocab = list(bytes_to_unicode().values()) 76 | vocab = vocab + [v + "" for v in vocab] 77 | for merge in merges: 78 | vocab.append("".join(merge)) 79 | vocab.extend(["<|startoftext|>", "<|endoftext|>"]) 80 | self.encoder = dict(zip(vocab, range(len(vocab)))) 81 | self.decoder = {v: k for k, v in self.encoder.items()} 82 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 83 | self.cache = { 84 | "<|startoftext|>": "<|startoftext|>", 85 | "<|endoftext|>": "<|endoftext|>", 86 | } 87 | self.pat = re.compile( 88 | r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", 89 | re.IGNORECASE, 90 | ) 91 | 92 | def bpe(self, token): 93 | if token in self.cache: 94 | return self.cache[token] 95 | word = tuple(token[:-1]) + (token[-1] + "",) 96 | pairs = get_pairs(word) 97 | 98 | if not pairs: 99 | return token + "" 100 | 101 | while True: 102 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) 103 | if bigram not in self.bpe_ranks: 104 | break 105 | first, second = bigram 106 | new_word = [] 107 | i = 0 108 | while i < len(word): 109 | try: 110 | j = word.index(first, i) 111 | new_word.extend(word[i:j]) 112 | i = j 113 | except: 114 | new_word.extend(word[i:]) 115 | break 116 | 117 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 118 | new_word.append(first + second) 119 | i += 2 120 | else: 121 | new_word.append(word[i]) 122 | i += 1 123 | new_word = tuple(new_word) 124 | word = new_word 125 | if len(word) == 1: 126 | break 127 | else: 128 | pairs = get_pairs(word) 129 | word = " ".join(word) 130 | self.cache[token] = word 131 | return word 132 | 133 | def encode(self, text): 134 | bpe_tokens = [] 135 | text = whitespace_clean(basic_clean(text)).lower() 136 | for token in re.findall(self.pat, text): 137 | token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) 138 | bpe_tokens.extend( 139 | self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") 140 | ) 141 | return bpe_tokens 142 | 143 | def decode(self, tokens): 144 | text = "".join([self.decoder[token] for token in tokens]) 145 | text = ( 146 | bytearray([self.byte_decoder[c] for c in text]) 147 | .decode("utf-8", errors="replace") 148 | .replace("", " ") 149 | ) 150 | return text 151 | -------------------------------------------------------------------------------- /CLIP/model-card.md: -------------------------------------------------------------------------------- 1 | # Model Card: CLIP 2 | 3 | Inspired by [Model Cards for Model Reporting (Mitchell et al.)](https://arxiv.org/abs/1810.03993) and [Lessons from Archives (Jo & Gebru)](https://arxiv.org/pdf/1912.10389.pdf), we’re providing some accompanying information about the multimodal model. 4 | 5 | ## Model Details 6 | 7 | The CLIP model was developed by researchers at OpenAI to learn about what contributes to robustness in computer vision tasks. The model was also developed to test the ability of models to generalize to arbitrary image classification tasks in a zero-shot manner. It was not developed for general model deployment - to deploy models like CLIP, researchers will first need to carefully study their capabilities in relation to the specific context they’re being deployed within. 8 | 9 | ### Model Date 10 | 11 | January 2021 12 | 13 | ### Model Type 14 | 15 | The base model uses a ResNet50 with several modifications as an image encoder and uses a masked self-attention Transformer as a text encoder. These encoders are trained to maximize the similarity of (image, text) pairs via a contrastive loss. There is also a variant of the model where the ResNet image encoder is replaced with a Vision Transformer. 16 | 17 | ### Model Versions 18 | 19 | Initially, we’ve released one CLIP model based on the Vision Transformer architecture equivalent to ViT-B/32, along with the RN50 model, using the architecture equivalent to ResNet-50. 20 | 21 | As part of the staged release process, we have also released the RN101 model, as well as RN50x4, a RN50 scaled up 4x according to the [EfficientNet](https://arxiv.org/abs/1905.11946) scaling rule. In July 2021, we additionally released the RN50x16 and ViT-B/16 models. 22 | 23 | Please see the paper linked below for further details about their specification. 24 | 25 | ### Documents 26 | 27 | - [Blog Post](https://openai.com/blog/clip/) 28 | - [CLIP Paper](https://arxiv.org/abs/2103.00020) 29 | 30 | 31 | 32 | ## Model Use 33 | 34 | ### Intended Use 35 | 36 | The model is intended as a research output for research communities. We hope that this model will enable researchers to better understand and explore zero-shot, arbitrary image classification. We also hope it can be used for interdisciplinary studies of the potential impact of such models - the CLIP paper includes a discussion of potential downstream impacts to provide an example for this sort of analysis. 37 | 38 | #### Primary intended uses 39 | 40 | The primary intended users of these models are AI researchers. 41 | 42 | We primarily imagine the model will be used by researchers to better understand robustness, generalization, and other capabilities, biases, and constraints of computer vision models. 43 | 44 | ### Out-of-Scope Use Cases 45 | 46 | **Any** deployed use case of the model - whether commercial or not - is currently out of scope. Non-deployed use cases such as image search in a constrained environment, are also not recommended unless there is thorough in-domain testing of the model with a specific, fixed class taxonomy. This is because our safety assessment demonstrated a high need for task specific testing especially given the variability of CLIP’s performance with different class taxonomies. This makes untested and unconstrained deployment of the model in any use case currently potentially harmful. 47 | 48 | Certain use cases which would fall under the domain of surveillance and facial recognition are always out-of-scope regardless of performance of the model. This is because the use of artificial intelligence for tasks such as these can be premature currently given the lack of testing norms and checks to ensure its fair use. 49 | 50 | Since the model has not been purposefully trained in or evaluated on any languages other than English, its use should be limited to English language use cases. 51 | 52 | 53 | 54 | ## Data 55 | 56 | The model was trained on publicly available image-caption data. This was done through a combination of crawling a handful of websites and using commonly-used pre-existing image datasets such as [YFCC100M](http://projects.dfki.uni-kl.de/yfcc100m/). A large portion of the data comes from our crawling of the internet. This means that the data is more representative of people and societies most connected to the internet which tend to skew towards more developed nations, and younger, male users. 57 | 58 | ### Data Mission Statement 59 | 60 | Our goal with building this dataset was to test out robustness and generalizability in computer vision tasks. As a result, the focus was on gathering large quantities of data from different publicly-available internet data sources. The data was gathered in a mostly non-interventionist manner. However, we only crawled websites that had policies against excessively violent and adult images and allowed us to filter out such content. We do not intend for this dataset to be used as the basis for any commercial or deployed model and will not be releasing the dataset. 61 | 62 | 63 | 64 | ## Performance and Limitations 65 | 66 | ### Performance 67 | 68 | We have evaluated the performance of CLIP on a wide range of benchmarks across a variety of computer vision datasets such as OCR to texture recognition to fine-grained classification. The paper describes model performance on the following datasets: 69 | 70 | - Food101 71 | - CIFAR10 72 | - CIFAR100 73 | - Birdsnap 74 | - SUN397 75 | - Stanford Cars 76 | - FGVC Aircraft 77 | - VOC2007 78 | - DTD 79 | - Oxford-IIIT Pet dataset 80 | - Caltech101 81 | - Flowers102 82 | - MNIST 83 | - SVHN 84 | - IIIT5K 85 | - Hateful Memes 86 | - SST-2 87 | - UCF101 88 | - Kinetics700 89 | - Country211 90 | - CLEVR Counting 91 | - KITTI Distance 92 | - STL-10 93 | - RareAct 94 | - Flickr30 95 | - MSCOCO 96 | - ImageNet 97 | - ImageNet-A 98 | - ImageNet-R 99 | - ImageNet Sketch 100 | - ObjectNet (ImageNet Overlap) 101 | - Youtube-BB 102 | - ImageNet-Vid 103 | 104 | ## Limitations 105 | 106 | CLIP and our analysis of it have a number of limitations. CLIP currently struggles with respect to certain tasks such as fine grained classification and counting objects. CLIP also poses issues with regards to fairness and bias which we discuss in the paper and briefly in the next section. Additionally, our approach to testing CLIP also has an important limitation- in many cases we have used linear probes to evaluate the performance of CLIP and there is evidence suggesting that linear probes can underestimate model performance. 107 | 108 | ### Bias and Fairness 109 | 110 | We find that the performance of CLIP - and the specific biases it exhibits - can depend significantly on class design and the choices one makes for categories to include and exclude. We tested the risk of certain kinds of denigration with CLIP by classifying images of people from [Fairface](https://arxiv.org/abs/1908.04913) into crime-related and non-human animal categories. We found significant disparities with respect to race and gender. Additionally, we found that these disparities could shift based on how the classes were constructed. (Details captured in the Broader Impacts Section in the paper). 111 | 112 | We also tested the performance of CLIP on gender, race and age classification using the Fairface dataset (We default to using race categories as they are constructed in the Fairface dataset.) in order to assess quality of performance across different demographics. We found accuracy >96% across all races for gender classification with ‘Middle Eastern’ having the highest accuracy (98.4%) and ‘White’ having the lowest (96.5%). Additionally, CLIP averaged ~93% for racial classification and ~63% for age classification. Our use of evaluations to test for gender, race and age classification as well as denigration harms is simply to evaluate performance of the model across people and surface potential risks and not to demonstrate an endorsement/enthusiasm for such tasks. 113 | 114 | 115 | 116 | ## Feedback 117 | 118 | ### Where to send questions or comments about the model 119 | 120 | Please use [this Google Form](https://forms.gle/Uv7afRH5dvY34ZEs9) 121 | -------------------------------------------------------------------------------- /CLIP/requirements.txt: -------------------------------------------------------------------------------- 1 | ftfy 2 | regex 3 | tqdm 4 | torch 5 | torchvision 6 | -------------------------------------------------------------------------------- /CLIP/setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pkg_resources 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name="clip", 8 | py_modules=["clip"], 9 | version="1.0", 10 | description="", 11 | author="OpenAI", 12 | packages=find_packages(exclude=["tests*"]), 13 | install_requires=[ 14 | str(r) 15 | for r in pkg_resources.parse_requirements( 16 | open(os.path.join(os.path.dirname(__file__), "requirements.txt")) 17 | ) 18 | ], 19 | include_package_data=True, 20 | extras_require={"dev": ["pytest"]}, 21 | ) 22 | -------------------------------------------------------------------------------- /CLIP/tests/test_consistency.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | 6 | import clip 7 | 8 | 9 | @pytest.mark.parametrize("model_name", clip.available_models()) 10 | def test_consistency(model_name): 11 | device = "cpu" 12 | jit_model, transform = clip.load(model_name, device=device, jit=True) 13 | py_model, _ = clip.load(model_name, device=device, jit=False) 14 | 15 | image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device) 16 | text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device) 17 | 18 | with torch.no_grad(): 19 | logits_per_image, _ = jit_model(image, text) 20 | jit_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 21 | 22 | logits_per_image, _ = py_model(image, text) 23 | py_probs = logits_per_image.softmax(dim=-1).cpu().numpy() 24 | 25 | assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1) 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Columbia Artificial Intelligence and Robotics Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

Semantic Abstraction: Open-World 3D Scene Understanding from 2D Vision-Language Models

2 |
3 | 4 | [Huy Ha](https://www.cs.columbia.edu/~huy/), [Shuran Song](https://www.cs.columbia.edu/~shurans/) 5 | 6 | Columbia University, New York, NY, United States 7 | 8 | [Conference on Robot Learning 2022](https://corl2022.org/) 9 | 10 | [Project Page](https://semantic-abstraction.cs.columbia.edu/) | [Arxiv](https://arxiv.org/abs/2207.11514) 11 | 12 | [![HuggingFace Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/huy-ha/semabs-relevancy) 13 | 14 |
15 | 16 | 17 | Our approach, Semantic Abstraction, unlocks 2D VLM's capabilities to 3D scene understanding. Trained with a limited synthetic dataset, our model generalizes to unseen classes in a novel domain (i.e., real world), even for small objects like “rubiks cube”, long-tail concepts like “harry potter”, and hidden objects like the “used N95s in the garbage bin”. Unseen classes are bolded. 18 | 19 |
20 |
21 | 22 |
23 | 24 | This repository contains code for generating relevancies, training, and evaluating [Semantic Abstraction](https://semantic-abstraction.cs.columbia.edu/). 25 | It has been tested on Ubuntu 18.04 and 20.04, NVIDIA GTX 1080, NVIDIA RTX A6000, NVIDIA GeForce RTX 3080, and NVIDIA GeForce RTX 3090. 26 | 27 | If you find this codebase useful, consider citing: 28 | 29 | ```bibtex 30 | @inproceedings{ha2022semabs, 31 | title={Semantic Abstraction: Open-World 3{D} Scene Understanding from 2{D} Vision-Language Models}, 32 | author = {Ha, Huy and Song, Shuran}, 33 | booktitle={Proceedings of the 2022 Conference on Robot Learning}, 34 | year={2022} 35 | } 36 | ``` 37 | 38 | If you have any questions, please contact [me](https://www.cs.columbia.edu/~huy/) at `huy [at] cs [dot] columbia [dot] edu`. 39 | 40 | **Table of Contents** 41 | 42 | - [Setup](#setup) 43 | - [Environment](#environment) 44 | - [Models](#models) 45 | - [Dataset (Optional)](#dataset-optional) 46 | - [Multi-scale Relevancy Extractor](#multi-scale-relevancy-extractor) 47 | - [Evaluation](#evaluation) 48 | - [Summarize](#summarize) 49 | - [Run inference](#run-inference) 50 | - [Visualization](#visualization) 51 | - [Training](#training) 52 | - [OVSSC](#ovssc) 53 | - [VOOL](#vool) 54 | - [Codebase Walkthrough](#codebase-walkthrough) 55 | - [Acknowledgements](#acknowledgements) 56 | 57 | # Setup 58 | 59 | ## Environment 60 | 61 | Create the conda environment 62 | 63 | ```sh 64 | conda env create -f semabs.yml 65 | ``` 66 | 67 | ## Models 68 | 69 | Download the model checkpoints (~3.5GB) by running this command at the root of the repo 70 | ```sh 71 | wget https://semantic-abstraction.cs.columbia.edu/downloads/models.tar.lz4 -O - | tar --use-compress-program=lz4 -xf - -C ./ 72 | ``` 73 | 74 | You should have the following directory structure 75 | 76 | ```console 77 | ❯ tree /path/to/semantic-abstraction/models 78 | /path/to/semantic-abstraction/models 79 | ├── chefer_et_al 80 | │   ├── ovssc 81 | │   │   ├── args.pkl 82 | │   │   ├── ovssc.pth 83 | │   │   └── ovssc_eval_stats.pkl 84 | │   └── vool 85 | │   ├── args.pkl 86 | │   ├── vool.pth 87 | │   └── vool_eval_stats.pkl 88 | ├── clipspatial 89 | │   └── vool 90 | │   ├── args.pkl 91 | │   ├── vool.pth 92 | │   └── vool_eval_stats.pkl 93 | ├── ours 94 | │   ├── ovssc 95 | │   │   ├── args.pkl 96 | │   │   ├── ovssc.pth 97 | │   │   └── ovssc_eval_stats.pkl 98 | │   └── vool 99 | │   ├── args.pkl 100 | │   ├── vool.pth 101 | │   └── vool_eval_stats.pkl 102 | └── semaware 103 | ├── ovssc 104 | │   ├── args.pkl 105 | │   ├── ovssc.pth 106 | │   └── ovssc_eval_stats.pkl 107 | └── vool 108 | ├── args.pkl 109 | ├── vool.pth 110 | └── vool_eval_stats.pkl 111 | 112 | 11 directories, 21 files 113 | ``` 114 | 115 | 116 | ## Dataset (Optional) 117 | 118 | To run the evaluation inference or training, you will need the dataset. 119 | 120 | Download the dataset (~269GB) by running the following at the root of the repo 121 | ```sh 122 | wget https://semantic-abstraction.cs.columbia.edu/downloads/dataset.tar.lz4 -O - | tar --use-compress-program=lz4 -xf - -C ./ 123 | ``` 124 | 125 | 126 | We also preprocessed the (~53GB) NYU dataset for training and evaluation 127 | ```sh 128 | wget https://semantic-abstraction.cs.columbia.edu/downloads/nyu_ovssc.tar.lz4 -O - | tar --use-compress-program=lz4 -xf - -C ./ 129 | ``` 130 | 131 | 132 | # Multi-scale Relevancy Extractor 133 | 134 | Play around with the multi-scale relevancy extractor on [![HuggingFace Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/huy-ha/semabs-relevancy). 135 | 136 | To run the multi-scale relevancy on GPU locally with the provided image from matterport 137 | 138 | ```sh 139 | python generate_relevancy.py image 140 | ``` 141 | 142 | which will output 143 | 144 | ![](grads.png) 145 | 146 | Try passing in your own image! 147 | 148 | ```console 149 | ❯ python generate_relevancy.py image --help 150 | 151 | Usage: generate_relevancy.py image [OPTIONS] [FILE_PATH] 152 | 153 | Generates a multi-scale relevancy for image at `file_path`. 154 | 155 | ╭─ Arguments ───────────────────────────────────────────────────────────────────────────────────╮ 156 | │ file_path [FILE_PATH] path of image file [default: matterport.png] │ 157 | ╰───────────────────────────────────────────────────────────────────────────────────────────────╯ 158 | ╭─ Options ─────────────────────────────────────────────────────────────────────────────────────╮ 159 | │ --labels TEXT list of object categories (e.g.: "nintendo switch") │ 160 | │ [default: basketball jersey, nintendo switch, television, ping pong │ 161 | │ table, vase, fireplace, abstract painting of a vespa, carpet, wall] │ 162 | │ --prompts TEXT prompt template to use with CLIP. │ 163 | │ [default: a photograph of a {} in a home.] │ 164 | │ --help Show this message and exit. │ 165 | ╰───────────────────────────────────────────────────────────────────────────────────────────────╯ 166 | ``` 167 | 168 | # Evaluation 169 | 170 | ## Summarize 171 | 172 | The evaluation result dataframes (`*_eval_stats*.pkl`) are provided along with the model checkpoints. 173 | To summarize them in a table 174 | 175 | ```console 176 | ❯ python summarize.py 177 | OVSSC THOR 178 | ╷ ╷ ╷ ╷ 179 | Approach │ Novel Room │ Novel Visual │ Novel Vocab │ Novel Class 180 | ═════════════════════════╪════════════╪══════════════╪═════════════╪═════════════ 181 | Semantic Aware │ 32.2 │ 31.9 │ 20.2 │ 0.0 182 | SemAbs + [Chefer et al] │ 26.6 │ 24.3 │ 17.8 │ 12.2 183 | ─────────────────────────┼────────────┼──────────────┼─────────────┼───────────── 184 | Ours │ 40.1 │ 36.4 │ 33.4 │ 37.9 185 | ╵ ╵ ╵ ╵ 186 | FULL VOOL THOR 187 | ╷ ╷ ╷ ╷ ╷ 188 | Approach │ Spatial Relation │ Novel Room │ Novel Visual │ Novel Vocab │ Novel Class 189 | ═════════════════════════╪══════════════════╪════════════╪══════════════╪═════════════╪═════════════ 190 | Semantic Aware │ in │ 15.0 │ 14.7 │ 7.6 │ 1.8 191 | │ on │ 9.0 │ 8.9 │ 11.4 │ 4.5 192 | │ on the left of │ 11.2 │ 11.1 │ 14.4 │ 4.0 193 | │ behind │ 12.8 │ 12.6 │ 14.1 │ 2.2 194 | │ on the right of │ 13.1 │ 13.0 │ 11.5 │ 3.4 195 | │ in front of │ 11.2 │ 11.1 │ 9.3 │ 2.2 196 | │ mean │ 12.1 │ 11.9 │ 11.4 │ 3.0 197 | ─────────────────────────┼──────────────────┼────────────┼──────────────┼─────────────┼───────────── 198 | ClipSpatial │ in │ 9.6 │ 8.6 │ 7.1 │ 3.3 199 | │ on │ 14.1 │ 12.1 │ 18.5 │ 20.0 200 | │ on the left of │ 11.0 │ 9.4 │ 14.2 │ 13.2 201 | │ behind │ 11.3 │ 9.9 │ 14.1 │ 8.9 202 | │ on the right of │ 12.1 │ 10.6 │ 16.2 │ 11.5 203 | │ in front of │ 12.3 │ 10.3 │ 15.7 │ 9.9 204 | │ mean │ 11.7 │ 10.1 │ 14.3 │ 11.2 205 | ─────────────────────────┼──────────────────┼────────────┼──────────────┼─────────────┼───────────── 206 | SemAbs + [Chefer et al] │ in │ 11.8 │ 11.1 │ 5.7 │ 2.1 207 | │ on │ 7.0 │ 6.7 │ 11.3 │ 7.1 208 | │ on the left of │ 9.5 │ 9.3 │ 13.7 │ 4.9 209 | │ behind │ 7.6 │ 7.6 │ 10.6 │ 2.5 210 | │ on the right of │ 9.2 │ 9.2 │ 11.0 │ 3.9 211 | │ in front of │ 9.4 │ 9.0 │ 12.0 │ 3.3 212 | │ mean │ 9.1 │ 8.8 │ 10.7 │ 4.0 213 | ─────────────────────────┼──────────────────┼────────────┼──────────────┼─────────────┼───────────── 214 | Ours │ in │ 17.8 │ 17.5 │ 8.5 │ 7.3 215 | │ on │ 21.0 │ 18.0 │ 27.2 │ 28.1 216 | │ on the left of │ 22.0 │ 20.3 │ 27.7 │ 25.1 217 | │ behind │ 19.9 │ 18.0 │ 22.8 │ 16.7 218 | │ on the right of │ 23.2 │ 21.7 │ 28.1 │ 22.1 219 | │ in front of │ 21.5 │ 19.4 │ 25.8 │ 19.1 220 | │ mean │ 20.9 │ 19.2 │ 23.4 │ 19.7 221 | 222 | OVSSC NYU 223 | ╷ ╷ ╷ ╷ ╷ ╷ ╷ ╷ ╷ ╷ ╷ ╷ 224 | Approach │ Ceiling │ Floor │ Wall │ Window │ Chair │ Bed │ Sofa │ Table │ Tvs │ Furn │ Objs │ Mean 225 | ═══════════════════╪═════════╪═══════╪══════╪════════╪═══════╪══════╪══════╪═══════╪══════╪══════╪══════╪══════ 226 | Ours (Supervised) │ 22.6 │ 46.1 │ 33.9 │ 35.9 │ 23.9 │ 55.9 │ 37.9 │ 19.7 │ 30.8 │ 39.8 │ 27.7 │ 34.0 227 | ───────────────────┼─────────┼───────┼──────┼────────┼───────┼──────┼──────┼───────┼──────┼──────┼──────┼────── 228 | Ours (Zeroshot) │ 13.7 │ 17.3 │ 13.5 │ 25.2 │ 15.2 │ 33.3 │ 31.5 │ 12.0 │ 23.7 │ 25.6 │ 19.9 │ 21.0 229 | ╵ ╵ ╵ ╵ ╵ ╵ ╵ ╵ ╵ ╵ ╵ ╵ 230 | ``` 231 | 232 | ## Run inference 233 | 234 | To run inference, make sure the dataset is [downloaded](#dataset-optional). 235 | 236 | Then, regenerate the evaluation result dataframes by running the evaluation script. 237 | 238 | For OVSSC 239 | 240 | ```sh 241 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=1 eval.py --task ovssc --file_path dataset/ --gpus 0 --load models/ours/ovssc/ovssc.pth 242 | ``` 243 | 244 | Inference can be sped up by using more than one GPU. 245 | For instance, to use GPU 0, 1, 2, and 3, use `--nproc_per_node=4` and `--gpus 0 1 2 3`. 246 | 247 | For OVSSC on NYU 248 | ```sh 249 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=1 eval.py --task ovssc --file_path nyu_ovssc/ --gpus 0 --load models/ours/ovssc/ovssc.pth 250 | ``` 251 | 252 | Similarly, for VOOL 253 | 254 | ```sh 255 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=1 eval.py --task vool --file_path dataset/ --gpus 0 --load models/ours/vool/vool.pth 256 | ``` 257 | 258 | ## Visualization 259 | 260 | The `visualize.py` script takes as input the scene pickle file and the network checkpoint. 261 | 262 | The pickle file should be a dictionary with the following keys and types 263 | ```python 264 | rgb: np.ndarray # shape h x w x 3 265 | depth: np.ndarray # shape h x w 266 | img_shape: Tuple[int, int] 267 | cam_intr: np.ndarray # shape 4 x 4 268 | cam_extr: np.ndarray # shape 4 x 4 269 | ovssc_obj_classes: List[str] 270 | descriptions: List[List[str]] 271 | ``` 272 | After being loaded, `rgb` and `depth` will be resized to `img_shape`, which matches the image dimensions in `cam_intr`. 273 | Each element in descriptions is a list containing the target object name, spatial preposition and reference object name respectively. 274 | We provide some example scene pickle files from [Habitat Matterport 3D](https://aihabitat.org/datasets/hm3d/) and [ARKitScenes](https://github.com/apple/ARKitScenes) in `scene_files/`. 275 | 276 | Visualizing OVSSC generates a `.mp4` video of the completion, along with `.obj` meshes, while visualizing VOOL generates `.mp4` videos along with `.ply` pointclouds for each description. 277 | 278 | For instance, running 279 | ```sh 280 | # OVSSC 281 | python visualize.py ovssc-inference scene_files/arkit_vn_poster.pkl models/ours/ovssc/ovssc.pth 282 | python visualize.py ovssc-visualize visualization/arkit_vn_poster 283 | # VOOL 284 | python visualize.py vool-inference scene_files/arkit_vn_poster.pkl models/ours/vool/vool.pth 285 | python visualize.py vool-visualize visualization/arkit_vn_poster 286 | ``` 287 | Will output to `visualization/arkit_vn_poster`, including the following relevancies 288 | 289 | ![](assets/vn_poster_relevancies.png) 290 | 291 | and the following VOOL localization for `the hair dryer with its wires all tangled up behind the table legs` and OVSSC completion: 292 | 293 | | RGB | Localization | Completion | 294 | | -------------------------------- | --------------------------------------- | ------------------------------------- | 295 | | ![](assets/hair_dryer_scene.png) | ![](assets/hair_dryer_behind_table.gif) | ![](assets/hair_dryer_completion.gif) | 296 | 297 | 298 | While these visualizations are sufficient for debugging, I recommend using the `ply` and `obj` files to render in Blender. 299 | 300 | 301 | | Legend | Localization | 302 | | ---------------------------------------------------- | --------------------------------------------- | 303 | | ![](assets/hair_dryer_completion_blender_legend.png) | ![](assets/hair_dryer_completion_blender.gif) | 304 | 305 | 306 | # Training 307 | 308 | To train the models, make sure the dataset is [downloaded](#dataset-optional). 309 | 310 | ## OVSSC 311 | 312 | To retrain our model 313 | 314 | ```sh 315 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_ovssc.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --saliency_config ours 316 | ``` 317 | 318 | To retrain the semantic aware model 319 | 320 | ```sh 321 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_ovssc.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --approach semantic_aware 322 | ``` 323 | 324 | To retrain the semantic abstraction + [Chefer et. al](https://github.com/hila-chefer/Transformer-MM-Explainability) model 325 | 326 | ```sh 327 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_ovssc.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --approach semantic_aware 328 | ``` 329 | 330 | ## VOOL 331 | 332 | To retrain our model 333 | 334 | ```sh 335 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_vool.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --saliency_config ours 336 | ``` 337 | 338 | To retrain the semantic aware model 339 | 340 | ```sh 341 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_vool.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --approach semantic_aware 342 | ``` 343 | 344 | To retrain the CLIP-Spatial model 345 | 346 | ``` 347 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_vool.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --approach clip_spatial 348 | ``` 349 | 350 | To retrain the semantic abstraction + [Chefer et. al](https://github.com/hila-chefer/Transformer-MM-Explainability) model 351 | 352 | ```sh 353 | python -m torch.distributed.run --nnodes=1 --nproc_per_node=8 train_vool.py --file_path dataset/ --log models/new-ours --gpus 0 1 2 3 4 5 6 7 --epochs 200 --approach semantic_aware 354 | ``` 355 | 356 | 357 | # Codebase Walkthrough 358 | 359 | Below, we've provided a summary of the networks provide, along with how and where our method as described in the paper is implemented. The links in the bullet points below links to specific lines in this codebase. We hope this code annotation helps clarify the network architecture and training procedures. 360 | 361 | - [`SemAbs3D`](net.py#L319): This class implements the SemAbs module. It contains two networks, a [`ResidualUNet3D`](unet3d.py) (i.e, $f_\mathrm{encode}$) and an [`ImplicitVolumetricDecoder`](net.py#L204) (i.e, $f_\mathrm{decoder}$). 362 | - [`SemanticAwareOVSSC`](net.py#L441): This class implements the SemAware baseline in the OVSSC task. It inherits directly from `SemAbs3D`, with two crucial differences: 1) it [takes RGB pointclouds as input](train_ovssc.py#L237) instead of [saliency pointclouds](utils.py#L94), and ) it uses its [sampled feature pointclouds to point](net.py#L455) to text features of semantic classes (i.e., encoded using CLIP's text encoder). These two differences together mean it has to learn to recognize semantic classes from RGB inputs by itself, which leads to overfitting of training semantic classes. 363 | - [`SemAbsVOOL`](net.py#L468): This class implements the SemAbs module for the VOOL task. In addition to the learnable parameters of `SemAbs3D`, it includes a [set of relational embeddings](net.py#L489), one for each of the close-vocabulary of spatial relations. 364 | - [`SemanticAwareVOOL`](net.py#L581): This class implements the SemAware baseline for the VOOl task. Similar to `SemanticAwareOVSSC`, it [takes as input RGB pointclouds](train_vool.py#L222). However, specifically for the VOOl task, it uses the entire localization description ([encoded using CLIP's text encoders and learned spatial relations](net.py#L590)) to [point to regions within the scene](net.py#L606). 365 | - [`ClipSpatialVOOL`](net.py#L638): This class implements the CLIPSpatial baseline for the VOOL task. In contrast to other VOOL networks, it does not attempt to learn spatial relations or semantics. Instead, it completely relies on relevancy inputs from CLIP. 366 | 367 | 368 | A few tips for training your semantic abstraction module: 369 | - We have observed that performing [small random transformations](dataset.py#L534) on input and output point clouds help generalization significantly. 370 | - To account for positive/negative class balance when using the binary cross entropy loss (in both [OVSSC](train_ovssc.py#L147) and [VOOL](train_vool.py#L171)), we found that using [`--balance_spatial_sampling`](utils.py#L71) helps tremendously. This [biases the subsampling of query points](dataset.py#L607) such that as many positive points are sample as possible without replacement to achieve a balanced batch. 371 | - Remember to [rescale your relevancy values in a reasonable range](dataset.py#L1052)! 372 | 373 | 374 | # Acknowledgements 375 | 376 | We would like to thank Samir Yitzhak Gadre, Cheng Chi and Zhenjia Xu for their helpful feedback and fruitful discussions. 377 | This work was supported in part by NSF Award #2143601, #2132519, JP Morgan Faculty Research Award, and Google Research Award. 378 | The views and conclusions contained herein are those of the authors and should not be interpreted as necessarily representing the official policies, either expressed or implied, of the sponsors. 379 | 380 | Code: 381 | - The relevancy extraction code was modified from [Chefer et al.'s codebase](https://github.com/hila-chefer/Transformer-MM-Explainability) and [CLIP on Wheels' codebase](https://cow.cs.columbia.edu/). 382 | - The [3D U-Net](unet3d.py) definition was taken from [Adrian Wolny](https://github.com/wolny/pytorch-3dunet/). 383 | - Integrated into [Huggingface Spaces 🤗](https://huggingface.co/spaces) using [Gradio](https://github.com/gradio-app/gradio). 384 | - The [LAMB](https://arxiv.org/pdf/1904.00962.pdf) PyTorch implementation is from the [Attention-driven Robotic Manipulation (ARM)](https://github.com/stepjam/ARM#running-experiments). 385 | 386 | 387 | ![](https://visitor-badge.glitch.me/badge?page_id=huy-ha.semabs-relevancy) -------------------------------------------------------------------------------- /arm/LICENSE: -------------------------------------------------------------------------------- 1 | Q-attention: Enabling Efficient Learning for Vision-based Robotic Manipulation 2 | 3 | LICENCE AGREEMENT 4 | 5 | WE (Imperial College of Science, Technology and Medicine, (“Imperial College London”)) 6 | ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY ON THE 7 | CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE 8 | FOLLOWING AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE 9 | DOWNLOADING THE SOFTWARE. BY EXERCISING THE OPTION TO DOWNLOAD 10 | THE SOFTWARE YOU AGREE TO BE BOUND BY THE TERMS OF THE AGREEMENT. 11 | SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS) 12 | 13 | 1.This Agreement pertains to a worldwide, non-exclusive, temporary, fully paid-up, royalty 14 | free, non-transferable, non-sub- licensable licence (the “Licence”) to use the Q-attention 15 | source code, including any modification, part or derivative (the “Software”). 16 | Ownership and Licence. Your rights to use and download the Software onto your computer, 17 | and all other copies that You are authorised to make, are specified in this Agreement. 18 | However, we (or our licensors) retain all rights, including but not limited to all copyright and 19 | other intellectual property rights anywhere in the world, in the Software not expressly 20 | granted to You in this Agreement. 21 | 22 | 2. Permitted use of the Licence: 23 | 24 | (a) You may download and install the Software onto one computer or server for use in 25 | accordance with Clause 2(b) of this Agreement provided that You ensure that the Software is 26 | not accessible by other users unless they have themselves accepted the terms of this licence 27 | agreement. 28 | 29 | (b) You may use the Software solely for non-commercial, internal or academic research 30 | purposes and only in accordance with the terms of this Agreement. You may not use the 31 | Software for commercial purposes, including but not limited to (1) integration of all or part of 32 | the source code or the Software into a product for sale or licence by or on behalf of You to 33 | third parties or (2) use of the Software or any derivative of it for research to develop software 34 | products for sale or licence to a third party or (3) use of the Software or any derivative of it 35 | for research to develop non-software products for sale or licence to a third party, or (4) use of 36 | the Software to provide any service to an external organisation for which payment is 37 | received. 38 | 39 | Should You wish to use the Software for commercial purposes, You shall 40 | email researchcontracts.engineering@imperial.ac.uk . 41 | 42 | (c) Right to Copy. You may copy the Software for back-up and archival purposes, provided 43 | that each copy is kept in your possession and provided You reproduce our copyright notice 44 | (set out in Schedule 1) on each copy. 45 | 46 | (d) Transfer and sub-licensing. You may not rent, lend, or lease the Software and You may 47 | not transmit, transfer or sub-license this licence to use the Software or any of your rights or 48 | obligations under this Agreement to another party. 49 | 50 | (e) Identity of Licensee. The licence granted herein is personal to You. You shall not permit 51 | any third party to access, modify or otherwise use the Software nor shall You access modify 52 | or otherwise use the Software on behalf of any third party. If You wish to obtain a licence for 53 | mutiple users or a site licence for the Software please contact us 54 | at researchcontracts.engineering@imperial.ac.uk . 55 | 56 | (f) Publications and presentations. You may make public, results or data obtained from, 57 | dependent on or arising from research carried out using the Software, provided that any such 58 | presentation or publication identifies the Software as the source of the results or the data, 59 | including the Copyright Notice given in each element of the Software, and stating that the 60 | Software has been made available for use by You under licence from Imperial College London 61 | and You provide a copy of any such publication to Imperial College London. 62 | 63 | 3. Prohibited Uses. You may not, without written permission from us 64 | at researchcontracts.engineering@imperial.ac.uk : 65 | 66 | (a) Use, copy, modify, merge, or transfer copies of the Software or any documentation 67 | provided by us which relates to the Software except as provided in this Agreement; 68 | 69 | (b) Use any back-up or archival copies of the Software (or allow anyone else to use such 70 | copies) for any purpose other than to replace the original copy in the event it is destroyed or 71 | becomes defective; or 72 | 73 | (c) Disassemble, decompile or "unlock", reverse translate, or in any manner decode the 74 | Software for any reason. 75 | 76 | 4. Warranty Disclaimer 77 | 78 | (a) Disclaimer. The Software has been developed for research purposes only. You 79 | acknowledge that we are providing the Software to You under this licence agreement free of 80 | charge and on condition that the disclaimer set out below shall apply. We do not represent or 81 | warrant that the Software as to: (i) the quality, accuracy or reliability of the Software; (ii) the 82 | suitability of the Software for any particular use or for use under any specific conditions; and 83 | (iii) whether use of the Software will infringe third-party rights. 84 | You acknowledge that You have reviewed and evaluated the Software to determine that it 85 | meets your needs and that You assume all responsibility and liability for determining the 86 | suitability of the Software as fit for your particular purposes and requirements. Subject to 87 | Clause 4(b), we exclude and expressly disclaim all express and implied representations, 88 | warranties, conditions and terms not stated herein (including the implied conditions or 89 | warranties of satisfactory quality, merchantable quality, merchantability and fitness for 90 | purpose). 91 | 92 | (b) Savings. Some jurisdictions may imply warranties, conditions or terms or impose 93 | obligations upon us which cannot, in whole or in part, be excluded, restricted or modified or 94 | otherwise do not allow the exclusion of implied warranties, conditions or terms, in which 95 | case the above warranty disclaimer and exclusion will only apply to You to the extent 96 | permitted in the relevant jurisdiction and does not in any event exclude any implied 97 | warranties, conditions or terms which may not under applicable law be excluded. 98 | 99 | (c) Imperial College London disclaims all responsibility for the use which is made of the 100 | Software and any liability for the outcomes arising from using the Software. 101 | 102 | 5. Limitation of Liability 103 | 104 | (a) You acknowledge that we are providing the Software to You under this licence agreement 105 | free of charge and on condition that the limitation of liability set out below shall apply. 106 | Accordingly, subject to Clause 5(b), we exclude all liability whether in contract, tort, 107 | negligence or otherwise, in respect of the Software and/or any related documentation 108 | provided to You by us including, but not limited to, liability for loss or corruption of data, 109 | loss of contracts, loss of income, loss of profits, loss of cover and any consequential or indirect 110 | loss or damage of any kind arising out of or in connection with this licence agreement, 111 | however caused. This exclusion shall apply even if we have been advised of the possibility of 112 | such loss or damage. 113 | 114 | (b) You agree to indemnify Imperial College London and hold it harmless from and against 115 | any and all claims, damages and liabilities asserted by third parties (including claims for 116 | negligence) which arise directly or indirectly from the use of the Software or any derivative 117 | of it or the sale of any products based on the Software. You undertake to make no liability 118 | claim against any employee, student, agent or appointee of Imperial College London, in 119 | connection with this Licence or the Software. 120 | 121 | (c) Nothing in this Agreement shall have the effect of excluding or limiting our statutory 122 | liability. 123 | 124 | (d) Some jurisdictions do not allow these limitations or exclusions either wholly or in part, 125 | and, to that extent, they may not apply to you. Nothing in this licence agreement will affect 126 | your statutory rights or other relevant statutory provisions which cannot be excluded, 127 | restricted or modified, and its terms and conditions must be read and construed subject to any 128 | such statutory rights and/or provisions. 129 | 130 | 6. Confidentiality. You agree not to disclose any confidential information provided to You by 131 | us pursuant to this Agreement to any third party without our prior written consent. The 132 | obligations in this Clause 6 shall survive the termination of this Agreement for any reason. 133 | 134 | 7. Termination. 135 | 136 | (a) We may terminate this licence agreement and your right to use the Software at any time 137 | with immediate effect upon written notice to You. 138 | 139 | (b) This licence agreement and your right to use the Software automatically terminate if You: 140 | (i) fail to comply with any provisions of this Agreement; or 141 | (ii) destroy the copies of the Software in your possession, or voluntarily return the Software 142 | to us. 143 | 144 | (c) Upon termination You will destroy all copies of the Software. 145 | 146 | (d) Otherwise, the restrictions on your rights to use the Software will expire 10 (ten) years 147 | after first use of the Software under this licence agreement. 148 | 149 | 8. Miscellaneous Provisions. 150 | 151 | (a) This Agreement will be governed by and construed in accordance with the substantive 152 | laws of England and Wales whose courts shall have exclusive jurisdiction over all disputes 153 | which may arise between us. 154 | 155 | (b) This is the entire agreement between us relating to the Software, and supersedes any prior 156 | purchase order, communications, advertising or representations concerning the Software. 157 | 158 | (c) No change or modification of this Agreement will be valid unless it is in writing, and is 159 | signed by us. 160 | 161 | (d) The unenforceability or invalidity of any part of this Agreement will not affect the 162 | enforceability or validity of the remaining parts. 163 | 164 | BSD Elements of the Software 165 | 166 | For BSD elements of the Software, the following terms shall apply: 167 | 168 | Copyright as indicated in the header of the individual element of the Software. 169 | 170 | All rights reserved. 171 | 172 | Redistribution and use in source and binary forms, with or without modification, are 173 | permitted provided that the following conditions are met: 174 | 175 | 1. Redistributions of source code must retain the above copyright notice, this list of 176 | conditions and the following disclaimer. 177 | 178 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of 179 | conditions and the following disclaimer in the documentation and/or other materials 180 | provided with the distribution. 181 | 182 | 3. Neither the name of the copyright holder nor the names of its contributors may be used to 183 | endorse or promote products derived from this software without specific prior written 184 | permission. 185 | 186 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 187 | "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 188 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 189 | A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 190 | HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 191 | SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 192 | LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 193 | DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 194 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 195 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 196 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /arm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/arm/__init__.py -------------------------------------------------------------------------------- /arm/optim/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 cybertronai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /arm/optim/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/arm/optim/__init__.py -------------------------------------------------------------------------------- /arm/optim/lamb.py: -------------------------------------------------------------------------------- 1 | # From https://github.com/cybertronai/pytorch-lamb/blob/master/pytorch_lamb/lamb.py 2 | 3 | """Lamb optimizer.""" 4 | 5 | import collections 6 | import math 7 | 8 | import torch 9 | from torch.optim import Optimizer 10 | 11 | 12 | # def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int): 13 | # """Log a histogram of trust ratio scalars in across layers.""" 14 | # results = collections.defaultdict(list) 15 | # for group in optimizer.param_groups: 16 | # for p in group['params']: 17 | # state = optimizer.state[p] 18 | # for i in ('weight_norm', 'adam_norm', 'trust_ratio'): 19 | # if i in state: 20 | # results[i].append(state[i]) 21 | # 22 | # for k, v in results.items(): 23 | # event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count) 24 | 25 | 26 | class Lamb(Optimizer): 27 | r"""Implements Lamb algorithm. 28 | It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. 29 | Arguments: 30 | params (iterable): iterable of parameters to optimize or dicts defining 31 | parameter groups 32 | lr (float, optional): learning rate (default: 1e-3) 33 | betas (Tuple[float, float], optional): coefficients used for computing 34 | running averages of gradient and its square (default: (0.9, 0.999)) 35 | eps (float, optional): term added to the denominator to improve 36 | numerical stability (default: 1e-8) 37 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 38 | adam (bool, optional): always use trust ratio = 1, which turns this into 39 | Adam. Useful for comparison purposes. 40 | .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: 41 | https://arxiv.org/abs/1904.00962 42 | """ 43 | 44 | def __init__( 45 | self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0, adam=False 46 | ): 47 | if not 0.0 <= lr: 48 | raise ValueError("Invalid learning rate: {}".format(lr)) 49 | if not 0.0 <= eps: 50 | raise ValueError("Invalid epsilon value: {}".format(eps)) 51 | if not 0.0 <= betas[0] < 1.0: 52 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 55 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 56 | self.adam = adam 57 | super(Lamb, self).__init__(params, defaults) 58 | 59 | def step(self, closure=None): 60 | """Performs a single optimization step. 61 | Arguments: 62 | closure (callable, optional): A closure that reevaluates the model 63 | and returns the loss. 64 | """ 65 | loss = None 66 | if closure is not None: 67 | loss = closure() 68 | 69 | for group in self.param_groups: 70 | for p in group["params"]: 71 | if p.grad is None: 72 | continue 73 | grad = p.grad.data 74 | if grad.is_sparse: 75 | raise RuntimeError( 76 | "Lamb does not support sparse gradients, consider SparseAdam instad." 77 | ) 78 | 79 | state = self.state[p] 80 | 81 | # State initialization 82 | if len(state) == 0: 83 | state["step"] = 0 84 | # Exponential moving average of gradient values 85 | state["exp_avg"] = torch.zeros_like(p.data) 86 | # Exponential moving average of squared gradient values 87 | state["exp_avg_sq"] = torch.zeros_like(p.data) 88 | 89 | exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] 90 | beta1, beta2 = group["betas"] 91 | 92 | state["step"] += 1 93 | 94 | # Decay the first and second moment running average coefficient 95 | # m_t 96 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 97 | # v_t 98 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 99 | 100 | # Paper v3 does not use debiasing. 101 | # bias_correction1 = 1 - beta1 ** state['step'] 102 | # bias_correction2 = 1 - beta2 ** state['step'] 103 | # Apply bias to lr to avoid broadcast. 104 | step_size = group[ 105 | "lr" 106 | ] # * math.sqrt(bias_correction2) / bias_correction1 107 | 108 | weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10) 109 | 110 | adam_step = exp_avg / exp_avg_sq.sqrt().add(group["eps"]) 111 | if group["weight_decay"] != 0: 112 | adam_step.add_(p.data, alpha=group["weight_decay"]) 113 | 114 | adam_norm = adam_step.pow(2).sum().sqrt() 115 | if weight_norm == 0 or adam_norm == 0: 116 | trust_ratio = 1 117 | else: 118 | trust_ratio = weight_norm / adam_norm 119 | state["weight_norm"] = weight_norm 120 | state["adam_norm"] = adam_norm 121 | state["trust_ratio"] = trust_ratio 122 | if self.adam: 123 | trust_ratio = 1 124 | 125 | p.data.add_(adam_step, alpha=-step_size * trust_ratio) 126 | 127 | return loss 128 | -------------------------------------------------------------------------------- /arm/utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from: https://github.com/stepjam/ARM/blob/main/arm/utils.py 2 | 3 | import torch 4 | import numpy as np 5 | from scipy.spatial.transform import Rotation 6 | 7 | import pyrender 8 | import trimesh 9 | from pyrender.trackball import Trackball 10 | 11 | 12 | def normalize_quaternion(quat): 13 | return np.array(quat) / np.linalg.norm(quat, axis=-1, keepdims=True) 14 | 15 | 16 | def quaternion_to_discrete_euler(quaternion, resolution): 17 | euler = Rotation.from_quat(quaternion).as_euler("xyz", degrees=True) + 180 18 | assert np.min(euler) >= 0 and np.max(euler) <= 360 19 | disc = np.around((euler / resolution)).astype(int) 20 | disc[disc == int(360 / resolution)] = 0 21 | return disc 22 | 23 | 24 | def discrete_euler_to_quaternion(discrete_euler, resolution): 25 | euluer = (discrete_euler * resolution) - 180 26 | return Rotation.from_euler("xyz", euluer, degrees=True).as_quat() 27 | 28 | 29 | def point_to_voxel_index( 30 | point: np.ndarray, voxel_size: np.ndarray, coord_bounds: np.ndarray 31 | ): 32 | bb_mins = np.array(coord_bounds[0:3]) 33 | bb_maxs = np.array(coord_bounds[3:]) 34 | dims_m_one = np.array([voxel_size] * 3) - 1 35 | bb_ranges = bb_maxs - bb_mins 36 | res = bb_ranges / (np.array([voxel_size] * 3) + 1e-12) 37 | voxel_indicy = np.minimum( 38 | np.floor((point - bb_mins) / (res + 1e-12)).astype(np.int32), dims_m_one 39 | ) 40 | return voxel_indicy 41 | 42 | 43 | def stack_on_channel(x): 44 | # expect (B, T, C, ...) 45 | return torch.cat(torch.split(x, 1, dim=1), dim=2).squeeze(1) 46 | 47 | 48 | def _compute_initial_camera_pose(scene): 49 | # Adapted from: 50 | # https://github.com/mmatl/pyrender/blob/master/pyrender/viewer.py#L1032 51 | centroid = scene.centroid 52 | scale = scene.scale 53 | # if scale == 0.0: 54 | # scale = DEFAULT_SCENE_SCALE 55 | scale = 4.0 56 | s2 = 1.0 / np.sqrt(2.0) 57 | cp = np.eye(4) 58 | cp[:3, :3] = np.array([[0.0, -s2, s2], [1.0, 0.0, 0.0], [0.0, s2, s2]]) 59 | hfov = np.pi / 6.0 60 | dist = scale / (2.0 * np.tan(hfov)) 61 | cp[:3, 3] = dist * np.array([1.0, 0.0, 1.0]) + centroid 62 | return cp 63 | 64 | 65 | def _from_trimesh_scene(trimesh_scene, bg_color=None, ambient_light=None): 66 | # convert trimesh geometries to pyrender geometries 67 | geometries = { 68 | name: pyrender.Mesh.from_trimesh(geom, smooth=False) 69 | for name, geom in trimesh_scene.geometry.items() 70 | } 71 | # create the pyrender scene object 72 | scene_pr = pyrender.Scene(bg_color=bg_color, ambient_light=ambient_light) 73 | # add every node with geometry to the pyrender scene 74 | for node in trimesh_scene.graph.nodes_geometry: 75 | pose, geom_name = trimesh_scene.graph[node] 76 | scene_pr.add(geometries[geom_name], pose=pose) 77 | return scene_pr 78 | 79 | 80 | def create_voxel_scene( 81 | voxel_grid: np.ndarray, 82 | q_attention: np.ndarray = None, 83 | highlight_coordinate: np.ndarray = None, 84 | highlight_gt_coordinate: np.ndarray = None, 85 | highlight_alpha: float = 1.0, 86 | voxel_size: float = 0.1, 87 | show_bb: bool = False, 88 | alpha: float = 0.5, 89 | ): 90 | _, d, h, w = voxel_grid.shape 91 | v = voxel_grid.transpose((1, 2, 3, 0)) 92 | occupancy = v[:, :, :, -1] != 0 93 | alpha = np.expand_dims(np.full_like(occupancy, alpha, dtype=np.float32), -1) 94 | rgb = np.concatenate([(v[:, :, :, 3:6] + 1) / 2.0, alpha], axis=-1) 95 | 96 | if q_attention is not None: 97 | q = np.max(q_attention, 0) 98 | q = q / np.max(q) 99 | show_q = q > 0.75 100 | occupancy = (show_q + occupancy).astype(bool) 101 | q = np.expand_dims(q - 0.5, -1) # Max q can be is 0.9 102 | q_rgb = np.concatenate( 103 | [q, np.zeros_like(q), np.zeros_like(q), np.clip(q, 0, 1)], axis=-1 104 | ) 105 | rgb = np.where(np.expand_dims(show_q, -1), q_rgb, rgb) 106 | 107 | if highlight_coordinate is not None: 108 | x, y, z = highlight_coordinate 109 | occupancy[x, y, z] = True 110 | rgb[x, y, z] = [1.0, 0.0, 0.0, highlight_alpha] 111 | 112 | if highlight_gt_coordinate is not None: 113 | x, y, z = highlight_gt_coordinate 114 | occupancy[x, y, z] = True 115 | rgb[x, y, z] = [0.0, 0.0, 1.0, highlight_alpha] 116 | 117 | transform = trimesh.transformations.scale_and_translate( 118 | scale=voxel_size, translate=(0.0, 0.0, 0.0) 119 | ) 120 | trimesh_voxel_grid = trimesh.voxel.VoxelGrid( 121 | encoding=occupancy, transform=transform 122 | ) 123 | geometry = trimesh_voxel_grid.as_boxes(colors=rgb) 124 | scene = trimesh.Scene() 125 | scene.add_geometry(geometry) 126 | if show_bb: 127 | assert d == h == w 128 | _create_bounding_box(scene, voxel_size, d) 129 | return scene 130 | 131 | 132 | def visualise_voxel( 133 | voxel_grid: np.ndarray, 134 | q_attention: np.ndarray = None, 135 | highlight_coordinate: np.ndarray = None, 136 | highlight_gt_coordinate: np.ndarray = None, 137 | highlight_alpha: float = 1.0, 138 | rotation_amount: float = 0.0, 139 | show: bool = False, 140 | voxel_size: float = 0.1, 141 | offscreen_renderer: pyrender.OffscreenRenderer = None, 142 | show_bb: bool = False, 143 | alpha: float = 0.5, 144 | render_gripper=False, 145 | gripper_pose=None, 146 | gripper_mesh_scale=1.0, 147 | ): 148 | scene = create_voxel_scene( 149 | voxel_grid, 150 | q_attention, 151 | highlight_coordinate, 152 | highlight_gt_coordinate, 153 | highlight_alpha, 154 | voxel_size, 155 | show_bb, 156 | alpha, 157 | ) 158 | if show: 159 | scene.show() 160 | else: 161 | r = offscreen_renderer or pyrender.OffscreenRenderer( 162 | viewport_width=1920, viewport_height=1080, point_size=1.0 163 | ) 164 | s = _from_trimesh_scene( 165 | scene, ambient_light=[0.8, 0.8, 0.8], bg_color=[1.0, 1.0, 1.0] 166 | ) 167 | cam = pyrender.PerspectiveCamera( 168 | yfov=np.pi / 4.0, aspectRatio=r.viewport_width / r.viewport_height 169 | ) 170 | p = _compute_initial_camera_pose(s) 171 | t = Trackball(p, (r.viewport_width, r.viewport_height), s.scale, s.centroid) 172 | t.rotate(rotation_amount, np.array([0.0, 0.0, 1.0])) 173 | s.add(cam, pose=t.pose) 174 | 175 | if render_gripper: 176 | gripper_trimesh = trimesh.load("peract_colab/meshes/hand.dae", force="mesh") 177 | gripper_trimesh.vertices *= gripper_mesh_scale 178 | radii = np.linalg.norm( 179 | gripper_trimesh.vertices - gripper_trimesh.center_mass, axis=1 180 | ) 181 | gripper_trimesh.visual.vertex_colors = trimesh.visual.interpolate( 182 | radii * gripper_mesh_scale, color_map="winter" 183 | ) 184 | gripper_mesh = pyrender.Mesh.from_trimesh( 185 | gripper_trimesh, poses=np.array([gripper_pose]), smooth=False 186 | ) 187 | s.add(gripper_mesh) 188 | color, depth = r.render(s) 189 | return color.copy() 190 | 191 | 192 | def get_gripper_render_pose( 193 | voxel_scale, scene_bound_origin, continuous_trans, continuous_quat 194 | ): 195 | # finger tip to gripper offset 196 | offset = np.array( 197 | [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0.1 * voxel_scale], [0, 0, 0, 1]] 198 | ) 199 | 200 | # scale and translate by origin 201 | translation = (continuous_trans - (np.array(scene_bound_origin[:3]))) * voxel_scale 202 | mat = np.eye(4, 4) 203 | mat[:3, :3] = Rotation.from_quat( 204 | [continuous_quat[0], continuous_quat[1], continuous_quat[2], continuous_quat[3]] 205 | ).as_matrix() 206 | offset_mat = np.matmul(mat, offset) 207 | mat[:3, 3] = translation - offset_mat[:3, 3] 208 | return mat 209 | -------------------------------------------------------------------------------- /assets/hair_dryer_behind_table.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/hair_dryer_behind_table.gif -------------------------------------------------------------------------------- /assets/hair_dryer_completion.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/hair_dryer_completion.gif -------------------------------------------------------------------------------- /assets/hair_dryer_completion_blender.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/hair_dryer_completion_blender.gif -------------------------------------------------------------------------------- /assets/hair_dryer_completion_blender_legend.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/hair_dryer_completion_blender_legend.png -------------------------------------------------------------------------------- /assets/hair_dryer_scene.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/hair_dryer_scene.png -------------------------------------------------------------------------------- /assets/matterport.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/matterport.png -------------------------------------------------------------------------------- /assets/teaser.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/teaser.gif -------------------------------------------------------------------------------- /assets/vn_poster_relevancies.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/assets/vn_poster_relevancies.png -------------------------------------------------------------------------------- /datagen/README.md: -------------------------------------------------------------------------------- 1 | # THOR Data Generation Instructions 2 | 3 | This data generation code has been tested on Ubuntu 18.04 with a NVIDIA GTX 1080 and Unity version 2019.4.38f1. 4 | 5 | ## Compiling a Custom Version of THOR 6 | 7 | In [Unity](https://unity.com/) version 2019.4.38f1, open the Unity project from our [custom THOR repo](https://github.com/huy-ha/ai2thor). 8 | Our custom THOR extends it with a [`SceneVolumeExporter`](https://github.com/huy-ha/ai2thor/blob/main/unity/Assets/SceneVolumeExporter.cs) for exporting 3D scenes from THOR. 9 | You can modify the output directory by changing [`rootPath`](https://github.com/huy-ha/ai2thor/blob/main/unity/Assets/SceneVolumeExporter.cs#L8) to somewhere convenient. 10 | This will be the path used for `--path_to_exported_scenes` in `generate_thor_data.py`. 11 | You can also export at a higher [resolution](https://github.com/huy-ha/ai2thor/blob/main/unity/Assets/SceneVolumeExporter.cs#L355), which will take longer to export. 12 | 13 | When building the project, make sure to select all scenes. 14 | 15 | ![](assets/unity-build.png) 16 | 17 | This will be used for `--path_to_custom_unity` in `generate_thor_data.py`. 18 | 19 | ## Generating data 20 | 21 | ### Export/download all 3D scenes 22 | 23 | `SceneVolumeExporter` works by [naively grid sampling THOR scenes with a collision probe](https://github.com/huy-ha/ai2thor/blob/main/unity/Assets/SceneVolumeExporter.cs#L378-L420), then dumps the collision results to `.txt` files. 24 | It also [queries all receptacles](https://github.com/huy-ha/ai2thor/blob/main/unity/Assets/SceneVolumeExporter.cs#L131-L172) and dumps this information to a `.txt` file. 25 | The result is the un-transformed 3D semantic occupancy of all 3D scenes at the `rootPath`. 26 | While this is costly, it only needs to be done once. 27 | Alternatively, download our pre-processed exported 3D scenes. 28 | 29 | ```sh 30 | wget https://semantic-abstraction.cs.columbia.edu/downloads/groundtruth3dpoints.tar.lz4 -O - | tar --use-compress-program=lz4 -xf - -C ./ 31 | ``` 32 | 33 | 34 | When `generate_thor_data.py` is ran, it will first check that all scenes' ground truth 3D semantic occupancies have been exported to `path_to_exported_scenes`. 35 | 36 | Note: if you get "Reading from AI2-THOR backend timed out" errors, you can remove the timeout limit in [AI2THOR's python server](https://github.com/allenai/ai2thor/blob/main/ai2thor/fifo_server.py#L130-L131). 37 | 38 | 39 | ### Generating SemAbs data points 40 | 41 | After installing `ai2thor` 42 | 43 | ```sh 44 | pip install ai2thor==4.2.0 45 | ``` 46 | you're ready to run the data generation script. It works as follows: 47 | 1. choose a random Thor room. 48 | 2. choose a random view point in that room. 49 | 3. transform and filters the ground truth 3D semantic occupancy at `path_to_exported_scenes` to match the view point. 50 | 4. preprocess return value from Thor (segmentation maps, object class names, generate TSDFs, etc.) 51 | 5. detect all spatial relations. 52 | 53 | An example command might look like this 54 | ```sh 55 | python generate_thor_data.py --dataset_dir_path /path/to/dataset --num_pts 50000 --path_to_custom_unity /path/to/unity/exec --path_to_exported_scenes /path/to/3dscenes 56 | ``` 57 | If you have issues launching THOR, try disabling multi-processing by using the `--local` flag. 58 | 59 | ### Generating relevancies 60 | 61 | The last step is to generate the relevancies for the dumped data points, 62 | ``` 63 | python generate_relevancy.py dataset /path/to/dataset num_processes 64 | ``` -------------------------------------------------------------------------------- /datagen/assets/unity-build.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/datagen/assets/unity-build.png -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | from tqdm import tqdm 4 | import torch 5 | import os 6 | import pickle 7 | from dataset import ObjectLocalizationDataset, SceneCompletionDataset 8 | from train_vool import get_losses as vool_get_losses, approach as vool_approaches 9 | from train_ovssc import get_losses as ovssc_get_losses, approach as ovssc_approaches 10 | import utils 11 | from torch.utils.data import DataLoader 12 | import pandas as pd 13 | import torch.distributed as dist 14 | from torch.utils.data.distributed import DistributedSampler 15 | 16 | if __name__ == "__main__": 17 | parser = utils.config_parser() 18 | parser.add_argument("--task", choices=["ovssc", "vool"], required=True) 19 | args = parser.parse_args() 20 | with open(os.path.dirname(args.load) + "/args.pkl", "rb") as file: 21 | exp_args = pickle.load(file) 22 | for arg in vars(exp_args): 23 | if any(arg == s for s in ["device", "file_path", "load", "gpus", "task"]): 24 | continue 25 | setattr(args, arg, getattr(exp_args, arg)) 26 | args.domain_randomization = False 27 | args.scene_bounds = torch.tensor(args.scene_bounds) 28 | args.batch_size = 1 29 | args.num_workers = 8 30 | args.balance_spatial_sampling = False 31 | args.detailed_analysis = True 32 | ddp = len(args.gpus) > 1 33 | approaches = ovssc_approaches if args.task == "ovssc" else vool_approaches 34 | dataset_class = ( 35 | SceneCompletionDataset if args.task == "ovssc" else ObjectLocalizationDataset 36 | ) 37 | exp_dict = utils.setup_experiment( 38 | args=args, 39 | net_class=approaches[args.approach], 40 | dataset_class=dataset_class, 41 | split_file_path=args.file_path 42 | + ("/vool_split.pkl" if args.task == "vool" else "/ssc_split.pkl"), 43 | return_vis=True, 44 | ddp=ddp, 45 | ) 46 | net = exp_dict["net"] 47 | net.eval() 48 | net.requires_grad = False 49 | epoch = exp_dict["start_epoch"] 50 | eval_detailed_stats = pd.DataFrame() 51 | with torch.no_grad(): 52 | for split, dataset in exp_dict["datasets"].items(): 53 | if split == "train": 54 | continue 55 | sampler = None 56 | if ddp: 57 | sampler = DistributedSampler( 58 | dataset=dataset, shuffle=False, drop_last=False 59 | ) 60 | sampler.set_epoch(0) 61 | loader = DataLoader( 62 | dataset=dataset, 63 | num_workers=args.num_workers, 64 | batch_size=1, 65 | sampler=sampler, 66 | ) 67 | detailed_stats = utils.loop( 68 | net=net, 69 | loader=loader, 70 | get_losses_fn=ovssc_get_losses 71 | if args.task == "ovssc" 72 | else vool_get_losses, 73 | **{ 74 | **vars(args), 75 | "optimizer": None, 76 | "lr_scheduler": None, 77 | "cutoffs": np.arange(-2.5, -0.0, 0.1), 78 | "pbar": tqdm( 79 | total=len(loader), 80 | dynamic_ncols=True, 81 | unit="batch", 82 | postfix=f"| {split.upper()} ", 83 | ), 84 | "detailed_analysis": True, 85 | }, 86 | ) 87 | detailed_stats["epoch"] = [epoch] * len(detailed_stats) 88 | detailed_stats["split"] = [split] * len(detailed_stats) 89 | eval_detailed_stats = pd.concat([eval_detailed_stats, detailed_stats]) 90 | if (ddp and dist.get_rank() == 0) or not ddp: 91 | stats_path = os.path.splitext(args.load)[0] + f"_eval_stats.pkl" 92 | eval_detailed_stats.to_pickle(stats_path) 93 | print("dumped stats to ", stats_path) 94 | -------------------------------------------------------------------------------- /fusion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2018 Andy Zeng 2 | # Source: https://github.com/andyzeng/tsdf-fusion-python/blob/master/fusion.py 3 | # BSD 2-Clause License 4 | 5 | # Copyright (c) 2019, Princeton University 6 | # All rights reserved. 7 | 8 | # Redistribution and use in source and binary forms, with or without 9 | # modification, are permitted provided that the following conditions are met: 10 | 11 | # 1. Redistributions of source code must retain the above copyright notice, this 12 | # list of conditions and the following disclaimer. 13 | 14 | # 2. Redistributions in binary form must reproduce the above copyright notice, 15 | # this list of conditions and the following disclaimer in the documentation 16 | # and/or other materials provided with the distribution. 17 | 18 | # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 19 | # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 20 | # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 21 | # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 22 | # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 23 | # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 24 | # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 25 | # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 26 | # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 27 | # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 28 | 29 | import numpy as np 30 | from numba import njit, prange 31 | from skimage import measure 32 | 33 | 34 | class TSDFVolume: 35 | """Volumetric TSDF Fusion of RGB-D Images.""" 36 | 37 | def __init__(self, vol_bnds, voxel_size): 38 | """Constructor. 39 | Args: 40 | vol_bnds (ndarray): An ndarray of shape (3, 2). Specifies the 41 | xyz bounds (min/max) in meters. 42 | voxel_size (float): The volume discretization in meters. 43 | """ 44 | vol_bnds = np.asarray(vol_bnds) 45 | assert vol_bnds.shape == (3, 2), "[!] `vol_bnds` should be of shape (3, 2)." 46 | assert (vol_bnds[:, 0] < vol_bnds[:, 1]).all() 47 | 48 | # Define voxel volume parameters 49 | self._vol_bnds = vol_bnds 50 | self._voxel_size = float(voxel_size) 51 | self._trunc_margin = 5 * self._voxel_size # truncation on SDF 52 | self._color_const = 256 * 256 53 | 54 | # Adjust volume bounds and ensure C-order contiguous 55 | self._vol_dim = ( 56 | np.ceil((self._vol_bnds[:, 1] - self._vol_bnds[:, 0]) / self._voxel_size) 57 | .copy(order="C") 58 | .astype(int) 59 | ) 60 | self._vol_bnds[:, 1] = self._vol_bnds[:, 0] + self._vol_dim * self._voxel_size 61 | self._vol_origin = self._vol_bnds[:, 0].copy(order="C").astype(np.float32) 62 | 63 | # Initialize pointers to voxel volume in CPU memory 64 | # Assume all unobserved regions are occupied 65 | self._tsdf_vol_cpu = -np.ones(self._vol_dim).astype(np.float32) 66 | # for computing the cumulative moving average of observations per voxel 67 | self._weight_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 68 | self._color_vol_cpu = np.zeros(self._vol_dim).astype(np.float32) 69 | 70 | # Get voxel grid coordinates 71 | xv, yv, zv = np.meshgrid( 72 | range(self._vol_dim[0]), 73 | range(self._vol_dim[1]), 74 | range(self._vol_dim[2]), 75 | indexing="ij", 76 | ) 77 | self.vox_coords = ( 78 | np.concatenate( 79 | [xv.reshape(1, -1), yv.reshape(1, -1), zv.reshape(1, -1)], axis=0 80 | ) 81 | .astype(int) 82 | .T 83 | ) 84 | 85 | @staticmethod 86 | @njit(parallel=True) 87 | def vox2world(vol_origin, vox_coords, vox_size): 88 | """Convert voxel grid coordinates to world coordinates.""" 89 | vol_origin = vol_origin.astype(np.float32) 90 | vox_coords = vox_coords.astype(np.float32) 91 | cam_pts = np.empty_like(vox_coords, dtype=np.float32) 92 | for i in prange(vox_coords.shape[0]): 93 | for j in range(3): 94 | cam_pts[i, j] = vol_origin[j] + (vox_size * vox_coords[i, j]) 95 | return cam_pts 96 | 97 | @staticmethod 98 | @njit(parallel=True) 99 | def cam2pix(cam_pts, intr): 100 | """Convert camera coordinates to pixel coordinates.""" 101 | intr = intr.astype(np.float32) 102 | fx, fy = intr[0, 0], intr[1, 1] 103 | cx, cy = intr[0, 2], intr[1, 2] 104 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64) 105 | for i in prange(cam_pts.shape[0]): 106 | pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx)) 107 | pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy)) 108 | return pix 109 | 110 | @staticmethod 111 | @njit(parallel=True) 112 | def integrate_tsdf(tsdf_vol, dist, w_old, obs_weight): 113 | """Integrate the TSDF volume.""" 114 | tsdf_vol_int = np.empty_like(tsdf_vol, dtype=np.float32) 115 | w_new = np.empty_like(w_old, dtype=np.float32) 116 | for i in prange(len(tsdf_vol)): 117 | w_new[i] = w_old[i] + obs_weight 118 | tsdf_vol_int[i] = (w_old[i] * tsdf_vol[i] + obs_weight * dist[i]) / w_new[i] 119 | return tsdf_vol_int, w_new 120 | 121 | def integrate(self, color_im, depth_im, cam_intr, cam_pose, obs_weight=1.0): 122 | """Integrate an RGB-D frame into the TSDF volume. 123 | Args: 124 | color_im (ndarray): An RGB image of shape (H, W, 3). 125 | depth_im (ndarray): A depth image of shape (H, W). 126 | cam_intr (ndarray): The camera intrinsics matrix of shape (3, 3). 127 | cam_pose (ndarray): The camera pose (i.e. extrinsics) of shape (4, 4). 128 | obs_weight (float): The weight to assign for the current observation. A higher 129 | value 130 | """ 131 | im_h, im_w = depth_im.shape 132 | 133 | # Fold RGB color image into a single channel image 134 | color_im = color_im.astype(np.float32) 135 | color_im = np.floor( 136 | color_im[..., 2] * self._color_const 137 | + color_im[..., 1] * 256 138 | + color_im[..., 0] 139 | ) 140 | 141 | # Convert voxel grid coordinates to pixel coordinates 142 | cam_pts = self.vox2world(self._vol_origin, self.vox_coords, self._voxel_size) 143 | cam_pts = rigid_transform(cam_pts, np.linalg.inv(cam_pose)) 144 | pix_z = cam_pts[:, 2] 145 | pix = self.cam2pix(cam_pts, cam_intr) 146 | pix_x, pix_y = pix[:, 0], pix[:, 1] 147 | 148 | # Eliminate pixels outside view frustum 149 | valid_pix = np.logical_and( 150 | pix_x >= 0, 151 | np.logical_and( 152 | pix_x < im_w, 153 | np.logical_and(pix_y >= 0, np.logical_and(pix_y < im_h, pix_z > 0)), 154 | ), 155 | ) 156 | depth_val = np.zeros(pix_x.shape) 157 | depth_val[valid_pix] = depth_im[pix_y[valid_pix], pix_x[valid_pix]] 158 | 159 | # Integrate TSDF 160 | depth_diff = depth_val - pix_z 161 | valid_pts = np.logical_and(depth_val > 0, depth_diff >= -self._trunc_margin) 162 | dist = np.maximum(-1, np.minimum(1, depth_diff / self._trunc_margin)) 163 | valid_vox_x = self.vox_coords[valid_pts, 0] 164 | valid_vox_y = self.vox_coords[valid_pts, 1] 165 | valid_vox_z = self.vox_coords[valid_pts, 2] 166 | w_old = self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 167 | tsdf_vals = self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 168 | valid_dist = dist[valid_pts] 169 | tsdf_vol_new, w_new = self.integrate_tsdf( 170 | tsdf_vals, valid_dist, w_old, obs_weight 171 | ) 172 | self._weight_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = w_new 173 | self._tsdf_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = tsdf_vol_new 174 | 175 | # Integrate color 176 | old_color = self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] 177 | old_b = np.floor(old_color / self._color_const) 178 | old_g = np.floor((old_color - old_b * self._color_const) / 256) 179 | old_r = old_color - old_b * self._color_const - old_g * 256 180 | new_color = color_im[pix_y[valid_pts], pix_x[valid_pts]] 181 | new_b = np.floor(new_color / self._color_const) 182 | new_g = np.floor((new_color - new_b * self._color_const) / 256) 183 | new_r = new_color - new_b * self._color_const - new_g * 256 184 | new_b = np.minimum( 185 | 255.0, np.round((w_old * old_b + obs_weight * new_b) / w_new) 186 | ) 187 | new_g = np.minimum( 188 | 255.0, np.round((w_old * old_g + obs_weight * new_g) / w_new) 189 | ) 190 | new_r = np.minimum( 191 | 255.0, np.round((w_old * old_r + obs_weight * new_r) / w_new) 192 | ) 193 | self._color_vol_cpu[valid_vox_x, valid_vox_y, valid_vox_z] = ( 194 | new_b * self._color_const + new_g * 256 + new_r 195 | ) 196 | 197 | def get_volume(self): 198 | # Fold RGB color image into a single channel image 199 | color_vol = np.zeros([3] + list(self._color_vol_cpu.shape)).astype(np.uint8) 200 | color_vol[2, ...] = np.floor(self._color_vol_cpu / self._color_const) 201 | color_vol[1, ...] = np.floor( 202 | (self._color_vol_cpu - color_vol[2, ...] * self._color_const) / 256 203 | ) 204 | color_vol[0, ...] = ( 205 | self._color_vol_cpu 206 | - color_vol[2, ...] * self._color_const 207 | - color_vol[1, ...] * 256 208 | ) 209 | return self._tsdf_vol_cpu, color_vol 210 | 211 | def get_point_cloud(self): 212 | """Extract a point cloud from the voxel volume.""" 213 | tsdf_vol, color_vol = self.get_volume() 214 | 215 | # Marching cubes 216 | verts = measure.marching_cubes_lewiner(tsdf_vol, level=0)[0] 217 | verts_ind = np.round(verts).astype(int) 218 | verts = verts * self._voxel_size + self._vol_origin 219 | 220 | # Get vertex colors 221 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]] 222 | colors_b = np.floor(rgb_vals / self._color_const) 223 | colors_g = np.floor((rgb_vals - colors_b * self._color_const) / 256) 224 | colors_r = rgb_vals - colors_b * self._color_const - colors_g * 256 225 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T 226 | colors = colors.astype(np.uint8) 227 | 228 | pc = np.hstack([verts, colors]) 229 | return pc 230 | 231 | def get_mesh(self): 232 | """Compute a mesh from the voxel volume using marching cubes.""" 233 | tsdf_vol, color_vol = self.get_volume() 234 | 235 | # Marching cubes 236 | verts, faces, norms, vals = measure.marching_cubes_lewiner(tsdf_vol, level=0) 237 | verts_ind = np.round(verts).astype(int) 238 | # voxel grid coordinates to world coordinates 239 | verts = verts * self._voxel_size + self._vol_origin 240 | 241 | # Get vertex colors 242 | rgb_vals = color_vol[verts_ind[:, 0], verts_ind[:, 1], verts_ind[:, 2]] 243 | colors_b = np.floor(rgb_vals / self._color_const) 244 | colors_g = np.floor((rgb_vals - colors_b * self._color_const) / 256) 245 | colors_r = rgb_vals - colors_b * self._color_const - colors_g * 256 246 | colors = np.floor(np.asarray([colors_r, colors_g, colors_b])).T 247 | colors = colors.astype(np.uint8) 248 | return verts, faces, norms, colors 249 | 250 | 251 | def rigid_transform(xyz, transform): 252 | """Applies a rigid transform to an (N, 3) pointcloud.""" 253 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)]) 254 | xyz_t_h = np.dot(transform, xyz_h.T).T 255 | return xyz_t_h[:, :3] 256 | 257 | 258 | def get_view_frustum(depth_im, cam_intr, cam_pose): 259 | """Get corners of 3D camera view frustum of depth image""" 260 | im_h = depth_im.shape[0] 261 | im_w = depth_im.shape[1] 262 | max_depth = np.max(depth_im) 263 | view_frust_pts = np.array( 264 | [ 265 | (np.array([0, 0, 0, im_w, im_w]) - cam_intr[0, 2]) 266 | * np.array([0, max_depth, max_depth, max_depth, max_depth]) 267 | / cam_intr[0, 0], 268 | (np.array([0, 0, im_h, 0, im_h]) - cam_intr[1, 2]) 269 | * np.array([0, max_depth, max_depth, max_depth, max_depth]) 270 | / cam_intr[1, 1], 271 | np.array([0, max_depth, max_depth, max_depth, max_depth]), 272 | ] 273 | ) 274 | view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T 275 | return view_frust_pts 276 | 277 | 278 | def meshwrite(filename, verts, faces, norms, colors): 279 | """Save a 3D mesh to a polygon .ply file.""" 280 | # Write header 281 | ply_file = open(filename, "w") 282 | ply_file.write("ply\n") 283 | ply_file.write("format ascii 1.0\n") 284 | ply_file.write("element vertex %d\n" % (verts.shape[0])) 285 | ply_file.write("property float x\n") 286 | ply_file.write("property float y\n") 287 | ply_file.write("property float z\n") 288 | ply_file.write("property float nx\n") 289 | ply_file.write("property float ny\n") 290 | ply_file.write("property float nz\n") 291 | ply_file.write("property uchar red\n") 292 | ply_file.write("property uchar green\n") 293 | ply_file.write("property uchar blue\n") 294 | ply_file.write("element face %d\n" % (faces.shape[0])) 295 | ply_file.write("property list uchar int vertex_index\n") 296 | ply_file.write("end_header\n") 297 | 298 | # Write vertex list 299 | for i in range(verts.shape[0]): 300 | ply_file.write( 301 | "%f %f %f %f %f %f %d %d %d\n" 302 | % ( 303 | verts[i, 0], 304 | verts[i, 1], 305 | verts[i, 2], 306 | norms[i, 0], 307 | norms[i, 1], 308 | norms[i, 2], 309 | colors[i, 0], 310 | colors[i, 1], 311 | colors[i, 2], 312 | ) 313 | ) 314 | 315 | # Write face list 316 | for i in range(faces.shape[0]): 317 | ply_file.write("3 %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2])) 318 | 319 | ply_file.close() 320 | 321 | 322 | def pcwrite(filename, xyzrgb): 323 | """Save a point cloud to a polygon .ply file.""" 324 | xyz = xyzrgb[:, :3] 325 | rgb = xyzrgb[:, 3:].astype(np.uint8) 326 | 327 | # Write header 328 | ply_file = open(filename, "w") 329 | ply_file.write("ply\n") 330 | ply_file.write("format ascii 1.0\n") 331 | ply_file.write("element vertex %d\n" % (xyz.shape[0])) 332 | ply_file.write("property float x\n") 333 | ply_file.write("property float y\n") 334 | ply_file.write("property float z\n") 335 | ply_file.write("property uchar red\n") 336 | ply_file.write("property uchar green\n") 337 | ply_file.write("property uchar blue\n") 338 | ply_file.write("end_header\n") 339 | 340 | # Write vertex list 341 | for i in range(xyz.shape[0]): 342 | ply_file.write( 343 | "%f %f %f %d %d %d\n" 344 | % ( 345 | xyz[i, 0], 346 | xyz[i, 1], 347 | xyz[i, 2], 348 | rgb[i, 0], 349 | rgb[i, 1], 350 | rgb[i, 2], 351 | ) 352 | ) 353 | -------------------------------------------------------------------------------- /generate_relevancy.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from pathlib import Path 3 | import h5py 4 | import torch 5 | from tqdm import tqdm 6 | import ray 7 | from utils import write_to_hdf5 8 | from filelock import FileLock 9 | import numpy as np 10 | from CLIP.clip import ClipWrapper, saliency_configs, imagenet_templates 11 | from dataset import synonyms, deref_h5py 12 | import typer 13 | import imageio 14 | from matplotlib import pyplot as plt 15 | import cv2 16 | from time import time 17 | 18 | app = typer.Typer() 19 | 20 | 21 | def resize_and_add_data(dataset, data): 22 | data_shape = np.array(data.shape) 23 | dataset_shape = np.array(dataset.shape) 24 | assert (dataset_shape[1:] == data_shape[1:]).all() 25 | dataset.resize(dataset_shape[0] + data_shape[0], axis=0) 26 | dataset[-data_shape[0] :, ...] = data 27 | return [ 28 | dataset.regionref[dataset_shape[0] + i, ...] 29 | for i in np.arange(0, data_shape[0]) 30 | ] 31 | 32 | 33 | def get_datastructure(image_shape, relevancy_shape, tsdf_dim, num_output_pts, **kwargs): 34 | image_shape = list(image_shape) 35 | relevancy_shape = list(relevancy_shape) 36 | return { 37 | "rgb": {"dtype": "uint8", "item_shape": image_shape + [3]}, 38 | "depth": {"dtype": "f", "item_shape": image_shape}, 39 | "seg": {"dtype": "i", "item_shape": image_shape}, 40 | "saliencies": {"dtype": "f", "item_shape": relevancy_shape}, 41 | "tsdf_value_pts": {"dtype": "f", "item_shape": [np.prod(tsdf_dim)]}, 42 | "tsdf_xyz_pts": {"dtype": "f", "item_shape": [np.prod(tsdf_dim), 3]}, 43 | "full_xyz_pts": {"dtype": "f", "item_shape": [num_output_pts, 3]}, 44 | "full_objid_pts": {"dtype": "i", "item_shape": [num_output_pts]}, 45 | } 46 | 47 | 48 | def init_dataset(file_path, data_structure): 49 | with h5py.File(file_path, mode="w") as file: 50 | # setup 51 | for key, data_info in data_structure.items(): 52 | file.create_dataset( 53 | name=key, 54 | shape=tuple([0] + data_info["item_shape"]), 55 | dtype=data_info["dtype"], 56 | chunks=tuple([1] + data_info["item_shape"]), 57 | compression="gzip", 58 | compression_opts=9, 59 | maxshape=tuple([None] + data_info["item_shape"]), 60 | ) 61 | 62 | 63 | @ray.remote 64 | def generate_saliency_helper( 65 | clip_wrapper, rgb_inputs, prompts, text_labels, scene_path, replace 66 | ): 67 | saliencies = { 68 | rgb_name: { 69 | saliency_config_name: ray.get( 70 | clip_wrapper.get_clip_saliency.remote( 71 | img=rgb, 72 | text_labels=text_labels, 73 | prompts=prompts 74 | if "imagenet_prompt_ensemble" 75 | not in saliency_config(img_dim=min(rgb.shape[:2])) 76 | or not saliency_config(img_dim=min(rgb.shape[:2]))[ 77 | "imagenet_prompt_ensemble" 78 | ] 79 | else imagenet_templates, 80 | **saliency_config(img_dim=min(rgb.shape[:2])), 81 | ) 82 | ) 83 | for saliency_config_name, saliency_config in saliency_configs.items() 84 | } 85 | for rgb_name, rgb in rgb_inputs.items() 86 | } 87 | with FileLock(scene_path + ".lock"): 88 | with h5py.File(scene_path, mode="a") as f: 89 | saliency_group = f["data"].create_group("saliencies") 90 | for rgb_name, rgb_saliencies in saliencies.items(): 91 | for ( 92 | saliency_config_name, 93 | (config_saliency, text_label_features), 94 | ) in rgb_saliencies.items(): 95 | storage_dims = np.array(f["saliencies"].shape)[1:] 96 | config_saliency = torch.nn.functional.interpolate( 97 | config_saliency[:, None, :, :], 98 | size=tuple(storage_dims), 99 | mode="nearest-exact" 100 | # mode='bilinear', 101 | # align_corners=False 102 | )[:, 0] 103 | 104 | config_saliency = torch.cat( 105 | [config_saliency, config_saliency.mean(dim=0, keepdim=True)], 106 | dim=0, 107 | ) 108 | text_label_features = torch.cat( 109 | [ 110 | text_label_features, 111 | text_label_features.mean(dim=0, keepdim=True), 112 | ], 113 | dim=0, 114 | ) 115 | text_label_features /= text_label_features.norm( 116 | dim=-1, keepdim=True 117 | ) 118 | write_to_hdf5( 119 | saliency_group, 120 | key=rgb_name 121 | + "|" 122 | + saliency_config_name 123 | + "|saliency_text_labels", 124 | value=np.array(text_labels + ["mean"]).astype("S"), 125 | replace=replace, 126 | ) 127 | write_to_hdf5( 128 | saliency_group, 129 | key=rgb_name 130 | + "|" 131 | + saliency_config_name 132 | + "|saliency_text_label_features", 133 | value=text_label_features, 134 | replace=replace, 135 | ) 136 | region_references = resize_and_add_data( 137 | dataset=f["saliencies"], data=config_saliency 138 | ) 139 | write_to_hdf5( 140 | saliency_group, 141 | key=rgb_name + "|" + saliency_config_name, 142 | dtype=h5py.regionref_dtype, 143 | value=region_references, 144 | replace=replace, 145 | ) 146 | return clip_wrapper 147 | 148 | 149 | @app.command() 150 | def dataset( 151 | file_path: str, 152 | num_processes: int, 153 | local: bool, 154 | prompts: List[str] = ["a render of a {} in a game engine."], 155 | replace=False, 156 | ): 157 | if "matterport" in file_path or "nyu" in file_path: 158 | prompts = ["a photograph of a {} in a home."] 159 | print(prompts) 160 | tasks = [] 161 | ray.init(log_to_driver=True, local_mode=local) 162 | num_cuda_devices = torch.cuda.device_count() 163 | assert num_cuda_devices > 0 164 | print(f"[INFO] FOUND {num_cuda_devices} CUDA DEVICE") 165 | wrapper_actor_cls = ray.remote(ClipWrapper) 166 | available_clip_wrappers = [ 167 | wrapper_actor_cls.options(num_gpus=num_cuda_devices / num_processes).remote( 168 | clip_model_type="ViT-B/32", device="cuda" 169 | ) 170 | for _ in range(num_processes) 171 | ] 172 | 173 | scene_paths = list(reversed(sorted(map(str, Path(file_path).rglob("*.hdf5"))))) 174 | if replace: 175 | if input("Replace = True. Delete existing relevancies? [y/n]") != "y": 176 | exit() 177 | for scene_path in tqdm( 178 | scene_paths, dynamic_ncols=True, desc="deleting existing relevancies" 179 | ): 180 | try: 181 | with h5py.File(scene_path, mode="a") as f: 182 | for k in f["data"]: 183 | if "salienc" in k: 184 | del f[f"data/{k}"] 185 | if "saliencies" in f: 186 | data_shape = list(f["saliencies"].shape[1:]) 187 | del f["saliencies"] 188 | f.create_dataset( 189 | name="saliencies", 190 | shape=tuple([0] + data_shape), 191 | dtype="f", 192 | chunks=tuple([1] + data_shape), 193 | compression="gzip", 194 | compression_opts=9, 195 | maxshape=tuple([None] + data_shape), 196 | ) 197 | except Exception as e: 198 | print(e, scene_path) 199 | exit() 200 | for scene_path in tqdm( 201 | scene_paths, dynamic_ncols=True, desc="generating relevancies", smoothing=0.001 202 | ): 203 | assert len(available_clip_wrappers) > 0 204 | try: 205 | with h5py.File(scene_path, mode="a") as f: 206 | scene_already_done = "saliencies" in f["data"] 207 | if not scene_already_done or replace: 208 | if scene_already_done: 209 | for k in f["data"]: 210 | if "salienc" in k: 211 | del f[f"data/{k}"] 212 | data_shape = f["saliencies"].shape[1:] 213 | if "saliencies" in f: 214 | del f["saliencies"] 215 | f.create_dataset( 216 | name="saliencies", 217 | shape=tuple([0] + data_shape), 218 | dtype="f", 219 | chunks=tuple([1] + data_shape), 220 | compression="gzip", 221 | compression_opts=9, 222 | maxshape=tuple([None] + data_shape), 223 | ) 224 | 225 | if "data/visible_scene_obj_labels" in f: 226 | del f["data/visible_scene_obj_labels"] 227 | objid_to_class = np.array(f[f"data/objid_to_class"]).astype(str) 228 | text_labels = objid_to_class.copy() 229 | scene_has_groundtruth = ( 230 | "seg" in f["data"] and "full_objid_pts" in f["data"] 231 | ) 232 | visible_scene_obj_labels = text_labels.copy() 233 | if scene_has_groundtruth: 234 | objids_in_scene = list( 235 | set( 236 | deref_h5py( 237 | dataset=f["full_objid_pts"], 238 | refs=f["data/full_objid_pts"], 239 | ) 240 | .astype(int) 241 | .reshape(-1) 242 | ) 243 | - {-1} 244 | ) # remove empty 245 | scene_object_labels = text_labels.copy()[objids_in_scene] 246 | 247 | # remove objects which are not in view 248 | gt_seg = deref_h5py(dataset=f["seg"], refs=f["data"]["seg"])[0] 249 | visible_obj_ids = list(map(int, set(np.unique(gt_seg)) - {-1})) 250 | visible_obj_labels = text_labels[visible_obj_ids] 251 | visible_scene_obj_labels = list( 252 | set(visible_obj_labels).intersection( 253 | set(scene_object_labels) 254 | ) 255 | ) 256 | visible_scene_obj_labels = list( 257 | sorted( 258 | set( 259 | map( 260 | lambda c: c.split("[")[0].lstrip().rstrip(), 261 | visible_scene_obj_labels, 262 | ) 263 | ) 264 | ) 265 | ) 266 | # visible_scene_obj_labels used to filter 267 | # objects both visible and in scene 268 | text_labels = visible_obj_labels.copy() 269 | text_labels = set(text_labels) 270 | 271 | # create saliency maps necessary for descriptions 272 | if ( 273 | "descriptions" in f["data"] 274 | and len(np.array(f["data/descriptions/spatial_relation_name"])) 275 | > 0 276 | ): 277 | target_obj_names = np.array( 278 | f["data/descriptions/target_obj_name"] 279 | ).astype(str) 280 | reference_obj_names = np.array( 281 | f["data/descriptions/reference_obj_name"] 282 | ).astype(str) 283 | spatial_relation_names = np.array( 284 | f["data/descriptions/spatial_relation_name"] 285 | ).astype(str) 286 | text_labels = text_labels.union( 287 | target_obj_names.tolist() + reference_obj_names.tolist() 288 | ) 289 | 290 | # gradcam for clip spatial 291 | descriptions = "" 292 | for desc_part in [ 293 | target_obj_names, 294 | " ", 295 | spatial_relation_names, 296 | " a ", 297 | reference_obj_names, 298 | ]: 299 | descriptions = np.char.add(descriptions, desc_part) 300 | text_labels = text_labels.union(descriptions) 301 | # descriptions with synonyms 302 | descriptions = "" 303 | for desc_part in [ 304 | np.array( 305 | list( 306 | map( 307 | lambda x: x 308 | if x not in synonyms.keys() 309 | else synonyms[x], 310 | target_obj_names, 311 | ) 312 | ) 313 | ), 314 | " ", 315 | spatial_relation_names, 316 | " a ", 317 | np.array( 318 | list( 319 | map( 320 | lambda x: x 321 | if x not in synonyms.keys() 322 | else synonyms[x], 323 | reference_obj_names, 324 | ) 325 | ) 326 | ), 327 | ]: 328 | descriptions = np.char.add(descriptions, desc_part) 329 | text_labels = text_labels.union(descriptions) 330 | text_labels = set( 331 | map(lambda c: c.split("[")[0].lstrip().rstrip(), text_labels) 332 | ) 333 | 334 | # do synonyms 335 | text_labels = text_labels.union( 336 | map( 337 | lambda text_label: synonyms[text_label], 338 | filter( 339 | lambda text_label: text_label in synonyms, text_labels 340 | ), 341 | ) 342 | ) 343 | for remove_label in {"unlabelled", "empty", "out of bounds"}: 344 | if remove_label in text_labels: 345 | text_labels.remove(remove_label) 346 | text_labels = list(sorted(text_labels)) 347 | 348 | rgb_inputs = {"rgb": np.array(f["rgb"][f["data"]["rgb"][0]][0])} 349 | if ( 350 | "domain_randomized_rgb" in f["data"] 351 | and len(np.array(f["data/domain_randomized_rgb"])[0].shape) > 1 352 | ): 353 | rgb_inputs["domain_randomized_rgb"] = np.array( 354 | f["data/domain_randomized_rgb"] 355 | )[0] 356 | write_to_hdf5( 357 | f["data"], 358 | key="visible_scene_obj_labels", 359 | value=np.array(visible_scene_obj_labels).astype("S"), 360 | replace=replace, 361 | ) 362 | clip_wrapper = available_clip_wrappers.pop() 363 | tasks.append( 364 | generate_saliency_helper.remote( 365 | clip_wrapper=clip_wrapper, 366 | scene_path=scene_path, 367 | rgb_inputs=rgb_inputs, 368 | text_labels=text_labels, 369 | prompts=prompts, 370 | replace=replace, 371 | ) 372 | ) 373 | except Exception as e: 374 | print(e) 375 | print(scene_path, "invalid hdf5 file") 376 | if len(available_clip_wrappers) == 0: 377 | readies, tasks = ray.wait(tasks, num_returns=1) 378 | num_readies = len(readies) 379 | try: 380 | available_clip_wrappers.extend(ray.get(readies)) 381 | except Exception as e: 382 | print(e) 383 | available_clip_wrappers.extend( 384 | [ 385 | wrapper_actor_cls.options( 386 | num_gpus=num_cuda_devices / num_processes 387 | ).remote(clip_model_type="ViT-B/32", device="cuda") 388 | for _ in range(num_readies) 389 | ] 390 | ) 391 | ray.get(tasks) 392 | 393 | 394 | @app.command() 395 | def image( 396 | file_path: str = typer.Argument( 397 | default="matterport.png", help="path of image file" 398 | ), 399 | labels: List[str] = typer.Option( 400 | default=[ 401 | "basketball jersey", 402 | "nintendo switch", 403 | "television", 404 | "ping pong table", 405 | "vase", 406 | "fireplace", 407 | "abstract painting of a vespa", 408 | "carpet", 409 | "wall", 410 | ], 411 | help='list of object categories (e.g.: "nintendo switch")', 412 | ), 413 | prompts: List[str] = typer.Option( 414 | default=["a photograph of a {} in a home."], 415 | help="prompt template to use with CLIP.", 416 | ), 417 | ): 418 | """ 419 | Generates a multi-scale relevancy for image at `file_path`. 420 | """ 421 | img = np.array(imageio.imread(file_path)) 422 | assert img.dtype == np.uint8 423 | h, w, c = img.shape 424 | start = time() 425 | grads = ClipWrapper.get_clip_saliency( 426 | img=img, 427 | text_labels=np.array(labels), 428 | prompts=prompts, 429 | **saliency_configs["ours"](h), 430 | )[0] 431 | print(f"get gradcam took {float(time() - start)} seconds", grads.shape) 432 | grads -= grads.mean(axis=0) 433 | grads = grads.cpu().numpy() 434 | fig, axes = plt.subplots(3, 3) 435 | axes = axes.flatten() 436 | vmin = 0.002 437 | cmap = plt.get_cmap("jet") 438 | vmax = 0.008 439 | for ax, label_grad, label in zip(axes, grads, labels): 440 | ax.axis("off") 441 | ax.imshow(img) 442 | ax.set_title(label, fontsize=12) 443 | grad = np.clip((label_grad - vmin) / (vmax - vmin), a_min=0.0, a_max=1.0) 444 | colored_grad = cmap(grad) 445 | grad = 1 - grad 446 | colored_grad[..., -1] = grad * 0.7 447 | ax.imshow(colored_grad) 448 | plt.tight_layout(pad=0) 449 | plt.savefig("grads.png") 450 | print("dumped relevancy to grads.png") 451 | plt.show() 452 | 453 | 454 | if __name__ == "__main__": 455 | app() 456 | -------------------------------------------------------------------------------- /grads.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/grads.png -------------------------------------------------------------------------------- /matterport.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/matterport.png -------------------------------------------------------------------------------- /plot_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from matplotlib.patches import Patch 3 | import matplotlib.pyplot as plt 4 | import io 5 | from PIL import Image 6 | import open3d as o3d 7 | from skimage.measure import block_reduce 8 | import matplotlib.cm as cm 9 | import matplotlib as mpl 10 | 11 | 12 | def plot_to_png(fig): 13 | buf = io.BytesIO() 14 | plt.savefig(buf, format="png") 15 | buf.seek(0) 16 | img = np.array(Image.open(buf)).astype(np.uint8) 17 | return img 18 | 19 | 20 | def set_view_and_save_img(fig, ax, views): 21 | for elev, azim in views: 22 | ax.view_init(elev=elev, azim=azim) 23 | yield plot_to_png(fig) 24 | 25 | 26 | def plot_pointcloud( 27 | xyz, 28 | features, 29 | object_labels=None, 30 | background_color=(0.1, 0.1, 0.1, 0.99), 31 | num_points=50000, 32 | views=[(45, 135)], 33 | pts_size=3, 34 | alpha=0.5, 35 | plot_empty=False, 36 | visualize_ghost_points=False, 37 | object_colors=None, 38 | delete_fig=True, 39 | show_plot=False, 40 | bounds=[[-1, -1, -0.1], [1, 1, 1.9]], 41 | ): 42 | is_semantic = len(features.shape) == 1 43 | if type(alpha) is float: 44 | alpha = np.ones(xyz.shape[0]).astype(np.float32) * alpha 45 | if not plot_empty and is_semantic and object_labels is not None: 46 | mask = np.ones_like(alpha).astype(bool) 47 | for remove_label in ["empty", "unlabelled", "out of bounds"]: 48 | if remove_label in object_labels.tolist(): 49 | remove_idx = object_labels.tolist().index(remove_label) 50 | mask = np.logical_and(mask, features != remove_idx) 51 | xyz = xyz[mask, :] 52 | features = features[mask, ...] 53 | alpha = alpha[mask] 54 | if type(pts_size) != int and type(pts_size) != float: 55 | pts_size = pts_size[mask] 56 | # subsample 57 | if xyz.shape[0] > num_points: 58 | indices = np.random.choice(xyz.shape[0], size=num_points, replace=False) 59 | xyz = xyz[indices, :] 60 | features = features[indices, ...] 61 | alpha = alpha[indices] 62 | if type(pts_size) != int and type(pts_size) != float: 63 | pts_size = pts_size[indices] 64 | 65 | fig = plt.figure(figsize=(6, 6), dpi=160) 66 | ax = fig.add_subplot(111, projection="3d") 67 | x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2] 68 | ax.set_facecolor(background_color) 69 | ax.w_xaxis.set_pane_color(background_color) 70 | ax.w_yaxis.set_pane_color(background_color) 71 | ax.w_zaxis.set_pane_color(background_color) 72 | # ax._axis3don = False 73 | 74 | if is_semantic and object_labels is not None: 75 | object_ids = list(np.unique(features)) 76 | object_labels = object_labels[object_ids].tolist() 77 | if object_colors is not None: 78 | object_colors = object_colors[object_ids] 79 | features = features.astype(np.int) 80 | # repack object ids 81 | repacked_obj_ids = np.zeros(features.shape).astype(np.uint32) 82 | for i, j in enumerate(object_ids): 83 | repacked_obj_ids[features == j] = i 84 | features = repacked_obj_ids 85 | 86 | object_ids = list(np.unique(features)) 87 | colors = np.zeros((len(features), 4)).astype(np.uint8) 88 | if object_colors is None: 89 | cmap = plt.get_cmap("tab20") 90 | object_colors = (255 * cmap(np.array(object_ids) % 20)).astype(np.uint8) 91 | for obj_id in np.unique(features): 92 | colors[features == obj_id, :] = object_colors[obj_id] 93 | colors = colors.astype(float) / 255.0 94 | object_colors = object_colors.astype(float) / 255 95 | handles = [ 96 | Patch(facecolor=c, edgecolor="grey", label=label) 97 | for label, c in zip(object_labels, object_colors) 98 | ] 99 | 100 | l = ax.legend( 101 | handles=handles, 102 | labels=object_labels, 103 | loc="lower center", 104 | bbox_to_anchor=(0.5, 0), 105 | ncol=4, 106 | facecolor=(0, 0, 0, 0.1), 107 | fontsize=8, 108 | framealpha=0, 109 | ) 110 | plt.setp(l.get_texts(), color=(0.8, 0.8, 0.8)) 111 | else: 112 | colors = features.astype(float) 113 | if colors.max() > 1.0: 114 | colors /= 255.0 115 | assert colors.max() <= 1.0 116 | # ensure alpha has same dims as colors 117 | if colors.shape[-1] == 4: 118 | colors[:, -1] = alpha 119 | ax.scatter(x, y, z, c=colors, s=pts_size) 120 | if visualize_ghost_points: 121 | x, y, z = np.array(np.unique(xyz, axis=0)).T 122 | ax.scatter(x, y, z, color=[1.0, 1.0, 1.0, 0.02], s=pts_size) 123 | 124 | # Hide axes ticks 125 | ax.set_xticks([]) 126 | ax.set_yticks([]) 127 | ax.set_zticks([]) 128 | ax.axes.set_xlim3d(left=bounds[0][0], right=bounds[1][0]) 129 | ax.axes.set_ylim3d(bottom=bounds[0][1], top=bounds[1][1]) 130 | ax.axes.set_zlim3d(bottom=bounds[0][2], top=bounds[1][2]) 131 | plt.tight_layout(pad=0) 132 | imgs = list(set_view_and_save_img(fig, ax, views)) 133 | if show_plot: 134 | plt.show() 135 | if delete_fig: 136 | plt.close(fig) 137 | return imgs 138 | 139 | 140 | # meshes = [] 141 | # for class_id in np.unique(features): 142 | # mask = features == class_id 143 | # pcd = o3d.geometry.PointCloud() 144 | # pcd.points = o3d.utility.Vector3dVector(xyz[mask, :]) 145 | # pcd.estimate_normals( 146 | # search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30)) 147 | # radii = [0.005, 0.01, 0.02, 0.04] 148 | # rec_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting( 149 | # pcd, o3d.utility.DoubleVector(radii)) 150 | # rec_mesh.paint_uniform_color(object_colors[class_id][:3]) 151 | # meshes.append(rec_mesh) 152 | # o3d.visualization.draw_geometries(meshes) 153 | 154 | 155 | def view_tsdf(tsdf, simplify=True): 156 | main_color = "#00000055" 157 | mpl.rcParams["text.color"] = main_color 158 | mpl.rcParams["axes.labelcolor"] = main_color 159 | mpl.rcParams["xtick.color"] = main_color 160 | mpl.rcParams["ytick.color"] = main_color 161 | mpl.rc("axes", edgecolor=main_color) 162 | mpl.rcParams["grid.color"] = "#00000033" 163 | 164 | if simplify: 165 | tsdf = block_reduce(tsdf, block_size=(8, 8, 8), func=np.mean) 166 | print("block reduced", tsdf.shape) 167 | 168 | x = np.arange(tsdf.shape[0])[:, None, None] 169 | y = np.arange(tsdf.shape[1])[None, :, None] 170 | z = np.arange(tsdf.shape[2])[None, None, :] 171 | x, y, z = np.broadcast_arrays(x, y, z) 172 | 173 | c = cm.plasma((tsdf.ravel() + 1)) 174 | alphas = (tsdf.ravel() < 0).astype(float) 175 | c[..., -1] = alphas 176 | 177 | fig = plt.figure() 178 | ax = fig.gca(projection="3d") 179 | ax.scatter(x.ravel(), y.ravel(), z.ravel(), c=c, s=1) 180 | ax.w_xaxis.set_pane_color((0.0, 0.0, 0.0, 0.0)) 181 | ax.w_yaxis.set_pane_color((0.0, 0.0, 0.0, 0.0)) 182 | ax.w_zaxis.set_pane_color((0.0, 0.0, 0.0, 0.0)) 183 | 184 | # Hide axes ticks 185 | ax.tick_params(axis="x", colors=(0.0, 0.0, 0.0, 0.0)) 186 | ax.tick_params(axis="y", colors=(0.0, 0.0, 0.0, 0.0)) 187 | ax.tick_params(axis="z", colors=(0.0, 0.0, 0.0, 0.0)) 188 | ax.view_init(20, -110) 189 | 190 | plt.show() 191 | -------------------------------------------------------------------------------- /point_cloud.py: -------------------------------------------------------------------------------- 1 | import pybullet_data 2 | import numpy as np 3 | from numba import njit, prange 4 | import pybullet as p 5 | import matplotlib.pyplot as plt 6 | 7 | 8 | def transform_pointcloud(xyz_pts, rigid_transform): 9 | """Apply rigid transformation to 3D pointcloud. 10 | Args: 11 | xyz_pts: Nx3 float array of 3D points 12 | rigid_transform: 3x4 or 4x4 float array defining a rigid transformation (rotation and translation) 13 | Returns: 14 | xyz_pts: Nx3 float array of transformed 3D points 15 | """ 16 | xyz_pts = np.dot(rigid_transform[:3, :3], xyz_pts.T) # apply rotation 17 | # apply translation 18 | xyz_pts = xyz_pts + np.tile( 19 | rigid_transform[:3, 3].reshape(3, 1), (1, xyz_pts.shape[1]) 20 | ) 21 | return xyz_pts.T 22 | 23 | 24 | def filter_pts_bounds(xyz, bounds): 25 | mask = xyz[:, 0] >= bounds[0, 0] 26 | mask = np.logical_and(mask, xyz[:, 0] <= bounds[1, 0]) 27 | mask = np.logical_and(mask, xyz[:, 1] >= bounds[0, 1]) 28 | mask = np.logical_and(mask, xyz[:, 1] <= bounds[1, 1]) 29 | mask = np.logical_and(mask, xyz[:, 2] >= bounds[0, 2]) 30 | mask = np.logical_and(mask, xyz[:, 2] <= bounds[1, 2]) 31 | return mask 32 | 33 | 34 | def get_pointcloud(depth_img, color_img, cam_intr, cam_pose=None): 35 | """Get 3D pointcloud from depth image. 36 | 37 | Args: 38 | depth_img: HxW float array of depth values in meters aligned with color_img 39 | color_img: HxWx3 uint8 array of color image 40 | cam_intr: 3x3 float array of camera intrinsic parameters 41 | cam_pose: (optional) 3x4 float array of camera pose matrix 42 | 43 | Returns: 44 | cam_pts: Nx3 float array of 3D points in camera/world coordinates 45 | color_pts: Nx3 uint8 array of color points 46 | """ 47 | 48 | img_h = depth_img.shape[0] 49 | img_w = depth_img.shape[1] 50 | 51 | # Project depth into 3D pointcloud in camera coordinates 52 | pixel_x, pixel_y = np.meshgrid( 53 | np.linspace(0, img_w - 1, img_w), np.linspace(0, img_h - 1, img_h) 54 | ) 55 | cam_pts_x = np.multiply(pixel_x - cam_intr[0, 2], depth_img / cam_intr[0, 0]) 56 | cam_pts_y = np.multiply(pixel_y - cam_intr[1, 2], depth_img / cam_intr[1, 1]) 57 | cam_pts_z = depth_img 58 | cam_pts = ( 59 | np.array([cam_pts_x, cam_pts_y, cam_pts_z]).transpose(1, 2, 0).reshape(-1, 3) 60 | ) 61 | 62 | if cam_pose is not None: 63 | cam_pts = transform_pointcloud(cam_pts, cam_pose) 64 | color_pts = None if color_img is None else color_img.reshape(-1, 3) 65 | # TODO check memory leak here 66 | return cam_pts, color_pts 67 | 68 | 69 | def project_pts_to_2d(pts, camera_view_matrix, camera_intrisic): 70 | """Project points to 2D. 71 | Args: 72 | pts: Nx3 float array of 3D points in world coordinates. 73 | camera_view_matrix: 4x4 float array. A wrd2cam transformation defining camera's totation and translation. 74 | camera_intrisic: 3x3 float array. [ [f,0,0],[0,f,0],[0,0,1] ]. f is focal length. 75 | Returns: 76 | coord_2d: Nx3 float array of 2D pixel. (w, h, d) the last one is depth 77 | """ 78 | pts_c = transform_pointcloud(pts, camera_view_matrix[0:3, :]) 79 | rot_algix = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0]]) 80 | pts_c = transform_pointcloud(pts_c, rot_algix) # Nx3 81 | coord_2d = np.dot(camera_intrisic, pts_c.T) # 3xN 82 | coord_2d[0:2, :] = coord_2d[0:2, :] / np.tile(coord_2d[2, :], (2, 1)) 83 | coord_2d[2, :] = pts_c[:, 2] 84 | coord_2d = np.array([coord_2d[1], coord_2d[0], coord_2d[2]]) 85 | return coord_2d.T 86 | 87 | 88 | def check_pts_in_frustum(xyz_pts, depth, cam_pose, cam_intr): 89 | # xyz_pts (N,3) 90 | cam_pts = transform_pointcloud( 91 | xyz_pts=xyz_pts, rigid_transform=np.linalg.inv(cam_pose) 92 | ) 93 | cam_pts_x = cam_pts[..., 0] 94 | cam_pts_y = cam_pts[..., 1] 95 | pix_z = cam_pts[..., 2] 96 | 97 | pix_x = (cam_intr[0, 0] / pix_z) * cam_pts_x + cam_intr[0, 2] 98 | pix_y = (cam_intr[1, 1] / pix_z) * cam_pts_y + cam_intr[1, 2] 99 | 100 | # camera to pixel space 101 | h, w = depth.shape 102 | 103 | valid_pix = np.logical_and( 104 | pix_x >= 0, 105 | np.logical_and( 106 | pix_x < w, np.logical_and(pix_y >= 0, np.logical_and(pix_y < h, pix_z > 0)) 107 | ), 108 | ) 109 | in_frustum_mask = valid_pix.reshape(-1) 110 | return in_frustum_mask 111 | 112 | 113 | def meshwrite(filename, verts, colors, faces=None): 114 | """Save 3D mesh to a polygon .ply file. 115 | Args: 116 | filename: string; path to mesh file. (suffix should be .ply) 117 | verts: [N, 3]. Coordinates of each vertex 118 | colors: [N, 3]. RGB or each vertex. (type: uint8) 119 | faces: (optional) [M, 4] 120 | """ 121 | # Write header 122 | ply_file = open(filename, "w") 123 | ply_file.write("ply\n") 124 | ply_file.write("format ascii 1.0\n") 125 | ply_file.write("element vertex %d\n" % (verts.shape[0])) 126 | ply_file.write("property float x\n") 127 | ply_file.write("property float y\n") 128 | ply_file.write("property float z\n") 129 | ply_file.write("property uchar red\n") 130 | ply_file.write("property uchar green\n") 131 | ply_file.write("property uchar blue\n") 132 | if faces is not None: 133 | ply_file.write("element face %d\n" % (faces.shape[0])) 134 | ply_file.write("end_header\n") 135 | 136 | # Write vertex list 137 | for i in range(verts.shape[0]): 138 | ply_file.write( 139 | "%f %f %f %d %d %d\n" 140 | % ( 141 | verts[i, 0], 142 | verts[i, 1], 143 | verts[i, 2], 144 | colors[i, 0], 145 | colors[i, 1], 146 | colors[i, 2], 147 | ) 148 | ) 149 | 150 | # Write face list 151 | if faces is not None: 152 | for i in range(faces.shape[0]): 153 | ply_file.write( 154 | "4 %d %d %d %d\n" % (faces[i, 0], faces[i, 1], faces[i, 2], faces[i, 3]) 155 | ) 156 | 157 | ply_file.close() 158 | 159 | 160 | @njit(parallel=True) 161 | def cam2pix(cam_pts, intr): 162 | """Convert camera coordinates to pixel coordinates.""" 163 | intr = intr.astype(np.float32) 164 | fx, fy = intr[0, 0], intr[1, 1] 165 | cx, cy = intr[0, 2], intr[1, 2] 166 | pix = np.empty((cam_pts.shape[0], 2), dtype=np.int64) 167 | for i in prange(cam_pts.shape[0]): 168 | pix[i, 0] = int(np.round((cam_pts[i, 0] * fx / cam_pts[i, 2]) + cx)) 169 | pix[i, 1] = int(np.round((cam_pts[i, 1] * fy / cam_pts[i, 2]) + cy)) 170 | return pix 171 | 172 | 173 | def compute_empty_mask( 174 | scene_bounds, depth_img, intrinsic_matrix, extrinsic_matrix, voxel_resolution=20 175 | ): 176 | # parts taken from 177 | # https://github.com/andyzeng/tsdf-fusion-python/blob/3f22a940d90f684145b1f29b1feaa92e09eb1db6/fusion.py#L170 178 | 179 | # start off all empty 180 | grid_shape = [voxel_resolution] * 3 181 | mask = np.ones(grid_shape).astype(int) 182 | # get volume points 183 | lc = scene_bounds[0] 184 | uc = scene_bounds[1] 185 | 186 | # get voxel indices 187 | grid_idxs = np.stack( 188 | np.meshgrid(*[np.arange(0, dim) for dim in grid_shape]), axis=-1 189 | ) 190 | 191 | # voxel indices to world pts 192 | idx_scale = np.array(grid_shape) - 1 193 | scales = (uc - lc) / idx_scale 194 | offsets = lc 195 | grid_points = grid_idxs.astype(float) * scales + offsets 196 | 197 | flattened_grid_points = grid_points.reshape(-1, 3) 198 | print(flattened_grid_points.min(axis=0), flattened_grid_points.max(axis=0)) 199 | 200 | # world pts to camera centric frame pts 201 | xyz_h = np.hstack( 202 | [ 203 | flattened_grid_points, 204 | np.ones((len(flattened_grid_points), 1), dtype=np.float32), 205 | ] 206 | ) 207 | xyz_t_h = np.dot(np.linalg.inv(extrinsic_matrix), xyz_h.T).T 208 | cam_pts = xyz_t_h[:, :3] 209 | pix_z = cam_pts[:, 2] 210 | pix = cam2pix(cam_pts, intrinsic_matrix) 211 | pix_x, pix_y = pix[:, 0], pix[:, 1] 212 | im_w, im_h = depth_img.shape 213 | 214 | valid_pix = np.logical_and( 215 | pix_x >= 0, 216 | np.logical_and( 217 | pix_x < im_w, 218 | np.logical_and(pix_y >= 0, np.logical_and(pix_y < im_h, pix_z > 0)), 219 | ), 220 | ) 221 | inframe_indices = grid_idxs.reshape(-1, 3)[valid_pix, :] 222 | 223 | # depth_val = np.zeros(pix_x.shape) 224 | # depth_val[valid_pix] = depth_img[pix_y[valid_pix], pix_x[valid_pix]] 225 | observed_indices = inframe_indices[ 226 | (depth_img[pix_y[valid_pix], pix_x[valid_pix]] > pix_z[valid_pix]) 227 | ] 228 | 229 | print("before:", mask.mean(), mask.shape, observed_indices.shape) 230 | for idx in observed_indices: 231 | mask[tuple(idx)] = 0 232 | print(mask.mean()) 233 | print(observed_indices.shape, mask.shape) 234 | # mask[observed_indices] = 0 235 | print("after:", mask.mean()) 236 | 237 | ax = plt.figure().add_subplot(projection="3d") 238 | ax.voxels(mask) 239 | # pts = grid_points[mask, :] 240 | # ax.scatter(pts[:, 0], pts[:, 1], pts[:, 2]) 241 | plt.show() 242 | return mask.astype(bool) 243 | 244 | 245 | def subsample(seg_pts, num_pts, random_state, balanced=True): 246 | probabilities = np.ones(seg_pts.shape).astype(np.float64) 247 | if balanced: 248 | unique_semantic_ids = np.unique(seg_pts) 249 | num_semantic_ids = len(unique_semantic_ids) 250 | for semantic_id in unique_semantic_ids: 251 | mask = seg_pts == semantic_id 252 | probabilities[mask] = 1.0 / (int((mask).sum().item()) * num_semantic_ids) 253 | else: 254 | probabilities /= probabilities.sum() 255 | indices = random_state.choice( 256 | seg_pts.shape[0], size=num_pts, replace=False, p=probabilities 257 | ) 258 | return indices 259 | 260 | 261 | if __name__ == "__main__": 262 | # TODO change this to filter input sampled points out based on 263 | # view point 264 | from datagen.simulation.asset import make_object, occluder_objects, partnet_objs 265 | from datagen.simulation import Camera 266 | 267 | object_keys = [k for k in occluder_objects] 268 | object_def = occluder_objects[object_keys[10]] 269 | p.connect(p.GUI) 270 | p.resetDebugVisualizerCamera( 271 | cameraDistance=4.0, 272 | cameraYaw=270, 273 | cameraPitch=-20, 274 | cameraTargetPosition=(0, 0, 0.4), 275 | ) 276 | p.setAdditionalSearchPath(pybullet_data.getDataPath()) 277 | p.setRealTimeSimulation(False) 278 | p.resetSimulation() 279 | p.setGravity(0, 0, -9.8) 280 | planeUid = p.loadURDF(fileName="plane.urdf", useFixedBase=True) 281 | occluder_obj = make_object(**object_def) 282 | camera = Camera(position=[-1, 1, 1], lookat=[0, 0, 0.5]) 283 | view = camera.get_image(return_pose=True, segmentation_mask=True) 284 | mask = compute_empty_mask( 285 | scene_bounds=np.array([[-1.0, -1.0, -0.1], [1.0, 1.0, 1.9]]), 286 | depth_img=view[1], 287 | intrinsic_matrix=view[-2], 288 | extrinsic_matrix=view[-1], 289 | ) 290 | -------------------------------------------------------------------------------- /scene_files/arkit_kitchen.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/arkit_kitchen.pkl -------------------------------------------------------------------------------- /scene_files/arkit_vn_poster.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/arkit_vn_poster.pkl -------------------------------------------------------------------------------- /scene_files/matterport_hallway.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/matterport_hallway.pkl -------------------------------------------------------------------------------- /scene_files/matterport_kitchen.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/matterport_kitchen.pkl -------------------------------------------------------------------------------- /scene_files/matterport_living_room.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/matterport_living_room.pkl -------------------------------------------------------------------------------- /scene_files/walle.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/real-stanford/semantic-abstraction/e76495e9ad49db75649394ac466fd3fec057c5b6/scene_files/walle.pkl -------------------------------------------------------------------------------- /semabs.yml: -------------------------------------------------------------------------------- 1 | name: semabs 2 | channels: 3 | - pyg 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - blas=1.0=mkl 11 | - blosc=1.21.0=h4ff587b_1 12 | - bottleneck=1.3.5=py38h7deecbd_0 13 | - brotli=1.0.9=h5eee18b_7 14 | - brotli-bin=1.0.9=h5eee18b_7 15 | - brotlipy=0.7.0=py38h27cfd23_1003 16 | - brunsli=0.1=h2531618_0 17 | - bzip2=1.0.8=h7f98852_4 18 | - c-ares=1.18.1=h7f8727e_0 19 | - ca-certificates=2022.10.11=h06a4308_0 20 | - certifi=2022.9.24=py38h06a4308_0 21 | - cffi=1.15.0=py38h7f8727e_0 22 | - cfitsio=3.470=h5893167_7 23 | - charls=2.2.0=h2531618_0 24 | - charset-normalizer=2.1.1=pyhd8ed1ab_0 25 | - cloudpickle=2.0.0=pyhd3eb1b0_0 26 | - cryptography=38.0.1=py38h9ce1e76_0 27 | - cudatoolkit=11.6.0=hecad31d_10 28 | - cycler=0.11.0=pyhd3eb1b0_0 29 | - cytoolz=0.12.0=py38h5eee18b_0 30 | - dask-core=2022.7.0=py38h06a4308_0 31 | - ffmpeg=4.3.2=hca11adc_0 32 | - fftw=3.3.9=h27cfd23_1 33 | - fonttools=4.25.0=pyhd3eb1b0_0 34 | - freetype=2.12.1=h4a9f257_0 35 | - fsspec=2022.10.0=py38h06a4308_0 36 | - gettext=0.21.0=hf68c758_0 37 | - giflib=5.2.1=h516909a_2 38 | - gmp=6.2.1=h58526e2_0 39 | - gnutls=3.6.15=he1e5248_0 40 | - icu=58.2=he6710b0_3 41 | - idna=3.4=pyhd8ed1ab_0 42 | - imagecodecs=2021.8.26=py38hf0132c2_1 43 | - importlib-metadata=4.11.3=py38h06a4308_0 44 | - importlib_metadata=4.11.3=hd3eb1b0_0 45 | - intel-openmp=2021.4.0=h06a4308_3561 46 | - jpeg=9e=h7f8727e_0 47 | - jxrlib=1.1=h7b6447c_2 48 | - kiwisolver=1.4.2=py38h295c915_0 49 | - krb5=1.19.2=hac12032_0 50 | - lame=3.100=h7b6447c_0 51 | - lcms2=2.12=h3be6417_0 52 | - lerc=3.0=h9c3ff4c_0 53 | - libaec=1.0.4=he6710b0_1 54 | - libbrotlicommon=1.0.9=h5eee18b_7 55 | - libbrotlidec=1.0.9=h5eee18b_7 56 | - libbrotlienc=1.0.9=h5eee18b_7 57 | - libcurl=7.85.0=h91b91d3_0 58 | - libdeflate=1.8=h7f98852_0 59 | - libedit=3.1.20210910=h7f8727e_0 60 | - libev=4.33=h7f8727e_1 61 | - libffi=3.2.1=hf484d3e_1007 62 | - libgcc-ng=11.2.0=h1234567_1 63 | - libgfortran-ng=11.2.0=h00389a5_1 64 | - libgfortran5=11.2.0=h1234567_1 65 | - libgomp=11.2.0=h1234567_1 66 | - libidn2=2.3.2=h7f8727e_0 67 | - libllvm11=11.1.0=h9e868ea_6 68 | - libnghttp2=1.46.0=hce63b2e_0 69 | - libpng=1.6.37=hed695b0_2 70 | - libssh2=1.10.0=h8f2d780_0 71 | - libstdcxx-ng=11.2.0=h1234567_1 72 | - libtasn1=4.16.0=h27cfd23_0 73 | - libtiff=4.4.0=hecacb30_1 74 | - libunistring=0.9.10=h14c3975_0 75 | - libwebp=1.2.4=h11a3e52_0 76 | - libwebp-base=1.2.4=h5eee18b_0 77 | - libxml2=2.9.14=h74e7548_0 78 | - libzopfli=1.0.3=he6710b0_0 79 | - llvmlite=0.39.1=py38he621ea3_0 80 | - locket=1.0.0=py38h06a4308_0 81 | - lz4-c=1.9.3=h9c3ff4c_1 82 | - matplotlib-base=3.5.3=py38hf590b9c_0 83 | - mkl=2021.4.0=h06a4308_640 84 | - mkl-service=2.4.0=py38h497a2fe_0 85 | - mkl_fft=1.3.1=py38h8666266_1 86 | - mkl_random=1.2.2=py38h1abd341_0 87 | - munkres=1.1.4=py_0 88 | - ncurses=6.3=h5eee18b_3 89 | - nettle=3.7.3=hbbd107a_1 90 | - networkx=2.8.4=py38h06a4308_0 91 | - numba=0.56.3=py38h417a72b_0 92 | - numexpr=2.8.3=py38h807cd23_0 93 | - numpy=1.22.3=py38he7a7128_0 94 | - numpy-base=1.22.3=py38hf524024_0 95 | - openh264=2.1.1=h780b84a_0 96 | - openjpeg=2.4.0=h3ad879b_0 97 | - openssl=1.1.1s=h7f8727e_0 98 | - packaging=21.3=pyhd3eb1b0_0 99 | - pandas=1.4.4=py38h6a678d5_0 100 | - partd=1.2.0=pyhd3eb1b0_1 101 | - pillow=9.2.0=py38hace64e9_1 102 | - pip=22.2.2=py38h06a4308_0 103 | - pycparser=2.21=pyhd8ed1ab_0 104 | - pyopenssl=22.1.0=pyhd8ed1ab_0 105 | - pyparsing=3.0.9=py38h06a4308_0 106 | - pysocks=1.7.1=py38h578d9bd_5 107 | - python=3.8.0=h0371630_2 108 | - python-dateutil=2.8.2=pyhd3eb1b0_0 109 | - python_abi=3.8=2_cp38 110 | - pytorch=1.12.1=py3.8_cuda11.6_cudnn8.3.2_0 111 | - pytorch-mutex=1.0=cuda 112 | - pytorch-scatter=2.0.9=py38_torch_1.12.0_cu116 113 | - pytz=2022.1=py38h06a4308_0 114 | - pywavelets=1.3.0=py38h7f8727e_0 115 | - readline=7.0=h7b6447c_5 116 | - requests=2.28.1=pyhd8ed1ab_1 117 | - scikit-image=0.19.2=py38h51133e4_0 118 | - scipy=1.9.3=py38h14f4228_0 119 | - seaborn=0.12.0=py38h06a4308_0 120 | - setuptools=65.5.0=py38h06a4308_0 121 | - six=1.16.0=pyh6c4a22f_0 122 | - snappy=1.1.9=h295c915_0 123 | - sqlite=3.33.0=h62c20be_0 124 | - tbb=2021.6.0=hdb19cb5_0 125 | - tifffile=2021.7.2=pyhd3eb1b0_2 126 | - tk=8.6.12=h1ccaba5_0 127 | - toolz=0.12.0=py38h06a4308_0 128 | - torchaudio=0.12.1=py38_cu116 129 | - torchvision=0.13.1=py38_cu116 130 | - tqdm=4.64.1=py38h06a4308_0 131 | - typing_extensions=4.4.0=pyha770c72_0 132 | - urllib3=1.26.12=py38h06a4308_0 133 | - wheel=0.37.1=pyhd3eb1b0_0 134 | - x264=1!161.3030=h7f98852_1 135 | - xz=5.2.6=h5eee18b_0 136 | - yaml=0.2.5=h7b6447c_0 137 | - zfp=0.5.5=h2531618_6 138 | - zlib=1.2.13=h5eee18b_0 139 | - zstd=1.5.2=ha4553b6_0 140 | - pip: 141 | - addict==2.4.0 142 | - aiosignal==1.2.0 143 | - asttokens==2.1.0 144 | - attrs==22.1.0 145 | - backcall==0.2.0 146 | - click==8.0.4 147 | - configargparse==1.5.3 148 | - dash==2.7.0 149 | - dash-core-components==2.0.0 150 | - dash-html-components==2.0.0 151 | - dash-table==5.0.0 152 | - debugpy==1.6.3 153 | - decorator==5.1.1 154 | - distlib==0.3.6 155 | - entrypoints==0.4 156 | - executing==1.2.0 157 | - fastjsonschema==2.16.2 158 | - filelock==3.8.0 159 | - flask==2.2.2 160 | - frozenlist==1.3.1 161 | - ftfy==6.1.1 162 | - greenlet==2.0.1 163 | - grpcio==1.43.0 164 | - h5py==3.7.0 165 | - imageio==2.22.4 166 | - imageio-ffmpeg==0.4.7 167 | - importlib-resources==5.10.0 168 | - ipykernel==6.17.0 169 | - ipython==8.6.0 170 | - ipywidgets==8.0.2 171 | - itsdangerous==2.1.2 172 | - jedi==0.18.1 173 | - jinja2==3.1.2 174 | - joblib==1.2.0 175 | - jsonschema==4.17.0 176 | - jupyter-client==7.4.4 177 | - jupyter-core==4.11.2 178 | - jupyterlab-widgets==3.0.3 179 | - markupsafe==2.1.1 180 | - matplotlib-inline==0.1.6 181 | - msgpack==1.0.4 182 | - nbformat==5.5.0 183 | - neovim==0.3.1 184 | - nest-asyncio==1.5.6 185 | - open3d==0.16.0 186 | - opencv-python==4.6.0.66 187 | - parso==0.8.3 188 | - pexpect==4.8.0 189 | - pickleshare==0.7.5 190 | - pkgutil-resolve-name==1.3.10 191 | - platformdirs==2.5.3 192 | - plotly==5.11.0 193 | - prompt-toolkit==3.0.32 194 | - protobuf==3.20.1 195 | - psutil==5.9.3 196 | - ptyprocess==0.7.0 197 | - pure-eval==0.2.2 198 | - pybullet==3.2.5 199 | - pygments==2.13.0 200 | - pynvim==0.4.3 201 | - pyquaternion==0.9.9 202 | - pyrsistent==0.19.2 203 | - pyyaml==6.0 204 | - pyzmq==24.0.1 205 | - ray==2.0.1 206 | - regex==2022.10.31 207 | - rich 208 | - sacremoses==0.0.53 209 | - scikit-learn==1.1.3 210 | - stack-data==0.6.0 211 | - tabulate==0.9.0 212 | - tenacity==8.1.0 213 | - tensorboardx==2.5.1 214 | - threadpoolctl==3.1.0 215 | - tokenizers==0.10.3 216 | - torchtyping==0.1.4 217 | - tornado==6.2 218 | - traitlets==5.5.0 219 | - transformers==4.5.1 220 | - transforms3d==0.4.1 221 | - typeguard==2.13.3 222 | - typer 223 | - virtualenv==20.16.6 224 | - wcwidth==0.2.5 225 | - werkzeug==2.2.2 226 | - widgetsnbextension==4.0.3 227 | - zipp==3.10.0 -------------------------------------------------------------------------------- /summarize.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import rich 3 | import pickle 4 | from dataset import synonyms 5 | import numpy as np 6 | from rich.console import Console 7 | from rich.table import Table 8 | 9 | test_objs = set( 10 | map(lambda l: l.rstrip().lstrip(), open("test_semantic_classes.txt", "r")) 11 | ) 12 | 13 | 14 | def summarize_ovssc(metric="voxel32x32x32_iou"): 15 | ssc_approaches = { 16 | "Semantic Aware": pickle.load( 17 | open("models/semaware/ovssc/ovssc_eval_stats.pkl", "rb") 18 | ), 19 | "SemAbs + [Chefer et al]": pickle.load( 20 | open("models/chefer_et_al/ovssc/ovssc_eval_stats.pkl", "rb") 21 | ), 22 | "Ours": pickle.load( 23 | open( 24 | "models/ours/ovssc/ovssc_eval_stats.pkl", 25 | "rb", 26 | ) 27 | ), 28 | } 29 | 30 | ovssc_stats = { 31 | "approach": [], 32 | "novel rooms": [], 33 | "novel visual": [], 34 | "novel vocab": [], 35 | "novel class": [], 36 | } 37 | pd.options.display.float_format = "{:,.3f}".format 38 | for approach, approach_stats in ssc_approaches.items(): 39 | # approach_stats = approach_stats[approach_stats.label!=''] 40 | approach_stats["room_id"] = approach_stats["scene_id"].apply( 41 | lambda s: int(s.split("_")[0].split("FloorPlan")[1]) 42 | ) 43 | approach_stats[metric] = approach_stats[metric] * 100 44 | cutoff_analysis = approach_stats.groupby("cutoff")[[metric]].mean() 45 | best_cutoff = cutoff_analysis[metric].idxmax() 46 | df = approach_stats[approach_stats.cutoff == best_cutoff] 47 | novel_class_mask = df.label.isin(test_objs) 48 | novel_vocab_mask = df.label.isin(synonyms.values()) 49 | ovssc_stats["approach"].append(approach) 50 | 51 | novel_rooms_df = df[(df.split == "unseen_instances") & (~novel_class_mask)] 52 | mean_per_room = np.array(novel_rooms_df.groupby("room_id")[metric].mean()) 53 | ovssc_stats["novel rooms"].append(mean_per_room.mean()) 54 | 55 | novel_rooms_dr_df = df[ 56 | (df.split == "unseen_instances_dr") & (~novel_class_mask) 57 | ] 58 | mean_per_room = np.array(novel_rooms_dr_df.groupby("room_id")[metric].mean()) 59 | ovssc_stats["novel visual"].append(mean_per_room.mean()) 60 | 61 | unseen_class_df = df[novel_class_mask] 62 | mean_per_label = unseen_class_df.groupby("label")[metric].mean() 63 | ovssc_stats["novel class"].append(np.array(mean_per_label).mean()) 64 | 65 | unseen_vocab_df = df[ 66 | (df.split == "unseen_instances_synonyms") & novel_vocab_mask 67 | ] 68 | mean_per_label = unseen_vocab_df.groupby("label")[metric].mean() 69 | ovssc_stats["novel vocab"].append(np.array(mean_per_label).mean()) 70 | ovssc_stats = pd.DataFrame.from_dict(ovssc_stats) 71 | table = Table(title="OVSSC THOR", box=rich.box.MINIMAL_DOUBLE_HEAD) 72 | table.add_column("Approach", justify="left") 73 | table.add_column("Novel Room", justify="right") 74 | table.add_column("Novel Visual", justify="right") 75 | table.add_column("Novel Vocab", justify="right") 76 | table.add_column("Novel Class", justify="right") 77 | for row in ovssc_stats.to_csv().split("\n")[1:-1]: 78 | approach, novel_room, novel_visual, novel_vocab, novel_class = row.split(",")[ 79 | 1: 80 | ] 81 | table.add_row( 82 | approach, 83 | f"{float(novel_room):.01f}", 84 | f"{float(novel_visual):.01f}", 85 | f"{float(novel_vocab):.01f}", 86 | f"{float(novel_class):.01f}", 87 | end_section=approach == "SemAbs + [Chefer et al]", 88 | style="green" if approach == "Ours" else "white", 89 | ) 90 | console = Console() 91 | console.print(table) 92 | 93 | 94 | def summarize_vool(metric="voxel32x32x32_iou"): 95 | vool_approaches = { 96 | "Semantic Aware": pickle.load( 97 | open("models/semaware/vool/vool_eval_stats.pkl", "rb") 98 | ), 99 | "ClipSpatial": pickle.load( 100 | open("models/clipspatial/vool/vool_eval_stats.pkl", "rb") 101 | ), 102 | "SemAbs + [Chefer et al]": pickle.load( 103 | open("models/chefer_et_al/vool/vool_eval_stats.pkl", "rb") 104 | ), 105 | "Ours": pickle.load(open("models/ours/vool/vool_eval_stats.pkl", "rb")), 106 | } 107 | vool_stats = { 108 | "approach": [], 109 | "relation": [], 110 | "novel rooms": [], 111 | "novel visual": [], 112 | "novel vocab": [], 113 | "novel class": [], 114 | } 115 | relations = vool_approaches["Ours"].spatial_relation_name.unique() 116 | for approach in vool_approaches.keys(): 117 | approach_stats = vool_approaches[approach] 118 | approach_stats["room_id"] = approach_stats["scene_id"].apply( 119 | lambda s: int(s.split("_")[0].split("FloorPlan")[1]) 120 | ) 121 | cutoff_analysis = approach_stats.groupby("cutoff")[[metric]].mean() 122 | best_cutoff = cutoff_analysis[metric].idxmax() 123 | approach_stats[metric] = approach_stats[metric] * 100 124 | for relation in relations: 125 | if relation == "[pad]": 126 | continue 127 | df = approach_stats[approach_stats.cutoff == best_cutoff] 128 | df = df[df.spatial_relation_name == relation] 129 | 130 | novel_vocab_mask = df.target_obj_name.isin( 131 | synonyms.values() 132 | ) | df.reference_obj_name.isin(synonyms.values()) 133 | novel_class_mask = df.target_obj_name.isin( 134 | test_objs 135 | ) | df.reference_obj_name.isin(test_objs) 136 | 137 | vool_stats["approach"].append(approach) 138 | vool_stats["relation"].append(relation) 139 | novel_rooms_df = df[(df.split == "unseen_instances") & (~novel_class_mask)] 140 | mean_per_room = np.array(novel_rooms_df.groupby("room_id")[metric].mean()) 141 | vool_stats["novel rooms"].append(np.nanmean(mean_per_room)) 142 | novel_rooms_dr_df = df[ 143 | (df.split == "unseen_instances_dr") & (~novel_class_mask) 144 | ] 145 | mean_per_room = np.array( 146 | novel_rooms_dr_df.groupby("room_id")[metric].mean() 147 | ) 148 | vool_stats["novel visual"].append(np.nanmean(mean_per_room)) 149 | 150 | unseen_class_df = df[novel_class_mask] 151 | vool_stats["novel class"].append(np.nanmean(unseen_class_df[metric])) 152 | unseen_vocab_df = df[ 153 | (df.split == "unseen_instances_synonyms") & novel_vocab_mask 154 | ] 155 | vool_stats["novel vocab"].append(np.nanmean(unseen_vocab_df[metric])) 156 | vool_stats = pd.DataFrame.from_dict(vool_stats) 157 | for approach_i, approach in enumerate(vool_approaches.keys()): 158 | mean_df = pd.DataFrame.from_dict( 159 | { 160 | "approach": [approach], 161 | "relation": ["mean"], 162 | **{ 163 | split: [ 164 | np.array( 165 | vool_stats[(vool_stats.approach == approach)][[split]] 166 | ).mean() 167 | ] 168 | for split in [ 169 | "novel rooms", 170 | "novel visual", 171 | "novel vocab", 172 | "novel class", 173 | ] 174 | }, 175 | } 176 | ) 177 | vool_stats = pd.concat( 178 | [ 179 | vool_stats.iloc[0 : (approach_i + 1) * 6 + approach_i], 180 | mean_df, 181 | vool_stats.iloc[(approach_i + 1) * 6 + approach_i :], 182 | ] 183 | ) 184 | table = Table(title="FULL VOOL THOR", box=rich.box.MINIMAL_DOUBLE_HEAD) 185 | table.add_column("Approach", justify="left") 186 | table.add_column("Spatial Relation", justify="left") 187 | table.add_column("Novel Room", justify="right") 188 | table.add_column("Novel Visual", justify="right") 189 | table.add_column("Novel Vocab", justify="right") 190 | table.add_column("Novel Class", justify="right") 191 | last_approach = "" 192 | for row in vool_stats.to_csv().split("\n")[1:-1]: 193 | ( 194 | approach, 195 | spatial_relation, 196 | novel_room, 197 | novel_visual, 198 | novel_vocab, 199 | novel_class, 200 | ) = row.split(",")[1:] 201 | table.add_row( 202 | approach if approach != last_approach else "", 203 | spatial_relation, 204 | f"{float(novel_room):.01f}", 205 | f"{float(novel_visual):.01f}", 206 | f"{float(novel_vocab):.01f}", 207 | f"{float(novel_class):.01f}", 208 | end_section=spatial_relation == "mean", 209 | style=("green" if approach == "Ours" else "white"), 210 | ) 211 | last_approach = approach 212 | console = Console() 213 | console.print(table) 214 | 215 | 216 | def summarize_nyuv2(metric="voxel60x60x60_iou"): 217 | ssc_approaches = { 218 | "Ours (Supervised)": pickle.load( 219 | open( 220 | "models/ours/ovssc/ovssc_eval_stats_supervised_nyu_merged.pkl", 221 | "rb", 222 | ) 223 | ), 224 | "Ours (Zeroshot)": pickle.load( 225 | open( 226 | "models/ours/ovssc/ovssc_eval_stats_zs_nyu_merged.pkl", 227 | "rb", 228 | ) 229 | ), 230 | } 231 | classes = [ 232 | "ceiling", 233 | "floor", 234 | "wall", 235 | "window", 236 | "chair", 237 | "bed", 238 | "sofa", 239 | "table", 240 | "tvs", 241 | "furn", 242 | "objs", 243 | "mean", 244 | ] 245 | table = Table(title="OVSSC NYU", box=rich.box.MINIMAL_DOUBLE_HEAD) 246 | table.add_column("Approach", justify="left") 247 | for c in classes: 248 | table.add_column(c.title(), justify="right") 249 | for approach, approach_stats in ssc_approaches.items(): 250 | approach_stats[metric] = approach_stats[metric] * 100 251 | cutoff_analysis = approach_stats.groupby("cutoff")[[metric]].mean() 252 | best_cutoff = cutoff_analysis[metric].idxmax() 253 | df = approach_stats[approach_stats.cutoff == best_cutoff] 254 | row = [approach] 255 | for c in classes: 256 | if c != "mean": 257 | row.append(f"{df[df.label == c][metric].mean():.01f}") 258 | else: 259 | row.append( 260 | f'{np.array(df.groupby("label")[metric].mean()).mean():.01f}' 261 | ) 262 | table.add_row( 263 | *row, 264 | end_section=approach == "Ours (Supervised)", 265 | style="green" if approach == "Ours (Zeroshot)" else "white", 266 | ) 267 | console = Console() 268 | console.print(table) 269 | 270 | 271 | if __name__ == "__main__": 272 | summarize_ovssc() 273 | summarize_vool() 274 | summarize_nyuv2() 275 | -------------------------------------------------------------------------------- /test_semantic_classes.txt: -------------------------------------------------------------------------------- 1 | pot 2 | mug 3 | safe 4 | teddy bear 5 | basket ball 6 | wine bottle 7 | -------------------------------------------------------------------------------- /test_thor_rooms.txt: -------------------------------------------------------------------------------- 1 | FloorPlan26_physics 2 | FloorPlan27_physics 3 | FloorPlan28_physics 4 | FloorPlan29_physics 5 | FloorPlan30_physics 6 | FloorPlan226_physics 7 | FloorPlan227_physics 8 | FloorPlan228_physics 9 | FloorPlan229_physics 10 | FloorPlan230_physics 11 | FloorPlan326_physics 12 | FloorPlan327_physics 13 | FloorPlan328_physics 14 | FloorPlan329_physics 15 | FloorPlan330_physics 16 | FloorPlan426_physics 17 | FloorPlan427_physics 18 | FloorPlan428_physics 19 | FloorPlan429_physics 20 | FloorPlan430_physics 21 | -------------------------------------------------------------------------------- /train_ovssc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.nn.functional import binary_cross_entropy_with_logits 4 | from net import SemAbs3D, SemanticAwareOVSSC 5 | import utils 6 | import pandas as pd 7 | from dataset import SceneCompletionDataset 8 | from typing import Dict, Tuple, Union 9 | 10 | 11 | def get_detailed_stats( 12 | prediction, 13 | gt_label, 14 | xyz_pts, 15 | patch_labels, 16 | scene_ids, 17 | scene_bounds, 18 | ignore_pts, 19 | detailed_analysis=False, 20 | eval_device="cuda", 21 | **kwargs, 22 | ): 23 | num_scenes, num_patches = patch_labels.shape 24 | retvals = { 25 | "scene_id": np.array([[scene_id] * num_patches for scene_id in scene_ids]) 26 | .reshape(-1) 27 | .tolist(), 28 | "label": patch_labels.reshape(-1).tolist(), 29 | } 30 | retvals.update( 31 | { 32 | f"point_{k}": v 33 | for k, v in utils.prediction_analysis( 34 | prediction=prediction.to(eval_device), 35 | label=gt_label.to(eval_device), 36 | ignore=ignore_pts.to(eval_device), 37 | ).items() 38 | } 39 | ) 40 | voxelized_pts = utils.voxelize_points( 41 | prediction=prediction, 42 | label=gt_label, 43 | xyz_pts=xyz_pts, 44 | voxel_shape=(32, 32, 32), 45 | scene_bounds=scene_bounds, 46 | ignore_pts=ignore_pts, 47 | ) 48 | retvals.update( 49 | { 50 | "voxel32x32x32_" + k: v 51 | for k, v in utils.prediction_analysis( 52 | **{k: v.to(eval_device) for k, v in voxelized_pts.items()} 53 | ).items() 54 | } 55 | ) 56 | if detailed_analysis: 57 | voxelized_pts = utils.voxelize_points( 58 | prediction=prediction, 59 | label=gt_label, 60 | xyz_pts=xyz_pts, 61 | voxel_shape=(64, 64, 64), 62 | scene_bounds=scene_bounds, 63 | ignore_pts=ignore_pts, 64 | ) 65 | retvals.update( 66 | { 67 | "voxel64x64x64_" + k: v 68 | for k, v in utils.prediction_analysis( 69 | **{k: v.to(eval_device) for k, v in voxelized_pts.items()} 70 | ).items() 71 | } 72 | ) 73 | for i, label in enumerate(patch_labels.reshape(-1).tolist()): 74 | if label == "": # skip padding classes 75 | for k in retvals.keys(): 76 | if "voxel" in k or "point" in k: 77 | retvals[k][i] = np.NAN 78 | return pd.DataFrame.from_dict(retvals) 79 | 80 | 81 | def get_losses( 82 | net, 83 | batch: dict, 84 | cutoffs=[0], 85 | balance_positive_negative: bool = False, 86 | **kwargs, 87 | ) -> Tuple[Dict[str, Union[float, torch.Tensor]], pd.DataFrame]: 88 | stats = {} 89 | num_pts = batch["output_xyz_pts"].shape[2] 90 | if num_pts <= 500000: 91 | outputs = net(**batch) 92 | else: 93 | num_patches = 1 94 | # probably CUDA OOM 95 | outputs = torch.cat( 96 | [ 97 | net( 98 | **{ 99 | **batch, 100 | "input_feature_pts": batch["input_feature_pts"][ 101 | :, patch_i * num_patches : (patch_i + 1) * num_patches, ... 102 | ] 103 | if batch["input_feature_pts"].shape[1] 104 | == batch["output_xyz_pts"].shape[1] 105 | else batch["input_feature_pts"], 106 | "output_xyz_pts": batch["output_xyz_pts"][ 107 | :, patch_i * num_patches : (patch_i + 1) * num_patches, ... 108 | ], 109 | "semantic_class_features": batch["semantic_class_features"][ 110 | :, patch_i * num_patches : (patch_i + 1) * num_patches, ... 111 | ], 112 | } 113 | ) 114 | for patch_i in range(len(batch["patch_labels"]) // num_patches + 1) 115 | if np.prod( 116 | batch["output_xyz_pts"][ 117 | :, patch_i * num_patches : (patch_i + 1) * num_patches, ... 118 | ].shape 119 | ) 120 | > 0 121 | ], 122 | dim=1, 123 | ) 124 | 125 | batch["patch_labels"] = np.array(batch["patch_labels"]).T 126 | padding_mask = torch.from_numpy(batch["patch_labels"] == "").bool() 127 | batch["out_of_bounds_pts"] = batch["out_of_bounds_pts"].view(outputs.shape) 128 | ignore_pts_mask = torch.zeros_like(outputs).bool() 129 | # ignore all padding labels 130 | ignore_pts_mask[padding_mask] = True 131 | # ignore all points out of bounds 132 | ignore_pts_mask = torch.logical_or(ignore_pts_mask, batch["out_of_bounds_pts"]) 133 | # don't eval on points outside of frustum 134 | ignore_pts_mask = torch.logical_or( 135 | ignore_pts_mask, batch["out_of_frustum_pts_mask"] 136 | ) 137 | stats["loss"] = binary_cross_entropy_with_logits( 138 | outputs[~ignore_pts_mask], 139 | batch["output_label_pts"][~ignore_pts_mask], 140 | weight=utils.get_bce_weight( 141 | output_label_pts=batch["output_label_pts"], 142 | balance_positive_negative=balance_positive_negative, 143 | )[~ignore_pts_mask], 144 | ) 145 | with torch.no_grad(): 146 | vision_accuracy_mask = ( 147 | (outputs > 0.0).long() == batch["output_label_pts"] 148 | ).float() 149 | stats["accuracy"] = vision_accuracy_mask[~ignore_pts_mask].mean() 150 | detailed_stats = [ 151 | get_detailed_stats( 152 | prediction=outputs > cutoff, 153 | gt_label=batch["output_label_pts"].bool(), 154 | xyz_pts=batch["output_xyz_pts"], 155 | ignore_pts=ignore_pts_mask, 156 | patch_labels=batch["patch_labels"], 157 | scene_ids=batch["scene_id"], 158 | eval_device=net.device, 159 | **kwargs, 160 | ) 161 | for cutoff in cutoffs 162 | ] 163 | for detailed_stat, cutoff in zip(detailed_stats, cutoffs): 164 | detailed_stat["cutoff"] = [cutoff] * len(detailed_stat) 165 | detailed_stats = pd.concat(detailed_stats) 166 | for k in detailed_stats.columns: 167 | if "iou" in k: 168 | stats[k] = detailed_stats[k].mean() 169 | return stats, detailed_stats 170 | 171 | 172 | approach = { 173 | "semantic_abstraction": SemAbs3D, 174 | "semantic_aware": SemanticAwareOVSSC, 175 | } 176 | 177 | 178 | if __name__ == "__main__": 179 | parser = utils.config_parser() 180 | parser.add_argument("--log", type=str, required=True) 181 | parser.add_argument( 182 | "--approach", choices=approach.keys(), default="semantic_abstraction" 183 | ) 184 | args = parser.parse_args() 185 | if args.approach == "semantic_aware": 186 | args.network_inputs = ["rgb"] 187 | utils.train( 188 | get_losses_fn=get_losses, 189 | **utils.setup_experiment( 190 | args=args, 191 | ddp=len(args.gpus) > 1, 192 | net_class=approach[args.approach], 193 | dataset_class=SceneCompletionDataset, 194 | split_file_path=args.file_path + "/ssc_split.pkl", 195 | ), 196 | **vars(args), 197 | ) 198 | -------------------------------------------------------------------------------- /train_vool.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Tuple, Union 2 | import numpy as np 3 | from dataset import ObjectLocalizationDataset 4 | from net import ( 5 | SemAbsVOOL, 6 | ClipSpatialVOOL, 7 | SemanticAwareVOOL, 8 | ) 9 | import utils 10 | from torch.nn.functional import binary_cross_entropy_with_logits 11 | import torch 12 | import pandas as pd 13 | 14 | 15 | def get_detailed_stats( 16 | prediction, 17 | gt_label, 18 | xyz_pts, 19 | scene_ids, 20 | target_obj_names, 21 | reference_obj_names, 22 | spatial_relation_names, 23 | scene_bounds, 24 | ignore_pts, 25 | detailed_analysis=False, 26 | eval_device="cuda", 27 | **kwargs, 28 | ): 29 | num_scenes, num_descs = gt_label.shape[:2] 30 | 31 | retvals = { 32 | "scene_id": np.array([[scene_id] * num_descs for scene_id in scene_ids]) 33 | .reshape(-1) 34 | .tolist(), 35 | "target_obj_name": np.array(target_obj_names).T.reshape(-1).tolist(), 36 | "reference_obj_name": np.array(reference_obj_names).T.reshape(-1).tolist(), 37 | "spatial_relation_name": np.array(spatial_relation_names) 38 | .T.reshape(-1) 39 | .tolist(), 40 | } 41 | retvals.update( 42 | { 43 | f"point_{k}": v 44 | for k, v in utils.prediction_analysis( 45 | prediction=prediction.to(eval_device), 46 | label=gt_label.to(eval_device), 47 | ignore=ignore_pts.to(eval_device), 48 | ).items() 49 | } 50 | ) 51 | num_desc_b = 10 52 | outputs = [] 53 | for i in np.arange(0, num_descs + num_desc_b + 1, num_desc_b): 54 | if np.prod(prediction[:, i : i + num_desc_b].shape) == 0: 55 | continue 56 | outputs.append( 57 | utils.voxelize_points( 58 | prediction=prediction[:, i : i + num_desc_b], 59 | label=gt_label[:, i : i + num_desc_b], 60 | xyz_pts=xyz_pts[:, i : i + num_desc_b], 61 | voxel_shape=(32, 32, 32), 62 | scene_bounds=scene_bounds, 63 | ignore_pts=ignore_pts[:, i : i + num_desc_b], 64 | device=eval_device, 65 | ) 66 | ) 67 | voxelized_pts = { 68 | k: torch.cat([output[k] for output in outputs], dim=1) 69 | for k in outputs[0].keys() 70 | } 71 | retvals.update( 72 | { 73 | "voxel32x32x32_" + k: v 74 | for k, v in utils.prediction_analysis( 75 | **{k: v.to(eval_device) for k, v in voxelized_pts.items()} 76 | ).items() 77 | } 78 | ) 79 | if detailed_analysis: 80 | outputs = [] 81 | for i in np.arange(0, num_descs + num_desc_b + 1, num_desc_b): 82 | if np.prod(prediction[:, i : i + num_desc_b].shape) == 0: 83 | continue 84 | outputs.append( 85 | utils.voxelize_points( 86 | prediction=prediction[:, i : i + num_desc_b], 87 | label=gt_label[:, i : i + num_desc_b], 88 | xyz_pts=xyz_pts[:, i : i + num_desc_b], 89 | voxel_shape=(64, 64, 64), 90 | scene_bounds=scene_bounds, 91 | ignore_pts=ignore_pts[:, i : i + num_desc_b], 92 | device=eval_device, 93 | ) 94 | ) 95 | voxelized_pts = { 96 | k: torch.cat([output[k] for output in outputs], dim=1) 97 | for k in outputs[0].keys() 98 | } 99 | retvals.update( 100 | { 101 | "voxel64x64x64_" + k: v 102 | for k, v in utils.prediction_analysis( 103 | **{k: v.to(eval_device) for k, v in voxelized_pts.items()} 104 | ).items() 105 | } 106 | ) 107 | 108 | for i, spatial_relation in enumerate( 109 | np.array(spatial_relation_names).T.reshape(-1) 110 | ): 111 | if spatial_relation == "[pad]": # skip padding classes 112 | for k in retvals.keys(): 113 | if "voxel" in k or "point" in k: 114 | retvals[k][i] = np.NAN 115 | return pd.DataFrame.from_dict(retvals) 116 | 117 | 118 | def get_losses( 119 | net, batch: dict, cutoffs=[-2.0], balance_positive_negative: bool = False, **kwargs 120 | ) -> Tuple[Dict[str, Union[float, torch.Tensor]], pd.DataFrame]: 121 | stats = {} 122 | batch_size, total_num_descs, num_pts = batch["output_label_pts"].shape 123 | if num_pts <= 500000: 124 | outputs = net(**batch) 125 | else: 126 | num_descs = 1 127 | # probably CUDA OOM 128 | outputs = torch.cat( 129 | [ 130 | net( 131 | **{ 132 | **batch, 133 | "input_target_saliency_pts": batch["input_target_saliency_pts"][ 134 | :, desc_i * num_descs : (desc_i + 1) * num_descs, ... 135 | ], 136 | "input_reference_saliency_pts": batch[ 137 | "input_reference_saliency_pts" 138 | ][:, desc_i * num_descs : (desc_i + 1) * num_descs, ...], 139 | "input_description_saliency_pts": batch[ 140 | "input_description_saliency_pts" 141 | ][:, desc_i * num_descs : (desc_i + 1) * num_descs, ...], 142 | "output_xyz_pts": batch["output_xyz_pts"][ 143 | :, desc_i * num_descs : (desc_i + 1) * num_descs, ... 144 | ], 145 | "spatial_relation_name": ( 146 | np.array(batch["spatial_relation_name"]) 147 | .T[:, desc_i * num_descs : (desc_i + 1) * num_descs] 148 | .T 149 | ), 150 | } 151 | ) 152 | for desc_i in range(total_num_descs // num_descs + 1) 153 | if np.prod( 154 | batch["output_xyz_pts"][ 155 | :, desc_i * num_descs : (desc_i + 1) * num_descs, ... 156 | ].shape 157 | ) 158 | > 0 159 | ], 160 | dim=1, 161 | ) 162 | 163 | padding_mask = torch.from_numpy( 164 | np.array(batch["spatial_relation_name"]).T == "[pad]" 165 | ).bool() 166 | ignore_pts_mask = torch.zeros_like(outputs).bool() 167 | # ignore all padding labels 168 | ignore_pts_mask[padding_mask] = True 169 | # ignore all points out of bounds 170 | ignore_pts_mask = torch.logical_or(ignore_pts_mask, batch["out_of_bounds_pts"]) 171 | stats["loss"] = binary_cross_entropy_with_logits( 172 | outputs, 173 | batch["output_label_pts"], 174 | weight=utils.get_bce_weight( 175 | output_label_pts=batch["output_label_pts"], 176 | balance_positive_negative=balance_positive_negative, 177 | ), 178 | ) 179 | 180 | with torch.no_grad(): 181 | accuracy = ((outputs > 0.0).long() == batch["output_label_pts"]).float()[ 182 | ~ignore_pts_mask 183 | ] 184 | stats["accuracy"] = accuracy.mean() 185 | detailed_stats = [ 186 | get_detailed_stats( 187 | prediction=outputs > cutoff, 188 | gt_label=batch["output_label_pts"].bool(), 189 | xyz_pts=batch["output_xyz_pts"], 190 | ignore_pts=ignore_pts_mask, 191 | target_obj_names=batch["target_obj_name"], 192 | reference_obj_names=batch["reference_obj_name"], 193 | spatial_relation_names=batch["spatial_relation_name"], 194 | scene_ids=batch["scene_id"], 195 | eval_device=net.device, 196 | **kwargs, 197 | ) 198 | for cutoff in cutoffs 199 | ] 200 | for detailed_stat, cutoff in zip(detailed_stats, cutoffs): 201 | detailed_stat["cutoff"] = [cutoff] * len(detailed_stat) 202 | detailed_stats = pd.concat(detailed_stats) 203 | for k in detailed_stats.columns: 204 | if "iou" in k: 205 | stats[k] = detailed_stats[k].mean() 206 | return stats, detailed_stats 207 | 208 | 209 | approach = { 210 | "semantic_abstraction": SemAbsVOOL, 211 | "semantic_aware": SemanticAwareVOOL, 212 | "clip_spatial": ClipSpatialVOOL, 213 | } 214 | 215 | if __name__ == "__main__": 216 | parser = utils.config_parser() 217 | parser.add_argument("--log", type=str, required=True) 218 | parser.add_argument( 219 | "--approach", choices=approach.keys(), default="semantic_abstraction" 220 | ) 221 | args = parser.parse_args() 222 | if args.approach == "semantic_aware": 223 | args.network_inputs = ["rgb"] 224 | utils.train( 225 | get_losses_fn=get_losses, 226 | **utils.setup_experiment( 227 | args=args, 228 | net_class=approach[args.approach], 229 | dataset_class=ObjectLocalizationDataset, 230 | split_file_path=args.file_path + "/vool_split.pkl", 231 | ), 232 | **vars(args), 233 | ) 234 | --------------------------------------------------------------------------------