├── .gitignore ├── LICENSE ├── README.md ├── data └── .gitignore ├── env.yml ├── evaluation ├── .gitignore ├── __init__.py └── evaluation_metrics.py ├── models ├── autoencoder.py ├── common.py ├── diffusion.py ├── encoders │ ├── __init__.py │ ├── pointcnn.py │ └── pointnet.py ├── flow.py ├── vae_flow.py └── vae_gaussian.py ├── pretrained └── .gitignore ├── results ├── .gitignore └── README.md ├── teaser.png ├── test_ae.py ├── test_gen.py ├── train_ae.py ├── train_gen.py └── utils ├── data.py ├── dataset.py ├── misc.py └── transform.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | .DS_Store 132 | /playgrounds 133 | /logs* 134 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Shitong Luo 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion-Point-Cloud (PointCNN version) 2 | 3 | This project is based on the open source implementation of the paper [**“Diffusion Probabilistic Models for 3D Point Cloud Generation”**](https://arxiv.org/abs/2103.01458), extending its original version and replacing the **backbone** of point cloud feature extraction from **PointNet** to **PointCNN**. This version achieves better generation quality and diversity on several 3D point cloud datasets. 4 | 5 | --- 6 | 7 | ## Project Introduction 8 | 9 | In the original paper, the authors applied diffusion probabilistic models to the task of 3D point cloud generation and proposed a denoising model based on PointNet as a feature extraction network. By defining the forward denoising process in the training phase, the model learns the inverse denoising process, so that high-fidelity target point clouds can be gradually sampled from Gaussian noise point clouds during inference. 10 | 11 | However, although **PointNet** is simple and effective to implement, its ability to express local structures is relatively limited. To this end, we replaced **PointNet** with [**PointCNN**](https://arxiv.org/abs/1801.07791) to enhance the ability to extract local neighborhood geometric information, thereby achieving better performance in generating finer local details and shape diversity. 12 | 13 | --- 14 | 15 | ## Major updates 16 | 17 | 1. **Feature extraction network: switch from PointNet to PointCNN** 18 | - **PointCNN** introduces the X-Conv operation, first performs a learnable transformation on the neighborhood point set, and then performs a convolution-like aggregation, so that the model can better capture the local geometric structure and the relationship between points. 19 | - Compared with PointNet, which only uses MLP for each point and performs global pooling, PointCNN can more effectively retain and integrate local-global information and improve the representation of complex 3D shapes. 20 | - With this change, the model has better performance in local detail restoration and generation diversity. 21 | 22 | 2. **Training stability** 23 | - Further optimize the hyperparameters, including batch size, learning rate, etc., to adapt to the deep network structure of PointCNN. 24 | - Experiments show that there is a certain degree of improvement in common evaluation indicators (such as Coverage, MMD, Chamfer Distance, etc.). 25 | 26 | 3. **Overall performance improvement** 27 | - Compared with the original PointNet version, the generated 3D point cloud is more realistic and natural in both overall structure and local details. 28 | 29 | --- 30 | 31 | ## Environment requirements 32 | 33 | - Python 3.7+ 34 | - PyTorch >= 1.7 (compatible with CUDA 10.2 / 11.x) 35 | - Common scientific computing and visualization libraries such as Numpy, Scipy, Matplotlib 36 | - [Open3D](http://www.open3d.org/) (optional, used for point cloud operations, etc.) 37 | - [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/) (if your PointCNN implementation relies on PyG's neighborhood search and other functions) 38 | 39 | **[Option 1]** Please first install the required libraries according to [env.yml](./env.yml) in this repository or according to the dependencies listed in the main branch: 40 | ```bash 41 | # Create the environment 42 | conda env create -f env.yml 43 | # Activate the environment 44 | conda activate dpm-pc-gen 45 | ``` 46 | **[Option 2]** Or you may setup the environment manually (**If you are using GPUs that only work with CUDA 11 or greater**). 47 | 48 | Our model only depends on the following commonly used packages, all of which can be installed via conda. 49 | 50 | | Package | Version | 51 | | ------------ | -------------------------------- | 52 | | PyTorch | ≥ 1.7.0 | 53 | | h5py | *not specified* (we used 4.61.1) | 54 | | tqdm | *not specified* | 55 | | tensorboard | *not specified* (we used 2.5.0) | 56 | | numpy | *not specified* (we used 1.20.2) | 57 | | scipy | *not specified* (we used 1.6.2) | 58 | | scikit-learn | *not specified* (we used 0.24.2) | 59 | 60 | 61 | ## Data preparation 62 | 63 | ### Dataset 64 | 65 | - **It is recommended to use ShapeNet, ModelNet and other common 3D shape datasets for experiments. ** 66 | - **Download and unzip the corresponding dataset to the `data/` directory (you can also specify the path yourself) according to actual needs. ** 67 | 68 | ### Preprocessing 69 | 70 | - For each 3D object, downsample/normalize it to a fixed number of points (such as 1024 points) as needed, and convert it to `.xyz` or `.npy` format. 71 | - The above steps can be completed in the script `data_preprocess.py`, and the preprocessing results are stored in the specified folder. 72 | - For details, please refer to the main branch and expect to remain consistent 73 | 74 | ## Configuration file 75 | 76 | - Set model hyperparameters, training hyperparameters, dataset path and other information in `configs/pointcnn_config.yaml`. 77 | - Core parameters include: 78 | - `num_points`: The number of points in each point cloud (such as 1024). 79 | - `batch_size`: Training batch size. 80 | - `learning_rate`: Initial learning rate. 81 | - `diffusion_steps`: The number of steps in the diffusion process. 82 | - `model`: Specify **PointCNN** as the feature extraction network. 83 | 84 | 85 | ## About the EMD Metric 86 | 87 | We have removed the EMD module due to GPU compatability issues. The legacy code can be found on the `emd-cd` branch. 88 | 89 | If you have to compute the EMD score or compare our model with others, we strongly advise you to use your own code to compute the metrics. The generation and decoding results will be saved to the `results` folder after each test run. 90 | 91 | ## Training 92 | 93 | ```bash 94 | # Train an auto-encoder 95 | python train_ae.py 96 | 97 | # Train a generator 98 | python train_gen.py 99 | ``` 100 | 101 | You may specify the value of arguments. Please find the available arguments in the script. 102 | 103 | Note that `--categories` can take `all` (use all the categories in the dataset), `airplane`, `chair` (use a single category), or `airplane,chair` (use multiple categories, separated by commas). 104 | 105 | ### Notes on the Metrics 106 | 107 | Note that the metrics computed during the validation stage in the training script (`train_gen.py`, `train_ae.py`) are not comparable to the metrics reported by the test scripts (`test_gen.py`, `test_ae.py`). ***If you train your own models, please evaluate them using the test scripts***. The differences include: 108 | 1. The scale of Chamfer distance in the training script is different. In the test script, we renormalize the bounding boxes of all the point clouds before calculating the metrics (Line 100, `test_gen.py`). However, in the validation stage of training, we do not renormalize the point clouds. 109 | 2. During the validation stage of training, we only use a subset of the validation set (400 point clouds) to compute the metrics and generates only 400 point clouds (controlled by the `--test_size` parameter). Limiting the number to 400 is for saving time. However, the actual size of the `airplane` validation set is 607, larger than 400. Less point clouds mean that it is less likely to find similar point clouds in the validation set for a generated point cloud. Hence, it would lead to a worse Minimum-Matching-Distance (MMD) score even if we renormalize the shapes during the validation stage in the training script. 110 | 111 | 112 | ## Testing 113 | 114 | ```bash 115 | # Test an auto-encoder 116 | python test_ae.py --ckpt ./pretrained/AE_all.pt --categories all 117 | 118 | # Test a generator 119 | python test_gen.py --ckpt ./pretrained/GEN_airplane.pt --categories airplane 120 | ``` 121 | 122 | 123 | 124 | 125 | ## Experimental results and performance 126 | 127 | Compared with the original **PointNet** version, **PointCNN** as the backbone network can capture richer local geometric structures, thus achieving improvements in **Coverage (COV)**, **Minimum Matching Distance (MMD)**, and **1-NNA** indicators: 128 | 129 | The following are the local test results of the current setting on the Airplane data: 130 | 131 | | Method | COV-CD (↑) | COV-EMD (↑) | MMD-CD (↓) | MMD-EMD (↓) | 1-NNA-CD (↓) | 1-NNA-EMD (↓) | 132 | |-----------------------|-----------|------------|-----------|------------|--------------|---------------| 133 | | **PointNet (Original)** | 48.71% | 45.47% | 3.276 | 1.061 | 64.83% | 75.12% | 134 | | **PointCNN (This project)** | 48.83% | 45.60% | 3.109 | 0.998 | 64.56% | 75.05% | 135 | 136 | 137 | ## References 138 | 139 | - [Diffusion Probabilistic Models for 3D Point Cloud Generation](https://arxiv.org/abs/2103.01458) 140 | Shitong Luo, Wei Hu 141 | 142 | - [PointCNN: Convolution On X-Transformed Points](https://arxiv.org/abs/1801.07791) 143 | Yangyan Li, Rui Bu, Mingchao Sun, Wei Wu, Xinhan Di, Baoquan Chen 144 | 145 | ## Acknowledgements 146 | 147 | - Thanks to the original open source project author for providing the basic framework and reference implementation. 148 | - Thanks to all developers who have contributed to the open source community. 149 | 150 | If you encounter any problems while using or reproducing this project, please [submit an issue](https://github.com/luost26/diffusion-point-cloud/issues) or contact the author. 151 | 152 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: dpm-pc-gen 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=4.5=1_gnu 9 | - absl-py=0.13.0=pyhd8ed1ab_0 10 | - aiohttp=3.7.4.post0=py37h5e8e339_0 11 | - async-timeout=3.0.1=py_1000 12 | - attrs=21.2.0=pyhd8ed1ab_0 13 | - blas=1.0=mkl 14 | - blinker=1.4=py_1 15 | - brotlipy=0.7.0=py37h5e8e339_1001 16 | - c-ares=1.17.1=h7f98852_1 17 | - ca-certificates=2021.5.30=ha878542_0 18 | - cached-property=1.5.2=hd8ed1ab_1 19 | - cached_property=1.5.2=pyha770c72_1 20 | - cachetools=4.2.2=pyhd8ed1ab_0 21 | - certifi=2021.5.30=py37h89c1867_0 22 | - cffi=1.14.5=py37hc58025e_0 23 | - chardet=4.0.0=py37h89c1867_1 24 | - click=8.0.1=py37h89c1867_0 25 | - cryptography=3.4.7=py37h5d9358c_0 26 | - cudatoolkit=10.1.243=h6bb024c_0 27 | - dataclasses=0.8=pyhc8e2a94_1 28 | - freetype=2.10.4=h5ab3b9f_0 29 | - google-auth=1.32.0=pyh6c4a22f_0 30 | - google-auth-oauthlib=0.4.1=py_2 31 | - grpcio=1.38.1=py37hb27c1af_0 32 | - h5py=3.3.0=nompi_py37ha3df211_100 33 | - hdf5=1.10.6=nompi_h7c3c948_1111 34 | - idna=2.10=pyh9f0ad1d_0 35 | - importlib-metadata=4.6.0=py37h89c1867_0 36 | - intel-openmp=2021.2.0=h06a4308_610 37 | - joblib=1.0.1=pyhd8ed1ab_0 38 | - jpeg=9b=h024ee3a_2 39 | - krb5=1.19.1=hcc1bbae_0 40 | - lcms2=2.12=h3be6417_0 41 | - ld_impl_linux-64=2.35.1=h7274673_9 42 | - libblas=3.9.0=9_mkl 43 | - libcblas=3.9.0=9_mkl 44 | - libcurl=7.77.0=h2574ce0_0 45 | - libedit=3.1.20191231=he28a2e2_2 46 | - libev=4.33=h516909a_1 47 | - libffi=3.3=he6710b0_2 48 | - libgcc-ng=9.3.0=h5101ec6_17 49 | - libgfortran-ng=7.5.0=h14aa051_19 50 | - libgfortran4=7.5.0=h14aa051_19 51 | - libgomp=9.3.0=h5101ec6_17 52 | - libnghttp2=1.43.0=h812cca2_0 53 | - libpng=1.6.37=hbc83047_0 54 | - libprotobuf=3.17.2=h780b84a_0 55 | - libssh2=1.9.0=ha56f1ee_6 56 | - libstdcxx-ng=9.3.0=hd4cf53a_17 57 | - libtiff=4.2.0=h85742a9_0 58 | - libwebp-base=1.2.0=h27cfd23_0 59 | - llvm-openmp=8.0.1=hc9558a2_0 60 | - lz4-c=1.9.3=h2531618_0 61 | - markdown=3.3.4=pyhd8ed1ab_0 62 | - mkl=2021.2.0=h06a4308_296 63 | - mkl-service=2.3.0=py37h27cfd23_1 64 | - mkl_fft=1.3.0=py37h42c9631_2 65 | - mkl_random=1.2.1=py37ha9443f7_2 66 | - multidict=5.1.0=py37h5e8e339_1 67 | - ncurses=6.2=he6710b0_1 68 | - ninja=1.10.2=hff7bd54_1 69 | - numpy=1.20.2=py37h2d18471_0 70 | - numpy-base=1.20.2=py37hfae3a4d_0 71 | - oauthlib=3.1.1=pyhd8ed1ab_0 72 | - olefile=0.46=py37_0 73 | - openmp=8.0.1=0 74 | - openssl=1.1.1k=h7f98852_0 75 | - pillow=8.2.0=py37he98fc37_0 76 | - pip=21.1.3=py37h06a4308_0 77 | - point_cloud_utils=0.18.0=py37h6dcda5c_1 78 | - protobuf=3.17.2=py37hcd2ae1e_0 79 | - pyasn1=0.4.8=py_0 80 | - pyasn1-modules=0.2.7=py_0 81 | - pycparser=2.20=pyh9f0ad1d_2 82 | - pyjwt=2.1.0=pyhd8ed1ab_0 83 | - pyopenssl=20.0.1=pyhd8ed1ab_0 84 | - pysocks=1.7.1=py37h89c1867_3 85 | - python=3.7.10=h12debd9_4 86 | - python_abi=3.7=2_cp37m 87 | - pytorch=1.6.0=py3.7_cuda10.1.243_cudnn7.6.3_0 88 | - pyu2f=0.1.5=pyhd8ed1ab_0 89 | - readline=8.1=h27cfd23_0 90 | - requests=2.25.1=pyhd3deb0d_0 91 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 92 | - rsa=4.7.2=pyh44b312d_0 93 | - scikit-learn=0.24.2=py37h18a542f_0 94 | - scipy=1.6.2=py37had2a1c9_1 95 | - setuptools=52.0.0=py37h06a4308_0 96 | - six=1.16.0=pyhd3eb1b0_0 97 | - sqlite=3.36.0=hc218d9a_0 98 | - tensorboard=2.5.0=pyhd8ed1ab_0 99 | - tensorboard-data-server=0.6.0=py37h7f0c10b_0 100 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 101 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 102 | - tk=8.6.10=hbc83047_0 103 | - torchvision=0.7.0=py37_cu101 104 | - tqdm=4.61.1=pyhd8ed1ab_0 105 | - typing-extensions=3.10.0.0=hd8ed1ab_0 106 | - typing_extensions=3.10.0.0=pyha770c72_0 107 | - urllib3=1.26.6=pyhd8ed1ab_0 108 | - werkzeug=2.0.1=pyhd8ed1ab_0 109 | - wheel=0.36.2=pyhd3eb1b0_0 110 | - xz=5.2.5=h7b6447c_0 111 | - yarl=1.6.3=py37h5e8e339_1 112 | - zipp=3.4.1=pyhd8ed1ab_0 113 | - zlib=1.2.11=h7b6447c_3 114 | - zstd=1.4.9=haebb681_0 115 | -------------------------------------------------------------------------------- /evaluation/.gitignore: -------------------------------------------------------------------------------- 1 | StructuralLosses 2 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluation_metrics import * 2 | from .evaluation_metrics import _pairwise_EMD_CD_, _jsdiv 3 | -------------------------------------------------------------------------------- /evaluation/evaluation_metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | From https://github.com/stevenygd/PointFlow/tree/master/metrics 3 | """ 4 | import torch 5 | import numpy as np 6 | import warnings 7 | from scipy.stats import entropy 8 | from sklearn.neighbors import NearestNeighbors 9 | from numpy.linalg import norm 10 | from tqdm.auto import tqdm 11 | 12 | 13 | _EMD_NOT_IMPL_WARNED = False 14 | def emd_approx(sample, ref): 15 | global _EMD_NOT_IMPL_WARNED 16 | emd = torch.zeros([sample.size(0)]).to(sample) 17 | if not _EMD_NOT_IMPL_WARNED: 18 | _EMD_NOT_IMPL_WARNED = True 19 | print('\n\n[WARNING]') 20 | print(' * EMD is not implemented due to GPU compatability issue.') 21 | print(' * We will set all EMD to zero by default.') 22 | print(' * You may implement your own EMD in the function `emd_approx` in ./evaluation/evaluation_metrics.py') 23 | print('\n') 24 | return emd 25 | 26 | 27 | # Borrow from https://github.com/ThibaultGROUEIX/AtlasNet 28 | def distChamfer(a, b): 29 | x, y = a, b 30 | bs, num_points, points_dim = x.size() 31 | xx = torch.bmm(x, x.transpose(2, 1)) 32 | yy = torch.bmm(y, y.transpose(2, 1)) 33 | zz = torch.bmm(x, y.transpose(2, 1)) 34 | diag_ind = torch.arange(0, num_points).to(a).long() 35 | rx = xx[:, diag_ind, diag_ind].unsqueeze(1).expand_as(xx) 36 | ry = yy[:, diag_ind, diag_ind].unsqueeze(1).expand_as(yy) 37 | P = (rx.transpose(2, 1) + ry - 2 * zz) 38 | return P.min(1)[0], P.min(2)[0] 39 | 40 | 41 | def EMD_CD(sample_pcs, ref_pcs, batch_size, reduced=True): 42 | N_sample = sample_pcs.shape[0] 43 | N_ref = ref_pcs.shape[0] 44 | assert N_sample == N_ref, "REF:%d SMP:%d" % (N_ref, N_sample) 45 | 46 | cd_lst = [] 47 | emd_lst = [] 48 | iterator = range(0, N_sample, batch_size) 49 | 50 | for b_start in tqdm(iterator, desc='EMD-CD'): 51 | b_end = min(N_sample, b_start + batch_size) 52 | sample_batch = sample_pcs[b_start:b_end] 53 | ref_batch = ref_pcs[b_start:b_end] 54 | 55 | dl, dr = distChamfer(sample_batch, ref_batch) 56 | cd_lst.append(dl.mean(dim=1) + dr.mean(dim=1)) 57 | 58 | emd_batch = emd_approx(sample_batch, ref_batch) 59 | emd_lst.append(emd_batch) 60 | 61 | if reduced: 62 | cd = torch.cat(cd_lst).mean() 63 | emd = torch.cat(emd_lst).mean() 64 | else: 65 | cd = torch.cat(cd_lst) 66 | emd = torch.cat(emd_lst) 67 | 68 | results = { 69 | 'MMD-CD': cd, 70 | 'MMD-EMD': emd, 71 | } 72 | return results 73 | 74 | 75 | def _pairwise_EMD_CD_(sample_pcs, ref_pcs, batch_size, verbose=True): 76 | N_sample = sample_pcs.shape[0] 77 | N_ref = ref_pcs.shape[0] 78 | all_cd = [] 79 | all_emd = [] 80 | iterator = range(N_sample) 81 | if verbose: 82 | iterator = tqdm(iterator, desc='Pairwise EMD-CD') 83 | for sample_b_start in iterator: 84 | sample_batch = sample_pcs[sample_b_start] 85 | 86 | cd_lst = [] 87 | emd_lst = [] 88 | sub_iterator = range(0, N_ref, batch_size) 89 | # if verbose: 90 | # sub_iterator = tqdm(sub_iterator, leave=False) 91 | for ref_b_start in sub_iterator: 92 | ref_b_end = min(N_ref, ref_b_start + batch_size) 93 | ref_batch = ref_pcs[ref_b_start:ref_b_end] 94 | 95 | batch_size_ref = ref_batch.size(0) 96 | point_dim = ref_batch.size(2) 97 | sample_batch_exp = sample_batch.view(1, -1, point_dim).expand( 98 | batch_size_ref, -1, -1) 99 | sample_batch_exp = sample_batch_exp.contiguous() 100 | 101 | dl, dr = distChamfer(sample_batch_exp, ref_batch) 102 | cd_lst.append((dl.mean(dim=1) + dr.mean(dim=1)).view(1, -1)) 103 | 104 | emd_batch = emd_approx(sample_batch_exp, ref_batch) 105 | emd_lst.append(emd_batch.view(1, -1)) 106 | 107 | cd_lst = torch.cat(cd_lst, dim=1) 108 | emd_lst = torch.cat(emd_lst, dim=1) 109 | all_cd.append(cd_lst) 110 | all_emd.append(emd_lst) 111 | 112 | all_cd = torch.cat(all_cd, dim=0) # N_sample, N_ref 113 | all_emd = torch.cat(all_emd, dim=0) # N_sample, N_ref 114 | 115 | return all_cd, all_emd 116 | 117 | 118 | # Adapted from https://github.com/xuqiantong/ 119 | # GAN-Metrics/blob/master/framework/metric.py 120 | def knn(Mxx, Mxy, Myy, k, sqrt=False): 121 | n0 = Mxx.size(0) 122 | n1 = Myy.size(0) 123 | label = torch.cat((torch.ones(n0), torch.zeros(n1))).to(Mxx) 124 | M = torch.cat([ 125 | torch.cat((Mxx, Mxy), 1), 126 | torch.cat((Mxy.transpose(0, 1), Myy), 1)], 0) 127 | if sqrt: 128 | M = M.abs().sqrt() 129 | INFINITY = float('inf') 130 | val, idx = (M + torch.diag(INFINITY * torch.ones(n0 + n1).to(Mxx))).topk( 131 | k, 0, False) 132 | 133 | count = torch.zeros(n0 + n1).to(Mxx) 134 | for i in range(0, k): 135 | count = count + label.index_select(0, idx[i]) 136 | pred = torch.ge(count, (float(k) / 2) * torch.ones(n0 + n1).to(Mxx)).float() 137 | 138 | s = { 139 | 'tp': (pred * label).sum(), 140 | 'fp': (pred * (1 - label)).sum(), 141 | 'fn': ((1 - pred) * label).sum(), 142 | 'tn': ((1 - pred) * (1 - label)).sum(), 143 | } 144 | 145 | s.update({ 146 | 'precision': s['tp'] / (s['tp'] + s['fp'] + 1e-10), 147 | 'recall': s['tp'] / (s['tp'] + s['fn'] + 1e-10), 148 | 'acc_t': s['tp'] / (s['tp'] + s['fn'] + 1e-10), 149 | 'acc_f': s['tn'] / (s['tn'] + s['fp'] + 1e-10), 150 | 'acc': torch.eq(label, pred).float().mean(), 151 | }) 152 | return s 153 | 154 | 155 | def lgan_mmd_cov(all_dist): 156 | N_sample, N_ref = all_dist.size(0), all_dist.size(1) 157 | min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) 158 | min_val, _ = torch.min(all_dist, dim=0) 159 | mmd = min_val.mean() 160 | mmd_smp = min_val_fromsmp.mean() 161 | cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) 162 | cov = torch.tensor(cov).to(all_dist) 163 | return { 164 | 'lgan_mmd': mmd, 165 | 'lgan_cov': cov, 166 | 'lgan_mmd_smp': mmd_smp, 167 | } 168 | 169 | 170 | def lgan_mmd_cov_match(all_dist): 171 | N_sample, N_ref = all_dist.size(0), all_dist.size(1) 172 | min_val_fromsmp, min_idx = torch.min(all_dist, dim=1) 173 | min_val, _ = torch.min(all_dist, dim=0) 174 | mmd = min_val.mean() 175 | mmd_smp = min_val_fromsmp.mean() 176 | cov = float(min_idx.unique().view(-1).size(0)) / float(N_ref) 177 | cov = torch.tensor(cov).to(all_dist) 178 | return { 179 | 'lgan_mmd': mmd, 180 | 'lgan_cov': cov, 181 | 'lgan_mmd_smp': mmd_smp, 182 | }, min_idx.view(-1) 183 | 184 | 185 | def compute_all_metrics(sample_pcs, ref_pcs, batch_size): 186 | results = {} 187 | 188 | print("Pairwise EMD CD") 189 | M_rs_cd, M_rs_emd = _pairwise_EMD_CD_(ref_pcs, sample_pcs, batch_size) 190 | 191 | ## CD 192 | res_cd = lgan_mmd_cov(M_rs_cd.t()) 193 | results.update({ 194 | "%s-CD" % k: v for k, v in res_cd.items() 195 | }) 196 | 197 | ## EMD 198 | # res_emd = lgan_mmd_cov(M_rs_emd.t()) 199 | # results.update({ 200 | # "%s-EMD" % k: v for k, v in res_emd.items() 201 | # }) 202 | 203 | for k, v in results.items(): 204 | print('[%s] %.8f' % (k, v.item())) 205 | 206 | M_rr_cd, M_rr_emd = _pairwise_EMD_CD_(ref_pcs, ref_pcs, batch_size) 207 | M_ss_cd, M_ss_emd = _pairwise_EMD_CD_(sample_pcs, sample_pcs, batch_size) 208 | 209 | # 1-NN results 210 | ## CD 211 | one_nn_cd_res = knn(M_rr_cd, M_rs_cd, M_ss_cd, 1, sqrt=False) 212 | results.update({ 213 | "1-NN-CD-%s" % k: v for k, v in one_nn_cd_res.items() if 'acc' in k 214 | }) 215 | ## EMD 216 | # one_nn_emd_res = knn(M_rr_emd, M_rs_emd, M_ss_emd, 1, sqrt=False) 217 | # results.update({ 218 | # "1-NN-EMD-%s" % k: v for k, v in one_nn_emd_res.items() if 'acc' in k 219 | # }) 220 | 221 | return results 222 | 223 | 224 | ####################################################### 225 | # JSD : from https://github.com/optas/latent_3d_points 226 | ####################################################### 227 | def unit_cube_grid_point_cloud(resolution, clip_sphere=False): 228 | """Returns the center coordinates of each cell of a 3D grid with 229 | resolution^3 cells, that is placed in the unit-cube. If clip_sphere it True 230 | it drops the "corner" cells that lie outside the unit-sphere. 231 | """ 232 | grid = np.ndarray((resolution, resolution, resolution, 3), np.float32) 233 | spacing = 1.0 / float(resolution - 1) 234 | for i in range(resolution): 235 | for j in range(resolution): 236 | for k in range(resolution): 237 | grid[i, j, k, 0] = i * spacing - 0.5 238 | grid[i, j, k, 1] = j * spacing - 0.5 239 | grid[i, j, k, 2] = k * spacing - 0.5 240 | 241 | if clip_sphere: 242 | grid = grid.reshape(-1, 3) 243 | grid = grid[norm(grid, axis=1) <= 0.5] 244 | 245 | return grid, spacing 246 | 247 | 248 | def jsd_between_point_cloud_sets( 249 | sample_pcs, ref_pcs, resolution=28): 250 | """Computes the JSD between two sets of point-clouds, 251 | as introduced in the paper 252 | ```Learning Representations And Generative Models For 3D Point Clouds```. 253 | Args: 254 | sample_pcs: (np.ndarray S1xR2x3) S1 point-clouds, each of R1 points. 255 | ref_pcs: (np.ndarray S2xR2x3) S2 point-clouds, each of R2 points. 256 | resolution: (int) grid-resolution. Affects granularity of measurements. 257 | """ 258 | in_unit_sphere = True 259 | sample_grid_var = entropy_of_occupancy_grid( 260 | sample_pcs, resolution, in_unit_sphere)[1] 261 | ref_grid_var = entropy_of_occupancy_grid( 262 | ref_pcs, resolution, in_unit_sphere)[1] 263 | return jensen_shannon_divergence(sample_grid_var, ref_grid_var) 264 | 265 | 266 | def entropy_of_occupancy_grid( 267 | pclouds, grid_resolution, in_sphere=False, verbose=False): 268 | """Given a collection of point-clouds, estimate the entropy of 269 | the random variables corresponding to occupancy-grid activation patterns. 270 | Inputs: 271 | pclouds: (numpy array) #point-clouds x points per point-cloud x 3 272 | grid_resolution (int) size of occupancy grid that will be used. 273 | """ 274 | epsilon = 10e-4 275 | bound = 0.5 + epsilon 276 | if abs(np.max(pclouds)) > bound or abs(np.min(pclouds)) > bound: 277 | if verbose: 278 | warnings.warn('Point-clouds are not in unit cube.') 279 | 280 | if in_sphere and np.max(np.sqrt(np.sum(pclouds ** 2, axis=2))) > bound: 281 | if verbose: 282 | warnings.warn('Point-clouds are not in unit sphere.') 283 | 284 | grid_coordinates, _ = unit_cube_grid_point_cloud(grid_resolution, in_sphere) 285 | grid_coordinates = grid_coordinates.reshape(-1, 3) 286 | grid_counters = np.zeros(len(grid_coordinates)) 287 | grid_bernoulli_rvars = np.zeros(len(grid_coordinates)) 288 | nn = NearestNeighbors(n_neighbors=1).fit(grid_coordinates) 289 | 290 | for pc in tqdm(pclouds, desc='JSD'): 291 | _, indices = nn.kneighbors(pc) 292 | indices = np.squeeze(indices) 293 | for i in indices: 294 | grid_counters[i] += 1 295 | indices = np.unique(indices) 296 | for i in indices: 297 | grid_bernoulli_rvars[i] += 1 298 | 299 | acc_entropy = 0.0 300 | n = float(len(pclouds)) 301 | for g in grid_bernoulli_rvars: 302 | if g > 0: 303 | p = float(g) / n 304 | acc_entropy += entropy([p, 1.0 - p]) 305 | 306 | return acc_entropy / len(grid_counters), grid_counters 307 | 308 | 309 | def jensen_shannon_divergence(P, Q): 310 | if np.any(P < 0) or np.any(Q < 0): 311 | raise ValueError('Negative values.') 312 | if len(P) != len(Q): 313 | raise ValueError('Non equal size.') 314 | 315 | P_ = P / np.sum(P) # Ensure probabilities. 316 | Q_ = Q / np.sum(Q) 317 | 318 | e1 = entropy(P_, base=2) 319 | e2 = entropy(Q_, base=2) 320 | e_sum = entropy((P_ + Q_) / 2.0, base=2) 321 | res = e_sum - ((e1 + e2) / 2.0) 322 | 323 | res2 = _jsdiv(P_, Q_) 324 | 325 | if not np.allclose(res, res2, atol=10e-5, rtol=0): 326 | warnings.warn('Numerical values of two JSD methods don\'t agree.') 327 | 328 | return res 329 | 330 | 331 | def _jsdiv(P, Q): 332 | """another way of computing JSD""" 333 | 334 | def _kldiv(A, B): 335 | a = A.copy() 336 | b = B.copy() 337 | idx = np.logical_and(a > 0, b > 0) 338 | a = a[idx] 339 | b = b[idx] 340 | return np.sum([v for v in a * np.log2(a / b)]) 341 | 342 | P_ = P / np.sum(P) 343 | Q_ = Q / np.sum(Q) 344 | 345 | M = 0.5 * (P_ + Q_) 346 | 347 | return 0.5 * (_kldiv(P_, M) + _kldiv(Q_, M)) 348 | 349 | 350 | if __name__ == '__main__': 351 | a = torch.randn([16, 2048, 3]).cuda() 352 | b = torch.randn([16, 2048, 3]).cuda() 353 | print(EMD_CD(a, b, batch_size=8)) 354 | -------------------------------------------------------------------------------- /models/autoencoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | 4 | from .encoders import * 5 | from .diffusion import * 6 | 7 | 8 | class AutoEncoder(Module): 9 | 10 | def __init__(self, args): 11 | super().__init__() 12 | self.args = args 13 | self.encoder = PointNetEncoder(zdim=args.latent_dim) 14 | self.diffusion = DiffusionPoint( 15 | net = PointwiseNet(point_dim=3, context_dim=args.latent_dim, residual=args.residual), 16 | var_sched = VarianceSchedule( 17 | num_steps=args.num_steps, 18 | beta_1=args.beta_1, 19 | beta_T=args.beta_T, 20 | mode=args.sched_mode 21 | ) 22 | ) 23 | 24 | def encode(self, x): 25 | """ 26 | Args: 27 | x: Point clouds to be encoded, (B, N, d). 28 | """ 29 | code, _ = self.encoder(x) 30 | return code 31 | 32 | def decode(self, code, num_points, flexibility=0.0, ret_traj=False): 33 | return self.diffusion.sample(num_points, code, flexibility=flexibility, ret_traj=ret_traj) 34 | 35 | def get_loss(self, x): 36 | code = self.encode(x) 37 | loss = self.diffusion.get_loss(x, code) 38 | return loss 39 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module, Linear 3 | from torch.optim.lr_scheduler import LambdaLR 4 | import numpy as np 5 | 6 | def reparameterize_gaussian(mean, logvar): 7 | std = torch.exp(0.5 * logvar) 8 | eps = torch.randn(std.size()).to(mean) 9 | return mean + std * eps 10 | 11 | 12 | def gaussian_entropy(logvar): 13 | const = 0.5 * float(logvar.size(1)) * (1. + np.log(np.pi * 2)) 14 | ent = 0.5 * logvar.sum(dim=1, keepdim=False) + const 15 | return ent 16 | 17 | 18 | def standard_normal_logprob(z): 19 | dim = z.size(-1) 20 | log_z = -0.5 * dim * np.log(2 * np.pi) 21 | return log_z - z.pow(2) / 2 22 | 23 | 24 | def truncated_normal_(tensor, mean=0, std=1, trunc_std=2): 25 | """ 26 | Taken from https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15 27 | """ 28 | size = tensor.shape 29 | tmp = tensor.new_empty(size + (4,)).normal_() 30 | valid = (tmp < trunc_std) & (tmp > -trunc_std) 31 | ind = valid.max(-1, keepdim=True)[1] 32 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 33 | tensor.data.mul_(std).add_(mean) 34 | return tensor 35 | 36 | 37 | class ConcatSquashLinear(Module): 38 | def __init__(self, dim_in, dim_out, dim_ctx): 39 | super(ConcatSquashLinear, self).__init__() 40 | self._layer = Linear(dim_in, dim_out) 41 | self._hyper_bias = Linear(dim_ctx, dim_out, bias=False) 42 | self._hyper_gate = Linear(dim_ctx, dim_out) 43 | 44 | def forward(self, ctx, x): 45 | gate = torch.sigmoid(self._hyper_gate(ctx)) 46 | bias = self._hyper_bias(ctx) 47 | # if x.dim() == 3: 48 | # gate = gate.unsqueeze(1) 49 | # bias = bias.unsqueeze(1) 50 | ret = self._layer(x) * gate + bias 51 | return ret 52 | 53 | 54 | def get_linear_scheduler(optimizer, start_epoch, end_epoch, start_lr, end_lr): 55 | def lr_func(epoch): 56 | if epoch <= start_epoch: 57 | return 1.0 58 | elif epoch <= end_epoch: 59 | total = end_epoch - start_epoch 60 | delta = epoch - start_epoch 61 | frac = delta / total 62 | return (1-frac) * 1.0 + frac * (end_lr / start_lr) 63 | else: 64 | return end_lr / start_lr 65 | return LambdaLR(optimizer, lr_lambda=lr_func) 66 | 67 | def lr_func(epoch): 68 | if epoch <= start_epoch: 69 | return 1.0 70 | elif epoch <= end_epoch: 71 | total = end_epoch - start_epoch 72 | delta = epoch - start_epoch 73 | frac = delta / total 74 | return (1-frac) * 1.0 + frac * (end_lr / start_lr) 75 | else: 76 | return end_lr / start_lr 77 | -------------------------------------------------------------------------------- /models/diffusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.nn import Module, Parameter, ModuleList 4 | import numpy as np 5 | 6 | from .common import * 7 | 8 | 9 | class VarianceSchedule(Module): 10 | 11 | def __init__(self, num_steps, beta_1, beta_T, mode='linear'): 12 | super().__init__() 13 | assert mode in ('linear', ) 14 | self.num_steps = num_steps 15 | self.beta_1 = beta_1 16 | self.beta_T = beta_T 17 | self.mode = mode 18 | 19 | if mode == 'linear': 20 | betas = torch.linspace(beta_1, beta_T, steps=num_steps) 21 | 22 | betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding 23 | 24 | alphas = 1 - betas 25 | log_alphas = torch.log(alphas) 26 | for i in range(1, log_alphas.size(0)): # 1 to T 27 | log_alphas[i] += log_alphas[i - 1] 28 | alpha_bars = log_alphas.exp() 29 | 30 | sigmas_flex = torch.sqrt(betas) 31 | sigmas_inflex = torch.zeros_like(sigmas_flex) 32 | for i in range(1, sigmas_flex.size(0)): 33 | sigmas_inflex[i] = ((1 - alpha_bars[i-1]) / (1 - alpha_bars[i])) * betas[i] 34 | sigmas_inflex = torch.sqrt(sigmas_inflex) 35 | 36 | self.register_buffer('betas', betas) 37 | self.register_buffer('alphas', alphas) 38 | self.register_buffer('alpha_bars', alpha_bars) 39 | self.register_buffer('sigmas_flex', sigmas_flex) 40 | self.register_buffer('sigmas_inflex', sigmas_inflex) 41 | 42 | def uniform_sample_t(self, batch_size): 43 | ts = np.random.choice(np.arange(1, self.num_steps+1), batch_size) 44 | return ts.tolist() 45 | 46 | def get_sigmas(self, t, flexibility): 47 | assert 0 <= flexibility and flexibility <= 1 48 | sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * (1 - flexibility) 49 | return sigmas 50 | 51 | 52 | class PointwiseNet(Module): 53 | 54 | def __init__(self, point_dim, context_dim, residual): 55 | super().__init__() 56 | self.act = F.leaky_relu 57 | self.residual = residual 58 | self.layers = ModuleList([ 59 | ConcatSquashLinear(3, 128, context_dim+3), 60 | ConcatSquashLinear(128, 256, context_dim+3), 61 | ConcatSquashLinear(256, 512, context_dim+3), 62 | ConcatSquashLinear(512, 256, context_dim+3), 63 | ConcatSquashLinear(256, 128, context_dim+3), 64 | ConcatSquashLinear(128, 3, context_dim+3) 65 | ]) 66 | 67 | def forward(self, x, beta, context): 68 | """ 69 | Args: 70 | x: Point clouds at some timestep t, (B, N, d). 71 | beta: Time. (B, ). 72 | context: Shape latents. (B, F). 73 | """ 74 | batch_size = x.size(0) 75 | beta = beta.view(batch_size, 1, 1) # (B, 1, 1) 76 | context = context.view(batch_size, 1, -1) # (B, 1, F) 77 | 78 | time_emb = torch.cat([beta, torch.sin(beta), torch.cos(beta)], dim=-1) # (B, 1, 3) 79 | ctx_emb = torch.cat([time_emb, context], dim=-1) # (B, 1, F+3) 80 | 81 | out = x 82 | for i, layer in enumerate(self.layers): 83 | out = layer(ctx=ctx_emb, x=out) 84 | if i < len(self.layers) - 1: 85 | out = self.act(out) 86 | 87 | if self.residual: 88 | return x + out 89 | else: 90 | return out 91 | 92 | 93 | class DiffusionPoint(Module): 94 | 95 | def __init__(self, net, var_sched:VarianceSchedule): 96 | super().__init__() 97 | self.net = net 98 | self.var_sched = var_sched 99 | 100 | def get_loss(self, x_0, context, t=None): 101 | """ 102 | Args: 103 | x_0: Input point cloud, (B, N, d). 104 | context: Shape latent, (B, F). 105 | """ 106 | batch_size, _, point_dim = x_0.size() 107 | if t == None: 108 | t = self.var_sched.uniform_sample_t(batch_size) 109 | alpha_bar = self.var_sched.alpha_bars[t] 110 | beta = self.var_sched.betas[t] 111 | 112 | c0 = torch.sqrt(alpha_bar).view(-1, 1, 1) # (B, 1, 1) 113 | c1 = torch.sqrt(1 - alpha_bar).view(-1, 1, 1) # (B, 1, 1) 114 | 115 | e_rand = torch.randn_like(x_0) # (B, N, d) 116 | e_theta = self.net(c0 * x_0 + c1 * e_rand, beta=beta, context=context) 117 | 118 | loss = F.mse_loss(e_theta.view(-1, point_dim), e_rand.view(-1, point_dim), reduction='mean') 119 | return loss 120 | 121 | def sample(self, num_points, context, point_dim=3, flexibility=0.0, ret_traj=False): 122 | batch_size = context.size(0) 123 | x_T = torch.randn([batch_size, num_points, point_dim]).to(context.device) 124 | traj = {self.var_sched.num_steps: x_T} 125 | for t in range(self.var_sched.num_steps, 0, -1): 126 | z = torch.randn_like(x_T) if t > 1 else torch.zeros_like(x_T) 127 | alpha = self.var_sched.alphas[t] 128 | alpha_bar = self.var_sched.alpha_bars[t] 129 | sigma = self.var_sched.get_sigmas(t, flexibility) 130 | 131 | c0 = 1.0 / torch.sqrt(alpha) 132 | c1 = (1 - alpha) / torch.sqrt(1 - alpha_bar) 133 | 134 | x_t = traj[t] 135 | beta = self.var_sched.betas[[t]*batch_size] 136 | e_theta = self.net(x_t, beta=beta, context=context) 137 | x_next = c0 * (x_t - c1 * e_theta) + sigma * z 138 | traj[t-1] = x_next.detach() # Stop gradient and save trajectory. 139 | traj[t] = traj[t].cpu() # Move previous output to CPU memory. 140 | if not ret_traj: 141 | del traj[t] 142 | 143 | if ret_traj: 144 | return traj 145 | else: 146 | return traj[0] 147 | 148 | -------------------------------------------------------------------------------- /models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .pointnet import * 2 | -------------------------------------------------------------------------------- /models/encoders/pointcnn.py: -------------------------------------------------------------------------------- 1 | from math import ceil 2 | from typing import Optional 3 | 4 | import torch 5 | from torch import Tensor 6 | from torch.nn import functional as F 7 | from torch.nn import ELU 8 | from torch.nn import BatchNorm1d as BN 9 | from torch.nn import Conv1d 10 | from torch.nn import Linear as L 11 | from torch.nn import Sequential as S 12 | 13 | 14 | class Reshape(torch.nn.Module): 15 | def __init__(self, *shape): 16 | super().__init__() 17 | self.shape = shape 18 | 19 | def forward(self, x: Tensor) -> Tensor: 20 | """""" # noqa: D419 21 | x = x.view(*self.shape) 22 | return x 23 | 24 | def __repr__(self) -> str: 25 | shape = ', '.join([str(dim) for dim in self.shape]) 26 | return f'{self.__class__.__name__}({shape})' 27 | 28 | 29 | class XConv(torch.nn.Module): 30 | 31 | def __init__(self, in_channels: int, out_channels: int, dim: int = 3, 32 | kernel_size: int = 32, hidden_channels: Optional[int] = None, 33 | dilation: int = 1, bias: bool = True, num_workers: int = 1): 34 | super().__init__() 35 | 36 | self.in_channels = in_channels 37 | if hidden_channels is None: 38 | hidden_channels = in_channels // 4 39 | assert hidden_channels > 0 40 | self.hidden_channels = hidden_channels 41 | self.out_channels = out_channels 42 | self.dim = dim 43 | self.kernel_size = kernel_size 44 | self.dilation = dilation 45 | self.num_workers = num_workers 46 | 47 | C_in, C_delta, C_out = in_channels, hidden_channels, out_channels 48 | D, K = dim, kernel_size 49 | 50 | self.mlp1 = S( 51 | L(dim, C_delta), 52 | ELU(), 53 | BN(C_delta), 54 | L(C_delta, C_delta), 55 | ELU(), 56 | BN(C_delta), 57 | Reshape(-1, K, C_delta), 58 | ) 59 | 60 | self.mlp2 = S( 61 | L(D * K, K**2), 62 | ELU(), 63 | BN(K**2), 64 | Reshape(-1, K, K), 65 | Conv1d(K, K**2, K, groups=K), 66 | ELU(), 67 | BN(K**2), 68 | Reshape(-1, K, K), 69 | Conv1d(K, K**2, K, groups=K), 70 | BN(K**2), 71 | Reshape(-1, K, K), 72 | ) 73 | 74 | C_in = C_in + C_delta 75 | depth_multiplier = int(ceil(C_out / C_in)) 76 | self.conv = S( 77 | Conv1d(C_in, C_in * depth_multiplier, K, groups=C_in), 78 | Reshape(-1, C_in * depth_multiplier), 79 | L(C_in * depth_multiplier, C_out, bias=bias), 80 | ) 81 | 82 | def forward(self, x: Optional[Tensor], pos: Tensor): 83 | """ 84 | Args: 85 | x: (B, N, F_in): Node features. 86 | pos: (B, N, D): Node positions. 87 | Returns: 88 | (B, N, F_out): Transformed node features. 89 | """ 90 | B, N, D = pos.size() 91 | 92 | # (1) Compute the k-NN graph using topk. 93 | pw_dist = torch.cdist(pos, pos) # (B, N, N) 94 | knn_idx = torch.topk(pw_dist, self.kernel_size * self.dilation, largest=False)[1] # (B, N, K * dilation) 95 | 96 | if self.dilation > 1: 97 | knn_idx = knn_idx[:, :, ::self.dilation] # (B, N, K) 98 | 99 | pos_i = pos[:, :, None, :] 100 | pos_j = pos[:, None, :, :].repeat(1, N, 1, 1) # (B, N, N, D) 101 | pos_j = torch.gather(pos_j, 2, knn_idx[:, :, :, None].expand(-1, -1, -1, D)) 102 | d = pos_i - pos_j # (B, N, K, D) 103 | 104 | x_star = self.mlp1(d.reshape(B * N * self.kernel_size, -1)).reshape(B, N, self.kernel_size, -1) # (B, N, K, C_delta) 105 | if x is not None: 106 | x = x[:, None, :, :].repeat(1, N, 1, 1) # (B, N, 1, F_in) 107 | x = torch.gather(x, 2, knn_idx[:, :, :, None].expand(-1, -1, -1, self.in_channels)) 108 | x_star = torch.cat([x_star, x], dim=-1) # (B, N, K, C_delta + F_in) 109 | 110 | x_star = x_star.reshape(B * N, self.kernel_size, -1) 111 | x_star = x_star.transpose(1, 2).contiguous() # (B * N, C_delta + C_in, K) 112 | 113 | transform_matrix = self.mlp2(d.reshape(B * N, self.kernel_size * D)) # (B * N, K, K) 114 | 115 | x_transformed = torch.matmul(x_star, transform_matrix) # (B * N, C_delta + C_in, K) 116 | 117 | out = self.conv(x_transformed) # (B * N, C_out) 118 | out = out.view(B, N, -1).contiguous() # (B, N, C_out) 119 | return out 120 | 121 | 122 | class PointCNNEncoder(torch.nn.Module): 123 | def __init__(self, zdim, input_dim=3): 124 | super().__init__() 125 | assert input_dim == 3, "Only 3D point clouds are supported." 126 | self.conv1 = XConv(0, 48, dim=3, kernel_size=8, hidden_channels=32) 127 | self.conv2 = XConv(48, 96, dim=3, kernel_size=12, hidden_channels=64, 128 | dilation=2) 129 | self.conv3 = XConv(96, 192, dim=3, kernel_size=16, hidden_channels=128, 130 | dilation=2) 131 | self.conv4 = XConv(192, 384, dim=3, kernel_size=16, 132 | hidden_channels=256, dilation=2) 133 | 134 | self.lin1 = L(384, 256) 135 | self.lin2 = L(256, 128) 136 | self.lin3 = L(128, zdim * 2) 137 | 138 | def forward(self, pos): 139 | x = F.relu(self.conv1(None, pos)) 140 | 141 | randidx = torch.randperm(x.size(1))[:x.size(1) // 2] 142 | x = x[:, randidx, :] 143 | pos = pos[:, randidx, :] 144 | 145 | x = F.relu(self.conv2(x, pos)) 146 | 147 | randidx = torch.randperm(x.size(1))[:x.size(1) // 2] 148 | x = x[:, randidx, :] 149 | pos = pos[:, randidx, :] 150 | 151 | x = F.relu(self.conv3(x, pos)) 152 | x = F.relu(self.conv4(x, pos)) 153 | 154 | x = x.mean(dim=1) # (B, C) 155 | 156 | x = F.relu(self.lin1(x)) 157 | x = F.relu(self.lin2(x)) 158 | x = F.dropout(x, p=0.5, training=self.training) 159 | x = self.lin3(x) 160 | m, v = x.chunk(2, dim=-1) 161 | return m, v 162 | 163 | 164 | if __name__ == '__main__': 165 | pc = torch.randn(2, 1024, 3) 166 | net = PointCNNEncoder(128) 167 | out = net(pc) 168 | print(out) 169 | -------------------------------------------------------------------------------- /models/encoders/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class PointNetEncoder(nn.Module): 7 | def __init__(self, zdim, input_dim=3): 8 | super().__init__() 9 | self.zdim = zdim 10 | self.conv1 = nn.Conv1d(input_dim, 128, 1) 11 | self.conv2 = nn.Conv1d(128, 128, 1) 12 | self.conv3 = nn.Conv1d(128, 256, 1) 13 | self.conv4 = nn.Conv1d(256, 512, 1) 14 | self.bn1 = nn.BatchNorm1d(128) 15 | self.bn2 = nn.BatchNorm1d(128) 16 | self.bn3 = nn.BatchNorm1d(256) 17 | self.bn4 = nn.BatchNorm1d(512) 18 | 19 | # Mapping to [c], cmean 20 | self.fc1_m = nn.Linear(512, 256) 21 | self.fc2_m = nn.Linear(256, 128) 22 | self.fc3_m = nn.Linear(128, zdim) 23 | self.fc_bn1_m = nn.BatchNorm1d(256) 24 | self.fc_bn2_m = nn.BatchNorm1d(128) 25 | 26 | # Mapping to [c], cmean 27 | self.fc1_v = nn.Linear(512, 256) 28 | self.fc2_v = nn.Linear(256, 128) 29 | self.fc3_v = nn.Linear(128, zdim) 30 | self.fc_bn1_v = nn.BatchNorm1d(256) 31 | self.fc_bn2_v = nn.BatchNorm1d(128) 32 | 33 | def forward(self, x): 34 | x = x.transpose(1, 2) 35 | x = F.relu(self.bn1(self.conv1(x))) 36 | x = F.relu(self.bn2(self.conv2(x))) 37 | x = F.relu(self.bn3(self.conv3(x))) 38 | x = self.bn4(self.conv4(x)) 39 | x = torch.max(x, 2, keepdim=True)[0] 40 | x = x.view(-1, 512) 41 | 42 | m = F.relu(self.fc_bn1_m(self.fc1_m(x))) 43 | m = F.relu(self.fc_bn2_m(self.fc2_m(m))) 44 | m = self.fc3_m(m) 45 | v = F.relu(self.fc_bn1_v(self.fc1_v(x))) 46 | v = F.relu(self.fc_bn2_v(self.fc2_v(v))) 47 | v = self.fc3_v(v) 48 | 49 | # Returns both mean and logvariance, just ignore the latter in deteministic cases. 50 | return m, v 51 | 52 | -------------------------------------------------------------------------------- /models/flow.py: -------------------------------------------------------------------------------- 1 | import types 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class CouplingLayer(nn.Module): 8 | 9 | def __init__(self, d, intermediate_dim, swap=False): 10 | nn.Module.__init__(self) 11 | self.d = d - (d // 2) 12 | self.swap = swap 13 | self.net_s_t = nn.Sequential( 14 | nn.Linear(self.d, intermediate_dim), 15 | nn.ReLU(inplace=True), 16 | nn.Linear(intermediate_dim, intermediate_dim), 17 | nn.ReLU(inplace=True), 18 | nn.Linear(intermediate_dim, (d - self.d) * 2), 19 | ) 20 | 21 | def forward(self, x, logpx=None, reverse=False): 22 | 23 | if self.swap: 24 | x = torch.cat([x[:, self.d:], x[:, :self.d]], 1) 25 | 26 | in_dim = self.d 27 | out_dim = x.shape[1] - self.d 28 | 29 | s_t = self.net_s_t(x[:, :in_dim]) 30 | scale = torch.sigmoid(s_t[:, :out_dim] + 2.) 31 | shift = s_t[:, out_dim:] 32 | 33 | logdetjac = torch.sum(torch.log(scale).view(scale.shape[0], -1), 1, keepdim=True) 34 | 35 | if not reverse: 36 | y1 = x[:, self.d:] * scale + shift 37 | delta_logp = -logdetjac 38 | else: 39 | y1 = (x[:, self.d:] - shift) / scale 40 | delta_logp = logdetjac 41 | 42 | y = torch.cat([x[:, :self.d], y1], 1) if not self.swap else torch.cat([y1, x[:, :self.d]], 1) 43 | 44 | if logpx is None: 45 | return y 46 | else: 47 | return y, logpx + delta_logp 48 | 49 | 50 | class SequentialFlow(nn.Module): 51 | """A generalized nn.Sequential container for normalizing flows. 52 | """ 53 | 54 | def __init__(self, layersList): 55 | super(SequentialFlow, self).__init__() 56 | self.chain = nn.ModuleList(layersList) 57 | 58 | def forward(self, x, logpx=None, reverse=False, inds=None): 59 | if inds is None: 60 | if reverse: 61 | inds = range(len(self.chain) - 1, -1, -1) 62 | else: 63 | inds = range(len(self.chain)) 64 | 65 | if logpx is None: 66 | for i in inds: 67 | x = self.chain[i](x, reverse=reverse) 68 | return x 69 | else: 70 | for i in inds: 71 | x, logpx = self.chain[i](x, logpx, reverse=reverse) 72 | return x, logpx 73 | 74 | 75 | def build_latent_flow(args): 76 | chain = [] 77 | for i in range(args.latent_flow_depth): 78 | chain.append(CouplingLayer(args.latent_dim, args.latent_flow_hidden_dim, swap=(i % 2 == 0))) 79 | return SequentialFlow(chain) 80 | 81 | 82 | ################## 83 | ## SpectralNorm ## 84 | ################## 85 | 86 | POWER_ITERATION_FN = "spectral_norm_power_iteration" 87 | 88 | 89 | class SpectralNorm(object): 90 | def __init__(self, name='weight', dim=0, eps=1e-12): 91 | self.name = name 92 | self.dim = dim 93 | self.eps = eps 94 | 95 | def compute_weight(self, module, n_power_iterations): 96 | if n_power_iterations < 0: 97 | raise ValueError( 98 | 'Expected n_power_iterations to be non-negative, but ' 99 | 'got n_power_iterations={}'.format(n_power_iterations) 100 | ) 101 | 102 | weight = getattr(module, self.name + '_orig') 103 | u = getattr(module, self.name + '_u') 104 | v = getattr(module, self.name + '_v') 105 | weight_mat = weight 106 | if self.dim != 0: 107 | # permute dim to front 108 | weight_mat = weight_mat.permute(self.dim, * [d for d in range(weight_mat.dim()) if d != self.dim]) 109 | height = weight_mat.size(0) 110 | weight_mat = weight_mat.reshape(height, -1) 111 | with torch.no_grad(): 112 | for _ in range(n_power_iterations): 113 | # Spectral norm of weight equals to `u^T W v`, where `u` and `v` 114 | # are the first left and right singular vectors. 115 | # This power iteration produces approximations of `u` and `v`. 116 | v = F.normalize(torch.matmul(weight_mat.t(), u), dim=0, eps=self.eps) 117 | u = F.normalize(torch.matmul(weight_mat, v), dim=0, eps=self.eps) 118 | setattr(module, self.name + '_u', u) 119 | setattr(module, self.name + '_v', v) 120 | 121 | sigma = torch.dot(u, torch.matmul(weight_mat, v)) 122 | weight = weight / sigma 123 | setattr(module, self.name, weight) 124 | 125 | def remove(self, module): 126 | weight = getattr(module, self.name) 127 | delattr(module, self.name) 128 | delattr(module, self.name + '_u') 129 | delattr(module, self.name + '_orig') 130 | module.register_parameter(self.name, torch.nn.Parameter(weight)) 131 | 132 | def get_update_method(self, module): 133 | def update_fn(module, n_power_iterations): 134 | self.compute_weight(module, n_power_iterations) 135 | 136 | return update_fn 137 | 138 | def __call__(self, module, unused_inputs): 139 | del unused_inputs 140 | self.compute_weight(module, n_power_iterations=0) 141 | 142 | # requires_grad might be either True or False during inference. 143 | if not module.training: 144 | r_g = getattr(module, self.name + '_orig').requires_grad 145 | setattr(module, self.name, getattr(module, self.name).detach().requires_grad_(r_g)) 146 | 147 | @staticmethod 148 | def apply(module, name, dim, eps): 149 | fn = SpectralNorm(name, dim, eps) 150 | weight = module._parameters[name] 151 | height = weight.size(dim) 152 | 153 | u = F.normalize(weight.new_empty(height).normal_(0, 1), dim=0, eps=fn.eps) 154 | v = F.normalize(weight.new_empty(int(weight.numel() / height)).normal_(0, 1), dim=0, eps=fn.eps) 155 | delattr(module, fn.name) 156 | module.register_parameter(fn.name + "_orig", weight) 157 | # We still need to assign weight back as fn.name because all sorts of 158 | # things may assume that it exists, e.g., when initializing weights. 159 | # However, we can't directly assign as it could be an nn.Parameter and 160 | # gets added as a parameter. Instead, we register weight.data as a 161 | # buffer, which will cause weight to be included in the state dict 162 | # and also supports nn.init due to shared storage. 163 | module.register_buffer(fn.name, weight.data) 164 | module.register_buffer(fn.name + "_u", u) 165 | module.register_buffer(fn.name + "_v", v) 166 | 167 | setattr(module, POWER_ITERATION_FN, types.MethodType(fn.get_update_method(module), module)) 168 | 169 | module.register_forward_pre_hook(fn) 170 | return fn 171 | 172 | 173 | def inplace_spectral_norm(module, name='weight', dim=None, eps=1e-12): 174 | r"""Applies spectral normalization to a parameter in the given module. 175 | .. math:: 176 | \mathbf{W} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})} \\ 177 | \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} 178 | Spectral normalization stabilizes the training of discriminators (critics) 179 | in Generaive Adversarial Networks (GANs) by rescaling the weight tensor 180 | with spectral norm :math:`\sigma` of the weight matrix calculated using 181 | power iteration method. If the dimension of the weight tensor is greater 182 | than 2, it is reshaped to 2D in power iteration method to get spectral 183 | norm. This is implemented via a hook that calculates spectral norm and 184 | rescales weight before every :meth:`~Module.forward` call. 185 | See `Spectral Normalization for Generative Adversarial Networks`_ . 186 | .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 187 | Args: 188 | module (nn.Module): containing module 189 | name (str, optional): name of weight parameter 190 | n_power_iterations (int, optional): number of power iterations to 191 | calculate spectal norm 192 | dim (int, optional): dimension corresponding to number of outputs, 193 | the default is 0, except for modules that are instances of 194 | ConvTranspose1/2/3d, when it is 1 195 | eps (float, optional): epsilon for numerical stability in 196 | calculating norms 197 | Returns: 198 | The original module with the spectal norm hook 199 | Example:: 200 | >>> m = spectral_norm(nn.Linear(20, 40)) 201 | Linear (20 -> 40) 202 | >>> m.weight_u.size() 203 | torch.Size([20]) 204 | """ 205 | if dim is None: 206 | if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)): 207 | dim = 1 208 | else: 209 | dim = 0 210 | SpectralNorm.apply(module, name, dim=dim, eps=eps) 211 | return module 212 | 213 | 214 | def remove_spectral_norm(module, name='weight'): 215 | r"""Removes the spectral normalization reparameterization from a module. 216 | Args: 217 | module (nn.Module): containing module 218 | name (str, optional): name of weight parameter 219 | Example: 220 | >>> m = spectral_norm(nn.Linear(40, 10)) 221 | >>> remove_spectral_norm(m) 222 | """ 223 | for k, hook in module._forward_pre_hooks.items(): 224 | if isinstance(hook, SpectralNorm) and hook.name == name: 225 | hook.remove(module) 226 | del module._forward_pre_hooks[k] 227 | return module 228 | 229 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) 230 | 231 | 232 | def add_spectral_norm(model, logger=None): 233 | """Applies spectral norm to all modules within the scope of a CNF.""" 234 | 235 | def apply_spectral_norm(module): 236 | if 'weight' in module._parameters: 237 | if logger: logger.info("Adding spectral norm to {}".format(module)) 238 | inplace_spectral_norm(module, 'weight') 239 | 240 | def find_coupling_layer(module): 241 | if isinstance(module, CouplingLayer): 242 | module.apply(apply_spectral_norm) 243 | else: 244 | for child in module.children(): 245 | find_coupling_layer(child) 246 | 247 | find_coupling_layer(model) 248 | 249 | 250 | def spectral_norm_power_iteration(model, n_power_iterations=1): 251 | 252 | def recursive_power_iteration(module): 253 | if hasattr(module, POWER_ITERATION_FN): 254 | getattr(module, POWER_ITERATION_FN)(n_power_iterations) 255 | 256 | model.apply(recursive_power_iteration) 257 | 258 | -------------------------------------------------------------------------------- /models/vae_flow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | 4 | from .common import * 5 | from .encoders import * 6 | from .diffusion import * 7 | from .flow import * 8 | 9 | 10 | class FlowVAE(Module): 11 | 12 | def __init__(self, args): 13 | super().__init__() 14 | self.args = args 15 | self.encoder = PointNetEncoder(args.latent_dim) 16 | self.flow = build_latent_flow(args) 17 | self.diffusion = DiffusionPoint( 18 | net = PointwiseNet(point_dim=3, context_dim=args.latent_dim, residual=args.residual), 19 | var_sched = VarianceSchedule( 20 | num_steps=args.num_steps, 21 | beta_1=args.beta_1, 22 | beta_T=args.beta_T, 23 | mode=args.sched_mode 24 | ) 25 | ) 26 | 27 | def get_loss(self, x, kl_weight, writer=None, it=None): 28 | """ 29 | Args: 30 | x: Input point clouds, (B, N, d). 31 | """ 32 | batch_size, _, _ = x.size() 33 | # print(x.size()) 34 | z_mu, z_sigma = self.encoder(x) 35 | z = reparameterize_gaussian(mean=z_mu, logvar=z_sigma) # (B, F) 36 | 37 | # H[Q(z|X)] 38 | entropy = gaussian_entropy(logvar=z_sigma) # (B, ) 39 | 40 | # P(z), Prior probability, parameterized by the flow: z -> w. 41 | w, delta_log_pw = self.flow(z, torch.zeros([batch_size, 1]).to(z), reverse=False) 42 | log_pw = standard_normal_logprob(w).view(batch_size, -1).sum(dim=1, keepdim=True) # (B, 1) 43 | log_pz = log_pw - delta_log_pw.view(batch_size, 1) # (B, 1) 44 | 45 | # Negative ELBO of P(X|z) 46 | neg_elbo = self.diffusion.get_loss(x, z) 47 | 48 | # Loss 49 | loss_entropy = -entropy.mean() 50 | loss_prior = -log_pz.mean() 51 | loss_recons = neg_elbo 52 | loss = kl_weight*(loss_entropy + loss_prior) + neg_elbo 53 | 54 | if writer is not None: 55 | writer.add_scalar('train/loss_entropy', loss_entropy, it) 56 | writer.add_scalar('train/loss_prior', loss_prior, it) 57 | writer.add_scalar('train/loss_recons', loss_recons, it) 58 | writer.add_scalar('train/z_mean', z_mu.mean(), it) 59 | writer.add_scalar('train/z_mag', z_mu.abs().max(), it) 60 | writer.add_scalar('train/z_var', (0.5*z_sigma).exp().mean(), it) 61 | 62 | return loss 63 | 64 | def sample(self, w, num_points, flexibility, truncate_std=None): 65 | batch_size, _ = w.size() 66 | if truncate_std is not None: 67 | w = truncated_normal_(w, mean=0, std=1, trunc_std=truncate_std) 68 | # Reverse: z <- w. 69 | z = self.flow(w, reverse=True).view(batch_size, -1) 70 | samples = self.diffusion.sample(num_points, context=z, flexibility=flexibility) 71 | return samples 72 | -------------------------------------------------------------------------------- /models/vae_gaussian.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import Module 3 | 4 | from .common import * 5 | from .encoders import * 6 | from .diffusion import * 7 | 8 | 9 | class GaussianVAE(Module): 10 | 11 | def __init__(self, args): 12 | super().__init__() 13 | self.args = args 14 | self.encoder = PointNetEncoder(args.latent_dim) 15 | self.diffusion = DiffusionPoint( 16 | net = PointwiseNet(point_dim=3, context_dim=args.latent_dim, residual=args.residual), 17 | var_sched = VarianceSchedule( 18 | num_steps=args.num_steps, 19 | beta_1=args.beta_1, 20 | beta_T=args.beta_T, 21 | mode=args.sched_mode 22 | ) 23 | ) 24 | 25 | def get_loss(self, x, writer=None, it=None, kl_weight=1.0): 26 | """ 27 | Args: 28 | x: Input point clouds, (B, N, d). 29 | """ 30 | batch_size, _, _ = x.size() 31 | z_mu, z_sigma = self.encoder(x) 32 | z = reparameterize_gaussian(mean=z_mu, logvar=z_sigma) # (B, F) 33 | log_pz = standard_normal_logprob(z).sum(dim=1) # (B, ), Independence assumption 34 | entropy = gaussian_entropy(logvar=z_sigma) # (B, ) 35 | loss_prior = (- log_pz - entropy).mean() 36 | 37 | loss_recons = self.diffusion.get_loss(x, z) 38 | 39 | loss = kl_weight * loss_prior + loss_recons 40 | 41 | if writer is not None: 42 | writer.add_scalar('train/loss_entropy', -entropy.mean(), it) 43 | writer.add_scalar('train/loss_prior', -log_pz.mean(), it) 44 | writer.add_scalar('train/loss_recons', loss_recons, it) 45 | 46 | return loss 47 | 48 | def sample(self, z, num_points, flexibility, truncate_std=None): 49 | """ 50 | Args: 51 | z: Input latent, normal random samples with mean=0 std=1, (B, F) 52 | """ 53 | if truncate_std is not None: 54 | z = truncated_normal_(z, mean=0, std=1, trunc_std=truncate_std) 55 | samples = self.diffusion.sample(num_points, context=z, flexibility=flexibility) 56 | return samples 57 | -------------------------------------------------------------------------------- /pretrained/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /results/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !README.md 4 | -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | # Results 2 | 3 | Generation and decoding results will be saved here. 4 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/luost26/diffusion-point-cloud/b037c5fb5c27ad016b1ebbe648275b0cffc6cf73/teaser.png -------------------------------------------------------------------------------- /test_ae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import torch 5 | from tqdm.auto import tqdm 6 | 7 | from utils.dataset import * 8 | from utils.misc import * 9 | from utils.data import * 10 | from models.autoencoder import * 11 | from evaluation import EMD_CD 12 | 13 | 14 | # Arguments 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--ckpt', type=str, default='./pretrained/AE_airplane.pt') 17 | parser.add_argument('--categories', type=str_list, default=['airplane']) 18 | parser.add_argument('--save_dir', type=str, default='./results') 19 | parser.add_argument('--device', type=str, default='cuda') 20 | # Datasets and loaders 21 | parser.add_argument('--dataset_path', type=str, default='./data/shapenet.hdf5') 22 | parser.add_argument('--batch_size', type=int, default=128) 23 | args = parser.parse_args() 24 | 25 | # Logging 26 | save_dir = os.path.join(args.save_dir, 'AE_Ours_%s_%d' % ('_'.join(args.categories), int(time.time())) ) 27 | if not os.path.exists(save_dir): 28 | os.makedirs(save_dir) 29 | logger = get_logger('test', save_dir) 30 | for k, v in vars(args).items(): 31 | logger.info('[ARGS::%s] %s' % (k, repr(v))) 32 | 33 | # Checkpoint 34 | ckpt = torch.load(args.ckpt) 35 | seed_all(ckpt['args'].seed) 36 | 37 | # Datasets and loaders 38 | logger.info('Loading datasets...') 39 | test_dset = ShapeNetCore( 40 | path=args.dataset_path, 41 | cates=args.categories, 42 | split='test', 43 | scale_mode=ckpt['args'].scale_mode 44 | ) 45 | test_loader = DataLoader(test_dset, batch_size=args.batch_size, num_workers=0) 46 | 47 | # Model 48 | logger.info('Loading model...') 49 | model = AutoEncoder(ckpt['args']).to(args.device) 50 | model.load_state_dict(ckpt['state_dict']) 51 | 52 | all_ref = [] 53 | all_recons = [] 54 | for i, batch in enumerate(tqdm(test_loader)): 55 | ref = batch['pointcloud'].to(args.device) 56 | shift = batch['shift'].to(args.device) 57 | scale = batch['scale'].to(args.device) 58 | model.eval() 59 | with torch.no_grad(): 60 | code = model.encode(ref) 61 | recons = model.decode(code, ref.size(1), flexibility=ckpt['args'].flexibility).detach() 62 | 63 | ref = ref * scale + shift 64 | recons = recons * scale + shift 65 | 66 | all_ref.append(ref.detach().cpu()) 67 | all_recons.append(recons.detach().cpu()) 68 | 69 | all_ref = torch.cat(all_ref, dim=0) 70 | all_recons = torch.cat(all_recons, dim=0) 71 | 72 | logger.info('Saving point clouds...') 73 | np.save(os.path.join(save_dir, 'ref.npy'), all_ref.numpy()) 74 | np.save(os.path.join(save_dir, 'out.npy'), all_recons.numpy()) 75 | 76 | logger.info('Start computing metrics...') 77 | metrics = EMD_CD(all_recons.to(args.device), all_ref.to(args.device), batch_size=args.batch_size) 78 | cd, emd = metrics['MMD-CD'].item(), metrics['MMD-EMD'].item() 79 | logger.info('CD: %.12f' % cd) 80 | logger.info('EMD: %.12f' % emd) 81 | -------------------------------------------------------------------------------- /test_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import math 4 | import argparse 5 | import torch 6 | from tqdm.auto import tqdm 7 | 8 | from utils.dataset import * 9 | from utils.misc import * 10 | from utils.data import * 11 | from models.vae_gaussian import * 12 | from models.vae_flow import * 13 | from models.flow import add_spectral_norm, spectral_norm_power_iteration 14 | from evaluation import * 15 | 16 | def normalize_point_clouds(pcs, mode, logger): 17 | if mode is None: 18 | logger.info('Will not normalize point clouds.') 19 | return pcs 20 | logger.info('Normalization mode: %s' % mode) 21 | for i in tqdm(range(pcs.size(0)), desc='Normalize'): 22 | pc = pcs[i] 23 | if mode == 'shape_unit': 24 | shift = pc.mean(dim=0).reshape(1, 3) 25 | scale = pc.flatten().std().reshape(1, 1) 26 | elif mode == 'shape_bbox': 27 | pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3) 28 | pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3) 29 | shift = ((pc_min + pc_max) / 2).view(1, 3) 30 | scale = (pc_max - pc_min).max().reshape(1, 1) / 2 31 | pc = (pc - shift) / scale 32 | pcs[i] = pc 33 | return pcs 34 | 35 | 36 | # Arguments 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument('--ckpt', type=str, default='./pretrained/GEN_airplane.pt') 39 | parser.add_argument('--categories', type=str_list, default=['airplane']) 40 | parser.add_argument('--save_dir', type=str, default='./results') 41 | parser.add_argument('--device', type=str, default='cuda') 42 | # Datasets and loaders 43 | parser.add_argument('--dataset_path', type=str, default='./data/shapenet.hdf5') 44 | parser.add_argument('--batch_size', type=int, default=128) 45 | # Sampling 46 | parser.add_argument('--sample_num_points', type=int, default=2048) 47 | parser.add_argument('--normalize', type=str, default='shape_bbox', choices=[None, 'shape_unit', 'shape_bbox']) 48 | parser.add_argument('--seed', type=int, default=9988) 49 | args = parser.parse_args() 50 | 51 | 52 | # Logging 53 | save_dir = os.path.join(args.save_dir, 'GEN_Ours_%s_%d' % ('_'.join(args.categories), int(time.time())) ) 54 | if not os.path.exists(save_dir): 55 | os.makedirs(save_dir) 56 | logger = get_logger('test', save_dir) 57 | for k, v in vars(args).items(): 58 | logger.info('[ARGS::%s] %s' % (k, repr(v))) 59 | 60 | # Checkpoint 61 | ckpt = torch.load(args.ckpt) 62 | seed_all(args.seed) 63 | 64 | # Datasets and loaders 65 | logger.info('Loading datasets...') 66 | test_dset = ShapeNetCore( 67 | path=args.dataset_path, 68 | cates=args.categories, 69 | split='test', 70 | scale_mode=args.normalize, 71 | ) 72 | test_loader = DataLoader(test_dset, batch_size=args.batch_size, num_workers=0) 73 | 74 | # Model 75 | logger.info('Loading model...') 76 | if ckpt['args'].model == 'gaussian': 77 | model = GaussianVAE(ckpt['args']).to(args.device) 78 | elif ckpt['args'].model == 'flow': 79 | model = FlowVAE(ckpt['args']).to(args.device) 80 | logger.info(repr(model)) 81 | # if ckpt['args'].spectral_norm: 82 | # add_spectral_norm(model, logger=logger) 83 | model.load_state_dict(ckpt['state_dict']) 84 | 85 | # Reference Point Clouds 86 | ref_pcs = [] 87 | for i, data in enumerate(test_dset): 88 | ref_pcs.append(data['pointcloud'].unsqueeze(0)) 89 | ref_pcs = torch.cat(ref_pcs, dim=0) 90 | 91 | # Generate Point Clouds 92 | gen_pcs = [] 93 | for i in tqdm(range(0, math.ceil(len(test_dset) / args.batch_size)), 'Generate'): 94 | with torch.no_grad(): 95 | z = torch.randn([args.batch_size, ckpt['args'].latent_dim]).to(args.device) 96 | x = model.sample(z, args.sample_num_points, flexibility=ckpt['args'].flexibility) 97 | gen_pcs.append(x.detach().cpu()) 98 | gen_pcs = torch.cat(gen_pcs, dim=0)[:len(test_dset)] 99 | if args.normalize is not None: 100 | gen_pcs = normalize_point_clouds(gen_pcs, mode=args.normalize, logger=logger) 101 | 102 | # Save 103 | logger.info('Saving point clouds...') 104 | np.save(os.path.join(save_dir, 'out.npy'), gen_pcs.numpy()) 105 | 106 | # Compute metrics 107 | with torch.no_grad(): 108 | results = compute_all_metrics(gen_pcs.to(args.device), ref_pcs.to(args.device), args.batch_size) 109 | results = {k:v.item() for k, v in results.items()} 110 | jsd = jsd_between_point_cloud_sets(gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy()) 111 | results['jsd'] = jsd 112 | 113 | for k, v in results.items(): 114 | logger.info('%s: %.12f' % (k, v)) 115 | -------------------------------------------------------------------------------- /train_ae.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import torch.utils.tensorboard 5 | from torch.nn.utils import clip_grad_norm_ 6 | from tqdm.auto import tqdm 7 | 8 | from utils.dataset import * 9 | from utils.misc import * 10 | from utils.data import * 11 | from utils.transform import * 12 | from models.autoencoder import * 13 | from evaluation import EMD_CD 14 | 15 | 16 | # Arguments 17 | parser = argparse.ArgumentParser() 18 | # Model arguments 19 | parser.add_argument('--latent_dim', type=int, default=256) 20 | parser.add_argument('--num_steps', type=int, default=200) 21 | parser.add_argument('--beta_1', type=float, default=1e-4) 22 | parser.add_argument('--beta_T', type=float, default=0.05) 23 | parser.add_argument('--sched_mode', type=str, default='linear') 24 | parser.add_argument('--flexibility', type=float, default=0.0) 25 | parser.add_argument('--residual', type=eval, default=True, choices=[True, False]) 26 | parser.add_argument('--resume', type=str, default=None) 27 | 28 | # Datasets and loaders 29 | parser.add_argument('--dataset_path', type=str, default='./data/shapenet.hdf5') 30 | parser.add_argument('--categories', type=str_list, default=['airplane']) 31 | parser.add_argument('--scale_mode', type=str, default='shape_unit') 32 | parser.add_argument('--train_batch_size', type=int, default=128) 33 | parser.add_argument('--val_batch_size', type=int, default=32) 34 | parser.add_argument('--rotate', type=eval, default=False, choices=[True, False]) 35 | 36 | # Optimizer and scheduler 37 | parser.add_argument('--lr', type=float, default=1e-3) 38 | parser.add_argument('--weight_decay', type=float, default=0) 39 | parser.add_argument('--max_grad_norm', type=float, default=10) 40 | parser.add_argument('--end_lr', type=float, default=1e-4) 41 | parser.add_argument('--sched_start_epoch', type=int, default=150*THOUSAND) 42 | parser.add_argument('--sched_end_epoch', type=int, default=300*THOUSAND) 43 | 44 | # Training 45 | parser.add_argument('--seed', type=int, default=2020) 46 | parser.add_argument('--logging', type=eval, default=True, choices=[True, False]) 47 | parser.add_argument('--log_root', type=str, default='./logs_ae') 48 | parser.add_argument('--device', type=str, default='cuda') 49 | parser.add_argument('--max_iters', type=int, default=float('inf')) 50 | parser.add_argument('--val_freq', type=float, default=1000) 51 | parser.add_argument('--tag', type=str, default=None) 52 | parser.add_argument('--num_val_batches', type=int, default=-1) 53 | parser.add_argument('--num_inspect_batches', type=int, default=1) 54 | parser.add_argument('--num_inspect_pointclouds', type=int, default=4) 55 | args = parser.parse_args() 56 | seed_all(args.seed) 57 | 58 | # Logging 59 | if args.logging: 60 | log_dir = get_new_log_dir(args.log_root, prefix='AE_', postfix='_' + args.tag if args.tag is not None else '') 61 | logger = get_logger('train', log_dir) 62 | writer = torch.utils.tensorboard.SummaryWriter(log_dir) 63 | ckpt_mgr = CheckpointManager(log_dir) 64 | else: 65 | logger = get_logger('train', None) 66 | writer = BlackHole() 67 | ckpt_mgr = BlackHole() 68 | logger.info(args) 69 | 70 | # Datasets and loaders 71 | transform = None 72 | if args.rotate: 73 | transform = RandomRotate(180, ['pointcloud'], axis=1) 74 | logger.info('Transform: %s' % repr(transform)) 75 | logger.info('Loading datasets...') 76 | train_dset = ShapeNetCore( 77 | path=args.dataset_path, 78 | cates=args.categories, 79 | split='train', 80 | scale_mode=args.scale_mode, 81 | transform=transform, 82 | ) 83 | val_dset = ShapeNetCore( 84 | path=args.dataset_path, 85 | cates=args.categories, 86 | split='val', 87 | scale_mode=args.scale_mode, 88 | transform=transform, 89 | ) 90 | train_iter = get_data_iterator(DataLoader( 91 | train_dset, 92 | batch_size=args.train_batch_size, 93 | num_workers=0, 94 | )) 95 | val_loader = DataLoader(val_dset, batch_size=args.val_batch_size, num_workers=0) 96 | 97 | 98 | # Model 99 | logger.info('Building model...') 100 | if args.resume is not None: 101 | logger.info('Resuming from checkpoint...') 102 | ckpt = torch.load(args.resume) 103 | model = AutoEncoder(ckpt['args']).to(args.device) 104 | model.load_state_dict(ckpt['state_dict']) 105 | else: 106 | model = AutoEncoder(args).to(args.device) 107 | logger.info(repr(model)) 108 | 109 | 110 | # Optimizer and scheduler 111 | optimizer = torch.optim.Adam(model.parameters(), 112 | lr=args.lr, 113 | weight_decay=args.weight_decay 114 | ) 115 | scheduler = get_linear_scheduler( 116 | optimizer, 117 | start_epoch=args.sched_start_epoch, 118 | end_epoch=args.sched_end_epoch, 119 | start_lr=args.lr, 120 | end_lr=args.end_lr 121 | ) 122 | 123 | # Train, validate 124 | def train(it): 125 | # Load data 126 | batch = next(train_iter) 127 | x = batch['pointcloud'].to(args.device) 128 | 129 | # Reset grad and model state 130 | optimizer.zero_grad() 131 | model.train() 132 | 133 | # Forward 134 | loss = model.get_loss(x) 135 | 136 | # Backward and optimize 137 | loss.backward() 138 | orig_grad_norm = clip_grad_norm_(model.parameters(), args.max_grad_norm) 139 | optimizer.step() 140 | scheduler.step() 141 | 142 | logger.info('[Train] Iter %04d | Loss %.6f | Grad %.4f ' % (it, loss.item(), orig_grad_norm)) 143 | writer.add_scalar('train/loss', loss, it) 144 | writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it) 145 | writer.add_scalar('train/grad_norm', orig_grad_norm, it) 146 | writer.flush() 147 | 148 | def validate_loss(it): 149 | 150 | all_refs = [] 151 | all_recons = [] 152 | for i, batch in enumerate(tqdm(val_loader, desc='Validate')): 153 | if args.num_val_batches > 0 and i >= args.num_val_batches: 154 | break 155 | ref = batch['pointcloud'].to(args.device) 156 | shift = batch['shift'].to(args.device) 157 | scale = batch['scale'].to(args.device) 158 | with torch.no_grad(): 159 | model.eval() 160 | code = model.encode(ref) 161 | recons = model.decode(code, ref.size(1), flexibility=args.flexibility) 162 | all_refs.append(ref * scale + shift) 163 | all_recons.append(recons * scale + shift) 164 | 165 | all_refs = torch.cat(all_refs, dim=0) 166 | all_recons = torch.cat(all_recons, dim=0) 167 | metrics = EMD_CD(all_recons, all_refs, batch_size=args.val_batch_size) 168 | cd, emd = metrics['MMD-CD'].item(), metrics['MMD-EMD'].item() 169 | 170 | logger.info('[Val] Iter %04d | CD %.6f | EMD %.6f ' % (it, cd, emd)) 171 | writer.add_scalar('val/cd', cd, it) 172 | writer.add_scalar('val/emd', emd, it) 173 | writer.flush() 174 | 175 | return cd 176 | 177 | def validate_inspect(it): 178 | sum_n = 0 179 | sum_chamfer = 0 180 | for i, batch in enumerate(tqdm(val_loader, desc='Inspect')): 181 | x = batch['pointcloud'].to(args.device) 182 | model.eval() 183 | code = model.encode(x) 184 | recons = model.decode(code, x.size(1), flexibility=args.flexibility).detach() 185 | 186 | sum_n += x.size(0) 187 | if i >= args.num_inspect_batches: 188 | break # Inspect only 5 batch 189 | 190 | writer.add_mesh('val/pointcloud', recons[:args.num_inspect_pointclouds], global_step=it) 191 | writer.flush() 192 | 193 | # Main loop 194 | logger.info('Start training...') 195 | try: 196 | it = 1 197 | while it <= args.max_iters: 198 | train(it) 199 | if it % args.val_freq == 0 or it == args.max_iters: 200 | with torch.no_grad(): 201 | cd_loss = validate_loss(it) 202 | validate_inspect(it) 203 | opt_states = { 204 | 'optimizer': optimizer.state_dict(), 205 | 'scheduler': scheduler.state_dict(), 206 | } 207 | ckpt_mgr.save(model, args, cd_loss, opt_states, step=it) 208 | it += 1 209 | 210 | except KeyboardInterrupt: 211 | logger.info('Terminating...') 212 | -------------------------------------------------------------------------------- /train_gen.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import torch 5 | import torch.utils.tensorboard 6 | from torch.utils.data import DataLoader 7 | from torch.nn.utils import clip_grad_norm_ 8 | from tqdm.auto import tqdm 9 | 10 | from utils.dataset import * 11 | from utils.misc import * 12 | from utils.data import * 13 | from models.vae_gaussian import * 14 | from models.vae_flow import * 15 | from models.flow import add_spectral_norm, spectral_norm_power_iteration 16 | from evaluation import * 17 | 18 | 19 | # Arguments 20 | parser = argparse.ArgumentParser() 21 | # Model arguments 22 | parser.add_argument('--model', type=str, default='flow', choices=['flow', 'gaussian']) 23 | parser.add_argument('--latent_dim', type=int, default=256) 24 | parser.add_argument('--num_steps', type=int, default=100) 25 | parser.add_argument('--beta_1', type=float, default=1e-4) 26 | parser.add_argument('--beta_T', type=float, default=0.02) 27 | parser.add_argument('--sched_mode', type=str, default='linear') 28 | parser.add_argument('--flexibility', type=float, default=0.0) 29 | parser.add_argument('--truncate_std', type=float, default=2.0) 30 | parser.add_argument('--latent_flow_depth', type=int, default=14) 31 | parser.add_argument('--latent_flow_hidden_dim', type=int, default=256) 32 | parser.add_argument('--num_samples', type=int, default=4) 33 | parser.add_argument('--sample_num_points', type=int, default=2048) 34 | parser.add_argument('--kl_weight', type=float, default=0.001) 35 | parser.add_argument('--residual', type=eval, default=True, choices=[True, False]) 36 | parser.add_argument('--spectral_norm', type=eval, default=False, choices=[True, False]) 37 | 38 | # Datasets and loaders 39 | parser.add_argument('--dataset_path', type=str, default='./data/shapenet.hdf5') 40 | parser.add_argument('--categories', type=str_list, default=['airplane']) 41 | parser.add_argument('--scale_mode', type=str, default='shape_unit') 42 | parser.add_argument('--train_batch_size', type=int, default=128) 43 | parser.add_argument('--val_batch_size', type=int, default=64) 44 | 45 | # Optimizer and scheduler 46 | parser.add_argument('--lr', type=float, default=2e-3) 47 | parser.add_argument('--weight_decay', type=float, default=0) 48 | parser.add_argument('--max_grad_norm', type=float, default=10) 49 | parser.add_argument('--end_lr', type=float, default=1e-4) 50 | parser.add_argument('--sched_start_epoch', type=int, default=200*THOUSAND) 51 | parser.add_argument('--sched_end_epoch', type=int, default=400*THOUSAND) 52 | 53 | # Training 54 | parser.add_argument('--seed', type=int, default=2020) 55 | parser.add_argument('--logging', type=eval, default=True, choices=[True, False]) 56 | parser.add_argument('--log_root', type=str, default='./logs_gen') 57 | parser.add_argument('--device', type=str, default='cuda') 58 | parser.add_argument('--max_iters', type=int, default=float('inf')) 59 | parser.add_argument('--val_freq', type=int, default=1000) 60 | parser.add_argument('--test_freq', type=int, default=30*THOUSAND) 61 | parser.add_argument('--test_size', type=int, default=400) 62 | parser.add_argument('--tag', type=str, default=None) 63 | args = parser.parse_args() 64 | seed_all(args.seed) 65 | 66 | # Logging 67 | if args.logging: 68 | log_dir = get_new_log_dir(args.log_root, prefix='GEN_', postfix='_' + args.tag if args.tag is not None else '') 69 | logger = get_logger('train', log_dir) 70 | writer = torch.utils.tensorboard.SummaryWriter(log_dir) 71 | ckpt_mgr = CheckpointManager(log_dir) 72 | log_hyperparams(writer, args) 73 | else: 74 | logger = get_logger('train', None) 75 | writer = BlackHole() 76 | ckpt_mgr = BlackHole() 77 | logger.info(args) 78 | 79 | # Datasets and loaders 80 | logger.info('Loading datasets...') 81 | 82 | train_dset = ShapeNetCore( 83 | path=args.dataset_path, 84 | cates=args.categories, 85 | split='train', 86 | scale_mode=args.scale_mode, 87 | ) 88 | val_dset = ShapeNetCore( 89 | path=args.dataset_path, 90 | cates=args.categories, 91 | split='val', 92 | scale_mode=args.scale_mode, 93 | ) 94 | train_iter = get_data_iterator(DataLoader( 95 | train_dset, 96 | batch_size=args.train_batch_size, 97 | num_workers=0, 98 | )) 99 | 100 | # Model 101 | logger.info('Building model...') 102 | if args.model == 'gaussian': 103 | model = GaussianVAE(args).to(args.device) 104 | elif args.model == 'flow': 105 | model = FlowVAE(args).to(args.device) 106 | logger.info(repr(model)) 107 | if args.spectral_norm: 108 | add_spectral_norm(model, logger=logger) 109 | 110 | # Optimizer and scheduler 111 | optimizer = torch.optim.Adam(model.parameters(), 112 | lr=args.lr, 113 | weight_decay=args.weight_decay 114 | ) 115 | scheduler = get_linear_scheduler( 116 | optimizer, 117 | start_epoch=args.sched_start_epoch, 118 | end_epoch=args.sched_end_epoch, 119 | start_lr=args.lr, 120 | end_lr=args.end_lr 121 | ) 122 | 123 | # Train, validate and test 124 | def train(it): 125 | # Load data 126 | batch = next(train_iter) 127 | x = batch['pointcloud'].to(args.device) 128 | 129 | # Reset grad and model state 130 | optimizer.zero_grad() 131 | model.train() 132 | if args.spectral_norm: 133 | spectral_norm_power_iteration(model, n_power_iterations=1) 134 | 135 | # Forward 136 | kl_weight = args.kl_weight 137 | loss = model.get_loss(x, kl_weight=kl_weight, writer=writer, it=it) 138 | 139 | # Backward and optimize 140 | loss.backward() 141 | orig_grad_norm = clip_grad_norm_(model.parameters(), args.max_grad_norm) 142 | optimizer.step() 143 | scheduler.step() 144 | 145 | logger.info('[Train] Iter %04d | Loss %.6f | Grad %.4f | KLWeight %.4f' % ( 146 | it, loss.item(), orig_grad_norm, kl_weight 147 | )) 148 | writer.add_scalar('train/loss', loss, it) 149 | writer.add_scalar('train/kl_weight', kl_weight, it) 150 | writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], it) 151 | writer.add_scalar('train/grad_norm', orig_grad_norm, it) 152 | writer.flush() 153 | 154 | def validate_inspect(it): 155 | z = torch.randn([args.num_samples, args.latent_dim]).to(args.device) 156 | x = model.sample(z, args.sample_num_points, flexibility=args.flexibility) #, truncate_std=args.truncate_std) 157 | writer.add_mesh('val/pointcloud', x, global_step=it) 158 | writer.flush() 159 | logger.info('[Inspect] Generating samples...') 160 | 161 | def test(it): 162 | ref_pcs = [] 163 | for i, data in enumerate(val_dset): 164 | if i >= args.test_size: 165 | break 166 | ref_pcs.append(data['pointcloud'].unsqueeze(0)) 167 | ref_pcs = torch.cat(ref_pcs, dim=0) 168 | 169 | gen_pcs = [] 170 | for i in tqdm(range(0, math.ceil(args.test_size / args.val_batch_size)), 'Generate'): 171 | with torch.no_grad(): 172 | z = torch.randn([args.val_batch_size, args.latent_dim]).to(args.device) 173 | x = model.sample(z, args.sample_num_points, flexibility=args.flexibility) 174 | gen_pcs.append(x.detach().cpu()) 175 | gen_pcs = torch.cat(gen_pcs, dim=0)[:args.test_size] 176 | 177 | # Denormalize point clouds, all shapes have zero mean. 178 | # [WARNING]: Do NOT denormalize! 179 | # ref_pcs *= val_dset.stats['std'] 180 | # gen_pcs *= val_dset.stats['std'] 181 | 182 | with torch.no_grad(): 183 | results = compute_all_metrics(gen_pcs.to(args.device), ref_pcs.to(args.device), args.val_batch_size) 184 | results = {k:v.item() for k, v in results.items()} 185 | jsd = jsd_between_point_cloud_sets(gen_pcs.cpu().numpy(), ref_pcs.cpu().numpy()) 186 | results['jsd'] = jsd 187 | 188 | # CD related metrics 189 | writer.add_scalar('test/Coverage_CD', results['lgan_cov-CD'], global_step=it) 190 | writer.add_scalar('test/MMD_CD', results['lgan_mmd-CD'], global_step=it) 191 | writer.add_scalar('test/1NN_CD', results['1-NN-CD-acc'], global_step=it) 192 | # EMD related metrics 193 | # writer.add_scalar('test/Coverage_EMD', results['lgan_cov-EMD'], global_step=it) 194 | # writer.add_scalar('test/MMD_EMD', results['lgan_mmd-EMD'], global_step=it) 195 | # writer.add_scalar('test/1NN_EMD', results['1-NN-EMD-acc'], global_step=it) 196 | # JSD 197 | writer.add_scalar('test/JSD', results['jsd'], global_step=it) 198 | 199 | # logger.info('[Test] Coverage | CD %.6f | EMD %.6f' % (results['lgan_cov-CD'], results['lgan_cov-EMD'])) 200 | # logger.info('[Test] MinMatDis | CD %.6f | EMD %.6f' % (results['lgan_mmd-CD'], results['lgan_mmd-EMD'])) 201 | # logger.info('[Test] 1NN-Accur | CD %.6f | EMD %.6f' % (results['1-NN-CD-acc'], results['1-NN-EMD-acc'])) 202 | logger.info('[Test] Coverage | CD %.6f | EMD n/a' % (results['lgan_cov-CD'], )) 203 | logger.info('[Test] MinMatDis | CD %.6f | EMD n/a' % (results['lgan_mmd-CD'], )) 204 | logger.info('[Test] 1NN-Accur | CD %.6f | EMD n/a' % (results['1-NN-CD-acc'], )) 205 | logger.info('[Test] JsnShnDis | %.6f ' % (results['jsd'])) 206 | 207 | # Main loop 208 | logger.info('Start training...') 209 | try: 210 | it = 1 211 | while it <= args.max_iters: 212 | train(it) 213 | if it % args.val_freq == 0 or it == args.max_iters: 214 | validate_inspect(it) 215 | opt_states = { 216 | 'optimizer': optimizer.state_dict(), 217 | 'scheduler': scheduler.state_dict(), 218 | } 219 | ckpt_mgr.save(model, args, 0, others=opt_states, step=it) 220 | if it % args.test_freq == 0 or it == args.max_iters: 221 | test(it) 222 | it += 1 223 | 224 | except KeyboardInterrupt: 225 | logger.info('Terminating...') 226 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, random_split 3 | 4 | 5 | def get_train_val_test_datasets(dataset, train_ratio, val_ratio): 6 | assert (train_ratio + val_ratio) <= 1 7 | train_size = int(len(dataset) * train_ratio) 8 | val_size = int(len(dataset) * val_ratio) 9 | test_size = len(dataset) - train_size - val_size 10 | 11 | train_set, val_set, test_set = random_split(dataset, [train_size, val_size, test_size]) 12 | return train_set, val_set, test_set 13 | 14 | 15 | def get_train_val_test_loaders(dataset, train_ratio, val_ratio, train_batch_size, val_test_batch_size, num_workers): 16 | train_set, val_set, test_set = get_train_val_test_datasets(dataset, train_ratio, val_ratio) 17 | 18 | train_loader = DataLoader(train_set, train_batch_size, shuffle=True, num_workers=num_workers) 19 | val_loader = DataLoader(val_set, val_test_batch_size, shuffle=False, num_workers=num_workers) 20 | test_loader = DataLoader(test_set, val_test_batch_size, shuffle=False, num_workers=num_workers) 21 | 22 | return train_loader, val_loader, test_loader 23 | 24 | 25 | def get_data_iterator(iterable): 26 | """Allows training with DataLoaders in a single infinite loop: 27 | for i, data in enumerate(inf_generator(train_loader)): 28 | """ 29 | iterator = iterable.__iter__() 30 | while True: 31 | try: 32 | yield iterator.__next__() 33 | except StopIteration: 34 | iterator = iterable.__iter__() 35 | -------------------------------------------------------------------------------- /utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from copy import copy 4 | import torch 5 | from torch.utils.data import Dataset 6 | import numpy as np 7 | import h5py 8 | from tqdm.auto import tqdm 9 | 10 | 11 | synsetid_to_cate = { 12 | '02691156': 'airplane', '02773838': 'bag', '02801938': 'basket', 13 | '02808440': 'bathtub', '02818832': 'bed', '02828884': 'bench', 14 | '02876657': 'bottle', '02880940': 'bowl', '02924116': 'bus', 15 | '02933112': 'cabinet', '02747177': 'can', '02942699': 'camera', 16 | '02954340': 'cap', '02958343': 'car', '03001627': 'chair', 17 | '03046257': 'clock', '03207941': 'dishwasher', '03211117': 'monitor', 18 | '04379243': 'table', '04401088': 'telephone', '02946921': 'tin_can', 19 | '04460130': 'tower', '04468005': 'train', '03085013': 'keyboard', 20 | '03261776': 'earphone', '03325088': 'faucet', '03337140': 'file', 21 | '03467517': 'guitar', '03513137': 'helmet', '03593526': 'jar', 22 | '03624134': 'knife', '03636649': 'lamp', '03642806': 'laptop', 23 | '03691459': 'speaker', '03710193': 'mailbox', '03759954': 'microphone', 24 | '03761084': 'microwave', '03790512': 'motorcycle', '03797390': 'mug', 25 | '03928116': 'piano', '03938244': 'pillow', '03948459': 'pistol', 26 | '03991062': 'pot', '04004475': 'printer', '04074963': 'remote_control', 27 | '04090263': 'rifle', '04099429': 'rocket', '04225987': 'skateboard', 28 | '04256520': 'sofa', '04330267': 'stove', '04530566': 'vessel', 29 | '04554684': 'washer', '02992529': 'cellphone', 30 | '02843684': 'birdhouse', '02871439': 'bookshelf', 31 | # '02858304': 'boat', no boat in our dataset, merged into vessels 32 | # '02834778': 'bicycle', not in our taxonomy 33 | } 34 | cate_to_synsetid = {v: k for k, v in synsetid_to_cate.items()} 35 | 36 | 37 | class ShapeNetCore(Dataset): 38 | 39 | GRAVITATIONAL_AXIS = 1 40 | 41 | def __init__(self, path, cates, split, scale_mode, transform=None): 42 | super().__init__() 43 | assert isinstance(cates, list), '`cates` must be a list of cate names.' 44 | assert split in ('train', 'val', 'test') 45 | assert scale_mode is None or scale_mode in ('global_unit', 'shape_unit', 'shape_bbox', 'shape_half', 'shape_34') 46 | self.path = path 47 | if 'all' in cates: 48 | cates = cate_to_synsetid.keys() 49 | self.cate_synsetids = [cate_to_synsetid[s] for s in cates] 50 | self.cate_synsetids.sort() 51 | self.split = split 52 | self.scale_mode = scale_mode 53 | self.transform = transform 54 | 55 | self.pointclouds = [] 56 | self.stats = None 57 | 58 | self.get_statistics() 59 | self.load() 60 | 61 | def get_statistics(self): 62 | 63 | basename = os.path.basename(self.path) 64 | dsetname = basename[:basename.rfind('.')] 65 | stats_dir = os.path.join(os.path.dirname(self.path), dsetname + '_stats') 66 | os.makedirs(stats_dir, exist_ok=True) 67 | 68 | if len(self.cate_synsetids) == len(cate_to_synsetid): 69 | stats_save_path = os.path.join(stats_dir, 'stats_all.pt') 70 | else: 71 | stats_save_path = os.path.join(stats_dir, 'stats_' + '_'.join(self.cate_synsetids) + '.pt') 72 | if os.path.exists(stats_save_path): 73 | self.stats = torch.load(stats_save_path) 74 | return self.stats 75 | 76 | with h5py.File(self.path, 'r') as f: 77 | pointclouds = [] 78 | for synsetid in self.cate_synsetids: 79 | for split in ('train', 'val', 'test'): 80 | pointclouds.append(torch.from_numpy(f[synsetid][split][...])) 81 | 82 | all_points = torch.cat(pointclouds, dim=0) # (B, N, 3) 83 | B, N, _ = all_points.size() 84 | mean = all_points.view(B*N, -1).mean(dim=0) # (1, 3) 85 | std = all_points.view(-1).std(dim=0) # (1, ) 86 | 87 | self.stats = {'mean': mean, 'std': std} 88 | torch.save(self.stats, stats_save_path) 89 | return self.stats 90 | 91 | def load(self): 92 | 93 | def _enumerate_pointclouds(f): 94 | for synsetid in self.cate_synsetids: 95 | cate_name = synsetid_to_cate[synsetid] 96 | for j, pc in enumerate(f[synsetid][self.split]): 97 | yield torch.from_numpy(pc), j, cate_name 98 | 99 | with h5py.File(self.path, mode='r') as f: 100 | for pc, pc_id, cate_name in _enumerate_pointclouds(f): 101 | 102 | if self.scale_mode == 'global_unit': 103 | shift = pc.mean(dim=0).reshape(1, 3) 104 | scale = self.stats['std'].reshape(1, 1) 105 | elif self.scale_mode == 'shape_unit': 106 | shift = pc.mean(dim=0).reshape(1, 3) 107 | scale = pc.flatten().std().reshape(1, 1) 108 | elif self.scale_mode == 'shape_half': 109 | shift = pc.mean(dim=0).reshape(1, 3) 110 | scale = pc.flatten().std().reshape(1, 1) / (0.5) 111 | elif self.scale_mode == 'shape_34': 112 | shift = pc.mean(dim=0).reshape(1, 3) 113 | scale = pc.flatten().std().reshape(1, 1) / (0.75) 114 | elif self.scale_mode == 'shape_bbox': 115 | pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3) 116 | pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3) 117 | shift = ((pc_min + pc_max) / 2).view(1, 3) 118 | scale = (pc_max - pc_min).max().reshape(1, 1) / 2 119 | else: 120 | shift = torch.zeros([1, 3]) 121 | scale = torch.ones([1, 1]) 122 | 123 | pc = (pc - shift) / scale 124 | 125 | self.pointclouds.append({ 126 | 'pointcloud': pc, 127 | 'cate': cate_name, 128 | 'id': pc_id, 129 | 'shift': shift, 130 | 'scale': scale 131 | }) 132 | 133 | # Deterministically shuffle the dataset 134 | self.pointclouds.sort(key=lambda data: data['id'], reverse=False) 135 | random.Random(2020).shuffle(self.pointclouds) 136 | 137 | def __len__(self): 138 | return len(self.pointclouds) 139 | 140 | def __getitem__(self, idx): 141 | data = {k:v.clone() if isinstance(v, torch.Tensor) else copy(v) for k, v in self.pointclouds[idx].items()} 142 | if self.transform is not None: 143 | data = self.transform(data) 144 | return data 145 | 146 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import time 6 | import logging 7 | import logging.handlers 8 | 9 | THOUSAND = 1000 10 | MILLION = 1000000 11 | 12 | 13 | class BlackHole(object): 14 | def __setattr__(self, name, value): 15 | pass 16 | def __call__(self, *args, **kwargs): 17 | return self 18 | def __getattr__(self, name): 19 | return self 20 | 21 | 22 | class CheckpointManager(object): 23 | 24 | def __init__(self, save_dir, logger=BlackHole()): 25 | super().__init__() 26 | os.makedirs(save_dir, exist_ok=True) 27 | self.save_dir = save_dir 28 | self.ckpts = [] 29 | self.logger = logger 30 | 31 | for f in os.listdir(self.save_dir): 32 | if f[:4] != 'ckpt': 33 | continue 34 | _, score, it = f.split('_') 35 | it = it.split('.')[0] 36 | self.ckpts.append({ 37 | 'score': float(score), 38 | 'file': f, 39 | 'iteration': int(it), 40 | }) 41 | 42 | def get_worst_ckpt_idx(self): 43 | idx = -1 44 | worst = float('-inf') 45 | for i, ckpt in enumerate(self.ckpts): 46 | if ckpt['score'] >= worst: 47 | idx = i 48 | worst = ckpt['score'] 49 | return idx if idx >= 0 else None 50 | 51 | def get_best_ckpt_idx(self): 52 | idx = -1 53 | best = float('inf') 54 | for i, ckpt in enumerate(self.ckpts): 55 | if ckpt['score'] <= best: 56 | idx = i 57 | best = ckpt['score'] 58 | return idx if idx >= 0 else None 59 | 60 | def get_latest_ckpt_idx(self): 61 | idx = -1 62 | latest_it = -1 63 | for i, ckpt in enumerate(self.ckpts): 64 | if ckpt['iteration'] > latest_it: 65 | idx = i 66 | latest_it = ckpt['iteration'] 67 | return idx if idx >= 0 else None 68 | 69 | def save(self, model, args, score, others=None, step=None): 70 | 71 | if step is None: 72 | fname = 'ckpt_%.6f_.pt' % float(score) 73 | else: 74 | fname = 'ckpt_%.6f_%d.pt' % (float(score), int(step)) 75 | path = os.path.join(self.save_dir, fname) 76 | 77 | torch.save({ 78 | 'args': args, 79 | 'state_dict': model.state_dict(), 80 | 'others': others 81 | }, path) 82 | 83 | self.ckpts.append({ 84 | 'score': score, 85 | 'file': fname 86 | }) 87 | 88 | return True 89 | 90 | def load_best(self): 91 | idx = self.get_best_ckpt_idx() 92 | if idx is None: 93 | raise IOError('No checkpoints found.') 94 | ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file'])) 95 | return ckpt 96 | 97 | def load_latest(self): 98 | idx = self.get_latest_ckpt_idx() 99 | if idx is None: 100 | raise IOError('No checkpoints found.') 101 | ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file'])) 102 | return ckpt 103 | 104 | def load_selected(self, file): 105 | ckpt = torch.load(os.path.join(self.save_dir, file)) 106 | return ckpt 107 | 108 | 109 | def seed_all(seed): 110 | torch.manual_seed(seed) 111 | np.random.seed(seed) 112 | random.seed(seed) 113 | 114 | 115 | def get_logger(name, log_dir=None): 116 | logger = logging.getLogger(name) 117 | logger.setLevel(logging.DEBUG) 118 | formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') 119 | 120 | stream_handler = logging.StreamHandler() 121 | stream_handler.setLevel(logging.DEBUG) 122 | stream_handler.setFormatter(formatter) 123 | logger.addHandler(stream_handler) 124 | 125 | if log_dir is not None: 126 | file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt')) 127 | file_handler.setLevel(logging.INFO) 128 | file_handler.setFormatter(formatter) 129 | logger.addHandler(file_handler) 130 | 131 | return logger 132 | 133 | 134 | def get_new_log_dir(root='./logs', postfix='', prefix=''): 135 | log_dir = os.path.join(root, prefix + time.strftime('%Y_%m_%d__%H_%M_%S', time.localtime()) + postfix) 136 | os.makedirs(log_dir) 137 | return log_dir 138 | 139 | 140 | def int_tuple(argstr): 141 | return tuple(map(int, argstr.split(','))) 142 | 143 | 144 | def str_tuple(argstr): 145 | return tuple(argstr.split(',')) 146 | 147 | 148 | def int_list(argstr): 149 | return list(map(int, argstr.split(','))) 150 | 151 | 152 | def str_list(argstr): 153 | return list(argstr.split(',')) 154 | 155 | 156 | def log_hyperparams(writer, args): 157 | from torch.utils.tensorboard.summary import hparams 158 | vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} 159 | exp, ssi, sei = hparams(vars_args, {}) 160 | writer.file_writer.add_summary(exp) 161 | writer.file_writer.add_summary(ssi) 162 | writer.file_writer.add_summary(sei) 163 | -------------------------------------------------------------------------------- /utils/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import math 4 | import random 5 | import numbers 6 | import random 7 | from itertools import repeat 8 | 9 | 10 | class Center(object): 11 | r"""Centers node positions around the origin.""" 12 | 13 | def __init__(self, attr): 14 | self.attr = attr 15 | 16 | def __call__(self, data): 17 | for key in self.attr: 18 | data[key] = data[key] - data[key].mean(dim=-2, keepdim=True) 19 | return data 20 | 21 | def __repr__(self): 22 | return '{}()'.format(self.__class__.__name__) 23 | 24 | 25 | class NormalizeScale(object): 26 | r"""Centers and normalizes node positions to the interval :math:`(-1, 1)`. 27 | """ 28 | 29 | def __init__(self, attr): 30 | self.center = Center(attr=attr) 31 | self.attr = attr 32 | 33 | def __call__(self, data): 34 | data = self.center(data) 35 | 36 | for key in self.attr: 37 | scale = (1 / data[key].abs().max()) * 0.999999 38 | data[key] = data[key] * scale 39 | 40 | return data 41 | 42 | 43 | class FixedPoints(object): 44 | r"""Samples a fixed number of :obj:`num` points and features from a point 45 | cloud. 46 | Args: 47 | num (int): The number of points to sample. 48 | replace (bool, optional): If set to :obj:`False`, samples fixed 49 | points without replacement. In case :obj:`num` is greater than 50 | the number of points, duplicated points are kept to a 51 | minimum. (default: :obj:`True`) 52 | """ 53 | 54 | def __init__(self, num, replace=True): 55 | self.num = num 56 | self.replace = replace 57 | # warnings.warn('FixedPoints is not deterministic') 58 | 59 | def __call__(self, data): 60 | num_nodes = data['pos'].size(0) 61 | data['dense'] = data['pos'] 62 | 63 | if self.replace: 64 | choice = np.random.choice(num_nodes, self.num, replace=True) 65 | else: 66 | choice = torch.cat([ 67 | torch.randperm(num_nodes) 68 | for _ in range(math.ceil(self.num / num_nodes)) 69 | ], dim=0)[:self.num] 70 | 71 | for key, item in data.items(): 72 | if torch.is_tensor(item) and item.size(0) == num_nodes and key != 'dense': 73 | data[key] = item[choice] 74 | 75 | return data 76 | 77 | def __repr__(self): 78 | return '{}({}, replace={})'.format(self.__class__.__name__, self.num, 79 | self.replace) 80 | 81 | 82 | class LinearTransformation(object): 83 | r"""Transforms node positions with a square transformation matrix computed 84 | offline. 85 | Args: 86 | matrix (Tensor): tensor with shape :math:`[D, D]` where :math:`D` 87 | corresponds to the dimensionality of node positions. 88 | """ 89 | 90 | def __init__(self, matrix, attr): 91 | assert matrix.dim() == 2, ( 92 | 'Transformation matrix should be two-dimensional.') 93 | assert matrix.size(0) == matrix.size(1), ( 94 | 'Transformation matrix should be square. Got [{} x {}] rectangular' 95 | 'matrix.'.format(*matrix.size())) 96 | 97 | self.matrix = matrix 98 | self.attr = attr 99 | 100 | def __call__(self, data): 101 | for key in self.attr: 102 | pos = data[key].view(-1, 1) if data[key].dim() == 1 else data[key] 103 | 104 | assert pos.size(-1) == self.matrix.size(-2), ( 105 | 'Node position matrix and transformation matrix have incompatible ' 106 | 'shape.') 107 | 108 | data[key] = torch.matmul(pos, self.matrix.to(pos.dtype).to(pos.device)) 109 | 110 | return data 111 | 112 | def __repr__(self): 113 | return '{}({})'.format(self.__class__.__name__, self.matrix.tolist()) 114 | 115 | 116 | class RandomRotate(object): 117 | r"""Rotates node positions around a specific axis by a randomly sampled 118 | factor within a given interval. 119 | Args: 120 | degrees (tuple or float): Rotation interval from which the rotation 121 | angle is sampled. If :obj:`degrees` is a number instead of a 122 | tuple, the interval is given by :math:`[-\mathrm{degrees}, 123 | \mathrm{degrees}]`. 124 | axis (int, optional): The rotation axis. (default: :obj:`0`) 125 | """ 126 | 127 | def __init__(self, degrees, attr, axis=0): 128 | if isinstance(degrees, numbers.Number): 129 | degrees = (-abs(degrees), abs(degrees)) 130 | assert isinstance(degrees, (tuple, list)) and len(degrees) == 2 131 | self.degrees = degrees 132 | self.axis = axis 133 | self.attr = attr 134 | 135 | def __call__(self, data): 136 | degree = math.pi * random.uniform(*self.degrees) / 180.0 137 | sin, cos = math.sin(degree), math.cos(degree) 138 | 139 | if self.axis == 0: 140 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] 141 | elif self.axis == 1: 142 | matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] 143 | else: 144 | matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] 145 | return LinearTransformation(torch.tensor(matrix), attr=self.attr)(data) 146 | 147 | def __repr__(self): 148 | return '{}({}, axis={})'.format(self.__class__.__name__, self.degrees, 149 | self.axis) 150 | 151 | 152 | class AddNoise(object): 153 | 154 | def __init__(self, std=0.01, noiseless_item_key='clean'): 155 | self.std = std 156 | self.key = noiseless_item_key 157 | 158 | def __call__(self, data): 159 | data[self.key] = data['pos'] 160 | data['pos'] = data['pos'] + torch.normal(mean=0, std=self.std, size=data['pos'].size()) 161 | return data 162 | 163 | 164 | class AddRandomNoise(object): 165 | 166 | def __init__(self, std_range=[0, 0.10], noiseless_item_key='clean'): 167 | self.std_range = std_range 168 | self.key = noiseless_item_key 169 | 170 | def __call__(self, data): 171 | noise_std = random.uniform(*self.std_range) 172 | data[self.key] = data['pos'] 173 | data['pos'] = data['pos'] + torch.normal(mean=0, std=noise_std, size=data['pos'].size()) 174 | return data 175 | 176 | 177 | class AddNoiseForEval(object): 178 | 179 | def __init__(self, stds=[0.0, 0.01, 0.02, 0.03, 0.05, 0.10, 0.15]): 180 | self.stds = stds 181 | self.keys = ['noisy_%.2f' % s for s in stds] 182 | 183 | def __call__(self, data): 184 | data['clean'] = data['pos'] 185 | for noise_std in self.stds: 186 | data['noisy_%.2f' % noise_std] = data['pos'] + torch.normal(mean=0, std=noise_std, size=data['pos'].size()) 187 | return data 188 | 189 | 190 | class IdentityTransform(object): 191 | 192 | def __call__(self, data): 193 | return data 194 | 195 | 196 | class RandomScale(object): 197 | r"""Scales node positions by a randomly sampled factor :math:`s` within a 198 | given interval, *e.g.*, resulting in the transformation matrix 199 | .. math:: 200 | \begin{bmatrix} 201 | s & 0 & 0 \\ 202 | 0 & s & 0 \\ 203 | 0 & 0 & s \\ 204 | \end{bmatrix} 205 | for three-dimensional positions. 206 | Args: 207 | scales (tuple): scaling factor interval, e.g. :obj:`(a, b)`, then scale 208 | is randomly sampled from the range 209 | :math:`a \leq \mathrm{scale} \leq b`. 210 | """ 211 | 212 | def __init__(self, scales, attr): 213 | assert isinstance(scales, (tuple, list)) and len(scales) == 2 214 | self.scales = scales 215 | self.attr = attr 216 | 217 | def __call__(self, data): 218 | scale = random.uniform(*self.scales) 219 | for key in self.attr: 220 | data[key] = data[key] * scale 221 | return data 222 | 223 | def __repr__(self): 224 | return '{}({})'.format(self.__class__.__name__, self.scales) 225 | 226 | 227 | class RandomTranslate(object): 228 | r"""Translates node positions by randomly sampled translation values 229 | within a given interval. In contrast to other random transformations, 230 | translation is applied separately at each position. 231 | Args: 232 | translate (sequence or float or int): Maximum translation in each 233 | dimension, defining the range 234 | :math:`(-\mathrm{translate}, +\mathrm{translate})` to sample from. 235 | If :obj:`translate` is a number instead of a sequence, the same 236 | range is used for each dimension. 237 | """ 238 | 239 | def __init__(self, translate, attr): 240 | self.translate = translate 241 | self.attr = attr 242 | 243 | def __call__(self, data): 244 | (n, dim), t = data['pos'].size(), self.translate 245 | if isinstance(t, numbers.Number): 246 | t = list(repeat(t, times=dim)) 247 | assert len(t) == dim 248 | 249 | ts = [] 250 | for d in range(dim): 251 | ts.append(data['pos'].new_empty(n).uniform_(-abs(t[d]), abs(t[d]))) 252 | 253 | for key in self.attr: 254 | data[key] = data[key] + torch.stack(ts, dim=-1) 255 | 256 | return data 257 | 258 | def __repr__(self): 259 | return '{}({})'.format(self.__class__.__name__, self.translate) 260 | 261 | 262 | class Rotate(object): 263 | r"""Rotates node positions around a specific axis by a randomly sampled 264 | factor within a given interval. 265 | Args: 266 | degrees (tuple or float): Rotation interval from which the rotation 267 | angle is sampled. If :obj:`degrees` is a number instead of a 268 | tuple, the interval is given by :math:`[-\mathrm{degrees}, 269 | \mathrm{degrees}]`. 270 | axis (int, optional): The rotation axis. (default: :obj:`0`) 271 | """ 272 | 273 | def __init__(self, degree, attr, axis=0): 274 | self.degree = degree 275 | self.axis = axis 276 | self.attr = attr 277 | 278 | def __call__(self, data): 279 | degree = math.pi * self.degree / 180.0 280 | sin, cos = math.sin(degree), math.cos(degree) 281 | 282 | if self.axis == 0: 283 | matrix = [[1, 0, 0], [0, cos, sin], [0, -sin, cos]] 284 | elif self.axis == 1: 285 | matrix = [[cos, 0, -sin], [0, 1, 0], [sin, 0, cos]] 286 | else: 287 | matrix = [[cos, sin, 0], [-sin, cos, 0], [0, 0, 1]] 288 | return LinearTransformation(torch.tensor(matrix), attr=self.attr)(data) 289 | 290 | def __repr__(self): 291 | return '{}({}, axis={})'.format(self.__class__.__name__, self.degrees, 292 | self.axis) 293 | --------------------------------------------------------------------------------