├── .clang-format ├── .gitignore ├── LICENSE ├── README.md ├── csrc ├── cpu │ ├── bucket_fps │ │ ├── dynamic │ │ │ ├── Interval.h │ │ │ ├── KDLineTree.h │ │ │ ├── KDNode.h │ │ │ ├── KDTreeBase.h │ │ │ ├── Point.h │ │ │ └── utils.h │ │ ├── static │ │ │ ├── Interval.h │ │ │ ├── KDLineTree.h │ │ │ ├── KDNode.h │ │ │ ├── KDTreeBase.h │ │ │ ├── Point.h │ │ │ └── utils.h │ │ ├── wrapper.cpp │ │ └── wrapper.h │ └── fpsample_cpu.cpp ├── cuda │ └── fpsample_cuda.cpp ├── fpsample.cpp ├── fpsample_autograd.cpp ├── fpsample_meta.cpp └── utils.h ├── setup.py └── torch_fpsample ├── __init__.py └── fps.py /.clang-format: -------------------------------------------------------------------------------- 1 | --- 2 | BasedOnStyle: LLVM 3 | IndentWidth: 4 4 | TabWidth: 4 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | # vscode 165 | .vscode/ 166 | 167 | tmp_* 168 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 AyajiLin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PyTorch fpsample 2 | 3 | PyTorch efficient farthest point sampling (FPS) implementation, adopted from [fpsample](https://github.com/leonardodalinky/fpsample). 4 | 5 | **Currently, this project is under heavy development and not ready for production use. If you want to make a contribution on implementing the GPU version, please feel free to contact me and make PRs.** 6 | 7 | > [!NOTE] 8 | > Since the PyTorch capsules the native multithread implementation, this project is expected to have a much better performance than the *fpsample* implementation. 9 | 10 | ## Installation 11 | 12 | ```bash 13 | # Install from github 14 | pip install git+https://github.com/leonardodalinky/pytorch_fpsample 15 | 16 | # Build locally 17 | pip install . 18 | ``` 19 | 20 | ## Usage 21 | 22 | ```python 23 | import torch_fpsample 24 | 25 | x = torch.rand(64, 2048, 3) 26 | # random sample 27 | sampled_points, indices = torch_fpsample.sample(x, 1024) 28 | # random sample with specific tree height 29 | sampled_points, indices = torch_fpsample.sample(x, 1024, h=5) 30 | # random sample with start point index (int) 31 | sampled_points, indices = torch_fpsample.sample(x, 1024, start_idx=0) 32 | 33 | > sampled_points.size(), indices.size() 34 | Size([64, 1024, 3]), Size([64, 1024]) 35 | ``` 36 | 37 | > [!WARNING] 38 | > Note: The GPU version is not implemented yet. Only CPU mode is available. 39 | 40 | ## Reference 41 | Bucket-based farthest point sampling (QuickFPS) is proposed in the following paper. The implementation is based on the author's Repo ([CPU](https://github.com/hanm2019/bucket-based_farthest-point-sampling_CPU) & [GPU](https://github.com/hanm2019/bucket-based_farthest-point-sampling_GPU)). 42 | ```bibtex 43 | @article{han2023quickfps, 44 | title={QuickFPS: Architecture and Algorithm Co-Design for Farthest Point Sampling in Large-Scale Point Clouds}, 45 | author={Han, Meng and Wang, Liang and Xiao, Limin and Zhang, Hao and Zhang, Chenhao and Xu, Xiangrong and Zhu, Jianfeng}, 46 | journal={IEEE Transactions on Computer-Aided Design of Integrated Circuits and Systems}, 47 | year={2023}, 48 | publisher={IEEE} 49 | } 50 | ``` 51 | 52 | Thanks to the authors for their great works. -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/dynamic/Interval.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 韩萌 on 2022/6/14. 3 | // Refactored by AyajiLin on 2024/09/03. 4 | // 5 | 6 | #pragma once 7 | 8 | namespace quickfps::dynamic { 9 | template class Interval { 10 | public: 11 | S low, high; 12 | Interval() : low(0), high(0) {}; 13 | Interval(S low, S high) : low(low), high(high) {}; 14 | Interval(const Interval &o) : low(o.low), high(o.high) {}; 15 | }; 16 | } // namespace quickfps::dynamic 17 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/dynamic/KDLineTree.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by hanm on 22-6-15. 3 | // Refactored by AyajiLin on 2024/09/03. 4 | // 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | #include "KDTreeBase.h" 12 | 13 | namespace quickfps::dynamic { 14 | 15 | template 16 | class KDLineTree : public KDTreeBase { 17 | public: 18 | using typename KDTreeBase::_Point; 19 | using typename KDTreeBase::_Points; 20 | using typename KDTreeBase::NodePtr; 21 | 22 | KDLineTree(_Points data, size_t pointSize, size_t treeHigh, 23 | _Points samplePoints); 24 | ~KDLineTree(); 25 | 26 | std::vector KDNode_list; 27 | 28 | size_t high_; 29 | 30 | _Point max_point() override; 31 | 32 | void update_distance(const _Point &ref_point) override; 33 | 34 | void sample(size_t sample_num) override; 35 | 36 | bool leftNode(size_t high, size_t count) const override { 37 | return high == this->high_ || count == 1; 38 | }; 39 | 40 | void addNode(NodePtr p) override; 41 | }; 42 | 43 | template 44 | KDLineTree::KDLineTree(_Points data, size_t pointSize, size_t treeHigh, 45 | _Points samplePoints) 46 | : KDTreeBase(data, pointSize, samplePoints), high_(treeHigh) { 47 | KDNode_list.clear(); 48 | } 49 | 50 | template KDLineTree::~KDLineTree() { 51 | KDNode_list.clear(); 52 | } 53 | 54 | template 55 | typename KDLineTree::_Point KDLineTree::max_point() { 56 | _Point tmpPoint(this->dim()); 57 | S max_distance = std::numeric_limits::lowest(); 58 | for (const auto &bucket : KDNode_list) { 59 | if (bucket->max_point.dis > max_distance) { 60 | max_distance = bucket->max_point.dis; 61 | tmpPoint = bucket->max_point; 62 | } 63 | } 64 | return tmpPoint; 65 | } 66 | 67 | template 68 | void KDLineTree::update_distance(const _Point &ref_point) { 69 | for (const auto &bucket : KDNode_list) { 70 | bucket->send_delay_point(ref_point); 71 | bucket->update_distance(); 72 | } 73 | } 74 | 75 | template 76 | void KDLineTree::sample(size_t sample_num) { 77 | for (size_t i = 1; i < sample_num; i++) { 78 | _Point ref_point = this->max_point(); 79 | this->sample_points[i] = ref_point; 80 | this->update_distance(ref_point); 81 | } 82 | } 83 | 84 | template void KDLineTree::addNode(NodePtr p) { 85 | size_t nodeIdx = KDNode_list.size(); 86 | p->idx = nodeIdx; 87 | KDNode_list.push_back(p); 88 | } 89 | 90 | } // namespace quickfps::dynamic 91 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/dynamic/KDNode.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 韩萌 on 2022/6/14. 3 | // Refactored by AyajiLin on 2024/09/03. 4 | // 5 | 6 | #pragma once 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "Interval.h" 13 | #include "Point.h" 14 | 15 | namespace quickfps::dynamic { 16 | 17 | template class KDNode { 18 | public: 19 | using _Point = Point; 20 | using _Points = _Point *; 21 | _Points points; 22 | size_t pointLeft, pointRight; 23 | size_t idx; 24 | 25 | std::vector> bboxs; 26 | std::vector<_Point> waitpoints; 27 | std::vector<_Point> delaypoints; 28 | _Point max_point; 29 | KDNode *left; 30 | KDNode *right; 31 | 32 | KDNode(); 33 | 34 | KDNode(const KDNode &a); 35 | 36 | KDNode(const std::vector> &bboxs); 37 | 38 | void init(const _Point &ref); 39 | 40 | size_t dim() const { return max_point.dim(); }; 41 | 42 | void updateMaxPoint(const _Point &lpoint, const _Point &rpoint) { 43 | if (lpoint.dis > rpoint.dis) 44 | this->max_point = lpoint; 45 | else 46 | this->max_point = rpoint; 47 | } 48 | 49 | S bound_distance(const _Point &ref_point) const; 50 | 51 | void send_delay_point(const _Point &point) { 52 | this->waitpoints.push_back(point); 53 | } 54 | 55 | void update_distance(); 56 | 57 | void reset(); 58 | 59 | size_t size() const; 60 | }; 61 | 62 | template 63 | KDNode::KDNode() 64 | : points(nullptr), pointLeft(0), pointRight(0), left(nullptr), 65 | right(nullptr), max_point(0) {} 66 | 67 | template 68 | KDNode::KDNode(const std::vector> &other_bboxs) 69 | : points(nullptr), pointLeft(0), pointRight(0), left(nullptr), 70 | right(nullptr), bboxs(other_bboxs), max_point(other_bboxs.size()) {} 71 | 72 | template 73 | KDNode::KDNode(const KDNode &a) 74 | : points(a.points), pointLeft(a.pointLeft), pointRight(a.pointRight), 75 | left(a.left), right(a.right), idx(a.idx), bboxs(a.bboxs), 76 | waitpoints(a.waitpoints), delaypoints(a.delaypoints), 77 | max_point(a.max_point) {} 78 | 79 | template void KDNode::init(const _Point &ref) { 80 | waitpoints.clear(); 81 | delaypoints.clear(); 82 | if (this->left && this->right) { 83 | this->left->init(ref); 84 | this->right->init(ref); 85 | updateMaxPoint(this->left->max_point, this->right->max_point); 86 | } else { 87 | S dis; 88 | S maxdis = std::numeric_limits::lowest(); 89 | for (size_t i = pointLeft; i < pointRight; i++) { 90 | dis = points[i].updatedistance(ref); 91 | if (dis > maxdis) { 92 | maxdis = dis; 93 | max_point = points[i]; 94 | } 95 | } 96 | } 97 | } 98 | 99 | template 100 | S KDNode::bound_distance(const _Point &ref_point) const { 101 | S bound_dis(0); 102 | S dim_distance; 103 | for (size_t cur_dim = 0; cur_dim < dim(); cur_dim++) { 104 | dim_distance = 0; 105 | if (ref_point[cur_dim] > this->bboxs[cur_dim].high) 106 | dim_distance = ref_point[cur_dim] - this->bboxs[cur_dim].high; 107 | else if (ref_point[cur_dim] < this->bboxs[cur_dim].low) 108 | dim_distance = this->bboxs[cur_dim].low - ref_point[cur_dim]; 109 | bound_dis += powi(dim_distance, 2); 110 | } 111 | return bound_dis; 112 | } 113 | 114 | template void KDNode::update_distance() { 115 | for (const auto &ref_point : this->waitpoints) { 116 | S lastmax_distance = this->max_point.dis; 117 | S cur_distance = this->max_point.distance(ref_point); 118 | // cur_distance > 119 | // lastmax_distance意味着当前Node的max_point不会进行更新 120 | if (cur_distance > lastmax_distance) { 121 | S boundary_distance = bound_distance(ref_point); 122 | if (boundary_distance < lastmax_distance) 123 | this->delaypoints.push_back(ref_point); 124 | } else { 125 | if (this->right && this->left) { 126 | if (!delaypoints.empty()) { 127 | for (const auto &delay_point : delaypoints) { 128 | this->left->send_delay_point(delay_point); 129 | this->right->send_delay_point(delay_point); 130 | } 131 | delaypoints.clear(); 132 | } 133 | this->left->send_delay_point(ref_point); 134 | this->left->update_distance(); 135 | 136 | this->right->send_delay_point(ref_point); 137 | this->right->update_distance(); 138 | 139 | updateMaxPoint(this->left->max_point, this->right->max_point); 140 | } else { 141 | S dis; 142 | S maxdis; 143 | this->delaypoints.push_back(ref_point); 144 | for (const auto &delay_point : delaypoints) { 145 | maxdis = std::numeric_limits::lowest(); 146 | for (size_t i = pointLeft; i < pointRight; i++) { 147 | dis = points[i].updatedistance(delay_point); 148 | if (dis > maxdis) { 149 | maxdis = dis; 150 | max_point = points[i]; 151 | } 152 | } 153 | } 154 | this->delaypoints.clear(); 155 | } 156 | } 157 | } 158 | this->waitpoints.clear(); 159 | } 160 | 161 | template void KDNode::reset() { 162 | for (size_t i = pointLeft; i < pointRight; i++) { 163 | points[i].reset(); 164 | } 165 | this->waitpoints.clear(); 166 | this->delaypoints.clear(); 167 | this->max_point.reset(); 168 | if (this->left && this->right) { 169 | this->left->reset(); 170 | this->right->reset(); 171 | } 172 | } 173 | 174 | template size_t KDNode::size() const { 175 | if (this->left && this->right) 176 | return this->left->size() + this->right->size(); 177 | return (pointRight - pointLeft); 178 | } 179 | 180 | } // namespace quickfps::dynamic 181 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/dynamic/KDTreeBase.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 韩萌 on 2022/6/14. 3 | // Refactored by AyajiLin on 2024/09/03. 4 | // 5 | 6 | #pragma once 7 | 8 | #include "KDNode.h" 9 | #include "Point.h" 10 | #include 11 | #include 12 | #include 13 | 14 | namespace quickfps::dynamic { 15 | 16 | template class KDTreeBase { 17 | public: 18 | using _Point = Point; 19 | using _Points = _Point *; 20 | using NodePtr = KDNode *; 21 | using _Interval = Interval; 22 | 23 | size_t pointSize; 24 | _Points sample_points; 25 | NodePtr root_; 26 | _Points points_; 27 | 28 | public: 29 | KDTreeBase(_Points data, size_t pointSize, _Points samplePoints); 30 | 31 | ~KDTreeBase(); 32 | 33 | void buildKDtree(); 34 | 35 | NodePtr get_root() const { return this->root_; }; 36 | 37 | void init(const _Point &ref); 38 | 39 | size_t dim() const { return points_[0].dim(); } 40 | 41 | virtual _Point max_point() = 0; 42 | 43 | virtual void sample(size_t sample_num) = 0; 44 | 45 | protected: 46 | void deleteNode(NodePtr node_p); 47 | virtual void addNode(NodePtr p) = 0; 48 | virtual bool leftNode(size_t high, size_t count) const = 0; 49 | virtual void update_distance(const _Point &ref_point) = 0; 50 | 51 | NodePtr divideTree(ssize_t left, ssize_t right, 52 | const std::vector<_Interval> &bboxs, size_t curr_high); 53 | 54 | size_t planeSplit(ssize_t left, ssize_t right, size_t split_dim, 55 | T split_val); 56 | 57 | T qSelectMedian(size_t dim, size_t left, size_t right); 58 | static size_t findSplitDim(const std::vector<_Interval> &bboxs, size_t dim); 59 | inline std::vector<_Interval> computeBoundingBox(size_t left, size_t right); 60 | }; 61 | 62 | template 63 | KDTreeBase::KDTreeBase(_Points data, size_t pointSize, 64 | _Points samplePoints) 65 | : pointSize(pointSize), sample_points(samplePoints), root_(nullptr), 66 | points_(data) {} 67 | 68 | template KDTreeBase::~KDTreeBase() { 69 | if (root_ != nullptr) 70 | deleteNode(root_); 71 | } 72 | 73 | template 74 | void KDTreeBase::deleteNode(NodePtr node_p) { 75 | if (node_p->left) 76 | deleteNode(node_p->left); 77 | if (node_p->right) 78 | deleteNode(node_p->right); 79 | delete node_p; 80 | } 81 | 82 | template void KDTreeBase::buildKDtree() { 83 | size_t left = 0; 84 | size_t right = pointSize; 85 | std::vector<_Interval> bboxs = this->computeBoundingBox(left, right); 86 | this->root_ = divideTree(left, right, bboxs, 0); 87 | } 88 | 89 | template 90 | typename KDTreeBase::NodePtr 91 | KDTreeBase::divideTree(ssize_t left, ssize_t right, 92 | const std::vector<_Interval> &bboxs, 93 | size_t curr_high) { 94 | NodePtr node = new KDNode(bboxs); 95 | 96 | ssize_t count = right - left; 97 | if (this->leftNode(curr_high, count)) { 98 | node->pointLeft = left; 99 | node->pointRight = right; 100 | node->points = this->points_; 101 | this->addNode(node); 102 | return node; 103 | } else { 104 | size_t split_dim = this->findSplitDim(bboxs, dim()); 105 | T split_val = this->qSelectMedian(split_dim, left, right); 106 | 107 | size_t split_delta = planeSplit(left, right, split_dim, split_val); 108 | 109 | std::vector<_Interval> bbox_cur = 110 | this->computeBoundingBox(left, left + split_delta); 111 | node->left = 112 | this->divideTree(left, left + split_delta, bbox_cur, curr_high + 1); 113 | bbox_cur = this->computeBoundingBox(left + split_delta, right); 114 | node->right = this->divideTree(left + split_delta, right, bbox_cur, 115 | curr_high + 1); 116 | return node; 117 | } 118 | } 119 | 120 | template 121 | size_t KDTreeBase::planeSplit(ssize_t left, ssize_t right, 122 | size_t split_dim, T split_val) { 123 | ssize_t start = left; 124 | ssize_t end = right - 1; 125 | 126 | for (;;) { 127 | while (start <= end && points_[start].pos[split_dim] < split_val) 128 | ++start; 129 | while (start <= end && points_[end].pos[split_dim] >= split_val) 130 | --end; 131 | 132 | if (start > end) 133 | break; 134 | std::swap(points_[start], points_[end]); 135 | ++start; 136 | --end; 137 | } 138 | 139 | ssize_t lim1 = start - left; 140 | if (start == left) 141 | lim1 = 1; 142 | if (start == right) 143 | lim1 = (right - left - 1); 144 | 145 | return static_cast(lim1); 146 | } 147 | 148 | template 149 | T KDTreeBase::qSelectMedian(size_t dim, size_t left, size_t right) { 150 | T sum = std::accumulate(this->points_ + left, this->points_ + right, 0.0, 151 | [dim](const T &acc, const _Point &point) { 152 | return acc + point.pos[dim]; 153 | }); 154 | return sum / (right - left); 155 | } 156 | 157 | template 158 | size_t KDTreeBase::findSplitDim(const std::vector<_Interval> &bboxs, 159 | size_t dim) { 160 | T min_, max_; 161 | T span = 0; 162 | size_t best_dim = 0; 163 | 164 | for (size_t cur_dim = 0; cur_dim < dim; cur_dim++) { 165 | min_ = bboxs[cur_dim].low; 166 | max_ = bboxs[cur_dim].high; 167 | T cur_span = (max_ - min_); 168 | 169 | if (cur_span > span) { 170 | best_dim = cur_dim; 171 | span = cur_span; 172 | } 173 | } 174 | 175 | return best_dim; 176 | } 177 | 178 | template 179 | inline std::vector> 180 | KDTreeBase::computeBoundingBox(size_t left, size_t right) { 181 | std::vector min_vals(this->dim(), std::numeric_limits::max()); 182 | std::vector max_vals(this->dim(), std::numeric_limits::lowest()); 183 | 184 | for (size_t i = left; i < right; ++i) { 185 | const _Point &pos = points_[i]; 186 | 187 | for (size_t cur_dim = 0; cur_dim < this->dim(); cur_dim++) { 188 | T val = pos[cur_dim]; 189 | min_vals[cur_dim] = std::min(min_vals[cur_dim], val); 190 | max_vals[cur_dim] = std::max(max_vals[cur_dim], val); 191 | } 192 | } 193 | 194 | std::vector<_Interval> bboxs(dim()); 195 | 196 | for (size_t cur_dim = 0; cur_dim < dim(); cur_dim++) { 197 | bboxs[cur_dim].low = min_vals[cur_dim]; 198 | bboxs[cur_dim].high = max_vals[cur_dim]; 199 | } 200 | 201 | return bboxs; 202 | } 203 | 204 | template 205 | void KDTreeBase::init(const _Point &ref) { 206 | this->sample_points[0] = ref; 207 | this->root_->init(ref); 208 | } 209 | 210 | } // namespace quickfps::dynamic 211 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/dynamic/Point.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 韩萌 on 2022/6/14. 3 | // Refactored by AyajiLin on 2024/09/03. 4 | // 5 | 6 | #pragma once 7 | 8 | #include "utils.h" 9 | #include 10 | #include 11 | #include 12 | 13 | namespace quickfps::dynamic { 14 | 15 | template class Point { 16 | public: 17 | std::vector pos; // x, y, z, ... 18 | S dis; 19 | size_t id; 20 | 21 | Point(size_t dim); 22 | Point(const std::vector &pos, size_t id); 23 | Point(const std::vector &pos, size_t id, S dis); 24 | Point(const Point &obj); 25 | ~Point() {}; 26 | 27 | bool operator<(const Point &aii) const; 28 | 29 | constexpr T operator[](size_t i) const { return pos.at(i); } 30 | 31 | Point &operator=(const Point &obj) { 32 | this->pos = obj.pos; 33 | this->dis = obj.dis; 34 | this->id = obj.id; 35 | return *this; 36 | } 37 | 38 | constexpr size_t dim() const { return pos.size(); } 39 | 40 | constexpr S distance(const Point &b) { 41 | S ret = 0; 42 | for (size_t i = 0; i < pos.size(); i++) { 43 | S temp = pos[i] - b.pos[i]; 44 | ret += temp * temp; 45 | } 46 | return ret; 47 | } 48 | 49 | void reset(); 50 | 51 | S updatedistance(const Point &ref); 52 | 53 | S updateDistanceAndCount(const Point &ref, size_t &count); 54 | }; 55 | 56 | template 57 | Point::Point(size_t dim) 58 | : pos(dim, 0), dis(std::numeric_limits::max()), id(0) {} 59 | 60 | template 61 | Point::Point(const std::vector &pos, size_t id) 62 | : pos(pos), dis(std::numeric_limits::max()), id(id) {} 63 | 64 | template 65 | Point::Point(const std::vector &pos, size_t id, S dis) 66 | : pos(pos), dis(dis), id(id) {} 67 | 68 | template 69 | Point::Point(const Point &obj) : pos(obj.pos), dis(obj.dis), id(obj.id) {} 70 | 71 | template 72 | bool Point::operator<(const Point &aii) const { 73 | return dis < aii.dis; 74 | } 75 | 76 | template 77 | S Point::updatedistance(const Point &ref) { 78 | this->dis = std::min(this->dis, this->distance(ref)); 79 | return this->dis; 80 | } 81 | 82 | template 83 | S Point::updateDistanceAndCount(const Point &ref, size_t &count) { 84 | S tempDistance = this->distance(ref); 85 | if (tempDistance < this->dis) { 86 | this->dis = tempDistance; 87 | count++; 88 | } 89 | return this->dis; 90 | } 91 | 92 | template void Point::reset() { 93 | this->dis = std::numeric_limits::max(); 94 | } 95 | 96 | } // namespace quickfps::dynamic 97 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/dynamic/utils.h: -------------------------------------------------------------------------------- 1 | // Refactored by AyajiLin on 2024/09/03. 2 | 3 | #pragma once 4 | #include 5 | #include 6 | 7 | namespace quickfps::dynamic { 8 | using ssize_t = std::make_signed_t; 9 | 10 | template 11 | inline constexpr T powi(const T base, const size_t exponent) { 12 | // (parentheses not required in next line) 13 | return (exponent == 0) ? 1 : (base * powi(base, exponent - 1)); 14 | } 15 | } // namespace quickfps::dynamic 16 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/static/Interval.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 韩萌 on 2022/6/14. 3 | // Refactored by AyajiLin on 2023/9/16. 4 | // 5 | 6 | #pragma once 7 | 8 | namespace quickfps { 9 | template class Interval { 10 | public: 11 | S low, high; 12 | Interval() : low(0), high(0) {}; 13 | Interval(S low, S high) : low(low), high(high) {}; 14 | Interval(const Interval &o) : low(o.low), high(o.high) {}; 15 | }; 16 | } // namespace quickfps 17 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/static/KDLineTree.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by hanm on 22-6-15. 3 | // Refactored by AyajiLin on 2023/9/16. 4 | // 5 | 6 | #pragma once 7 | 8 | #include 9 | #include 10 | 11 | #include "KDTreeBase.h" 12 | 13 | namespace quickfps { 14 | 15 | template 16 | class KDLineTree : public KDTreeBase { 17 | public: 18 | using typename KDTreeBase::_Point; 19 | using typename KDTreeBase::_Points; 20 | using typename KDTreeBase::NodePtr; 21 | 22 | KDLineTree(_Points data, size_t pointSize, size_t treeHigh, 23 | _Points samplePoints); 24 | ~KDLineTree(); 25 | 26 | std::vector KDNode_list; 27 | 28 | size_t high_; 29 | 30 | _Point max_point() override; 31 | 32 | void update_distance(const _Point &ref_point) override; 33 | 34 | void sample(size_t sample_num) override; 35 | 36 | bool leftNode(size_t high, size_t count) const override { 37 | return high == this->high_ || count == 1; 38 | }; 39 | 40 | void addNode(NodePtr p) override; 41 | }; 42 | 43 | template 44 | KDLineTree::KDLineTree(_Points data, size_t pointSize, 45 | size_t treeHigh, _Points samplePoints) 46 | : KDTreeBase(data, pointSize, samplePoints), high_(treeHigh) { 47 | KDNode_list.clear(); 48 | } 49 | 50 | template 51 | KDLineTree::~KDLineTree() { 52 | KDNode_list.clear(); 53 | } 54 | 55 | template 56 | typename KDLineTree::_Point KDLineTree::max_point() { 57 | _Point tmpPoint; 58 | S max_distance = std::numeric_limits::lowest(); 59 | for (const auto &bucket : KDNode_list) { 60 | if (bucket->max_point.dis > max_distance) { 61 | max_distance = bucket->max_point.dis; 62 | tmpPoint = bucket->max_point; 63 | } 64 | } 65 | return tmpPoint; 66 | } 67 | 68 | template 69 | void KDLineTree::update_distance(const _Point &ref_point) { 70 | for (const auto &bucket : KDNode_list) { 71 | bucket->send_delay_point(ref_point); 72 | bucket->update_distance(); 73 | } 74 | } 75 | 76 | template 77 | void KDLineTree::sample(size_t sample_num) { 78 | _Point ref_point; 79 | for (size_t i = 1; i < sample_num; i++) { 80 | ref_point = this->max_point(); 81 | this->sample_points[i] = ref_point; 82 | this->update_distance(ref_point); 83 | } 84 | } 85 | 86 | template 87 | void KDLineTree::addNode(NodePtr p) { 88 | size_t nodeIdx = KDNode_list.size(); 89 | p->idx = nodeIdx; 90 | KDNode_list.push_back(p); 91 | } 92 | 93 | } // namespace quickfps 94 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/static/KDNode.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 韩萌 on 2022/6/14. 3 | // Refactored by AyajiLin on 2023/9/16. 4 | // 5 | 6 | #pragma once 7 | #include 8 | #include 9 | #include 10 | #include 11 | 12 | #include "Interval.h" 13 | #include "Point.h" 14 | 15 | namespace quickfps { 16 | 17 | template class KDNode { 18 | public: 19 | using _Point = Point; 20 | using _Points = _Point *; 21 | _Points points; 22 | size_t pointLeft, pointRight; 23 | size_t idx; 24 | 25 | std::array, DIM> bboxs; 26 | std::vector<_Point> waitpoints; 27 | std::vector<_Point> delaypoints; 28 | _Point max_point; 29 | KDNode *left; 30 | KDNode *right; 31 | 32 | KDNode(); 33 | 34 | KDNode(const KDNode &a); 35 | 36 | KDNode(const std::array, DIM> &bboxs); 37 | 38 | void init(const _Point &ref); 39 | 40 | void updateMaxPoint(const _Point &lpoint, const _Point &rpoint) { 41 | if (lpoint.dis > rpoint.dis) 42 | this->max_point = lpoint; 43 | else 44 | this->max_point = rpoint; 45 | } 46 | 47 | S bound_distance(const _Point &ref_point) const; 48 | 49 | void send_delay_point(const _Point &point) { 50 | this->waitpoints.push_back(point); 51 | } 52 | 53 | void update_distance(); 54 | 55 | void reset(); 56 | 57 | size_t size() const; 58 | }; 59 | 60 | template 61 | KDNode::KDNode() 62 | : points(nullptr), pointLeft(0), pointRight(0), left(nullptr), 63 | right(nullptr) {} 64 | 65 | template 66 | KDNode::KDNode(const std::array, DIM> &other_bboxs) 67 | : points(nullptr), pointLeft(0), pointRight(0), left(nullptr), 68 | right(nullptr) { 69 | std::copy(other_bboxs.cbegin(), other_bboxs.cend(), this->bboxs.begin()); 70 | } 71 | 72 | template 73 | KDNode::KDNode(const KDNode &a) 74 | : points(a.points), pointLeft(a.pointLeft), pointRight(a.pointRight), 75 | left(a.left), right(a.right), idx(a.idx) { 76 | std::copy(a.bboxs.cbegin(), a.bboxs.cend(), this->bboxs.begin()); 77 | std::copy(a.waitpoints.cbegin(), a.waitpoints.cend(), 78 | this->waitpoints.begin()); 79 | std::copy(a.delaypoints.cbegin(), a.delaypoints.cend(), 80 | this->delaypoints.begin()); 81 | } 82 | 83 | template 84 | void KDNode::init(const _Point &ref) { 85 | waitpoints.clear(); 86 | delaypoints.clear(); 87 | if (this->left && this->right) { 88 | this->left->init(ref); 89 | this->right->init(ref); 90 | updateMaxPoint(this->left->max_point, this->right->max_point); 91 | } else { 92 | S dis; 93 | S maxdis = std::numeric_limits::lowest(); 94 | for (size_t i = pointLeft; i < pointRight; i++) { 95 | dis = points[i].updatedistance(ref); 96 | if (dis > maxdis) { 97 | maxdis = dis; 98 | max_point = points[i]; 99 | } 100 | } 101 | } 102 | } 103 | 104 | template 105 | S KDNode::bound_distance(const _Point &ref_point) const { 106 | S bound_dis(0); 107 | S dim_distance; 108 | for (size_t cur_dim = 0; cur_dim < DIM; cur_dim++) { 109 | dim_distance = 0; 110 | if (ref_point[cur_dim] > this->bboxs[cur_dim].high) 111 | dim_distance = ref_point[cur_dim] - this->bboxs[cur_dim].high; 112 | else if (ref_point[cur_dim] < this->bboxs[cur_dim].low) 113 | dim_distance = this->bboxs[cur_dim].low - ref_point[cur_dim]; 114 | bound_dis += powi(dim_distance, 2); 115 | } 116 | return bound_dis; 117 | } 118 | 119 | template 120 | void KDNode::update_distance() { 121 | for (const auto &ref_point : this->waitpoints) { 122 | S lastmax_distance = this->max_point.dis; 123 | S cur_distance = this->max_point.distance(ref_point); 124 | // cur_distance > 125 | // lastmax_distance意味着当前Node的max_point不会进行更新 126 | if (cur_distance > lastmax_distance) { 127 | S boundary_distance = bound_distance(ref_point); 128 | if (boundary_distance < lastmax_distance) 129 | this->delaypoints.push_back(ref_point); 130 | } else { 131 | if (this->right && this->left) { 132 | if (!delaypoints.empty()) { 133 | for (const auto &delay_point : delaypoints) { 134 | this->left->send_delay_point(delay_point); 135 | this->right->send_delay_point(delay_point); 136 | } 137 | delaypoints.clear(); 138 | } 139 | this->left->send_delay_point(ref_point); 140 | this->left->update_distance(); 141 | 142 | this->right->send_delay_point(ref_point); 143 | this->right->update_distance(); 144 | 145 | updateMaxPoint(this->left->max_point, this->right->max_point); 146 | } else { 147 | S dis; 148 | S maxdis; 149 | this->delaypoints.push_back(ref_point); 150 | for (const auto &delay_point : delaypoints) { 151 | maxdis = std::numeric_limits::lowest(); 152 | for (size_t i = pointLeft; i < pointRight; i++) { 153 | dis = points[i].updatedistance(delay_point); 154 | if (dis > maxdis) { 155 | maxdis = dis; 156 | max_point = points[i]; 157 | } 158 | } 159 | } 160 | this->delaypoints.clear(); 161 | } 162 | } 163 | } 164 | this->waitpoints.clear(); 165 | } 166 | 167 | template void KDNode::reset() { 168 | for (size_t i = pointLeft; i < pointRight; i++) { 169 | points[i].reset(); 170 | } 171 | this->waitpoints.clear(); 172 | this->delaypoints.clear(); 173 | this->max_point.reset(); 174 | if (this->left && this->right) { 175 | this->left->reset(); 176 | this->right->reset(); 177 | } 178 | } 179 | 180 | template 181 | size_t KDNode::size() const { 182 | if (this->left && this->right) 183 | return this->left->size() + this->right->size(); 184 | return (pointRight - pointLeft); 185 | } 186 | 187 | } // namespace quickfps 188 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/static/KDTreeBase.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 韩萌 on 2022/6/14. 3 | // Refactored by AyajiLin on 2023/9/16. 4 | // 5 | 6 | #pragma once 7 | 8 | #include "KDNode.h" 9 | #include "Point.h" 10 | #include 11 | #include 12 | #include 13 | 14 | namespace quickfps { 15 | 16 | template class KDTreeBase { 17 | public: 18 | using _Point = Point; 19 | using _Points = _Point *; 20 | using NodePtr = KDNode *; 21 | using _Interval = Interval; 22 | 23 | size_t pointSize; 24 | _Points sample_points; 25 | NodePtr root_; 26 | _Points points_; 27 | 28 | public: 29 | KDTreeBase(_Points data, size_t pointSize, _Points samplePoints); 30 | 31 | ~KDTreeBase(); 32 | 33 | void buildKDtree(); 34 | 35 | NodePtr get_root() const { return this->root_; }; 36 | 37 | void init(const _Point &ref); 38 | 39 | virtual _Point max_point() = 0; 40 | 41 | virtual void sample(size_t sample_num) = 0; 42 | 43 | protected: 44 | void deleteNode(NodePtr node_p); 45 | virtual void addNode(NodePtr p) = 0; 46 | virtual bool leftNode(size_t high, size_t count) const = 0; 47 | virtual void update_distance(const _Point &ref_point) = 0; 48 | 49 | NodePtr divideTree(ssize_t left, ssize_t right, 50 | const std::array<_Interval, DIM> &bboxs, 51 | size_t curr_high); 52 | 53 | size_t planeSplit(ssize_t left, ssize_t right, size_t split_dim, 54 | T split_val); 55 | 56 | T qSelectMedian(size_t dim, size_t left, size_t right); 57 | static size_t findSplitDim(const std::array<_Interval, DIM> &bboxs); 58 | inline std::array<_Interval, DIM> computeBoundingBox(size_t left, 59 | size_t right); 60 | }; 61 | 62 | template 63 | KDTreeBase::KDTreeBase(_Points data, size_t pointSize, 64 | _Points samplePoints) 65 | : pointSize(pointSize), sample_points(samplePoints), root_(nullptr), 66 | points_(data) {} 67 | 68 | template 69 | KDTreeBase::~KDTreeBase() { 70 | if (root_ != nullptr) 71 | deleteNode(root_); 72 | } 73 | 74 | template 75 | void KDTreeBase::deleteNode(NodePtr node_p) { 76 | if (node_p->left) 77 | deleteNode(node_p->left); 78 | if (node_p->right) 79 | deleteNode(node_p->right); 80 | delete node_p; 81 | } 82 | 83 | template 84 | void KDTreeBase::buildKDtree() { 85 | size_t left = 0; 86 | size_t right = pointSize; 87 | std::array<_Interval, DIM> bboxs = this->computeBoundingBox(left, right); 88 | this->root_ = divideTree(left, right, bboxs, 0); 89 | } 90 | 91 | template 92 | typename KDTreeBase::NodePtr 93 | KDTreeBase::divideTree(ssize_t left, ssize_t right, 94 | const std::array<_Interval, DIM> &bboxs, 95 | size_t curr_high) { 96 | NodePtr node = new KDNode(bboxs); 97 | 98 | ssize_t count = right - left; 99 | if (this->leftNode(curr_high, count)) { 100 | node->pointLeft = left; 101 | node->pointRight = right; 102 | node->points = this->points_; 103 | this->addNode(node); 104 | return node; 105 | } else { 106 | size_t split_dim = this->findSplitDim(bboxs); 107 | T split_val = this->qSelectMedian(split_dim, left, right); 108 | 109 | size_t split_delta = planeSplit(left, right, split_dim, split_val); 110 | 111 | std::array<_Interval, DIM> bbox_cur = 112 | this->computeBoundingBox(left, left + split_delta); 113 | node->left = 114 | this->divideTree(left, left + split_delta, bbox_cur, curr_high + 1); 115 | bbox_cur = this->computeBoundingBox(left + split_delta, right); 116 | node->right = this->divideTree(left + split_delta, right, bbox_cur, 117 | curr_high + 1); 118 | return node; 119 | } 120 | } 121 | 122 | template 123 | size_t KDTreeBase::planeSplit(ssize_t left, ssize_t right, 124 | size_t split_dim, T split_val) { 125 | ssize_t start = left; 126 | ssize_t end = right - 1; 127 | 128 | for (;;) { 129 | while (start <= end && points_[start].pos[split_dim] < split_val) 130 | ++start; 131 | while (start <= end && points_[end].pos[split_dim] >= split_val) 132 | --end; 133 | 134 | if (start > end) 135 | break; 136 | std::swap(points_[start], points_[end]); 137 | ++start; 138 | --end; 139 | } 140 | 141 | ssize_t lim1 = start - left; 142 | if (start == left) 143 | lim1 = 1; 144 | if (start == right) 145 | lim1 = (right - left - 1); 146 | 147 | return static_cast(lim1); 148 | } 149 | 150 | template 151 | T KDTreeBase::qSelectMedian(size_t dim, size_t left, size_t right) { 152 | T sum = std::accumulate(this->points_ + left, this->points_ + right, 0.0, 153 | [dim](const T &acc, const _Point &point) { 154 | return acc + point.pos[dim]; 155 | }); 156 | return sum / (right - left); 157 | } 158 | 159 | template 160 | size_t 161 | KDTreeBase::findSplitDim(const std::array<_Interval, DIM> &bboxs) { 162 | T min_, max_; 163 | T span = 0; 164 | size_t best_dim = 0; 165 | 166 | for (size_t cur_dim = 0; cur_dim < DIM; cur_dim++) { 167 | min_ = bboxs[cur_dim].low; 168 | max_ = bboxs[cur_dim].high; 169 | T cur_span = (max_ - min_); 170 | 171 | if (cur_span > span) { 172 | best_dim = cur_dim; 173 | span = cur_span; 174 | } 175 | } 176 | 177 | return best_dim; 178 | } 179 | 180 | template 181 | inline std::array, DIM> 182 | KDTreeBase::computeBoundingBox(size_t left, size_t right) { 183 | T min_vals[DIM]; 184 | T max_vals[DIM]; 185 | std::fill(min_vals, min_vals + DIM, std::numeric_limits::max()); 186 | std::fill(max_vals, max_vals + DIM, std::numeric_limits::lowest()); 187 | 188 | for (size_t i = left; i < right; ++i) { 189 | const _Point &pos = points_[i]; 190 | 191 | for (size_t cur_dim = 0; cur_dim < DIM; cur_dim++) { 192 | T val = pos[cur_dim]; 193 | min_vals[cur_dim] = std::min(min_vals[cur_dim], val); 194 | max_vals[cur_dim] = std::max(max_vals[cur_dim], val); 195 | } 196 | } 197 | 198 | std::array<_Interval, DIM> bboxs; 199 | 200 | for (size_t cur_dim = 0; cur_dim < DIM; cur_dim++) { 201 | bboxs[cur_dim].low = min_vals[cur_dim]; 202 | bboxs[cur_dim].high = max_vals[cur_dim]; 203 | } 204 | 205 | return bboxs; 206 | } 207 | 208 | template 209 | void KDTreeBase::init(const _Point &ref) { 210 | this->sample_points[0] = ref; 211 | this->root_->init(ref); 212 | } 213 | 214 | } // namespace quickfps 215 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/static/Point.h: -------------------------------------------------------------------------------- 1 | // 2 | // Created by 韩萌 on 2022/6/14. 3 | // Refactored by AyajiLin on 2023/9/16. 4 | // 5 | 6 | #pragma once 7 | 8 | #include "utils.h" 9 | #include 10 | #include 11 | #include 12 | 13 | namespace quickfps { 14 | 15 | template class Point { 16 | public: 17 | T pos[DIM]; // x, y, z 18 | S dis; 19 | size_t id; 20 | 21 | Point(); 22 | Point(const T pos[DIM], size_t id); 23 | Point(const T pos[DIM], size_t id, S dis); 24 | Point(const Point &obj); 25 | ~Point() {}; 26 | 27 | bool operator<(const Point &aii) const; 28 | 29 | constexpr T operator[](size_t i) const { return pos[i]; } 30 | 31 | Point &operator=(const Point &obj) { 32 | std::copy(obj.pos, obj.pos + DIM, this->pos); 33 | this->dis = obj.dis; 34 | this->id = obj.id; 35 | return *this; 36 | } 37 | 38 | constexpr S distance(const Point &b) { return _distance(b, DIM); } 39 | 40 | void reset(); 41 | 42 | S updatedistance(const Point &ref); 43 | 44 | S updateDistanceAndCount(const Point &ref, size_t &count); 45 | 46 | private: 47 | constexpr S _distance(const Point &b, size_t dim_left) { 48 | return (dim_left == 0) 49 | ? 0 50 | : powi((this->pos[dim_left - 1] - b.pos[dim_left - 1]), 2) + 51 | _distance(b, dim_left - 1); 52 | } 53 | }; 54 | 55 | template 56 | Point::Point() : dis(std::numeric_limits::max()), id(0) { 57 | std::fill(pos, pos + DIM, 0); 58 | } 59 | 60 | template 61 | Point::Point(const T pos[DIM], size_t id) 62 | : dis(std::numeric_limits::max()), id(id) { 63 | std::copy(pos, pos + DIM, this->pos); 64 | } 65 | 66 | template 67 | Point::Point(const T pos[DIM], size_t id, S dis) : dis(dis), id(id) { 68 | std::copy(pos, pos + DIM, this->pos); 69 | } 70 | 71 | template 72 | Point::Point(const Point &obj) : dis(obj.dis), id(obj.id) { 73 | std::copy(obj.pos, obj.pos + DIM, this->pos); 74 | } 75 | 76 | template 77 | bool Point::operator<(const Point &aii) const { 78 | return dis < aii.dis; 79 | } 80 | 81 | template 82 | S Point::updatedistance(const Point &ref) { 83 | this->dis = std::min(this->dis, this->distance(ref)); 84 | return this->dis; 85 | } 86 | 87 | template 88 | S Point::updateDistanceAndCount(const Point &ref, size_t &count) { 89 | S tempDistance = this->distance(ref); 90 | if (tempDistance < this->dis) { 91 | this->dis = tempDistance; 92 | count++; 93 | } 94 | return this->dis; 95 | } 96 | 97 | template void Point::reset() { 98 | this->dis = std::numeric_limits::max(); 99 | } 100 | 101 | } // namespace quickfps 102 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/static/utils.h: -------------------------------------------------------------------------------- 1 | // Refactored by AyajiLin on 2023/9/16. 2 | 3 | #pragma once 4 | #include 5 | #include 6 | 7 | namespace quickfps { 8 | using ssize_t = std::make_signed_t; 9 | 10 | template 11 | inline constexpr T powi(const T base, const size_t exponent) { 12 | // (parentheses not required in next line) 13 | return (exponent == 0) ? 1 : (base * powi(base, exponent - 1)); 14 | } 15 | } // namespace quickfps 16 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/wrapper.cpp: -------------------------------------------------------------------------------- 1 | #include "wrapper.h" 2 | #include "dynamic/KDLineTree.h" 3 | #include "static/KDLineTree.h" 4 | #include 5 | #include 6 | #include 7 | #include 8 | #include 9 | 10 | #ifndef BUCKET_FPS_MAX_DIM 11 | #define BUCKET_FPS_MAX_DIM 8 12 | #endif 13 | constexpr size_t max_dim = BUCKET_FPS_MAX_DIM; 14 | 15 | using quickfps::KDLineTree; 16 | using quickfps::Point; 17 | 18 | template 19 | using DynPoint = quickfps::dynamic::Point; 20 | 21 | ////////////////// 22 | // // 23 | // Static // 24 | // // 25 | ////////////////// 26 | 27 | template 28 | std::vector> raw_data_to_points(const float *raw_data, 29 | size_t n_points, size_t dim) { 30 | std::vector> points; 31 | points.reserve(n_points); 32 | for (size_t i = 0; i < n_points; i++) { 33 | const float *ptr = raw_data + i * dim; 34 | points.push_back(Point(ptr, i)); 35 | } 36 | return points; 37 | } 38 | 39 | template 40 | void kdline_sample(const float *raw_data, size_t n_points, size_t dim, 41 | size_t n_samples, size_t start_idx, size_t height, 42 | int64_t *sampled_point_indices) { 43 | auto points = raw_data_to_points(raw_data, n_points, dim); 44 | auto sampled_points = std::make_unique[]>(n_samples); 45 | KDLineTree tree(points.data(), n_points, height, 46 | sampled_points.get()); 47 | tree.buildKDtree(); 48 | // NOTE: points are shuffled after building KDTree 49 | // we have to locate the start points by its ID 50 | auto start_point = 51 | *std::find_if(points.begin(), points.end(), 52 | [=](auto &p) { return p.id == start_idx; }); 53 | tree.init(start_point); 54 | tree.sample(n_samples); 55 | for (size_t i = 0; i < n_samples; i++) { 56 | sampled_point_indices[i] = sampled_points[i].id; 57 | } 58 | } 59 | 60 | /////////////////// 61 | // // 62 | // Dynamic // 63 | // // 64 | /////////////////// 65 | 66 | template 67 | std::vector> 68 | raw_data_to_points_varlen(const float *raw_data, size_t n_points, size_t dim) { 69 | std::vector> points; 70 | points.reserve(n_points); 71 | for (size_t i = 0; i < n_points; i++) { 72 | const float *ptr = raw_data + i * dim; 73 | points.push_back(DynPoint(std::vector(ptr, ptr + dim), i)); 74 | } 75 | return points; 76 | } 77 | 78 | template 79 | void kdline_sample_varlen(const float *raw_data, size_t n_points, size_t dim, 80 | size_t n_samples, size_t start_idx, size_t height, 81 | int64_t *sampled_point_indices) { 82 | auto points = raw_data_to_points_varlen(raw_data, n_points, dim); 83 | auto sampled_points = 84 | std::vector>(n_samples, DynPoint(dim)); 85 | quickfps::dynamic::KDLineTree tree(points.data(), n_points, height, 86 | sampled_points.data()); 87 | tree.buildKDtree(); 88 | auto start_point = 89 | *std::find_if(points.begin(), points.end(), 90 | [=](auto &p) { return p.id == start_idx; }); 91 | tree.init(start_point); 92 | tree.sample(n_samples); 93 | for (size_t i = 0; i < n_samples; i++) { 94 | sampled_point_indices[i] = sampled_points[i].id; 95 | } 96 | } 97 | 98 | //////////////////////////////////////// 99 | // // 100 | // Compile Time Function Helper // 101 | // // 102 | //////////////////////////////////////// 103 | using KDLineFuncType = void (*)(const float *, size_t, size_t, size_t, size_t, 104 | size_t, int64_t *); 105 | 106 | template 107 | constexpr std::array mapIndices(M &&m, std::index_sequence) { 108 | std::array result { m.template operator()()... }; 109 | return result; 110 | } 111 | 112 | template 113 | constexpr std::array map(M m) { 114 | return mapIndices(m, std::make_index_sequence()); 115 | } 116 | 117 | template struct kdline_func_helper { 118 | template KDLineFuncType operator()() { 119 | return &kdline_sample; 120 | } 121 | }; 122 | 123 | ///////////////// 124 | // // 125 | // API // 126 | // // 127 | ///////////////// 128 | 129 | void bucket_fps_kdline(const float *raw_data, size_t n_points, size_t dim, 130 | size_t n_samples, size_t start_idx, size_t height, 131 | int64_t *sampled_point_indices) { 132 | TORCH_CHECK(dim > 0, "dim should be larger than 0"); 133 | TORCH_CHECK(n_points != 0, "n_points should be larger than 0"); 134 | TORCH_CHECK(n_samples != 0, "n_samples should be larger than 0"); 135 | TORCH_CHECK(height != 0, "height should be larger than 0"); 136 | TORCH_CHECK(start_idx < n_points, 137 | "start_idx should be smaller than n_points"); 138 | if (dim <= max_dim) { 139 | auto func_arr = 140 | map(kdline_func_helper{}); 141 | func_arr[dim - 1](raw_data, n_points, dim, n_samples, start_idx, height, 142 | sampled_point_indices); 143 | } else { 144 | // var dim 145 | kdline_sample_varlen(raw_data, n_points, dim, n_samples, 146 | start_idx, height, sampled_point_indices); 147 | } 148 | } 149 | -------------------------------------------------------------------------------- /csrc/cpu/bucket_fps/wrapper.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | #include 3 | #include 4 | 5 | void bucket_fps_kdline(const float *raw_data, size_t n_points, size_t dim, 6 | size_t n_samples, size_t start_idx, size_t height, 7 | int64_t *sampled_point_indices); 8 | -------------------------------------------------------------------------------- /csrc/cpu/fpsample_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include "../utils.h" 4 | #include "bucket_fps/wrapper.h" 5 | 6 | using torch::Tensor; 7 | 8 | /////////////// 9 | // // 10 | // CPU // 11 | // // 12 | /////////////// 13 | 14 | std::tuple sample_cpu(const Tensor &x, int64_t k, 15 | torch::optional h, 16 | torch::optional start_idx) { 17 | TORCH_CHECK(x.device().is_cpu(), "x must be a CPU tensor, but found on ", 18 | x.device()); 19 | TORCH_CHECK(x.dim() >= 2, 20 | "x must have at least 2 dims, but got size: ", x.sizes()); 21 | TORCH_CHECK(k >= 1, "k must be greater than or equal to 1, but got ", k); 22 | auto [old_size, x_reshaped_raw] = bnorm_reshape(x); 23 | auto x_reshaped = x_reshaped_raw.to(torch::kFloat32).contiguous(); 24 | 25 | auto height = h.value_or(5); 26 | 27 | torch::Tensor cur_start_idx; 28 | if (start_idx.has_value()) { 29 | cur_start_idx = torch::ones({x_reshaped.size(0)}, 30 | x_reshaped.options().dtype(torch::kInt64)) * 31 | start_idx.value(); 32 | } else { 33 | cur_start_idx = 34 | torch::randint(0, x_reshaped.size(-2), {x_reshaped.size(0)}, 35 | x_reshaped.options().dtype(torch::kInt64)); 36 | } 37 | 38 | Tensor ret_indices = torch::empty( 39 | {x_reshaped.size(0), k}, x_reshaped.options().dtype(torch::kInt64)); 40 | 41 | if (x_reshaped.size(0) == 1) { 42 | // single batch 43 | bucket_fps_kdline(x_reshaped.const_data_ptr(), 44 | x_reshaped.size(1), x_reshaped.size(2), k, 45 | cur_start_idx.const_data_ptr()[0], height, 46 | ret_indices.mutable_data_ptr()); 47 | } else { 48 | torch::parallel_for( 49 | 0, x_reshaped.size(0), 0, [&](int64_t start, int64_t end) { 50 | for (auto i = start; i < end; i++) { 51 | bucket_fps_kdline( 52 | x_reshaped[i].const_data_ptr(), 53 | x_reshaped.size(1), x_reshaped.size(2), k, 54 | cur_start_idx.const_data_ptr()[i], height, 55 | ret_indices[i].mutable_data_ptr()); 56 | } 57 | }); 58 | } 59 | 60 | Tensor ret_tensor = torch::gather( 61 | x_reshaped_raw, 1, 62 | ret_indices.view({ret_indices.size(0), ret_indices.size(1), 1}) 63 | .repeat({1, 1, x_reshaped.size(2)})); 64 | 65 | // reshape to original size 66 | auto ret_tensor_sizes = old_size.vec(); 67 | ret_tensor_sizes[ret_tensor_sizes.size() - 2] = k; 68 | auto ret_indices_sizes = old_size.vec(); 69 | ret_indices_sizes.pop_back(); 70 | ret_indices_sizes[ret_indices_sizes.size() - 1] = k; 71 | 72 | return std::make_tuple( 73 | ret_tensor.view(ret_tensor_sizes), 74 | ret_indices.view(ret_indices_sizes).to(torch::kLong)); 75 | } 76 | 77 | TORCH_LIBRARY_IMPL(torch_fpsample, CPU, m) { m.impl("sample", &sample_cpu); } 78 | -------------------------------------------------------------------------------- /csrc/cuda/fpsample_cuda.cpp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leonardodalinky/pytorch_fpsample/7bb2fda82d4f726c5c8dcaaa5bcd8f6d2546cea9/csrc/cuda/fpsample_cuda.cpp -------------------------------------------------------------------------------- /csrc/fpsample.cpp: -------------------------------------------------------------------------------- 1 | #if __cplusplus < 201703L 2 | #error "C++17 is required" 3 | #endif 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #define STR_(x) #x 10 | #define STR(x) STR_(x) 11 | 12 | TORCH_LIBRARY(torch_fpsample, m) { 13 | m.def("sample(Tensor self, int k, int? h=None, int? start_idx=None) -> (Tensor, Tensor)"); 14 | } 15 | 16 | PYBIND11_MODULE(_core, m) { 17 | m.attr("CPP_VERSION") = __cplusplus; 18 | m.attr("PYTORCH_VERSION") = STR(TORCH_VERSION_MAJOR) "." STR( 19 | TORCH_VERSION_MINOR) "." STR(TORCH_VERSION_PATCH); 20 | m.attr("PYBIND11_VERSION") = STR(PYBIND11_VERSION_MAJOR) "." STR( 21 | PYBIND11_VERSION_MINOR) "." STR(PYBIND11_VERSION_PATCH); 22 | } 23 | -------------------------------------------------------------------------------- /csrc/fpsample_autograd.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | using torch::Tensor; 5 | using FuncType = std::tuple(const Tensor &, int64_t, 6 | torch::optional, 7 | torch::optional); 8 | 9 | //////////////////// 10 | // // 11 | // Autograd // 12 | // // 13 | //////////////////// 14 | using torch::autograd::AutogradContext; 15 | using torch::autograd::Variable; 16 | using torch::autograd::variable_list; 17 | class FPSampleFunction : public torch::autograd::Function { 18 | public: 19 | static variable_list forward(AutogradContext *ctx, const Tensor &x, 20 | int64_t k, torch::optional h, 21 | torch::optional start_idx) { 22 | torch::AutoDispatchBelowADInplaceOrView guard; 23 | static auto op = torch::Dispatcher::singleton() 24 | .findSchemaOrThrow("torch_fpsample::sample", "") 25 | .typed(); 26 | auto results = op.call(x, k, h, start_idx); 27 | auto ret_tensor = std::get<0>(results); 28 | auto ret_indices = std::get<1>(results); 29 | ctx->save_for_backward({ret_indices}); 30 | ctx->saved_data["x_sizes"] = x.sizes(); 31 | return {ret_tensor, ret_indices}; 32 | } 33 | 34 | static variable_list backward(AutogradContext *ctx, 35 | variable_list grad_outputs) { 36 | auto saved_tensors = ctx->get_saved_variables(); 37 | auto ret_indices = saved_tensors[0]; 38 | auto x_sizes = ctx->saved_data["x_sizes"].toIntVector(); 39 | auto grad_output = grad_outputs[0]; 40 | 41 | auto tmp_repeat_sizes = x_sizes; 42 | std::fill(tmp_repeat_sizes.begin(), tmp_repeat_sizes.end() - 1, 1); 43 | 44 | auto grad_input = torch::scatter( 45 | torch::zeros(x_sizes, grad_output.options()), -2, 46 | ret_indices.unsqueeze(-1).repeat(tmp_repeat_sizes), grad_output); 47 | 48 | return {grad_input, Variable(), Variable(), Variable(), Variable()}; 49 | } 50 | }; 51 | 52 | std::tuple sample_autograd(const Tensor &x, int64_t k, 53 | torch::optional h, 54 | torch::optional start_idx) { 55 | auto results = FPSampleFunction::apply(x, k, h, start_idx); 56 | return std::make_tuple(results[0], results[1]); 57 | } 58 | 59 | TORCH_LIBRARY_IMPL(torch_fpsample, Autograd, m) { 60 | m.impl("sample", &sample_autograd); 61 | } 62 | -------------------------------------------------------------------------------- /csrc/fpsample_meta.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | using torch::Tensor; 5 | 6 | std::tuple sample_meta(const Tensor &x, int64_t k, 7 | torch::optional h, 8 | torch::optional start_idx) { 9 | TORCH_CHECK(x.dim() >= 2, 10 | "x must have at least 2 dims, but got size: ", x.sizes()); 11 | TORCH_CHECK(k >= 1, "k must be greater than or equal to 1, but got ", k); 12 | auto tmp_s1 = x.sizes().vec(); 13 | tmp_s1[tmp_s1.size() - 2] = k; 14 | auto tmp_s2 = x.sizes().vec(); 15 | tmp_s2.pop_back(); 16 | tmp_s2[tmp_s2.size() - 1] = k; 17 | return std::make_tuple( 18 | torch::empty(tmp_s1, x.options()), 19 | torch::empty(tmp_s2, x.options().dtype(torch::kLong))); 20 | } 21 | 22 | TORCH_LIBRARY_IMPL(torch_fpsample, Meta, m) { m.impl("sample", &sample_meta); } 23 | -------------------------------------------------------------------------------- /csrc/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | inline std::tuple 6 | bnorm_reshape(const torch::Tensor &t) { 7 | if (t.dim() > 2) { 8 | // reshape to (..., rows, cols) 9 | return {t.sizes(), t.view({-1, t.size(-2), t.size(-1)})}; 10 | } else if (t.dim() == 2) { 11 | // reshape to (1, rows, cols) 12 | return {t.sizes(), t.view({1, t.size(0), t.size(1)})}; 13 | } else { 14 | TORCH_CHECK(false, 15 | "x must have at least 2 dims, but got size: ", t.sizes()); 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import platform 3 | import sys 4 | 5 | from setuptools import find_packages, setup 6 | from torch.__config__ import parallel_info 7 | from torch.utils import cpp_extension 8 | 9 | __version__ = "0.1.0" 10 | # WITH_CUDA = os.getenv("WITH_CUDA", "1") == "1" 11 | WITH_CUDA = False 12 | 13 | sources = [ 14 | "csrc/fpsample.cpp", 15 | "csrc/fpsample_autograd.cpp", 16 | "csrc/fpsample_meta.cpp", 17 | "csrc/cpu/fpsample_cpu.cpp", 18 | "csrc/cpu/bucket_fps/wrapper.cpp", 19 | ] 20 | extra_compile_args = {"cxx": ["-O3"]} 21 | extra_link_args = [] 22 | 23 | # OpenMP 24 | info = parallel_info() 25 | if "backend: OpenMP" in info and "OpenMP not found" not in info and sys.platform != "darwin": 26 | extra_compile_args["cxx"] += ["-DAT_PARALLEL_OPENMP"] 27 | if sys.platform == "win32": 28 | extra_compile_args["cxx"] += ["/openmp"] 29 | else: 30 | extra_compile_args["cxx"] += ["-fopenmp"] 31 | else: 32 | print("Compiling without OpenMP...") 33 | 34 | # Compile for mac arm64 35 | if sys.platform == "darwin": 36 | extra_compile_args["cxx"] += ["-D_LIBCPP_DISABLE_AVAILABILITY"] 37 | if platform.machine() == "arm64": 38 | extra_compile_args["cxx"] += ["-arch", "arm64"] 39 | extra_link_args += ["-arch", "arm64"] 40 | 41 | 42 | if WITH_CUDA: 43 | # TODO 44 | raise NotImplementedError("CUDA is not supported yet.") 45 | sources += [] 46 | ext_modules = [ 47 | cpp_extension.CUDAExtension( 48 | name="torch_fpsample._core", 49 | include_dirs=["csrc"], 50 | sources=sources, 51 | extra_compile_args=extra_compile_args, 52 | extra_link_args=extra_link_args, 53 | ) 54 | ] 55 | else: 56 | ext_modules = [ 57 | cpp_extension.CppExtension( 58 | name="torch_fpsample._core", 59 | include_dirs=["csrc"], 60 | sources=sources, 61 | extra_compile_args=extra_compile_args, 62 | extra_link_args=extra_link_args, 63 | ) 64 | ] 65 | 66 | 67 | setup( 68 | name="torch_fpsample", 69 | version=__version__, 70 | author="Leonard Lin", 71 | author_email="leonard.keilin@gmail.com", 72 | description="PyTorch implementation of fpsample.", 73 | ext_modules=ext_modules, 74 | keywords=["pytorch", "farthest", "furthest", "sampling", "sample", "fps"], 75 | packages=find_packages(), 76 | package_data={"": ["*.pyi"]}, 77 | cmdclass={"build_ext": cpp_extension.BuildExtension}, 78 | python_requires=">=3.8", 79 | install_requires=["torch>=2.0"], 80 | ) 81 | -------------------------------------------------------------------------------- /torch_fpsample/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib.machinery 2 | import os.path as osp 3 | 4 | import torch 5 | 6 | from .fps import sample 7 | 8 | torch.ops.load_library( 9 | importlib.machinery.PathFinder().find_spec(f"_core", [osp.dirname(__file__)]).origin 10 | ) 11 | -------------------------------------------------------------------------------- /torch_fpsample/fps.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | 6 | def sample( 7 | x: torch.Tensor, 8 | k: int, 9 | h: Optional[int] = None, 10 | start_idx: Optional[int] = None, 11 | ) -> Tuple[torch.Tensor, torch.LongTensor]: 12 | """Farthest Point Sampling (FPS) algorithm. 13 | 14 | Args: 15 | x (torch.Tensor): (*, N, D) input points tensor. 16 | k (int): Number of points to sample. 17 | h (int, optional): Maximum height for the bucket sampling. Defaults to 5. 18 | See https://github.com/leonardodalinky/fpsample#usage for details. 19 | start_idx (int, optional): Index of the point to start sampling from. Defaults to None. 20 | backend (str, optional): Backend to use for sampling. Defaults to "bucket". 21 | Available options are: `bucket`, `naive`. 22 | 23 | Returns: 24 | (torch.Tensor, torch.LongTensor): (Batched) sampled points tensor and (batched) indices of the sampled points. 25 | """ 26 | return torch.ops.torch_fpsample.sample(x, k, h, start_idx) 27 | --------------------------------------------------------------------------------