├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── installation.yaml └── workflows │ └── pre-commit.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── cython_setup.py ├── docs ├── FAQ.md └── figs │ ├── eval_benchmark.png │ ├── torchsparse.png │ └── train_benchmark.png ├── examples ├── README.md ├── backbones.py ├── converter.py ├── example.py ├── mmdetection3d │ ├── README.md │ ├── configs │ │ └── README.md │ ├── converted_models │ │ └── README.md │ ├── demo.ipynb │ ├── scripts │ │ └── run_evaluation │ │ │ └── SECOND.sh │ ├── setup.py │ └── ts_plugin │ │ ├── __init__.py │ │ └── models │ │ ├── __init__.py │ │ ├── backbones │ │ ├── __init__.py │ │ └── resnet.py │ │ ├── layers │ │ ├── __init__.py │ │ └── sparse_block.py │ │ ├── middle_encoders │ │ ├── __init__.py │ │ ├── sparse_encoder.py │ │ ├── sparse_unet.py │ │ └── voxel_set_abstraction.py │ │ └── roi_heads │ │ └── bbox_heads │ │ ├── __init__.py │ │ └── parta2_bbox_head.py ├── openpcdet │ ├── README.md │ ├── cfgs_templates │ │ ├── kitti_models │ │ │ ├── PartA2_plugin.yaml │ │ │ ├── pv_rcnn_plugin.yaml │ │ │ ├── second_plugin.yaml │ │ │ └── voxel_rcnn_car_plugin.yaml │ │ └── nuscenes_models │ │ │ └── cbgs_voxel0075_voxelnext.yaml │ ├── converted_models │ │ └── README.md │ ├── converter_voxelnext.py │ ├── demo.ipynb │ ├── pcdet_plugin │ │ ├── __init__.py │ │ └── models │ │ │ ├── __init__.py │ │ │ ├── backbones_2d │ │ │ ├── __init__.py │ │ │ └── map_to_bev │ │ │ │ ├── __init__.py │ │ │ │ └── height_compression.py │ │ │ ├── backbones_3d │ │ │ ├── __init__.py │ │ │ ├── backbone3d.py │ │ │ ├── backbone_voxel_next.py │ │ │ ├── pfe.py │ │ │ └── unet.py │ │ │ ├── dense_heads │ │ │ ├── __init__.py │ │ │ └── voxel_next_head.py │ │ │ ├── detectors │ │ │ ├── __init__.py │ │ │ └── detector3d_template.py │ │ │ └── roi_heads │ │ │ ├── __init__.py │ │ │ └── partA2_head.py │ └── setup.py └── performance.py ├── install.py ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests ├── python │ ├── __init__.py │ ├── test_single_layer_conv.py │ ├── test_to_dense.py │ └── test_utils.py └── test.py └── torchsparse ├── __init__.py ├── backbones ├── __init__.py ├── modules │ ├── __init__.py │ └── blocks.py ├── resnet.py └── unet.py ├── backend ├── convolution │ ├── convolution_backward_wgrad_implicit_gemm_cuda.cu │ ├── convolution_backward_wgrad_implicit_gemm_cuda.h │ ├── convolution_backward_wgrad_implicit_gemm_sorted_cuda.cu │ ├── convolution_backward_wgrad_implicit_gemm_sorted_cuda.h │ ├── convolution_forward_fetch_on_demand_cuda.cu │ ├── convolution_forward_fetch_on_demand_cuda.h │ ├── convolution_forward_implicit_gemm_cuda.cu │ ├── convolution_forward_implicit_gemm_cuda.h │ ├── convolution_forward_implicit_gemm_sorted_cuda.cu │ ├── convolution_forward_implicit_gemm_sorted_cuda.h │ ├── convolution_gather_scatter_cpu.cpp │ ├── convolution_gather_scatter_cpu.h │ ├── convolution_gather_scatter_cuda.cu │ └── convolution_gather_scatter_cuda.h ├── devoxelize │ ├── devoxelize_cpu.cpp │ ├── devoxelize_cpu.h │ ├── devoxelize_cuda.cu │ └── devoxelize_cuda.h ├── hash │ ├── hash_cpu.cpp │ ├── hash_cpu.h │ ├── hash_cuda.cu │ └── hash_cuda.h ├── hashmap │ ├── hashmap_cpu.cpp │ ├── hashmap_cpu.hpp │ └── hashmap_cuda.cuh ├── others │ ├── count_cpu.cpp │ ├── count_cpu.h │ ├── count_cuda.cu │ ├── count_cuda.h │ ├── downsample_cuda.cu │ ├── downsample_cuda.h │ ├── exclusive_scan_cuda.cu │ ├── exclusive_scan_cuda.h │ ├── query_cpu.cpp │ ├── query_cpu.h │ ├── query_cuda.cu │ ├── query_cuda.h │ ├── reduce_bitmask_cuda.cu │ ├── reduce_bitmask_cuda.h │ ├── reorder_map_cuda.cu │ ├── reorder_map_cuda.h │ ├── sparsemapping_cuda.cu │ └── sparsemapping_cuda.h ├── pybind_cpu.cpp ├── pybind_cuda.cu ├── utils │ ├── atomic.cuh │ └── memory.cuh └── voxelize │ ├── voxelize_cpu.cpp │ ├── voxelize_cpu.h │ ├── voxelize_cuda.cu │ └── voxelize_cuda.h ├── backends.py ├── nn ├── __init__.py ├── functional │ ├── __init__.py │ ├── activation.py │ ├── conv │ │ ├── __init__.py │ │ ├── conv.py │ │ ├── conv_config.py │ │ ├── conv_mode.py │ │ ├── func │ │ │ ├── __init__.py │ │ │ ├── fetch_on_demand.py │ │ │ ├── gather_scatter.py │ │ │ └── implicit_gemm.py │ │ ├── hash │ │ │ ├── __init__.py │ │ │ ├── hash.py │ │ │ └── query.py │ │ ├── kmap │ │ │ ├── __init__.py │ │ │ ├── build_kmap.py │ │ │ ├── downsample.py │ │ │ ├── func │ │ │ │ ├── __init__.py │ │ │ │ ├── hashmap.py │ │ │ │ └── hashmap_on_the_fly.py │ │ │ └── upsample.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── collections.py │ │ │ └── compat.py │ ├── count.py │ ├── crop.py │ ├── devoxelize.py │ ├── hash.py │ ├── pooling.py │ ├── query.py │ └── voxelize.py ├── modules │ ├── __init__.py │ ├── activation.py │ ├── bev.py │ ├── conv.py │ ├── crop.py │ ├── norm.py │ └── pooling.py └── utils │ ├── __init__.py │ ├── apply.py │ └── kernel.py ├── operators.py ├── tensor.py ├── utils ├── __init__.py ├── collate.py ├── quantize.py ├── tensor_cache.py ├── to_dense.py ├── tune.py └── utils.py └── version.py /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: 🐞 Bug 2 | description: Report unexpected behavior with the library 3 | title: "[BUG] " 4 | labels: [Bug, Needs Triage] 5 | body: 6 | - type: checkboxes 7 | attributes: 8 | label: Is there an existing issue for this? 9 | description: Please search to see if an issue already exists for the bug you encountered. 10 | options: 11 | - label: I have searched the existing issues 12 | required: true 13 | - type: textarea 14 | attributes: 15 | label: Current Behavior 16 | description: A concise description of what you're experiencing. 17 | validations: 18 | required: false 19 | - type: textarea 20 | attributes: 21 | label: Expected Behavior 22 | description: A concise description of what you expected to happen. 23 | validations: 24 | required: false 25 | - type: textarea 26 | attributes: 27 | label: Environment 28 | description: | 29 | How to find these values: 30 | - **GCC**: `gcc --version` 31 | - **NVCC**: `nvcc --version` 32 | - **PyTorch**: `python -c "import torch; print(torch.__version__)"` 33 | - **PyTorch CUDA**: `python -c "import torch; print(torch.version.cuda)"` 34 | - **TorchSparse**: `python -c "import torchsparse; print(torchsparse.__version__)"` 35 | value: | 36 | - GCC: 37 | - NVCC: 38 | - PyTorch: 39 | - PyTorch CUDA: 40 | - TorchSparse: 41 | render: markdown 42 | validations: 43 | required: false 44 | - type: textarea 45 | attributes: 46 | label: Anything else? 47 | description: | 48 | Links? References? Anything that will give us more context about the issue you are encountering! 49 | 50 | Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in. 51 | validations: 52 | required: false 53 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/installation.yaml: -------------------------------------------------------------------------------- 1 | name: Installation 2 | description: Get help with installation issues, including missing imports. 3 | title: "[Installation] <title>" 4 | labels: [installation] 5 | body: 6 | - type: checkboxes 7 | attributes: 8 | label: Is there an existing issue for this? 9 | description: Please search to see if an issue already exists for the bug you encountered. 10 | options: 11 | - label: I have searched the existing issues 12 | required: true 13 | - type: checkboxes 14 | attributes: 15 | label: Have you followed all the steps in the FAQ? 16 | description: Please follow all the steps [in the FAQ](../blob/master/docs/FAQ.md) before filing an issue. 17 | options: 18 | - label: I have tried the steps in the FAQ. 19 | required: true 20 | - type: textarea 21 | attributes: 22 | label: Current Behavior 23 | description: A concise description of what you're experiencing. 24 | validations: 25 | required: false 26 | - type: textarea 27 | attributes: 28 | label: Error Line 29 | description: | 30 | Look through the log of `FORCE_CUDA=1 pip install --no-cache-dir git+https://github.com/mit-han-lab/torchsparse.git` to find the line that is causing the build to fail. 31 | For example: `fatal error: cuda_runtime_api.h: No such file or directory` is one such compilation error message. 32 | validations: 33 | required: true 34 | - type: textarea 35 | attributes: 36 | label: Environment 37 | description: | 38 | How to find these values: 39 | - **GCC**: `gcc --version` 40 | - **NVCC**: `nvcc --version` 41 | - **PyTorch**: `python -c "import torch; print(torch.__version__)"` 42 | - **PyTorch CUDA**: `python -c "import torch; print(torch.version.cuda)"` 43 | value: | 44 | - GCC: 45 | - NVCC: 46 | - PyTorch: 47 | - PyTorch CUDA: 48 | render: markdown 49 | validations: 50 | required: true 51 | - type: textarea 52 | attributes: 53 | label: Full Error Log 54 | description: Provide the full error log of both your import error and the log of `FORCE_CUDA=1 pip install --no-cache-dir git+https://github.com/mit-han-lab/torchsparse.git`. 55 | value: | 56 | <details> 57 | <summary>Error Log</summary> 58 | 59 | [PUT YOUR ERROR LOG HERE] 60 | </details> 61 | validations: 62 | required: false 63 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yaml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [master] 7 | 8 | jobs: 9 | pre-commit: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - uses: actions/checkout@v2 13 | - uses: actions/setup-python@v2 14 | - uses: pre-commit/action@v2.0.3 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode/ 2 | build/ 3 | *.pyc 4 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.0.1 4 | hooks: 5 | - id: trailing-whitespace 6 | name: (Common) Remove trailing whitespaces 7 | - id: mixed-line-ending 8 | name: (Common) Fix mixed line ending 9 | args: [--fix=lf] 10 | - id: end-of-file-fixer 11 | name: (Common) Remove extra EOF newlines 12 | - id: check-merge-conflict 13 | name: (Common) Check for merge conflicts 14 | - id: requirements-txt-fixer 15 | name: (Common) Sort "requirements.txt" 16 | - id: fix-encoding-pragma 17 | name: (Python) Remove encoding pragmas 18 | args: [--remove] 19 | - id: double-quote-string-fixer 20 | name: (Python) Fix double-quoted strings 21 | - id: debug-statements 22 | name: (Python) Check for debugger imports 23 | - id: check-json 24 | name: (JSON) Check syntax 25 | - id: check-yaml 26 | name: (YAML) Check syntax 27 | - id: check-toml 28 | name: (TOML) Check syntax 29 | - repo: https://github.com/executablebooks/mdformat 30 | rev: 0.7.7 31 | hooks: 32 | - id: mdformat 33 | name: (Markdown) Format with mdformat 34 | - repo: https://github.com/asottile/pyupgrade 35 | rev: v3.0.0 36 | hooks: 37 | - id: pyupgrade 38 | name: (Python) Update syntax for newer versions 39 | args: [--py36-plus] 40 | - repo: https://github.com/google/yapf 41 | rev: v0.31.0 42 | hooks: 43 | - id: yapf 44 | name: (Python) Format with yapf 45 | - repo: https://github.com/pycqa/isort 46 | rev: 5.8.0 47 | hooks: 48 | - id: isort 49 | name: (Python) Sort imports with isort 50 | - repo: https://github.com/pycqa/flake8 51 | rev: 3.9.2 52 | hooks: 53 | - id: flake8 54 | name: (Python) Check with flake8 55 | additional_dependencies: 56 | - flake8-bugbear 57 | - flake8-comprehensions 58 | - flake8-docstrings 59 | - flake8-executable 60 | - flake8-quotes 61 | - repo: https://github.com/pre-commit/mirrors-mypy 62 | rev: v0.902 63 | hooks: 64 | - id: mypy 65 | name: (Python) Check with mypy 66 | additional_dependencies: 67 | - tokenize-rt 68 | - types-pyyaml 69 | - types-toml 70 | - repo: https://github.com/pre-commit/mirrors-clang-format 71 | rev: v13.0.0 72 | hooks: 73 | - id: clang-format 74 | name: (C/C++/CUDA) Format with clang-format 75 | args: [-style=google, -i] 76 | types_or: [c, c++, cuda] 77 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020-2021 TorchSparse Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /cython_setup.py: -------------------------------------------------------------------------------- 1 | # cython: language_level=3 2 | import glob 3 | import os 4 | import sys 5 | 6 | import torch 7 | import torch.cuda 8 | from setuptools import find_packages, setup 9 | from torch.utils.cpp_extension import ( 10 | CUDA_HOME, 11 | BuildExtension, 12 | CppExtension, 13 | CUDAExtension, 14 | ) 15 | 16 | # from torchsparse import __version__ 17 | 18 | from Cython.Build import cythonize 19 | 20 | cython_clean_flag = False 21 | 22 | version_file = open("./torchsparse/version.py") 23 | version = version_file.read().split("'")[1] 24 | print("torchsparse version:", version) 25 | 26 | if (torch.cuda.is_available() and CUDA_HOME is not None) or ( 27 | os.getenv("FORCE_CUDA", "0") == "1" 28 | ): 29 | device = "cuda" 30 | pybind_fn = f"pybind_{device}.cu" 31 | else: 32 | device = "cpu" 33 | pybind_fn = f"pybind_{device}.cpp" 34 | 35 | sources = [os.path.join("torchsparse", "backend", pybind_fn)] 36 | for fpath in glob.glob(os.path.join("torchsparse", "backend", "**", "*")): 37 | if (fpath.endswith("_cpu.cpp") and device in ["cpu", "cuda"]) or ( 38 | fpath.endswith("_cuda.cu") and device == "cuda" 39 | ): 40 | sources.append(fpath) 41 | 42 | pyx_files = [] 43 | for root, dirnames, filenames in os.walk("torchsparse"): 44 | for filename in filenames: 45 | file_path = os.path.join(root, filename) 46 | if file_path.endswith(".py"): 47 | file_path2 = file_path + "x" 48 | os.system("mv " + file_path + " " + file_path2) 49 | os.system("sed -i '1s/^/# cython: language_level=3\\n/' " + file_path2) 50 | pyx_files.append(file_path2) 51 | 52 | if pyx_files == []: 53 | for root, dirnames, filenames in os.walk("torchsparse"): 54 | for filename in filenames: 55 | file_path = os.path.join(root, filename) 56 | if file_path.endswith(".pyx"): 57 | pyx_files.append(file_path) 58 | 59 | extension_type = CUDAExtension if device == "cuda" else CppExtension 60 | extra_compile_args = { 61 | "cxx": ["-g", "-O3", "-fopenmp", "-lgomp"], 62 | "nvcc": ["-O3", "-std=c++17"], 63 | } 64 | 65 | setup( 66 | name="torchsparse", 67 | version=version, 68 | packages=find_packages(), 69 | ext_modules=cythonize( 70 | [ 71 | extension_type( 72 | "torchsparse.backend", sources, extra_compile_args=extra_compile_args 73 | ), 74 | ] 75 | + pyx_files 76 | ), 77 | install_requires=[ 78 | "numpy", 79 | "backports.cached_property", 80 | "tqdm", 81 | "typing-extensions", 82 | "wheel", 83 | "rootpath", 84 | "attributedict", 85 | ], 86 | cmdclass={"build_ext": BuildExtension}, 87 | zip_safe=False, 88 | ) 89 | 90 | # Clean up 91 | if cython_clean_flag: 92 | for root, dirnames, filenames in os.walk("torchsparse"): 93 | for filename in filenames: 94 | file_path = os.path.join(root, filename) 95 | if file_path.endswith(".c"): 96 | os.system("rm " + file_path) 97 | if file_path.endswith(".pyx"): 98 | os.system("rm " + file_path) 99 | -------------------------------------------------------------------------------- /docs/FAQ.md: -------------------------------------------------------------------------------- 1 | ## Frequently Asked Questions 2 | 3 | Before posting an issue, please go through the following troubleshooting steps on your own: 4 | 5 | - Check whether the issue is TorchSparse specific or environment specific. Try creating an isolated environment via Docker or on another computer and see if the error persists. If using TorchSparse as a dependancy of another project, ensure the downstream project is compatible with the version of TorchSparse that you installed. 6 | 7 | - Read the error logs line-by-line: if it's a compilation error, the problem will be shown in the log. Often, compilation issues will come from incorrectly configured environment, such as an improper NVCC or PyTorch installation, rather than incompatibility with this library. Please paste the full log message of `pip install -v git+https://github.com/mit-han-lab/torchsparse.git` when you submit the issue. 8 | 9 | - Try [completely uninstalling CUDA](https://askubuntu.com/q/530043) and make sure that there are no additional CUDA installations: 10 | 11 | ```bash 12 | ls /usr/local/cuda* -d 13 | ``` 14 | 15 | - Then, follow **all** of the steps for toolkit installation in the [CUDA installation guide](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html), especially the [post installation actions](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#post-installation-actions) to set your `LD_LIBRARY_PATH` and `PATH`. 16 | 17 | - Ensure that PyTorch and NVCC are using the same version of CUDA: 18 | 19 | ```bash 20 | nvcc --version 21 | python -c "import torch; print(torch.version.cuda);" 22 | ``` 23 | 24 | - If you're trying to cross-compile the library (i.e. compiling for a different GPU than the one in the system at build time, such as in a docker build), make use of the `TORCH_CUDA_ARCH_LIST` environmental variable. You can use [this chart](http://arnon.dk/matching-sm-architectures-arch-and-gencode-for-various-nvidia-cards/) to find your architecture/gencode. For example, if you want to compile for a Turing-architecture GPU, you would do: 25 | 26 | ```bash 27 | TORCH_CUDA_ARCH_LIST="7.0;7.5" pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git 28 | ``` 29 | 30 | - If you see `Killed` in the compilation log, it's likely the compilation failed due to out of memory as a result of parallel compilation. You can limit the number of CPUs the compiler will use by setting the `MAX_JOBS` environmental variable before installation: 31 | 32 | ```bash 33 | MAX_JOBS=2 pip install --upgrade git+https://github.com/mit-han-lab/torchsparse.git 34 | ``` 35 | -------------------------------------------------------------------------------- /docs/figs/eval_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/torchsparse/385f5ce8718fcae93540511b7f5832f4e71fd835/docs/figs/eval_benchmark.png -------------------------------------------------------------------------------- /docs/figs/torchsparse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/torchsparse/385f5ce8718fcae93540511b7f5832f4e71fd835/docs/figs/torchsparse.png -------------------------------------------------------------------------------- /docs/figs/train_benchmark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mit-han-lab/torchsparse/385f5ce8718fcae93540511b7f5832f4e71fd835/docs/figs/train_benchmark.png -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Containers 2 | A docker image is created with all the required environment installed: `ioeddk/torchsparse_plugin_demo:latest`, including MMDetection3D, OpenPCDet, TorchSparse, plugins, and PyTorch based on the NVIDIA CUDA 12.1 image. 3 | The dataset is not included in the image and need to be bind mounted to the container when starting. Specifically with the following command: 4 | ```bash 5 | docker run -it --gpus all --mount type=bind,source=<kitti_dataset_root>,target=/root/data/kitti --mount type=bind,source=<nuscenes_dataset_root>,target=/root/data/nuscenes ioeddk/torchsparse_plugin_demo:latest 6 | ``` 7 | The above is an example to mount the kitti dataset when starting the container. 8 | 9 | Using this container is the simplest way to start the demo of this plugin since the all the dependencies are installed and the paths are configured. You can simply open `/root/repo/torchsparse-dev/examples/mmdetection3d/demo.ipynb` or `/root/repo/torchsparse-dev/examples/openpcdet/demo.ipynb` and run all cells to run the demo. The helper functions in the demo are defined to automatically load the pretrained checkpoints, do the conversions, and run the evaluation. 10 | 11 | If not using the container, then please follow the tutorial below to run the demo. The same copy of demo is also in the demo notebook. 12 | 13 | # Convert the Module Weights 14 | The dimensions of TorchSparse differs from the SpConv, so the parameter dimension conversion is required to use the TorchSparse backend. The conversion script can be found in `examples/converter.py`. The `convert_weighs` function has the header `def convert_weights(ckpt_before: str, ckpt_after: str, cfg_path: str, v_spconv: int = 1, framework: str = "mmdet3d")`: 15 | - `ckpt_before`: the pretrained checkpoint of your module, typically downloaded from the MMDetection3d and OpenPCDet model Zoo. 16 | - `ckpt_after`: the output path for the converted checkpoint. 17 | - `cfg_path`: the path to the config file of the MMdet3d or OPC model to be converted. It is requried since the converter create an instance of the model, find all the Sparse Convolution layers, and convert the weights of thay layer. 18 | - `v_spconv`: the version of the SpConv that the original model is build upon. Valud versions are 1 or 2. 19 | - `framework`: choose between `mmdet3d` and `openpc`. 20 | 21 | ## Example Conversion Commands 22 | ### MMDetection3D 23 | ```bash 24 | python examples/converter.py --ckpt_before ../mmdetection3d/models/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --cfg_path ../mmdetection3d/pv_rcnn/pv_rcnn_8xb2-80e_kitti-3d-3class.py --ckpt_after ./converted/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --v_spconv 1 --framework mmdet3d 25 | ``` 26 | 27 | ### OpenPCDet 28 | ```bash 29 | python examples/converter.py --ckpt_before ../OpenPCDet/models/SECOND/second_7862.pth --cfg_path ../OpenPCDet/tools/cfgs/kitti_models/second.yaml --ckpt_after ./converted/SECOND/second_7862.pth --v_spconv 1 --framework openpc 30 | ``` 31 | 32 | # Run evaluation. 33 | Use the `test.py` that comes with the MMDet3D or OPC to run the evaluation. Provide the converted checkpoint as the model weights. For MMDet3D models, you need to provide extra arguments to replace certain layers to be torchsparse's (see how to replace them in `examples/mmdetection3d/demo.ipynb`). For OpenPCDet, the config file with those layers replaced is in the `examples/openpcdet/cfgs`; to use them, see `examples/openpcdet/demo.ipynb`. An additional step is to add `import ts_plugin` in `mmdetection3d/tools/test.py` and add `import pcdet_plugin` in `OpenPCDet/tools/test.py` to activate the plugins before running the evaluation. 34 | 35 | # Details 36 | Please see `examples/mmdetection3d/demo.ipynb` and `examples/openpcdet/demo.ipynb` for more details. 37 | -------------------------------------------------------------------------------- /examples/backbones.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | 5 | from torchsparse import SparseTensor 6 | from torchsparse.backbones import SparseResNet21D, SparseResUNet42 7 | from torchsparse.utils.quantize import sparse_quantize 8 | 9 | 10 | @torch.no_grad() 11 | def main() -> None: 12 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 13 | from torchsparse.nn import functional as F 14 | 15 | F.set_kmap_mode("hashmap") 16 | 17 | for backbone in [SparseResNet21D, SparseResUNet42]: 18 | print(f"{backbone.__name__}:") 19 | model: nn.Module = backbone(in_channels=4, width_multiplier=1.0) 20 | model = model.to(device).eval() 21 | 22 | # generate data 23 | input_size, voxel_size = 10000, 0.2 24 | inputs = np.random.uniform(-100, 100, size=(input_size, 4)) 25 | pcs, feats = inputs[:, :3], inputs 26 | pcs -= np.min(pcs, axis=0, keepdims=True) 27 | pcs, indices = sparse_quantize(pcs, voxel_size, return_index=True) 28 | coords = np.zeros((pcs.shape[0], 4)) 29 | coords[:, 1:4] = pcs[:, :3] 30 | coords[:, 0] = 0 31 | coords = torch.as_tensor(coords, dtype=torch.int) 32 | feats = torch.as_tensor(feats[indices], dtype=torch.float) 33 | input = SparseTensor(coords=coords, feats=feats).to(device) 34 | 35 | # forward 36 | outputs = model(input) 37 | 38 | # print feature shapes 39 | for k, output in enumerate(outputs): 40 | print(f"output[{k}].F.shape = {output.feats.shape}") 41 | 42 | 43 | if __name__ == "__main__": 44 | main() 45 | -------------------------------------------------------------------------------- /examples/example.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from typing import Any, Dict 4 | 5 | import numpy as np 6 | import torch 7 | import torch.utils.data 8 | from torch import nn 9 | from torch.cuda import amp 10 | 11 | import torchsparse 12 | from torchsparse import SparseTensor 13 | from torchsparse import nn as spnn 14 | from torchsparse.nn import functional as F 15 | from torchsparse.utils.collate import sparse_collate_fn 16 | from torchsparse.utils.quantize import sparse_quantize 17 | 18 | 19 | class RandomDataset: 20 | def __init__(self, input_size: int, voxel_size: float) -> None: 21 | self.input_size = input_size 22 | self.voxel_size = voxel_size 23 | 24 | def __getitem__(self, _: int) -> Dict[str, Any]: 25 | inputs = np.random.uniform(-100, 100, size=(self.input_size, 4)) 26 | labels = np.random.choice(10, size=self.input_size) 27 | 28 | coords, feats = inputs[:, :3], inputs 29 | coords -= np.min(coords, axis=0, keepdims=True) 30 | coords, indices = sparse_quantize(coords, self.voxel_size, return_index=True) 31 | 32 | coords = torch.tensor(coords, dtype=torch.int) 33 | feats = torch.tensor(feats[indices], dtype=torch.float) 34 | labels = torch.tensor(labels[indices], dtype=torch.long) 35 | 36 | input = SparseTensor(coords=coords, feats=feats) 37 | label = SparseTensor(coords=coords, feats=labels) 38 | return {"input": input, "label": label} 39 | 40 | def __len__(self): 41 | return 100 42 | 43 | 44 | if __name__ == "__main__": 45 | conv_config = F.get_default_conv_config() 46 | # conv_config.dataflow = F.Dataflow.GatherScatter 47 | F.set_global_conv_config(conv_config) 48 | 49 | parser = argparse.ArgumentParser() 50 | parser.add_argument("--device", type=str, default="cuda") 51 | parser.add_argument("--amp_enabled", action="store_true") 52 | args = parser.parse_args() 53 | 54 | random.seed(0) 55 | np.random.seed(0) 56 | torch.manual_seed(0) 57 | 58 | dataset = RandomDataset(input_size=10000, voxel_size=0.2) 59 | dataflow = torch.utils.data.DataLoader( 60 | dataset, 61 | batch_size=2, 62 | collate_fn=sparse_collate_fn, 63 | ) 64 | 65 | model = nn.Sequential( 66 | spnn.Conv3d(4, 32, 3), 67 | spnn.BatchNorm(32), 68 | spnn.ReLU(True), 69 | spnn.Conv3d(32, 64, 2, stride=2), 70 | spnn.BatchNorm(64), 71 | spnn.ReLU(True), 72 | spnn.Conv3d(64, 64, 2, stride=2, transposed=True), 73 | spnn.BatchNorm(64), 74 | spnn.ReLU(True), 75 | spnn.Conv3d(64, 32, 3), 76 | spnn.BatchNorm(32), 77 | spnn.ReLU(True), 78 | spnn.Conv3d(32, 10, 1), 79 | ).to(args.device) 80 | 81 | criterion = nn.CrossEntropyLoss() 82 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) 83 | scaler = amp.GradScaler(enabled=args.amp_enabled) 84 | 85 | for k, feed_dict in enumerate(dataflow): 86 | inputs = feed_dict["input"].to(device=args.device) 87 | labels = feed_dict["label"].to(device=args.device) 88 | 89 | with amp.autocast(enabled=args.amp_enabled): 90 | outputs = model(inputs) 91 | loss = criterion(outputs.feats, labels.feats) 92 | 93 | print(f"[step {k + 1}] loss = {loss.item()}") 94 | 95 | optimizer.zero_grad() 96 | scaler.scale(loss).backward() 97 | scaler.step(optimizer) 98 | scaler.update() 99 | 100 | # enable torchsparse 2.0 inference 101 | model.eval() 102 | # enable fused and locality-aware memory access optimization 103 | torchsparse.backends.benchmark = True # type: ignore 104 | 105 | with torch.no_grad(): 106 | for k, feed_dict in enumerate(dataflow): 107 | inputs = feed_dict["input"].to(device=args.device).half() 108 | labels = feed_dict["label"].to(device=args.device) 109 | 110 | with amp.autocast(enabled=True): 111 | outputs = model(inputs) 112 | loss = criterion(outputs.feats, labels.feats) 113 | 114 | print(f"[inference step {k + 1}] loss = {loss.item()}") 115 | -------------------------------------------------------------------------------- /examples/mmdetection3d/README.md: -------------------------------------------------------------------------------- 1 | # TorchSparse for MMDetection3D Plugin Demo 2 | 3 | This tutorial demonstrates how to evaluate TorchSparse integrated MMDetection3D models. Follow the steps below to install dependencies, configure paths, convert model weights, and run the demo. 4 | 5 | ## Dependencies 6 | 7 | 1. **MMDetection3D Installation**: Follow the [MMDetection3D documentation](https://mmdetection3d.readthedocs.io/en/latest/get_started.html). 8 | 2. **Dataset Preparation**: Pre-process the datasets as described [here](https://mmdetection3d.readthedocs.io/en/latest/user_guides/dataset_prepare.html). 9 | 3. **TorchSparse Installation**: Install [TorchSparse](https://github.com/mit-han-lab/torchsparse). 10 | 4. **Install TorchSparse Plugin for MMDetection3D**: 11 | 1. Clone this repository. 12 | 2. Navigate to `examples/mmdetection3d` and run `pip install -v -e .`. 13 | 14 | ## Notes 15 | 16 | - For model evaluation, change the data root in the original MMDetection3D's model config to the full path of the corresponding dataset root. 17 | 18 | ## Steps 19 | 20 | 1. Install the dependencies. 21 | 2. Specify the base paths and model registry. 22 | 3. **IMPORTANT,** Activate the plugin: In `mmdetection3d/tools/test.py`, add `import ts_plugin` as the last import statement to activate the plugin. 23 | 4. Run the evaluation. 24 | 25 | ## Supported Models 26 | 27 | - SECOND 28 | - PV-RCNN 29 | - CenterPoint 30 | - Part-A2 31 | 32 | ## Convert Module Weights 33 | The dimensions of TorchSparse differ from SpConv, so parameter dimension conversion is required. You can use `convert_weights_cmd()` in converter.py as a command line tool or use `convert_weights()` as an API. Both functions have four parameters: 34 | 35 | 1. `ckpt_before`: Path to the input SpConv checkpoint file. 36 | 2. `ckpt_after`: Path where the converted TorchSparse checkpoint will be saved. 37 | 3. `cfg_path`: Path to the configuration mmdet3d file of the model. 38 | 4. `v_spconv`: Version of SpConv used in the original model (1 or 2). 39 | 5. `framework`: Choose between `'openpc'` and `'mmdet3d'`, default to `'mmdet3d'`. 40 | 41 | These parameters allow the converter to locate the input model, specify the output location, understand the model's architecture, and apply the appropriate conversion method based for specific Sparse Conv layers. 42 | 43 | Example conversion commands: 44 | ```bash 45 | python examples/converter.py --ckpt_before ../mmdetection3d/models/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --cfg_path ../mmdetection3d/pv_rcnn/pv_rcnn_8xb2-80e_kitti-3d-3class.py --ckpt_after ./converted/PV-RCNN/pv_rcnn_8xb2-80e_kitti-3d-3class_20221117_234428-b384d22f.pth --v_spconv 1 --framework mmdet3d 46 | ``` 47 | 48 | 49 | # Run a demo 50 | In your Conda environment, run: 51 | ```bash 52 | python <test_file_path> <cfg_path> <torchsparse_model_path> <cfg_options> --task lidar_det 53 | ``` 54 | 55 | - `test_file_path`: The `tools/test.py` file in mmdet3d repository. 56 | - `cfg_path`: The path to the mmdet3d's model config for your model. 57 | - `torchsparse_model_path`: the path to the converted TorchSparse model checkpoint. 58 | - `cfg_options`: The plugin requires the use of MMDet3D cfg_options to tweak certain model layers to be the plugin layers. `cfg_options` examples are below: 59 | 60 | ## SECOND 61 | `cfg_options`: 62 | ```bash 63 | "--cfg-options test_evaluator.pklfile_prefix=outputs/torchsparse/second --cfg-options model.middle_encoder.type=SparseEncoderTS" 64 | ``` 65 | 66 | ## PV-RCNN 67 | `cfg_options`: 68 | ```bash 69 | "--cfg-options test_evaluator.pklfile_prefix=outputs/torchsparse/pv_rcnn --cfg-options model.middle_encoder.type=SparseEncoderTS --cfg-options model.points_encoder.type=VoxelSetAbstractionTS" 70 | ``` 71 | 72 | ### CenterPoint Voxel 0.1 Circular NMS 73 | 74 | Update the path of the NuScenes dataset in the MMDetection3D dataset config `configs/_base_/datasets/nus-3d.py`. 75 | 76 | `cfg_options`: 77 | ```bash 78 | "--cfg-options model.pts_middle_encoder.type=SparseEncoderTS" 79 | ``` -------------------------------------------------------------------------------- /examples/mmdetection3d/configs/README.md: -------------------------------------------------------------------------------- 1 | This folder contains the configs to carry out the demo in mmdetectino3d. -------------------------------------------------------------------------------- /examples/mmdetection3d/converted_models/README.md: -------------------------------------------------------------------------------- 1 | Default model conversion base folder for the demo. Please create the relative path to each specific model under this directory. 2 | -------------------------------------------------------------------------------- /examples/mmdetection3d/scripts/run_evaluation/SECOND.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export MMDET3D_HOME="/home/yingqi/repo/mmdetection3d" && python ${MMDET3D_HOME}/tools/test.py ${MMDET3D_HOME}/configs/second/second_hv_secfpn_8xb6-80e_kitti-3d-3class.py /home/ioeddk/GitHub/torchsparse-dev/examples/mmdetection3d/pretrained_models/backup/second/second_hv_secfpn_8xb6-80e_kitti-3d-3class-b086d0a3-converted.pth --cfg-options test_evaluator.pklfile_prefix=outputs/torchsparse/second --cfg-options model.middle_encoder.type=SparseEncoderTS --task lidar_det -------------------------------------------------------------------------------- /examples/mmdetection3d/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='ts_plugin', 5 | version='0.1', 6 | packages=find_packages(), 7 | ) 8 | -------------------------------------------------------------------------------- /examples/mmdetection3d/ts_plugin/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | 3 | from mmengine.registry import MODELS 4 | 5 | from torchsparse.nn import BatchNorm 6 | 7 | MODELS.register_module('TorchSparseBatchNorm', force=True, module=BatchNorm) -------------------------------------------------------------------------------- /examples/mmdetection3d/ts_plugin/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .layers import * 2 | from .middle_encoders import * 3 | from .roi_heads.bbox_heads import * 4 | from .backbones import * -------------------------------------------------------------------------------- /examples/mmdetection3d/ts_plugin/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import BasicBlockTS -------------------------------------------------------------------------------- /examples/mmdetection3d/ts_plugin/models/layers/__init__.py: -------------------------------------------------------------------------------- 1 | from .sparse_block import SparseBasicBlockTS, replace_feature_ts, make_sparse_convmodule_ts 2 | -------------------------------------------------------------------------------- /examples/mmdetection3d/ts_plugin/models/layers/sparse_block.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torchsparse.nn as spnn 3 | 4 | from ..backbones.resnet import BasicBlockTS 5 | from mmcv.cnn import build_conv_layer, build_norm_layer 6 | 7 | import logging 8 | 9 | def replace_feature_ts(out, new_features): 10 | out.feats = new_features 11 | return out 12 | 13 | 14 | class SparseBasicBlockTS(BasicBlockTS): 15 | """Sparse basic block for PartA^2. 16 | 17 | Sparse basic block implemented with submanifold sparse convolution. 18 | 19 | Args: 20 | inplanes (int): Inplanes of block. 21 | planes (int): Planes of block. 22 | stride (int or Tuple[int]): Stride of the first block. Defaults to 1. 23 | downsample (Module, optional): Down sample module for block. 24 | Defaults to None. 25 | indice_key (str): Indice key for spconv. Default to None. 26 | conv_cfg (:obj:`ConfigDict` or dict, optional): Config dict for 27 | convolution layer. Defaults to None. 28 | norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for 29 | normalization layer. Defaults to None. 30 | """ 31 | 32 | expansion = 1 33 | 34 | def __init__( 35 | self, 36 | inplanes, 37 | planes, 38 | stride=1, 39 | downsample=None, 40 | conv_cfg=None, 41 | norm_cfg=None, 42 | act_cfg=None, 43 | ): 44 | BasicBlockTS.__init__( 45 | self, 46 | inplanes, 47 | planes, 48 | stride=stride, 49 | downsample=downsample, 50 | conv_cfg=conv_cfg, 51 | norm_cfg=norm_cfg, 52 | ) 53 | if act_cfg is not None: 54 | if act_cfg == "swish": 55 | self.relu = spnn.SiLU(inplace=True) 56 | else: 57 | self.relu = spnn.ReLU(inplace=True) 58 | 59 | 60 | 61 | def make_sparse_convmodule_ts( 62 | in_channels, 63 | out_channels, 64 | kernel_size, 65 | stride=1, 66 | padding=0, 67 | conv_type="TorchSparseConv3d", 68 | norm_cfg=None, 69 | order=("conv", "norm", "act"), 70 | activation_type="relu", 71 | indice_key=None, 72 | transposed=False 73 | ): 74 | """Make sparse convolution module. 75 | 76 | Args: 77 | in_channels (int): The number of input channels. 78 | out_channels (int): The number of out channels. 79 | kernel_size (int | Tuple[int]): Kernel size of convolution. 80 | indice_key (str): The indice key used for sparse tensor. 81 | stride (int or tuple[int]): The stride of convolution. 82 | padding (int or tuple[int]): The padding number of input. 83 | conv_type (str): Sparse conv type in spconv. Defaults to 'SubMConv3d'. 84 | norm_cfg (:obj:`ConfigDict` or dict, optional): Config dict for 85 | normalization layer. Defaults to None. 86 | order (Tuple[str]): The order of conv/norm/activation layers. It is a 87 | sequence of "conv", "norm" and "act". Common examples are 88 | ("conv", "norm", "act") and ("act", "conv", "norm"). 89 | Defaults to ('conv', 'norm', 'act'). 90 | 91 | Returns: 92 | spconv.SparseSequential: sparse convolution module. 93 | """ 94 | assert isinstance(order, tuple) and len(order) <= 3 95 | assert set(order) | {"conv", "norm", "act"} == {"conv", "norm", "act"} 96 | 97 | conv_cfg = {"type": conv_type} 98 | 99 | if norm_cfg is None: 100 | norm_cfg = dict(type='BN1d') 101 | 102 | layers = [] 103 | for layer in order: 104 | if layer == "conv": 105 | layers.append( 106 | build_conv_layer( 107 | cfg=conv_cfg, 108 | in_channels=in_channels, 109 | out_channels=out_channels, 110 | kernel_size=kernel_size, 111 | stride=stride, 112 | padding=padding, 113 | bias=False, 114 | transposed=transposed, 115 | ) 116 | # spnn.Conv3d( 117 | # in_channels=in_channels, 118 | # out_channels=out_channels, 119 | # kernel_size=kernel_size, 120 | # stride=stride, 121 | # padding=padding, 122 | # bias=False, 123 | # transposed=transposed) 124 | ) 125 | elif layer == "norm": 126 | assert norm_cfg is not None, "norm_cfg must be provided" 127 | layers.append(build_norm_layer(norm_cfg, out_channels)[1]) 128 | elif layer == "act": 129 | if activation_type == "relu": 130 | layers.append(spnn.ReLU(inplace=True)) 131 | elif activation_type == "swish": 132 | layers.append(spnn.SiLU(inplace=True)) 133 | else: 134 | raise NotImplementedError 135 | layers = nn.Sequential(*layers) 136 | logging.info("Made TorchSparse Module") 137 | return layers -------------------------------------------------------------------------------- /examples/mmdetection3d/ts_plugin/models/middle_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from .sparse_encoder import SparseEncoderTS 2 | from .voxel_set_abstraction import VoxelSetAbstractionTS 3 | from .sparse_unet import SparseUNetTS -------------------------------------------------------------------------------- /examples/mmdetection3d/ts_plugin/models/roi_heads/bbox_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .parta2_bbox_head import PartA2BboxHeadTS -------------------------------------------------------------------------------- /examples/openpcdet/README.md: -------------------------------------------------------------------------------- 1 | # TorchSparse for OpenPCDet Plugin Demo 2 | 3 | This tutorial demonstrates how to evaluate TorchSparse integrated OpenPCDet models. Follow the steps below to install dependencies, configure paths, convert model weights, and run the demo. 4 | 5 | ## Dependencies 6 | 7 | 1. **Conda**: Ensure Conda is installed. 8 | 2. **OpenPCDet Installation**: Follow the [OpenPCDet documentation](https://github.com/open-mmlab/OpenPCDet/blob/master/docs/INSTALL.md). 9 | 3. **Dataset Preparation**: Pre-process the datasets as described [here](https://github.com/open-mmlab/OpenPCDet/blob/master/docs/GETTING_STARTED.md). 10 | 4. **TorchSparse Installation**: Install [TorchSparse](https://github.com/mit-han-lab/torchsparse). 11 | 5. **Install TorchSparse Plugin for OpenPCDet**: 12 | 1. Clone this repository. 13 | 2. Define the environment variable `PCDET_BASE` to point to the installation path of OpenPCDet. 14 | 3. Navigate to `examples/openpcdet` and run `pip install -v -e .`. 15 | 16 | ## Notes 17 | 18 | - You may need to disable PyTorch JIT compile to avoid errors. Add the following to the import section of the relevant `.py` file: 19 | ```python 20 | import torch 21 | torch.jit._state.disable() 22 | ``` 23 | - Modify dataset paths in the model config to absolute paths to avoid `FileNotFoundError`. 24 | 25 | ## Steps 26 | 27 | 1. Install the dependencies. 28 | 2. Specify the base paths and model registry. 29 | 3. **IMPORTANT,** Activate the plugin: In `OpenPCDet/tools/test.py`, add `import pcdet_plugin` as the last import statement to activate the plugin. 30 | 4. Run the evaluation. 31 | 32 | ## Supported Models 33 | 34 | - Kitti: SECOND, PV-RCNN, Part-A2 35 | - NuScenes: VoxelNeXt 36 | 37 | ## Load the Weight Conversion Module 38 | The dimensions of TorchSparse differ from SpConv, so parameter dimension conversion is required. You can use `convert_weights_cmd()` in converter.py as a command line tool or use `convert_weights()` as an API. Both functions have four parameters: 39 | 40 | 1. `ckpt_before`: Path to the input SpConv checkpoint file. 41 | 2. `ckpt_after`: Path where the converted TorchSparse checkpoint will be saved. 42 | 3. `cfg_path`: Path to the configuration mmdet3d file of the model. 43 | 4. `v_spconv`: Version of SpConv used in the original model (1 or 2). 44 | 5. `framework`: Choose between `'openpc'` and `'mmdet3d'`, default to `'mmdet3d'`. 45 | 46 | These parameters allow the converter to locate the input model, specify the output location, understand the model's architecture, and apply the appropriate conversion method based for specific Sparse Conv layers. 47 | 48 | Example conversion commands: 49 | ```bash 50 | python examples/converter.py --ckpt_before ../OpenPCDet/models/SECOND/second_7862.pth --cfg_path ../OpenPCDet/tools/cfgs/kitti_models/second.yaml --ckpt_after ./converted/SECOND/second_7862.pth --v_spconv 1 --framework openpc 51 | ``` 52 | 53 | 54 | ## Run the Evaluation 55 | In your Conda environment with all the dependencies installed, run the following for the evaluation: 56 | ```bash 57 | python <test_file_path> --cfg_file <torchsparse_cfg_path> --ckpt <torchsparse_model_path> 58 | ``` 59 | 60 | - `test_file_path`: the evaluatino script in OpenPC. 61 | - `torchsparse_cfg_path`: the config file of the model, in `examples/openpcdet/cfgs` folder of this repository. 62 | - `torchsparse_model_path`: converted TorchSparse checkpoint. 63 | 64 | 65 | ### VoxelNeXt 66 | VoxelNeXt requires `examples/openpcdet/converter_voxelnext.py` as a model converter, rather than the general converter.py. 67 | -------------------------------------------------------------------------------- /examples/openpcdet/cfgs_templates/kitti_models/PartA2_plugin.yaml: -------------------------------------------------------------------------------- 1 | CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist'] 2 | 3 | DATA_CONFIG: 4 | _BASE_CONFIG_: {{ pcdet_base_path }}/tools/cfgs/dataset_configs/kitti_dataset.yaml 5 | 6 | 7 | MODEL: 8 | NAME: PartA2Net 9 | 10 | VFE: 11 | NAME: MeanVFE 12 | 13 | BACKBONE_3D: 14 | NAME: UNetV2TS 15 | 16 | MAP_TO_BEV: 17 | NAME: HeightCompressionTS 18 | NUM_BEV_FEATURES: 256 19 | 20 | BACKBONE_2D: 21 | NAME: BaseBEVBackbone 22 | 23 | LAYER_NUMS: [5, 5] 24 | LAYER_STRIDES: [1, 2] 25 | NUM_FILTERS: [128, 256] 26 | UPSAMPLE_STRIDES: [1, 2] 27 | NUM_UPSAMPLE_FILTERS: [256, 256] 28 | 29 | DENSE_HEAD: 30 | NAME: AnchorHeadSingle 31 | CLASS_AGNOSTIC: False 32 | 33 | USE_DIRECTION_CLASSIFIER: True 34 | DIR_OFFSET: 0.78539 35 | DIR_LIMIT_OFFSET: 0.0 36 | NUM_DIR_BINS: 2 37 | 38 | ANCHOR_GENERATOR_CONFIG: [ 39 | { 40 | 'class_name': 'Car', 41 | 'anchor_sizes': [[3.9, 1.6, 1.56]], 42 | 'anchor_rotations': [0, 1.57], 43 | 'anchor_bottom_heights': [-1.78], 44 | 'align_center': False, 45 | 'feature_map_stride': 8, 46 | 'matched_threshold': 0.6, 47 | 'unmatched_threshold': 0.45 48 | }, 49 | { 50 | 'class_name': 'Pedestrian', 51 | 'anchor_sizes': [[0.8, 0.6, 1.73]], 52 | 'anchor_rotations': [0, 1.57], 53 | 'anchor_bottom_heights': [-1.78], 54 | 'align_center': False, 55 | 'feature_map_stride': 8, 56 | 'matched_threshold': 0.5, 57 | 'unmatched_threshold': 0.35 58 | }, 59 | { 60 | 'class_name': 'Cyclist', 61 | 'anchor_sizes': [[1.76, 0.6, 1.73]], 62 | 'anchor_rotations': [0, 1.57], 63 | 'anchor_bottom_heights': [-1.78], 64 | 'align_center': False, 65 | 'feature_map_stride': 8, 66 | 'matched_threshold': 0.5, 67 | 'unmatched_threshold': 0.35 68 | } 69 | ] 70 | 71 | TARGET_ASSIGNER_CONFIG: 72 | NAME: AxisAlignedTargetAssigner 73 | POS_FRACTION: -1.0 74 | SAMPLE_SIZE: 512 75 | NORM_BY_NUM_EXAMPLES: False 76 | MATCH_HEIGHT: False 77 | BOX_CODER: ResidualCoder 78 | 79 | LOSS_CONFIG: 80 | LOSS_WEIGHTS: { 81 | 'cls_weight': 1.0, 82 | 'loc_weight': 2.0, 83 | 'dir_weight': 0.2, 84 | 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 85 | } 86 | 87 | POINT_HEAD: 88 | NAME: PointIntraPartOffsetHead 89 | CLS_FC: [] 90 | PART_FC: [] 91 | CLASS_AGNOSTIC: True 92 | TARGET_CONFIG: 93 | GT_EXTRA_WIDTH: [0.2, 0.2, 0.2] 94 | LOSS_CONFIG: 95 | LOSS_REG: smooth-l1 96 | LOSS_WEIGHTS: { 97 | 'point_cls_weight': 1.0, 98 | 'point_part_weight': 1.0 99 | } 100 | 101 | ROI_HEAD: 102 | NAME: PartA2FCHeadTS 103 | CLASS_AGNOSTIC: True 104 | 105 | SHARED_FC: [256, 256, 256] 106 | CLS_FC: [256, 256] 107 | REG_FC: [256, 256] 108 | DP_RATIO: 0.3 109 | 110 | SEG_MASK_SCORE_THRESH: 0.3 111 | 112 | NMS_CONFIG: 113 | TRAIN: 114 | NMS_TYPE: nms_gpu 115 | MULTI_CLASSES_NMS: False 116 | NMS_PRE_MAXSIZE: 9000 117 | NMS_POST_MAXSIZE: 512 118 | NMS_THRESH: 0.8 119 | TEST: 120 | NMS_TYPE: nms_gpu 121 | MULTI_CLASSES_NMS: False 122 | NMS_PRE_MAXSIZE: 1024 123 | NMS_POST_MAXSIZE: 100 124 | NMS_THRESH: 0.7 125 | 126 | ROI_AWARE_POOL: 127 | POOL_SIZE: 12 128 | NUM_FEATURES: 128 129 | MAX_POINTS_PER_VOXEL: 128 130 | 131 | TARGET_CONFIG: 132 | BOX_CODER: ResidualCoder 133 | ROI_PER_IMAGE: 128 134 | FG_RATIO: 0.5 135 | 136 | SAMPLE_ROI_BY_EACH_CLASS: True 137 | CLS_SCORE_TYPE: roi_iou 138 | 139 | CLS_FG_THRESH: 0.75 140 | CLS_BG_THRESH: 0.25 141 | CLS_BG_THRESH_LO: 0.1 142 | HARD_BG_RATIO: 0.8 143 | 144 | REG_FG_THRESH: 0.65 145 | 146 | LOSS_CONFIG: 147 | CLS_LOSS: BinaryCrossEntropy 148 | REG_LOSS: smooth-l1 149 | CORNER_LOSS_REGULARIZATION: True 150 | LOSS_WEIGHTS: { 151 | 'rcnn_cls_weight': 1.0, 152 | 'rcnn_reg_weight': 1.0, 153 | 'rcnn_corner_weight': 1.0, 154 | 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 155 | } 156 | 157 | POST_PROCESSING: 158 | RECALL_THRESH_LIST: [0.3, 0.5, 0.7] 159 | SCORE_THRESH: 0.1 160 | OUTPUT_RAW_SCORE: False 161 | 162 | EVAL_METRIC: kitti 163 | 164 | NMS_CONFIG: 165 | MULTI_CLASSES_NMS: False 166 | NMS_TYPE: nms_gpu 167 | NMS_THRESH: 0.1 168 | NMS_PRE_MAXSIZE: 4096 169 | NMS_POST_MAXSIZE: 500 170 | 171 | 172 | OPTIMIZATION: 173 | BATCH_SIZE_PER_GPU: 4 174 | NUM_EPOCHS: 80 175 | 176 | OPTIMIZER: adam_onecycle 177 | LR: 0.01 178 | WEIGHT_DECAY: 0.01 179 | MOMENTUM: 0.9 180 | 181 | MOMS: [0.95, 0.85] 182 | PCT_START: 0.4 183 | DIV_FACTOR: 10 184 | DECAY_STEP_LIST: [35, 45] 185 | LR_DECAY: 0.1 186 | LR_CLIP: 0.0000001 187 | 188 | LR_WARMUP: False 189 | WARMUP_EPOCH: 1 190 | 191 | GRAD_NORM_CLIP: 10 192 | -------------------------------------------------------------------------------- /examples/openpcdet/cfgs_templates/kitti_models/second_plugin.yaml: -------------------------------------------------------------------------------- 1 | CLASS_NAMES: ['Car', 'Pedestrian', 'Cyclist'] 2 | 3 | DATA_CONFIG: 4 | _BASE_CONFIG_: {{ pcdet_base_path }}/tools/cfgs/dataset_configs/kitti_dataset.yaml 5 | 6 | 7 | MODEL: 8 | NAME: SECONDNet 9 | 10 | VFE: 11 | NAME: MeanVFE 12 | 13 | BACKBONE_3D: 14 | NAME: VoxelBackBone8xTS 15 | 16 | MAP_TO_BEV: 17 | NAME: HeightCompressionTS 18 | NUM_BEV_FEATURES: 256 19 | 20 | BACKBONE_2D: 21 | NAME: BaseBEVBackbone 22 | 23 | LAYER_NUMS: [5, 5] 24 | LAYER_STRIDES: [1, 2] 25 | NUM_FILTERS: [128, 256] 26 | UPSAMPLE_STRIDES: [1, 2] 27 | NUM_UPSAMPLE_FILTERS: [256, 256] 28 | 29 | DENSE_HEAD: 30 | NAME: AnchorHeadSingle 31 | CLASS_AGNOSTIC: False 32 | 33 | USE_DIRECTION_CLASSIFIER: True 34 | DIR_OFFSET: 0.78539 35 | DIR_LIMIT_OFFSET: 0.0 36 | NUM_DIR_BINS: 2 37 | 38 | ANCHOR_GENERATOR_CONFIG: [ 39 | { 40 | 'class_name': 'Car', 41 | 'anchor_sizes': [[3.9, 1.6, 1.56]], 42 | 'anchor_rotations': [0, 1.57], 43 | 'anchor_bottom_heights': [-1.78], 44 | 'align_center': False, 45 | 'feature_map_stride': 8, 46 | 'matched_threshold': 0.6, 47 | 'unmatched_threshold': 0.45 48 | }, 49 | { 50 | 'class_name': 'Pedestrian', 51 | 'anchor_sizes': [[0.8, 0.6, 1.73]], 52 | 'anchor_rotations': [0, 1.57], 53 | 'anchor_bottom_heights': [-0.6], 54 | 'align_center': False, 55 | 'feature_map_stride': 8, 56 | 'matched_threshold': 0.5, 57 | 'unmatched_threshold': 0.35 58 | }, 59 | { 60 | 'class_name': 'Cyclist', 61 | 'anchor_sizes': [[1.76, 0.6, 1.73]], 62 | 'anchor_rotations': [0, 1.57], 63 | 'anchor_bottom_heights': [-0.6], 64 | 'align_center': False, 65 | 'feature_map_stride': 8, 66 | 'matched_threshold': 0.5, 67 | 'unmatched_threshold': 0.35 68 | } 69 | ] 70 | 71 | TARGET_ASSIGNER_CONFIG: 72 | NAME: AxisAlignedTargetAssigner 73 | POS_FRACTION: -1.0 74 | SAMPLE_SIZE: 512 75 | NORM_BY_NUM_EXAMPLES: False 76 | MATCH_HEIGHT: False 77 | BOX_CODER: ResidualCoder 78 | 79 | LOSS_CONFIG: 80 | LOSS_WEIGHTS: { 81 | 'cls_weight': 1.0, 82 | 'loc_weight': 2.0, 83 | 'dir_weight': 0.2, 84 | 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] 85 | } 86 | 87 | POST_PROCESSING: 88 | RECALL_THRESH_LIST: [0.3, 0.5, 0.7] 89 | SCORE_THRESH: 0.1 90 | OUTPUT_RAW_SCORE: False 91 | 92 | EVAL_METRIC: kitti 93 | 94 | NMS_CONFIG: 95 | MULTI_CLASSES_NMS: False 96 | NMS_TYPE: nms_gpu 97 | NMS_THRESH: 0.01 98 | NMS_PRE_MAXSIZE: 4096 99 | NMS_POST_MAXSIZE: 500 100 | 101 | 102 | OPTIMIZATION: 103 | BATCH_SIZE_PER_GPU: 4 104 | NUM_EPOCHS: 80 105 | 106 | OPTIMIZER: adam_onecycle 107 | LR: 0.003 108 | WEIGHT_DECAY: 0.01 109 | MOMENTUM: 0.9 110 | 111 | MOMS: [0.95, 0.85] 112 | PCT_START: 0.4 113 | DIV_FACTOR: 10 114 | DECAY_STEP_LIST: [35, 45] 115 | LR_DECAY: 0.1 116 | LR_CLIP: 0.0000001 117 | 118 | LR_WARMUP: False 119 | WARMUP_EPOCH: 1 120 | 121 | GRAD_NORM_CLIP: 10 122 | -------------------------------------------------------------------------------- /examples/openpcdet/cfgs_templates/nuscenes_models/cbgs_voxel0075_voxelnext.yaml: -------------------------------------------------------------------------------- 1 | CLASS_NAMES: ['car','truck', 'construction_vehicle', 'bus', 'trailer', 2 | 'barrier', 'motorcycle', 'bicycle', 'pedestrian', 'traffic_cone'] 3 | 4 | DATA_CONFIG: 5 | _BASE_CONFIG_: {{ pcdet_base_path }}/tools/cfgs/dataset_configs/nuscenes_dataset.yaml 6 | POINT_CLOUD_RANGE: [-54.0, -54.0, -5.0, 54.0, 54.0, 3.0] 7 | INFO_PATH: { 8 | 'train': [nuscenes_infos_10sweeps_train.pkl], 9 | 'test': [nuscenes_infos_10sweeps_val.pkl], 10 | } 11 | DATA_AUGMENTOR: 12 | DISABLE_AUG_LIST: ['placeholder'] 13 | AUG_CONFIG_LIST: 14 | - NAME: gt_sampling 15 | DB_INFO_PATH: 16 | - nuscenes_dbinfos_10sweeps_withvelo.pkl 17 | USE_SHARED_MEMORY: False #True # set it to True to speed up (it costs about 15GB shared memory) 18 | DB_DATA_PATH: 19 | - nuscenes_dbinfos_10sweeps_withvelo_global.pkl.npy 20 | PREPARE: { 21 | filter_by_min_points: [ 22 | 'car:5','truck:5', 'construction_vehicle:5', 'bus:5', 'trailer:5', 23 | 'barrier:5', 'motorcycle:5', 'bicycle:5', 'pedestrian:5', 'traffic_cone:5' 24 | ], 25 | } 26 | 27 | SAMPLE_GROUPS: [ 28 | 'car:2','truck:2', 'construction_vehicle:2', 'bus:2', 'trailer:2', 29 | 'barrier:2', 'motorcycle:2', 'bicycle:2', 'pedestrian:2', 'traffic_cone:2' 30 | ] 31 | 32 | NUM_POINT_FEATURES: 5 33 | DATABASE_WITH_FAKELIDAR: False 34 | REMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0] 35 | LIMIT_WHOLE_SCENE: True 36 | 37 | - NAME: random_world_flip 38 | ALONG_AXIS_LIST: ['x', 'y'] 39 | 40 | - NAME: random_world_rotation 41 | WORLD_ROT_ANGLE: [-0.78539816, 0.78539816] 42 | 43 | - NAME: random_world_scaling 44 | WORLD_SCALE_RANGE: [0.9, 1.1] 45 | 46 | - NAME: random_world_translation 47 | NOISE_TRANSLATE_STD: [0.5, 0.5, 0.5] 48 | 49 | 50 | DATA_PROCESSOR: 51 | - NAME: mask_points_and_boxes_outside_range 52 | REMOVE_OUTSIDE_BOXES: True 53 | 54 | - NAME: shuffle_points 55 | SHUFFLE_ENABLED: { 56 | 'train': True, 57 | 'test': True 58 | } 59 | 60 | - NAME: transform_points_to_voxels 61 | VOXEL_SIZE: [0.075, 0.075, 0.2] 62 | MAX_POINTS_PER_VOXEL: 10 63 | MAX_NUMBER_OF_VOXELS: { 64 | 'train': 120000, 65 | 'test': 160000 66 | } 67 | 68 | 69 | MODEL: 70 | NAME: VoxelNeXt 71 | 72 | VFE: 73 | NAME: MeanVFE 74 | 75 | BACKBONE_3D: 76 | NAME: VoxelResBackBone8xVoxelNeXtTS 77 | 78 | DENSE_HEAD: 79 | NAME: VoxelNeXtHeadTS 80 | CLASS_AGNOSTIC: False 81 | INPUT_FEATURES: 128 82 | 83 | CLASS_NAMES_EACH_HEAD: [ 84 | ['car'], 85 | ['truck', 'construction_vehicle'], 86 | ['bus', 'trailer'], 87 | ['barrier'], 88 | ['motorcycle', 'bicycle'], 89 | ['pedestrian', 'traffic_cone'], 90 | ] 91 | 92 | SHARED_CONV_CHANNEL: 128 93 | KERNEL_SIZE_HEAD: 1 94 | 95 | USE_BIAS_BEFORE_NORM: True 96 | NUM_HM_CONV: 2 97 | SEPARATE_HEAD_CFG: 98 | HEAD_ORDER: ['center', 'center_z', 'dim', 'rot', 'vel'] 99 | HEAD_DICT: { 100 | 'center': {'out_channels': 2, 'num_conv': 2}, 101 | 'center_z': {'out_channels': 1, 'num_conv': 2}, 102 | 'dim': {'out_channels': 3, 'num_conv': 2}, 103 | 'rot': {'out_channels': 2, 'num_conv': 2}, 104 | 'vel': {'out_channels': 2, 'num_conv': 2}, 105 | } 106 | 107 | TARGET_ASSIGNER_CONFIG: 108 | FEATURE_MAP_STRIDE: 8 109 | NUM_MAX_OBJS: 500 110 | GAUSSIAN_OVERLAP: 0.1 111 | MIN_RADIUS: 2 112 | 113 | LOSS_CONFIG: 114 | LOSS_WEIGHTS: { 115 | 'cls_weight': 1.0, 116 | 'loc_weight': 0.25, 117 | 'code_weights': [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.2, 0.2, 1.0, 1.0] 118 | } 119 | 120 | POST_PROCESSING: 121 | SCORE_THRESH: 0.1 122 | POST_CENTER_LIMIT_RANGE: [-61.2, -61.2, -10.0, 61.2, 61.2, 10.0] 123 | MAX_OBJ_PER_SAMPLE: 500 124 | NMS_CONFIG: 125 | NMS_TYPE: nms_gpu 126 | NMS_THRESH: 0.2 127 | NMS_PRE_MAXSIZE: 1000 128 | NMS_POST_MAXSIZE: 83 129 | 130 | POST_PROCESSING: 131 | RECALL_THRESH_LIST: [0.3, 0.5, 0.7] 132 | 133 | EVAL_METRIC: kitti 134 | 135 | 136 | 137 | OPTIMIZATION: 138 | BATCH_SIZE_PER_GPU: 4 139 | NUM_EPOCHS: 20 140 | 141 | OPTIMIZER: adam_onecycle 142 | LR: 0.001 143 | WEIGHT_DECAY: 0.01 144 | MOMENTUM: 0.9 145 | 146 | MOMS: [0.95, 0.85] 147 | PCT_START: 0.4 148 | DIV_FACTOR: 10 149 | DECAY_STEP_LIST: [35, 45] 150 | LR_DECAY: 0.1 151 | LR_CLIP: 0.0000001 152 | 153 | LR_WARMUP: False 154 | WARMUP_EPOCH: 1 155 | 156 | GRAD_NORM_CLIP: 10 157 | -------------------------------------------------------------------------------- /examples/openpcdet/converted_models/README.md: -------------------------------------------------------------------------------- 1 | Default model conversion base folder for the demo. Please create the relative path to each specific model under this directory. 2 | -------------------------------------------------------------------------------- /examples/openpcdet/pcdet_plugin/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import * 2 | 3 | # Apply Monkey Patch 4 | # Monkey Patch here 5 | import pcdet.models.backbones_3d as pcd_backbones_3d 6 | import pcdet.models.backbones_3d.pfe as pcd_pfe 7 | import pcdet.models.backbones_2d.map_to_bev as pcd_map_to_bev 8 | import pcdet.models.dense_heads as pcd_dense_heads 9 | import pcdet.models.roi_heads as pcd_roi_heads 10 | 11 | import pcdet_plugin.models.backbones_3d as pcd_plugin_backbones_3d 12 | import pcdet_plugin.models.backbones_3d.pfe as pcd_plugin_pfe 13 | import pcdet_plugin.models.backbones_3d.unet as pcd_plugin_spconv_unet 14 | import pcdet_plugin.models.backbones_3d.backbone_voxel_next as pcd_plugin_backbone_voxel_next 15 | import pcdet_plugin.models.backbones_2d.map_to_bev as pcd_plugin_map_to_bev 16 | import pcdet_plugin.models.dense_heads.voxel_next_head as pcd_plugin_voxel_next_head 17 | import pcdet_plugin.models.roi_heads.partA2_head as pcd_plugin_partA2_head 18 | 19 | import pcdet_plugin.models.detectors.detector3d_template as pcd_plugin_detector3d_template 20 | 21 | pcd_backbones_3d.__all__['VoxelBackBone8xTS'] = pcd_plugin_backbones_3d.VoxelBackBone8xTS 22 | pcd_backbones_3d.__all__['UNetV2TS'] = pcd_plugin_spconv_unet.UNetV2TS 23 | pcd_backbones_3d.__all__['VoxelResBackBone8xVoxelNeXtTS'] = pcd_plugin_backbone_voxel_next.VoxelResBackBone8xVoxelNeXtTS 24 | pcd_map_to_bev.__all__['HeightCompressionTS'] = pcd_plugin_map_to_bev.HeightCompressionTS 25 | pcd_pfe.__all__['VoxelSetAbstractionTS'] = pcd_plugin_pfe.VoxelSetAbstractionTS 26 | pcd_dense_heads.__all__['VoxelNeXtHeadTS'] = pcd_plugin_voxel_next_head.VoxelNeXtHeadTS 27 | pcd_roi_heads.__all__['PartA2FCHeadTS'] = pcd_plugin_partA2_head.PartA2FCHeadTS 28 | 29 | # Monkey patch the detector 3d template 30 | import pcdet.models.detectors as pcd_detectors 31 | 32 | pcd_detectors.__all__['Detector3DTemplate']._load_state_dict = pcd_plugin_detector3d_template.Detector3DTemplate._load_state_dict 33 | # pcd_detectors.detector3d_template.Detector3DTemplate = pcd_plugin_detector3d_template.Detector3DTemplate -------------------------------------------------------------------------------- /examples/openpcdet/pcdet_plugin/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones_2d import map_to_bev 2 | from .backbones_3d import pfe 3 | from .backbones_3d import backbone3d 4 | from .dense_heads import voxel_next_head 5 | from .detectors import detector3d_template 6 | from .roi_heads import partA2_head -------------------------------------------------------------------------------- /examples/openpcdet/pcdet_plugin/models/backbones_2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .map_to_bev import height_compression -------------------------------------------------------------------------------- /examples/openpcdet/pcdet_plugin/models/backbones_2d/map_to_bev/__init__.py: -------------------------------------------------------------------------------- 1 | from .height_compression import HeightCompressionTS 2 | # from pcdet.models.backbones_2d.map_to_bev.__init__ import __all__ 3 | 4 | # # Try register a pcdet model this way. 5 | # __all__['HeightCompressionTS'] = HeightCompressionTS 6 | -------------------------------------------------------------------------------- /examples/openpcdet/pcdet_plugin/models/backbones_2d/map_to_bev/height_compression.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class HeightCompressionTS(nn.Module): 4 | def __init__(self, model_cfg, **kwargs): 5 | super().__init__() 6 | self.model_cfg = model_cfg 7 | self.num_bev_features = self.model_cfg.NUM_BEV_FEATURES 8 | 9 | def forward(self, batch_dict): 10 | """ 11 | Args: 12 | batch_dict: 13 | encoded_spconv_tensor: sparse tensor 14 | Returns: 15 | batch_dict: 16 | spatial_features: 17 | 18 | """ 19 | encoded_spconv_tensor = batch_dict['encoded_spconv_tensor'] 20 | spatial_features = encoded_spconv_tensor.dense() 21 | 22 | N, D, H, W, C = spatial_features.shape 23 | spatial_features = spatial_features.permute(0, 2, 3, 4, 1).contiguous().reshape(N, H, W, C*D).permute(0, 3, 1, 2).contiguous() 24 | 25 | batch_dict['spatial_features'] = spatial_features 26 | batch_dict['spatial_features_stride'] = batch_dict['encoded_spconv_tensor_stride'] 27 | return batch_dict 28 | 29 | -------------------------------------------------------------------------------- /examples/openpcdet/pcdet_plugin/models/backbones_3d/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone3d import VoxelBackBone8xTS 2 | from .unet import UNetV2TS 3 | from .pfe import VoxelSetAbstractionTS 4 | from .backbone_voxel_next import VoxelResBackBone8xVoxelNeXtTS 5 | -------------------------------------------------------------------------------- /examples/openpcdet/pcdet_plugin/models/backbones_3d/backbone3d.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchsparse 6 | import torchsparse.nn as spnn 7 | 8 | import os, logging 9 | 10 | from pcdet.models.backbones_3d.__init__ import __all__ 11 | 12 | def post_act_block_ts(in_channels, out_channels, kernel_size, indice_key=None, stride=1, padding=0, 13 | conv_type='tsconv', norm_fn=None): 14 | 15 | if conv_type == 'tsconv': 16 | conv = spnn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False) 17 | elif conv_type == 'inverseconv': 18 | conv = spnn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, bias=False, transposed=True) 19 | else: 20 | raise NotImplementedError 21 | 22 | m = nn.Sequential( 23 | conv, 24 | norm_fn(out_channels), 25 | spnn.ReLU(), 26 | ) 27 | 28 | return m 29 | 30 | 31 | class VoxelBackBone8xTS(nn.Module): 32 | def __init__(self, model_cfg, input_channels, grid_size, **kwargs): 33 | super().__init__() 34 | self.model_cfg = model_cfg 35 | norm_fn = partial(spnn.BatchNorm, eps=1e-3, momentum=0.01) 36 | 37 | self.sparse_shape = grid_size[::-1] + [1, 0, 0] 38 | 39 | self.conv_input = nn.Sequential( 40 | spnn.Conv3d(input_channels, 16, 3, padding=1, bias=False), 41 | spnn.BatchNorm(16), 42 | spnn.ReLU(), 43 | ) 44 | 45 | block = post_act_block_ts 46 | 47 | self.conv1 = nn.Sequential( 48 | block(16, 16, 3, norm_fn=norm_fn, padding=1, indice_key='subm1'), 49 | ) 50 | 51 | self.conv2 = nn.Sequential( 52 | # [1600, 1408, 41] <- [800, 704, 21] 53 | block(16, 32, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv2', conv_type='tsconv'), 54 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), 55 | block(32, 32, 3, norm_fn=norm_fn, padding=1, indice_key='subm2'), 56 | ) 57 | 58 | self.conv3 = nn.Sequential( 59 | # [800, 704, 21] <- [400, 352, 11] 60 | block(32, 64, 3, norm_fn=norm_fn, stride=2, padding=1, indice_key='spconv3', conv_type='tsconv'), 61 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), 62 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm3'), 63 | ) 64 | 65 | self.conv4 = nn.Sequential( 66 | # [400, 352, 11] <- [200, 176, 5] 67 | block(64, 64, 3, norm_fn=norm_fn, stride=2, padding=(0, 1, 1), indice_key='spconv4', conv_type='tsconv'), 68 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), 69 | block(64, 64, 3, norm_fn=norm_fn, padding=1, indice_key='subm4'), 70 | ) 71 | 72 | last_pad = 0 73 | last_pad = self.model_cfg.get('last_pad', last_pad) 74 | self.conv_out = nn.Sequential( 75 | # [200, 150, 5] -> [200, 150, 2] 76 | spnn.Conv3d(64, 128, (3, 1, 1), stride=(2, 1, 1), padding=last_pad, 77 | bias=False), 78 | norm_fn(128), 79 | spnn.ReLU(), 80 | ) 81 | self.num_point_features = 128 82 | self.backbone_channels = { 83 | 'x_conv1': 16, 84 | 'x_conv2': 32, 85 | 'x_conv3': 64, 86 | 'x_conv4': 64 87 | } 88 | 89 | logging.warning('Built VoxelBackBone8x for TorchSparse') 90 | 91 | 92 | def forward(self, batch_dict): 93 | """ 94 | Args: 95 | batch_dict: 96 | batch_size: int 97 | vfe_features: (num_voxels, C) 98 | voxel_coords: (num_voxels, 4), [batch_idx, z_idx, y_idx, x_idx] 99 | Returns: 100 | batch_dict: 101 | encoded_spconv_tensor: sparse tensor 102 | """ 103 | voxel_features, voxel_coords = batch_dict['voxel_features'], batch_dict['voxel_coords'] 104 | batch_size = batch_dict['batch_size'] 105 | # input_sp_tensor = spconv.SparseConvTensor( 106 | # features=voxel_features, 107 | # indices=voxel_coords.int(), 108 | # spatial_shape=self.sparse_shape, 109 | # batch_size=batch_size 110 | # ) 111 | voxel_coords = voxel_coords.int() 112 | # input_sp_tensor = spconv.SparseConvTensor(voxel_features, coors, self.sparse_shape, batch_size) 113 | spatial_range = (voxel_coords[:, 0].max().item() + 1,) + tuple(self.sparse_shape) 114 | input_sp_tensor = torchsparse.SparseTensor(voxel_features, voxel_coords, spatial_range=spatial_range) # dimension match 115 | 116 | 117 | x = self.conv_input(input_sp_tensor) 118 | 119 | x_conv1 = self.conv1(x) 120 | x_conv2 = self.conv2(x_conv1) 121 | x_conv3 = self.conv3(x_conv2) 122 | x_conv4 = self.conv4(x_conv3) 123 | 124 | # for detection head 125 | # [200, 176, 5] -> [200, 176, 2] 126 | out = self.conv_out(x_conv4) 127 | 128 | batch_dict.update({ 129 | 'encoded_spconv_tensor': out, 130 | 'encoded_spconv_tensor_stride': 8 131 | }) 132 | batch_dict.update({ 133 | 'multi_scale_3d_features': { 134 | 'x_conv1': x_conv1, 135 | 'x_conv2': x_conv2, 136 | 'x_conv3': x_conv3, 137 | 'x_conv4': x_conv4, 138 | } 139 | }) 140 | batch_dict.update({ 141 | 'multi_scale_3d_strides': { 142 | 'x_conv1': 1, 143 | 'x_conv2': 2, 144 | 'x_conv3': 4, 145 | 'x_conv4': 8, 146 | } 147 | }) 148 | 149 | return batch_dict 150 | 151 | __all__['VoxelBackBone8xTS'] = VoxelBackBone8xTS -------------------------------------------------------------------------------- /examples/openpcdet/pcdet_plugin/models/dense_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .voxel_next_head import VoxelNeXtHeadTS 2 | -------------------------------------------------------------------------------- /examples/openpcdet/pcdet_plugin/models/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .detector3d_template import Detector3DTemplate -------------------------------------------------------------------------------- /examples/openpcdet/pcdet_plugin/models/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .partA2_head import PartA2FCHeadTS -------------------------------------------------------------------------------- /examples/openpcdet/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from jinja2 import Template 3 | import os 4 | 5 | config_template_paths = [ 6 | "./cfgs_templates/kitti_models/second_plugin.yaml", 7 | "./cfgs_templates/kitti_models/PartA2_plugin.yaml", 8 | "./cfgs_templates/kitti_models/pv_rcnn_plugin.yaml", 9 | "./cfgs_templates/kitti_models/voxel_rcnn_car_plugin.yaml", 10 | "./cfgs_templates/nuscenes_models/cbgs_voxel0075_voxelnext_mini.yaml" 11 | ] 12 | 13 | os.makedirs("./cfgs", exist_ok=True) 14 | os.makedirs("./cfgs/kitti_models", exist_ok=True) 15 | os.makedirs("./cfgs/nuscenes_models", exist_ok=True) 16 | 17 | # define PCDET_BASE 18 | if os.environ.get("PCDET_BASE") is None: 19 | # throw some exceptions to ask users to deifne the environment variable 20 | raise ValueError("Please define the environment variable PCDET_BASE") 21 | else: 22 | base = os.environ.get("PCDET_BASE") 23 | print(f"PCDET_BASE: {base}") 24 | for template_path in config_template_paths: 25 | curr_template = Template(open(template_path).read()) 26 | curr_template_rendered = curr_template.render(pcdet_base_path=base) 27 | 28 | file_name = os.path.basename(template_path) 29 | folder_path = os.path.dirname(template_path) 30 | folder_name = os.path.basename(folder_path) 31 | output_file_path = os.path.join("./cfgs", folder_name, file_name) 32 | with open(output_file_path, 'w') as file: 33 | file.write(curr_template_rendered) 34 | 35 | 36 | 37 | 38 | setup( 39 | name='pcdet_plugin', 40 | version='0.1', 41 | packages=find_packages(), 42 | ) 43 | 44 | # Define global initialize torchsparse backend 45 | # design init function, let pcdet traverse the folder we modified. Then in pcdet plugin reference folders. 46 | -------------------------------------------------------------------------------- /examples/performance.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | 3 | import numpy as np 4 | import torch 5 | import torch.autograd.profiler as profiler 6 | import torch.cuda 7 | import torch.nn as nn 8 | import torch.optim 9 | 10 | import torchsparse.nn as spnn 11 | from torchsparse import SparseTensor 12 | from torchsparse.utils.collate import sparse_collate_fn 13 | from torchsparse.utils.quantize import sparse_quantize 14 | 15 | 16 | def generate_random_point_cloud(size=100000, voxel_size=0.2): 17 | pc = np.random.randn(size, 4) 18 | pc[:, :3] = pc[:, :3] * 10 19 | labels = np.random.choice(10, size) 20 | coords, feats = pc[:, :3], pc 21 | coords -= np.min(coords, axis=0, keepdims=True) 22 | coords, indices = sparse_quantize(coords, voxel_size, return_index=True) 23 | 24 | coords = torch.tensor(coords, dtype=torch.int) 25 | feats = torch.tensor(feats[indices], dtype=torch.float) 26 | labels = torch.tensor(labels[indices], dtype=torch.long) 27 | 28 | input = SparseTensor(coords=coords, feats=feats) 29 | label = SparseTensor(coords=coords, feats=labels) 30 | 31 | feed_dict = {"input": input, "label": label} 32 | 33 | return feed_dict 34 | 35 | 36 | def generate_batched_random_point_clouds(size=100000, voxel_size=0.2, batch_size=2): 37 | batch = [] 38 | for _ in range(batch_size): 39 | batch.append(generate_random_point_cloud(size, voxel_size)) 40 | return sparse_collate_fn(batch) 41 | 42 | 43 | def dummy_train_3x3(device): 44 | model = nn.Sequential( 45 | spnn.Conv3d(4, 32, kernel_size=3, stride=1), 46 | spnn.Conv3d(32, 64, kernel_size=3, stride=1), 47 | spnn.Conv3d(64, 128, kernel_size=3, stride=1), 48 | spnn.Conv3d(128, 256, kernel_size=3, stride=1), 49 | spnn.Conv3d(256, 128, kernel_size=3, stride=1, transposed=True), 50 | spnn.Conv3d(128, 64, kernel_size=3, stride=1, transposed=True), 51 | spnn.Conv3d(64, 32, kernel_size=3, stride=1, transposed=True), 52 | spnn.Conv3d(32, 10, kernel_size=3, stride=1, transposed=True), 53 | ).to(device) 54 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 55 | criterion = nn.CrossEntropyLoss().to(device) 56 | 57 | print("Starting dummy_train_3x3...") 58 | time = datetime.now() 59 | with profiler.profile(profile_memory=True, use_cuda=True) as prof: 60 | with profiler.record_function("model_inference"): 61 | for _ in range(10): 62 | feed_dict = generate_batched_random_point_clouds() 63 | inputs = feed_dict["input"].to(device) 64 | targets = feed_dict["label"].F.to(device).long() 65 | outputs = model(inputs) 66 | optimizer.zero_grad() 67 | loss = criterion(outputs.F, targets) 68 | loss.backward() 69 | optimizer.step() 70 | # print('[step %d] loss = %f.'%(i, loss.item())) 71 | print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) 72 | prof.export_chrome_trace("trace_dummy_3x3.json") 73 | 74 | time = datetime.now() - time 75 | print("Finished dummy_train_3x3 in ", time) 76 | 77 | 78 | def dummy_train_3x1(device): 79 | model = nn.Sequential( 80 | spnn.Conv3d(4, 32, kernel_size=(3, 1, 3), stride=1), 81 | spnn.Conv3d(32, 64, kernel_size=(1, 3, 3), stride=1), 82 | spnn.Conv3d(64, 128, kernel_size=(3, 1, 3), stride=1), 83 | spnn.Conv3d(128, 256, kernel_size=(1, 3, 3), stride=1), 84 | spnn.Conv3d(256, 128, kernel_size=(3, 1, 3), stride=1, transposed=True), 85 | spnn.Conv3d(128, 64, kernel_size=(1, 3, 3), stride=1, transposed=True), 86 | spnn.Conv3d(64, 32, kernel_size=(3, 1, 3), stride=1, transposed=True), 87 | spnn.Conv3d(32, 10, kernel_size=(1, 3, 3), stride=1, transposed=True), 88 | ).to(device) 89 | optimizer = torch.optim.Adam(model.parameters(), lr=0.001) 90 | criterion = nn.CrossEntropyLoss().to(device) 91 | 92 | print("Starting dummy_train_3x1 ...") 93 | time = datetime.now() 94 | with profiler.profile(profile_memory=True, use_cuda=True) as prof: 95 | with profiler.record_function("model_inference"): 96 | for _ in range(10): 97 | feed_dict = generate_batched_random_point_clouds() 98 | inputs = feed_dict["input"].to(device) 99 | targets = feed_dict["label"].F.to(device).long() 100 | outputs = model(inputs) 101 | optimizer.zero_grad() 102 | loss = criterion(outputs.F, targets) 103 | loss.backward() 104 | optimizer.step() 105 | # print('[step %d] loss = %f.'%(i, loss.item())) 106 | print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) 107 | prof.export_chrome_trace("trace_dummy_3x1.json") 108 | 109 | time = datetime.now() - time 110 | print("Finished dummy_train_3x1 in ", time) 111 | 112 | 113 | if __name__ == "__main__": 114 | device = "cuda:0" if torch.cuda.is_available() else "cpu" 115 | 116 | dummy_train_3x1(device) 117 | dummy_train_3x3(device) 118 | -------------------------------------------------------------------------------- /install.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import os 3 | import torch 4 | 5 | __version__ = "2.1.0" 6 | 7 | 8 | def find_maximal_match(support_list: List, target): 9 | if target in support_list: 10 | return target 11 | else: 12 | max_match_version = None 13 | for item in support_list: 14 | if item <= target: 15 | max_match_version = item 16 | if max_match_version == None: 17 | max_match_version = support_list[0] 18 | print( 19 | f"[Warning] CUDA version {target} is too low, may not be well supported by torch_{torch.__version__}." 20 | ) 21 | return max_match_version 22 | 23 | 24 | torch_cuda_mapping = dict( 25 | [ 26 | ("torch19", ["11.1"]), 27 | ("torch110", ["11.1", "11.3"]), 28 | ("torch111", ["11.3", "11.5"]), 29 | ("torch112", ["11.3", "11.6"]), 30 | ("torch113", ["11.6", "11.7"]), 31 | ("torch20", ["11.7", "11.8"]), 32 | ] 33 | ) 34 | 35 | torch_tag, _ = ("torch" + torch.__version__).rsplit(".", 1) 36 | torch_tag = torch_tag.replace(".", "") 37 | 38 | if torch.cuda.is_available(): 39 | cuda_version = torch.version.cuda 40 | support_cuda_list = torch_cuda_mapping[torch_tag] 41 | cuda_version = find_maximal_match(support_cuda_list, cuda_version) 42 | cuda_tag = "cu" + cuda_version 43 | else: 44 | cuda_tag = "cpu" 45 | cuda_tag = cuda_tag.replace(".", "") 46 | 47 | 48 | os.system( 49 | f"pip install --extra-index-url http://24.199.104.228/simple --trusted-host 24.199.104.228 torchsparse=={__version__}+{torch_tag}{cuda_tag} --force-reinstall" 50 | ) 51 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backports.cached_property 2 | numpy 3 | tqdm 4 | typing-extensions 5 | wheel 6 | attributedict 7 | rootpath 8 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [yapf] 2 | based_on_style = google 3 | spaces_around_power_operator = true 4 | split_before_arithmetic_operator = true 5 | split_before_logical_operator = true 6 | split_before_bitwise_operator = true 7 | 8 | [isort] 9 | known_first_party = torchsparse 10 | 11 | [pydocstyle] 12 | convention = google 13 | 14 | [flake8] 15 | select = B, C, D, E, F, P, T4, W, B9 16 | ignore = D10, E501, E722, W503 17 | per-file-ignores = 18 | __init__.py: F401, F403 19 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | import torch 5 | import torch.cuda 6 | from setuptools import find_packages, setup 7 | from torch.utils.cpp_extension import ( 8 | CUDA_HOME, 9 | BuildExtension, 10 | CppExtension, 11 | CUDAExtension, 12 | ) 13 | 14 | # from torchsparse import __version__ 15 | 16 | version_file = open("./torchsparse/version.py") 17 | version = version_file.read().split("'")[1] 18 | print("torchsparse version:", version) 19 | 20 | if (torch.cuda.is_available() and CUDA_HOME is not None) or ( 21 | os.getenv("FORCE_CUDA", "0") == "1" 22 | ): 23 | device = "cuda" 24 | pybind_fn = f"pybind_{device}.cu" 25 | else: 26 | device = "cpu" 27 | pybind_fn = f"pybind_{device}.cpp" 28 | 29 | sources = [os.path.join("torchsparse", "backend", pybind_fn)] 30 | for fpath in glob.glob(os.path.join("torchsparse", "backend", "**", "*")): 31 | if (fpath.endswith("_cpu.cpp") and device in ["cpu", "cuda"]) or ( 32 | fpath.endswith("_cuda.cu") and device == "cuda" 33 | ): 34 | sources.append(fpath) 35 | 36 | extension_type = CUDAExtension if device == "cuda" else CppExtension 37 | extra_compile_args = { 38 | "cxx": ["-g", "-O3", "-fopenmp", "-lgomp"], 39 | "nvcc": ["-O3", "-std=c++17"], 40 | } 41 | 42 | setup( 43 | name="torchsparse", 44 | version=version, 45 | packages=find_packages(), 46 | ext_modules=[ 47 | extension_type( 48 | "torchsparse.backend", sources, extra_compile_args=extra_compile_args 49 | ) 50 | ], 51 | url="https://github.com/mit-han-lab/torchsparse", 52 | install_requires=[ 53 | "numpy", 54 | "backports.cached_property", 55 | "tqdm", 56 | "typing-extensions", 57 | "wheel", 58 | "rootpath", 59 | "torch", 60 | "torchvision" 61 | ], 62 | dependency_links=[ 63 | 'https://download.pytorch.org/whl/cu118' 64 | ], 65 | cmdclass={"build_ext": BuildExtension}, 66 | zip_safe=False, 67 | ) 68 | -------------------------------------------------------------------------------- /tests/python/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_single_layer_conv import * 2 | from .test_to_dense import * 3 | -------------------------------------------------------------------------------- /tests/python/test_to_dense.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple, Union, Optional, List 2 | 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | import torchsparse 8 | from torchsparse import nn as spnn 9 | from torchsparse.utils import make_ntuple, to_dense 10 | 11 | from .test_utils import generate_feature_map 12 | 13 | __all__ = ["test_to_dense_forward"] 14 | 15 | 16 | def test_to_dense_forward( 17 | batch_size: int = 1, 18 | shape: Union[int, Tuple[int, ...]] = 3, 19 | num_points: int = 6, 20 | channel: int = 4, 21 | device="cuda:0", 22 | ): 23 | 24 | np.random.seed(0) 25 | torch.manual_seed(0) 26 | 27 | torch_dtype = torch.float16 28 | np_dtype = np.float16 29 | 30 | shape = make_ntuple(shape, ndim=3) 31 | spatial_range = make_ntuple([batch_size, *shape], ndim=4) 32 | 33 | if num_points > np.prod(shape): 34 | print("Warning: num_points exceeds coords range!") 35 | print(" reduce num_points to %d!" % np.prod(shape)) 36 | num_points = np.prod(shape) 37 | num_points = [num_points] * batch_size 38 | 39 | sparse_dict = generate_feature_map(shape, num_points, channel, dtype=np_dtype) 40 | 41 | feats = np.ascontiguousarray(sparse_dict["feats"]) 42 | coords = np.ascontiguousarray(sparse_dict["coords"][:, [3, 0, 1, 2]]) # batch first 43 | ref_dense_feats = sparse_dict["dense_feats"].transpose(0, 2, 3, 4, 1) 44 | 45 | coords_t = torch.from_numpy(coords).int().to(device) 46 | feats_t = torch.from_numpy(feats).to(torch_dtype).to(device) 47 | 48 | output = to_dense(feats_t, coords_t, spatial_range).cpu().numpy() 49 | 50 | # print(output) 51 | # print(ref_dense_feats) 52 | 53 | max_adiff = np.max(np.abs(output - ref_dense_feats)) 54 | return max_adiff 55 | 56 | 57 | if __name__ == "__main__": 58 | max_adiff = test_to_dense_forward() 59 | print(max_adiff) 60 | -------------------------------------------------------------------------------- /tests/python/test_utils.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple, Union, Optional, List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def generate_feature_map( 8 | shape, 9 | num_points, 10 | num_channels, 11 | data_range=(-1, 1), 12 | with_dense=True, 13 | dtype=np.float16, 14 | ): 15 | dense_shape = shape 16 | ndim = len(dense_shape) 17 | num_points = np.array(num_points) 18 | batch_size = len(num_points) 19 | batch_indices = [] 20 | coords_total = np.stack(np.meshgrid(*[np.arange(0, s) for s in shape]), axis=-1) 21 | coords_total = coords_total.reshape(-1, ndim) 22 | 23 | for i in range(batch_size): 24 | np.random.shuffle(coords_total) 25 | inds_total = coords_total[: num_points[i]] 26 | inds_total = np.pad( 27 | inds_total, 28 | ((0, 0), (0, 1)), # batch last 29 | mode="constant", 30 | constant_values=i, 31 | ) 32 | batch_indices.append(inds_total) 33 | 34 | features = np.random.uniform( 35 | data_range[0], data_range[1], size=[num_points.sum(), num_channels] 36 | ).astype(dtype) 37 | 38 | sparse_dict = dict( 39 | [ 40 | ("feats", features), 41 | ] 42 | ) 43 | 44 | if with_dense: 45 | dense_feats = np.zeros([batch_size, num_channels, *dense_shape], dtype=dtype) 46 | start = 0 47 | for i, inds in enumerate(batch_indices): 48 | for j, ind in enumerate(inds): 49 | dense_slice = (i, slice(None), *ind[:-1]) 50 | dense_feats[dense_slice] = features[start + j] 51 | start += len(inds) 52 | sparse_dict["dense_feats"] = dense_feats 53 | batch_indices = np.concatenate(batch_indices, axis=0) 54 | sparse_dict["coords"] = batch_indices.astype(np.int32) 55 | 56 | return sparse_dict 57 | 58 | 59 | def sparse_tensor_to_dense( 60 | ts_tensor, 61 | shape, 62 | num_channels=None, 63 | dtype=np.float16, 64 | ): 65 | ts_pt = ts_tensor.F[: ts_tensor.C.shape[0]] 66 | ts_coords = ts_tensor.C 67 | 68 | np_ts_pt = np.array(ts_pt.detach().cpu()) 69 | np_ts_coords = np.array(ts_coords.detach().cpu()) 70 | 71 | if num_channels is None: 72 | num_channels = np_ts_pt.shape[1] 73 | 74 | np_ts_pt = np_ts_pt[:, 0:num_channels] 75 | 76 | batch_size = np.max(np_ts_coords[:, 0]) - np.min(np_ts_coords[:, 0]) + 1 77 | 78 | dense_feats = np.zeros([batch_size, num_channels, *shape], dtype=dtype) 79 | 80 | for j, coord in enumerate(np_ts_coords): 81 | dense_slice = (coord[0], slice(None), *coord[1:]) 82 | dense_feats[dense_slice] = np_ts_pt[j] 83 | 84 | return dense_feats 85 | 86 | 87 | def dense_to_subm(feats, coords): 88 | # batch_size = feats.shape[0] 89 | # num_channels = feats.shape[1] 90 | 91 | mask = np.zeros(feats.shape, dtype=np.int32) 92 | 93 | for j, coord in enumerate(coords): 94 | dense_slice = (coord[0], slice(None), *coord[1:]) 95 | mask[dense_slice] = 1 96 | 97 | subm_feats = feats * mask 98 | 99 | return subm_feats 100 | 101 | 102 | def dense_pad(dense_feats_t, kernel_size): 103 | dense_feats_t = torch.nn.functional.pad( 104 | dense_feats_t, 105 | (0, kernel_size - 1, 0, kernel_size - 1, 0, kernel_size - 1), 106 | "constant", 107 | 0, 108 | ) 109 | return dense_feats_t 110 | -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from torchsparse.nn import functional as F 3 | from python import ( 4 | test_single_layer_convolution_forward, 5 | test_to_dense_forward, 6 | ) 7 | 8 | 9 | class SparseConvTestCase(unittest.TestCase): 10 | def test_single_layer(self): 11 | kernel_sizes = [2, 3, 5] 12 | strides = [1, 2, 3] 13 | acc_adiff = 0.0 14 | acc_rdiff = 0.0 15 | count = 0 16 | 17 | # hashmap mode by default 18 | for kernel_size in kernel_sizes: 19 | for stride in strides: 20 | mean_adiff, max_rdiff = test_single_layer_convolution_forward( 21 | kernel_size=kernel_size, stride=stride 22 | ) 23 | acc_adiff += mean_adiff 24 | acc_rdiff += max_rdiff 25 | count += 1 26 | 27 | # switch to hashmap_on_the_fly 28 | config = F.conv_config.get_default_conv_config() 29 | config.kmap_mode = "hashmap_on_the_fly" 30 | F.conv_config.set_global_conv_config(config) 31 | for kernel_size in kernel_sizes: 32 | for stride in strides: 33 | mean_adiff, max_rdiff = test_single_layer_convolution_forward( 34 | kernel_size=kernel_size, stride=stride 35 | ) 36 | acc_adiff += mean_adiff 37 | acc_rdiff += max_rdiff 38 | count += 1 39 | 40 | self.assertLessEqual(acc_adiff / count, 1e-4) 41 | self.assertLessEqual(acc_rdiff / count, 1e-2) 42 | 43 | 44 | class ToDenseTestCase(unittest.TestCase): 45 | def test_to_dense(self): 46 | max_adiff = test_to_dense_forward() 47 | self.assertLessEqual(max_adiff, 1e-5) 48 | 49 | 50 | if __name__ == "__main__": 51 | unittest.main() 52 | -------------------------------------------------------------------------------- /torchsparse/__init__.py: -------------------------------------------------------------------------------- 1 | import torchsparse.backends as backends 2 | 3 | from .operators import * 4 | from .tensor import * 5 | from .utils.tune import tune 6 | from .version import __version__ 7 | 8 | backends.init() 9 | -------------------------------------------------------------------------------- /torchsparse/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import * 2 | from .unet import * 3 | -------------------------------------------------------------------------------- /torchsparse/backbones/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .blocks import * 2 | -------------------------------------------------------------------------------- /torchsparse/backbones/modules/blocks.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | import numpy as np 4 | from torch import nn 5 | 6 | from torchsparse import SparseTensor 7 | from torchsparse import nn as spnn 8 | 9 | __all__ = ["SparseConvBlock", "SparseConvTransposeBlock", "SparseResBlock"] 10 | 11 | 12 | class SparseConvBlock(nn.Sequential): 13 | def __init__( 14 | self, 15 | in_channels: int, 16 | out_channels: int, 17 | kernel_size: Union[int, List[int], Tuple[int, ...]], 18 | stride: Union[int, List[int], Tuple[int, ...]] = 1, 19 | dilation: int = 1, 20 | ) -> None: 21 | super().__init__( 22 | spnn.Conv3d( 23 | in_channels, out_channels, kernel_size, stride=stride, dilation=dilation 24 | ), 25 | spnn.BatchNorm(out_channels), 26 | spnn.ReLU(True), 27 | ) 28 | 29 | 30 | class SparseConvTransposeBlock(nn.Sequential): 31 | def __init__( 32 | self, 33 | in_channels: int, 34 | out_channels: int, 35 | kernel_size: Union[int, List[int], Tuple[int, ...]], 36 | stride: Union[int, List[int], Tuple[int, ...]] = 1, 37 | dilation: int = 1, 38 | ) -> None: 39 | super().__init__( 40 | spnn.Conv3d( 41 | in_channels, 42 | out_channels, 43 | kernel_size, 44 | stride=stride, 45 | dilation=dilation, 46 | transposed=True, 47 | ), 48 | spnn.BatchNorm(out_channels), 49 | spnn.ReLU(True), 50 | ) 51 | 52 | 53 | class SparseResBlock(nn.Module): 54 | def __init__( 55 | self, 56 | in_channels: int, 57 | out_channels: int, 58 | kernel_size: Union[int, List[int], Tuple[int, ...]], 59 | stride: Union[int, List[int], Tuple[int, ...]] = 1, 60 | dilation: int = 1, 61 | ) -> None: 62 | super().__init__() 63 | self.main = nn.Sequential( 64 | spnn.Conv3d( 65 | in_channels, out_channels, kernel_size, dilation=dilation, stride=stride 66 | ), 67 | spnn.BatchNorm(out_channels), 68 | spnn.ReLU(True), 69 | spnn.Conv3d(out_channels, out_channels, kernel_size, dilation=dilation), 70 | spnn.BatchNorm(out_channels), 71 | ) 72 | 73 | if in_channels != out_channels or np.prod(stride) != 1: 74 | self.shortcut = nn.Sequential( 75 | spnn.Conv3d(in_channels, out_channels, 1, stride=stride), 76 | spnn.BatchNorm(out_channels), 77 | ) 78 | else: 79 | self.shortcut = nn.Identity() 80 | 81 | self.relu = spnn.ReLU(True) 82 | 83 | def forward(self, x: SparseTensor) -> SparseTensor: 84 | x = self.relu(self.main(x) + self.shortcut(x)) 85 | return x 86 | -------------------------------------------------------------------------------- /torchsparse/backbones/resnet.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple, Union 2 | 3 | from torch import nn 4 | 5 | from torchsparse import SparseTensor 6 | 7 | from .modules import SparseConvBlock, SparseResBlock 8 | 9 | __all__ = ["SparseResNet21D"] 10 | 11 | 12 | class SparseResNet(nn.ModuleList): 13 | def __init__( 14 | self, 15 | blocks: List[ 16 | Tuple[int, int, Union[int, Tuple[int, ...]], Union[int, Tuple[int, ...]]] 17 | ], 18 | *, 19 | in_channels: int = 4, 20 | width_multiplier: float = 1.0, 21 | ) -> None: 22 | super().__init__() 23 | self.blocks = blocks 24 | self.in_channels = in_channels 25 | self.width_multiplier = width_multiplier 26 | 27 | for num_blocks, out_channels, kernel_size, stride in blocks: 28 | out_channels = int(out_channels * width_multiplier) 29 | blocks = [] 30 | for index in range(num_blocks): 31 | if index == 0: 32 | blocks.append( 33 | SparseConvBlock( 34 | in_channels, 35 | out_channels, 36 | kernel_size, 37 | stride=stride, 38 | ) 39 | ) 40 | else: 41 | blocks.append( 42 | SparseResBlock( 43 | in_channels, 44 | out_channels, 45 | kernel_size, 46 | ) 47 | ) 48 | in_channels = out_channels 49 | self.append(nn.Sequential(*blocks)) 50 | 51 | def forward(self, x: SparseTensor) -> List[SparseTensor]: 52 | outputs = [] 53 | for module in self: 54 | x = module(x) 55 | outputs.append(x) 56 | return outputs 57 | 58 | 59 | class SparseResNet21D(SparseResNet): 60 | def __init__(self, **kwargs) -> None: 61 | super().__init__( 62 | blocks=[ 63 | (3, 16, 3, 1), 64 | (3, 32, 3, 2), 65 | (3, 64, 3, 2), 66 | (3, 128, 3, 2), 67 | (1, 128, (1, 3, 1), (1, 2, 1)), 68 | ], 69 | **kwargs, 70 | ) 71 | -------------------------------------------------------------------------------- /torchsparse/backbones/unet.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from torch import nn 4 | 5 | import torchsparse 6 | from torchsparse import SparseTensor 7 | from torchsparse import nn as spnn 8 | 9 | from .modules import SparseConvBlock, SparseConvTransposeBlock, SparseResBlock 10 | 11 | __all__ = ["SparseResUNet42"] 12 | 13 | 14 | class SparseResUNet(nn.Module): 15 | def __init__( 16 | self, 17 | stem_channels: int, 18 | encoder_channels: List[int], 19 | decoder_channels: List[int], 20 | *, 21 | in_channels: int = 4, 22 | width_multiplier: float = 1.0, 23 | ) -> None: 24 | super().__init__() 25 | self.stem_channels = stem_channels 26 | self.encoder_channels = encoder_channels 27 | self.decoder_channels = decoder_channels 28 | self.in_channels = in_channels 29 | self.width_multiplier = width_multiplier 30 | 31 | num_channels = [stem_channels] + encoder_channels + decoder_channels 32 | num_channels = [int(width_multiplier * nc) for nc in num_channels] 33 | 34 | self.stem = nn.Sequential( 35 | spnn.Conv3d(in_channels, num_channels[0], 3), 36 | spnn.BatchNorm(num_channels[0]), 37 | spnn.ReLU(True), 38 | spnn.Conv3d(num_channels[0], num_channels[0], 3), 39 | spnn.BatchNorm(num_channels[0]), 40 | spnn.ReLU(True), 41 | ) 42 | 43 | # TODO(Zhijian): the current implementation of encoder and decoder 44 | # is hard-coded for 4 encoder stages and 4 decoder stages. We should 45 | # work on a more generic implementation in the future. 46 | 47 | self.encoders = nn.ModuleList() 48 | for k in range(4): 49 | self.encoders.append( 50 | nn.Sequential( 51 | SparseConvBlock( 52 | num_channels[k], 53 | num_channels[k], 54 | 2, 55 | stride=2, 56 | ), 57 | SparseResBlock(num_channels[k], num_channels[k + 1], 3), 58 | SparseResBlock(num_channels[k + 1], num_channels[k + 1], 3), 59 | ) 60 | ) 61 | 62 | self.decoders = nn.ModuleList() 63 | for k in range(4): 64 | self.decoders.append( 65 | nn.ModuleDict( 66 | { 67 | "upsample": SparseConvTransposeBlock( 68 | num_channels[k + 4], 69 | num_channels[k + 5], 70 | 2, 71 | stride=2, 72 | ), 73 | "fuse": nn.Sequential( 74 | SparseResBlock( 75 | num_channels[k + 5] + num_channels[3 - k], 76 | num_channels[k + 5], 77 | 3, 78 | ), 79 | SparseResBlock( 80 | num_channels[k + 5], 81 | num_channels[k + 5], 82 | 3, 83 | ), 84 | ), 85 | } 86 | ) 87 | ) 88 | 89 | def _unet_forward( 90 | self, 91 | x: SparseTensor, 92 | encoders: nn.ModuleList, 93 | decoders: nn.ModuleList, 94 | ) -> List[SparseTensor]: 95 | if not encoders and not decoders: 96 | return [x] 97 | 98 | # downsample 99 | xd = encoders[0](x) 100 | 101 | # inner recursion 102 | outputs = self._unet_forward(xd, encoders[1:], decoders[:-1]) 103 | yd = outputs[-1] 104 | 105 | # upsample and fuse 106 | u = decoders[-1]["upsample"](yd) 107 | y = decoders[-1]["fuse"](torchsparse.cat([u, x])) 108 | 109 | return [x] + outputs + [y] 110 | 111 | def forward(self, x: SparseTensor) -> List[SparseTensor]: 112 | return self._unet_forward(self.stem(x), self.encoders, self.decoders) 113 | 114 | 115 | class SparseResUNet42(SparseResUNet): 116 | def __init__(self, **kwargs) -> None: 117 | super().__init__( 118 | stem_channels=32, 119 | encoder_channels=[32, 64, 128, 256], 120 | decoder_channels=[256, 128, 96, 96], 121 | **kwargs, 122 | ) 123 | -------------------------------------------------------------------------------- /torchsparse/backend/convolution/convolution_backward_wgrad_implicit_gemm_cuda.h: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | 3 | at::Tensor conv_backward_wgrad_implicit_gemm_cuda( 4 | torch::Tensor _in_feats, torch::Tensor _kernel, 5 | torch::Tensor _out_in_map, const int split_k_iters, 6 | bool allow_tf32, bool allow_fp16); 7 | -------------------------------------------------------------------------------- /torchsparse/backend/convolution/convolution_backward_wgrad_implicit_gemm_sorted_cuda.h: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | 3 | at::Tensor conv_backward_wgrad_implicit_gemm_sorted_cuda( 4 | torch::Tensor _in_feats, torch::Tensor _kernel, 5 | torch::Tensor _out_in_map,torch::Tensor _reduced_mask, 6 | torch::Tensor _reorder_loc, const int split_k_iters, 7 | bool allow_tf32, bool allow_fp16); 8 | -------------------------------------------------------------------------------- /torchsparse/backend/convolution/convolution_forward_fetch_on_demand_cuda.h: -------------------------------------------------------------------------------- 1 | /* 2 | Please consider citing the following paper when using the code: 3 | 4 | @inproceedings{hong2023pcengine, 5 | title={{Exploiting Hardware Utilization and Adaptive Dataflow for Efficient Sparse Convolution in 3D Point Clouds}}, 6 | author={Hong, Ke and Yu, Zhongming and Dai, Guohao and Yang, Xinhao and Lian, Yaoxiu and Liu, Zehao and Xu, Ningyi and Wang, Yu}, 7 | booktitle={Sixth Conference on Machine Learning and Systems (MLSys)}, 8 | year={2023} 9 | } 10 | */ 11 | 12 | #pragma once 13 | 14 | #include <torch/torch.h> 15 | 16 | at::Tensor conv_forward_fetch_on_demand_cuda( 17 | at::Tensor in_feat, at::Tensor kernel, 18 | at::Tensor neighbor_map, const int sum_nnz, 19 | at::Tensor neighbor_address, at::Tensor q_neighbor_address, 20 | const int output_size, const int qsum_nnz, const bool transpose, 21 | const bool allow_tf32, const bool allow_fp16); 22 | 23 | at::Tensor conv_forward_fetch_on_demand_no_fusion_cuda( 24 | at::Tensor in_feat, at::Tensor kernel, 25 | at::Tensor neighbor_map, at::Tensor neighbor_offset, 26 | const int sum_nnz, const int output_size, const bool transpose, 27 | const bool allow_tf32, const bool allow_fp16); 28 | -------------------------------------------------------------------------------- /torchsparse/backend/convolution/convolution_forward_implicit_gemm_cuda.h: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | 3 | at::Tensor conv_forward_implicit_gemm_cuda(torch::Tensor _in_feats, torch::Tensor _kernel, 4 | torch::Tensor _out_in_map, int num_out_feats, int num_out_channels, 5 | bool allow_tf32, bool allow_fp16); -------------------------------------------------------------------------------- /torchsparse/backend/convolution/convolution_forward_implicit_gemm_sorted_cuda.h: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | 3 | at::Tensor conv_forward_implicit_gemm_sorted_cuda( 4 | torch::Tensor _in_feats, torch::Tensor _kernel, 5 | torch::Tensor _out_in_map,torch::Tensor _reduced_mask, 6 | torch::Tensor _reorder_loc, 7 | int num_out_feats, int num_out_channels, 8 | bool allow_tf32, bool allow_fp16); 9 | -------------------------------------------------------------------------------- /torchsparse/backend/convolution/convolution_gather_scatter_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | void conv_forward_gather_scatter_cpu(at::Tensor in_feat, at::Tensor out_feat, 6 | at::Tensor kernel, at::Tensor neighbor_map, 7 | at::Tensor neighbor_offset, const bool transpose); 8 | 9 | void conv_backward_gather_scatter_cpu(at::Tensor in_feat, at::Tensor grad_in_feat, 10 | at::Tensor grad_out_feat, at::Tensor kernel, 11 | at::Tensor grad_kernel, at::Tensor neighbor_map, 12 | at::Tensor neighbor_offset, const bool transpose); 13 | -------------------------------------------------------------------------------- /torchsparse/backend/convolution/convolution_gather_scatter_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor conv_forward_gather_scatter_cuda( 6 | at::Tensor in_feat, at::Tensor kernel, at::Tensor neighbor_map, 7 | at::Tensor neighbor_offset, at::Tensor input_mask, at::Tensor output_mask, 8 | const int output_size, const float epsilon, const int mm_thresh, 9 | const int conv_mode, const bool transpose, at::Tensor buffer); 10 | 11 | at::Tensor conv_forward_gather_scatter_cuda_latest( 12 | at::Tensor in_feat, at::Tensor kernel, at::Tensor neighbor_map, 13 | at::Tensor neighbor_offset, at::Tensor input_mask, at::Tensor output_mask, 14 | const int output_size, const float epsilon, const int mm_thresh, 15 | const int conv_mode, const bool transpose, at::Tensor buffer); 16 | 17 | at::Tensor conv_forward_gather_scatter_cuda_fallback( 18 | at::Tensor in_feat, at::Tensor kernel, at::Tensor neighbor_map, 19 | const int output_size, const int conv_mode, at::Tensor neighbor_offset, 20 | const bool transpose); 21 | 22 | void conv_backward_gather_scatter_cuda(at::Tensor in_feat, at::Tensor grad_in_feat, 23 | at::Tensor grad_out_feat, at::Tensor kernel, 24 | at::Tensor grad_kernel, at::Tensor neighbor_map, 25 | at::Tensor neighbor_offset, 26 | const bool transpose); 27 | -------------------------------------------------------------------------------- /torchsparse/backend/devoxelize/devoxelize_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "devoxelize_cpu.h" 2 | 3 | #include <torch/torch.h> 4 | 5 | #include <vector> 6 | 7 | // make sure indices is int type 8 | // feat: (b,c,s) indices: (N, 3) batch_index: (N, ) -> out: (N, c) 9 | at::Tensor devoxelize_forward_cpu(const at::Tensor feat, 10 | const at::Tensor indices, 11 | const at::Tensor weight) { 12 | int c = feat.size(1); 13 | int N = indices.size(0); 14 | 15 | at::Tensor out = torch::zeros( 16 | {N, c}, at::device(feat.device()).dtype(at::ScalarType::Float)); 17 | #pragma omp parallel for 18 | for (int i = 0; i < N; i++) { 19 | int *indices_ = indices.data_ptr<int>() + i * 8; 20 | float *weight_ = weight.data_ptr<float>() + i * 8; 21 | for (int j = 0; j < c; j++) { 22 | float *feat_ = feat.data_ptr<float>() + j; 23 | float cur_feat; 24 | for (int k = 0; k < 8; k++) { 25 | cur_feat = (indices_[k] >= 0) ? feat_[indices_[k] * c] : 0; 26 | *(out.data_ptr<float>() + i * c + j) += weight_[k] * cur_feat; 27 | } 28 | } 29 | } 30 | return out; 31 | } 32 | 33 | // top_grad: (N, c), indices: (N, 3), batch_index: (N, ) -> bottom_grad: 34 | // (b,c,s), s=r^3 35 | at::Tensor devoxelize_backward_cpu(const at::Tensor top_grad, 36 | const at::Tensor indices, 37 | const at::Tensor weight, int n) { 38 | int c = top_grad.size(1); 39 | int N = top_grad.size(0); 40 | at::Tensor bottom_grad = torch::zeros( 41 | {n, c}, at::device(top_grad.device()).dtype(at::ScalarType::Float)); 42 | 43 | for (int i = 0; i < N; i++) { 44 | int *indices_ = indices.data_ptr<int>() + i * 8; 45 | float *weight_ = weight.data_ptr<float>() + i * 8; 46 | #pragma omp parallel for 47 | for (int j = 0; j < c; j++) { 48 | float *top_grad_ = top_grad.data_ptr<float>() + j; 49 | float cur_top_grad; 50 | for (int k = 0; k < 8; k++) { 51 | cur_top_grad = (indices_[k] >= 0) ? top_grad_[indices_[k] * c] : 0; 52 | *(bottom_grad.data_ptr<float>() + indices_[k] * c + j) += 53 | weight_[k] * cur_top_grad; 54 | } 55 | } 56 | } 57 | 58 | return bottom_grad; 59 | } 60 | -------------------------------------------------------------------------------- /torchsparse/backend/devoxelize/devoxelize_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor devoxelize_forward_cpu(const at::Tensor feat, 6 | const at::Tensor indices, 7 | const at::Tensor weight); 8 | 9 | at::Tensor devoxelize_backward_cpu(const at::Tensor top_grad, 10 | const at::Tensor indices, 11 | const at::Tensor weight, int n); 12 | -------------------------------------------------------------------------------- /torchsparse/backend/devoxelize/devoxelize_cuda.cu: -------------------------------------------------------------------------------- 1 | #include <stdio.h> 2 | #include <stdlib.h> 3 | #include <thrust/device_vector.h> 4 | #include <torch/extension.h> 5 | 6 | #include <THC/THCAtomics.cuh> 7 | 8 | // input features (n, c), indices (N, 8), weight (N, 8) -> output features (N, 9 | // c) 10 | template <typename scalar_t> 11 | __global__ void devoxelize_forward_kernel(int N, int c, 12 | const int *__restrict__ indices, 13 | const scalar_t *__restrict__ weight, 14 | const scalar_t *__restrict__ feat, 15 | scalar_t *__restrict__ out) { 16 | int index = blockIdx.x * blockDim.x + threadIdx.x; 17 | int i = index / c; 18 | int j = index % c; 19 | 20 | if (i < N) { 21 | const int *indices_ = indices + 8 * i; 22 | const scalar_t *weight_ = weight + 8 * i; 23 | const scalar_t *feat_ = feat + j; 24 | 25 | scalar_t cur_feat; 26 | for (int k = 0; k < 8; k++) { 27 | cur_feat = 0; 28 | if (indices_[k] >= 0) cur_feat = feat_[indices_[k] * c]; 29 | 30 | out[i * c + j] += weight_[k] * cur_feat; 31 | } 32 | } 33 | } 34 | 35 | // input weight (N, 8), indices (N, 8), top_grad (N, c) -> bottom grad (n, c) 36 | template <typename scalar_t> 37 | __global__ void devoxelize_backward_kernel( 38 | int N, int n, int c, const int *__restrict__ indices, 39 | const scalar_t *__restrict__ weight, const scalar_t *__restrict__ top_grad, 40 | scalar_t *__restrict__ bottom_grad) { 41 | int index = blockIdx.x * blockDim.x + threadIdx.x; 42 | int i = index / c; 43 | int j = index % c; 44 | 45 | if (i < N) { 46 | const int *indices_ = indices + 8 * i; 47 | const scalar_t *weight_ = weight + 8 * i; 48 | 49 | scalar_t cur_top_grad = top_grad[i * c + j]; 50 | 51 | #pragma unroll 52 | for (int k = 0; k < 8; k++) { 53 | if (indices_[k] >= 0) 54 | atomicAdd(&bottom_grad[indices_[k] * c + j], weight_[k] * cur_top_grad); 55 | } 56 | } 57 | } 58 | 59 | // make sure indices is int type 60 | // feat: (b,c,s) indices: (N, 3) batch_index: (N, ) -> out: (N, c) 61 | at::Tensor devoxelize_forward_cuda(const at::Tensor feat, 62 | const at::Tensor indices, 63 | const at::Tensor weight) { 64 | int c = feat.size(1); 65 | int N = indices.size(0); 66 | 67 | at::Tensor out = 68 | torch::zeros({N, c}, at::device(feat.device()).dtype(feat.dtype())); 69 | 70 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 71 | feat.scalar_type(), "devoxelize_forward_cuda", ([&] { 72 | devoxelize_forward_kernel<scalar_t><<<N, c>>>( 73 | N, c, indices.data_ptr<int>(), weight.data_ptr<scalar_t>(), 74 | feat.data_ptr<scalar_t>(), out.data_ptr<scalar_t>()); 75 | })); 76 | 77 | return out; 78 | } 79 | 80 | // top_grad: (N, c), indices: (N, 3), batch_index: (N, ) -> bottom_grad: 81 | // (b,c,s), s=r^3 82 | at::Tensor devoxelize_backward_cuda(const at::Tensor top_grad, 83 | const at::Tensor indices, 84 | const at::Tensor weight, int n) { 85 | int c = top_grad.size(1); 86 | int N = top_grad.size(0); 87 | at::Tensor bottom_grad = torch::zeros( 88 | {n, c}, at::device(top_grad.device()).dtype(top_grad.dtype())); 89 | 90 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 91 | top_grad.scalar_type(), "devoxelize_backward_cuda", ([&] { 92 | devoxelize_backward_kernel<scalar_t><<<N, c>>>( 93 | N, n, c, indices.data_ptr<int>(), weight.data_ptr<scalar_t>(), 94 | top_grad.data_ptr<scalar_t>(), bottom_grad.data_ptr<scalar_t>()); 95 | })); 96 | 97 | return bottom_grad; 98 | } 99 | -------------------------------------------------------------------------------- /torchsparse/backend/devoxelize/devoxelize_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor devoxelize_forward_cuda(const at::Tensor feat, 6 | const at::Tensor indices, 7 | const at::Tensor weight); 8 | 9 | at::Tensor devoxelize_backward_cuda(const at::Tensor top_grad, 10 | const at::Tensor indices, 11 | const at::Tensor weight, int n); 12 | -------------------------------------------------------------------------------- /torchsparse/backend/hash/hash_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "hash_cpu.h" 2 | 3 | #include <torch/torch.h> 4 | 5 | #include <vector> 6 | 7 | void cpu_hash_wrapper(int N, const int *data, int64_t *out) { 8 | #pragma omp parallel for 9 | for (int i = 0; i < N; i++) { 10 | uint64_t hash = 14695981039346656037UL; 11 | for (int j = 0; j < 4; j++) { 12 | hash ^= (unsigned int)data[4 * i + j]; 13 | hash *= 1099511628211UL; 14 | } 15 | hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); 16 | out[i] = hash; 17 | } 18 | } 19 | 20 | void cpu_kernel_hash_wrapper(int N, int K, const int *data, 21 | const int *kernel_offset, int64_t *out) { 22 | for (int k = 0; k < K; k++) { 23 | #pragma omp parallel for 24 | for (int i = 0; i < N; i++) { 25 | int cur_coord[4]; 26 | for (int j = 0; j < 3; j++) { 27 | cur_coord[j] = data[i * 4 + j] + kernel_offset[k * 3 + j]; 28 | } 29 | cur_coord[3] = data[i * 4 + 3]; 30 | uint64_t hash = 14695981039346656037UL; 31 | for (int j = 0; j < 4; j++) { 32 | hash ^= (unsigned int)cur_coord[j]; 33 | hash *= 1099511628211UL; 34 | } 35 | hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); 36 | out[k * N + i] = hash; 37 | } 38 | } 39 | } 40 | 41 | at::Tensor hash_cpu(const at::Tensor idx) { 42 | int N = idx.size(0); 43 | at::Tensor out = 44 | torch::zeros({N}, at::device(idx.device()).dtype(at::ScalarType::Long)); 45 | cpu_hash_wrapper(N, idx.data_ptr<int>(), out.data_ptr<int64_t>()); 46 | return out; 47 | } 48 | 49 | at::Tensor kernel_hash_cpu(const at::Tensor idx, 50 | const at::Tensor kernel_offset) { 51 | int N = idx.size(0); 52 | int K = kernel_offset.size(0); 53 | at::Tensor out = torch::zeros( 54 | {K, N}, at::device(idx.device()).dtype(at::ScalarType::Long)); 55 | cpu_kernel_hash_wrapper(N, K, idx.data_ptr<int>(), 56 | kernel_offset.data_ptr<int>(), 57 | out.data_ptr<int64_t>()); 58 | return out; 59 | } 60 | -------------------------------------------------------------------------------- /torchsparse/backend/hash/hash_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor hash_cpu(const at::Tensor idx); 6 | 7 | at::Tensor kernel_hash_cpu(const at::Tensor idx, 8 | const at::Tensor kernel_offset); 9 | -------------------------------------------------------------------------------- /torchsparse/backend/hash/hash_cuda.cu: -------------------------------------------------------------------------------- 1 | #include <stdio.h> 2 | #include <stdlib.h> 3 | #include <torch/torch.h> 4 | 5 | #include <cmath> 6 | #include <vector> 7 | // hashing 8 | // input N*4 int32 tensor output N*1 int64 tensor 9 | __global__ void hash_kernel(int N, const int *__restrict__ data, 10 | int64_t *__restrict__ out) { 11 | int i = blockDim.x * blockIdx.x + threadIdx.x; 12 | if (i < N) { 13 | data += i * 4; 14 | uint64_t hash = 14695981039346656037UL; 15 | for (int j = 0; j < 4; j++) { 16 | hash ^= (unsigned int)data[j]; 17 | hash *= 1099511628211UL; 18 | } 19 | // hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); 20 | out[i] = hash; 21 | } 22 | } 23 | 24 | // kernel hashing: given data D and offset map K, generate D x K 25 | // input N*4 int32 tensor, |K|*3 int32 tensor, output |K|*N int64 tensor 26 | __global__ void kernel_hash_kernel(int N, int K, const int *__restrict__ data, 27 | const int *__restrict__ kernel_offset, 28 | int64_t *__restrict__ out) { 29 | extern __shared__ int kernel_offset_local[]; 30 | 31 | for (int i = 0; i < K * 3; i++) { 32 | kernel_offset_local[i] = kernel_offset[i]; 33 | } 34 | __syncthreads(); 35 | 36 | int idx = blockDim.x * blockIdx.x + threadIdx.x; 37 | int k = idx % K; 38 | int i = idx / K; 39 | int cur_coord[4]; 40 | if (i < N) { 41 | data += i * 4; 42 | for (int j = 1; j < 4; j++) { 43 | cur_coord[j] = data[j] + kernel_offset[k * 3 + j - 1]; 44 | } 45 | cur_coord[0] = data[0]; 46 | uint64_t hash = 14695981039346656037UL; 47 | for (int j = 0; j < 4; j++) { 48 | hash ^= (unsigned int)cur_coord[j]; 49 | hash *= 1099511628211UL; 50 | } 51 | // hash = (hash >> 60) ^ (hash & 0xFFFFFFFFFFFFFFF); 52 | out[k * N + i] = hash; 53 | } 54 | } 55 | 56 | void kernel_hash_wrapper(int N, int K, const int *data, 57 | const int *kernel_offset, int64_t *out) { 58 | kernel_hash_kernel<<<ceil((double)(N * K) / 512), 512, K * 3 * sizeof(int)>>>( 59 | N, K, data, kernel_offset, out); 60 | } 61 | 62 | void hash_wrapper(int N, const int *data, int64_t *out) { 63 | hash_kernel<<<ceil((double)N / 512), 512>>>(N, data, out); 64 | } 65 | 66 | at::Tensor hash_cuda(const at::Tensor idx) { 67 | int N = idx.size(0); 68 | at::Tensor out = 69 | torch::zeros({N}, at::device(idx.device()).dtype(at::ScalarType::Long)); 70 | hash_wrapper(N, idx.data_ptr<int>(), out.data_ptr<int64_t>()); 71 | return out; 72 | } 73 | 74 | at::Tensor kernel_hash_cuda(const at::Tensor idx, 75 | const at::Tensor kernel_offset) { 76 | int N = idx.size(0); 77 | int K = kernel_offset.size(0); 78 | at::Tensor out = torch::zeros( 79 | {K, N}, at::device(idx.device()).dtype(at::ScalarType::Long)); 80 | kernel_hash_wrapper(N, K, idx.data_ptr<int>(), kernel_offset.data_ptr<int>(), 81 | out.data_ptr<int64_t>()); 82 | return out; 83 | } 84 | -------------------------------------------------------------------------------- /torchsparse/backend/hash/hash_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor hash_cuda(const at::Tensor idx); 6 | 7 | at::Tensor kernel_hash_cuda(const at::Tensor idx, 8 | const at::Tensor kernel_offset); 9 | -------------------------------------------------------------------------------- /torchsparse/backend/hashmap/hashmap_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "hashmap_cpu.hpp" 2 | 3 | #include <chrono> 4 | #include <cstdio> 5 | #include <cstdlib> 6 | #include <stdexcept> 7 | 8 | void HashTableCPU::lookup_vals(const int64_t* const keys, 9 | int64_t* const results, const int n) { 10 | #pragma omp parallel for 11 | for (int idx = 0; idx < n; idx++) { 12 | int64_t key = keys[idx]; 13 | google::dense_hash_map<int64_t, int64_t>::iterator iter = hashmap.find(key); 14 | if (iter != hashmap.end()) { 15 | results[idx] = iter->second; 16 | } else { 17 | results[idx] = 0; 18 | } 19 | } 20 | } 21 | 22 | void HashTableCPU::insert_vals(const int64_t* const keys, 23 | const int64_t* const vals, const int n) { 24 | for (int i = 0; i < 10; i++) { 25 | printf("%d, %d, %d, %d\n", i, i < n, n, i < 10); 26 | // hashmap[(int)keys[idx]] = (int)vals[idx]+1; 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /torchsparse/backend/hashmap/hashmap_cpu.hpp: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <cmath> 4 | #include <cstdint> 5 | #include <cstdio> 6 | #include <cstdlib> 7 | #include <google/dense_hash_map> 8 | #include <vector> 9 | 10 | class HashTableCPU { 11 | private: 12 | google::dense_hash_map<int64_t, int64_t> hashmap; 13 | 14 | public: 15 | HashTableCPU() {} 16 | 17 | ~HashTableCPU() {} 18 | 19 | void insert_vals(const int64_t* const keys, const int64_t* const vals, 20 | const int n); 21 | 22 | void lookup_vals(const int64_t* const keys, int64_t* const results, 23 | const int n); 24 | }; 25 | -------------------------------------------------------------------------------- /torchsparse/backend/others/count_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "count_cpu.h" 2 | 3 | #include <torch/torch.h> 4 | 5 | #include <vector> 6 | 7 | at::Tensor count_cpu(const at::Tensor idx, const int s) { 8 | int N = idx.size(0); 9 | at::Tensor out = 10 | torch::zeros({s}, at::device(idx.device()).dtype(at::ScalarType::Int)); 11 | int *idx_ = idx.data_ptr<int>(); 12 | int *out_ = out.data_ptr<int>(); 13 | #pragma omp parallel for 14 | for (int i = 0; i < N; i++) { 15 | int cur_idx = idx_[i]; 16 | if (cur_idx < 0) { 17 | continue; 18 | } 19 | #pragma omp atomic 20 | out_[cur_idx]++; 21 | } 22 | return out; 23 | } 24 | -------------------------------------------------------------------------------- /torchsparse/backend/others/count_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor count_cpu(const at::Tensor idx, const int s); 6 | -------------------------------------------------------------------------------- /torchsparse/backend/others/count_cuda.cu: -------------------------------------------------------------------------------- 1 | #include <stdio.h> 2 | #include <stdlib.h> 3 | #include <torch/torch.h> 4 | 5 | #include <cmath> 6 | #include <vector> 7 | 8 | // counting 9 | // input N*3 int32 tensor output N*1 int64 tensor 10 | __global__ void count_kernel(int N, const int *__restrict__ data, 11 | int *__restrict__ out) { 12 | int i = blockDim.x * blockIdx.x + threadIdx.x; 13 | if (i < N && data[i] >= 0) { 14 | atomicAdd(&out[data[i]], 1); 15 | } 16 | } 17 | 18 | void count_wrapper(int N, const int *data, int *out) { 19 | count_kernel<<<ceil((double)N / 512), 512>>>(N, data, out); 20 | } 21 | 22 | // make sure indices is int type 23 | // feat: (b,c,n) indices: (b,n) -> out: (b,c,s), out_indices: (b,n) 24 | // (preprocessed indices) 25 | at::Tensor count_cuda(const at::Tensor idx, const int s) { 26 | int N = idx.size(0); 27 | at::Tensor out = 28 | torch::zeros({s}, at::device(idx.device()).dtype(at::ScalarType::Int)); 29 | count_wrapper(N, idx.data_ptr<int>(), out.data_ptr<int>()); 30 | return out; 31 | } 32 | -------------------------------------------------------------------------------- /torchsparse/backend/others/count_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor count_cuda(const at::Tensor idx, const int s); 6 | -------------------------------------------------------------------------------- /torchsparse/backend/others/downsample_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor downsample_cuda(at::Tensor _in_coords, at::Tensor _coords_max, 6 | at::Tensor _coords_min, at::Tensor _kernel_sizes, 7 | at::Tensor _stride, at::Tensor _padding); 8 | -------------------------------------------------------------------------------- /torchsparse/backend/others/exclusive_scan_cuda.cu: -------------------------------------------------------------------------------- 1 | #include <torch/torch.h> 2 | #include <torch/extension.h> 3 | 4 | #include "exclusive_scan_cuda.h" 5 | 6 | // to derive quantified address of activated features 7 | __global__ void exclusive_scan_for_kernel_quantified( 8 | const int kv, 9 | const int *input, 10 | const int q, 11 | // const int mid_kernel, 12 | int *output, 13 | int *qoutput 14 | // bool precompute_mid 15 | ){ 16 | // a thread for a scan 17 | const int id = threadIdx.x + 1; 18 | if (id >= kv){return;} 19 | int acc = 0; 20 | int qacc = 0; 21 | #pragma unroll 22 | for (int i = 0; i < id; i++){ 23 | // if (precompute_mid && i == mid_kernel){continue;} 24 | acc += input[i]; 25 | qacc += (input[i] + q - 1) / q * q; 26 | } 27 | output[id] = acc; 28 | qoutput[id] = qacc; 29 | } 30 | 31 | at::Tensor exclusive_scan_quantified_wrapper( 32 | const int k_vol, at::Tensor neighbor_offset, 33 | at::Tensor neighbor_address, at::Tensor q_neighbor_address){ 34 | 35 | int *knnz_ptr = neighbor_offset.data_ptr<int>(); 36 | int *kpos_ptr = neighbor_address.data_ptr<int>(); 37 | int *qkpos_ptr = q_neighbor_address.data_ptr<int>(); 38 | 39 | exclusive_scan_for_kernel_quantified<<<1, k_vol, 0, 0>>>( 40 | k_vol + 1, knnz_ptr, 128, kpos_ptr, qkpos_ptr 41 | ); 42 | // We must have a tensor as return val for Pybind. 43 | at::Tensor status = at::zeros({1}); 44 | return status; 45 | } -------------------------------------------------------------------------------- /torchsparse/backend/others/exclusive_scan_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor exclusive_scan_quantified_wrapper( 6 | const int k_vol, at::Tensor neighbor_offset, 7 | at::Tensor neighbor_address, at::Tensor q_neighbor_address); -------------------------------------------------------------------------------- /torchsparse/backend/others/query_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "query_cpu.h" 2 | 3 | #include <torch/torch.h> 4 | 5 | #include <cmath> 6 | #include <google/dense_hash_map> 7 | #include <iostream> 8 | #include <vector> 9 | 10 | #include "../hashmap/hashmap_cpu.hpp" 11 | 12 | at::Tensor hash_query_cpu(const at::Tensor hash_query, 13 | const at::Tensor hash_target, 14 | const at::Tensor idx_target) { 15 | int n = hash_target.size(0); 16 | int n1 = hash_query.size(0); 17 | 18 | google::dense_hash_map<int64_t, int64_t> hashmap; 19 | hashmap.set_empty_key(0); 20 | at::Tensor out = torch::zeros( 21 | {n1}, at::device(hash_query.device()).dtype(at::ScalarType::Long)); 22 | for (int idx = 0; idx < n; idx++) { 23 | int64_t key = *(hash_target.data_ptr<int64_t>() + idx); 24 | int64_t val = *(idx_target.data_ptr<int64_t>() + idx) + 1; 25 | hashmap.insert(std::make_pair(key, val)); 26 | } 27 | #pragma omp parallel for 28 | for (int idx = 0; idx < n1; idx++) { 29 | int64_t key = *(hash_query.data_ptr<int64_t>() + idx); 30 | google::dense_hash_map<int64_t, int64_t>::iterator iter = hashmap.find(key); 31 | if (iter != hashmap.end()) { 32 | *(out.data_ptr<int64_t>() + idx) = iter->second; 33 | } 34 | } 35 | 36 | return out; 37 | } 38 | -------------------------------------------------------------------------------- /torchsparse/backend/others/query_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor hash_query_cpu(const at::Tensor hash_query, 6 | const at::Tensor hash_target, 7 | const at::Tensor idx_target); 8 | -------------------------------------------------------------------------------- /torchsparse/backend/others/query_cuda.cu: -------------------------------------------------------------------------------- 1 | #include <torch/torch.h> 2 | 3 | #include <cmath> 4 | #include <iostream> 5 | #include <vector> 6 | 7 | #include "../hashmap/hashmap_cuda.cuh" 8 | 9 | __global__ void convert_out_in_map_kernel(const int* out_in_map, int* out_in_map_t, int n, int kernel_volume){ 10 | int idx = blockIdx.x * blockDim.x + threadIdx.x; 11 | if(idx >= n * kernel_volume) return; 12 | int input_idx = out_in_map[idx]; 13 | if(input_idx < 0) return; 14 | out_in_map_t[idx % kernel_volume + input_idx * kernel_volume] = idx / kernel_volume; 15 | } 16 | 17 | __global__ void derive_bit_mask_from_out_in_map_kernel(int* out_in_map, int* bitmask, int valid_n, int n, int kernel_volume, int split_mask_num){ 18 | int tidx = blockIdx.x * blockDim.x + threadIdx.x; 19 | int idx = tidx / split_mask_num; 20 | if(idx >= valid_n) return; 21 | int split_mask_iter = tidx % split_mask_num; 22 | int split_mask_len = (kernel_volume + split_mask_num - 1) / split_mask_num; 23 | int* cur_out_in_map = out_in_map + kernel_volume * idx + split_mask_iter * split_mask_len; 24 | if (split_mask_iter == (split_mask_num - 1)) // The last tile 25 | split_mask_len = kernel_volume - split_mask_iter * split_mask_len; 26 | int cur_bitmask = 0; 27 | for(int i = 0; i < split_mask_len; i++){ 28 | cur_bitmask += (int)(cur_out_in_map[i] >= 0) * (int)(1u << i); 29 | } 30 | bitmask[split_mask_iter * n + idx] = cur_bitmask; 31 | } 32 | 33 | at::Tensor hash_query_cuda(const at::Tensor hash_query, 34 | const at::Tensor hash_target, 35 | const at::Tensor idx_target) { 36 | // return group_point_forward_gpu(points, indices); 37 | int n = hash_target.size(0); 38 | int n1 = hash_query.size(0); 39 | hashtable in_hash_table(n * 2); 40 | 41 | in_hash_table.insert_many(hash_target.data_ptr<int64_t>(), n); 42 | 43 | at::Tensor out = torch::zeros( 44 | {n1}, at::device(hash_query.device()).dtype(at::ScalarType::Int)); 45 | in_hash_table.lookup_many(hash_query.data_ptr<int64_t>(), out.data_ptr<int>(), n1); 46 | return out; 47 | } 48 | 49 | 50 | void convert_transposed_out_in_map(const at::Tensor out_in_map, 51 | at::Tensor out_in_map_t) { 52 | convert_out_in_map_kernel<<<(out_in_map.size(0) * out_in_map.size(1) + 255) / 256, 256>>>( 53 | out_in_map.data_ptr<int>(), out_in_map_t.data_ptr<int>(), out_in_map.size(0), out_in_map.size(1)); 54 | } 55 | 56 | 57 | 58 | 59 | at::Tensor derive_bitmask_from_out_in_map(const at::Tensor out_in_map, const int split_mask_num, int valid_n) { 60 | at::Tensor bitmask = torch::full( 61 | {split_mask_num, out_in_map.size(0)}, -1, at::device(out_in_map.device()).dtype(at::ScalarType::Int)); 62 | derive_bit_mask_from_out_in_map_kernel<<<(split_mask_num * out_in_map.size(0) + 255) / 256, 256>>>( 63 | out_in_map.data_ptr<int>(), bitmask.data_ptr<int>(), valid_n, out_in_map.size(0), out_in_map.size(1), split_mask_num); 64 | return bitmask; 65 | } 66 | -------------------------------------------------------------------------------- /torchsparse/backend/others/query_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor hash_query_cuda(const at::Tensor hash_query, 6 | const at::Tensor hash_target, 7 | const at::Tensor idx_target); 8 | void convert_transposed_out_in_map(const at::Tensor out_in_map, 9 | at::Tensor out_in_map_t); 10 | at::Tensor derive_bitmask_from_out_in_map(const at::Tensor out_in_map, const int split_mask_num, int valid_n); -------------------------------------------------------------------------------- /torchsparse/backend/others/reduce_bitmask_cuda.cu: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "reduce_bitmask_cuda.h" 3 | 4 | 5 | // 1 block -- 4 warps -- 128 threads 6 | // 1 warp -- 8 output elements (4 threads for 1 reduced element in int32) 7 | // each thread reduce reduce_tile/4 int32 numbers 8 | // threads in 1st warp finish the final reduction of 4 numbers 9 | 10 | #define thd_per_blk 128 11 | #define output_per_blk 32 // thd_per_blk / 4 -> (4 threads for 1 reduced element in int32) 12 | 13 | extern "C" __global__ 14 | void __launch_bounds__(thd_per_blk) reduce_mask_cuda_int32( 15 | int* __restrict__ bitmask, 16 | int output_node_num, 17 | int reduced_row_num, 18 | int reduce_tile, 19 | int* __restrict__ reduced_bitmask) { 20 | 21 | int split_mask_iter = blockIdx.y; 22 | int thread_size = reduce_tile / 4; 23 | int blockIdx_x = (int)blockIdx.x; 24 | int threadIdx_x = (int)threadIdx.x; 25 | int laneid = (threadIdx_x & 31); 26 | int warpid = (threadIdx_x >> 5); 27 | 28 | int bitmask_local = 0; 29 | __shared__ int bitmask_shared[thd_per_blk]; 30 | int* final_reduce_ptr = bitmask_shared + (laneid << 2); 31 | 32 | int* bitmask_blk = bitmask + split_mask_iter * output_node_num; 33 | int* reduced_bitmask_blk = reduced_bitmask + split_mask_iter * reduced_row_num; 34 | int block_offset = blockIdx_x * thd_per_blk * thread_size; 35 | int thread_offset = block_offset + (threadIdx_x * thread_size); 36 | int load_len = min(thread_size, output_node_num - thread_offset); 37 | 38 | #pragma unroll 39 | for (int i = 0; i < load_len; i++) { 40 | int load_offset = i + thread_offset; 41 | bitmask_local = bitmask_local | bitmask_blk[load_offset]; 42 | } 43 | bitmask_shared[threadIdx_x] = bitmask_local; 44 | __syncthreads(); 45 | 46 | // final reduction 47 | if(warpid == 0){ 48 | #pragma unroll 49 | for(int i = 1; i < 4; i++){ 50 | final_reduce_ptr[0] = final_reduce_ptr[0] | final_reduce_ptr[i]; 51 | } 52 | int output_offset = (blockIdx_x << 5) + laneid; 53 | if (output_offset < reduced_row_num){ 54 | reduced_bitmask_blk[output_offset] = final_reduce_ptr[0]; 55 | } 56 | } 57 | } 58 | 59 | 60 | torch::Tensor reduce_bitmask_cuda( 61 | torch::Tensor _bitmask_int, 62 | int M_tile 63 | ){ 64 | if (M_tile % 4 != 0) 65 | { 66 | throw std::runtime_error("[Bitmask reduce] reduce tile size must be multiple of 4."); 67 | } 68 | int split_mask_num = _bitmask_int.size(0); 69 | int output_node_num = _bitmask_int.size(1); 70 | int reduced_row_num = (output_node_num - 1) / M_tile + 1; 71 | auto options = torch::TensorOptions().dtype(torch::kInt32).device(_bitmask_int.device()); 72 | torch::Tensor _reduced_bitmask_int = torch::zeros({split_mask_num, reduced_row_num}, options); 73 | 74 | auto bitmask_int = _bitmask_int.data_ptr<int>(); 75 | auto reduced_bitmask_int = _reduced_bitmask_int.data_ptr<int>(); 76 | 77 | dim3 num_blocks(((reduced_row_num - 1) / output_per_blk + 1), split_mask_num); 78 | dim3 num_threads(thd_per_blk); 79 | 80 | reduce_mask_cuda_int32<<<num_blocks, num_threads>>>( 81 | bitmask_int, output_node_num, reduced_row_num, M_tile, reduced_bitmask_int); 82 | 83 | return _reduced_bitmask_int; 84 | } 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /torchsparse/backend/others/reduce_bitmask_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include <torch/torch.h> 3 | 4 | torch::Tensor reduce_bitmask_cuda( 5 | torch::Tensor _bitmask_int, 6 | int M_tile 7 | ); -------------------------------------------------------------------------------- /torchsparse/backend/others/reorder_map_cuda.cu: -------------------------------------------------------------------------------- 1 | #include <torch/extension.h> 2 | #include "reorder_map_cuda.h" 3 | 4 | #define cta_M 128 5 | #define thd_num 128 // 1 thd per row 6 | 7 | 8 | __global__ void __launch_bounds__(thd_num) reorder_out_in_map_kernel( 9 | int* __restrict__ out_in_map, 10 | int* __restrict__ reorder_loc, 11 | int M, // node num 12 | int kernel_volume, 13 | int split_mask_len, 14 | int* __restrict__ reorder_out_in_map 15 | ){ 16 | int index = blockIdx.x * blockDim.x + threadIdx.x; 17 | int output_row_idx = index / kernel_volume; 18 | int output_col_idx = index % kernel_volume; 19 | if (output_row_idx >= M) return; 20 | int split_mask_iter = output_col_idx / split_mask_len; 21 | int input_row_idx = reorder_loc[split_mask_iter * M + output_row_idx]; 22 | reorder_out_in_map[output_row_idx * kernel_volume + output_col_idx] = out_in_map[input_row_idx * kernel_volume + output_col_idx]; 23 | } 24 | 25 | at::Tensor reorder_out_in_map_cuda( 26 | torch::Tensor _out_in_map, 27 | torch::Tensor _reorder_loc 28 | ){ 29 | 30 | int M = _out_in_map.size(0); 31 | int kernel_volume = _out_in_map.size(1); 32 | int split_mask_num = _reorder_loc.size(0); 33 | int split_mask_len = (kernel_volume + split_mask_num - 1) / split_mask_num; 34 | 35 | auto options = 36 | torch::TensorOptions().dtype(_out_in_map.dtype()).device(_out_in_map.device()); 37 | at::Tensor _reorder_out_in_map = torch::empty({M, kernel_volume}, options); 38 | 39 | 40 | auto out_in_map = _out_in_map.data_ptr<int>(); 41 | auto reorder_loc = _reorder_loc.data_ptr<int>(); 42 | auto reorder_out_in_map = _reorder_out_in_map.data_ptr<int>(); 43 | 44 | reorder_out_in_map_kernel<<<(M + cta_M - 1) / cta_M * kernel_volume, cta_M>>>( 45 | out_in_map, reorder_loc, M, kernel_volume, split_mask_len, reorder_out_in_map); 46 | 47 | return _reorder_out_in_map; 48 | } -------------------------------------------------------------------------------- /torchsparse/backend/others/reorder_map_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include <torch/torch.h> 3 | 4 | at::Tensor reorder_out_in_map_cuda( 5 | torch::Tensor _out_in_map, 6 | torch::Tensor _reorder_loc 7 | ); -------------------------------------------------------------------------------- /torchsparse/backend/others/sparsemapping_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | #include "../hashmap/hashmap_cuda.cuh" 5 | 6 | std::vector<at::Tensor> build_mask_from_kmap(int n_points, int n_out_points, 7 | at::Tensor _kmap, 8 | at::Tensor _kmap_sizes); 9 | 10 | std::vector<at::Tensor> build_kernel_map_subm_hashmap( 11 | hashtable& table, 12 | at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max, 13 | at::Tensor _kernel_sizes, at::Tensor _stride, 14 | at::Tensor padding, bool to_insert); 15 | 16 | std::vector<at::Tensor> build_kernel_map_downsample_hashmap( 17 | hashtable& table, 18 | at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max, 19 | at::Tensor _kernel_sizes, at::Tensor _stride, 20 | at::Tensor _padding, bool to_insert); 21 | 22 | std::vector<at::Tensor> build_kernel_map_subm_hashmap_int32( 23 | hashtable32& table, 24 | at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max, 25 | at::Tensor _kernel_sizes, at::Tensor _stride, 26 | at::Tensor padding, bool to_insert); 27 | 28 | std::vector<at::Tensor> build_kernel_map_downsample_hashmap_int32( 29 | hashtable32& table, 30 | at::Tensor _in_coords, at::Tensor _coords_min, at::Tensor _coords_max, 31 | at::Tensor _kernel_sizes, at::Tensor _stride, 32 | at::Tensor _padding, bool to_insert); 33 | -------------------------------------------------------------------------------- /torchsparse/backend/pybind_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include <pybind11/pybind11.h> 2 | #include <torch/extension.h> 3 | #include <torch/serialize/tensor.h> 4 | 5 | #include "convolution/convolution_gather_scatter_cpu.h" 6 | #include "devoxelize/devoxelize_cpu.h" 7 | #include "hash/hash_cpu.h" 8 | #include "others/count_cpu.h" 9 | #include "others/query_cpu.h" 10 | #include "voxelize/voxelize_cpu.h" 11 | 12 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 13 | m.def("conv_forward_gather_scatter_cpu", &conv_forward_gather_scatter_cpu); 14 | m.def("conv_backward_gather_scatter_cpu", &conv_backward_gather_scatter_cpu); 15 | m.def("voxelize_forward_cpu", &voxelize_forward_cpu); 16 | m.def("voxelize_backward_cpu", &voxelize_backward_cpu); 17 | m.def("devoxelize_forward_cpu", &devoxelize_forward_cpu); 18 | m.def("devoxelize_backward_cpu", &devoxelize_backward_cpu); 19 | m.def("hash_cpu", &hash_cpu); 20 | m.def("kernel_hash_cpu", &kernel_hash_cpu); 21 | m.def("hash_query_cpu", &hash_query_cpu); 22 | m.def("count_cpu", &count_cpu); 23 | } 24 | -------------------------------------------------------------------------------- /torchsparse/backend/pybind_cuda.cu: -------------------------------------------------------------------------------- 1 | #include <pybind11/pybind11.h> 2 | #include <torch/extension.h> 3 | #include <torch/serialize/tensor.h> 4 | 5 | #include "convolution/convolution_gather_scatter_cpu.h" 6 | #include "convolution/convolution_gather_scatter_cuda.h" 7 | #include "convolution/convolution_forward_fetch_on_demand_cuda.h" 8 | #include "convolution/convolution_forward_implicit_gemm_cuda.h" 9 | #include "convolution/convolution_forward_implicit_gemm_sorted_cuda.h" 10 | #include "convolution/convolution_backward_wgrad_implicit_gemm_cuda.h" 11 | #include "convolution/convolution_backward_wgrad_implicit_gemm_sorted_cuda.h" 12 | #include "devoxelize/devoxelize_cpu.h" 13 | #include "devoxelize/devoxelize_cuda.h" 14 | #include "hash/hash_cpu.h" 15 | #include "hash/hash_cuda.h" 16 | #include "others/count_cpu.h" 17 | #include "others/count_cuda.h" 18 | #include "others/downsample_cuda.h" 19 | #include "others/exclusive_scan_cuda.h" 20 | #include "others/query_cpu.h" 21 | #include "others/query_cuda.h" 22 | #include "others/reduce_bitmask_cuda.h" 23 | #include "others/reorder_map_cuda.h" 24 | #include "others/sparsemapping_cuda.h" 25 | #include "voxelize/voxelize_cpu.h" 26 | #include "voxelize/voxelize_cuda.h" 27 | #include "hashmap/hashmap_cuda.cuh" 28 | 29 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 30 | py::class_<hashtable>(m, "GPUHashTable") 31 | .def(py::init<const int>()) 32 | .def(py::init<torch::Tensor, torch::Tensor>()) 33 | .def("insert_vals", &hashtable::insert_vals) 34 | .def("lookup_vals", &hashtable::lookup_vals) 35 | .def("insert_coords", &hashtable::insert_coords) 36 | .def("lookup_coords", &hashtable::lookup_coords); 37 | py::class_<hashtable32>(m, "GPUHashTable32") 38 | .def(py::init<const int>()) 39 | .def(py::init<torch::Tensor, torch::Tensor>()) 40 | .def("insert_vals", &hashtable32::insert_vals) 41 | .def("lookup_vals", &hashtable32::lookup_vals) 42 | .def("insert_coords", &hashtable32::insert_coords) 43 | .def("lookup_coords", &hashtable32::lookup_coords); 44 | m.def("conv_forward_gather_scatter_cpu", &conv_forward_gather_scatter_cpu); 45 | m.def("conv_forward_gather_scatter_cuda", &conv_forward_gather_scatter_cuda); 46 | m.def("conv_forward_fetch_on_demand_cuda", &conv_forward_fetch_on_demand_cuda); 47 | m.def("conv_forward_fetch_on_demand_no_fusion_cuda", &conv_forward_fetch_on_demand_no_fusion_cuda); 48 | m.def("conv_forward_implicit_gemm_cuda", &conv_forward_implicit_gemm_cuda, py::arg("_in_feats"), py::arg("_kernel"), py::arg("_out_in_map"), py::arg("num_out_feats"),py::arg("num_out_channels"), py::arg("allow_tf32") = false, py::arg("allow_fp16") = true); 49 | m.def("conv_forward_implicit_gemm_sorted_cuda", &conv_forward_implicit_gemm_sorted_cuda, py::arg("_in_feats"), py::arg("_kernel"), py::arg("_out_in_map"), py::arg("_reduced_mask"), py::arg("_reorder_loc"), py::arg("num_out_feats"), py::arg("num_out_channels"), py::arg("allow_tf32") = false, py::arg("allow_fp16") = true); 50 | m.def("conv_backward_wgrad_implicit_gemm_cuda", &conv_backward_wgrad_implicit_gemm_cuda, py::arg("_in_feats"), py::arg("_kernel"), py::arg("_out_in_map"), py::arg("split_k_iters"), py::arg("allow_tf32") = false, py::arg("allow_fp16") = true); 51 | m.def("conv_backward_wgrad_implicit_gemm_sorted_cuda", &conv_backward_wgrad_implicit_gemm_sorted_cuda, py::arg("_in_feats"), py::arg("_kernel"), py::arg("_out_in_map"), py::arg("_reduced_mask"), py::arg("_reorder_loc"), py::arg("split_k_iters"), py::arg("allow_tf32") = false, py::arg("allow_fp16") = true); 52 | m.def("conv_backward_gather_scatter_cpu", &conv_backward_gather_scatter_cpu); 53 | m.def("conv_backward_gather_scatter_cuda", &conv_backward_gather_scatter_cuda); 54 | m.def("voxelize_forward_cpu", &voxelize_forward_cpu); 55 | m.def("voxelize_forward_cuda", &voxelize_forward_cuda); 56 | m.def("voxelize_backward_cpu", &voxelize_backward_cpu); 57 | m.def("voxelize_backward_cuda", &voxelize_backward_cuda); 58 | m.def("to_dense_forward_cuda", &to_dense_forward_cuda); 59 | m.def("to_dense_backward_cuda", &to_dense_backward_cuda); 60 | m.def("devoxelize_forward_cpu", &devoxelize_forward_cpu); 61 | m.def("devoxelize_forward_cuda", &devoxelize_forward_cuda); 62 | m.def("devoxelize_backward_cpu", &devoxelize_backward_cpu); 63 | m.def("devoxelize_backward_cuda", &devoxelize_backward_cuda); 64 | m.def("exclusive_scan_quantified_wrapper", &exclusive_scan_quantified_wrapper); 65 | m.def("hash_cpu", &hash_cpu); 66 | m.def("hash_cuda", &hash_cuda); 67 | m.def("kernel_hash_cpu", &kernel_hash_cpu); 68 | m.def("kernel_hash_cuda", &kernel_hash_cuda); 69 | m.def("hash_query_cpu", &hash_query_cpu); 70 | m.def("hash_query_cuda", &hash_query_cuda); 71 | m.def("convert_transposed_out_in_map", &convert_transposed_out_in_map); 72 | m.def("derive_bitmask_from_out_in_map", &derive_bitmask_from_out_in_map); 73 | m.def("reduce_bitmask_cuda", &reduce_bitmask_cuda); 74 | m.def("reorder_out_in_map_cuda", &reorder_out_in_map_cuda); 75 | m.def("build_kernel_map_subm_hashmap", &build_kernel_map_subm_hashmap); 76 | m.def("build_kernel_map_downsample_hashmap", &build_kernel_map_downsample_hashmap); 77 | m.def("build_kernel_map_subm_hashmap_int32", &build_kernel_map_subm_hashmap_int32); 78 | m.def("build_kernel_map_downsample_hashmap_int32", &build_kernel_map_downsample_hashmap_int32); 79 | m.def("build_mask_from_kmap", &build_mask_from_kmap); 80 | m.def("downsample_cuda", &downsample_cuda); 81 | m.def("count_cpu", &count_cpu); 82 | m.def("count_cuda", &count_cuda); 83 | } 84 | -------------------------------------------------------------------------------- /torchsparse/backend/utils/atomic.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | __device__ static uint64_t atomicExch(uint64_t *addr, uint64_t val) { 4 | return (uint64_t)atomicExch((unsigned long long int *)addr, 5 | (unsigned long long int)val); 6 | } 7 | -------------------------------------------------------------------------------- /torchsparse/backend/utils/memory.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | template <int bytes> 4 | struct global_load; 5 | 6 | template <> 7 | struct global_load<16> 8 | { 9 | __device__ __inline__ global_load(uint4 &D, void const *ptr, int pred_guard) 10 | { 11 | uint4 &data = *reinterpret_cast<uint4 *>(&D); 12 | asm volatile( 13 | "{\n" 14 | " .reg .pred p;\n" 15 | " setp.ne.b32 p, %5, 0;\n" 16 | " mov.b32 %0, %6;\n" 17 | " mov.b32 %1, %7;\n" 18 | " mov.b32 %2, %8;\n" 19 | " mov.b32 %3, %9;\n" 20 | " @p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n" 21 | "}\n" 22 | : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) 23 | : "l"(ptr), "r"((int)(pred_guard & 1)), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); 24 | } 25 | }; 26 | 27 | template <> 28 | struct global_load<8> 29 | { 30 | __device__ __inline__ global_load(uint4 &D, void const *ptr, int pred_guard) 31 | { 32 | uint2 const *ptr_ldg = reinterpret_cast<uint2 const *>(ptr); 33 | #pragma unroll 34 | for (int ldg_idx = 0; ldg_idx < 2; ldg_idx++) 35 | { 36 | uint2 &data = *(reinterpret_cast<uint2 *>(&D) + ldg_idx); 37 | asm volatile( 38 | "{\n" 39 | " .reg .pred p;\n" 40 | " setp.ne.b32 p, %3, 0;\n" 41 | " mov.b32 %0, %4;\n" 42 | " mov.b32 %1, %5;\n" 43 | " @p ld.global.v2.u32 {%0, %1}, [%2];\n" 44 | "}\n" 45 | : "=r"(data.x), "=r"(data.y) 46 | : "l"(ptr_ldg + ldg_idx), "r"((int)(pred_guard & (1 << ldg_idx))), "r"(data.x), "r"(data.y)); 47 | } 48 | } 49 | }; 50 | 51 | template <> 52 | struct global_load<4> 53 | { 54 | __device__ __inline__ global_load(uint4 &D, void const *ptr, int pred_guard) 55 | { 56 | unsigned const *ptr_ldg = reinterpret_cast<unsigned const *>(ptr); 57 | #pragma unroll 58 | for (int ldg_idx = 0; ldg_idx < 4; ldg_idx++) 59 | { 60 | unsigned &data = *(reinterpret_cast<unsigned *>(&D) + ldg_idx); 61 | asm volatile( 62 | "{\n" 63 | " .reg .pred p;\n" 64 | " setp.ne.b32 p, %2, 0;\n" 65 | " mov.b32 %0, %3;\n" 66 | " @p ld.global.u32 %0, [%1];\n" 67 | "}\n" 68 | : "=r"(data) 69 | : "l"(ptr_ldg + ldg_idx), "r"((int)(pred_guard & (1 << ldg_idx))), "r"(data)); 70 | } 71 | } 72 | }; 73 | 74 | template <> 75 | struct global_load<2> 76 | { 77 | __device__ __inline__ global_load(uint4 &D, void const *ptr, int pred_guard) 78 | { 79 | uint16_t const *ptr_ldg = reinterpret_cast<uint16_t const *>(ptr); 80 | #pragma unroll 81 | for (int ldg_idx = 0; ldg_idx < 8; ldg_idx++) 82 | { 83 | uint16_t &data = *(reinterpret_cast<uint16_t *>(&D) + ldg_idx); 84 | asm volatile( 85 | "{\n" 86 | " .reg .pred p;\n" 87 | " setp.ne.b32 p, %2, 0;\n" 88 | " mov.b16 %0, %3;\n" 89 | " @p ld.global.u16 %0, [%1];\n" 90 | "}\n" 91 | : "=h"(data) 92 | : "l"(ptr_ldg + ldg_idx), "r"((int)(pred_guard & (1 << ldg_idx))), "h"(data)); 93 | } 94 | } 95 | }; 96 | 97 | -------------------------------------------------------------------------------- /torchsparse/backend/voxelize/voxelize_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "voxelize_cpu.h" 2 | 3 | #include <torch/torch.h> 4 | 5 | #include <vector> 6 | 7 | at::Tensor voxelize_forward_cpu(const at::Tensor inputs, const at::Tensor idx, 8 | const at::Tensor counts) { 9 | int N = inputs.size(0); 10 | int c = inputs.size(1); 11 | int N1 = counts.size(0); 12 | at::Tensor out = torch::zeros( 13 | {N1, c}, at::device(idx.device()).dtype(at::ScalarType::Float)); 14 | for (int i = 0; i < N; i++) { 15 | int pos = *(idx.data_ptr<int>() + i); 16 | if (*(counts.data_ptr<int>() + pos) == 0) continue; 17 | #pragma omp parallel for 18 | for (int j = 0; j < c; j++) { 19 | *(out.data_ptr<float>() + pos * c + j) += 20 | *(inputs.data_ptr<float>() + i * c + j) / 21 | (float)(*(counts.data_ptr<int>() + pos)); 22 | } 23 | } 24 | return out; 25 | } 26 | 27 | at::Tensor voxelize_backward_cpu(const at::Tensor top_grad, 28 | const at::Tensor idx, const at::Tensor counts, 29 | const int N) { 30 | int c = top_grad.size(1); 31 | at::Tensor bottom_grad = torch::zeros( 32 | {N, c}, at::device(idx.device()).dtype(at::ScalarType::Float)); 33 | for (int i = 0; i < N; i++) { 34 | if (*(counts.data_ptr<int>() + *(idx.data_ptr<int>() + i)) == 0) continue; 35 | #pragma omp parallel for 36 | for (int j = 0; j < c; j++) { 37 | *(bottom_grad.data_ptr<float>() + i * c + j) = 38 | *(top_grad.data_ptr<float>() + *(idx.data_ptr<int>() + i) * c + j) / 39 | (float)(*(counts.data_ptr<int>() + *(idx.data_ptr<int>() + i))); 40 | } 41 | } 42 | return bottom_grad; 43 | } 44 | -------------------------------------------------------------------------------- /torchsparse/backend/voxelize/voxelize_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor voxelize_forward_cpu(const at::Tensor inputs, const at::Tensor idx, 6 | const at::Tensor counts); 7 | 8 | at::Tensor voxelize_backward_cpu(const at::Tensor top_grad, 9 | const at::Tensor idx, const at::Tensor counts, 10 | const int N); 11 | -------------------------------------------------------------------------------- /torchsparse/backend/voxelize/voxelize_cuda.cu: -------------------------------------------------------------------------------- 1 | #include <stdio.h> 2 | #include <stdlib.h> 3 | #include <torch/torch.h> 4 | 5 | #include <THC/THCAtomics.cuh> 6 | #include <cmath> 7 | 8 | // to_dense: feats (N x C), coords (N x 4), output (B x H x W x D x C) 9 | // coords: batch, x, y, z 10 | template <typename scalar_t> 11 | __global__ void to_dense_forward_kernel(int N, int c, const scalar_t *__restrict__ feats, const int *__restrict__ coords, const int *__restrict__ range, scalar_t *__restrict__ out) 12 | { 13 | int index = blockDim.x * blockIdx.x + threadIdx.x; 14 | int i = index / c; 15 | int j = index % c; 16 | if (i < N) 17 | { 18 | const int *cur_coords = coords + 4 * i; 19 | int pos = cur_coords[0] * range[1] * range[2] * range[3] + cur_coords[1] * range[2] * range[3] + cur_coords[2] * range[3] + cur_coords[3]; 20 | out[pos * c + j] = feats[index]; 21 | } 22 | } 23 | 24 | // to_dense: top_grad (B x H x W x D x C), coords (N x 4), bottom_grad (N x C) 25 | template <typename scalar_t> 26 | __global__ void to_dense_backward_kernel(int N, int c, const scalar_t *__restrict__ top_grad, const int *__restrict__ coords, const int *__restrict__ range, scalar_t *__restrict__ bottom_grad) 27 | { 28 | int index = blockDim.x * blockIdx.x + threadIdx.x; 29 | int i = index / c; 30 | int j = index % c; 31 | if (i < N) 32 | { 33 | const int *cur_coords = coords + 4 * i; 34 | int pos = cur_coords[0] * range[1] * range[2] * range[3] + cur_coords[1] * range[2] * range[3] + cur_coords[2] * range[3] + cur_coords[3]; 35 | bottom_grad[index] = top_grad[pos * c + j]; 36 | } 37 | } 38 | 39 | template <typename scalar_t> 40 | __global__ void voxelize_forward_kernel(int N, int c, int s, 41 | const scalar_t *__restrict__ data, 42 | const int *__restrict__ idx, 43 | const int *__restrict__ counts, 44 | scalar_t *__restrict__ out) 45 | { 46 | int index = blockDim.x * blockIdx.x + threadIdx.x; 47 | int i = index / c; 48 | int j = index % c; 49 | if (i < N) 50 | { 51 | int pos = idx[i]; 52 | if (pos < 0 || pos >= s || counts[pos] == 0) 53 | return; 54 | atomicAdd(&out[pos * c + j], data[i * c + j] / float(counts[pos])); 55 | } 56 | } 57 | 58 | template <typename scalar_t> 59 | __global__ void voxelize_backward_kernel(int N, int c, int s, 60 | const scalar_t *__restrict__ top_grad, 61 | const int *__restrict__ idx, 62 | const int *__restrict__ counts, 63 | scalar_t *__restrict__ bottom_grad) 64 | { 65 | int index = blockDim.x * blockIdx.x + threadIdx.x; 66 | int i = index / c; 67 | int j = index % c; 68 | if (i < N) 69 | { 70 | int pos = idx[i]; 71 | if (pos < 0 || pos >= s || counts[pos] == 0) 72 | return; 73 | atomicAdd(&bottom_grad[i * c + j], 74 | top_grad[pos * c + j] / float(counts[pos])); 75 | } 76 | } 77 | 78 | at::Tensor voxelize_forward_cuda(const at::Tensor inputs, const at::Tensor idx, 79 | const at::Tensor counts) 80 | { 81 | int N = inputs.size(0); 82 | int c = inputs.size(1); 83 | int N1 = counts.size(0); 84 | 85 | at::Tensor out = 86 | torch::zeros({N1, c}, at::device(idx.device()).dtype(inputs.dtype())); 87 | 88 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 89 | inputs.scalar_type(), "voxelize_forward_cuda", ([&] 90 | { voxelize_forward_kernel<scalar_t><<<N, c>>>( 91 | N, c, N1, inputs.data_ptr<scalar_t>(), idx.data_ptr<int>(), 92 | counts.data_ptr<int>(), out.data_ptr<scalar_t>()); })); 93 | 94 | return out; 95 | } 96 | 97 | at::Tensor voxelize_backward_cuda(const at::Tensor top_grad, 98 | const at::Tensor idx, const at::Tensor counts, 99 | const int N) 100 | { 101 | int c = top_grad.size(1); 102 | int N1 = counts.size(0); 103 | 104 | at::Tensor bottom_grad = 105 | torch::zeros({N, c}, at::device(idx.device()).dtype(top_grad.dtype())); 106 | 107 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 108 | top_grad.scalar_type(), "voxelize_backward_cuda", ([&] 109 | { voxelize_backward_kernel<scalar_t><<<N, c>>>( 110 | N, c, N1, top_grad.data_ptr<scalar_t>(), idx.data_ptr<int>(), 111 | counts.data_ptr<int>(), bottom_grad.data_ptr<scalar_t>()); })); 112 | 113 | return bottom_grad; 114 | } 115 | 116 | void to_dense_forward_cuda(const at::Tensor inputs, const at::Tensor idx, 117 | const at::Tensor range, at::Tensor outputs) 118 | { 119 | int N = inputs.size(0); 120 | int c = inputs.size(1); 121 | 122 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 123 | inputs.scalar_type(), "to_dense_forward_cuda", ([&] 124 | { to_dense_forward_kernel<scalar_t><<<(N * c + 255) / 256, 256>>>( 125 | N, c, inputs.data_ptr<scalar_t>(), idx.data_ptr<int>(), 126 | range.data_ptr<int>(), outputs.data_ptr<scalar_t>()); })); 127 | } 128 | 129 | void to_dense_backward_cuda(const at::Tensor top_grad, 130 | const at::Tensor idx, const at::Tensor range, 131 | const at::Tensor bottom_grad) 132 | { 133 | int N = bottom_grad.size(0); 134 | int c = bottom_grad.size(1); 135 | 136 | AT_DISPATCH_FLOATING_TYPES_AND_HALF( 137 | top_grad.scalar_type(), "to_dense_backward_cuda", ([&] 138 | { to_dense_backward_kernel<scalar_t><<<(N * c + 255) / 256, 256>>>( 139 | N, c, top_grad.data_ptr<scalar_t>(), idx.data_ptr<int>(), 140 | range.data_ptr<int>(), bottom_grad.data_ptr<scalar_t>()); })); 141 | } 142 | -------------------------------------------------------------------------------- /torchsparse/backend/voxelize/voxelize_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include <torch/torch.h> 4 | 5 | at::Tensor voxelize_forward_cuda(const at::Tensor inputs, const at::Tensor idx, 6 | const at::Tensor counts); 7 | 8 | at::Tensor voxelize_backward_cuda(const at::Tensor top_grad, 9 | const at::Tensor idx, const at::Tensor counts, 10 | const int N); 11 | 12 | void to_dense_forward_cuda(const at::Tensor inputs, const at::Tensor idx, 13 | const at::Tensor range, at::Tensor outputs); 14 | 15 | void to_dense_backward_cuda(const at::Tensor top_grad, 16 | const at::Tensor idx, const at::Tensor range, 17 | const at::Tensor bottom_grad); 18 | -------------------------------------------------------------------------------- /torchsparse/backends.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def init(): 5 | global benchmark, allow_tf32, allow_fp16, device_capability, hash_rsv_ratio 6 | benchmark = False 7 | device_capability = torch.cuda.get_device_capability() 8 | device_capability = device_capability[0] * 100 + device_capability[1] * 10 9 | allow_tf32 = device_capability >= 800 10 | allow_fp16 = device_capability >= 750 11 | hash_rsv_ratio = 2 # default value, reserve 2x ( 2 * original_point_number) space for downsampling 12 | -------------------------------------------------------------------------------- /torchsparse/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import * 2 | from .conv import * 3 | from .count import * 4 | from .crop import * 5 | from .devoxelize import * 6 | from .hash import * 7 | from .pooling import * 8 | from .query import * 9 | from .voxelize import * 10 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/activation.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | from torchsparse import SparseTensor 4 | from torchsparse.nn.utils import fapply 5 | 6 | __all__ = ["relu", "silu", "leaky_relu"] 7 | 8 | 9 | def relu(input: SparseTensor, inplace: bool = True) -> SparseTensor: 10 | return fapply(input, F.relu, inplace=inplace) 11 | 12 | 13 | def silu(input: SparseTensor, inplace: bool = True) -> SparseTensor: 14 | return fapply(input, F.silu, inplace=inplace) 15 | 16 | 17 | def leaky_relu( 18 | input: SparseTensor, negative_slope: float = 0.1, inplace: bool = True 19 | ) -> SparseTensor: 20 | return fapply(input, F.leaky_relu, negative_slope=negative_slope, inplace=inplace) 21 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv import * 2 | from .conv_mode import * 3 | 4 | # from .conv_config import * 5 | from .conv_config import Dataflow 6 | from .hash import * 7 | from .kmap import * 8 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/conv_config.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple, Union 2 | from enum import Enum 3 | from .utils import AttributeDict 4 | from .conv_mode import ConvMode, get_kmap_mode, get_downsample_mode 5 | 6 | 7 | class Dataflow(Enum): 8 | ImplicitGEMM = 0 9 | GatherScatter = 1 10 | FetchOnDemand = 2 11 | CodedCSR = 3 12 | 13 | 14 | _global_conv_config = None 15 | _default_conv_config = AttributeDict( 16 | [ 17 | ("dataflow", Dataflow.ImplicitGEMM), 18 | ("ifsort", False), 19 | ("kmap_mode", "hashmap_on_the_fly"), 20 | ("downsample_mode", "spconv"), 21 | ("split_mask_num", 1), 22 | ("split_mask_num_bwd", 3), 23 | ("epsilon", 0.0), 24 | ("mm_thresh", 0), 25 | ("FOD_fusion", True), 26 | ] 27 | ) 28 | 29 | 30 | def keys_check(conv_config): 31 | flag = False 32 | if "dataflow" not in conv_config: 33 | flag = True 34 | conv_config["dataflow"] = _default_conv_config["dataflow"] 35 | if "ifsort" not in conv_config: 36 | flag = True 37 | conv_config["ifsort"] = _default_conv_config["ifsort"] 38 | if "kmap_mode" not in conv_config: 39 | flag = True 40 | conv_config["kmap_mode"] = _default_conv_config["kmap_mode"] 41 | if "downsample_mode" not in conv_config: 42 | flag = True 43 | conv_config["downsample_mode"] = _default_conv_config["downsample_mode"] 44 | if "split_mask_num" not in conv_config: 45 | flag = True 46 | conv_config["split_mask_num"] = _default_conv_config["split_mask_num"] 47 | if "split_mask_num_bwd" not in conv_config: 48 | flag = True 49 | conv_config["split_mask_num_bwd"] = _default_conv_config["split_mask_num_bwd"] 50 | if "epsilon" not in conv_config: 51 | flag = True 52 | conv_config["epsilon"] = _default_conv_config["epsilon"] 53 | if "mm_thresh" not in conv_config: 54 | flag = True 55 | conv_config["mm_thresh"] = _default_conv_config["mm_thresh"] 56 | if "FOD_fusion" not in conv_config: 57 | flag = True 58 | conv_config["FOD_fusion"] = _default_conv_config["FOD_fusion"] 59 | if flag == True: 60 | print( 61 | "Warning: Missing fields for ConvConfig. Use default configs for these fields." 62 | ) 63 | 64 | 65 | def get_global_conv_config(): 66 | global _global_conv_config 67 | return _global_conv_config 68 | 69 | 70 | def set_global_conv_config(conv_config): 71 | global _global_conv_config 72 | keys_check(conv_config) 73 | _global_conv_config = conv_config 74 | 75 | 76 | def clear_global_conv_config(): 77 | global _global_conv_config 78 | _global_conv_config = None 79 | 80 | 81 | def get_default_conv_config( 82 | conv_mode: ConvMode = ConvMode.mode0, training: bool = False 83 | ): 84 | config = _default_conv_config 85 | # if training: 86 | # config.ifsort = True 87 | if conv_mode == ConvMode.mode0: 88 | pass 89 | elif conv_mode == ConvMode.mode1: 90 | config.ifsort = True 91 | elif conv_mode == ConvMode.mode2: 92 | config.ifsort = True 93 | config.split_mask_num = 3 94 | return config 95 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/conv_mode.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | _global_kmap_mode = "hashmap_on_the_fly" # or "hashmap" 4 | _global_downsample_mode = "spconv" # or "minkowski" 5 | 6 | 7 | def get_kmap_mode(): 8 | global _global_kmap_mode 9 | return _global_kmap_mode 10 | 11 | 12 | def set_kmap_mode(kmap_mode: str): 13 | global _global_kmap_mode 14 | if kmap_mode in ["hashmap_on_the_fly", "hashmap"]: 15 | _global_kmap_mode = kmap_mode 16 | else: 17 | assert ( 18 | 0 19 | ), f'Unsupport kmap_mode: {kmap_mode}. Please set kmap_mode to "hashmap_on_the_fly" or "hashmap".' 20 | 21 | 22 | def get_downsample_mode(): 23 | global _global_downsample_mode 24 | return _global_downsample_mode 25 | 26 | 27 | def set_downsample_mode(downsample_mode: str): 28 | global _global_downsample_mode 29 | if downsample_mode in ["spconv", "minkowski"]: 30 | _global_downsample_mode = downsample_mode 31 | else: 32 | assert ( 33 | 0 34 | ), f'Unsupport downsample_mode {downsample_mode}. Please set downsample_mode to "spconv" or "minkowski".' 35 | 36 | 37 | class ConvMode(Enum): 38 | mode0 = 0 # split=0 fwd & split=3 bwd 39 | mode1 = 1 # split=1 fwd & split=3 bwd 40 | mode2 = 2 # split=3 fwd & split=3 bwd 41 | 42 | 43 | _global_conv_mode = ConvMode.mode0 44 | 45 | 46 | def get_conv_mode(): 47 | global _global_conv_mode 48 | return _global_conv_mode 49 | 50 | 51 | def set_conv_mode(conv_mode): 52 | global _global_conv_mode 53 | if isinstance(conv_mode, int): 54 | if conv_mode == 0: 55 | _global_conv_mode = ConvMode.mode0 56 | elif conv_mode == 1: 57 | _global_conv_mode = ConvMode.mode1 58 | elif conv_mode == 2: 59 | _global_conv_mode = ConvMode.mode2 60 | else: 61 | assert 0, f"Undefined conv_mode:{conv_mode}" 62 | elif isinstance(conv_mode, ConvMode): 63 | _global_conv_mode = conv_mode 64 | else: 65 | assert 0, f"Unsupport conv_mode input type" 66 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/func/__init__.py: -------------------------------------------------------------------------------- 1 | from .gather_scatter import * 2 | from .implicit_gemm import * 3 | from .fetch_on_demand import * 4 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/func/fetch_on_demand.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from torch.autograd import Function 5 | 6 | # from torch.cuda.amp import custom_bwd, custom_fwd 7 | 8 | import torchsparse 9 | import torchsparse.backend 10 | 11 | # TODO: Fetch_on_demand do not have backward kernels now. 12 | # Using Gather_Scatter for backward propogation. 13 | 14 | __all__ = ["FetchOnDemandConvolutionFuntion"] 15 | 16 | 17 | class FetchOnDemandConvolutionFuntion(Function): 18 | @staticmethod 19 | # @custom_fwd(cast_inputs=torch.half) 20 | def forward( 21 | ctx, 22 | input: torch.Tensor, 23 | weight: torch.Tensor, 24 | kmap: Dict, 25 | config: Dict, 26 | transposed: bool = False, 27 | ) -> torch.Tensor: 28 | 29 | """if transposed: 30 | input_nbmaps = kmap["nbmaps"][1, :] 31 | output_nbmaps = kmap["nbmaps"][0, :] 32 | else: 33 | input_nbmaps = kmap["nbmaps"][0, :] 34 | output_nbmaps = kmap["nbmaps"][1, :] 35 | 36 | M = nbmaps.size(0) 37 | nbmaps_t = torch.zeros((2, M), 38 | dtype=torch.int, device=input.device, requires_grad=False) 39 | for l in range(M): 40 | nbmaps_t[0, l] = nbmaps[l, 0] 41 | nbmaps_t[1, l] = nbmaps[l, 1]""" 42 | 43 | nbmaps = kmap["nbmaps"] 44 | nbsizes = kmap["nbsizes"] 45 | nbaddrs = kmap["nbaddrs"] 46 | qnbaddrs = kmap["qnbaddrs"] 47 | sizes = kmap["sizes"] 48 | qmapsize = kmap["qmapsize"] 49 | 50 | mapsize = nbmaps.size(1) 51 | 52 | input = input.contiguous() 53 | weight = weight.contiguous() 54 | 55 | # nbmaps = nbmaps.int().contiguous() 56 | # input_nbmaps = input_nbmaps.int().contiguous() 57 | # output_nbmaps = output_nbmaps.int().contiguous() 58 | # nbaddrs = nbaddrs.int().contiguous() 59 | # qnbaddrs = qnbaddrs.int().contiguous() 60 | # nbsizes = nbsizes.int().contiguous() 61 | 62 | if not input.device.type == "cuda": 63 | if not transposed: 64 | output = torch.zeros( 65 | sizes[1], weight.size(-1), dtype=input.dtype, device=input.device 66 | ) 67 | else: 68 | # TODO(Haotian): ensure the original, upsampled size to be the same. 69 | output = torch.zeros( 70 | sizes[0], weight.size(-1), dtype=input.dtype, device=input.device 71 | ) 72 | 73 | if input.device.type == "cuda": 74 | if torch.float16 in [input.dtype, weight.dtype]: 75 | input = input.to(torch.float16) 76 | weight = weight.to(torch.float16) 77 | 78 | if config["FOD_fusion"] == True: 79 | output = torchsparse.backend.conv_forward_fetch_on_demand_cuda( 80 | input, 81 | weight, 82 | nbmaps, 83 | mapsize, 84 | nbaddrs, 85 | qnbaddrs, 86 | sizes[1] if not transposed else sizes[0], 87 | qmapsize, 88 | transposed, 89 | torchsparse.backends.allow_tf32, 90 | torchsparse.backends.allow_fp16, 91 | ) 92 | else: 93 | output = ( 94 | torchsparse.backend.conv_forward_fetch_on_demand_no_fusion_cuda( 95 | input, 96 | weight, 97 | nbmaps, 98 | nbsizes.cpu(), 99 | mapsize, 100 | sizes[1] if not transposed else sizes[0], 101 | transposed, 102 | torchsparse.backends.allow_tf32, 103 | torchsparse.backends.allow_fp16, 104 | ) 105 | ) 106 | 107 | else: 108 | raise NotImplementedError 109 | 110 | ctx.for_backwards = (input, weight, nbmaps, nbsizes, transposed) 111 | return output.to(weight.dtype) 112 | 113 | @staticmethod 114 | # @custom_bwd 115 | def backward(ctx, grad_output: torch.Tensor): 116 | input, weight, nbmaps, nbsizes, transposed = ctx.for_backwards 117 | 118 | if grad_output.dtype != weight.dtype: 119 | grad_output = grad_output.to(weight.dtype) 120 | 121 | print( 122 | "[Warning] Fetch_On_Demand does not have backward kernels now. Use Gather-Scatter for backward." 123 | ) 124 | grad_input = torch.zeros_like(input) 125 | grad_weight = torch.zeros_like(weight) 126 | 127 | if grad_output.device.type == "cuda": 128 | torchsparse.backend.conv_backward_gather_scatter_cuda( 129 | input, 130 | grad_input, 131 | grad_output.contiguous(), 132 | weight, 133 | grad_weight, 134 | nbmaps, 135 | nbsizes.cpu(), 136 | transposed, 137 | ) 138 | elif grad_output.device.type == "cpu": 139 | torchsparse.backend.conv_backward_gather_scatter_cpu( 140 | input, 141 | grad_input, 142 | grad_output.contiguous(), 143 | weight, 144 | grad_weight, 145 | nbmaps, 146 | nbsizes.cpu(), 147 | transposed, 148 | ) 149 | else: 150 | raise NotImplementedError 151 | return (grad_input, grad_weight, None, None, None, None) 152 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/func/gather_scatter.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from torch.autograd import Function 5 | 6 | # from torch.cuda.amp import custom_bwd, custom_fwd 7 | 8 | import torchsparse 9 | import torchsparse.backend 10 | 11 | buffer = torch.Tensor() 12 | 13 | __all__ = ["GatherScatterConvolutionFuntion"] 14 | 15 | 16 | class GatherScatterConvolutionFuntion(Function): # TorchSparse_v2 17 | @staticmethod 18 | # @custom_fwd(cast_inputs=torch.half) 19 | def forward( 20 | ctx, 21 | input: torch.Tensor, 22 | weight: torch.Tensor, 23 | kmap: Dict, 24 | config: Dict, 25 | transposed: bool = False, 26 | ) -> torch.Tensor: 27 | nbmaps = kmap["nbmaps"] 28 | nbsizes = kmap["nbsizes"].cpu() 29 | sizes = kmap["sizes"] 30 | input_mask = kmap["input_mask"] 31 | output_mask = kmap["output_mask"] 32 | epsilon = config["epsilon"] 33 | mm_thresh = config["mm_thresh"] 34 | 35 | conv_mode = 0 36 | global buffer 37 | if torchsparse.backends.benchmark: # type: ignore 38 | conv_mode = 1 if (epsilon == 0.0 and mm_thresh == 0) else 2 39 | if buffer.shape[0] == 0 or buffer.dtype != input.dtype: 40 | buffer = torch.zeros( 41 | 4000000 * 64, 42 | dtype=input.dtype, 43 | device=input.device, 44 | requires_grad=False, 45 | ) 46 | 47 | input = input.contiguous() 48 | weight = weight.contiguous() 49 | nbmaps = nbmaps.int().contiguous() 50 | nbsizes = nbsizes.int().contiguous() 51 | 52 | if not input.device.type == "cuda": 53 | if not transposed: 54 | output = torch.zeros( 55 | sizes[1], weight.size(-1), dtype=input.dtype, device=input.device 56 | ) 57 | else: 58 | # TODO(Haotian): ensure the original, upsampled size to be the same. 59 | output = torch.zeros( 60 | sizes[0], weight.size(-1), dtype=input.dtype, device=input.device 61 | ) 62 | 63 | if input.device.type == "cuda": 64 | if torch.float16 in [input.dtype, weight.dtype]: 65 | input = input.to(torch.float16) 66 | weight = weight.to(torch.float16) 67 | 68 | output = torchsparse.backend.conv_forward_gather_scatter_cuda( 69 | input, 70 | weight, 71 | nbmaps, 72 | nbsizes.cpu(), 73 | input_mask, 74 | output_mask, 75 | sizes[1] if not transposed else sizes[0], 76 | epsilon, 77 | int(mm_thresh), 78 | conv_mode, 79 | transposed, 80 | buffer, 81 | ) 82 | elif input.device.type == "cpu": 83 | torchsparse.backend.conv_forward_gather_scatter_cpu( 84 | input, output, weight, nbmaps, nbsizes.cpu(), transposed 85 | ) 86 | else: 87 | # use the native pytorch XLA APIs for the TPU. 88 | cur_st = 0 89 | for kernel_idx in range(weight.shape[0]): 90 | cur_ed = cur_st + nbsizes[kernel_idx] 91 | in_map = nbmaps[cur_st:cur_ed, 0].long() 92 | out_map = nbmaps[cur_st:cur_ed, 1].long() 93 | cur_st += nbsizes[kernel_idx] 94 | 95 | if transposed: 96 | in_map, out_map = out_map, in_map 97 | 98 | cur_feat = input[in_map] 99 | cur_feat = torch.mm(cur_feat, weight[kernel_idx]) 100 | output[out_map] += cur_feat 101 | ctx.for_backwards = (input, weight, nbmaps, nbsizes, transposed) 102 | return output.to(weight.dtype) 103 | 104 | @staticmethod 105 | # @custom_bwd 106 | def backward(ctx, grad_output: torch.Tensor): 107 | input, weight, nbmaps, nbsizes, transposed = ctx.for_backwards 108 | 109 | if grad_output.dtype != weight.dtype: 110 | grad_output = grad_output.to(weight.dtype) 111 | 112 | grad_input = torch.zeros_like(input) 113 | grad_weight = torch.zeros_like(weight) 114 | 115 | if grad_output.device.type == "cuda": 116 | torchsparse.backend.conv_backward_gather_scatter_cuda( 117 | input, 118 | grad_input, 119 | grad_output.contiguous(), 120 | weight, 121 | grad_weight, 122 | nbmaps, 123 | nbsizes.cpu(), 124 | transposed, 125 | ) 126 | elif grad_output.device.type == "cpu": 127 | torchsparse.backend.conv_backward_gather_scatter_cpu( 128 | input, 129 | grad_input, 130 | grad_output.contiguous(), 131 | weight, 132 | grad_weight, 133 | nbmaps, 134 | nbsizes.cpu(), 135 | transposed, 136 | ) 137 | else: 138 | raise NotImplementedError 139 | return ( 140 | grad_input, 141 | grad_weight, 142 | None, 143 | None, 144 | None, 145 | ) 146 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/hash/__init__.py: -------------------------------------------------------------------------------- 1 | from .hash import * 2 | from .query import * 3 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/hash/hash.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | import torchsparse.backend 6 | 7 | __all__ = ["sphash"] 8 | 9 | 10 | def sphash( 11 | coords: torch.Tensor, offsets: Optional[torch.Tensor] = None 12 | ) -> torch.Tensor: 13 | assert coords.dtype == torch.int, coords.dtype 14 | assert coords.ndim == 2 and coords.shape[1] == 4, coords.shape 15 | coords = coords.contiguous() 16 | 17 | # TODO(Zhijian): We might be able to merge `hash_kernel` and `hash`. 18 | if offsets is None: 19 | if coords.device.type == "cuda": 20 | return torchsparse.backend.hash_cuda(coords) 21 | elif coords.device.type == "cpu": 22 | return torchsparse.backend.hash_cpu(coords) 23 | else: 24 | device = coords.device 25 | return torchsparse.backend.hash_cpu(coords.cpu()).to(device) 26 | else: 27 | assert offsets.dtype == torch.int, offsets.dtype 28 | assert offsets.ndim == 2 and offsets.shape[1] == 3, offsets.shape 29 | offsets = offsets.contiguous() 30 | 31 | if coords.device.type == "cuda": 32 | return torchsparse.backend.kernel_hash_cuda(coords, offsets) 33 | elif coords.device.type == "cpu": 34 | return torchsparse.backend.kernel_hash_cpu(coords, offsets) 35 | else: 36 | device = coords.device 37 | return torchsparse.backend.kernel_hash_cpu(coords.cpu(), offsets.cpu()).to( 38 | device 39 | ) 40 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/hash/query.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchsparse.backend 4 | 5 | __all__ = ["sphashquery", "convert_transposed_out_in_map"] 6 | 7 | 8 | def sphashquery(queries: torch.Tensor, references: torch.Tensor) -> torch.Tensor: 9 | queries = queries.contiguous() 10 | references = references.contiguous() 11 | 12 | sizes = queries.size() 13 | queries = queries.view(-1) 14 | 15 | indices = torch.arange(len(references), device=queries.device, dtype=torch.long) 16 | 17 | if queries.device.type == "cuda": 18 | hashtable = torchsparse.backend.GPUHashTable(references.shape[0] * 2) 19 | hashtable.insert_vals(references) 20 | output = hashtable.lookup_vals(queries) 21 | elif queries.device.type == "cpu": 22 | output = torchsparse.backend.hash_query_cpu(queries, references, indices) 23 | else: 24 | device = queries.device 25 | output = torchsparse.backend.hash_query_cpu( 26 | queries.cpu(), references.cpu(), indices.cpu() 27 | ).to(device) 28 | 29 | output = (output - 1).view(*sizes) 30 | if output.shape[0] % 128 != 0: 31 | output = torch.cat( 32 | [ 33 | output, 34 | torch.zeros( 35 | 128 - output.shape[0] % 128, 36 | output.shape[1], 37 | device=output.device, 38 | dtype=output.dtype, 39 | ) 40 | - 1, 41 | ], 42 | dim=0, 43 | ) 44 | return output 45 | 46 | 47 | def convert_transposed_out_in_map(out_in_map, size): 48 | out_in_map_t = torch.full( 49 | (size, out_in_map.shape[1]), 50 | fill_value=-1, 51 | device=out_in_map.device, 52 | dtype=torch.int32, 53 | ) 54 | torchsparse.backend.convert_transposed_out_in_map(out_in_map, out_in_map_t) 55 | return out_in_map_t 56 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/kmap/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_kmap import * 2 | from .downsample import * 3 | from .upsample import * 4 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/kmap/downsample.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, Optional 2 | 3 | import torch 4 | 5 | import torchsparse.backend 6 | from torchsparse.utils import make_ntuple, make_tensor 7 | 8 | __all__ = ["spdownsample"] 9 | 10 | 11 | def spdownsample( 12 | _coords: torch.Tensor, 13 | stride: Union[int, Tuple[int, ...]] = 2, 14 | kernel_size: Union[int, Tuple[int, ...]] = 2, 15 | padding: torch.Tensor = 0, 16 | spatial_range: Optional[Tuple[int]] = None, 17 | downsample_mode: str = "spconv", 18 | ) -> torch.Tensor: 19 | assert downsample_mode in ["spconv", "minkowski"] 20 | 21 | stride = make_ntuple(stride, ndim=3) 22 | kernel_size = make_ntuple(kernel_size, ndim=3) 23 | padding = make_ntuple(padding, ndim=3) 24 | 25 | sample_stride = tuple([stride[k] for k in range(3)]) 26 | sample_stride = make_tensor( 27 | sample_stride, dtype=torch.int, device=_coords.device 28 | ).unsqueeze(dim=0) 29 | 30 | if ( 31 | all(stride[k] in [1, kernel_size[k]] for k in range(3)) 32 | or downsample_mode == "minkowski" 33 | ): 34 | coords = _coords.clone() 35 | coords[:, 1:] = torch.div(coords[:, 1:], sample_stride.float()).floor() 36 | coords = torch.unique(coords, dim=0) 37 | return coords 38 | else: 39 | if _coords.device.type == "cuda": 40 | _coords = _coords.contiguous() 41 | 42 | padding_t = make_tensor(padding, dtype=torch.int, device=_coords.device) 43 | kernel_size_t = make_tensor( 44 | kernel_size, dtype=torch.int, device=_coords.device 45 | ) 46 | stride_t = make_tensor(stride, dtype=torch.int, device=_coords.device) 47 | 48 | if spatial_range is not None: 49 | coords_max_tuple = tuple(x - 1 for x in spatial_range) 50 | coords_max = make_tensor( 51 | coords_max_tuple, dtype=torch.int, device=_coords.device 52 | ) 53 | else: 54 | coords_max = _coords.max(0).values 55 | coords_max[1:] = ( 56 | coords_max[1:] + 2 * padding_t - (kernel_size_t - 1) 57 | ) // stride_t 58 | 59 | if torchsparse.tensor.get_allow_negative_coordinates(): 60 | coords_min = _coords.min(0).values 61 | coords_min[1:] = torch.div( 62 | coords_min[1:] - 2 * padding_t + (kernel_size_t - 1), stride_t 63 | ) 64 | else: 65 | coords_min = make_tensor( 66 | (0, 0, 0, 0), dtype=torch.int, device=_coords.device 67 | ) 68 | 69 | out_coords = torchsparse.backend.downsample_cuda( 70 | _coords, 71 | coords_max, 72 | coords_min, 73 | kernel_size_t, 74 | stride_t, 75 | padding_t, 76 | ) 77 | return out_coords 78 | else: 79 | raise NotImplementedError 80 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/kmap/func/__init__.py: -------------------------------------------------------------------------------- 1 | from .hashmap import * 2 | from .hashmap_on_the_fly import * 3 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/kmap/upsample.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union, Optional 2 | 3 | import torch 4 | 5 | import torchsparse.backend 6 | from torchsparse.utils import make_ntuple, make_tensor 7 | from torchsparse.nn.utils.kernel import get_kernel_offsets 8 | 9 | __all__ = ["spupsample_generative"] 10 | 11 | 12 | def spupsample_generative( 13 | _coords: torch.Tensor, 14 | stride: Union[int, Tuple[int, ...]] = 2, 15 | kernel_size: Union[int, Tuple[int, ...]] = 2, 16 | padding: torch.Tensor = 0, 17 | spatial_range: Optional[Tuple[int]] = None, 18 | ) -> torch.Tensor: 19 | stride = make_ntuple(stride, ndim=3) 20 | kernel_size = make_ntuple(kernel_size, ndim=3) 21 | padding = make_ntuple(padding, ndim=3) 22 | sample_stride = make_tensor( 23 | stride, dtype=torch.int, device=_coords.device 24 | ).unsqueeze(0) 25 | # stride and dilation are both 1 26 | kernel_offsets = get_kernel_offsets(kernel_size, 1, 1, device=_coords.device) 27 | coords = _coords.clone() 28 | coords[:, 1:] *= sample_stride 29 | coords = coords.unsqueeze(1).repeat(1, kernel_offsets.size(0), 1) 30 | coords[:, :, 1:] = coords[:, :, 1:] + kernel_offsets.unsqueeze(0) 31 | assert ( 32 | spatial_range is not None 33 | ), "spatial range must be specified in generative mode" 34 | for i in range(1, coords.size(-1)): 35 | coords[:, :, i].clamp_(min=0, max=spatial_range[i] - 1) 36 | coords = coords.reshape(-1, coords.size(-1)) 37 | coords = torch.unique(coords, dim=0) 38 | return coords 39 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .collections import * 2 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/conv/utils/compat.py: -------------------------------------------------------------------------------- 1 | # Adapted from from python-attributedict 2 | # https://github.com/grimen/python-attributedict/blob/master/attributedict/compat.py 3 | 4 | # Copyright (c) 2018 Jonas Grimfelt <grimen@gmail.com> 5 | 6 | # Permission is hereby granted, free of charge, to any person obtaining a copy 7 | # of this software and associated documentation files (the "Software"), to deal 8 | # in the Software without restriction, including without limitation the rights 9 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | # copies of the Software, and to permit persons to whom the Software is 11 | # furnished to do so, subject to the following conditions: 12 | 13 | # The above copyright notice and this permission notice shall be included in all 14 | # copies or substantial portions of the Software. 15 | 16 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | # SOFTWARE. 23 | 24 | # ========================================= 25 | # DEPS 26 | # -------------------------------------- 27 | 28 | import rootpath 29 | 30 | # @see https://github.com/benjaminp/six/blob/master/six.py 31 | 32 | import sys 33 | import types 34 | 35 | 36 | # ========================================= 37 | # CONSTANTS 38 | # -------------------------------------- 39 | 40 | PY2 = sys.version_info[0] == 2 41 | PY3 = sys.version_info[0] == 3 42 | PY34 = sys.version_info[0:2] >= (3, 4) 43 | 44 | 45 | # ========================================= 46 | # VARIABLES 47 | # -------------------------------------- 48 | 49 | if PY3: 50 | string_types = (str,) 51 | integer_types = (int,) 52 | class_types = (type,) 53 | text_type = str 54 | binary_type = bytes 55 | 56 | else: 57 | string_types = (basestring,) 58 | integer_types = (int, long) 59 | class_types = (type, types.ClassType) 60 | text_type = unicode 61 | binary_type = str 62 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/count.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchsparse.backend 4 | 5 | __all__ = ["spcount"] 6 | 7 | 8 | def spcount(coords: torch.Tensor, num: torch.Tensor) -> torch.Tensor: 9 | coords = coords.contiguous() 10 | if coords.device.type == "cuda": 11 | return torchsparse.backend.count_cuda(coords, num) 12 | elif coords.device.type == "cpu": 13 | return torchsparse.backend.count_cpu(coords, num) 14 | else: 15 | device = coords.device 16 | return torchsparse.backend.count_cpu(coords.cpu(), num).to(device) 17 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/crop.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | from torchsparse import SparseTensor 6 | 7 | __all__ = ["spcrop"] 8 | 9 | 10 | def spcrop( 11 | input: SparseTensor, 12 | coords_min: Optional[Tuple[int, ...]] = None, 13 | coords_max: Optional[Tuple[int, ...]] = None, 14 | ) -> SparseTensor: 15 | coords, feats, stride = input.coords, input.feats, input.stride 16 | 17 | mask = torch.ones((coords.shape[0], 3), dtype=torch.bool, device=coords.device) 18 | if coords_min is not None: 19 | coords_min = torch.tensor( 20 | coords_min, dtype=torch.int, device=coords.device 21 | ).unsqueeze(dim=0) 22 | mask &= coords[:, :3] >= coords_min 23 | if coords_max is not None: 24 | coords_max = torch.tensor( 25 | coords_max, dtype=torch.int, device=coords.device 26 | ).unsqueeze(dim=0) 27 | # Using "<" instead of "<=" is for the backward compatability (in 28 | # some existing detection codebase). We might need to reflect this 29 | # in the document or change it back to "<=" in the future. 30 | mask &= coords[:, :3] < coords_max 31 | 32 | mask = torch.all(mask, dim=1) 33 | coords, feats = coords[mask], feats[mask] 34 | output = SparseTensor(coords=coords, feats=feats, stride=stride) 35 | return output 36 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/devoxelize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | # from torch.cuda.amp import custom_bwd, custom_fwd 5 | 6 | import torchsparse.backend 7 | 8 | __all__ = ["spdevoxelize", "calc_ti_weights"] 9 | 10 | 11 | def calc_ti_weights( 12 | coords: torch.Tensor, idx_query: torch.Tensor, scale: float = 1 13 | ) -> torch.Tensor: 14 | with torch.no_grad(): 15 | p = coords 16 | if scale != 1: 17 | pf = torch.floor(coords / scale) * scale 18 | else: 19 | pf = torch.floor(coords) 20 | pc = pf + scale 21 | 22 | x = p[:, 0].view(-1, 1) 23 | y = p[:, 1].view(-1, 1) 24 | z = p[:, 2].view(-1, 1) 25 | 26 | xf = pf[:, 0].view(-1, 1).float() 27 | yf = pf[:, 1].view(-1, 1).float() 28 | zf = pf[:, 2].view(-1, 1).float() 29 | 30 | xc = pc[:, 0].view(-1, 1).float() 31 | yc = pc[:, 1].view(-1, 1).float() 32 | zc = pc[:, 2].view(-1, 1).float() 33 | 34 | w0 = (xc - x) * (yc - y) * (zc - z) 35 | w1 = (xc - x) * (yc - y) * (z - zf) 36 | w2 = (xc - x) * (y - yf) * (zc - z) 37 | w3 = (xc - x) * (y - yf) * (z - zf) 38 | w4 = (x - xf) * (yc - y) * (zc - z) 39 | w5 = (x - xf) * (yc - y) * (z - zf) 40 | w6 = (x - xf) * (y - yf) * (zc - z) 41 | w7 = (x - xf) * (y - yf) * (z - zf) 42 | 43 | w = torch.cat([w0, w1, w2, w3, w4, w5, w6, w7], dim=1) 44 | # w = w.transpose(1, 0).contiguous() 45 | if scale != 1: 46 | w /= scale**3 47 | w[idx_query == -1] = 0 48 | w /= torch.sum(w, dim=1).unsqueeze(1) + 1e-8 49 | return w 50 | 51 | 52 | class DevoxelizeFunction(Function): 53 | @staticmethod 54 | # @custom_fwd(cast_inputs=torch.half) 55 | def forward( 56 | ctx, feats: torch.Tensor, coords: torch.Tensor, weights: torch.Tensor 57 | ) -> torch.Tensor: 58 | feats = feats.contiguous() 59 | coords = coords.contiguous().int() 60 | weights = weights.contiguous() 61 | 62 | if feats.device.type == "cuda": 63 | output = torchsparse.backend.devoxelize_forward_cuda(feats, coords, weights) 64 | elif feats.device.type == "cpu": 65 | output = torchsparse.backend.devoxelize_forward_cpu(feats, coords, weights) 66 | else: 67 | device = feats.device 68 | output = torchsparse.backend.devoxelize_forward_cpu( 69 | feats.cpu(), coords.cpu(), weights.cpu() 70 | ).to(device) 71 | 72 | ctx.for_backwards = (coords, weights, feats.shape[0]) 73 | return output.to(feats.dtype) 74 | 75 | @staticmethod 76 | # @custom_bwd 77 | def backward(ctx, grad_output: torch.Tensor): 78 | coords, weights, input_size = ctx.for_backwards 79 | grad_output = grad_output.contiguous() 80 | 81 | if grad_output.device.type == "cuda": 82 | grad_feats = torchsparse.backend.devoxelize_backward_cuda( 83 | grad_output, coords, weights, input_size 84 | ) 85 | elif grad_output.device.type == "cpu": 86 | grad_feats = torchsparse.backend.devoxelize_backward_cpu( 87 | grad_output, coords, weights, input_size 88 | ) 89 | else: 90 | device = grad_output.device 91 | grad_feats = torchsparse.backend.devoxelize_backward_cpu( 92 | grad_output.cpu(), coords.cpu(), weights.cpu(), input_size 93 | ).to(device) 94 | 95 | return grad_feats, None, None 96 | 97 | 98 | def spdevoxelize( 99 | feats: torch.Tensor, coords: torch.Tensor, weights: torch.Tensor 100 | ) -> torch.Tensor: 101 | return DevoxelizeFunction.apply(feats, coords, weights) 102 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/hash.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | import torchsparse.backend 6 | 7 | __all__ = ["sphash"] 8 | 9 | 10 | def sphash( 11 | coords: torch.Tensor, offsets: Optional[torch.Tensor] = None 12 | ) -> torch.Tensor: 13 | assert coords.dtype == torch.int, coords.dtype 14 | assert coords.ndim == 2 and coords.shape[1] == 4, coords.shape 15 | coords = coords.contiguous() 16 | 17 | # TODO(Zhijian): We might be able to merge `hash_kernel` and `hash`. 18 | if offsets is None: 19 | if coords.device.type == "cuda": 20 | return torchsparse.backend.hash_cuda(coords) 21 | elif coords.device.type == "cpu": 22 | return torchsparse.backend.hash_cpu(coords) 23 | else: 24 | device = coords.device 25 | return torchsparse.backend.hash_cpu(coords.cpu()).to(device) 26 | else: 27 | assert offsets.dtype == torch.int, offsets.dtype 28 | assert offsets.ndim == 2 and offsets.shape[1] == 3, offsets.shape 29 | offsets = offsets.contiguous() 30 | 31 | if coords.device.type == "cuda": 32 | return torchsparse.backend.kernel_hash_cuda(coords, offsets) 33 | elif coords.device.type == "cpu": 34 | return torchsparse.backend.kernel_hash_cpu(coords, offsets) 35 | else: 36 | device = coords.device 37 | return torchsparse.backend.kernel_hash_cpu(coords.cpu(), offsets.cpu()).to( 38 | device 39 | ) 40 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from torchsparse import SparseTensor 4 | 5 | __all__ = ["global_avg_pool", "global_max_pool"] 6 | 7 | 8 | def global_avg_pool(inputs: SparseTensor) -> torch.Tensor: 9 | batch_size = torch.max(inputs.coords[:, 0]).item() + 1 10 | outputs = [] 11 | for k in range(batch_size): 12 | input = inputs.feats[inputs.coords[:, 0] == k] 13 | output = torch.mean(input, dim=0) 14 | outputs.append(output) 15 | outputs = torch.stack(outputs, dim=0) 16 | return outputs 17 | 18 | 19 | def global_max_pool(inputs: SparseTensor) -> torch.Tensor: 20 | batch_size = torch.max(inputs.coords[:, 0]).item() + 1 21 | outputs = [] 22 | for k in range(batch_size): 23 | input = inputs.feats[inputs.coords[:, 0] == k] 24 | output = torch.max(input, dim=0)[0] 25 | outputs.append(output) 26 | outputs = torch.stack(outputs, dim=0) 27 | return outputs 28 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/query.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import torchsparse.backend 4 | 5 | __all__ = ["sphashquery"] 6 | 7 | 8 | def sphashquery(queries: torch.Tensor, references: torch.Tensor) -> torch.Tensor: 9 | queries = queries.contiguous() 10 | references = references.contiguous() 11 | 12 | sizes = queries.size() 13 | queries = queries.view(-1) 14 | 15 | hashmap_keys = torch.zeros( 16 | 2 * references.shape[0], dtype=torch.int64, device=references.device 17 | ) 18 | hashmap_vals = torch.zeros( 19 | 2 * references.shape[0], dtype=torch.int32, device=references.device 20 | ) 21 | hashmap = torchsparse.backend.GPUHashTable(hashmap_keys, hashmap_vals) 22 | hashmap.insert_vals(references) 23 | 24 | if queries.device.type == "cuda": 25 | output = hashmap.lookup_vals(queries)[: queries.shape[0]] 26 | elif queries.device.type == "cpu": 27 | indices = torch.arange(len(references), device=queries.device, dtype=torch.long) 28 | output = torchsparse.backend.hash_query_cpu(queries, references, indices) 29 | else: 30 | device = queries.device 31 | indices = torch.arange(len(references), device=queries.device, dtype=torch.long) 32 | output = torchsparse.backend.hash_query_cpu( 33 | queries.cpu(), references.cpu(), indices.cpu() 34 | ).to(device) 35 | 36 | output = (output - 1).view(*sizes) 37 | return output 38 | -------------------------------------------------------------------------------- /torchsparse/nn/functional/voxelize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | # from torch.cuda.amp import custom_bwd, custom_fwd 5 | 6 | import torchsparse.backend 7 | 8 | __all__ = ["spvoxelize"] 9 | 10 | 11 | class VoxelizeFunction(Function): 12 | @staticmethod 13 | # @custom_fwd(cast_inputs=torch.half) 14 | def forward( 15 | ctx, feats: torch.Tensor, coords: torch.Tensor, counts: torch.Tensor 16 | ) -> torch.Tensor: 17 | feats = feats.contiguous() 18 | coords = coords.contiguous().int() 19 | 20 | if feats.device.type == "cuda": 21 | output = torchsparse.backend.voxelize_forward_cuda(feats, coords, counts) 22 | elif feats.device.type == "cpu": 23 | output = torchsparse.backend.voxelize_forward_cpu(feats, coords, counts) 24 | else: 25 | device = feats.device 26 | output = torchsparse.backend.voxelize_forward_cpu( 27 | feats.cpu(), coords.cpu(), counts.cpu() 28 | ).to(device) 29 | 30 | ctx.for_backwards = (coords, counts, feats.shape[0]) 31 | return output.to(feats.dtype) 32 | 33 | @staticmethod 34 | # @custom_bwd 35 | def backward(ctx, grad_output: torch.Tensor): 36 | coords, counts, input_size = ctx.for_backwards 37 | grad_output = grad_output.contiguous() 38 | 39 | if grad_output.device.type == "cuda": 40 | grad_feats = torchsparse.backend.voxelize_backward_cuda( 41 | grad_output, coords, counts, input_size 42 | ) 43 | elif grad_output.device.type == "cpu": 44 | grad_feats = torchsparse.backend.voxelize_backward_cpu( 45 | grad_output, coords, counts, input_size 46 | ) 47 | else: 48 | device = grad_output.device 49 | grad_feats = torchsparse.backend.voxelize_backward_cpu( 50 | grad_output.cpu(), coords.cpu(), counts.cpu(), input_size 51 | ).to(device) 52 | 53 | return grad_feats, None, None 54 | 55 | 56 | def spvoxelize( 57 | feats: torch.Tensor, coords: torch.Tensor, counts: torch.Tensor 58 | ) -> torch.Tensor: 59 | return VoxelizeFunction.apply(feats, coords, counts) 60 | -------------------------------------------------------------------------------- /torchsparse/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .activation import * 2 | from .bev import * 3 | from .conv import * 4 | from .crop import * 5 | from .norm import * 6 | from .pooling import * 7 | -------------------------------------------------------------------------------- /torchsparse/nn/modules/activation.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from torchsparse import SparseTensor 4 | from torchsparse.nn.utils import fapply 5 | 6 | __all__ = ["ReLU", "LeakyReLU", "SiLU"] 7 | 8 | 9 | class ReLU(nn.ReLU): 10 | def forward(self, input: SparseTensor) -> SparseTensor: 11 | return fapply(input, super().forward) 12 | 13 | 14 | class LeakyReLU(nn.LeakyReLU): 15 | def forward(self, input: SparseTensor) -> SparseTensor: 16 | return fapply(input, super().forward) 17 | 18 | 19 | class SiLU(nn.SiLU): 20 | def forward(self, input: SparseTensor) -> SparseTensor: 21 | return fapply(input, super().forward) 22 | -------------------------------------------------------------------------------- /torchsparse/nn/modules/conv.py: -------------------------------------------------------------------------------- 1 | import math 2 | import sys 3 | from typing import Dict, List, Tuple, Union 4 | 5 | if sys.version_info >= (3, 8): 6 | from functools import cached_property 7 | else: 8 | from backports.cached_property import cached_property 9 | 10 | import numpy as np 11 | import torch 12 | from torch import nn 13 | 14 | import torchsparse 15 | from torchsparse import SparseTensor 16 | from torchsparse.nn import functional as F 17 | from torchsparse.utils import make_ntuple 18 | 19 | __all__ = ["Conv3d"] 20 | 21 | 22 | class Conv3d(nn.Module): 23 | def __init__( 24 | self, 25 | in_channels: int, 26 | out_channels: int, 27 | kernel_size: Union[int, List[int], Tuple[int, ...]] = 3, 28 | stride: Union[int, List[int], Tuple[int, ...]] = 1, 29 | padding: Union[int, Tuple[int, ...]] = 0, 30 | dilation: int = 1, 31 | bias: bool = False, 32 | transposed: bool = False, 33 | generative: bool = False, 34 | config: Dict = None, 35 | ) -> None: 36 | super().__init__() 37 | self.in_channels = in_channels 38 | self.out_channels = out_channels 39 | self.kernel_size = make_ntuple(kernel_size, ndim=3) 40 | self.stride = make_ntuple(stride, ndim=3) 41 | self.dilation = dilation 42 | _padding = make_ntuple(padding, 3) 43 | self.padding = () 44 | for i in range(3): 45 | if self.kernel_size[i] % 2 == 1 and self.stride[i] == 1: 46 | self.padding += ((self.kernel_size[i] - 1) // 2,) 47 | else: 48 | self.padding += (_padding[i],) 49 | self.transposed = transposed 50 | self.generative = generative 51 | if self.generative: 52 | assert self.transposed 53 | 54 | self._config = config 55 | 56 | self.kernel_volume = int(np.prod(self.kernel_size)) 57 | if ( 58 | self.kernel_volume > 1 59 | or self.kernel_volume == 1 60 | and self.stride != (1, 1, 1) 61 | ): 62 | self.kernel = nn.Parameter( 63 | torch.zeros(self.kernel_volume, in_channels, out_channels) 64 | ) 65 | else: 66 | self.kernel = nn.Parameter(torch.zeros(in_channels, out_channels)) 67 | if bias: 68 | self.bias = nn.Parameter(torch.Tensor(out_channels)) 69 | else: 70 | self.register_parameter("bias", None) 71 | self.reset_parameters() 72 | 73 | def extra_repr(self) -> str: 74 | s = "{in_channels}, {out_channels}, kernel_size={kernel_size}" 75 | if self.stride != (1,) * len(self.stride): 76 | s += ", stride={stride}" 77 | if self.dilation != 1: 78 | s += ", dilation={dilation}" 79 | if self.bias is None: 80 | s += ", bias=False" 81 | if self.transposed: 82 | s += ", transposed=True" 83 | if self.generative: 84 | s += ", generative=True" 85 | return s.format(**self.__dict__) 86 | 87 | def reset_parameters(self) -> None: 88 | std = 1 / math.sqrt( 89 | (self.out_channels if self.transposed else self.in_channels) 90 | * self.kernel_volume 91 | ) 92 | self.kernel.data.uniform_(-std, std) 93 | if self.bias is not None: 94 | self.bias.data.uniform_(-std, std) 95 | 96 | def forward(self, input: SparseTensor) -> SparseTensor: 97 | 98 | return F.conv3d( 99 | input, 100 | weight=self.kernel, 101 | kernel_size=self.kernel_size, 102 | bias=self.bias, 103 | stride=self.stride, 104 | padding=self.padding, 105 | dilation=self.dilation, 106 | transposed=self.transposed, 107 | generative=self.generative, 108 | config=self._config, 109 | training=self.training, 110 | ) 111 | -------------------------------------------------------------------------------- /torchsparse/nn/modules/crop.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | from torch import nn 4 | 5 | from torchsparse import SparseTensor 6 | from torchsparse.nn import functional as F 7 | 8 | __all__ = ["SparseCrop"] 9 | 10 | 11 | class SparseCrop(nn.Module): 12 | def __init__( 13 | self, 14 | coords_min: Optional[Tuple[int, ...]] = None, 15 | coords_max: Optional[Tuple[int, ...]] = None, 16 | ) -> None: 17 | super().__init__() 18 | self.coords_min = coords_min 19 | self.coords_max = coords_max 20 | 21 | def forward(self, input: SparseTensor) -> SparseTensor: 22 | return F.spcrop(input, self.coords_min, self.coords_max) 23 | -------------------------------------------------------------------------------- /torchsparse/nn/modules/norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from torchsparse import SparseTensor 5 | from torchsparse.nn.utils import fapply 6 | 7 | __all__ = ["BatchNorm", "GroupNorm", "InstanceNorm"] 8 | 9 | 10 | class InstanceNorm(nn.InstanceNorm1d): 11 | def forward(self, input: SparseTensor) -> SparseTensor: 12 | return fapply(input, super().forward) 13 | 14 | 15 | class BatchNorm(nn.BatchNorm1d): 16 | def forward(self, input: SparseTensor) -> SparseTensor: 17 | return fapply(input, super().forward) 18 | 19 | 20 | class GroupNorm(nn.GroupNorm): 21 | def forward(self, input: SparseTensor) -> SparseTensor: 22 | coords, feats, stride = input.coords, input.feats, input.stride 23 | 24 | batch_size = torch.max(coords[:, 0]).item() + 1 25 | num_channels = feats.shape[1] 26 | 27 | # PyTorch's GroupNorm function expects the input to be in (N, C, *) 28 | # format where N is batch size, and C is number of channels. "feats" 29 | # is not in that format. So, we extract the feats corresponding to 30 | # each sample, bring it to the format expected by PyTorch's GroupNorm 31 | # function, and invoke it. 32 | nfeats = torch.zeros_like(feats) 33 | for k in range(batch_size): 34 | indices = coords[:, 0] == k 35 | bfeats = feats[indices] 36 | bfeats = bfeats.transpose(0, 1).reshape(1, num_channels, -1) 37 | bfeats = super().forward(bfeats) 38 | bfeats = bfeats.reshape(num_channels, -1).transpose(0, 1) 39 | nfeats[indices] = bfeats 40 | 41 | output = SparseTensor( 42 | coords=coords, 43 | feats=nfeats, 44 | stride=stride, 45 | spatial_range=input.spatial_range, 46 | ) 47 | output._caches = input._caches 48 | return output 49 | -------------------------------------------------------------------------------- /torchsparse/nn/modules/pooling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from torchsparse import SparseTensor 5 | from torchsparse.nn import functional as F 6 | 7 | __all__ = ["GlobalAvgPool", "GlobalMaxPool"] 8 | 9 | 10 | class GlobalAvgPool(nn.Module): 11 | def forward(self, input: SparseTensor) -> torch.Tensor: 12 | return F.global_avg_pool(input) 13 | 14 | 15 | class GlobalMaxPool(nn.Module): 16 | def forward(self, input: SparseTensor) -> torch.Tensor: 17 | return F.global_max_pool(input) 18 | -------------------------------------------------------------------------------- /torchsparse/nn/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .apply import * 2 | from .kernel import * 3 | -------------------------------------------------------------------------------- /torchsparse/nn/utils/apply.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | 5 | from torchsparse import SparseTensor 6 | 7 | __all__ = ["fapply"] 8 | 9 | 10 | def fapply( 11 | input: SparseTensor, fn: Callable[..., torch.Tensor], *args, **kwargs 12 | ) -> SparseTensor: 13 | feats = fn(input.feats, *args, **kwargs) 14 | output = SparseTensor( 15 | coords=input.coords, 16 | feats=feats, 17 | stride=input.stride, 18 | spatial_range=input.spatial_range, 19 | ) 20 | output._caches = input._caches 21 | return output 22 | -------------------------------------------------------------------------------- /torchsparse/nn/utils/kernel.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from torchsparse.utils import make_ntuple, make_tensor 7 | 8 | __all__ = ["get_kernel_offsets"] 9 | 10 | 11 | def get_kernel_offsets( 12 | size: Union[int, Tuple[int, ...]], 13 | stride: Union[int, Tuple[int, ...]] = 1, 14 | dilation: Union[int, Tuple[int, ...]] = 1, 15 | device="cpu", 16 | ) -> torch.Tensor: 17 | size = make_ntuple(size, ndim=3) 18 | stride = make_ntuple(stride, ndim=3) 19 | dilation = make_ntuple(dilation, ndim=3) 20 | 21 | offsets = [ 22 | (np.arange(-size[k] // 2 + 1, size[k] // 2 + 1) * stride[k] * dilation[k]) 23 | for k in range(3) 24 | ] 25 | 26 | # This condition check is only to make sure that our weight layout is 27 | # compatible with `MinkowskiEngine`. 28 | if np.prod(size) % 2 == 1: 29 | offsets = tuple( 30 | [(x, y, z) for z in offsets[2] for y in offsets[1] for x in offsets[0]] 31 | ) 32 | else: 33 | offsets = tuple( 34 | [(x, y, z) for x in offsets[0] for y in offsets[1] for z in offsets[2]] 35 | ) 36 | 37 | offsets = make_tensor(offsets, dtype=torch.int, device=device) 38 | return offsets 39 | -------------------------------------------------------------------------------- /torchsparse/operators.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | import torch 4 | 5 | from torchsparse.tensor import SparseTensor 6 | 7 | # from torch_scatter import scatter_sum 8 | 9 | __all__ = ["cat", "generative_add"] 10 | 11 | 12 | def cat(inputs: List[SparseTensor]) -> SparseTensor: 13 | feats = torch.cat([input.feats for input in inputs], dim=1) 14 | output = SparseTensor(coords=inputs[0].coords, feats=feats, stride=inputs[0].stride) 15 | output._caches = inputs[0]._caches 16 | return output 17 | 18 | 19 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 20 | if dim < 0: 21 | dim = other.dim() + dim 22 | if src.dim() == 1: 23 | for _ in range(0, dim): 24 | src = src.unsqueeze(0) 25 | for _ in range(src.dim(), other.dim()): 26 | src = src.unsqueeze(-1) 27 | src = src.expand(other.size()) 28 | return src 29 | 30 | 31 | def scatter_sum( 32 | src: torch.Tensor, 33 | index: torch.Tensor, 34 | dim: int = -1, 35 | out: Optional[torch.Tensor] = None, 36 | dim_size: Optional[int] = None, 37 | ) -> torch.Tensor: 38 | index = broadcast(index, src, dim) 39 | if out is None: 40 | size = list(src.size()) 41 | if dim_size is not None: 42 | size[dim] = dim_size 43 | elif index.numel() == 0: 44 | size[dim] = 0 45 | else: 46 | size[dim] = int(index.max()) + 1 47 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 48 | return out.scatter_add_(dim, index, src) 49 | else: 50 | return out.scatter_add_(dim, index, src) 51 | 52 | 53 | def generative_add(a: SparseTensor, b: SparseTensor) -> SparseTensor: 54 | input_a = a if a.F.size(0) >= b.F.size(0) else b 55 | input_b = b if a.F.size(0) >= b.F.size(0) else a 56 | union_coords = torch.cat([input_a.C, input_b.C], dim=0) 57 | union_features = torch.cat([input_a.F, input_b.F], dim=0) 58 | unique_coords, unique_idx = torch.unique(union_coords, dim=0, return_inverse=True) 59 | out_feature = scatter_sum(union_features, unique_idx, dim=0) 60 | out_tensor = SparseTensor( 61 | out_feature, unique_coords, input_a.s, spatial_range=input_a.spatial_range 62 | ) 63 | out_tensor._caches = input_a._caches 64 | return out_tensor 65 | -------------------------------------------------------------------------------- /torchsparse/tensor.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple, Union, Optional, List 2 | 3 | import torch 4 | 5 | from torchsparse.utils import make_ntuple, to_dense 6 | from torchsparse.utils.tensor_cache import ( 7 | TensorCache, 8 | TensorCacheMode, 9 | get_global_tensor_cache, 10 | set_global_tensor_cache, 11 | get_tensor_cache_mode, 12 | ) 13 | 14 | __all__ = ["SparseTensor"] 15 | 16 | _allow_negative_coordinates = False 17 | 18 | 19 | def get_allow_negative_coordinates(): 20 | global _allow_negative_coordinates 21 | return _allow_negative_coordinates 22 | 23 | 24 | def set_allow_negative_coordinates(allow_negative_coordinates): 25 | global _allow_negative_coordinates 26 | _allow_negative_coordinates = allow_negative_coordinates 27 | 28 | 29 | class SparseTensor: 30 | def __init__( 31 | self, 32 | feats: torch.Tensor, 33 | coords: torch.Tensor, 34 | stride: Union[int, Tuple[int, ...]] = 1, 35 | spatial_range: Union[int, Tuple[int, ...]] = None, 36 | ) -> None: 37 | self.feats = feats 38 | self.coords = coords 39 | self.stride = make_ntuple(stride, ndim=3) 40 | if spatial_range is None: 41 | self.spatial_range = None 42 | else: 43 | self.spatial_range = make_ntuple(spatial_range, ndim=len(spatial_range)) 44 | 45 | if get_tensor_cache_mode() == TensorCacheMode.GLOBAL_TENSOR_CACHE: 46 | _caches = get_global_tensor_cache() 47 | if _caches is None: 48 | _caches = TensorCache() 49 | set_global_tensor_cache(_caches) 50 | self._caches = _caches 51 | else: 52 | self._caches = TensorCache() 53 | 54 | @property 55 | def F(self) -> torch.Tensor: 56 | return self.feats 57 | 58 | @F.setter 59 | def F(self, feats: torch.Tensor) -> None: 60 | self.feats = feats 61 | 62 | @property 63 | def C(self) -> torch.Tensor: 64 | return self.coords 65 | 66 | @C.setter 67 | def C(self, coords: torch.Tensor) -> None: 68 | self.coords = coords 69 | 70 | @property 71 | def s(self) -> Tuple[int, ...]: 72 | return self.stride 73 | 74 | @s.setter 75 | def s(self, stride: Union[int, Tuple[int, ...]]) -> None: 76 | self.stride = make_ntuple(stride, ndim=3) 77 | 78 | def cpu(self): 79 | self.coords = self.coords.cpu() 80 | self.feats = self.feats.cpu() 81 | return self 82 | 83 | def cuda(self): 84 | self.coords = self.coords.cuda() 85 | self.feats = self.feats.cuda() 86 | return self 87 | 88 | def half(self): 89 | self.feats = self.feats.half() 90 | return self 91 | 92 | def detach(self): 93 | self.coords = self.coords.detach() 94 | self.feats = self.feats.detach() 95 | return self 96 | 97 | def to(self, device, non_blocking: bool = True): 98 | self.coords = self.coords.to(device, non_blocking=non_blocking) 99 | self.feats = self.feats.to(device, non_blocking=non_blocking) 100 | return self 101 | 102 | def dense(self): 103 | assert self.spatial_range is not None 104 | return to_dense(self.feats, self.coords, self.spatial_range) 105 | 106 | def __add__(self, other): 107 | output = SparseTensor( 108 | coords=self.coords, 109 | feats=self.feats + other.feats, 110 | stride=self.stride, 111 | spatial_range=self.spatial_range, 112 | ) 113 | output._caches = self._caches 114 | return output 115 | 116 | class PointTensor: 117 | def __init__(self, feats, coords, idx_query=None, weights=None): 118 | self.F = feats 119 | self.C = coords 120 | self.idx_query = idx_query if idx_query is not None else {} 121 | self.weights = weights if weights is not None else {} 122 | self.additional_features = {} 123 | self.additional_features['idx_query'] = {} 124 | self.additional_features['counts'] = {} 125 | 126 | def cuda(self): 127 | self.F = self.F.cuda() 128 | self.C = self.C.cuda() 129 | return self 130 | 131 | def detach(self): 132 | self.F = self.F.detach() 133 | self.C = self.C.detach() 134 | return self 135 | 136 | def to(self, device, non_blocking=True): 137 | self.F = self.F.to(device, non_blocking=non_blocking) 138 | self.C = self.C.to(device, non_blocking=non_blocking) 139 | return self 140 | 141 | def __add__(self, other): 142 | tensor = PointTensor(self.F + other.F, self.C, self.idx_query, 143 | self.weights) 144 | tensor.additional_features = self.additional_features 145 | return tensor 146 | 147 | -------------------------------------------------------------------------------- /torchsparse/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import * 2 | from .to_dense import * 3 | -------------------------------------------------------------------------------- /torchsparse/utils/collate.py: -------------------------------------------------------------------------------- 1 | from typing import Any, List 2 | 3 | import numpy as np 4 | import torch 5 | 6 | from torchsparse import SparseTensor 7 | 8 | __all__ = ["sparse_collate", "sparse_collate_fn"] 9 | 10 | 11 | def sparse_collate(inputs: List[SparseTensor]) -> SparseTensor: 12 | coords, feats = [], [] 13 | stride = inputs[0].stride 14 | 15 | for k, x in enumerate(inputs): 16 | if isinstance(x.coords, np.ndarray): 17 | x.coords = torch.tensor(x.coords) 18 | if isinstance(x.feats, np.ndarray): 19 | x.feats = torch.tensor(x.feats) 20 | 21 | assert isinstance(x.coords, torch.Tensor), type(x.coords) 22 | assert isinstance(x.feats, torch.Tensor), type(x.feats) 23 | assert x.stride == stride, (x.stride, stride) 24 | 25 | input_size = x.coords.shape[0] 26 | batch = torch.full((input_size, 1), k, device=x.coords.device, dtype=torch.int) 27 | coords.append(torch.cat((batch, x.coords), dim=1)) 28 | feats.append(x.feats) 29 | 30 | coords = torch.cat(coords, dim=0) 31 | feats = torch.cat(feats, dim=0) 32 | output = SparseTensor(coords=coords, feats=feats, stride=stride) 33 | return output 34 | 35 | 36 | def sparse_collate_fn(inputs: List[Any]) -> Any: 37 | if isinstance(inputs[0], dict): 38 | output = {} 39 | for name in inputs[0].keys(): 40 | if isinstance(inputs[0][name], dict): 41 | output[name] = sparse_collate_fn([input[name] for input in inputs]) 42 | elif isinstance(inputs[0][name], np.ndarray): 43 | output[name] = torch.stack( 44 | [torch.tensor(input[name]) for input in inputs], dim=0 45 | ) 46 | elif isinstance(inputs[0][name], torch.Tensor): 47 | output[name] = torch.stack([input[name] for input in inputs], dim=0) 48 | elif isinstance(inputs[0][name], SparseTensor): 49 | output[name] = sparse_collate([input[name] for input in inputs]) 50 | else: 51 | output[name] = [input[name] for input in inputs] 52 | return output 53 | else: 54 | return inputs 55 | -------------------------------------------------------------------------------- /torchsparse/utils/quantize.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | from typing import List, Tuple, Union 3 | 4 | import numpy as np 5 | 6 | __all__ = ["sparse_quantize"] 7 | 8 | 9 | def ravel_hash(x: np.ndarray) -> np.ndarray: 10 | assert x.ndim == 2, x.shape 11 | 12 | x = x - np.min(x, axis=0) 13 | x = x.astype(np.uint64, copy=False) 14 | xmax = np.max(x, axis=0).astype(np.uint64) + 1 15 | 16 | h = np.zeros(x.shape[0], dtype=np.uint64) 17 | for k in range(x.shape[1] - 1): 18 | h += x[:, k] 19 | h *= xmax[k + 1] 20 | h += x[:, -1] 21 | return h 22 | 23 | 24 | def sparse_quantize( 25 | coords, 26 | voxel_size: Union[float, Tuple[float, ...]] = 1, 27 | *, 28 | return_index: bool = False, 29 | return_inverse: bool = False 30 | ) -> List[np.ndarray]: 31 | if isinstance(voxel_size, (float, int)): 32 | voxel_size = tuple(repeat(voxel_size, 3)) 33 | assert isinstance(voxel_size, tuple) and len(voxel_size) == 3 34 | 35 | voxel_size = np.array(voxel_size) 36 | coords = np.floor(coords / voxel_size).astype(np.int32) 37 | 38 | _, indices, inverse_indices = np.unique( 39 | ravel_hash(coords), return_index=True, return_inverse=True 40 | ) 41 | coords = coords[indices] 42 | 43 | outputs = [coords] 44 | if return_index: 45 | outputs += [indices] 46 | if return_inverse: 47 | outputs += [inverse_indices] 48 | return outputs[0] if len(outputs) == 1 else outputs 49 | -------------------------------------------------------------------------------- /torchsparse/utils/tensor_cache.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Tuple, Union 2 | from enum import Enum 3 | import copy 4 | 5 | 6 | class TensorCacheMode(Enum): 7 | SEPARATE_TENSOR_CACHE = 0 8 | GLOBAL_TENSOR_CACHE = 1 9 | 10 | 11 | _tensor_cache_mode = TensorCacheMode.SEPARATE_TENSOR_CACHE 12 | _global_tensor_cache = None 13 | 14 | 15 | def set_tensor_cache_mode(mode: TensorCacheMode): 16 | r""" 17 | _tensor_cache_mode is set SEPARATE_TENSOR_CACHE by default 18 | if _tensor_cache_mode is set to GLOBAL_TENSOR_CACHE 19 | the _global_tensor_cache must be cleared after each forward/backward 20 | """ 21 | assert isinstance( 22 | mode, TensorCacheMode 23 | ), f"Input must be an instance of TensorCacheMode" 24 | global _tensor_cache_mode 25 | _tensor_cache_mode = mode 26 | 27 | 28 | def get_tensor_cache_mode() -> TensorCacheMode: 29 | global _tensor_cache_mode 30 | return copy.deepcopy(_tensor_cache_mode) 31 | 32 | 33 | class TensorCache: 34 | def __init__( 35 | self, 36 | ) -> None: 37 | self.cmaps: Dict[Tuple[int, ...], Tuple[torch.Tensor, Tuple[int, ...]]] = {} 38 | self.kmaps: Dict[Tuple[Any, ...], Any] = {} 39 | self.hashmaps: Dict[Tuple[int, ...], Tuple[Any, ...]] = {} 40 | 41 | 42 | def get_global_tensor_cache(): 43 | global _global_tensor_cache 44 | return _global_tensor_cache 45 | 46 | 47 | def set_global_tensor_cache(tensor_cache): 48 | global _global_tensor_cache 49 | _global_tensor_cache = tensor_cache 50 | 51 | 52 | def clear_global_tensor_cache(): 53 | r""" 54 | if _tensor_cache_mode is set to GLOBAL_TENSOR_CACHE 55 | the _global_tensor_cache must be cleared after each forward/backward 56 | """ 57 | global _global_tensor_cache 58 | _global_tensor_cache = None 59 | -------------------------------------------------------------------------------- /torchsparse/utils/to_dense.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | 4 | # from torch.cuda.amp import custom_bwd, custom_fwd 5 | from typing import Tuple 6 | 7 | import torchsparse.backend 8 | from torchsparse.utils.utils import make_tensor 9 | 10 | __all__ = ["to_dense"] 11 | 12 | 13 | class ToDenseFunction(Function): 14 | @staticmethod 15 | # @custom_fwd(cast_inputs=torch.half) 16 | def forward( 17 | ctx, 18 | feats: torch.Tensor, 19 | coords: torch.Tensor, 20 | spatial_range: Tuple[int], 21 | ) -> torch.Tensor: 22 | feats = feats.contiguous() 23 | coords = coords.contiguous().int() 24 | outputs = torch.zeros( 25 | spatial_range + (feats.size(1),), dtype=feats.dtype, device=feats.device 26 | ) 27 | spatial_range = make_tensor(spatial_range, dtype=torch.int, device=feats.device) 28 | 29 | if feats.device.type == "cuda": 30 | torchsparse.backend.to_dense_forward_cuda( 31 | feats, coords, spatial_range, outputs 32 | ) 33 | else: 34 | raise NotImplementedError 35 | 36 | ctx.for_backwards = (coords, spatial_range) 37 | return outputs.to(feats.dtype) 38 | 39 | @staticmethod 40 | # @custom_bwd 41 | def backward(ctx, grad_output: torch.Tensor): 42 | coords, spatial_range = ctx.for_backwards 43 | grad_output = grad_output.contiguous() 44 | grad_feats = torch.zeros( 45 | coords.size(0), 46 | grad_output.size(-1), 47 | dtype=grad_output.dtype, 48 | device=grad_output.device, 49 | ) 50 | 51 | if grad_output.device.type == "cuda": 52 | torchsparse.backend.to_dense_backward_cuda( 53 | grad_output, coords, spatial_range, grad_feats 54 | ) 55 | else: 56 | raise NotImplementedError 57 | 58 | return grad_feats, None, None 59 | 60 | 61 | def to_dense( 62 | feats: torch.Tensor, coords: torch.Tensor, spatial_range: Tuple[int] 63 | ) -> torch.Tensor: 64 | return ToDenseFunction.apply(feats, coords, spatial_range) 65 | -------------------------------------------------------------------------------- /torchsparse/utils/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | from typing import List, Tuple, Union 3 | from functools import lru_cache 4 | import torch 5 | 6 | __all__ = ["make_ntuple", "make_tensor", "make_divisible"] 7 | 8 | 9 | def make_ntuple( 10 | x: Union[int, List[int], Tuple[int, ...], torch.Tensor], ndim: int 11 | ) -> Tuple[int, ...]: 12 | if isinstance(x, int): 13 | x = tuple(repeat(x, ndim)) 14 | elif isinstance(x, list): 15 | x = tuple(x) 16 | elif isinstance(x, torch.Tensor): 17 | x = tuple(x.view(-1).cpu().numpy().tolist()) 18 | 19 | assert isinstance(x, tuple) and len(x) == ndim, x 20 | return x 21 | 22 | 23 | @lru_cache() 24 | def make_tensor(x: Tuple[int, ...], dtype: torch.dtype, device) -> torch.Tensor: 25 | return torch.tensor(x, dtype=dtype, device=device) 26 | 27 | 28 | def make_divisible(x: int, divisor: int): 29 | return (x + divisor - 1) // divisor * divisor 30 | -------------------------------------------------------------------------------- /torchsparse/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '2.1.0' 2 | --------------------------------------------------------------------------------