├── .gitignore ├── LICENSE ├── README.md ├── docker └── pytorch.dockerfile ├── pyproject.toml ├── requirements.txt └── src └── vqmpt ├── __init__.py ├── modules ├── __init__.py ├── autoregressive.py ├── context_encoder.py ├── decoder.py ├── encoder.py ├── env_encoder.py ├── quantizer.py └── sublayers.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # VSCode files 163 | .vscode/** 164 | .devcontainer.json -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2023, Jacob John Johnson 4 | 5 | Redistribution and use in source and binary forms, with or without 6 | modification, are permitted provided that the following conditions are met: 7 | 8 | 1. Redistributions of source code must retain the above copyright notice, this 9 | list of conditions and the following disclaimer. 10 | 11 | 2. Redistributions in binary form must reproduce the above copyright notice, 12 | this list of conditions and the following disclaimer in the documentation 13 | and/or other materials provided with the distribution. 14 | 15 | 3. Neither the name of the copyright holder nor the names of its 16 | contributors may be used to endorse or promote products derived from 17 | this software without specific prior written permission. 18 | 19 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 20 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 21 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 22 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 23 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 24 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 25 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 26 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 27 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vector Quantized - Motion Planning Transformers 2 | 3 | This github repo contains the models, and helper functions for generating sampling distributions using [VQ-MPT](https://sites.google.com/ucsd.edu/vq-mpt/home). 4 | 5 | ## Installing the package 6 | 7 | To install the package, clone this repo to your local machine. 8 | 9 | ``` 10 | git clone https://github.com/jacobjj/vqmpt.git 11 | ``` 12 | 13 | To install the package, go to cloned repo, and run the following command. 14 | 15 | ``` 16 | pip install -e . 17 | ``` 18 | 19 | ### Runing inside a container 20 | We provide the dockerfiles to run our models. But you need to have our [base](https://drive.google.com/file/d/1DFC5nKoPTKF6ASZHnqi5FA8NF54LKD7A/view?usp=sharing) image downloaded and [loaded](https://docs.docker.com/engine/reference/commandline/load/) on your system before you build the container. Afterwards, either clone or attach this repo inside the container to get started with using our models. 21 | 22 | ## Loading models 23 | 24 | You can get the pre-trained models for the panda robot from here - [Panda Models](https://drive.google.com/file/d/1B0KVBxYBi0fCQcvagponF6j_2TikZfN7/view?usp=sharing) 25 | 26 | To load the models, use the following: 27 | 28 | ``` 29 | from vqmpt import utils 30 | import torch 31 | 32 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 33 | quantizer_model, decoder_model, context_env_encoder, ar_model = utils.get_inference_models( 34 | decoder_model_folder, 35 | ar_model_folder, 36 | device, 37 | n_e=2048, 38 | e_dim=8, 39 | ) 40 | ``` 41 | 42 | ## Getting distributions 43 | To get the distribution, you can use the `get_search_dist` in `utils.py`. Normalize the start and goal configurations between [0, 1], and stack them together to form a 2*n_dim. 44 | 45 | ``` 46 | search_dist_mu, search_dist_sigma = utils.get_search_dist( 47 | norm_start_n_goal, 48 | depth_points, 49 | context_encoder, 50 | decoder_model, 51 | ar_model, 52 | quantizer_model, 53 | num_keys=2048, 54 | device=device, 55 | ) 56 | ``` 57 | 58 | ## Sampling from the distribution 59 | 60 | If you are using the distribution with OMPL, here is an example of writing a custom sampling function. 61 | 62 | ``` 63 | from ompl import base as ob 64 | 65 | class StateSamplerRegion(ob.StateSampler): 66 | '''A class to sample robot joints from a given joint configuration. 67 | ''' 68 | def __init__(self, space, qMin=None, qMax=None, dist_mu=None, dist_sigma=None): 69 | ''' 70 | If dist_mu is None, then set the sampler as a uniform sampler. 71 | :param space: an object of type ompl.base.Space 72 | :param qMin: np.array of minimum joint bound 73 | :param qMax: np.array of maximum joint bound 74 | :param region: np.array of points to sample from 75 | ''' 76 | super(StateSamplerRegion, self).__init__(space) 77 | self.name_ ='region' 78 | self.q_min = qMin 79 | self.q_max = qMax 80 | if dist_mu is None: 81 | self.X = None 82 | self.U = stats.uniform(np.zeros_like(qMin), np.ones_like(qMax)) 83 | else: 84 | self.seq_num = dist_mu.shape[0] 85 | self.X = MultivariateNormal(dist_mu, dist_sigma) 86 | 87 | 88 | def get_random_samples(self): 89 | '''Generates a random sample from the list of points 90 | ''' 91 | index = 0 92 | random_samples = np.random.permutation(self.X.sample()*(self.q_max-self.q_min)+self.q_min) 93 | 94 | while True: 95 | yield random_samples[index, :] 96 | index += 1 97 | if index==self.seq_num: 98 | random_samples = np.random.permutation(self.X.sample()*(self.q_max-self.q_min)+self.q_min) 99 | index = 0 100 | 101 | def sampleUniform(self, state): 102 | '''Generate a sample from uniform distribution or key-points 103 | :param state: ompl.base.Space object 104 | ''' 105 | if self.X is None: 106 | sample_pos = ((self.q_max-self.q_min)*self.U.rvs()+self.q_min)[0] 107 | else: 108 | sample_pos = next(self.get_random_samples()) 109 | for i, val in enumerate(sample_pos): 110 | state[i] = float(val) 111 | return True 112 | ``` 113 | -------------------------------------------------------------------------------- /docker/pytorch.dockerfile: -------------------------------------------------------------------------------- 1 | FROM ompl:focal-1.6-mod AS BUILDER 2 | 3 | FROM nvcr.io/nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 as BASE 4 | 5 | COPY --from=BUILDER /usr/local/include/ompl-1.6 /usr/include/ompl-1.6 6 | COPY --from=BUILDER /usr/local/lib/libompl* /usr/local/lib/ 7 | COPY --from=BUILDER /usr/lib/libtriangle* /usr/lib/ 8 | COPY --from=BUILDER /usr/local/bin/ompl_benchmark_statistics.py /usr/bin/ompl_benchmark_statistics.py 9 | COPY --from=BUILDER /usr/lib/python3/dist-packages/ompl /usr/lib/python3/dist-packages/ompl 10 | 11 | ENV DEBIAN_FRONTEND=noninteractive 12 | 13 | # ----- Files required for OMPL ---------- 14 | RUN apt-get update && apt-get install -y \ 15 | libboost-serialization-dev \ 16 | libboost-filesystem-dev \ 17 | libboost-numpy-dev \ 18 | libboost-system-dev \ 19 | libboost-program-options-dev \ 20 | libboost-python-dev \ 21 | libboost-test-dev \ 22 | libflann-dev \ 23 | libode-dev \ 24 | libeigen3-dev \ 25 | python3-pip\ 26 | && rm -rf /var/lib/apt/lists/* 27 | 28 | RUN apt-get update && apt-get install -y \ 29 | pypy3 \ 30 | wget && \ 31 | # Install spot 32 | wget -O /etc/apt/trusted.gpg.d/lrde.gpg https://www.lrde.epita.fr/repo/debian.gpg && \ 33 | echo 'deb http://www.lrde.epita.fr/repo/debian/ stable/' >> /etc/apt/sources.list && \ 34 | apt-get update && \ 35 | apt-get install -y libspot-dev && \ 36 | pip3 install pygccxml pyplusplus 37 | 38 | RUN python3 -m pip install -U pip 39 | 40 | # ----------------------------------------- 41 | RUN pip install torch \ 42 | torchvision \ 43 | torchaudio --index-url https://download.pytorch.org/whl/cu118 44 | 45 | # Libgl1 used for open3d 46 | RUN apt-get update && apt-get install -y \ 47 | git \ 48 | libgl1 \ 49 | && rm -rf /var/lib/apt/lists/* 50 | 51 | RUN pip install einops \ 52 | open3d 53 | 54 | # Install torch_geometric 55 | RUN pip install torch_geometric 56 | 57 | # Install additional dependencies 58 | RUN pip install pyg_lib \ 59 | torch_scatter \ 60 | torch_sparse \ 61 | torch_cluster \ 62 | torch_spline_conv \ 63 | -f https://data.pyg.org/whl/torch-2.0.0+cu118.html -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "vqmpt" 3 | version = "0.0.1" 4 | authors = [ 5 | { name="Jacob J. Johnson", email="jjj025@ucsd.edu" }, 6 | ] 7 | description = "A package defining VQ-MPT models and sampling function" 8 | readme = "README.md" 9 | requires-python = ">=3.8" 10 | classifiers = [ 11 | "Programming Language :: Python :: 3", 12 | "License :: OSI Approved :: BSD 3-Clause", 13 | "Operating System :: OS Independent", 14 | ] 15 | [project.urls] 16 | "Homepage" = "https://github.com/pypa/sampleproject" 17 | "Bug Tracker" = "https://github.com/pypa/sampleproject/issues" 18 | [build-systems] 19 | requires = ["setuptools"] 20 | build-backend = "setuptools.build_meta" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | ansi2html==1.8.0 3 | asttokens==2.4.1 4 | attrs==23.1.0 5 | backcall==0.2.0 6 | blinker==1.7.0 7 | certifi==2022.12.7 8 | charset-normalizer==2.1.1 9 | click==8.1.7 10 | cmake==3.25.0 11 | comm==0.1.4 12 | ConfigArgParse==1.7 13 | contourpy==1.1.1 14 | cycler==0.12.1 15 | dash==2.14.1 16 | dash-core-components==2.0.0 17 | dash-html-components==2.0.0 18 | dash-table==5.0.0 19 | decorator==5.1.1 20 | einops==0.7.0 21 | executing==2.0.1 22 | fastjsonschema==2.18.1 23 | filelock==3.9.0 24 | Flask==3.0.0 25 | fonttools==4.44.0 26 | idna==3.4 27 | importlib-metadata==6.8.0 28 | importlib-resources==6.1.0 29 | ipython==8.12.3 30 | ipywidgets==8.1.1 31 | itsdangerous==2.1.2 32 | jedi==0.19.1 33 | Jinja2==3.1.2 34 | joblib==1.3.2 35 | jsonschema==4.19.2 36 | jsonschema-specifications==2023.7.1 37 | jupyter_core==5.5.0 38 | jupyterlab-widgets==3.0.9 39 | kiwisolver==1.4.5 40 | lit==15.0.7 41 | MarkupSafe==2.1.2 42 | matplotlib==3.7.3 43 | matplotlib-inline==0.1.6 44 | mpmath==1.3.0 45 | nbformat==5.7.0 46 | nest-asyncio==1.5.8 47 | networkx==3.0 48 | numpy==1.24.1 49 | open3d==0.17.0 50 | packaging==23.2 51 | pandas==2.0.3 52 | parso==0.8.3 53 | pexpect==4.8.0 54 | pickleshare==0.7.5 55 | Pillow==9.3.0 56 | pkgutil_resolve_name==1.3.10 57 | platformdirs==3.11.0 58 | plotly==5.18.0 59 | prompt-toolkit==3.0.39 60 | psutil==5.9.6 61 | ptyprocess==0.7.0 62 | pure-eval==0.2.2 63 | pyg-lib==0.3.0+pt20cu118 64 | pygccxml==2.4.0 65 | Pygments==2.16.1 66 | pyparsing==3.1.1 67 | pyplusplus==1.8.5 68 | pyquaternion==0.9.9 69 | python-dateutil==2.8.2 70 | pytz==2023.3.post1 71 | PyYAML==6.0.1 72 | referencing==0.30.2 73 | requests==2.28.1 74 | retrying==1.3.4 75 | rpds-py==0.10.6 76 | scikit-learn==1.3.2 77 | scipy==1.10.1 78 | six==1.16.0 79 | stack-data==0.6.3 80 | sympy==1.12 81 | tenacity==8.2.3 82 | threadpoolctl==3.2.0 83 | torch==2.0.1+cu118 84 | torch-cluster==1.6.3+pt20cu118 85 | torch-scatter==2.1.2+pt20cu118 86 | torch-sparse==0.6.18+pt20cu118 87 | torch-spline-conv==1.2.2+pt20cu118 88 | torch_geometric==2.4.0 89 | torchaudio==2.0.2+cu118 90 | torchvision==0.15.2+cu118 91 | tqdm==4.66.1 92 | traitlets==5.13.0 93 | triton==2.0.0 94 | typing_extensions==4.4.0 95 | tzdata==2023.3 96 | urllib3==1.26.13 97 | wcwidth==0.2.9 98 | Werkzeug==3.0.1 99 | widgetsnbextension==4.0.9 100 | zipp==3.17.0 101 | -------------------------------------------------------------------------------- /src/vqmpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucsdarclab/vqmpt/9e022f8209e9c78fa177446b3ba978bfb35ad033/src/vqmpt/__init__.py -------------------------------------------------------------------------------- /src/vqmpt/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ucsdarclab/vqmpt/9e022f8209e9c78fa177446b3ba978bfb35ad033/src/vqmpt/modules/__init__.py -------------------------------------------------------------------------------- /src/vqmpt/modules/autoregressive.py: -------------------------------------------------------------------------------- 1 | """ The autoregressive model for predicting steps. 2 | """ 3 | import torch.nn as nn 4 | import torch_geometric.utils as tg_utils 5 | 6 | from . import encoder 7 | from . import env_encoder 8 | from . import context_encoder 9 | 10 | 11 | class AutoRegressiveModel(nn.Module): 12 | """Get the encoder input and convert it to set of logits values.""" 13 | 14 | def __init__( 15 | self, d_k, d_v, d_model, d_inner, n_layers, n_heads, num_keys, dropout=0.1 16 | ): 17 | """ 18 | :param d_k: dimension of the key. 19 | :param d_v: dimension of the value. 20 | :param d_inner: dimension of the latent vector. 21 | :param d_model: dimension of the latent vector. 22 | :param n_layers: Number of self-attention layers. 23 | :param n_heads: number of heads for self-attention layers. 24 | :param dropout: dropout for fully connected layer. 25 | """ 26 | super().__init__() 27 | 28 | self.layer_stack = nn.ModuleList( 29 | [ 30 | encoder.EncoderLayerPreNorm( 31 | d_model, d_inner, n_heads, d_k, d_v, dropout 32 | ) 33 | for _ in range(n_layers) 34 | ] 35 | ) 36 | 37 | # Add layer norm to the final layer 38 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 39 | 40 | # Implement logit function. 41 | self.class_pred = nn.Linear(d_model, num_keys) 42 | 43 | def forward(self, enc_output, slf_attn_mask=None): 44 | """ 45 | The forward module: 46 | :param enc_input: the i/p to the encoder. 47 | :param slf_attn_mask: mask for the self-attn. 48 | """ 49 | for attn_layer in self.layer_stack: 50 | enc_output = attn_layer(enc_output, slf_attn_mask) 51 | 52 | # Add layer normalization to the final layer 53 | enc_output = self.layer_norm(enc_output) 54 | 55 | # pass through logit function. 56 | enc_output = self.class_pred(enc_output) 57 | return enc_output 58 | 59 | 60 | class EnvContextCrossAttModel(nn.Module): 61 | """Given the context and environment model, return the cross attention model.""" 62 | 63 | def __init__(self, env_params, context_params, robot="2D"): 64 | """ 65 | :param env_params: A dictionary with values for the following keys for the envirnoment encoder 66 | {n_layers, n_heads, d_k, d_v, d_model, d_inner, dropout, n_position} 67 | :param context_params: A dict with values for the following keys for the context encoder. 68 | {} 69 | """ 70 | super().__init__() 71 | 72 | # Define Environment model. 73 | if robot == "2D": 74 | self.env_encoder = env_encoder.EnvEncoder(**env_params) # type: ignore 75 | 76 | if robot == "6D" or robot == "14D" or robot == "7D": 77 | self.env_encoder = env_encoder.FeatureExtractor(**env_params) # type: ignore 78 | 79 | self.robot = robot 80 | 81 | # Translate context embedding and do cross-attention. 82 | self.context_encoder = context_encoder.ContextEncoder(**context_params) # type: ignore 83 | 84 | def forward(self, env_input, start_goal_input): 85 | cross_encoding_output = None 86 | # Pass the input through the encoder. 87 | if self.robot == "2D": 88 | env_encoding_output = self.env_encoder(env_input) 89 | # Take the cross attention model. 90 | (cross_encoding_output,) = self.context_encoder( 91 | start_goal_input, env_encoding_output 92 | ) 93 | 94 | if self.robot == "6D" or self.robot == "14D" or self.robot == "7D": 95 | (h, _, batch), _ = self.env_encoder(env_input) 96 | env_encoding_output, dec_mask = tg_utils.to_dense_batch(h, batch) 97 | # Take the cross attention model 98 | (cross_encoding_output,) = self.context_encoder( 99 | start_goal_input, env_encoding_output, env_encoding_mask=dec_mask 100 | ) 101 | 102 | return cross_encoding_output 103 | -------------------------------------------------------------------------------- /src/vqmpt/modules/context_encoder.py: -------------------------------------------------------------------------------- 1 | """ Define the context encoder for the network. 2 | """ 3 | from torch import nn 4 | 5 | from . import decoder 6 | 7 | 8 | class ContextEncoder(nn.Module): 9 | """Converting s/g points to planning context.""" 10 | 11 | def __init__( 12 | self, d_context, d_k, d_v, d_model, d_inner, n_layers, n_heads, dropout=0.1 13 | ): 14 | """ 15 | :param d_context: input size of the context map. 16 | :param d_k: dimension of the key. 17 | :param d_v: dimension of the value. 18 | :param d_model: dimension of the latent vector. 19 | :param d_inner: dimension of fully connected layer. 20 | :param n_layers: number of self-attention layers. 21 | :param n_heads: number of heads for self-attention layers. 22 | :param dropout: dropout for fully connected layer. 23 | """ 24 | super().__init__() 25 | 26 | # Convert the context to latent embedding 27 | self.to_latent_embedding = nn.Sequential( 28 | nn.Linear(d_context, d_inner), 29 | nn.ReLU(), 30 | nn.Linear(d_inner, d_model), 31 | nn.Dropout(dropout), 32 | nn.LayerNorm(d_model), 33 | ) 34 | 35 | self.layer_stack = nn.ModuleList( 36 | [ 37 | decoder.DecoderLayer( 38 | d_model, d_inner, n_heads, d_k, d_v, dropout=dropout 39 | ) 40 | for _ in range(n_layers) 41 | ] 42 | ) 43 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 44 | 45 | def forward(self, context, env_encoding, env_encoding_mask=None): 46 | """ 47 | :param context: the s/g pairs or textk embeddings. 48 | :param env_encoding: Environment encoding from PC++ or FCN layers 49 | """ 50 | # pass the context through the feed-forward network. 51 | context_embedding = self.to_latent_embedding(context) 52 | 53 | # Pass the environment embedding through the decoder layer. 54 | for cross_layer in self.layer_stack: 55 | context_embedding = cross_layer( 56 | context_embedding, env_encoding, dec_enc_attn_mask=env_encoding_mask 57 | ) 58 | 59 | context_embedding = self.layer_norm(context_embedding) 60 | return (context_embedding,) 61 | -------------------------------------------------------------------------------- /src/vqmpt/modules/decoder.py: -------------------------------------------------------------------------------- 1 | """ VAE style decoder. 2 | """ 3 | from torch import nn 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | from . import sublayers 8 | 9 | 10 | class DecoderLayer(nn.Module): 11 | """Compose with three layers""" 12 | 13 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 14 | """ 15 | Initialize the Layer 16 | :param d_model: Dimension of input/output this layer. 17 | :param d_inner: Dimension of hidden layer of the position wise FFN 18 | :param n_head: Number of self-attention modules. 19 | :param d_k: Dimension of each Key. 20 | :param d_v: Dimension of each Value. 21 | :param dropout: Argument to the dropout layer. 22 | """ 23 | super().__init__() 24 | self.enc_attn = sublayers.MultiHeadAttentionPreNorm( 25 | n_head, d_model, d_k, d_v, dropout=dropout 26 | ) 27 | self.pos_ffn = sublayers.PositionwiseFeedForwardPreNorm( 28 | d_model, d_inner, dropout=dropout 29 | ) 30 | 31 | def forward(self, dec_input, enc_output, dec_enc_attn_mask=None): 32 | """ 33 | Callback function 34 | :param dec_input: 35 | :param enc_output: 36 | :param slf_attn_mask: 37 | :param dec_enc_attn_mask: 38 | """ 39 | dec_output = self.enc_attn( 40 | dec_input, enc_output, enc_output, dec_mask=dec_enc_attn_mask 41 | ) 42 | dec_output = self.pos_ffn(dec_output) 43 | return dec_output 44 | 45 | 46 | class DecoderPreNormGeneral(nn.Module): 47 | """Decoder that takes the latent encoding and generates joint samples. 48 | with non-zero cross correlation variables. 49 | """ 50 | 51 | def __init__(self, e_dim, h_dim, c_space_dim, dropout=0.5): 52 | """ 53 | :param e_dim: Dimension of the dictionary vectors. 54 | :param h_dim: Dimension of the feedforward networks hidden vector. 55 | :param c_space_dim: Dimension of the c-space. 56 | :param dropout: Dropout value for the fullyconnected layer. 57 | """ 58 | super().__init__() 59 | self.pos_ffn = sublayers.PositionwiseFeedForward(e_dim, h_dim, dropout) 60 | 61 | # Layers for returning mean and variance 62 | self.mu = nn.Sequential(nn.Linear(e_dim, c_space_dim), nn.Tanh()) 63 | self.diag = nn.Sequential( 64 | nn.Linear(e_dim, c_space_dim), 65 | ) 66 | self.l = nn.Sequential( 67 | nn.Linear(e_dim, int(c_space_dim * (c_space_dim - 1) / 2)) 68 | ) 69 | self.register_buffer( 70 | "l_index", torch.tril_indices(c_space_dim, c_space_dim, offset=-1) 71 | ) 72 | self.c_space_dim = c_space_dim 73 | 74 | def forward(self, z_q): 75 | """Returns the decoded mean and variance. 76 | :param z_q: Latent encoding vectors. 77 | :returns tuple: mean and diagonal variance vectors. 78 | """ 79 | z_q = self.pos_ffn(z_q) 80 | D = torch.diag_embed(F.softplus(self.diag(z_q))) 81 | L_linear = self.l(z_q) 82 | L = torch.diag_embed( 83 | torch.ones( 84 | (z_q.shape[0], z_q.shape[1], self.c_space_dim), 85 | device=z_q.device 86 | ) 87 | ) 88 | L[:, :, self.l_index[0], self.l_index[1]] = L_linear # type: ignore 89 | covar = torch.matmul(torch.matmul(L, D), L.transpose(2, 3)) 90 | return (self.mu(z_q) + 1) / 2, covar 91 | 92 | def get_sigma_sqrroot(self, z_q): 93 | """ 94 | Returns the square root of the covariance matrix. 95 | :param z_q: a projected dictionary tensor of dimension d_model. 96 | """ 97 | z_q = self.pos_ffn(z_q) 98 | D_sqrroot = torch.diag_embed(torch.sqrt(F.softplus(self.diag(z_q)))) 99 | L_linear = self.l(z_q) 100 | L = torch.diag_embed( 101 | torch.ones( 102 | (z_q.shape[0], z_q.shape[1], self.c_space_dim), 103 | device=z_q.device 104 | ) 105 | ) 106 | L[:, :, self.l_index[0], self.l_index[1]] = L_linear # type: ignore 107 | return L @ D_sqrroot 108 | 109 | def get_mean(self, z_q): 110 | """ 111 | Returns the mean given projected latent-vector dictonary 112 | :param z_q: A projected dictionary tensor of dimension d_model 113 | """ 114 | z_q = self.pos_ffn(z_q) 115 | return (self.mu(z_q) + 1) / 2 116 | -------------------------------------------------------------------------------- /src/vqmpt/modules/encoder.py: -------------------------------------------------------------------------------- 1 | """Define the Layers 2 | Derived from - 3 | https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/132907dd272e2cc92e3c10e6c4e783a87ff8893d/transformer/Layers.py 4 | """ 5 | 6 | import numpy as np 7 | 8 | from torch import nn 9 | import torch 10 | import torch.utils.checkpoint 11 | 12 | from . import sublayers 13 | 14 | 15 | class EncoderLayer(nn.Module): 16 | """Single Encoder layer, that consists of a MHA layers and positiion-wise 17 | feedforward layer. 18 | """ 19 | 20 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 21 | """ 22 | Initialize the module. 23 | :param d_model: Dimension of input/output of this layer 24 | :param d_inner: Dimension of the hidden layer of hte position-wise feedforward layer 25 | :param n_head: Number of self-attention modules 26 | :param d_k: Dimension of each Key 27 | :param d_v: Dimension of each Value 28 | :param dropout: Argument to the dropout layer. 29 | """ 30 | super().__init__() 31 | self.slf_attn = sublayers.MultiHeadAttention( 32 | n_head, d_model, d_k, d_v, dropout=dropout 33 | ) 34 | self.pos_ffn = sublayers.PositionwiseFeedForward( 35 | d_model, d_inner, dropout=dropout 36 | ) 37 | 38 | def forward(self, enc_input, slf_attn_mask=None): 39 | """ 40 | The forward module: 41 | :param enc_input: The input to the encoder. 42 | :param slf_attn_mask: 43 | """ 44 | 45 | # With Gradient Checking 46 | enc_output = torch.utils.checkpoint.checkpoint( 47 | self.slf_attn, enc_input, enc_input, enc_input, slf_attn_mask 48 | ) 49 | 50 | enc_output = self.pos_ffn(enc_output) 51 | return enc_output 52 | 53 | 54 | class EncoderLayerPreNorm(nn.Module): 55 | """Single Encoder layer, that consists of a MHA layers and positiion-wise 56 | feedforward layer. 57 | """ 58 | 59 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 60 | """ 61 | Initialize the module. 62 | :param d_model: Dimension of input/output of this layer 63 | :param d_inner: Dimension of the hidden layer of hte position-wise 64 | feedforward layer 65 | :param n_head: Number of self-attention modules 66 | :param d_k: Dimension of each Key 67 | :param d_v: Dimension of each Value 68 | :param dropout: Argument to the dropout layer. 69 | """ 70 | super(EncoderLayerPreNorm, self).__init__() 71 | self.slf_attn = sublayers.MultiHeadAttentionPreNorm( 72 | n_head, d_model, d_k, d_v, dropout=dropout 73 | ) 74 | self.pos_ffn = sublayers.PositionwiseFeedForwardPreNorm( 75 | d_model, d_inner, dropout=dropout 76 | ) 77 | 78 | def forward(self, enc_input, slf_attn_mask=None): 79 | """ 80 | The forward module: 81 | :param enc_input: The input to the encoder. 82 | :param slf_attn_mask: mask for self-attn. 83 | """ 84 | # Without gradient Checking 85 | enc_output = self.slf_attn(enc_input, enc_input, enc_input, mask=slf_attn_mask) 86 | enc_output = self.pos_ffn(enc_output) 87 | return enc_output 88 | 89 | 90 | class PositionalEncoding(nn.Module): 91 | """Positional encoding""" 92 | 93 | def __init__(self, d_hid, n_position): 94 | """ 95 | Intialize the Encoder. 96 | :param d_hid: Dimesion of the attention features. 97 | :param n_position: Number of positions to consider. 98 | """ 99 | super().__init__() 100 | self.n_pos_sqrt = n_position 101 | 102 | # Not parameters 103 | self.register_buffer( 104 | "pos_table", self._get_sinusoid_encoding_table(n_position, d_hid) 105 | ) 106 | 107 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 108 | """ 109 | Sinusoid position encoding table. 110 | :param n_position: 111 | :param d_hid: 112 | :returns 113 | """ 114 | 115 | def get_position_angle_vec(position): 116 | return [ 117 | position / np.power(10000, 2 * (hid_j // 2) / d_hid) 118 | for hid_j in range(d_hid) 119 | ] 120 | 121 | sinusoid_table = np.array( 122 | [get_position_angle_vec(pos_i) for pos_i in range(n_position)] 123 | ) 124 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 125 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 126 | return torch.Tensor(sinusoid_table[None, :]) 127 | 128 | def forward(self, x): 129 | """ 130 | Callback function 131 | :param x: 132 | """ 133 | pos_enc = self.pos_table[:, : x.size(1)].clone().detach() # type: ignore 134 | return x + pos_enc 135 | 136 | 137 | class EncoderPreNorm(nn.Module): 138 | """The encoder of the planner.""" 139 | 140 | def __init__( 141 | self, 142 | n_layers, 143 | n_heads, 144 | d_k, 145 | d_v, 146 | d_model, 147 | d_inner, 148 | c_space_dim, 149 | dropout, 150 | n_position, 151 | ): 152 | """ 153 | Intialize the encoder. 154 | :param n_layers: Number of layers of attention and fully connected 155 | layer. 156 | :param n_heads: Number of self attention modules. 157 | :param d_k: Dimension of each Key. 158 | :param d_v: Dimension of each Value. 159 | :param d_model: Dimension of input/output of encoder layer. 160 | :param d_inner: Dimension of the hidden layers of position wise FFN 161 | :param c_space_dim: Dimension of the c-space 162 | :param dropout: The value to the dropout argument. 163 | :param n_position: Total number of patches the model can handle. 164 | """ 165 | super().__init__() 166 | 167 | # Embedding 168 | self.to_embedding = nn.Sequential( 169 | nn.Linear(c_space_dim, d_model), 170 | nn.ReLU(), 171 | nn.Linear(d_model, d_model), 172 | ) 173 | # Position Encoding. 174 | # NOTE: Current setup for adding position encoding after patch 175 | # Embedding. 176 | self.position_enc = PositionalEncoding(d_model, n_position=n_position) 177 | 178 | self.dropout = nn.Dropout(p=dropout) 179 | self.layer_stack = nn.ModuleList( 180 | [ 181 | EncoderLayerPreNorm( 182 | d_model, d_inner, n_heads, d_k, d_v, dropout=dropout 183 | ) 184 | for _ in range(n_layers) 185 | ] 186 | ) 187 | 188 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 189 | 190 | def forward(self, input_sequence, returns_attns=False): 191 | """ 192 | The input of the Encoder should be of dim (b, c, h, w). 193 | :param input_sequence: Sequence of the trajectories. 194 | :param returns_attns: If True, the model returns slf_attns at each 195 | layer 196 | """ 197 | enc_slf_attn_list = [] 198 | # Get latent embedding. 199 | enc_output = self.to_embedding(input_sequence) 200 | enc_output = self.layer_norm(enc_output) 201 | 202 | # Add position encoding. 203 | enc_output = self.position_enc(enc_output) 204 | 205 | enc_output = self.dropout(enc_output) 206 | enc_output = self.layer_norm(enc_output) 207 | 208 | for enc_layer in self.layer_stack: 209 | enc_output = enc_layer(enc_output, slf_attn_mask=None) 210 | 211 | # Final layer requires a layer-norm 212 | enc_output = self.layer_norm(enc_output) 213 | 214 | if returns_attns: 215 | return enc_output, enc_slf_attn_list 216 | return (enc_output,) 217 | -------------------------------------------------------------------------------- /src/vqmpt/modules/env_encoder.py: -------------------------------------------------------------------------------- 1 | """ Defining layers for converting maps to latent encodings. 2 | """ 3 | import torch 4 | from torch import nn 5 | import numpy as np 6 | from einops.layers.torch import Rearrange 7 | from einops import rearrange 8 | import torch_geometric.nn as tgnn 9 | 10 | 11 | class PositionalEncoding(nn.Module): 12 | """Positional encoding""" 13 | 14 | def __init__(self, d_hid, n_position): 15 | """ 16 | Intialize the Encoder. 17 | :param d_hid: Dimesion of the attention features. 18 | :param n_position: Number of positions to consider. 19 | :param train_shape: The 2D shape of the training model. 20 | """ 21 | super(PositionalEncoding, self).__init__() 22 | self.n_pos_sqrt = int(np.sqrt(n_position)) 23 | # Not a parameter 24 | self.register_buffer("hashIndex", self._get_hash_table(n_position)) 25 | self.register_buffer( 26 | "pos_table", self._get_sinusoid_encoding_table(n_position, d_hid) 27 | ) 28 | 29 | def _get_hash_table(self, n_position): 30 | """ 31 | A simple table converting 1D indexes to 2D grid. 32 | :param n_position: The number of positions on the grid. 33 | """ 34 | return rearrange( 35 | torch.arange(n_position), 36 | "(h w) -> h w", 37 | h=int(np.sqrt(n_position)), 38 | w=int(np.sqrt(n_position)), 39 | ) 40 | 41 | def _get_sinusoid_encoding_table(self, n_position, d_hid): 42 | """ 43 | Sinusoid position encoding table. 44 | :param n_position: 45 | :param d_hid: 46 | :returns 47 | """ 48 | # TODO: make it with torch instead of numpy 49 | def get_position_angle_vec(position): 50 | return [ 51 | position / np.power(10000, 2 * (hid_j // 2) / d_hid) 52 | for hid_j in range(d_hid) 53 | ] 54 | 55 | sinusoid_table = np.array( 56 | [get_position_angle_vec(pos_i) for pos_i in range(n_position)] 57 | ) 58 | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i 59 | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 60 | return torch.FloatTensor(sinusoid_table[None, :]) 61 | 62 | def forward(self, x, conv_shape): 63 | """ 64 | Callback function 65 | :param x: 66 | """ 67 | selectIndex = rearrange( 68 | self.hashIndex[: conv_shape[0], : conv_shape[1]], "h w -> (h w)" # type: ignore 69 | ) 70 | return x + torch.index_select(self.pos_table, dim=1, index=selectIndex) # type: ignore 71 | 72 | 73 | # Encoder for environment 74 | class EnvEncoder(nn.Module): 75 | """The environment encoder of the planner.""" 76 | 77 | def __init__(self, d_model, dropout, n_position): 78 | """ 79 | Intialize the encoder. 80 | :param n_layers: Number of layers of attention and fully connected layer. 81 | :param n_heads: Number of self attention modules. 82 | :param d_k: Dimension of each Key. 83 | :param d_v: Dimension of each Value. 84 | :param d_model: Dimension of input/output of encoder layer. 85 | :param d_inner: Dimension of the hidden layers of position wise FFN 86 | :param dropout: The value to the dropout argument. 87 | :param n_position: Total number of patches the model can handle. 88 | :param train_shape: The shape of the output of the patch encodings. 89 | """ 90 | super().__init__() 91 | # Convert the image to and input embedding. 92 | # NOTE: This is one place where we can add convolution networks. 93 | # Convert the image to linear model 94 | 95 | # NOTE: Padding of 3 is added to the final layer to ensure that 96 | # the output of the network has receptive field across the entire map. 97 | # NOTE: pytorch doesn't have a good way to ensure automatic padding. 98 | # This allows only for a select few map sizes to be solved using this 99 | # method. 100 | self.to_patch_embedding = nn.Sequential( 101 | nn.Conv2d(1, 6, kernel_size=5), 102 | nn.MaxPool2d(kernel_size=2), 103 | nn.ReLU(), 104 | nn.Conv2d(6, 16, kernel_size=5), 105 | nn.MaxPool2d(kernel_size=2), 106 | nn.ReLU(), 107 | nn.Conv2d(16, d_model, kernel_size=5, stride=5, padding=3), 108 | ) 109 | 110 | self.reorder_dims = Rearrange("b c h w -> b (h w) c") 111 | # Position Encoding. 112 | # NOTE: Current setup for adding position encoding after patch 113 | # Embedding. 114 | self.position_enc = PositionalEncoding(d_model, n_position=n_position) 115 | 116 | self.dropout = nn.Dropout(p=dropout) 117 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 118 | 119 | def forward(self, input_map, returns_attns=False): 120 | """ 121 | The input of the Encoder should be of dim (b, c, h, w). 122 | :param input_map: The input map for planning. 123 | :param returns_attns: If True, the model returns slf_attns at 124 | each layer 125 | """ 126 | enc_output = self.to_patch_embedding(input_map) 127 | conv_map_shape = enc_output.shape[-2:] 128 | enc_output = self.reorder_dims(enc_output) 129 | 130 | enc_output = self.position_enc(enc_output, conv_map_shape) 131 | 132 | enc_output = self.dropout(enc_output) 133 | enc_output = self.layer_norm(enc_output) 134 | return enc_output 135 | 136 | 137 | # Point cloud encoder 138 | class SAModule(nn.Module): 139 | """The set abstraction layer""" 140 | 141 | def __init__(self, ratio, r, channels): 142 | """Initialization of the model. 143 | :param ratio: Amount of points dropped 144 | :param r: the radius of grouping 145 | :param channels: Shared weights for each latent vectors 146 | """ 147 | super(SAModule, self).__init__() 148 | self.ratio = ratio 149 | self.r = r 150 | mlp = nn.Sequential( 151 | *[ 152 | nn.Sequential( 153 | nn.Linear(c, channels[i + 1]), 154 | nn.ReLU(), 155 | nn.BatchNorm1d(channels[i + 1]), 156 | ) 157 | for i, c in enumerate(channels[:-1]) 158 | ] 159 | ) 160 | # NOTE: "Here, we do not really want to add self-loops to the graph as 161 | # we are operating in bipartite graphs. The real "self-loop" is already 162 | # added to tgnn.PointConv by the radius call." 163 | # Ref: https://github.com/pyg-team/pytorch_geometric/issues/2558 164 | self.conv = tgnn.PointNetConv(local_nn=mlp, add_self_loops=False).jittable() 165 | 166 | def forward(self, x, pos, batch): 167 | """Forward propogation of the model.""" 168 | # Reduce the density of point cloud by farthest point sampling 169 | # random_start=False, This is to ensure origin is added to the graph 170 | idx = tgnn.fps(pos, batch, ratio=self.ratio, random_start=False) 171 | # row - indexes for y 172 | # col - indexes for x 173 | row, col = tgnn.radius( 174 | pos, pos[idx], self.r, batch, batch[idx], max_num_neighbors=64 175 | ) 176 | # readjust the indexes for creating edge_index. 177 | newRow = idx[row] 178 | edge_index = torch.stack([col, newRow], dim=0) 179 | x = self.conv(x, pos, edge_index) 180 | pos, batch = pos[idx], batch[idx] 181 | return x[idx], pos, batch, idx 182 | 183 | 184 | class FeatureExtractor(torch.nn.Module): 185 | """Extract features from using PointNet++ architecture""" 186 | 187 | def __init__(self, d_model): 188 | """Initialize the network. 189 | :param input_dim: dimension of the point cloud data point. 190 | :param d_model: dimension of the final latent layer 191 | """ 192 | super(FeatureExtractor, self).__init__() 193 | self.sa1_module = SAModule(0.75, 0.2, channels=[3 + 3, 64, 128]) 194 | self.sa2_module = SAModule(0.75, 0.4, channels=[128 + 3, 256, d_model]) 195 | 196 | def forward(self, data): 197 | """ 198 | :param data: An object of type torch_geometric.data.Batch 199 | :returns tuple: (latent_vector, tensor_point, batch) 200 | """ 201 | allIndex = torch.arange(data.pos.shape[0], device=data.pos.device) 202 | *h_pos_batch, idx = self.sa1_module(data.pos, data.pos, data.batch) 203 | allIndex = allIndex[idx] 204 | *h_pos_batch, idx = self.sa2_module(*h_pos_batch) 205 | allIndex = allIndex[idx] 206 | return h_pos_batch, allIndex 207 | -------------------------------------------------------------------------------- /src/vqmpt/modules/quantizer.py: -------------------------------------------------------------------------------- 1 | # Define the vector quantizer module. 2 | # Taken from - https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py 3 | # and https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | from einops import rearrange 10 | 11 | 12 | class VectorQuantizer(nn.Module): 13 | """A vector quantizer for storing the dictionary of sample points.""" 14 | 15 | def __init__(self, n_e, e_dim, latent_dim): 16 | """ 17 | :param n_e: Number of elements in the embedding. 18 | :param e_dim: Size of the latent embedding vector. 19 | :param latent_dim: Dimension of the encoder vector. 20 | """ 21 | super().__init__() 22 | 23 | self.n_e = n_e 24 | self.e_dim = e_dim 25 | 26 | # Define the linear layer. 27 | self.input_linear_map = nn.Linear(latent_dim, e_dim) 28 | self.output_linear_map = nn.Linear(e_dim, latent_dim) 29 | 30 | # Initialize the embedding. 31 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 32 | nn.init.xavier_uniform_(self.embedding.weight) 33 | self.batch_norm = nn.BatchNorm1d(self.e_dim, affine=False) 34 | 35 | def forward(self, z, mask): 36 | """ 37 | Inputs the output of the encoder network z and maps it to a discrete 38 | one-hot vector that is the index of the closest embedding vector e_j 39 | 40 | z (continuous) -> z_q (discrete) 41 | 42 | z.shape = (batch, num_seq, latent_encoding) 43 | 44 | quantization pipeline: 45 | 1. get encoder output (B, S, E) 46 | 2. flatten input to (B*S, E) 47 | """ 48 | # flatten input vector 49 | z_flattened = rearrange(z, "B S E -> (B S) E") 50 | # pass through the input projection. 51 | z_flattened = self.input_linear_map(z_flattened) 52 | 53 | # Normalize input vectors. 54 | z_flattened = F.normalize(z_flattened) 55 | # Normalize embedding vectors. 56 | self.embedding.weight.data = F.normalize(self.embedding.weight.data) 57 | 58 | # =========== Since vectors are normalized ============== 59 | # distances from z to embeddings e_j (z - e)^2 = - e * z 60 | d = -torch.einsum( 61 | "bd,dn->bn", z_flattened, rearrange(self.embedding.weight, "n d -> d n") 62 | ) 63 | # ============================================================== 64 | 65 | min_encoding_indices = torch.argmin(d, dim=1) 66 | z_q_flattened = self.embedding(min_encoding_indices) 67 | 68 | # Preserve gradients through linear transform also 69 | z_q_flattened = z_flattened + z_q_flattened - z_flattened.detach() 70 | 71 | # Translate to output encoder shape 72 | z_q_flattened = self.output_linear_map(z_q_flattened) 73 | z_q = z_q_flattened.view(z.shape) 74 | 75 | perplexity = None 76 | min_encodings = None 77 | 78 | return z_q, (perplexity, min_encodings, min_encoding_indices) 79 | -------------------------------------------------------------------------------- /src/vqmpt/modules/sublayers.py: -------------------------------------------------------------------------------- 1 | """ Define the sublayers in encoder/decoder layer 2 | Derived From : 3 | https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/SubLayers.py 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | import torch.nn.functional as F 9 | 10 | from einops import repeat 11 | 12 | 13 | class ScaledDotProductAttention(nn.Module): 14 | """Scaled Dot-Product Attention""" 15 | 16 | def __init__(self, temperature, attn_dropout=0.1): 17 | """ 18 | Initialize the model. 19 | :param temperature: TODO .... 20 | :param attn_dropout: Argument to dropout after softmax(QK) 21 | """ 22 | super().__init__() 23 | self.temperature = temperature 24 | self.dropout = nn.Dropout(attn_dropout) 25 | 26 | def forward(self, q, k, v, mask=None, dec_mask=None): 27 | """ 28 | Callback Function: 29 | :param q: The Query matrix. 30 | :param k: The Key matrix. 31 | :param v: The value matrix. 32 | :param mask: The mask of the input. 33 | :returns (output, attention): A tuple consisting of softmax(QK^T)V and softmax(QK^T) 34 | """ 35 | attn = torch.matmul(q / self.temperature, k.transpose(2, 3)) 36 | 37 | # mask size is expected tbo be b x n_head x n 38 | if mask is not None: 39 | row_mask = mask.unsqueeze(2) 40 | attn = attn.masked_fill(row_mask == 0, -1e9) 41 | col_mask = mask.unsqueeze(3) 42 | attn = attn.masked_fill(col_mask == 0, -1e9) 43 | 44 | if dec_mask is not None: 45 | row_mask = dec_mask.unsqueeze(2) 46 | attn = attn.masked_fill(row_mask == 0, -1e9) 47 | 48 | attn = self.dropout(F.softmax(attn, dim=-1)) 49 | output = torch.matmul(attn, v) 50 | 51 | return output 52 | 53 | 54 | class MultiHeadAttention(nn.Module): 55 | """Multi-Head Attention module""" 56 | 57 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 58 | """ 59 | Intialize the model. 60 | :param n_head: Number of self-attention modules 61 | :param d_model: Dimension of input/output of this layer 62 | :param d_k: Dimension of each Key 63 | :param d_v: Dimension of each Value 64 | :param dropout: 65 | """ 66 | super().__init__() 67 | 68 | self.n_head = n_head 69 | self.d_k = d_k 70 | self.d_v = d_v 71 | 72 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 73 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 74 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 75 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 76 | 77 | self.attention = ScaledDotProductAttention(temperature=d_k**0.5) 78 | 79 | self.dropout = nn.Dropout(dropout) 80 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 81 | 82 | def forward(self, q, k, v, mask=None): 83 | """ 84 | Callback function. 85 | :param q: The Query matrix. 86 | :param k: The Key matrix. 87 | :param v: The value matrix. 88 | :param mask: The mask to use. 89 | :returns (output, attention): A tuple consisting of network output and softmax(QK^T) 90 | """ 91 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 92 | sz_b_q = q.size(0) 93 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 94 | 95 | residual = q 96 | 97 | # Pass through the pre-attention projection: b x lq x (n*dv) 98 | # Separate different heads: b x lq x n x dv 99 | q = self.w_qs(q).view(sz_b_q, len_q, n_head, d_k) 100 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 101 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 102 | 103 | # Transpose for attention dot product: b x n x lq x dv 104 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 105 | 106 | if mask is not None: 107 | mask = repeat(mask, "b lq -> b n lq", n=n_head) 108 | 109 | q = self.attention(q, k, v, mask=mask) 110 | 111 | # Transpose to move the head dimension back: b x lq x n x dv 112 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 113 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 114 | q = self.dropout(self.fc(q)) 115 | q += residual 116 | 117 | q = self.layer_norm(q) 118 | 119 | return q 120 | 121 | 122 | class MultiHeadAttentionPreNorm(nn.Module): 123 | """Multi-Head Attention module""" 124 | 125 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 126 | """ 127 | Intialize the model. 128 | :param n_head: Number of self-attention modules 129 | :param d_model: Dimension of input/output of this layer 130 | :param d_k: Dimension of each Key 131 | :param d_v: Dimension of each Value 132 | :param dropout: 133 | """ 134 | super().__init__() 135 | 136 | self.n_head = n_head 137 | self.d_k = d_k 138 | self.d_v = d_v 139 | 140 | self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False) 141 | self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False) 142 | self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False) 143 | self.fc = nn.Linear(n_head * d_v, d_model, bias=False) 144 | 145 | self.attention = ScaledDotProductAttention(temperature=d_k**0.5) 146 | 147 | self.dropout = nn.Dropout(dropout) 148 | self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) 149 | 150 | def forward(self, q, k, v, mask=None, dec_mask=None): 151 | """ 152 | Callback function. 153 | :param q: The Query matrix. 154 | :param k: The Key matrix. 155 | :param v: The value matrix. 156 | :param mask: The mask to use. 157 | :param dec_mask: applies mask only on the query input. 158 | :returns (output, attention): A tuple consisting of network output and softmax(QK^T) 159 | """ 160 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 161 | sz_b_q = q.size(0) 162 | sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) 163 | 164 | residual = q 165 | 166 | # Pre-norm layers: 167 | q = self.layer_norm(q) 168 | k = self.layer_norm(k) 169 | v = self.layer_norm(v) 170 | 171 | # Pass through the pre-attention projection: b x lq x (n*dv) 172 | # Separate different heads: b x lq x n x dv 173 | q = self.w_qs(q).view(sz_b_q, len_q, n_head, d_k) 174 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 175 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 176 | 177 | # Transpose for attention dot product: b x n x lq x dv 178 | q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2) 179 | 180 | if mask is not None: 181 | mask = repeat(mask, "b lq -> b n lq", n=n_head) 182 | if dec_mask is not None: 183 | dec_mask = repeat(dec_mask, "b lq -> b n lq", n=n_head) 184 | 185 | q = self.attention(q, k, v, mask=mask, dec_mask=dec_mask) 186 | 187 | # Transpose to move the head dimension back: b x lq x n x dv 188 | # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv) 189 | q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1) 190 | q = self.dropout(self.fc(q)) 191 | q += residual 192 | 193 | return q 194 | 195 | 196 | class PositionwiseFeedForward(nn.Module): 197 | """A simple 2 layer with fully connected layer.""" 198 | 199 | def __init__(self, d_in, d_hid, dropout=0.1): 200 | """ 201 | Initialize the model. 202 | :param d_in: Dimension of the input/output of the model. 203 | :param d_hid: Dimension of the hidden layer. 204 | :param dropout: Argument to the dropout layer. 205 | """ 206 | super().__init__() 207 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 208 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 209 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 210 | self.dropout = nn.Dropout(dropout) 211 | 212 | def forward(self, x): 213 | """ 214 | Callback function. 215 | :param x: The input to the function. 216 | :returns torch.array: An output of the same dimension as the input. 217 | """ 218 | residual = x 219 | 220 | x = self.w_2(F.relu(self.w_1(x))) 221 | x = self.dropout(x) 222 | x += residual 223 | 224 | x = self.layer_norm(x) 225 | 226 | return x 227 | 228 | 229 | class PositionwiseFeedForwardPreNorm(nn.Module): 230 | """A simple 2 layer with fully connected layer.""" 231 | 232 | def __init__(self, d_in, d_hid, dropout=0.1): 233 | """ 234 | Initialize the model. 235 | :param d_in: Dimension of the input/output of the model. 236 | :param d_hid: Dimension of the hidden layer. 237 | :param dropout: Argument to the dropout layer. 238 | """ 239 | super().__init__() 240 | self.w_1 = nn.Linear(d_in, d_hid) # position-wise 241 | self.w_2 = nn.Linear(d_hid, d_in) # position-wise 242 | self.layer_norm = nn.LayerNorm(d_in, eps=1e-6) 243 | self.dropout = nn.Dropout(dropout) 244 | 245 | def forward(self, x): 246 | """ 247 | Callback function. 248 | :param x: The input to the function. 249 | :returns torch.array: An output of the same dimension as the input. 250 | """ 251 | residual = x 252 | x = self.layer_norm(x) 253 | 254 | x = self.w_2(F.relu(self.w_1(x))) 255 | x = self.dropout(x) 256 | x += residual 257 | return x 258 | -------------------------------------------------------------------------------- /src/vqmpt/utils.py: -------------------------------------------------------------------------------- 1 | """ Useful functions for planning using VQ-MPT models. 2 | """ 3 | 4 | import json 5 | from os import path as osp 6 | from torch.nn import functional as F 7 | 8 | import torch 9 | import numpy as np 10 | import torch_geometric.data as tg_data 11 | 12 | from .modules import quantizer 13 | from .modules import decoder 14 | from .modules import autoregressive 15 | 16 | 17 | def get_inference_models( 18 | decoder_model_folder, 19 | ar_model_folder, 20 | device, 21 | n_e=2048, 22 | e_dim=8, 23 | ): 24 | """ 25 | Return the quantizer, decoder, cross-attention, and auto-regressive models. 26 | :param decoder_model_folder: The folder where the decoder model is stored. 27 | :param ar_model_folder: The folder where AR model is stored. 28 | :param device: which device to load the models on. 29 | :param n_e: Number of dictonary variables to be used. 30 | :param e_dim: Dimension of the dictionary latent vector. 31 | :returns tuple: quantizer model, decoder model, environment encoder, ar model 32 | """ 33 | # Define the decoder model 34 | with open(osp.join(decoder_model_folder, "model_params.json"), "r") as f: 35 | dict_model_params = json.load(f) 36 | 37 | decoder_model = decoder.DecoderPreNormGeneral( 38 | e_dim=dict_model_params["d_model"], 39 | h_dim=dict_model_params["d_inner"], 40 | c_space_dim=dict_model_params["c_space_dim"], 41 | ) 42 | 43 | quantizer_model = quantizer.VectorQuantizer( 44 | n_e=n_e, e_dim=e_dim, latent_dim=dict_model_params["d_model"] 45 | ) 46 | dec_file = osp.join(decoder_model_folder, "best_model.pkl") 47 | decoder_checkpoint = torch.load(dec_file, map_location=device) 48 | 49 | # Load model parameters and set it to eval 50 | for model, state_dict in zip( 51 | [quantizer_model, decoder_model], 52 | ["quantizer_state", "decoder_state"], 53 | ): 54 | model.load_state_dict(decoder_checkpoint[state_dict]) 55 | model.eval() 56 | model.to(device) 57 | 58 | # Define the AR + Cross attention model 59 | with open(osp.join(ar_model_folder, "cross_attn.json"), "r") as f: 60 | context_env_encoder_params = json.load(f) 61 | env_params = { 62 | "d_model": dict_model_params["d_model"], 63 | } 64 | context_env_encoder = autoregressive.EnvContextCrossAttModel( 65 | env_params, context_env_encoder_params, robot="6D" 66 | ) 67 | # Create the AR model 68 | with open(osp.join(ar_model_folder, "ar_params.json"), "r") as f: 69 | ar_params = json.load(f) 70 | ar_model = autoregressive.AutoRegressiveModel(**ar_params) 71 | 72 | # Load the parameters and set the model to eval 73 | ar_checkpoint = torch.load( 74 | osp.join(ar_model_folder, "best_model.pkl"), map_location=device 75 | ) 76 | for model, state_dict in zip( 77 | [context_env_encoder, ar_model], ["context_state", "ar_model_state"] 78 | ): 79 | model.load_state_dict(ar_checkpoint[state_dict]) 80 | model.eval() 81 | model.to(device) 82 | return quantizer_model, decoder_model, context_env_encoder, ar_model 83 | 84 | 85 | def get_beam_search_path( 86 | max_length, 87 | K, 88 | context_output, 89 | ar_model, 90 | quantizer_model, 91 | goal_index, 92 | device, 93 | ): 94 | """A beam search function, that stops when any of the paths reaches 95 | termination. 96 | :param max_length: Max length to search. 97 | :param K: Number of paths to keep. 98 | :param context_output: the tensor ecoding environment information. 99 | :param ar_model: nn.Model type for the Auto-Regressor. 100 | :param quantizer_model: For extracting the feature vector. 101 | :param goal_index: Index used to mark end of sequence 102 | :param device: device on which to do the processing. 103 | """ 104 | 105 | # Create place holder for input sequences. 106 | input_seq = torch.ones(K, max_length, 512, dtype=torch.float, device=device) * -1 107 | quant_keys = torch.ones(K, max_length) * -1 108 | mask = torch.zeros(K, max_length + 2, device=device) 109 | 110 | ar_model_input_i = torch.cat([context_output.repeat((K, 1, 1)), input_seq], dim=1) 111 | # mask the start/goal encoding and the prev. sequences. 112 | mask[:, :3] = 1 113 | 114 | # Get first set of quant_keys 115 | ar_output = ar_model(ar_model_input_i, mask) 116 | intial_cost = F.log_softmax(ar_output[:, 2, :], dim=-1) 117 | # Do not terminate on the final dictionary 118 | intial_cost[:, goal_index] = -1e9 119 | path_cost, start_index = intial_cost.topk(k=K, dim=-1) 120 | start_index = start_index[0] 121 | path_cost = path_cost[0] 122 | input_seq[:, 1, :] = quantizer_model.output_linear_map( 123 | quantizer_model.embedding(start_index) 124 | ) 125 | quant_keys[:, 0] = start_index 126 | for i in range(1, max_length - 1): 127 | ar_model_input_i = torch.cat( 128 | [context_output.repeat((K, 1, 1)), input_seq], dim=1 129 | ) 130 | # mask the start/goal encoding and the prev. sequences. 131 | mask[:, : 3 + i] = 1 132 | 133 | ar_output = ar_model(ar_model_input_i, mask) 134 | 135 | # Get the sequence cost for the next step 136 | seq_cost = F.softmax(ar_output[:, 2 + i, :], dim=-1) 137 | # Make self-loops impossible by setting the cost really low 138 | seq_cost[:, quant_keys[:, i - 1].to(dtype=torch.int64)] = -1e9 139 | 140 | # Get the top set of possible sequences by flattening across batch 141 | # sizes. 142 | cur_cost = path_cost[:, None] + seq_cost 143 | nxt_cost, flatten_index = cur_cost.flatten().topk(K) 144 | # Reshape back into tensor size to get the approriate batch index and 145 | # word index. 146 | new_sequence = torch.as_tensor( 147 | np.array(np.unravel_index(flatten_index.cpu().numpy(), seq_cost.shape)).T 148 | ) 149 | 150 | # Update previous keys given the current prediction. 151 | quant_keys[:, :i] = quant_keys[new_sequence[:, 0], :i] 152 | # Update the current set of keys. 153 | quant_keys[:, i] = new_sequence[:, 1].to(dtype=torch.float) 154 | # Update the cost 155 | path_cost = nxt_cost 156 | 157 | # Break at the first sign of termination 158 | if (new_sequence[:, 1] == goal_index).any(): 159 | break 160 | 161 | # Select index 162 | select_index = new_sequence[:, 1] != goal_index 163 | 164 | # Update the input embedding. 165 | input_seq[select_index, : i + 1, :] = input_seq[ 166 | new_sequence[select_index, 0], : i + 1, : 167 | ] 168 | input_seq[select_index, i + 1, :] = quantizer_model.output_linear_map( 169 | quantizer_model.embedding(new_sequence[select_index, 1].to(device)) 170 | ) 171 | return quant_keys, path_cost, input_seq 172 | 173 | 174 | def get_search_dist( 175 | norm_start_n_goal, 176 | depth_points, 177 | context_encoder, 178 | decoder_model, 179 | ar_model, 180 | quantizer_model, 181 | num_keys, 182 | device, 183 | ): 184 | """ 185 | Get the search distribution for a given start and goal state. 186 | :param norm_start_n_goal: numpy tensor with the normalized start and 187 | goal 188 | :param depth_points: 3D Point cloud data passed as an numpy array 189 | :param context_encoder: context encoder model 190 | :param decoder_model: decoder model to retrive distributions 191 | :param ar_model: auto-regressive model 192 | :param quantizer_model: quantizer model 193 | :param num_keys: Total number of keys in the dictionary 194 | :param device: device on which to perform torch operations 195 | :returns (torch.tensor, torch.tensor, float): Returns an array of 196 | mean and covariance matrix 197 | """ 198 | # Get the context. 199 | start_n_goal = torch.as_tensor( 200 | norm_start_n_goal, 201 | dtype=torch.float, 202 | ) 203 | map_data = tg_data.Data(pos=torch.as_tensor(depth_points, dtype=torch.float, device=device)) 204 | env_input = tg_data.Batch.from_data_list([map_data]) 205 | context_output = context_encoder(env_input, start_n_goal[None, :].to(device)) 206 | # Find the sequence of dict values using beam search 207 | goal_index = num_keys + 1 208 | quant_keys, _, input_seq = get_beam_search_path( 209 | 51, 210 | 3, 211 | context_output, 212 | ar_model, 213 | quantizer_model, 214 | goal_index, 215 | device, 216 | ) 217 | 218 | reached_goal = torch.stack(torch.where(quant_keys == goal_index), dim=1) 219 | # Get the distribution. 220 | if len(reached_goal) > 0: 221 | # Ignore the zero index, since it is encoding representation of start 222 | # vector. 223 | output_dist_mu, output_dist_sigma = decoder_model( 224 | input_seq[reached_goal[0, 0], 1:reached_goal[0, 1] + 1][None, :] 225 | ) 226 | dist_mu = output_dist_mu.detach().cpu() 227 | dist_sigma = output_dist_sigma.detach().cpu() 228 | # If only a single point is predicted, then reshape the vector to a 2D 229 | # tensor. 230 | if len(dist_mu.shape) == 1: 231 | dist_mu = dist_mu[None, :] 232 | dist_sigma = dist_sigma[None, :] 233 | # ========================== append search with goal ================ 234 | search_dist_mu = torch.zeros((reached_goal[0, 1] + 1, 7)) 235 | search_dist_mu[: reached_goal[0, 1], :] = dist_mu 236 | search_dist_mu[reached_goal[0, 1], :] = torch.tensor(norm_start_n_goal[-1]) 237 | search_dist_sigma = torch.diag_embed(torch.ones((reached_goal[0, 1] + 1, 7))) 238 | search_dist_sigma[: reached_goal[0, 1], :, :] = dist_sigma 239 | search_dist_sigma[reached_goal[0, 1], :, :] = ( 240 | search_dist_sigma[reached_goal[0, 1], :, :] * 0.01 241 | ) 242 | # ==================================================================== 243 | else: 244 | search_dist_mu = None 245 | search_dist_sigma = None 246 | return search_dist_mu, search_dist_sigma 247 | --------------------------------------------------------------------------------