├── .github └── workflows │ ├── build_wheels.yml │ └── cuda │ ├── Linux-env.sh │ ├── Linux.sh │ ├── Windows-env.sh │ ├── Windows.sh │ ├── macOS-env.sh │ └── macOS.sh ├── .gitignore ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── pyproject.toml ├── requirements.txt ├── setup.cfg ├── setup.py ├── tests └── test_ops.py ├── tools └── packaging │ └── audit_torch_extension.py └── torchpairwise ├── __init__.py ├── _ops.py ├── _version.py ├── csrc ├── macros.h ├── ops │ ├── additive_chi2_kernel.cpp │ ├── additive_chi2_kernel.h │ ├── autograd │ │ ├── additive_chi2_kernel.cpp │ │ ├── braycurtis_kernel.cpp │ │ ├── canberra_kernel.cpp │ │ ├── hausdorff_kernel.cpp │ │ ├── haversine_kernel.cpp │ │ ├── ppminkowski_kernel.cpp │ │ ├── prf_div_kernel.cpp │ │ ├── snr_kernel.cpp │ │ ├── sqjensenshannon_kernel.cpp │ │ ├── sqmahalanobis_kernel.cpp │ │ └── wminkowski_kernel.cpp │ ├── braycurtis.cpp │ ├── braycurtis.h │ ├── canberra.cpp │ ├── canberra.h │ ├── common │ │ ├── binary_ops.h │ │ ├── prf_div_mode.h │ │ └── reduction_ops.h │ ├── cpdist.cpp │ ├── cpdist.h │ ├── cpu │ │ ├── additive_chi2_kernel.cpp │ │ ├── binary_ops.h │ │ ├── braycurtis_kernel.cpp │ │ ├── canberra_kernel.cpp │ │ ├── cpu_helpers.h │ │ ├── hausdorff_kernel.cpp │ │ ├── haversine_kernel.cpp │ │ ├── pairwise_binary_kernels.cpp │ │ ├── ppminkowski_kernel.cpp │ │ ├── prf_div_kernel.cpp │ │ ├── prf_divide.h │ │ ├── reduction_ops.h │ │ ├── rel_entr.h │ │ ├── signum.h │ │ ├── snr_kernel.cpp │ │ ├── sqjensenshannon_kernel.cpp │ │ ├── sqmahalanobis_kernel.cpp │ │ └── wminkowski_kernel.cpp │ ├── cuda │ │ ├── additive_chi2_kernel.cu │ │ ├── binary_ops.cuh │ │ ├── braycurtis_kernel.cu │ │ ├── canberra_kernel.cu │ │ ├── cuda_helpers.h │ │ ├── hausdorff_kernel.cu │ │ ├── haversine_kernel.cu │ │ ├── pairwise_binary_kernels.cu │ │ ├── ppminkowski_kernel.cu │ │ ├── prf_div_kernel.cu │ │ ├── prf_divide.cuh │ │ ├── reduction_ops.cuh │ │ ├── rel_entr.cuh │ │ ├── signum.cuh │ │ ├── snr_kernel.cu │ │ ├── sqjensenshannon_kernel.cu │ │ ├── sqmahalanobis_kernel.cu │ │ └── wminkowski_kernel.cu │ ├── hausdorff.cpp │ ├── hausdorff.h │ ├── haversine.cpp │ ├── haversine.h │ ├── neighbors.cpp │ ├── neighbors.h │ ├── ops.h │ ├── pairwise_binary.cpp │ ├── pairwise_binary.h │ ├── pairwise_metrics.cpp │ ├── pairwise_metrics.h │ ├── ppminkowski.cpp │ ├── ppminkowski.h │ ├── prf_div.cpp │ ├── prf_div.h │ ├── snr.cpp │ ├── snr.h │ ├── sqjensenshannon.cpp │ ├── sqjensenshannon.h │ ├── sqmahalanobis.cpp │ ├── sqmahalanobis.h │ ├── utils │ │ ├── dispatch.h │ │ └── scalar_type_utils.h │ ├── wminkowski.cpp │ └── wminkowski.h ├── torchpairwise.cpp └── torchpairwise.h ├── extension.py └── ops ├── __init__.py ├── cpdist.py ├── cpdist.pyi └── metrics ├── __init__.py ├── neighbors.py ├── neighbors.pyi ├── pairwise.py └── pairwise.pyi /.github/workflows/build_wheels.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | tags: 6 | - 'v*' 7 | 8 | jobs: 9 | build_sdist: 10 | name: Source distribution 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Build sdist 16 | run: pipx run build --sdist 17 | 18 | - name: Upload sdist 19 | uses: actions/upload-artifact@v4 20 | with: 21 | name: artifacts-sdist 22 | path: dist/*.tar.gz 23 | 24 | build_wheels: 25 | name: Wheel on ${{ matrix.os }} for cp-${{ matrix.python-version }} 26 | runs-on: ${{ matrix.os }} 27 | strategy: 28 | matrix: 29 | os: [ ubuntu-22.04, macos-14, windows-2019 ] 30 | python-version: [ '3.9', '3.10', '3.11', '3.12', '3.13' ] 31 | torch-version: [ 2.7.0 ] 32 | steps: 33 | - uses: actions/checkout@v4 34 | 35 | - name: Set up Python ${{ matrix.python-version }} 36 | uses: actions/setup-python@v2 37 | with: 38 | python-version: ${{ matrix.python-version }} 39 | 40 | - name: Free Disk Space (Ubuntu) 41 | if: ${{ runner.os == 'Linux' }} 42 | uses: jlumbroso/free-disk-space@main 43 | 44 | - name: Install ninja-build 45 | uses: seanmiddleditch/gha-setup-ninja@master 46 | 47 | - name: Install CUDA 48 | run: | 49 | bash .github/workflows/cuda/${{ runner.os }}.sh 50 | 51 | - name: Upgrade build tools 52 | run: | 53 | pip install --upgrade setuptools wheel 54 | 55 | - name: Install PyTorch ${{ matrix.torch-version }} 56 | run: | 57 | pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/${{ runner.os == 'macOS' && 'cpu' || 'cu128' }} 58 | python -c "import torch; print('PyTorch:', torch.__version__); print('CUDA:', torch.version.cuda)" 59 | 60 | - name: Build wheel 61 | run: | 62 | source .github/workflows/cuda/${{ runner.os }}-env.sh 63 | python setup.py bdist_wheel --dist-dir=wheelhouse 64 | shell: 65 | bash 66 | 67 | - name: Repair wheel (Ubuntu) 68 | if: ${{ runner.os == 'Linux' }} 69 | run: | 70 | mkdir dist 71 | mv wheelhouse/*.whl dist/ 72 | pip install auditwheel patchelf 73 | python tools/packaging/audit_torch_extension.py repair --plat manylinux_2_34_x86_64 -w wheelhouse dist/*.whl 74 | rm -r dist 75 | 76 | - name: Upload wheels 77 | uses: actions/upload-artifact@v4 78 | with: 79 | name: artifacts-bdist-${{ matrix.os }}-cp${{ matrix.python-version }} 80 | path: ./wheelhouse/*.whl 81 | 82 | publish_on_pypi: 83 | if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v') 84 | needs: [ build_sdist, build_wheels ] 85 | runs-on: ubuntu-latest 86 | environment: pypi 87 | permissions: 88 | id-token: write 89 | steps: 90 | - name: Download wheels 91 | uses: actions/download-artifact@v4 92 | with: 93 | pattern: artifacts-* 94 | path: dist 95 | merge-multiple: true 96 | 97 | - name: Publish wheels to PyPI 98 | uses: pypa/gh-action-pypi-publish@v1.5.0 99 | with: 100 | user: __token__ 101 | password: ${{ secrets.pypi_password }} 102 | -------------------------------------------------------------------------------- /.github/workflows/cuda/Linux-env.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | export FORCE_CUDA=1 4 | export TORCH_CUDA_ARCH_LIST='5.0 6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0 12.0+PTX' 5 | -------------------------------------------------------------------------------- /.github/workflows/cuda/Linux.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | # Install CUDA 12.8: 7 | echo "Installing CUDA 12.8" 8 | 9 | OS=ubuntu2404 10 | 11 | wget https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-keyring_1.1-1_all.deb 12 | sudo dpkg -i cuda-keyring_1.1-1_all.deb 13 | 14 | sudo apt-get update 15 | sudo apt-get -y install cuda-nvcc-12-8 cuda-cudart-dev-12-8 libcublas-dev-12-8 libcurand-dev-12-8 libcusolver-dev-12-8 libcusparse-dev-12-8 16 | sudo apt clean 17 | 18 | export PATH=/usr/local/cuda/bin:$PATH 19 | export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 20 | export CUDA_HOME="/usr/local/cuda" 21 | 22 | # test 23 | nvcc -V 24 | -------------------------------------------------------------------------------- /.github/workflows/cuda/Windows-env.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | export FORCE_CUDA=1 4 | export TORCH_CUDA_ARCH_LIST='5.0 6.0 6.1 7.0 7.5 8.0 8.6 9.0 10.0 12.0+PTX' 5 | -------------------------------------------------------------------------------- /.github/workflows/cuda/Windows.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | # Install CUDA 12.8: (requires thrust) 7 | echo "Installing CUDA 12.8" 8 | curl -L -o cuda.exe https://developer.download.nvidia.com/compute/cuda/12.8.1/network_installers/cuda_12.8.1_windows_network.exe 9 | ./cuda.exe -s nvcc_12.8 cudart_12.8 thrust_12.8 cublas_dev_12.8 curand_dev_12.8 cusolver_dev_12.8 cusparse_dev_12.8 10 | rm cuda.exe 11 | 12 | export PATH="/c/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.8/bin":$PATH 13 | export CUDA_HOME="C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v12.8" 14 | 15 | # test 16 | nvcc -V 17 | -------------------------------------------------------------------------------- /.github/workflows/cuda/macOS-env.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | export FORCE_ONLY_CPU=1 4 | -------------------------------------------------------------------------------- /.github/workflows/cuda/macOS.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | set -e 4 | set -x 5 | 6 | echo "CUDA is not available on MacOS" 7 | 8 | echo "Installing libomp" 9 | brew install llvm libomp 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Ignore Visual Studio temporary files, build results, and 2 | ## files generated by popular Visual Studio add-ons. 3 | ## 4 | ## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore 5 | 6 | # User-specific files 7 | *.rsuser 8 | *.suo 9 | *.user 10 | *.userosscache 11 | *.sln.docstates 12 | 13 | # User-specific files (MonoDevelop/Xamarin Studio) 14 | *.userprefs 15 | 16 | # Build results 17 | [Dd]ebug/ 18 | [Dd]ebugPublic/ 19 | [Rr]elease/ 20 | [Rr]eleases/ 21 | x64/ 22 | x86/ 23 | [Aa][Rr][Mm]/ 24 | [Aa][Rr][Mm]64/ 25 | bld/ 26 | [Bb]in/ 27 | [Oo]bj/ 28 | [Ll]og/ 29 | 30 | # Visual Studio 2015/2017 cache/options directory 31 | .vs/ 32 | # Uncomment if you have tasks that create the project's static files in wwwroot 33 | #wwwroot/ 34 | 35 | # Visual Studio 2017 auto generated files 36 | Generated\ Files/ 37 | 38 | # MSTest tests Results 39 | [Tt]est[Rr]esult*/ 40 | [Bb]uild[Ll]og.* 41 | 42 | # NUNIT 43 | *.VisualState.xml 44 | TestResult.xml 45 | 46 | # Build Results of an ATL Project 47 | [Dd]ebugPS/ 48 | [Rr]eleasePS/ 49 | dlldata.c 50 | 51 | # Benchmark Results 52 | BenchmarkDotNet.Artifacts/ 53 | 54 | # .NET Core 55 | project.lock.json 56 | project.fragment.lock.json 57 | artifacts/ 58 | 59 | # StyleCop 60 | StyleCopReport.xml 61 | 62 | # Files built by Visual Studio 63 | *_i.c 64 | *_p.c 65 | *_h.h 66 | *.ilk 67 | *.meta 68 | *.obj 69 | *.iobj 70 | *.pch 71 | *.pdb 72 | *.ipdb 73 | *.pgc 74 | *.pgd 75 | *.rsp 76 | *.sbr 77 | *.tlb 78 | *.tli 79 | *.tlh 80 | *.tmp 81 | *.tmp_proj 82 | *_wpftmp.csproj 83 | *.log 84 | *.vspscc 85 | *.vssscc 86 | .builds 87 | *.pidb 88 | *.svclog 89 | *.scc 90 | 91 | # Chutzpah Test files 92 | _Chutzpah* 93 | 94 | # Visual C++ cache files 95 | ipch/ 96 | *.aps 97 | *.ncb 98 | *.opendb 99 | *.opensdf 100 | *.sdf 101 | *.cachefile 102 | *.VC.db 103 | *.VC.VC.opendb 104 | 105 | # Visual Studio profiler 106 | *.psess 107 | *.vsp 108 | *.vspx 109 | *.sap 110 | 111 | # Visual Studio Trace Files 112 | *.e2e 113 | 114 | # TFS 2012 Local Workspace 115 | $tf/ 116 | 117 | # Guidance Automation Toolkit 118 | *.gpState 119 | 120 | # ReSharper is a .NET coding add-in 121 | _ReSharper*/ 122 | *.[Rr]e[Ss]harper 123 | *.DotSettings.user 124 | 125 | # JustCode is a .NET coding add-in 126 | .JustCode 127 | 128 | # TeamCity is a build add-in 129 | _TeamCity* 130 | 131 | # DotCover is a Code Coverage Tool 132 | *.dotCover 133 | 134 | # AxoCover is a Code Coverage Tool 135 | .axoCover/* 136 | !.axoCover/settings.json 137 | 138 | # Visual Studio code coverage results 139 | *.coverage 140 | *.coveragexml 141 | 142 | # NCrunch 143 | _NCrunch_* 144 | .*crunch*.local.xml 145 | nCrunchTemp_* 146 | 147 | # MightyMoose 148 | *.mm.* 149 | AutoTest.Net/ 150 | 151 | # Web workbench (sass) 152 | .sass-cache/ 153 | 154 | # Installshield output folder 155 | [Ee]xpress/ 156 | 157 | # DocProject is a documentation generator add-in 158 | DocProject/buildhelp/ 159 | DocProject/Help/*.HxT 160 | DocProject/Help/*.HxC 161 | DocProject/Help/*.hhc 162 | DocProject/Help/*.hhk 163 | DocProject/Help/*.hhp 164 | DocProject/Help/Html2 165 | DocProject/Help/html 166 | 167 | # Click-Once directory 168 | publish/ 169 | 170 | # Publish Web Output 171 | *.[Pp]ublish.xml 172 | *.azurePubxml 173 | # Note: Comment the next line if you want to checkin your web deploy settings, 174 | # but database connection strings (with potential passwords) will be unencrypted 175 | *.pubxml 176 | *.publishproj 177 | 178 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 179 | # checkin your Azure Web App publish settings, but sensitive information contained 180 | # in these scripts will be unencrypted 181 | PublishScripts/ 182 | 183 | # NuGet Packages 184 | *.nupkg 185 | # The packages folder can be ignored because of Package Restore 186 | **/[Pp]ackages/* 187 | # except build/, which is used as an MSBuild target. 188 | !**/[Pp]ackages/build/ 189 | # Uncomment if necessary however generally it will be regenerated when needed 190 | #!**/[Pp]ackages/repositories.config 191 | # NuGet v3's project.json files produces more ignorable files 192 | *.nuget.props 193 | *.nuget.targets 194 | 195 | # Microsoft Azure Build Output 196 | csx/ 197 | *.build.csdef 198 | 199 | # Microsoft Azure Emulator 200 | ecf/ 201 | rcf/ 202 | 203 | # Windows Store app package directories and files 204 | AppPackages/ 205 | BundleArtifacts/ 206 | Package.StoreAssociation.xml 207 | _pkginfo.txt 208 | *.appx 209 | 210 | # Visual Studio cache files 211 | # files ending in .cache can be ignored 212 | *.[Cc]ache 213 | # but keep track of directories ending in .cache 214 | !?*.[Cc]ache/ 215 | 216 | # Others 217 | ClientBin/ 218 | ~$* 219 | *~ 220 | *.dbmdl 221 | *.dbproj.schemaview 222 | *.jfm 223 | *.pfx 224 | *.publishsettings 225 | orleans.codegen.cs 226 | 227 | # Including strong name files can present a security risk 228 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 229 | #*.snk 230 | 231 | # Since there are multiple workflows, uncomment next line to ignore bower_components 232 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 233 | #bower_components/ 234 | 235 | # RIA/Silverlight projects 236 | Generated_Code/ 237 | 238 | # Backup & report files from converting an old project file 239 | # to a newer Visual Studio version. Backup files are not needed, 240 | # because we have git ;-) 241 | _UpgradeReport_Files/ 242 | Backup*/ 243 | UpgradeLog*.XML 244 | UpgradeLog*.htm 245 | ServiceFabricBackup/ 246 | *.rptproj.bak 247 | 248 | # SQL Server files 249 | *.mdf 250 | *.ldf 251 | *.ndf 252 | 253 | # Business Intelligence projects 254 | *.rdl.data 255 | *.bim.layout 256 | *.bim_*.settings 257 | *.rptproj.rsuser 258 | *- Backup*.rdl 259 | 260 | # Microsoft Fakes 261 | FakesAssemblies/ 262 | 263 | # GhostDoc plugin setting file 264 | *.GhostDoc.xml 265 | 266 | # Node.js Tools for Visual Studio 267 | .ntvs_analysis.dat 268 | node_modules/ 269 | 270 | # Visual Studio 6 build log 271 | *.plg 272 | 273 | # Visual Studio 6 workspace options file 274 | *.opt 275 | 276 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 277 | *.vbw 278 | 279 | # Visual Studio LightSwitch build output 280 | **/*.HTMLClient/GeneratedArtifacts 281 | **/*.DesktopClient/GeneratedArtifacts 282 | **/*.DesktopClient/ModelManifest.xml 283 | **/*.Server/GeneratedArtifacts 284 | **/*.Server/ModelManifest.xml 285 | _Pvt_Extensions 286 | 287 | # Paket dependency manager 288 | .paket/paket.exe 289 | paket-files/ 290 | 291 | # FAKE - F# Make 292 | .fake/ 293 | 294 | # JetBrains Rider 295 | .idea/ 296 | *.sln.iml 297 | 298 | # CodeRush personal settings 299 | .cr/personal 300 | 301 | # Python Tools for Visual Studio (PTVS) 302 | __pycache__/ 303 | *.pyc 304 | *.pyd 305 | *.so 306 | 307 | # Cake - Uncomment if you are using it 308 | # tools/** 309 | # !tools/packages.config 310 | 311 | # Tabs Studio 312 | *.tss 313 | 314 | # Telerik's JustMock configuration file 315 | *.jmconfig 316 | 317 | # BizTalk build output 318 | *.btp.cs 319 | *.btm.cs 320 | *.odx.cs 321 | *.xsd.cs 322 | 323 | # OpenCover UI analysis results 324 | OpenCover/ 325 | 326 | # Azure Stream Analytics local run output 327 | ASALocalRun/ 328 | 329 | # MSBuild Binary and Structured Log 330 | *.binlog 331 | 332 | # NVidia Nsight GPU debugger configuration file 333 | *.nvuser 334 | 335 | # MFractors (Xamarin productivity tool) working folder 336 | .mfractor/ 337 | 338 | # Local History for Visual Studio 339 | .localhistory/ 340 | 341 | # BeatPulse healthcheck temp database 342 | healthchecksdb 343 | 344 | # Cmake 345 | cmake-build-*/ 346 | CMakeLists.txt 347 | 348 | # Python build 349 | build/ 350 | dist/ 351 | wheelhouse/ 352 | *.egg-info/ 353 | 354 | # secret 355 | develop/ 356 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Hoang-Nhat Tran 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | recursive-include torchpairwise *.cpp 2 | recursive-include torchpairwise *.cu 3 | include README.md 4 | include LICENSE.txt 5 | include requirements*.txt 6 | 7 | recursive-exclude * __pycache__ 8 | recursive-exclude * *.py[co] 9 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "setuptools>=42", 4 | "wheel", 5 | # "--find-links https://download.pytorch.org/whl/cu128; sys_platform == \"linux\" or sys_platform == \"win32\"", 6 | "torch>=2.7.0,<2.8.0", 7 | ] 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.7.0,<2.8.0 2 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = torchpairwise 3 | description = Pairwise Metrics for PyTorch 4 | long_description = file: README.md 5 | long_description_content_type = text/markdown 6 | license = MIT 7 | license_files = LICENSE.txt 8 | author = Hoang-Nhat Tran (inspiros) 9 | author_email = hnhat.tran@gmail.com 10 | url = https://github.com/inspiros/torchpairwise 11 | download_urls = https://test_pypi.org/project/torchpairwise 12 | project_urls = 13 | Source = https://github.com/inspiros/torchpairwise 14 | classifiers = 15 | Development Status :: 4 - Beta 16 | Environment :: GPU :: NVIDIA CUDA :: 12 17 | Intended Audience :: Developers 18 | Intended Audience :: Education 19 | Intended Audience :: Science/Research 20 | License :: OSI Approved :: MIT License 21 | Programming Language :: Python :: 3 22 | Programming Language :: Python :: 3 :: Only 23 | Programming Language :: Python :: 3.9 24 | Programming Language :: Python :: 3.10 25 | Programming Language :: Python :: 3.11 26 | Programming Language :: Python :: 3.12 27 | Programming Language :: Python :: 3.13 28 | Operating System :: OS Independent 29 | Topic :: Scientific/Engineering :: Mathematics 30 | Topic :: Scientific/Engineering :: Artificial Intelligence 31 | keywords = pairwise_metric,pairwise_distance,kernel_function 32 | 33 | [options] 34 | zip_safe = False 35 | include_package_data = True 36 | packages = find: 37 | python_requires = >=3.9 38 | setup_requires = torch>=2.7.0,<2.8.0 39 | install_requires = torch>=2.7.0,<2.8.0 40 | 41 | [options.extras_require] 42 | examples = scipy; scikit-learn; tqdm 43 | tests = scipy; scikit-learn; tqdm 44 | 45 | [options.packages.find] 46 | exclude = 47 | examples* 48 | tools* 49 | docs* 50 | tests* 51 | resources* 52 | 53 | [options.package_data] 54 | * = *.h, *.hpp, *.cuh, *.c, *.cpp, *.cu, *. 55 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import sys 4 | 5 | import torch 6 | from setuptools import setup 7 | from torch.utils.cpp_extension import CUDAExtension 8 | from torch.utils.cpp_extension import CUDA_HOME 9 | from torch.utils.cpp_extension import CppExtension 10 | 11 | PACKAGE_ROOT = 'torchpairwise' 12 | 13 | 14 | def get_version(version_file='_version.py'): 15 | import importlib.util 16 | version_file_path = os.path.abspath(os.path.join(PACKAGE_ROOT, version_file)) 17 | try: 18 | spec = importlib.util.spec_from_file_location('_version', version_file_path) 19 | version_module = importlib.util.module_from_spec(spec) 20 | spec.loader.exec_module(version_module) 21 | return str(version_module.__version__) 22 | except: 23 | return '0.0.0' 24 | 25 | 26 | def get_parallel_options(backend=None): 27 | parallel_extra_compile_args = [] 28 | parallel_define_macros = [] 29 | 30 | if backend is not None: 31 | if backend.upper() not in ['OPENMP', 'NATIVE', 'NATIVE_TBB']: 32 | raise ValueError('Parallel backend options are OPENMP, NATIVE, or NATIVE_TBB. ' 33 | f'Got unknown backend {backend}.') 34 | else: # detect torch parallel backend 35 | parallel_info_string = torch.__config__.parallel_info() 36 | parallel_info_array = parallel_info_string.splitlines() 37 | backend_lines = [line for line in parallel_info_array 38 | if line.startswith('ATen parallel backend:')] 39 | if len(backend_lines): 40 | backend = backend_lines[0].rsplit(': ')[1] 41 | 42 | backend = backend.lower() if backend is not None else '' 43 | if backend == 'openmp': 44 | parallel_define_macros += [('AT_PARALLEL_OPENMP', None)] 45 | if sys.platform == 'darwin': 46 | parallel_extra_compile_args.append('-Xpreprocessor') 47 | parallel_extra_compile_args.append('/openmp' if sys.platform == 'win32' else '-fopenmp') 48 | if sys.platform == 'darwin': 49 | parallel_extra_compile_args.append('-lomp') 50 | elif backend.startswith('native'): 51 | if backend.endswith('tbb'): 52 | parallel_define_macros += [('AT_PARALLEL_NATIVE_TBB', None)] 53 | else: 54 | parallel_define_macros += [('AT_PARALLEL_NATIVE', None)] 55 | return parallel_extra_compile_args, parallel_define_macros 56 | 57 | 58 | def get_extensions(): 59 | extensions_dir = os.path.join(PACKAGE_ROOT, 'csrc') 60 | 61 | main_file = (glob.glob(os.path.join(extensions_dir, '*.cpp')) + 62 | glob.glob(os.path.join(extensions_dir, 'ops', '*.cpp')) 63 | ) 64 | 65 | source_cpu = (glob.glob(os.path.join(extensions_dir, 'ops', 'cpu', '*.cpp')) + 66 | glob.glob(os.path.join(extensions_dir, 'ops', 'dispatch', '*.cpp')) + 67 | glob.glob(os.path.join(extensions_dir, 'ops', 'autograd', '*.cpp')) 68 | ) 69 | 70 | source_cuda = glob.glob(os.path.join(extensions_dir, 'ops', 'cuda', '*.cu')) 71 | source_cuda += glob.glob(os.path.join(extensions_dir, 'ops', 'autocast', '*.cpp')) 72 | 73 | sources = main_file + source_cpu 74 | extension = CppExtension 75 | extra_compile_args = {'cxx': []} 76 | extra_compile_args['cxx'].append('/std:c++17' if sys.platform == 'win32' else '-std=c++17') 77 | define_macros = [] 78 | 79 | print('Compiling extensions with following flags:') 80 | force_cuda = os.getenv('FORCE_CUDA', '0') == '1' 81 | print(f' FORCE_CUDA: {force_cuda}') 82 | debug_mode = os.getenv('DEBUG', '0') == '1' 83 | print(f' DEBUG: {debug_mode}') 84 | 85 | nvcc_flags = os.getenv('NVCC_FLAGS', '') 86 | print(f' NVCC_FLAGS: {nvcc_flags}') 87 | 88 | # enable cpu parallel 89 | parallel_extra_compile_args, parallel_define_macros = get_parallel_options('openmp') 90 | extra_compile_args['cxx'] += parallel_extra_compile_args 91 | define_macros += parallel_define_macros 92 | 93 | # enable cuda 94 | if (torch.cuda.is_available() and CUDA_HOME is not None) or force_cuda: 95 | extension = CUDAExtension 96 | sources += source_cuda 97 | define_macros += [('WITH_CUDA', None)] 98 | if nvcc_flags == '': 99 | nvcc_flags = [] 100 | else: 101 | nvcc_flags = nvcc_flags.split(' ') 102 | nvcc_flags.append('-std=c++17') 103 | extra_compile_args['nvcc'] = nvcc_flags 104 | 105 | if sys.platform == 'win32': 106 | define_macros += [(f'{PACKAGE_ROOT}_EXPORTS', None)] 107 | define_macros += [('USE_PYTHON', None)] 108 | extra_compile_args['cxx'].append('/MP') 109 | 110 | if debug_mode: 111 | print('Compiling in debug mode') 112 | extra_compile_args['cxx'].append('-g') 113 | extra_compile_args['cxx'].append('-O0') 114 | if 'nvcc' in extra_compile_args: 115 | # we have to remove '-OX' and '-g' flag if exists and append 116 | nvcc_flags = extra_compile_args['nvcc'] 117 | extra_compile_args['nvcc'] = [f for f in nvcc_flags if not ('-O' in f or '-g' in f)] 118 | extra_compile_args['nvcc'].append('-O0') 119 | extra_compile_args['nvcc'].append('-g') 120 | 121 | include_dirs = [extensions_dir] 122 | ext_modules = [ 123 | extension( 124 | f'{PACKAGE_ROOT}._C', 125 | sources=sources, 126 | include_dirs=include_dirs, 127 | define_macros=define_macros, 128 | extra_compile_args=extra_compile_args, 129 | ) 130 | ] 131 | return ext_modules 132 | 133 | 134 | def setup_package(): 135 | setup( 136 | version=get_version(), 137 | ext_modules=get_extensions(), 138 | cmdclass={ 139 | 'build_ext': torch.utils.cpp_extension.BuildExtension 140 | }, 141 | ) 142 | 143 | 144 | if __name__ == '__main__': 145 | setup_package() 146 | -------------------------------------------------------------------------------- /tests/test_ops.py: -------------------------------------------------------------------------------- 1 | import scipy.spatial.distance as sci_dist 2 | import sklearn.metrics.pairwise as sklearn_pw 3 | import torch 4 | 5 | import torchpairwise 6 | 7 | 8 | def test_boolean_kernels(device='cuda'): 9 | x1 = torch.randint(0, 3, (10, 5), device=device) 10 | x2 = torch.randint(0, 3, (8, 5), device=device) 11 | 12 | print('dice') 13 | output = torchpairwise.dice_distances(x1, x2) 14 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 15 | x2.detach().cpu(), 16 | metric='dice') 17 | print(output.detach().cpu() - py_output) 18 | 19 | print('hamming') 20 | output = torchpairwise.hamming_distances(x1, x2) 21 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 22 | x2.detach().cpu(), 23 | metric='hamming') 24 | print(output.detach().cpu() - py_output) 25 | 26 | print('jaccard') 27 | output = torchpairwise.jaccard_distances(x1, x2) 28 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 29 | x2.detach().cpu(), 30 | metric='jaccard') 31 | print(output.detach().cpu() - py_output) 32 | 33 | print('kulsinski') 34 | output = torchpairwise.kulsinski_distances(x1, x2) 35 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 36 | x2.detach().cpu(), 37 | metric='kulsinski') 38 | print(output.detach().cpu() - py_output) 39 | 40 | print('rogerstanimoto') 41 | output = torchpairwise.rogerstanimoto_distances(x1, x2) 42 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 43 | x2.detach().cpu(), 44 | metric='rogerstanimoto') 45 | print(output.detach().cpu() - py_output) 46 | 47 | print('russellrao') 48 | output = torchpairwise.russellrao_distances(x1, x2) 49 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 50 | x2.detach().cpu(), 51 | metric='russellrao') 52 | print(output.detach().cpu() - py_output) 53 | 54 | print('sokalmichener') 55 | output = torchpairwise.sokalmichener_distances(x1, x2) 56 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 57 | x2.detach().cpu(), 58 | metric='sokalmichener') 59 | print(output.detach().cpu() - py_output) 60 | 61 | print('sokalsneath') 62 | output = torchpairwise.sokalsneath_distances(x1, x2) 63 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 64 | x2.detach().cpu(), 65 | metric='sokalsneath') 66 | print(output.detach().cpu() - py_output) 67 | 68 | print('yule') 69 | output = torchpairwise.yule_distances(x1, x2) 70 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 71 | x2.detach().cpu(), 72 | metric='yule') 73 | print(output.detach().cpu() - py_output) 74 | 75 | 76 | def test_floating_kernels(dtype=torch.float64, device='cuda'): 77 | x1 = torch.rand(10, 5, dtype=dtype, device=device) 78 | x2 = torch.rand(8, 5, dtype=dtype, device=device) 79 | 80 | print('additive_chi2_kernel') 81 | output = torchpairwise.additive_chi2_kernel(x1, x2) 82 | py_output = sklearn_pw.additive_chi2_kernel(x1.detach().cpu(), 83 | x2.detach().cpu()) 84 | print(output.detach().cpu() - py_output) 85 | 86 | print('mahalanobis') 87 | VI = torch.full((x1.size(-1), x2.size(-1)), 0.1, dtype=dtype, device=device) 88 | output = torchpairwise.mahalanobis_distances(x1, x2, VI) 89 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 90 | x2.detach().cpu(), 91 | metric='mahalanobis', 92 | VI=VI.detach().cpu()) 93 | print(output.detach().cpu() - py_output) 94 | 95 | print('seuclidean') 96 | V = torch.full((x1.size(-1),), 0.1, dtype=dtype, device=device) 97 | output = torchpairwise.seuclidean_distances(x1, x2, V) 98 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 99 | x2.detach().cpu(), 100 | metric='seuclidean', 101 | V=V.detach().cpu()) 102 | print(output.detach().cpu() - py_output) 103 | 104 | print('braycurtis') 105 | output = torchpairwise.braycurtis_distances(x1, x2) 106 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 107 | x2.detach().cpu(), 108 | metric='braycurtis') 109 | print(output.detach().cpu() - py_output) 110 | 111 | print('canberra') 112 | output = torchpairwise.canberra_distances(x1, x2) 113 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 114 | x2.detach().cpu(), 115 | metric='canberra') 116 | print(output.detach().cpu() - py_output) 117 | 118 | print('cosine') 119 | output = torchpairwise.cosine_distances(x1, x2) 120 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 121 | x2.detach().cpu(), 122 | metric='cosine') 123 | print(output.detach().cpu() - py_output) 124 | 125 | print('correlation') 126 | output = torchpairwise.correlation_distances(x1, x2) 127 | py_output = sklearn_pw.pairwise_distances(x1.detach().cpu(), 128 | x2.detach().cpu(), 129 | metric='correlation') 130 | print(output.detach().cpu() - py_output) 131 | 132 | print('jensenshannon') 133 | output = torchpairwise.jensenshannon_distances(x1, x2) 134 | py_output = sci_dist.cdist(x1.detach().cpu(), 135 | x2.detach().cpu(), 136 | metric='jensenshannon') 137 | print(output.detach().cpu() - py_output) 138 | 139 | print('snr') 140 | output = torchpairwise.snr_distances(x1, x2) 141 | py_output = sci_dist.cdist(x1.detach().cpu(), 142 | x2.detach().cpu(), 143 | metric=lambda u, v: (v - u).var() / u.var()) 144 | print(output.detach().cpu() - py_output) 145 | 146 | print('directed_hausdorff') 147 | x1 = torch.rand(10, 9, 3) 148 | x2 = torch.rand(8, 7, 3) 149 | output, x1_inds, x2_inds = torchpairwise.directed_hausdorff_distances(x1, x2, shuffle=True) 150 | print(output) 151 | print(x1_inds) 152 | print(x2_inds) 153 | py_output = sci_dist.directed_hausdorff(x1[-1].detach().cpu(), 154 | x2[-1].detach().cpu()) 155 | print(py_output) 156 | # print(output.detach().cpu() - py_output) 157 | 158 | gen = torch.Generator(device=device) 159 | gen.manual_seed(1) 160 | x1 = x1.to(dtype=torch.float64, device=device) 161 | x2 = x2.to(dtype=torch.float64, device=device) 162 | x1.requires_grad_() 163 | x2.requires_grad_() 164 | grad_correct = torch.autograd.gradcheck( 165 | lambda x, y: torchpairwise.directed_hausdorff_distances(x, y, shuffle=True, generator=gen), inputs=(x1, x2)) 166 | print('grad_correct:', grad_correct) 167 | 168 | 169 | def test_cdist(dtype=torch.float64, device='cuda'): 170 | x1 = torch.rand(10, 5, dtype=dtype, device=device) 171 | x2 = torch.rand(8, 5, dtype=dtype, device=device) 172 | 173 | output = torchpairwise.cdist(x1, x2, metric='manhattan') 174 | print(output) 175 | print(output - torchpairwise.manhattan_distances(x1, x2)) 176 | 177 | 178 | if __name__ == '__main__': 179 | # test_boolean_kernels() 180 | test_floating_kernels() 181 | # test_cdist() 182 | -------------------------------------------------------------------------------- /tools/packaging/audit_torch_extension.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | from auditwheel.main import main 4 | 5 | 6 | LIBTORCH_LIBRARIES = [ 7 | 'libtorch.so', 8 | 'libtorch_cpu.so', 9 | 'libtorch_cuda.so', 10 | 'libtorch_python.so', 11 | 'libc10.so', 12 | 'libc10_cuda.so', 13 | ] 14 | 15 | try: 16 | from auditwheel.policy import _POLICIES as POLICIES 17 | 18 | for p in POLICIES: 19 | p['lib_whitelist'].extend(LIBTORCH_LIBRARIES) 20 | except ImportError: 21 | for lib in LIBTORCH_LIBRARIES: 22 | sys.argv.append('--exclude') 23 | sys.argv.append(lib) 24 | 25 | 26 | if __name__ == "__main__": 27 | sys.exit(main()) 28 | -------------------------------------------------------------------------------- /torchpairwise/__init__.py: -------------------------------------------------------------------------------- 1 | from .extension import _HAS_OPS, _assert_has_ops, has_ops, with_cuda, cuda_version 2 | 3 | _assert_has_ops() 4 | 5 | from . import ops 6 | from ._ops import _ops 7 | from ._version import __version__ 8 | from .ops import * 9 | -------------------------------------------------------------------------------- /torchpairwise/_ops.py: -------------------------------------------------------------------------------- 1 | from torch._ops import _OpNamespace 2 | 3 | from .extension import _assert_has_ops 4 | 5 | __all__ = ['_ops'] 6 | 7 | 8 | class _TorchPairwiseOpNameSpace(_OpNamespace): 9 | 10 | def __init__(self): 11 | super(_TorchPairwiseOpNameSpace, self).__init__('torchpairwise') 12 | 13 | def __getattr__(self, op_name): 14 | _assert_has_ops() 15 | return super(_TorchPairwiseOpNameSpace, self).__getattr__(op_name) 16 | 17 | 18 | _ops = _TorchPairwiseOpNameSpace() 19 | -------------------------------------------------------------------------------- /torchpairwise/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.3.0' 2 | -------------------------------------------------------------------------------- /torchpairwise/csrc/macros.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #if defined(_WIN32) && !defined(torchpairwise_BUILD_STATIC_LIBS) 4 | #if defined(torchpairwise_EXPORTS) 5 | #define TORCHPAIRWISE_API __declspec(dllexport) 6 | #else 7 | #define TORCHPAIRWISE_API __declspec(dllimport) 8 | #endif 9 | #else 10 | #define TORCHPAIRWISE_API 11 | #endif 12 | 13 | #if (defined __cpp_inline_variables) || __cplusplus >= 201703L 14 | #define TORCHPAIRWISE_INLINE_VARIABLE inline 15 | #else 16 | #ifdef _MSC_VER 17 | #define TORCHPAIRWISE_INLINE_VARIABLE __declspec(selectany) 18 | #define HINT_MSVC_LINKER_INCLUDE_SYMBOL 19 | #else 20 | #define TORCHPAIRWISE_INLINE_VARIABLE __attribute__((weak)) 21 | #endif 22 | #endif 23 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/additive_chi2_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "additive_chi2_kernel.h" 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _additive_chi2_kernel( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2) { 10 | static auto op = c10::Dispatcher::singleton() 11 | .findSchemaOrThrow("torchpairwise::_additive_chi2_kernel", "") 12 | .typed(); 13 | return op.call(x1, x2); 14 | } 15 | 16 | namespace detail { 17 | std::tuple __additive_chi2_kernel_backward( 18 | const at::Tensor &grad, 19 | const at::Tensor &x1, 20 | const at::Tensor &x2) { 21 | static auto op = 22 | c10::Dispatcher::singleton() 23 | .findSchemaOrThrow("torchpairwise::__additive_chi2_kernel_backward", "") 24 | .typed(); 25 | return op.call(grad, x1, x2); 26 | } 27 | } 28 | 29 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 30 | m.def(TORCH_SELECTIVE_SCHEMA( 31 | "torchpairwise::_additive_chi2_kernel(Tensor x1, Tensor x2) -> Tensor") 32 | ); 33 | m.def(TORCH_SELECTIVE_SCHEMA( 34 | "torchpairwise::__additive_chi2_kernel_backward(Tensor grad, Tensor x1, Tensor x2) -> (Tensor grad_x1, Tensor grad_x2)") 35 | ); 36 | } 37 | } // namespace ops 38 | } // namespace torchpairwise 39 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/additive_chi2_kernel.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _additive_chi2_kernel( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2); 10 | 11 | namespace detail { 12 | std::tuple __additive_chi2_kernel_backward( 13 | const at::Tensor &grad, 14 | const at::Tensor &x1, 15 | const at::Tensor &x2); 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/autograd/additive_chi2_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "../additive_chi2_kernel.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | namespace { 9 | class _AdditiveChi2KernelFunction 10 | : public torch::autograd::Function<_AdditiveChi2KernelFunction> { 11 | public: 12 | static torch::autograd::Variable forward( 13 | torch::autograd::AutogradContext *ctx, 14 | const torch::autograd::Variable &x1, 15 | const torch::autograd::Variable &x2) { 16 | at::AutoDispatchBelowADInplaceOrView g; 17 | 18 | ctx->save_for_backward({x1, x2}); 19 | 20 | auto output = _additive_chi2_kernel(x1, x2); 21 | 22 | return output; 23 | } 24 | 25 | static torch::autograd::variable_list backward( 26 | torch::autograd::AutogradContext *ctx, 27 | const torch::autograd::variable_list &grad_output) { 28 | auto saved = ctx->get_saved_variables(); 29 | auto x1 = saved[0]; 30 | auto x2 = saved[1]; 31 | 32 | auto grads = detail::__additive_chi2_kernel_backward( 33 | grad_output[0], 34 | x1, 35 | x2); 36 | auto grad_x1 = std::get<0>(grads); 37 | auto grad_x2 = std::get<1>(grads); 38 | 39 | return { 40 | grad_x1, 41 | grad_x2, 42 | }; 43 | } 44 | }; 45 | 46 | at::Tensor _additive_chi2_kernel_autograd( 47 | const at::Tensor &x1, 48 | const at::Tensor &x2) { 49 | return _AdditiveChi2KernelFunction::apply(x1, x2); 50 | } 51 | } // namespace 52 | 53 | TORCH_LIBRARY_IMPL(torchpairwise, Autograd, m) { 54 | m.impl( 55 | TORCH_SELECTIVE_NAME("torchpairwise::_additive_chi2_kernel"), 56 | TORCH_FN(_additive_chi2_kernel_autograd)); 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/autograd/braycurtis_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "../braycurtis.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | namespace { 9 | class BrayCurtisDistancesFunction 10 | : public torch::autograd::Function { 11 | public: 12 | static torch::autograd::Variable forward( 13 | torch::autograd::AutogradContext *ctx, 14 | const torch::autograd::Variable &x1, 15 | const torch::autograd::Variable &x2) { 16 | at::AutoDispatchBelowADInplaceOrView g; 17 | 18 | ctx->save_for_backward({x1, x2}); 19 | 20 | auto output = _braycurtis(x1, x2); 21 | 22 | return output; 23 | } 24 | 25 | static torch::autograd::variable_list backward( 26 | torch::autograd::AutogradContext *ctx, 27 | const torch::autograd::variable_list &grad_output) { 28 | auto saved = ctx->get_saved_variables(); 29 | auto x1 = saved[0]; 30 | auto x2 = saved[1]; 31 | 32 | auto grads = detail::__braycurtis_backward( 33 | grad_output[0], 34 | x1, 35 | x2); 36 | auto grad_x1 = std::get<0>(grads); 37 | auto grad_x2 = std::get<1>(grads); 38 | 39 | return { 40 | grad_x1, 41 | grad_x2, 42 | }; 43 | } 44 | }; 45 | 46 | at::Tensor _braycurtis_autograd( 47 | const at::Tensor &x1, 48 | const at::Tensor &x2) { 49 | return BrayCurtisDistancesFunction::apply(x1, x2); 50 | } 51 | } // namespace 52 | 53 | TORCH_LIBRARY_IMPL(torchpairwise, Autograd, m) { 54 | m.impl( 55 | TORCH_SELECTIVE_NAME("torchpairwise::_braycurtis"), 56 | TORCH_FN(_braycurtis_autograd)); 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/autograd/canberra_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "../canberra.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | namespace { 9 | class CanberraDistancesFunction 10 | : public torch::autograd::Function { 11 | public: 12 | static torch::autograd::Variable forward( 13 | torch::autograd::AutogradContext *ctx, 14 | const torch::autograd::Variable &x1, 15 | const torch::autograd::Variable &x2) { 16 | at::AutoDispatchBelowADInplaceOrView g; 17 | 18 | ctx->save_for_backward({x1, x2}); 19 | 20 | auto output = _canberra(x1, x2); 21 | 22 | return output; 23 | } 24 | 25 | static torch::autograd::variable_list backward( 26 | torch::autograd::AutogradContext *ctx, 27 | const torch::autograd::variable_list &grad_output) { 28 | auto saved = ctx->get_saved_variables(); 29 | auto x1 = saved[0]; 30 | auto x2 = saved[1]; 31 | 32 | auto grads = detail::__canberra_backward( 33 | grad_output[0], 34 | x1, 35 | x2); 36 | auto grad_x1 = std::get<0>(grads); 37 | auto grad_x2 = std::get<1>(grads); 38 | 39 | return { 40 | grad_x1, 41 | grad_x2, 42 | }; 43 | } 44 | }; 45 | 46 | at::Tensor _canberra_autograd( 47 | const at::Tensor &x1, 48 | const at::Tensor &x2) { 49 | return CanberraDistancesFunction::apply(x1, x2); 50 | } 51 | } // namespace 52 | 53 | TORCH_LIBRARY_IMPL(torchpairwise, Autograd, m) { 54 | m.impl( 55 | TORCH_SELECTIVE_NAME("torchpairwise::_canberra"), 56 | TORCH_FN(_canberra_autograd)); 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/autograd/hausdorff_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "../hausdorff.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | namespace { 9 | class DirectedHausdorffDistancesFunction 10 | : public torch::autograd::Function { 11 | public: 12 | static torch::autograd::variable_list forward( 13 | torch::autograd::AutogradContext *ctx, 14 | const torch::autograd::Variable &x1, 15 | const torch::autograd::Variable &x2, 16 | bool shuffle, 17 | c10::optional generator) { 18 | at::AutoDispatchBelowADInplaceOrView g; 19 | 20 | ctx->save_for_backward({x1, x2}); 21 | ctx->saved_data["shuffle"] = shuffle; 22 | ctx->saved_data["generator"] = generator.has_value() ? c10::make_optional(generator->clone()) 23 | : generator; 24 | 25 | at::Tensor output, x1_indices, x2_indices; 26 | std::tie(output, x1_indices, x2_indices) = _directed_hausdorff(x1, x2, shuffle, generator); 27 | ctx->mark_non_differentiable({x1_indices, x2_indices}); 28 | 29 | return { 30 | output, 31 | x1_indices, 32 | x2_indices, 33 | }; 34 | } 35 | 36 | static torch::autograd::variable_list backward( 37 | torch::autograd::AutogradContext *ctx, 38 | const torch::autograd::variable_list &grad_output) { 39 | auto saved = ctx->get_saved_variables(); 40 | auto x1 = saved[0]; 41 | auto x2 = saved[1]; 42 | auto shuffle = ctx->saved_data["shuffle"].toBool(); 43 | auto generator = ctx->saved_data["generator"].toOptional(); 44 | 45 | auto grads = detail::__directed_hausdorff_backward( 46 | grad_output[0], 47 | x1, 48 | x2, 49 | shuffle, 50 | generator); 51 | auto grad_x1 = std::get<0>(grads); 52 | auto grad_x2 = std::get<1>(grads); 53 | 54 | return { 55 | grad_x1, 56 | grad_x2, 57 | torch::autograd::Variable(), 58 | torch::autograd::Variable(), 59 | }; 60 | } 61 | }; 62 | 63 | std::tuple _directed_hausdorff_autograd( 64 | const at::Tensor &x1, 65 | const at::Tensor &x2, 66 | bool shuffle, 67 | c10::optional generator) { 68 | auto result = DirectedHausdorffDistancesFunction::apply(x1, x2, shuffle, generator); 69 | return std::make_tuple(result[0], result[1], result[2]); 70 | } 71 | } // namespace 72 | 73 | TORCH_LIBRARY_IMPL(torchpairwise, Autograd, m) { 74 | m.impl( 75 | TORCH_SELECTIVE_NAME("torchpairwise::_directed_hausdorff"), 76 | TORCH_FN(_directed_hausdorff_autograd)); 77 | } 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/autograd/haversine_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "../haversine.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | namespace { 9 | class HaversineDistancesFunction 10 | : public torch::autograd::Function { 11 | public: 12 | static torch::autograd::Variable forward( 13 | torch::autograd::AutogradContext *ctx, 14 | const torch::autograd::Variable &x1, 15 | const torch::autograd::Variable &x2) { 16 | at::AutoDispatchBelowADInplaceOrView g; 17 | 18 | auto output = _haversine(x1, x2); 19 | 20 | ctx->save_for_backward({x1, x2}); 21 | 22 | return output; 23 | } 24 | 25 | static torch::autograd::variable_list backward( 26 | torch::autograd::AutogradContext *ctx, 27 | const torch::autograd::variable_list &grad_output) { 28 | auto saved = ctx->get_saved_variables(); 29 | auto x1 = saved[0]; 30 | auto x2 = saved[1]; 31 | 32 | auto grads = detail::__haversine_backward( 33 | grad_output[0], x1, x2); 34 | auto grad_x1 = std::get<0>(grads); 35 | auto grad_x2 = std::get<1>(grads); 36 | 37 | return { 38 | grad_x1, 39 | grad_x2, 40 | }; 41 | } 42 | }; 43 | 44 | at::Tensor _haversine_autograd( 45 | const at::Tensor &x1, 46 | const at::Tensor &x2) { 47 | return HaversineDistancesFunction::apply(x1, x2); 48 | } 49 | } // namespace 50 | 51 | TORCH_LIBRARY_IMPL(torchpairwise, Autograd, m) { 52 | m.impl( 53 | TORCH_SELECTIVE_NAME("torchpairwise::_haversine"), 54 | TORCH_FN(_haversine_autograd)); 55 | } 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/autograd/ppminkowski_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "../ppminkowski.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | namespace { 9 | class PPowerMinkowskiFunction 10 | : public torch::autograd::Function { 11 | public: 12 | static torch::autograd::Variable forward( 13 | torch::autograd::AutogradContext *ctx, 14 | const torch::autograd::Variable &x1, 15 | const torch::autograd::Variable &x2, 16 | double p) { 17 | at::AutoDispatchBelowADInplaceOrView g; 18 | 19 | auto output = _ppminkowski(x1, x2, p); 20 | 21 | ctx->save_for_backward({x1, x2}); 22 | ctx->saved_data["p"] = p; 23 | 24 | return output; 25 | } 26 | 27 | static torch::autograd::variable_list backward( 28 | torch::autograd::AutogradContext *ctx, 29 | const torch::autograd::variable_list &grad_output) { 30 | auto saved = ctx->get_saved_variables(); 31 | auto x1 = saved[0]; 32 | auto x2 = saved[1]; 33 | double p = ctx->saved_data["p"].toDouble(); 34 | 35 | auto grads = detail::__ppminkowski_backward( 36 | grad_output[0], x1, x2, p); 37 | auto grad_x1 = std::get<0>(grads); 38 | auto grad_x2 = std::get<1>(grads); 39 | 40 | return { 41 | grad_x1, 42 | grad_x2, 43 | torch::autograd::Variable(), 44 | }; 45 | } 46 | }; 47 | 48 | at::Tensor _ppminkowski_autograd( 49 | const at::Tensor &x1, 50 | const at::Tensor &x2, 51 | double p) { 52 | return PPowerMinkowskiFunction::apply(x1, x2, p); 53 | } 54 | } // namespace 55 | 56 | TORCH_LIBRARY_IMPL(torchpairwise, Autograd, m) { 57 | m.impl( 58 | TORCH_SELECTIVE_NAME("torchpairwise::_ppminkowski"), 59 | TORCH_FN(_ppminkowski_autograd)); 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/autograd/snr_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "../snr.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | namespace { 9 | class SignalToNoiseRatioFunction 10 | : public torch::autograd::Function { 11 | public: 12 | static torch::autograd::Variable forward( 13 | torch::autograd::AutogradContext *ctx, 14 | const torch::autograd::Variable &x1, 15 | const torch::autograd::Variable &x2) { 16 | at::AutoDispatchBelowADInplaceOrView g; 17 | 18 | ctx->save_for_backward({x1, x2}); 19 | 20 | auto output = _snr(x1, x2); 21 | 22 | return output; 23 | } 24 | 25 | static torch::autograd::variable_list backward( 26 | torch::autograd::AutogradContext *ctx, 27 | const torch::autograd::variable_list &grad_output) { 28 | auto saved = ctx->get_saved_variables(); 29 | auto x1 = saved[0]; 30 | auto x2 = saved[1]; 31 | 32 | auto grads = detail::__snr_backward( 33 | grad_output[0], 34 | x1, 35 | x2); 36 | auto grad_x1 = std::get<0>(grads); 37 | auto grad_x2 = std::get<1>(grads); 38 | 39 | return { 40 | grad_x1, 41 | grad_x2, 42 | }; 43 | } 44 | }; 45 | 46 | at::Tensor _snr_autograd( 47 | const at::Tensor &x1, 48 | const at::Tensor &x2) { 49 | return SignalToNoiseRatioFunction::apply(x1, x2); 50 | } 51 | } // namespace 52 | 53 | TORCH_LIBRARY_IMPL(torchpairwise, Autograd, m) { 54 | m.impl( 55 | TORCH_SELECTIVE_NAME("torchpairwise::_snr"), 56 | TORCH_FN(_snr_autograd)); 57 | } 58 | } 59 | } 60 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/autograd/sqjensenshannon_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "../sqjensenshannon.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | namespace { 9 | class SquaredJensenShannonDistancesFunction 10 | : public torch::autograd::Function { 11 | public: 12 | static torch::autograd::Variable forward( 13 | torch::autograd::AutogradContext *ctx, 14 | const torch::autograd::Variable &x1, 15 | const torch::autograd::Variable &x2, 16 | c10::optional base) { 17 | at::AutoDispatchBelowADInplaceOrView g; 18 | 19 | ctx->save_for_backward({x1, x2}); 20 | ctx->saved_data["base"] = base; 21 | 22 | auto output = _sqjensenshannon(x1, x2, base); 23 | 24 | return output; 25 | } 26 | 27 | static torch::autograd::variable_list backward( 28 | torch::autograd::AutogradContext *ctx, 29 | const torch::autograd::variable_list &grad_output) { 30 | auto saved = ctx->get_saved_variables(); 31 | auto x1 = saved[0]; 32 | auto x2 = saved[1]; 33 | auto base = ctx->saved_data["base"].toOptional(); 34 | 35 | auto grads = detail::__sqjensenshannon_backward( 36 | grad_output[0], 37 | x1, 38 | x2, 39 | base); 40 | auto grad_x1 = std::get<0>(grads); 41 | auto grad_x2 = std::get<1>(grads); 42 | 43 | return { 44 | grad_x1, 45 | grad_x2, 46 | torch::autograd::Variable(), 47 | }; 48 | } 49 | }; 50 | 51 | at::Tensor _sqjensenshannon_autograd( 52 | const at::Tensor &x1, 53 | const at::Tensor &x2, 54 | c10::optional base) { 55 | return SquaredJensenShannonDistancesFunction::apply(x1, x2, base); 56 | } 57 | } // namespace 58 | 59 | TORCH_LIBRARY_IMPL(torchpairwise, Autograd, m) { 60 | m.impl( 61 | TORCH_SELECTIVE_NAME("torchpairwise::_sqjensenshannon"), 62 | TORCH_FN(_sqjensenshannon_autograd)); 63 | } 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/autograd/sqmahalanobis_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "../sqmahalanobis.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | namespace { 9 | class SquaredMahalanobisDistancesFunction 10 | : public torch::autograd::Function { 11 | public: 12 | static torch::autograd::Variable forward( 13 | torch::autograd::AutogradContext *ctx, 14 | const torch::autograd::Variable &x1, 15 | const torch::autograd::Variable &x2, 16 | const torch::autograd::Variable &VI) { 17 | at::AutoDispatchBelowADInplaceOrView g; 18 | 19 | auto output = _sqmahalanobis(x1, x2, VI); 20 | 21 | ctx->save_for_backward({x1, x2, VI}); 22 | 23 | return output; 24 | } 25 | 26 | static torch::autograd::variable_list backward( 27 | torch::autograd::AutogradContext *ctx, 28 | const torch::autograd::variable_list &grad_output) { 29 | auto saved = ctx->get_saved_variables(); 30 | auto x1 = saved[0]; 31 | auto x2 = saved[1]; 32 | auto VI = saved[2]; 33 | 34 | auto grads = detail::__sqmahalanobis_backward( 35 | grad_output[0], x1, x2, VI); 36 | auto grad_x1 = std::get<0>(grads); 37 | auto grad_x2 = std::get<1>(grads); 38 | auto grad_VI = std::get<2>(grads); 39 | 40 | return { 41 | grad_x1, 42 | grad_x2, 43 | grad_VI, 44 | }; 45 | } 46 | }; 47 | 48 | at::Tensor _sqmahalanobis_autograd( 49 | const at::Tensor &x1, 50 | const at::Tensor &x2, 51 | const at::Tensor &VI) { 52 | return SquaredMahalanobisDistancesFunction::apply(x1, x2, VI); 53 | } 54 | } // namespace 55 | 56 | TORCH_LIBRARY_IMPL(torchpairwise, Autograd, m) { 57 | m.impl( 58 | TORCH_SELECTIVE_NAME("torchpairwise::_sqmahalanobis"), 59 | TORCH_FN(_sqmahalanobis_autograd)); 60 | } 61 | } 62 | } 63 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/autograd/wminkowski_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include "../wminkowski.h" 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | namespace { 9 | class WeightedMinkowskiFunction 10 | : public torch::autograd::Function { 11 | public: 12 | static torch::autograd::Variable forward( 13 | torch::autograd::AutogradContext *ctx, 14 | const torch::autograd::Variable &x1, 15 | const torch::autograd::Variable &x2, 16 | const torch::autograd::Variable &w, 17 | double p) { 18 | at::AutoDispatchBelowADInplaceOrView g; 19 | 20 | auto output = _wminkowski(x1, x2, w, p); 21 | 22 | ctx->save_for_backward({x1, x2, w}); 23 | ctx->saved_data["p"] = p; 24 | 25 | return output; 26 | } 27 | 28 | static torch::autograd::variable_list backward( 29 | torch::autograd::AutogradContext *ctx, 30 | const torch::autograd::variable_list &grad_output) { 31 | auto saved = ctx->get_saved_variables(); 32 | auto x1 = saved[0]; 33 | auto x2 = saved[1]; 34 | auto w = saved[2]; 35 | double p = ctx->saved_data["p"].toDouble(); 36 | 37 | auto grads = detail::__wminkowski_backward( 38 | grad_output[0], x1, x2, w, p); 39 | auto grad_x1 = std::get<0>(grads); 40 | auto grad_x2 = std::get<1>(grads); 41 | auto grad_w = std::get<2>(grads); 42 | 43 | return { 44 | grad_x1, 45 | grad_x2, 46 | grad_w, 47 | torch::autograd::Variable(), 48 | }; 49 | } 50 | }; 51 | 52 | at::Tensor _wminkowski_autograd( 53 | const at::Tensor &x1, 54 | const at::Tensor &x2, 55 | const at::Tensor &w, 56 | double p) { 57 | return WeightedMinkowskiFunction::apply(x1, x2, w, p); 58 | } 59 | } // namespace 60 | 61 | TORCH_LIBRARY_IMPL(torchpairwise, Autograd, m) { 62 | m.impl( 63 | TORCH_SELECTIVE_NAME("torchpairwise::_wminkowski"), 64 | TORCH_FN(_wminkowski_autograd)); 65 | } 66 | } 67 | } 68 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/braycurtis.cpp: -------------------------------------------------------------------------------- 1 | #include "braycurtis.h" 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _braycurtis( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2) { 10 | static auto op = c10::Dispatcher::singleton() 11 | .findSchemaOrThrow("torchpairwise::_braycurtis", "") 12 | .typed(); 13 | return op.call(x1, x2); 14 | } 15 | 16 | namespace detail { 17 | std::tuple __braycurtis_backward( 18 | const at::Tensor &grad, 19 | const at::Tensor &x1, 20 | const at::Tensor &x2) { 21 | static auto op = 22 | c10::Dispatcher::singleton() 23 | .findSchemaOrThrow("torchpairwise::__braycurtis_backward", "") 24 | .typed(); 25 | return op.call(grad, x1, x2); 26 | } 27 | } 28 | 29 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 30 | m.def(TORCH_SELECTIVE_SCHEMA( 31 | "torchpairwise::_braycurtis(Tensor x1, Tensor x2) -> Tensor") 32 | ); 33 | m.def(TORCH_SELECTIVE_SCHEMA( 34 | "torchpairwise::__braycurtis_backward(Tensor grad, Tensor x1, Tensor x2) -> (Tensor grad_x1, Tensor grad_x2)") 35 | ); 36 | } 37 | } // namespace ops 38 | } // namespace torchpairwise 39 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/braycurtis.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _braycurtis( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2); 10 | 11 | namespace detail { 12 | std::tuple __braycurtis_backward( 13 | const at::Tensor &grad, 14 | const at::Tensor &x1, 15 | const at::Tensor &x2); 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/canberra.cpp: -------------------------------------------------------------------------------- 1 | #include "canberra.h" 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _canberra( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2) { 10 | static auto op = c10::Dispatcher::singleton() 11 | .findSchemaOrThrow("torchpairwise::_canberra", "") 12 | .typed(); 13 | return op.call(x1, x2); 14 | } 15 | 16 | namespace detail { 17 | std::tuple __canberra_backward( 18 | const at::Tensor &grad, 19 | const at::Tensor &x1, 20 | const at::Tensor &x2) { 21 | static auto op = 22 | c10::Dispatcher::singleton() 23 | .findSchemaOrThrow("torchpairwise::__canberra_backward", "") 24 | .typed(); 25 | return op.call(grad, x1, x2); 26 | } 27 | } 28 | 29 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 30 | m.def(TORCH_SELECTIVE_SCHEMA( 31 | "torchpairwise::_canberra(Tensor x1, Tensor x2) -> Tensor") 32 | ); 33 | m.def(TORCH_SELECTIVE_SCHEMA( 34 | "torchpairwise::__canberra_backward(Tensor grad, Tensor x1, Tensor x2) -> (Tensor grad_x1, Tensor grad_x2)") 35 | ); 36 | } 37 | } // namespace ops 38 | } // namespace torchpairwise 39 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/canberra.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _canberra( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2); 10 | 11 | namespace detail { 12 | std::tuple __canberra_backward( 13 | const at::Tensor &grad, 14 | const at::Tensor &x1, 15 | const at::Tensor &x2); 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/common/binary_ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "../utils/dispatch.h" 6 | 7 | namespace torchpairwise { 8 | namespace ops { 9 | enum BinaryOp { 10 | // logical 11 | And, Or, Xor, 12 | // comparison 13 | Equal, NotEqual, Less, Greater, LessEqual, GreaterEqual 14 | }; 15 | 16 | template 17 | inline std::string op_name() { 18 | std::string ns = with_namespace ? "torchpairwise::" : ""; 19 | std::string signature = with_signature ? "(Tensor x1, Tensor x2) -> Tensor" : ""; 20 | switch (op) { 21 | case And: 22 | return c10::str(ns, "pwand", signature); 23 | case Or: 24 | return c10::str(ns, "pwor", signature); 25 | case Xor: 26 | return c10::str(ns, "pwxor", signature); 27 | case Equal: 28 | return c10::str(ns, "pweq", signature); 29 | case NotEqual: 30 | return c10::str(ns, "pwne", signature); 31 | case Less: 32 | return c10::str(ns, "pwlt", signature); 33 | case Greater: 34 | return c10::str(ns, "pwgt", signature); 35 | case LessEqual: 36 | return c10::str(ns, "pwle", signature); 37 | case GreaterEqual: 38 | return c10::str(ns, "pwge", signature); 39 | default: 40 | return "[unknown_op]"; 41 | } 42 | } 43 | 44 | template 45 | inline std::string op_schema_name() { 46 | return op_name(); 47 | } 48 | 49 | template 50 | inline std::string op_full_schema() { 51 | return op_name(); 52 | } 53 | } 54 | } 55 | 56 | // dispatch macros 57 | 58 | #define TORCHPAIRWISE_DISPATCH_BINARY_OP_TYPES(BINARY_OP, TYPE, NAME, ...) \ 59 | if (BINARY_OP == And || BINARY_OP == Or || BINARY_OP == Xor) { \ 60 | AT_DISPATCH_BOOLEAN_TYPE(TYPE, NAME, __VA_ARGS__); \ 61 | } else if (BINARY_OP == Equal || BINARY_OP == NotEqual) { \ 62 | AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, \ 63 | TYPE, NAME, __VA_ARGS__); \ 64 | } else { \ 65 | AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, \ 66 | TYPE, NAME, __VA_ARGS__); \ 67 | } 68 | 69 | #define TORCHPAIRWISE_DISPATCH_CONSTEXPR_BINARY_OP_TYPES(BINARY_OP, TYPE, NAME, ...) \ 70 | if constexpr (BINARY_OP == And || BINARY_OP == Or || BINARY_OP == Xor) { \ 71 | AT_DISPATCH_BOOLEAN_TYPE(TYPE, NAME, __VA_ARGS__); \ 72 | } else if constexpr (BINARY_OP == Equal || BINARY_OP == NotEqual) { \ 73 | AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::Bool, \ 74 | TYPE, NAME, __VA_ARGS__); \ 75 | } else { \ 76 | AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, \ 77 | TYPE, NAME, __VA_ARGS__); \ 78 | } 79 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/common/prf_div_mode.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | enum PRFDivMode { 9 | Zero, 10 | Identity, 11 | }; 12 | 13 | inline PRFDivMode get_prf_div_mode(c10::string_view mode) { 14 | if (mode == "zero") 15 | return Zero; 16 | else if (mode == "identity") { 17 | return Identity; 18 | } else { 19 | TORCH_CHECK(false, 20 | "mode must be either zero or identity. Got ", 21 | mode) 22 | } 23 | } 24 | } 25 | } 26 | 27 | #define TORCHPAIRWISE_DISPATCH_PRF_DIV_MODE(MODE, ...) \ 28 | auto _mode = get_prf_div_mode(MODE); \ 29 | if (_mode == Zero) { \ 30 | static constexpr auto prf_div_mode = PRFDivMode::Zero; \ 31 | __VA_ARGS__(); \ 32 | } else if (_mode == Identity) { \ 33 | static constexpr auto prf_div_mode = PRFDivMode::Identity; \ 34 | __VA_ARGS__(); \ 35 | } 36 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/common/reduction_ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "binary_ops.h" 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | enum ReductionOp { 8 | // logical 9 | All, Any, 10 | // arithmetics 11 | Sum, Prod, Mean, 12 | }; 13 | 14 | template 16 | inline std::string op_name() { 17 | std::string ns = with_namespace ? "torchpairwise::" : ""; 18 | std::string signature = with_signature ? "(Tensor x1, Tensor x2) -> Tensor" : ""; 19 | static const std::string binary_prefix = op_name() + "_"; 20 | switch (reduction_op) { 21 | case All: 22 | return c10::str(ns, binary_prefix, "all", signature); 23 | case Any: 24 | return c10::str(ns, binary_prefix, "any", signature); 25 | case Sum: 26 | return c10::str(ns, binary_prefix, "sum", signature); 27 | case Prod: 28 | return c10::str(ns, binary_prefix, "prod", signature); 29 | case Mean: 30 | return c10::str(ns, binary_prefix, "mean", signature); 31 | default: 32 | return "[unknown_op]"; 33 | } 34 | } 35 | 36 | template 37 | inline std::string op_schema_name() { 38 | return op_name(); 39 | } 40 | 41 | template 42 | inline std::string op_full_schema() { 43 | return op_name(); 44 | } 45 | } 46 | } 47 | 48 | // dispatch macros 49 | 50 | #define TORCHPAIRWISE_DISPATCH_BINARY_REDUCTION_OP_TYPES(BINARY_OP, REDUCTION_OP, TYPE, NAME, ...) \ 51 | if (REDUCTION_OP == All || REDUCTION_OP == Any) { \ 52 | using output_t = bool; \ 53 | TORCHPAIRWISE_DISPATCH_BINARY_OP_TYPES(BINARY_OP, TYPE, NAME, __VA_ARGS__) \ 54 | } else { \ 55 | using output_t = float; \ 56 | TORCHPAIRWISE_DISPATCH_BINARY_OP_TYPES(BINARY_OP, TYPE, NAME, __VA_ARGS__) \ 57 | } 58 | 59 | #define TORCHPAIRWISE_DISPATCH_CONSTEXPR_BINARY_REDUCTION_OP_TYPES(BINARY_OP, REDUCTION_OP, TYPE, NAME, ...) \ 60 | if constexpr (REDUCTION_OP == All || REDUCTION_OP == Any) { \ 61 | using output_t = bool; \ 62 | TORCHPAIRWISE_DISPATCH_CONSTEXPR_BINARY_OP_TYPES(BINARY_OP, TYPE, NAME, __VA_ARGS__) \ 63 | } else { \ 64 | using output_t = float; \ 65 | TORCHPAIRWISE_DISPATCH_CONSTEXPR_BINARY_OP_TYPES(BINARY_OP, TYPE, NAME, __VA_ARGS__) \ 66 | } 67 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cpdist.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "../macros.h" 6 | 7 | #define TORCHPAIRWISE_CPDIST_EXTRA_ARGS_SCHEMA_STR "Tensor? w=None, Tensor? V=None, Tensor? VI=None, float? p=None, float? base=None, bool? shuffle=None, Generator? generator=None" 8 | 9 | namespace torchpairwise { 10 | namespace ops { 11 | TORCHPAIRWISE_API at::Tensor cdist( 12 | const at::Tensor &x1, 13 | const at::Tensor &x2, 14 | c10::string_view metric = "minkowski", 15 | const c10::optional &w = c10::nullopt, 16 | const c10::optional &V = c10::nullopt, 17 | const c10::optional &VI = c10::nullopt, 18 | c10::optional p = c10::nullopt, 19 | c10::optional base = c10::nullopt, 20 | c10::optional shuffle = c10::nullopt, 21 | c10::optional generator = c10::nullopt); 22 | 23 | TORCHPAIRWISE_API at::Tensor pdist( 24 | const at::Tensor &input, 25 | c10::string_view metric = "minkowski", 26 | const c10::optional &w = c10::nullopt, 27 | const c10::optional &V = c10::nullopt, 28 | const c10::optional &VI = c10::nullopt, 29 | c10::optional p = c10::nullopt, 30 | c10::optional base = c10::nullopt, 31 | c10::optional shuffle = c10::nullopt, 32 | c10::optional generator = c10::nullopt); 33 | } 34 | } 35 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cpu/additive_chi2_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cpu_helpers.h" 5 | #include "../utils/dispatch.h" 6 | 7 | namespace torchpairwise { 8 | namespace ops { 9 | namespace { 10 | namespace impl { 11 | template 12 | void _additive_chi2_kernel_forward_kernel_impl( 13 | index_t n_kernels, 14 | const at::TensorAccessor x1, 15 | const at::TensorAccessor x2, 16 | at::TensorAccessor output) { 17 | CPU_1D_PARALLEL_KERNEL_LOOP(index, n_kernels) { 18 | index_t j = index % x2.size(1); 19 | index_t i = (index / x2.size(1)) % x1.size(1); 20 | index_t b = index / (x2.size(1) * x1.size(1)); 21 | 22 | scalar_t val = 0; 23 | for (index_t k = 0; k < x1.size(2); k++) { 24 | scalar_t nom = x1[b][i][k] + x2[b][j][k]; 25 | if (nom != 0) { 26 | scalar_t denom = x1[b][i][k] - x2[b][j][k]; 27 | val -= denom * denom / nom; 28 | } 29 | } 30 | output[b][i][j] = val; 31 | } 32 | } 33 | } // namespace impl 34 | 35 | at::Tensor _additive_chi2_kernel_forward_kernel( 36 | const at::Tensor &x1, 37 | const at::Tensor &x2) { 38 | at::CheckedFrom c = "_additive_chi2_kernel_forward"; 39 | auto args = { 40 | at::TensorArg(x1, "x1", 1), 41 | at::TensorArg(x2, "x2", 2)}; 42 | at::checkAllSameType(c, args); 43 | 44 | bool unbatched = x1.ndimension() == 2; 45 | TORCH_CHECK(unbatched || x1.ndimension() == 3, 46 | "x1 must be 2-D (unbatched) or 3-D (batched) tensor.") 47 | TORCH_CHECK(unbatched || x2.ndimension() == 3, 48 | "x2 must be 2-D (unbatched) or 3-D (batched) tensor.") 49 | TORCH_CHECK(unbatched || (x1.size(0) == x2.size(0)), 50 | "batch_size of x1 and x2 do not match.") 51 | TORCH_CHECK((unbatched && x1.size(1) == x2.size(1)) || 52 | (!unbatched && x1.size(2) == x2.size(2)), 53 | "feature dimension of x1 and x2 do not match.") 54 | 55 | auto x1_c = x1.contiguous(); 56 | auto x2_c = x2.contiguous(); 57 | if (unbatched) { 58 | x1_c = x1_c.unsqueeze(0); 59 | x2_c = x2_c.unsqueeze(0); 60 | } 61 | 62 | int64_t batch_size = x1_c.size(0); 63 | auto output = at::empty({batch_size, x1_c.size(1), x2_c.size(1)}, x1.options()); 64 | int64_t n_kernels = output.numel(); 65 | 66 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x1.scalar_type(), "_additive_chi2_kernel_forward_cpu", ([&] { 67 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CPU, ([&] { 68 | auto output_accessor = 69 | output.accessor(); 70 | impl::_additive_chi2_kernel_forward_kernel_impl( 71 | n_kernels, 72 | x1_c.accessor(), 73 | x2_c.accessor(), 74 | output_accessor); 75 | })); 76 | })); 77 | if (unbatched) 78 | output.squeeze_(0); 79 | return output; 80 | } 81 | 82 | namespace impl { 83 | template 84 | void _additive_chi2_kernel_backward_x1_kernel_impl( 85 | index_t n_kernels, 86 | const at::TensorAccessor grad_output, 87 | const at::TensorAccessor x1, 88 | const at::TensorAccessor x2, 89 | at::TensorAccessor grad_x1) { 90 | CPU_1D_PARALLEL_KERNEL_LOOP(index, n_kernels) { 91 | index_t k = index % x1.size(2); 92 | index_t i = (index / x1.size(2)) % x1.size(1); 93 | index_t b = index / (x1.size(2) * x1.size(1)); 94 | 95 | scalar_t val = 0; 96 | for (index_t j = 0; j < x2.size(1); j++) { 97 | scalar_t nom = x1[b][i][k] + x2[b][j][k]; 98 | if (nom != 0) { 99 | scalar_t denom = x1[b][i][k] - x2[b][j][k]; 100 | scalar_t weight = denom * denom / (nom * nom) - 2 * denom / nom; 101 | val += weight * grad_output[b][i][j]; 102 | } 103 | } 104 | grad_x1[b][i][k] = val; 105 | } 106 | } 107 | 108 | template 109 | void _additive_chi2_kernel_backward_x2_kernel_impl( 110 | index_t n_kernels, 111 | const at::TensorAccessor grad_output, 112 | const at::TensorAccessor x1, 113 | const at::TensorAccessor x2, 114 | at::TensorAccessor grad_x2) { 115 | CPU_1D_PARALLEL_KERNEL_LOOP(index, n_kernels) { 116 | index_t k = index % x2.size(2); 117 | index_t j = (index / x2.size(2)) % x2.size(1); 118 | index_t b = index / (x2.size(2) * x2.size(1)); 119 | 120 | scalar_t val = 0; 121 | for (index_t i = 0; i < x1.size(1); i++) { 122 | scalar_t nom = x1[b][i][k] + x2[b][j][k]; 123 | if (nom != 0) { 124 | scalar_t denom = x1[b][i][k] - x2[b][j][k]; 125 | scalar_t weight = 2 * denom / nom + denom * denom / (nom * nom); 126 | val += weight * grad_output[b][i][j]; 127 | } 128 | } 129 | grad_x2[b][j][k] = val; 130 | } 131 | } 132 | } // namespace impl 133 | 134 | std::tuple _additive_chi2_kernel_backward_kernel( 135 | const at::Tensor &grad_output, 136 | const at::Tensor &x1, 137 | const at::Tensor &x2) { 138 | bool unbatched = x1.ndimension() == 2; 139 | 140 | auto grad_output_c = grad_output.contiguous(); 141 | auto x1_c = x1.contiguous(); 142 | auto x2_c = x2.contiguous(); 143 | if (unbatched) { 144 | grad_output_c = grad_output_c.unsqueeze(0); 145 | x1_c = x1_c.unsqueeze(0); 146 | x2_c = x2_c.unsqueeze(0); 147 | } 148 | 149 | int64_t n_kernels; 150 | auto grad_x1 = at::zeros_like(x1_c); 151 | auto grad_x2 = at::zeros_like(x2_c); 152 | 153 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x1.scalar_type(), "_additive_chi2_kernel_backward_cpu", ([&] { 154 | n_kernels = x1_c.numel(); 155 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CPU, ([&] { 156 | auto grad_x1_accessor = 157 | grad_x1.accessor(); 158 | impl::_additive_chi2_kernel_backward_x1_kernel_impl( 159 | n_kernels, 160 | grad_output_c.accessor(), 161 | x1_c.accessor(), 162 | x2_c.accessor(), 163 | grad_x1_accessor); 164 | })); 165 | 166 | n_kernels = x2_c.numel(); 167 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CPU, ([&] { 168 | auto grad_x2_accessor = 169 | grad_x2.accessor(); 170 | impl::_additive_chi2_kernel_backward_x2_kernel_impl( 171 | n_kernels, 172 | grad_output_c.accessor(), 173 | x1_c.accessor(), 174 | x2_c.accessor(), 175 | grad_x2_accessor); 176 | })); 177 | })); 178 | if (unbatched) { 179 | grad_x1.squeeze_(0); 180 | grad_x2.squeeze_(0); 181 | } 182 | return std::make_tuple(grad_x1, grad_x2); 183 | } 184 | } 185 | 186 | TORCH_LIBRARY_IMPL(torchpairwise, CPU, m) { 187 | m.impl( 188 | TORCH_SELECTIVE_NAME("torchpairwise::_additive_chi2_kernel"), 189 | TORCH_FN(_additive_chi2_kernel_forward_kernel)); 190 | m.impl( 191 | TORCH_SELECTIVE_NAME("torchpairwise::__additive_chi2_kernel_backward"), 192 | TORCH_FN(_additive_chi2_kernel_backward_kernel)); 193 | } 194 | } 195 | } 196 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cpu/binary_ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu_helpers.h" 4 | #include "../common/binary_ops.h" 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | template 9 | __forceinline__ scalar_t call(scalar_t self, scalar_t other) { 10 | // logical 11 | if constexpr (op == And) 12 | return self & other; 13 | if constexpr (op == Or) 14 | return self | other; 15 | if constexpr (op == Xor) 16 | return self ^ other; 17 | // comparison 18 | if constexpr (op == Equal) 19 | return self == other; 20 | if constexpr (op == NotEqual) 21 | return self != other; 22 | if constexpr (op == Less) 23 | return self < other; 24 | if constexpr (op == Greater) 25 | return self > other; 26 | if constexpr (op == LessEqual) 27 | return self <= other; 28 | if constexpr (op == GreaterEqual) 29 | return self >= other; 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cpu/canberra_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cpu_helpers.h" 5 | #include "signum.h" 6 | #include "../utils/dispatch.h" 7 | 8 | namespace torchpairwise { 9 | namespace ops { 10 | namespace { 11 | namespace impl { 12 | template 13 | void _canberra_forward_kernel_impl( 14 | index_t n_kernels, 15 | const at::TensorAccessor x1, 16 | const at::TensorAccessor x2, 17 | at::TensorAccessor output) { 18 | CPU_1D_PARALLEL_KERNEL_LOOP(index, n_kernels) { 19 | index_t j = index % x2.size(1); 20 | index_t i = (index / x2.size(1)) % x1.size(1); 21 | index_t b = index / (x2.size(1) * x1.size(1)); 22 | 23 | scalar_t val = 0; 24 | for (index_t k = 0; k < x1.size(2); k++) { 25 | scalar_t denom = fabs(x1[b][i][k]) + fabs(x2[b][j][k]); 26 | if (denom != 0) 27 | val += fabs(x1[b][i][k] - x2[b][j][k]) / denom; 28 | } 29 | output[b][i][j] = val; 30 | } 31 | } 32 | } // namespace impl 33 | 34 | at::Tensor _canberra_forward_kernel( 35 | const at::Tensor &x1, 36 | const at::Tensor &x2) { 37 | at::CheckedFrom c = "_canberra_forward"; 38 | auto args = { 39 | at::TensorArg(x1, "x1", 1), 40 | at::TensorArg(x2, "x2", 2)}; 41 | at::checkAllSameType(c, args); 42 | 43 | bool unbatched = x1.ndimension() == 2; 44 | TORCH_CHECK(unbatched || x1.ndimension() == 3, 45 | "x1 must be 2-D (unbatched) or 3-D (batched) tensor.") 46 | TORCH_CHECK(unbatched || x2.ndimension() == 3, 47 | "x2 must be 2-D (unbatched) or 3-D (batched) tensor.") 48 | TORCH_CHECK(unbatched || (x1.size(0) == x2.size(0)), 49 | "batch_size of x1 and x2 do not match.") 50 | TORCH_CHECK((unbatched && x1.size(1) == x2.size(1)) || 51 | (!unbatched && x1.size(2) == x2.size(2)), 52 | "feature dimension of x1 and x2 do not match.") 53 | 54 | auto x1_c = x1.contiguous(); 55 | auto x2_c = x2.contiguous(); 56 | if (unbatched) { 57 | x1_c = x1_c.unsqueeze(0); 58 | x2_c = x2_c.unsqueeze(0); 59 | } 60 | 61 | int64_t batch_size = x1_c.size(0); 62 | auto output = at::empty({batch_size, x1_c.size(1), x2_c.size(1)}, x1.options()); 63 | int64_t n_kernels = output.numel(); 64 | 65 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x1.scalar_type(), "_canberra_forward_cpu", ([&] { 66 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CPU, ([&] { 67 | auto output_accessor = 68 | output.accessor(); 69 | impl::_canberra_forward_kernel_impl( 70 | n_kernels, 71 | x1_c.accessor(), 72 | x2_c.accessor(), 73 | output_accessor); 74 | })); 75 | })); 76 | if (unbatched) 77 | output.squeeze_(0); 78 | return output; 79 | } 80 | 81 | namespace impl { 82 | template 83 | void _canberra_backward_x1_kernel_impl( 84 | index_t n_kernels, 85 | const at::TensorAccessor grad_output, 86 | const at::TensorAccessor x1, 87 | const at::TensorAccessor x2, 88 | at::TensorAccessor grad_x1) { 89 | CPU_1D_PARALLEL_KERNEL_LOOP(index, n_kernels) { 90 | index_t k = index % x1.size(2); 91 | index_t i = (index / x1.size(2)) % x1.size(1); 92 | index_t b = index / (x1.size(2) * x1.size(1)); 93 | 94 | scalar_t val = 0; 95 | for (index_t j = 0; j < x2.size(1); j++) { 96 | scalar_t nom = fabs(x1[b][i][k]) + fabs(x2[b][j][k]); 97 | scalar_t denom = x1[b][i][k] - x2[b][j][k]; 98 | val += grad_output[b][i][j] * 99 | (m_signum(denom) / nom - m_signum(x1[b][i][k]) * fabs(denom) / nom / nom); 100 | } 101 | grad_x1[b][i][k] = val; 102 | } 103 | } 104 | 105 | template 106 | void _canberra_backward_x2_kernel_impl( 107 | index_t n_kernels, 108 | const at::TensorAccessor grad_output, 109 | const at::TensorAccessor x1, 110 | const at::TensorAccessor x2, 111 | at::TensorAccessor grad_x2) { 112 | CPU_1D_PARALLEL_KERNEL_LOOP(index, n_kernels) { 113 | index_t k = index % x2.size(2); 114 | index_t j = (index / x2.size(2)) % x2.size(1); 115 | index_t b = index / (x2.size(2) * x2.size(1)); 116 | 117 | scalar_t val = 0; 118 | for (index_t i = 0; i < x1.size(1); i++) { 119 | scalar_t nom = fabs(x1[b][i][k]) + fabs(x2[b][j][k]); 120 | scalar_t denom = x1[b][i][k] - x2[b][j][k]; 121 | val += grad_output[b][i][j] * 122 | (-m_signum(denom) / nom - m_signum(x2[b][j][k]) * fabs(denom) / nom / nom); 123 | } 124 | grad_x2[b][j][k] = val; 125 | } 126 | } 127 | } // namespace impl 128 | 129 | std::tuple _canberra_backward_kernel( 130 | const at::Tensor &grad_output, 131 | const at::Tensor &x1, 132 | const at::Tensor &x2) { 133 | bool unbatched = x1.ndimension() == 2; 134 | 135 | auto grad_output_c = grad_output.contiguous(); 136 | auto x1_c = x1.contiguous(); 137 | auto x2_c = x2.contiguous(); 138 | if (unbatched) { 139 | grad_output_c = grad_output_c.unsqueeze(0); 140 | x1_c = x1_c.unsqueeze(0); 141 | x2_c = x2_c.unsqueeze(0); 142 | } 143 | 144 | int64_t n_kernels; 145 | auto grad_x1 = at::zeros_like(x1_c); 146 | auto grad_x2 = at::zeros_like(x2_c); 147 | 148 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x1.scalar_type(), "_canberra_backward_cpu", ([&] { 149 | n_kernels = x1_c.numel(); 150 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CPU, ([&] { 151 | auto grad_x1_accessor = 152 | grad_x1.accessor(); 153 | impl::_canberra_backward_x1_kernel_impl( 154 | n_kernels, 155 | grad_output_c.accessor(), 156 | x1_c.accessor(), 157 | x2_c.accessor(), 158 | grad_x1_accessor); 159 | })); 160 | 161 | n_kernels = x2_c.numel(); 162 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CPU, ([&] { 163 | auto grad_x2_accessor = 164 | grad_x2.accessor(); 165 | impl::_canberra_backward_x2_kernel_impl( 166 | n_kernels, 167 | grad_output_c.accessor(), 168 | x1_c.accessor(), 169 | x2_c.accessor(), 170 | grad_x2_accessor); 171 | })); 172 | })); 173 | if (unbatched) { 174 | grad_x1.squeeze_(0); 175 | grad_x2.squeeze_(0); 176 | } 177 | return std::make_tuple(grad_x1, grad_x2); 178 | } 179 | } 180 | 181 | TORCH_LIBRARY_IMPL(torchpairwise, CPU, m) { 182 | m.impl( 183 | TORCH_SELECTIVE_NAME("torchpairwise::_canberra"), 184 | TORCH_FN(_canberra_forward_kernel)); 185 | m.impl( 186 | TORCH_SELECTIVE_NAME("torchpairwise::__canberra_backward"), 187 | TORCH_FN(_canberra_backward_kernel)); 188 | } 189 | } 190 | } 191 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cpu/cpu_helpers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #ifdef AT_PARALLEL_OPENMP 7 | #ifndef OMP_NUM_THREADS 8 | #define OMP_NUM_THREADS 32 9 | #endif 10 | 11 | #if defined(_MSC_VER) 12 | #define _PRAGMA_OMP_PARALLEL_FOR __pragma(omp parallel for num_threads(OMP_NUM_THREADS)) 13 | #else 14 | #define _PRAGMA_OMP_PARALLEL_FOR _Pragma("omp parallel for num_threads(OMP_NUM_THREADS)") 15 | #endif 16 | #else 17 | #define _PRAGMA_OMP_PARALLEL_FOR 18 | #endif 19 | 20 | // regular for loop 21 | #define CPU_1D_KERNEL_LOOP_BETWEEN_T(i, start, end, index_t) \ 22 | for (index_t i = start; i < end; ++i) 23 | 24 | #define CPU_1D_KERNEL_LOOP_T(i, n, index_t) \ 25 | CPU_1D_KERNEL_LOOP_BETWEEN_T(i, 0, n, index_t) 26 | 27 | #define CPU_1D_KERNEL_LOOP_BETWEEN(i, start, end) \ 28 | CPU_1D_KERNEL_LOOP_BETWEEN_T(i, start, end, \ 29 | std::remove_cv_t>>) 30 | 31 | #define CPU_1D_KERNEL_LOOP(i, n) \ 32 | CPU_1D_KERNEL_LOOP_T(i, n, std::remove_cv_t>) 33 | 34 | // openmp parallel for loop 35 | #define CPU_1D_PARALLEL_KERNEL_LOOP_BETWEEN_T(i, start, end, index_t) \ 36 | _PRAGMA_OMP_PARALLEL_FOR \ 37 | CPU_1D_KERNEL_LOOP_BETWEEN_T(i, start, end, index_t) 38 | 39 | #define CPU_1D_PARALLEL_KERNEL_LOOP_T(i, n, index_t) \ 40 | CPU_1D_PARALLEL_KERNEL_LOOP_BETWEEN_T(i, 0, n, index_t) 41 | 42 | #define CPU_1D_PARALLEL_KERNEL_LOOP_BETWEEN(i, start, end) \ 43 | CPU_1D_PARALLEL_KERNEL_LOOP_BETWEEN_T(i, start, end, \ 44 | std::remove_cv_t>>) 45 | 46 | #define CPU_1D_PARALLEL_KERNEL_LOOP(i, n) \ 47 | CPU_1D_PARALLEL_KERNEL_LOOP_T(i, n, std::remove_cv_t>) 48 | 49 | // inline 50 | #if defined(_MSC_VER) 51 | #define __forceinline__ __forceinline 52 | #elif defined(__GNUC__) && !defined(__clang__) 53 | #define __forceinline__ __attribute__((always_inline)) inline 54 | #else 55 | #define __forceinline__ inline 56 | #endif 57 | 58 | using std::min; 59 | using std::max; 60 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cpu/prf_divide.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu_helpers.h" 4 | #include "../common/prf_div_mode.h" 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | template 9 | __forceinline__ constexpr T prf_divide(const T &x, const T &y) { 10 | if constexpr (mode == Zero) 11 | return y != T(0) ? x / y : T(0); 12 | else if constexpr (mode == Identity) 13 | return y != T(0) ? x / y : x; 14 | else 15 | return x / y; 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cpu/reduction_ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "cpu_helpers.h" 6 | #include "../common/reduction_ops.h" 7 | 8 | namespace torchpairwise { 9 | namespace ops { 10 | template 11 | constexpr scalar_t identity_value() { 12 | if constexpr (op == All) 13 | return static_cast(true); 14 | if constexpr (op == Any) 15 | return static_cast(false); 16 | if constexpr (op == Sum || op == Mean) 17 | return static_cast(0); 18 | if constexpr (op == Prod) 19 | return static_cast(1); 20 | } 21 | 22 | template 23 | __forceinline__ output_t call(input_ts... args) { 24 | // logical 25 | if constexpr (op == All) 26 | return (... && args); 27 | if constexpr (op == Any) 28 | return (... || args); 29 | // arithmetics 30 | if constexpr (op == Sum) 31 | return (... + args); 32 | if constexpr (op == Prod) 33 | return (... * args); 34 | if constexpr (op == Mean) 35 | return (... + args) / static_cast(sizeof...(args)); 36 | } 37 | 38 | template 39 | __forceinline__ output_t call(const at::TensorAccessor args) { 40 | output_t output = identity_value(); 41 | for (int64_t i = 0; i < args.size(0); i++) { 42 | // logical 43 | if constexpr (op == All) 44 | output &= args[i]; 45 | if constexpr (op == Any) 46 | output |= args[i]; 47 | // arithmetics 48 | if constexpr (op == Sum || op == Mean) 49 | output += args[i]; 50 | if constexpr (op == Prod) 51 | output *= args[i]; 52 | } 53 | if constexpr (op == Mean) { 54 | output /= static_cast(args.size(0)); 55 | } 56 | return output; 57 | } 58 | 59 | template 60 | __forceinline__ void accumulate_call(output_t* val, input_t arg) { 61 | // logical 62 | if constexpr (op == All) 63 | *val &= arg; 64 | if constexpr (op == Any) 65 | *val |= arg; 66 | // arithmetics 67 | if constexpr (op == Sum || op == Mean) 68 | *val += arg; 69 | if constexpr (op == Prod) 70 | *val *= arg; 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cpu/rel_entr.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu_helpers.h" 4 | #include "../utils/scalar_type_utils.h" 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | template 9 | __forceinline__ constexpr T rel_entr(const T &x, const T &y) { 10 | if (std::isnan(x)) 11 | return x; 12 | else if (x > T(0) && y > T(0)) 13 | return x * log(x / y); 14 | else if (x == T(0) && y >= T(0)) 15 | return 0; 16 | else 17 | return c10::CPPTypeLimits::upper_bound(); 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cpu/signum.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cpu_helpers.h" 4 | #include "../utils/scalar_type_utils.h" 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | template 9 | typename std::enable_if::value, int>::type 10 | __forceinline__ constexpr m_signum(const T &x) { 11 | return T(0) < x; 12 | } 13 | 14 | template 15 | typename std::enable_if::value, int>::type 16 | __forceinline__ constexpr m_signum(const T &x) { 17 | return (T(0) < x) - (x < T(0)); 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cpu/sqjensenshannon_kernel.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cpu_helpers.h" 5 | #include "rel_entr.h" 6 | #include "../utils/dispatch.h" 7 | 8 | namespace torchpairwise { 9 | namespace ops { 10 | namespace { 11 | namespace impl { 12 | template 13 | void _sqjensenshannon_forward_kernel_impl( 14 | index_t n_kernels, 15 | const at::TensorAccessor x1, 16 | const at::TensorAccessor x2, 17 | at::TensorAccessor output) { 18 | CPU_1D_PARALLEL_KERNEL_LOOP(index, n_kernels) { 19 | index_t j = index % x2.size(1); 20 | index_t i = (index / x2.size(1)) % x1.size(1); 21 | index_t b = index / (x2.size(1) * x1.size(1)); 22 | 23 | scalar_t val = 0, m; 24 | for (int64_t k = 0; k < x1.size(2); k++) { 25 | m = (x1[b][i][k] + x2[b][j][k]) / static_cast(2); 26 | val += rel_entr(x1[b][i][k], m) + rel_entr(x2[b][j][k], m); 27 | } 28 | output[b][i][j] = val; 29 | } 30 | } 31 | } // namespace impl 32 | 33 | at::Tensor _sqjensenshannon_forward_kernel( 34 | const at::Tensor &x1, 35 | const at::Tensor &x2, 36 | c10::optional base) { 37 | at::CheckedFrom c = "_sqjensenshannon_forward"; 38 | auto args = { 39 | at::TensorArg(x1, "x1", 1), 40 | at::TensorArg(x2, "x2", 2)}; 41 | at::checkAllSameType(c, args); 42 | 43 | bool unbatched = x1.ndimension() == 2; 44 | TORCH_CHECK(unbatched || x1.ndimension() == 3, 45 | "x1 must be 2-D (unbatched) or 3-D (batched) tensor.") 46 | TORCH_CHECK(unbatched || x2.ndimension() == 3, 47 | "x2 must be 2-D (unbatched) or 3-D (batched) tensor.") 48 | TORCH_CHECK(unbatched || (x1.size(0) == x2.size(0)), 49 | "batch_size of x1 and x2 do not match.") 50 | TORCH_CHECK((unbatched && x1.size(1) == x2.size(1)) || 51 | (!unbatched && x1.size(2) == x2.size(2)), 52 | "feature dimension of x1 and x2 do not match.") 53 | 54 | auto x1_c = x1.contiguous(); 55 | auto x2_c = x2.contiguous(); 56 | if (unbatched) { 57 | x1_c = x1_c.unsqueeze(0); 58 | x2_c = x2_c.unsqueeze(0); 59 | } 60 | 61 | int64_t batch_size = x1_c.size(0); 62 | auto output = at::empty({batch_size, x1_c.size(1), x2_c.size(1)}, x1.options()); 63 | int64_t n_kernels = output.numel(); 64 | 65 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x1.scalar_type(), "_sqjensenshannon_forward_cpu", ([&] { 66 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CPU, ([&] { 67 | auto output_accessor = 68 | output.accessor(); 69 | impl::_sqjensenshannon_forward_kernel_impl( 70 | n_kernels, 71 | x1_c.accessor(), 72 | x2_c.accessor(), 73 | output_accessor); 74 | output.div_(base.has_value() ? 2 * log(static_cast(base.value())) : 2); 75 | })); 76 | })); 77 | if (unbatched) 78 | output.squeeze_(0); 79 | return output; 80 | } 81 | 82 | namespace impl { 83 | template 84 | void _sqjensenshannon_backward_x1_kernel_impl( 85 | index_t n_kernels, 86 | const at::TensorAccessor grad_output, 87 | const at::TensorAccessor x1, 88 | const at::TensorAccessor x2, 89 | at::TensorAccessor grad_x1) { 90 | CPU_1D_PARALLEL_KERNEL_LOOP(index, n_kernels) { 91 | index_t k = index % x1.size(2); 92 | index_t i = (index / x1.size(2)) % x1.size(1); 93 | index_t b = index / (x1.size(2) * x1.size(1)); 94 | 95 | scalar_t val = 0, sum, m; 96 | for (index_t j = 0; j < x2.size(1); j++) { 97 | sum = x1[b][i][k] + x2[b][j][k]; 98 | m = sum / static_cast(2); 99 | if (m > static_cast(0)) { 100 | if (x1[b][i][k] > static_cast(0)) { 101 | val += grad_output[b][i][j] * 102 | (x2[b][j][k] + sum * log(x1[b][i][k] / m)) / sum; 103 | } 104 | if (x2[b][j][k] > static_cast(0)) { 105 | val += grad_output[b][i][j] * -x2[b][j][k] / sum; 106 | } 107 | } 108 | } 109 | grad_x1[b][i][k] += val; 110 | } 111 | } 112 | 113 | template 114 | void _sqjensenshannon_backward_x2_kernel_impl( 115 | index_t n_kernels, 116 | const at::TensorAccessor grad_output, 117 | const at::TensorAccessor x1, 118 | const at::TensorAccessor x2, 119 | at::TensorAccessor grad_x2) { 120 | CPU_1D_PARALLEL_KERNEL_LOOP(index, n_kernels) { 121 | index_t k = index % x2.size(2); 122 | index_t j = (index / x2.size(2)) % x2.size(1); 123 | index_t b = index / (x2.size(2) * x2.size(1)); 124 | 125 | scalar_t val = 0, sum, m; 126 | for (index_t i = 0; i < x1.size(1); i++) { 127 | sum = x1[b][i][k] + x2[b][j][k]; 128 | m = sum / static_cast(2); 129 | if (m > static_cast(0)) { 130 | if (x1[b][i][k] > static_cast(0)) { 131 | val += grad_output[b][i][j] * -x1[b][i][k] / sum; 132 | } 133 | if (x2[b][j][k] > static_cast(0)) { 134 | val += grad_output[b][i][j] * 135 | (x1[b][i][k] + sum * log(x2[b][j][k] / m)) / sum; 136 | } 137 | } 138 | } 139 | grad_x2[b][j][k] += val; 140 | } 141 | } 142 | } // namespace impl 143 | 144 | std::tuple _sqjensenshannon_backward_kernel( 145 | const at::Tensor &grad_output, 146 | const at::Tensor &x1, 147 | const at::Tensor &x2, 148 | c10::optional base) { 149 | bool unbatched = x1.ndimension() == 2; 150 | 151 | auto grad_output_c = grad_output.contiguous(); 152 | auto x1_c = x1.contiguous(); 153 | auto x2_c = x2.contiguous(); 154 | if (unbatched) { 155 | grad_output_c = grad_output_c.unsqueeze(0); 156 | x1_c = x1_c.unsqueeze(0); 157 | x2_c = x2_c.unsqueeze(0); 158 | } 159 | 160 | int64_t n_kernels; 161 | auto grad_x1 = at::zeros_like(x1_c); 162 | auto grad_x2 = at::zeros_like(x2_c); 163 | 164 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x1.scalar_type(), "_sqjensenshannon_backward_cpu", ([&] { 165 | grad_output_c = grad_output_c.div( 166 | base.has_value() ? 2 * log(static_cast(base.value())) : 2); 167 | 168 | n_kernels = x1_c.numel(); 169 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CPU, ([&] { 170 | auto grad_x1_accessor = 171 | grad_x1.accessor(); 172 | impl::_sqjensenshannon_backward_x1_kernel_impl( 173 | n_kernels, 174 | grad_output_c.accessor(), 175 | x1_c.accessor(), 176 | x2_c.accessor(), 177 | grad_x1_accessor); 178 | })); 179 | 180 | n_kernels = x2_c.numel(); 181 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CPU, ([&] { 182 | auto grad_x2_accessor = 183 | grad_x2.accessor(); 184 | impl::_sqjensenshannon_backward_x2_kernel_impl( 185 | n_kernels, 186 | grad_output_c.accessor(), 187 | x1_c.accessor(), 188 | x2_c.accessor(), 189 | grad_x2_accessor); 190 | })); 191 | })); 192 | if (unbatched) { 193 | grad_x1.squeeze_(0); 194 | grad_x2.squeeze_(0); 195 | } 196 | return std::make_tuple(grad_x1, grad_x2); 197 | } 198 | } 199 | 200 | TORCH_LIBRARY_IMPL(torchpairwise, CPU, m) { 201 | m.impl( 202 | TORCH_SELECTIVE_NAME("torchpairwise::_sqjensenshannon"), 203 | TORCH_FN(_sqjensenshannon_forward_kernel)); 204 | m.impl( 205 | TORCH_SELECTIVE_NAME("torchpairwise::__sqjensenshannon_backward"), 206 | TORCH_FN(_sqjensenshannon_backward_kernel)); 207 | } 208 | } 209 | } 210 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cuda/additive_chi2_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_helpers.h" 5 | #include "../utils/dispatch.h" 6 | 7 | namespace torchpairwise { 8 | namespace ops { 9 | namespace { 10 | constexpr unsigned int GET_THREADS() { 11 | return 1024; 12 | } 13 | 14 | namespace impl { 15 | template 16 | C10_LAUNCH_BOUNDS_1(1024) __global__ void _additive_chi2_kernel_forward_kernel_impl( 17 | index_t n_kernels, 18 | const at::GenericPackedTensorAccessor x1, 19 | const at::GenericPackedTensorAccessor x2, 20 | at::GenericPackedTensorAccessor output) { 21 | CUDA_1D_KERNEL_LOOP(index, n_kernels) { 22 | index_t j = index % x2.size(1); 23 | index_t i = (index / x2.size(1)) % x1.size(1); 24 | index_t b = index / (x2.size(1) * x1.size(1)); 25 | 26 | scalar_t val = 0; 27 | for (index_t k = 0; k < x1.size(2); k++) { 28 | scalar_t nom = x1[b][i][k] + x2[b][j][k]; 29 | if (nom != 0) { 30 | scalar_t denom = x1[b][i][k] - x2[b][j][k]; 31 | val -= denom * denom / nom; 32 | } 33 | } 34 | output[b][i][j] = val; 35 | } 36 | } 37 | } // namespace impl 38 | 39 | at::Tensor _additive_chi2_kernel_forward_kernel( 40 | const at::Tensor &x1, 41 | const at::Tensor &x2) { 42 | at::CheckedFrom c = "_additive_chi2_kernel_forward"; 43 | auto args = { 44 | at::TensorArg(x1, "x1", 1), 45 | at::TensorArg(x2, "x2", 2)}; 46 | at::checkAllSameGPU(c, args); 47 | at::checkAllSameType(c, args); 48 | 49 | at::cuda::CUDAGuard device_guard(x1.get_device()); 50 | bool unbatched = x1.ndimension() == 2; 51 | TORCH_CHECK(unbatched || x1.ndimension() == 3, 52 | "x1 must be 2-D (unbatched) or 3-D (batched) tensor.") 53 | TORCH_CHECK(unbatched || x2.ndimension() == 3, 54 | "x2 must be 2-D (unbatched) or 3-D (batched) tensor.") 55 | TORCH_CHECK(unbatched || (x1.size(0) == x2.size(0)), 56 | "batch_size of x1 and x2 do not match.") 57 | TORCH_CHECK((unbatched && x1.size(1) == x2.size(1)) || 58 | (!unbatched && x1.size(2) == x2.size(2)), 59 | "feature dimension of x1 and x2 do not match.") 60 | 61 | auto x1_c = x1.contiguous(); 62 | auto x2_c = x2.contiguous(); 63 | if (unbatched) { 64 | x1_c = x1_c.unsqueeze(0); 65 | x2_c = x2_c.unsqueeze(0); 66 | } 67 | 68 | int64_t batch_size = x1_c.size(0); 69 | auto output = at::empty({batch_size, x1_c.size(1), x2_c.size(1)}, x1.options()); 70 | int64_t n_kernels = output.numel(); 71 | 72 | const unsigned int threads = GET_THREADS(); 73 | const unsigned int blocks = GET_BLOCKS(threads, n_kernels); 74 | 75 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x1.scalar_type(), "_additive_chi2_kernel_forward_cuda", ([&] { 76 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CUDA, ([&] { 77 | auto output_accessor = 78 | output.generic_packed_accessor(); 79 | impl::_additive_chi2_kernel_forward_kernel_impl<<>>( 80 | n_kernels, 81 | x1_c.generic_packed_accessor(), 82 | x2_c.generic_packed_accessor(), 83 | output_accessor); 84 | })); 85 | })); 86 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 87 | if (unbatched) 88 | output.squeeze_(0); 89 | return output; 90 | } 91 | 92 | namespace impl { 93 | template 94 | C10_LAUNCH_BOUNDS_1(1024) __global__ void _additive_chi2_kernel_backward_x1_kernel_impl( 95 | index_t n_kernels, 96 | const at::GenericPackedTensorAccessor grad_output, 97 | const at::GenericPackedTensorAccessor x1, 98 | const at::GenericPackedTensorAccessor x2, 99 | at::GenericPackedTensorAccessor grad_x1) { 100 | CUDA_1D_KERNEL_LOOP(index, n_kernels) { 101 | index_t k = index % x1.size(2); 102 | index_t i = (index / x1.size(2)) % x1.size(1); 103 | index_t b = index / (x1.size(2) * x1.size(1)); 104 | 105 | scalar_t val = 0; 106 | for (index_t j = 0; j < x2.size(1); j++) { 107 | scalar_t nom = x1[b][i][k] + x2[b][j][k]; 108 | if (nom != 0) { 109 | scalar_t denom = x1[b][i][k] - x2[b][j][k]; 110 | scalar_t weight = denom * denom / (nom * nom) - 2 * denom / nom; 111 | val += weight * grad_output[b][i][j]; 112 | } 113 | } 114 | grad_x1[b][i][k] = val; 115 | } 116 | } 117 | 118 | template 119 | C10_LAUNCH_BOUNDS_1(1024) __global__ void _additive_chi2_kernel_backward_x2_kernel_impl( 120 | index_t n_kernels, 121 | const at::GenericPackedTensorAccessor grad_output, 122 | const at::GenericPackedTensorAccessor x1, 123 | const at::GenericPackedTensorAccessor x2, 124 | at::GenericPackedTensorAccessor grad_x2) { 125 | CUDA_1D_KERNEL_LOOP(index, n_kernels) { 126 | index_t k = index % x2.size(2); 127 | index_t j = (index / x2.size(2)) % x2.size(1); 128 | index_t b = index / (x2.size(2) * x2.size(1)); 129 | 130 | scalar_t val = 0; 131 | for (index_t i = 0; i < x1.size(1); i++) { 132 | scalar_t nom = x1[b][i][k] + x2[b][j][k]; 133 | if (nom != 0) { 134 | scalar_t denom = x1[b][i][k] - x2[b][j][k]; 135 | scalar_t weight = 2 * denom / nom + denom * denom / (nom * nom); 136 | val += weight * grad_output[b][i][j]; 137 | } 138 | } 139 | grad_x2[b][j][k] = val; 140 | } 141 | } 142 | } // namespace impl 143 | 144 | std::tuple _additive_chi2_kernel_backward_kernel( 145 | const at::Tensor &grad_output, 146 | const at::Tensor &x1, 147 | const at::Tensor &x2) { 148 | at::cuda::CUDAGuard device_guard(grad_output.get_device()); 149 | bool unbatched = x1.ndimension() == 2; 150 | 151 | auto grad_output_c = grad_output.contiguous(); 152 | auto x1_c = x1.contiguous(); 153 | auto x2_c = x2.contiguous(); 154 | if (unbatched) { 155 | grad_output_c = grad_output_c.unsqueeze(0); 156 | x1_c = x1_c.unsqueeze(0); 157 | x2_c = x2_c.unsqueeze(0); 158 | } 159 | 160 | int64_t n_kernels; 161 | auto grad_x1 = at::zeros_like(x1_c); 162 | auto grad_x2 = at::zeros_like(x2_c); 163 | 164 | const unsigned int threads = GET_THREADS(); 165 | unsigned int blocks; 166 | 167 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x1.scalar_type(), "_additive_chi2_kernel_backward_cuda", ([&] { 168 | n_kernels = x1_c.numel(); 169 | blocks = GET_BLOCKS(threads, n_kernels); 170 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CUDA, ([&] { 171 | auto grad_x1_accessor = 172 | grad_x1.generic_packed_accessor(); 173 | impl::_additive_chi2_kernel_backward_x1_kernel_impl<<>>( 174 | n_kernels, 175 | grad_output_c.generic_packed_accessor(), 176 | x1_c.generic_packed_accessor(), 177 | x2_c.generic_packed_accessor(), 178 | grad_x1_accessor); 179 | })); 180 | 181 | n_kernels = x2_c.numel(); 182 | blocks = GET_BLOCKS(threads, n_kernels); 183 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CUDA, ([&] { 184 | auto grad_x2_accessor = 185 | grad_x2.generic_packed_accessor(); 186 | impl::_additive_chi2_kernel_backward_x2_kernel_impl<<>>( 187 | n_kernels, 188 | grad_output_c.generic_packed_accessor(), 189 | x1_c.generic_packed_accessor(), 190 | x2_c.generic_packed_accessor(), 191 | grad_x2_accessor); 192 | })); 193 | })); 194 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 195 | if (unbatched) { 196 | grad_x1.squeeze_(0); 197 | grad_x2.squeeze_(0); 198 | } 199 | return std::make_tuple(grad_x1, grad_x2); 200 | } 201 | } 202 | 203 | TORCH_LIBRARY_IMPL(torchpairwise, CUDA, m) { 204 | m.impl( 205 | TORCH_SELECTIVE_NAME("torchpairwise::_additive_chi2_kernel"), 206 | TORCH_FN(_additive_chi2_kernel_forward_kernel)); 207 | m.impl( 208 | TORCH_SELECTIVE_NAME("torchpairwise::__additive_chi2_kernel_backward"), 209 | TORCH_FN(_additive_chi2_kernel_backward_kernel)); 210 | } 211 | } 212 | } 213 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cuda/binary_ops.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "cuda_helpers.h" 4 | #include "../common/binary_ops.h" 5 | 6 | namespace torchpairwise { 7 | namespace ops { 8 | template 9 | __forceinline__ __device__ scalar_t call(scalar_t self, scalar_t other) { 10 | // logical 11 | if constexpr (op == And) 12 | return self & other; 13 | if constexpr (op == Or) 14 | return self | other; 15 | if constexpr (op == Xor) 16 | return self ^ other; 17 | // comparison 18 | if constexpr (op == Equal) 19 | return self == other; 20 | if constexpr (op == NotEqual) 21 | return self != other; 22 | if constexpr (op == Less) 23 | return self < other; 24 | if constexpr (op == Greater) 25 | return self > other; 26 | if constexpr (op == LessEqual) 27 | return self <= other; 28 | if constexpr (op == GreaterEqual) 29 | return self >= other; 30 | } 31 | } 32 | } 33 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cuda/canberra_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "cuda_helpers.h" 5 | #include "signum.cuh" 6 | #include "../utils/dispatch.h" 7 | 8 | namespace torchpairwise { 9 | namespace ops { 10 | namespace { 11 | constexpr unsigned int GET_THREADS() { 12 | return 1024; 13 | } 14 | 15 | namespace impl { 16 | template 17 | C10_LAUNCH_BOUNDS_1(1024) __global__ void _canberra_forward_kernel_impl( 18 | index_t n_kernels, 19 | const at::GenericPackedTensorAccessor x1, 20 | const at::GenericPackedTensorAccessor x2, 21 | at::GenericPackedTensorAccessor output) { 22 | CUDA_1D_KERNEL_LOOP(index, n_kernels) { 23 | index_t j = index % x2.size(1); 24 | index_t i = (index / x2.size(1)) % x1.size(1); 25 | index_t b = index / (x2.size(1) * x1.size(1)); 26 | 27 | scalar_t val = 0; 28 | for (index_t k = 0; k < x1.size(2); k++) { 29 | scalar_t denom = fabs(x1[b][i][k]) + fabs(x2[b][j][k]); 30 | if (denom != 0) 31 | val += fabs(x1[b][i][k] - x2[b][j][k]) / denom; 32 | } 33 | output[b][i][j] = val; 34 | } 35 | } 36 | } // namespace impl 37 | 38 | at::Tensor _canberra_forward_kernel( 39 | const at::Tensor &x1, 40 | const at::Tensor &x2) { 41 | at::CheckedFrom c = "_canberra_forward"; 42 | auto args = { 43 | at::TensorArg(x1, "x1", 1), 44 | at::TensorArg(x2, "x2", 2)}; 45 | at::checkAllSameGPU(c, args); 46 | at::checkAllSameType(c, args); 47 | 48 | at::cuda::CUDAGuard device_guard(x1.get_device()); 49 | bool unbatched = x1.ndimension() == 2; 50 | TORCH_CHECK(unbatched || x1.ndimension() == 3, 51 | "x1 must be 2-D (unbatched) or 3-D (batched) tensor.") 52 | TORCH_CHECK(unbatched || x2.ndimension() == 3, 53 | "x2 must be 2-D (unbatched) or 3-D (batched) tensor.") 54 | TORCH_CHECK(unbatched || (x1.size(0) == x2.size(0)), 55 | "batch_size of x1 and x2 do not match.") 56 | TORCH_CHECK((unbatched && x1.size(1) == x2.size(1)) || 57 | (!unbatched && x1.size(2) == x2.size(2)), 58 | "feature dimension of x1 and x2 do not match.") 59 | 60 | auto x1_c = x1.contiguous(); 61 | auto x2_c = x2.contiguous(); 62 | if (unbatched) { 63 | x1_c = x1_c.unsqueeze(0); 64 | x2_c = x2_c.unsqueeze(0); 65 | } 66 | 67 | int64_t batch_size = x1_c.size(0); 68 | auto output = at::empty({batch_size, x1_c.size(1), x2_c.size(1)}, x1.options()); 69 | int64_t n_kernels = output.numel(); 70 | 71 | const unsigned int threads = GET_THREADS(); 72 | const unsigned int blocks = GET_BLOCKS(threads, n_kernels); 73 | 74 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x1.scalar_type(), "_canberra_forward_cuda", ([&] { 75 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CUDA, ([&] { 76 | auto output_accessor = 77 | output.generic_packed_accessor(); 78 | impl::_canberra_forward_kernel_impl<<>>( 79 | n_kernels, 80 | x1_c.generic_packed_accessor(), 81 | x2_c.generic_packed_accessor(), 82 | output_accessor); 83 | })); 84 | })); 85 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 86 | if (unbatched) 87 | output.squeeze_(0); 88 | return output; 89 | } 90 | 91 | namespace impl { 92 | template 93 | C10_LAUNCH_BOUNDS_1(1024) __global__ void _canberra_backward_x1_kernel_impl( 94 | index_t n_kernels, 95 | const at::GenericPackedTensorAccessor grad_output, 96 | const at::GenericPackedTensorAccessor x1, 97 | const at::GenericPackedTensorAccessor x2, 98 | at::GenericPackedTensorAccessor grad_x1) { 99 | CUDA_1D_KERNEL_LOOP(index, n_kernels) { 100 | index_t k = index % x1.size(2); 101 | index_t i = (index / x1.size(2)) % x1.size(1); 102 | index_t b = index / (x1.size(2) * x1.size(1)); 103 | 104 | scalar_t val = 0; 105 | for (index_t j = 0; j < x2.size(1); j++) { 106 | scalar_t nom = fabs(x1[b][i][k]) + fabs(x2[b][j][k]); 107 | scalar_t denom = x1[b][i][k] - x2[b][j][k]; 108 | val += grad_output[b][i][j] * 109 | (m_signum(denom) / nom - m_signum(x1[b][i][k]) * fabs(denom) / nom / nom); 110 | } 111 | grad_x1[b][i][k] = val; 112 | } 113 | } 114 | 115 | template 116 | C10_LAUNCH_BOUNDS_1(1024) __global__ void _canberra_backward_x2_kernel_impl( 117 | index_t n_kernels, 118 | const at::GenericPackedTensorAccessor grad_output, 119 | const at::GenericPackedTensorAccessor x1, 120 | const at::GenericPackedTensorAccessor x2, 121 | at::GenericPackedTensorAccessor grad_x2) { 122 | CUDA_1D_KERNEL_LOOP(index, n_kernels) { 123 | index_t k = index % x2.size(2); 124 | index_t j = (index / x2.size(2)) % x2.size(1); 125 | index_t b = index / (x2.size(2) * x2.size(1)); 126 | 127 | scalar_t val = 0; 128 | for (index_t i = 0; i < x1.size(1); i++) { 129 | scalar_t nom = fabs(x1[b][i][k]) + fabs(x2[b][j][k]); 130 | scalar_t denom = x1[b][i][k] - x2[b][j][k]; 131 | val += grad_output[b][i][j] * 132 | (-m_signum(denom) / nom - m_signum(x2[b][j][k]) * fabs(denom) / nom / nom); 133 | } 134 | grad_x2[b][j][k] = val; 135 | } 136 | } 137 | } // namespace impl 138 | 139 | std::tuple _canberra_backward_kernel( 140 | const at::Tensor &grad_output, 141 | const at::Tensor &x1, 142 | const at::Tensor &x2) { 143 | at::cuda::CUDAGuard device_guard(grad_output.get_device()); 144 | bool unbatched = x1.ndimension() == 2; 145 | 146 | auto grad_output_c = grad_output.contiguous(); 147 | auto x1_c = x1.contiguous(); 148 | auto x2_c = x2.contiguous(); 149 | if (unbatched) { 150 | grad_output_c = grad_output_c.unsqueeze(0); 151 | x1_c = x1_c.unsqueeze(0); 152 | x2_c = x2_c.unsqueeze(0); 153 | } 154 | 155 | int64_t n_kernels; 156 | auto grad_x1 = at::zeros_like(x1_c); 157 | auto grad_x2 = at::zeros_like(x2_c); 158 | 159 | const unsigned int threads = GET_THREADS(); 160 | unsigned int blocks; 161 | 162 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x1.scalar_type(), "_canberra_backward_cuda", ([&] { 163 | n_kernels = x1_c.numel(); 164 | blocks = GET_BLOCKS(threads, n_kernels); 165 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CUDA, ([&] { 166 | auto grad_x1_accessor = 167 | grad_x1.generic_packed_accessor(); 168 | impl::_canberra_backward_x1_kernel_impl<<>>( 169 | n_kernels, 170 | grad_output_c.generic_packed_accessor(), 171 | x1_c.generic_packed_accessor(), 172 | x2_c.generic_packed_accessor(), 173 | grad_x1_accessor); 174 | })); 175 | 176 | n_kernels = x2_c.numel(); 177 | blocks = GET_BLOCKS(threads, n_kernels); 178 | TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(n_kernels, CUDA, ([&] { 179 | auto grad_x2_accessor = 180 | grad_x2.generic_packed_accessor(); 181 | impl::_canberra_backward_x2_kernel_impl<<>>( 182 | n_kernels, 183 | grad_output_c.generic_packed_accessor(), 184 | x1_c.generic_packed_accessor(), 185 | x2_c.generic_packed_accessor(), 186 | grad_x2_accessor); 187 | })); 188 | })); 189 | C10_CUDA_KERNEL_LAUNCH_CHECK(); 190 | if (unbatched) { 191 | grad_x1.squeeze_(0); 192 | grad_x2.squeeze_(0); 193 | } 194 | return std::make_tuple(grad_x1, grad_x2); 195 | } 196 | } 197 | 198 | TORCH_LIBRARY_IMPL(torchpairwise, CUDA, m) { 199 | m.impl( 200 | TORCH_SELECTIVE_NAME("torchpairwise::_canberra"), 201 | TORCH_FN(_canberra_forward_kernel)); 202 | m.impl( 203 | TORCH_SELECTIVE_NAME("torchpairwise::__canberra_backward"), 204 | TORCH_FN(_canberra_backward_kernel)); 205 | } 206 | } 207 | } 208 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cuda/cuda_helpers.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include 6 | #include 7 | #include 8 | 9 | #define CUDA_1D_KERNEL_LOOP_T(i, n, index_t) \ 10 | for (index_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); i += (blockDim.x * gridDim.x)) 11 | 12 | #define CUDA_1D_KERNEL_LOOP(i, n) \ 13 | CUDA_1D_KERNEL_LOOP_T(i, n, std::remove_cv_t>) 14 | 15 | inline unsigned int GET_BLOCKS( 16 | const unsigned int THREADS, 17 | const unsigned int N) { 18 | unsigned int kMaxGridNum = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; 19 | return std::min(kMaxGridNum, (N + THREADS - 1) / THREADS); 20 | } 21 | 22 | // Temporarily counter latest MSVC update that causes incompatibility with CUDA 23 | #if (_MSC_VER >= 1928) 24 | #define floor floorf 25 | #define ceil ceilf 26 | #endif 27 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cuda/prf_divide.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../common/prf_div_mode.h" 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | template 8 | __forceinline__ __device__ constexpr T prf_divide(const T &x, const T &y) { 9 | if constexpr (mode == Zero) 10 | return y != T(0) ? x / y : T(0); 11 | else if constexpr (mode == Identity) 12 | return y != T(0) ? x / y : x; 13 | else 14 | return x / y; 15 | } 16 | } 17 | } 18 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cuda/reduction_ops.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "cuda_helpers.h" 6 | #include "../common/reduction_ops.h" 7 | 8 | namespace torchpairwise { 9 | namespace ops { 10 | template 11 | __device__ constexpr scalar_t identity_value() { 12 | if constexpr (op == All) 13 | return static_cast(true); 14 | if constexpr (op == Any) 15 | return static_cast(false); 16 | if constexpr (op == Sum || op == Mean) 17 | return static_cast(0); 18 | if constexpr (op == Prod) 19 | return static_cast(1); 20 | } 21 | 22 | template 23 | __forceinline__ __device__ output_t call(input_ts... args) { 24 | // logical 25 | if constexpr (op == All) 26 | return (... && args); 27 | if constexpr (op == Any) 28 | return (... || args); 29 | // arithmetics 30 | if constexpr (op == Sum) 31 | return (... + args); 32 | if constexpr (op == Prod) 33 | return (... * args); 34 | if constexpr (op == Mean) 35 | return (... + args) / static_cast(sizeof...(args)); 36 | } 37 | 38 | template class PtrTraits = at::DefaultPtrTraits, typename index_t = int64_t> 39 | __forceinline__ __device__ output_t call(const at::GenericPackedTensorAccessor args) { 40 | output_t output = identity_value(); 41 | for (int64_t i = 0; i < args.size(0); i++) { 42 | // logical 43 | if constexpr (op == All) 44 | output &= args[i]; 45 | if constexpr (op == Any) 46 | output |= args[i]; 47 | // arithmetics 48 | if constexpr (op == Sum || op == Mean) 49 | output += args[i]; 50 | if constexpr (op == Prod) 51 | output *= args[i]; 52 | } 53 | if constexpr (op == Mean) { 54 | output /= static_cast(args.size(0)); 55 | } 56 | return output; 57 | } 58 | 59 | template 60 | __forceinline__ __device__ void accumulate_call(output_t* val, input_t arg) { 61 | // logical 62 | if constexpr (op == All) 63 | *val &= arg; 64 | if constexpr (op == Any) 65 | *val |= arg; 66 | // arithmetics 67 | if constexpr (op == Sum || op == Mean) 68 | *val += arg; 69 | if constexpr (op == Prod) 70 | *val *= arg; 71 | } 72 | } 73 | } 74 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cuda/rel_entr.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "../utils/scalar_type_utils.h" 7 | 8 | namespace torchpairwise { 9 | namespace ops { 10 | template 11 | __forceinline__ __device__ bool m_isnan(const T &_X) throw() { 12 | if constexpr (std::is_same_v || 13 | std::is_same_v) 14 | return _X.x == std::numeric_limits::quiet_NaN().x; 15 | else 16 | return isnan(_X); 17 | } 18 | 19 | template 20 | __forceinline__ __device__ constexpr T rel_entr(const T &x, const T &y) { 21 | if (m_isnan(x)) 22 | return x; 23 | else if (x > T(0) && y > T(0)) 24 | return x * log(x / y); 25 | else if (x == T(0) && y >= T(0)) 26 | return 0; 27 | else 28 | return c10::CPPTypeLimits::upper_bound(); 29 | } 30 | } 31 | } 32 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/cuda/signum.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../utils/scalar_type_utils.h" 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | template 8 | typename std::enable_if::value, int>::type 9 | __forceinline__ __device__ constexpr m_signum(const T &x) { 10 | return T(0) < x; 11 | } 12 | 13 | template 14 | typename std::enable_if::value, int>::type 15 | __forceinline__ __device__ constexpr m_signum(const T &x) { 16 | return (T(0) < x) - (x < T(0)); 17 | } 18 | } 19 | } 20 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/hausdorff.cpp: -------------------------------------------------------------------------------- 1 | #include "hausdorff.h" 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | std::tuple _directed_hausdorff( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2, 10 | bool shuffle, 11 | c10::optional generator) { 12 | static auto op = c10::Dispatcher::singleton() 13 | .findSchemaOrThrow("torchpairwise::_directed_hausdorff", "") 14 | .typed(); 15 | return op.call(x1, x2, shuffle, generator); 16 | } 17 | 18 | namespace detail { 19 | std::tuple __directed_hausdorff_backward( 20 | const at::Tensor &grad, 21 | const at::Tensor &x1, 22 | const at::Tensor &x2, 23 | bool shuffle, 24 | c10::optional generator) { 25 | static auto op = 26 | c10::Dispatcher::singleton() 27 | .findSchemaOrThrow("torchpairwise::__directed_hausdorff_backward", "") 28 | .typed(); 29 | return op.call(grad, x1, x2, shuffle, generator); 30 | } 31 | } 32 | 33 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 34 | m.def(TORCH_SELECTIVE_SCHEMA( 35 | "torchpairwise::_directed_hausdorff(Tensor x1, Tensor x2, *, bool shuffle=False, Generator? generator=None) -> (Tensor output, Tensor x1_indices, Tensor x2_indices)") 36 | ); 37 | m.def(TORCH_SELECTIVE_SCHEMA( 38 | "torchpairwise::__directed_hausdorff_backward(Tensor grad, Tensor x1, Tensor x2, *, bool shuffle=False, Generator? generator=None) -> (Tensor grad_x1, Tensor grad_x2)") 39 | ); 40 | } 41 | } // namespace ops 42 | } // namespace torchpairwise 43 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/hausdorff.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | std::tuple _directed_hausdorff( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2, 10 | bool shuffle = false, 11 | c10::optional generator = c10::nullopt); 12 | 13 | namespace detail { 14 | std::tuple __directed_hausdorff_backward( 15 | const at::Tensor &grad, 16 | const at::Tensor &x1, 17 | const at::Tensor &x2, 18 | bool shuffle, 19 | c10::optional generator); 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/haversine.cpp: -------------------------------------------------------------------------------- 1 | #include "haversine.h" 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _haversine( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2) { 10 | static auto op = c10::Dispatcher::singleton() 11 | .findSchemaOrThrow("torchpairwise::_haversine", "") 12 | .typed(); 13 | return op.call(x1, x2); 14 | } 15 | 16 | namespace detail { 17 | std::tuple __haversine_backward( 18 | const at::Tensor &grad, 19 | const at::Tensor &x1, 20 | const at::Tensor &x2) { 21 | static auto op = 22 | c10::Dispatcher::singleton() 23 | .findSchemaOrThrow("torchpairwise::__haversine_backward", "") 24 | .typed(); 25 | return op.call(grad, x1, x2); 26 | } 27 | } 28 | 29 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 30 | m.def(TORCH_SELECTIVE_SCHEMA( 31 | "torchpairwise::_haversine(Tensor x1, Tensor x2) -> Tensor") 32 | ); 33 | m.def(TORCH_SELECTIVE_SCHEMA( 34 | "torchpairwise::__haversine_backward(Tensor grad, Tensor x1, Tensor x2) -> (Tensor grad_x1, Tensor grad_x2)") 35 | ); 36 | } 37 | } // namespace ops 38 | } // namespace torchpairwise 39 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/haversine.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "../macros.h" 6 | 7 | namespace torchpairwise { 8 | namespace ops { 9 | at::Tensor _haversine( 10 | const at::Tensor &x1, 11 | const at::Tensor &x2); 12 | 13 | namespace detail { 14 | std::tuple __haversine_backward( 15 | const at::Tensor &grad, 16 | const at::Tensor &x1, 17 | const at::Tensor &x2); 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/neighbors.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "pairwise_metrics.h" 6 | #include "cpdist.h" 7 | 8 | namespace torchpairwise { 9 | namespace ops { 10 | at::Tensor k_neighbors_mask( 11 | const at::Tensor &x1, 12 | const c10::optional &x2, 13 | int64_t k, 14 | c10::string_view metric, 15 | const c10::optional &w, 16 | const c10::optional &V, 17 | const c10::optional &VI, 18 | c10::optional p, 19 | c10::optional base, 20 | c10::optional shuffle, 21 | c10::optional generator) { 22 | C10_LOG_API_USAGE_ONCE("torchpairwise.csrc.ops.neighbors.k_neighbors_mask") 23 | at::Tensor x1_, x2_; 24 | std::tie(x1_, x2_) = utils::check_pairwise_inputs("k_neighbors_mask", x1, x2); 25 | TORCH_CHECK(k >= 0, 26 | "k must be non-negative. Got k=", 27 | k) 28 | at::NoGradGuard no_grad_guard; 29 | auto dists = ops::cdist(x1_, x2_, metric, w, V, VI, p, base, shuffle, generator); 30 | auto neighbors_inds = dists.argsort(1, false).slice(1, 0, k + 1); 31 | auto first_dim = at::arange(0, x1_.size(0), neighbors_inds.options()).view({-1, 1}).repeat({1, k + 1}); 32 | auto output = at::zeros({x1.size(0), x2_.size(0)}, x1_.options().dtype(at::kBool)); 33 | output.index_put_({first_dim.flatten(), neighbors_inds.flatten()}, true); 34 | return output; 35 | } 36 | 37 | at::Tensor radius_neighbors_mask( 38 | const at::Tensor &x1, 39 | const c10::optional &x2, 40 | double epsilon, 41 | c10::string_view metric, 42 | const c10::optional &w, 43 | const c10::optional &V, 44 | const c10::optional &VI, 45 | c10::optional p, 46 | c10::optional base, 47 | c10::optional shuffle, 48 | c10::optional generator) { 49 | C10_LOG_API_USAGE_ONCE("torchpairwise.csrc.ops.neighbors.k_neighbors_mask") 50 | at::Tensor x1_, x2_; 51 | std::tie(x1_, x2_) = utils::check_pairwise_inputs("radius_neighbors_mask", x1, x2); 52 | TORCH_CHECK(epsilon >= 0., 53 | "epsilon must be non-negative. Got epsilon=", 54 | epsilon) 55 | at::NoGradGuard no_grad_guard; 56 | auto dists = ops::cdist(x1_, x2_, metric, w, V, VI, p, base, shuffle, generator); 57 | return dists.le(epsilon); 58 | } 59 | 60 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 61 | // utilities 62 | m.def("torchpairwise::k_neighbors_mask(Tensor x1, Tensor? x2=None, int k=1, " 63 | "str metric=\"euclidean\", *, " TORCHPAIRWISE_CPDIST_EXTRA_ARGS_SCHEMA_STR ") -> Tensor", 64 | TORCH_FN(k_neighbors_mask)); 65 | m.def("torchpairwise::radius_neighbors_mask(Tensor x1, Tensor? x2=None, float epsilon=0., " 66 | "str metric=\"euclidean\", *, " TORCHPAIRWISE_CPDIST_EXTRA_ARGS_SCHEMA_STR ") -> Tensor", 67 | TORCH_FN(radius_neighbors_mask)); 68 | } 69 | } 70 | } 71 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/neighbors.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "../macros.h" 6 | 7 | namespace torchpairwise { 8 | namespace ops { 9 | TORCHPAIRWISE_API at::Tensor k_neighbors_mask( 10 | const at::Tensor &x1, 11 | const c10::optional &x2 = c10::nullopt, 12 | int64_t k = 1, 13 | c10::string_view metric = "euclidean", 14 | const c10::optional &w = c10::nullopt, 15 | const c10::optional &V = c10::nullopt, 16 | const c10::optional &VI = c10::nullopt, 17 | c10::optional p = c10::nullopt, 18 | c10::optional base = c10::nullopt, 19 | c10::optional shuffle = c10::nullopt, 20 | c10::optional generator = c10::nullopt); 21 | 22 | TORCHPAIRWISE_API at::Tensor radius_neighbors_mask( 23 | const at::Tensor &x1, 24 | const c10::optional &x2 = c10::nullopt, 25 | double epsilon = 0, 26 | c10::string_view metric = "euclidean", 27 | const c10::optional &w = c10::nullopt, 28 | const c10::optional &V = c10::nullopt, 29 | const c10::optional &VI = c10::nullopt, 30 | c10::optional p = c10::nullopt, 31 | c10::optional base = c10::nullopt, 32 | c10::optional shuffle = c10::nullopt, 33 | c10::optional generator = c10::nullopt); 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/ops.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "pairwise_metrics.h" 4 | #include "neighbors.h" 5 | #include "cpdist.h" 6 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/pairwise_binary.cpp: -------------------------------------------------------------------------------- 1 | #include "haversine.h" 2 | 3 | #include 4 | 5 | #include "common/binary_ops.h" 6 | #include "common/reduction_ops.h" 7 | 8 | namespace torchpairwise { 9 | namespace ops { 10 | namespace detail { 11 | template 12 | inline at::Tensor _pairwise_binary( 13 | const at::Tensor &x1, 14 | const at::Tensor &x2) { 15 | static auto op = c10::Dispatcher::singleton() 16 | .findSchemaOrThrow(op_schema_name().c_str(), "") 17 | .template typed)>(); 18 | return op.call(x1, x2); 19 | } 20 | 21 | template 22 | inline at::Tensor _pairwise_binary_reduction( 23 | const at::Tensor &x1, 24 | const at::Tensor &x2) { 25 | static auto op = c10::Dispatcher::singleton() 26 | .findSchemaOrThrow(op_schema_name().c_str(), "") 27 | .template typed)>(); 28 | return op.call(x1, x2); 29 | } 30 | } 31 | 32 | // ~~~~~ binary ~~~~~ 33 | // logical 34 | at::Tensor pwand( 35 | const at::Tensor &x1, 36 | const at::Tensor &x2) { 37 | return detail::_pairwise_binary(x1, x2); 38 | } 39 | 40 | at::Tensor pwor( 41 | const at::Tensor &x1, 42 | const at::Tensor &x2) { 43 | return detail::_pairwise_binary(x1, x2); 44 | } 45 | 46 | at::Tensor pwxor( 47 | const at::Tensor &x1, 48 | const at::Tensor &x2) { 49 | return detail::_pairwise_binary(x1, x2); 50 | } 51 | 52 | // comparison 53 | at::Tensor pweq( 54 | const at::Tensor &x1, 55 | const at::Tensor &x2) { 56 | return detail::_pairwise_binary(x1, x2); 57 | } 58 | 59 | at::Tensor pwne( 60 | const at::Tensor &x1, 61 | const at::Tensor &x2) { 62 | return detail::_pairwise_binary(x1, x2); 63 | } 64 | 65 | at::Tensor pwlt( 66 | const at::Tensor &x1, 67 | const at::Tensor &x2) { 68 | return detail::_pairwise_binary(x1, x2); 69 | } 70 | 71 | at::Tensor pwgt( 72 | const at::Tensor &x1, 73 | const at::Tensor &x2) { 74 | return detail::_pairwise_binary(x1, x2); 75 | } 76 | 77 | at::Tensor pwle( 78 | const at::Tensor &x1, 79 | const at::Tensor &x2) { 80 | return detail::_pairwise_binary(x1, x2); 81 | } 82 | 83 | at::Tensor pwge( 84 | const at::Tensor &x1, 85 | const at::Tensor &x2) { 86 | return detail::_pairwise_binary(x1, x2); 87 | } 88 | 89 | // ~~~~~ binary reduction ~~~~~ 90 | // logical sum 91 | at::Tensor pwand_sum( 92 | const at::Tensor &x1, 93 | const at::Tensor &x2) { 94 | return detail::_pairwise_binary_reduction(x1, x2); 95 | } 96 | 97 | at::Tensor pwor_sum( 98 | const at::Tensor &x1, 99 | const at::Tensor &x2) { 100 | return detail::_pairwise_binary_reduction(x1, x2); 101 | } 102 | 103 | at::Tensor pwxor_sum( 104 | const at::Tensor &x1, 105 | const at::Tensor &x2) { 106 | return detail::_pairwise_binary_reduction(x1, x2); 107 | } 108 | 109 | // comparison sum 110 | at::Tensor pweq_sum( 111 | const at::Tensor &x1, 112 | const at::Tensor &x2) { 113 | return detail::_pairwise_binary_reduction(x1, x2); 114 | } 115 | 116 | at::Tensor pwne_sum( 117 | const at::Tensor &x1, 118 | const at::Tensor &x2) { 119 | return detail::_pairwise_binary_reduction(x1, x2); 120 | } 121 | 122 | // comparison mean 123 | at::Tensor pweq_mean( 124 | const at::Tensor &x1, 125 | const at::Tensor &x2) { 126 | return detail::_pairwise_binary_reduction(x1, x2); 127 | } 128 | 129 | at::Tensor pwne_mean( 130 | const at::Tensor &x1, 131 | const at::Tensor &x2) { 132 | return detail::_pairwise_binary_reduction(x1, x2); 133 | } 134 | 135 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 136 | // ~~~~~ binary ~~~~~ 137 | // logical 138 | m.def(op_full_schema().c_str()); 139 | m.def(op_full_schema().c_str()); 140 | m.def(op_full_schema().c_str()); 141 | // comparison 142 | m.def(op_full_schema().c_str()); 143 | m.def(op_full_schema().c_str()); 144 | m.def(op_full_schema().c_str()); 145 | m.def(op_full_schema().c_str()); 146 | m.def(op_full_schema().c_str()); 147 | m.def(op_full_schema().c_str()); 148 | 149 | // ~~~~~ binary reduction ~~~~~ 150 | // logical sum 151 | m.def(op_full_schema().c_str()); 152 | m.def(op_full_schema().c_str()); 153 | m.def(op_full_schema().c_str()); 154 | // comparison sum 155 | m.def(op_full_schema().c_str()); 156 | m.def(op_full_schema().c_str()); 157 | // comparison mean 158 | m.def(op_full_schema().c_str()); 159 | m.def(op_full_schema().c_str()); 160 | } 161 | } // namespace ops 162 | } // namespace torchpairwise 163 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/pairwise_binary.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "common/binary_ops.h" 6 | #include "common/reduction_ops.h" 7 | #include "../macros.h" 8 | 9 | namespace torchpairwise { 10 | namespace ops { 11 | namespace detail { 12 | template 13 | inline at::Tensor _pairwise_binary( 14 | const at::Tensor &x1, 15 | const at::Tensor &x2); 16 | 17 | template 18 | inline at::Tensor _pairwise_binary_reduction( 19 | const at::Tensor &x1, 20 | const at::Tensor &x2); 21 | } 22 | 23 | // ~~~~~ binary ~~~~~ 24 | // logical 25 | TORCHPAIRWISE_API at::Tensor pwand( 26 | const at::Tensor &x1, 27 | const at::Tensor &x2); 28 | 29 | TORCHPAIRWISE_API at::Tensor pwor( 30 | const at::Tensor &x1, 31 | const at::Tensor &x2); 32 | 33 | TORCHPAIRWISE_API at::Tensor pwxor( 34 | const at::Tensor &x1, 35 | const at::Tensor &x2); 36 | 37 | // comparison 38 | TORCHPAIRWISE_API at::Tensor pweq( 39 | const at::Tensor &x1, 40 | const at::Tensor &x2); 41 | 42 | TORCHPAIRWISE_API at::Tensor pwne( 43 | const at::Tensor &x1, 44 | const at::Tensor &x2); 45 | 46 | TORCHPAIRWISE_API at::Tensor pwlt( 47 | const at::Tensor &x1, 48 | const at::Tensor &x2); 49 | 50 | TORCHPAIRWISE_API at::Tensor pwgt( 51 | const at::Tensor &x1, 52 | const at::Tensor &x2); 53 | 54 | TORCHPAIRWISE_API at::Tensor pwle( 55 | const at::Tensor &x1, 56 | const at::Tensor &x2); 57 | 58 | TORCHPAIRWISE_API at::Tensor pwge( 59 | const at::Tensor &x1, 60 | const at::Tensor &x2); 61 | 62 | // ~~~~~ binary reduction ~~~~~ 63 | // logical sum 64 | TORCHPAIRWISE_API at::Tensor pwand_sum( 65 | const at::Tensor &x1, 66 | const at::Tensor &x2); 67 | 68 | TORCHPAIRWISE_API at::Tensor pwor_sum( 69 | const at::Tensor &x1, 70 | const at::Tensor &x2); 71 | 72 | TORCHPAIRWISE_API at::Tensor pwxor_sum( 73 | const at::Tensor &x1, 74 | const at::Tensor &x2); 75 | 76 | // comparison sum 77 | TORCHPAIRWISE_API at::Tensor pweq_sum( 78 | const at::Tensor &x1, 79 | const at::Tensor &x2); 80 | 81 | TORCHPAIRWISE_API at::Tensor pwne_sum( 82 | const at::Tensor &x1, 83 | const at::Tensor &x2); 84 | 85 | // comparison mean 86 | TORCHPAIRWISE_API at::Tensor pweq_mean( 87 | const at::Tensor &x1, 88 | const at::Tensor &x2); 89 | 90 | TORCHPAIRWISE_API at::Tensor pwne_mean( 91 | const at::Tensor &x1, 92 | const at::Tensor &x2); 93 | } 94 | } 95 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/ppminkowski.cpp: -------------------------------------------------------------------------------- 1 | #include "ppminkowski.h" 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _ppminkowski( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2, 10 | double p) { 11 | static auto op = c10::Dispatcher::singleton() 12 | .findSchemaOrThrow("torchpairwise::_ppminkowski", "") 13 | .typed(); 14 | return op.call(x1, x2, p); 15 | } 16 | 17 | namespace detail { 18 | std::tuple __ppminkowski_backward( 19 | const at::Tensor &grad, 20 | const at::Tensor &x1, 21 | const at::Tensor &x2, 22 | double p) { 23 | static auto op = 24 | c10::Dispatcher::singleton() 25 | .findSchemaOrThrow("torchpairwise::__ppminkowski_backward", "") 26 | .typed(); 27 | return op.call(grad, x1, x2, p); 28 | } 29 | } 30 | 31 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 32 | m.def(TORCH_SELECTIVE_SCHEMA( 33 | "torchpairwise::_ppminkowski(Tensor x1, Tensor x2, float p=2) -> Tensor") 34 | ); 35 | m.def(TORCH_SELECTIVE_SCHEMA( 36 | "torchpairwise::__ppminkowski_backward(Tensor grad, Tensor x1, Tensor x2, float p) -> (Tensor grad_x1, Tensor grad_x2)") 37 | ); 38 | } 39 | } // namespace ops 40 | } // namespace torchpairwise 41 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/ppminkowski.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "../macros.h" 6 | 7 | namespace torchpairwise { 8 | namespace ops { 9 | TORCHPAIRWISE_API at::Tensor _ppminkowski( 10 | const at::Tensor &x1, 11 | const at::Tensor &x2, 12 | double p = 2); 13 | 14 | namespace detail { 15 | std::tuple __ppminkowski_backward( 16 | const at::Tensor &grad, 17 | const at::Tensor &x1, 18 | const at::Tensor &x2, 19 | double p); 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/prf_div.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor prf_div( 8 | const at::Tensor &self, 9 | const at::Tensor &other, 10 | c10::string_view mode = "zero"); 11 | 12 | at::Tensor prf_ldiv( 13 | const at::Tensor &self, 14 | const at::Tensor &other, 15 | c10::string_view mode = "zero"); 16 | 17 | at::Tensor prf_div( 18 | const at::Tensor &self, 19 | const at::Scalar &other, 20 | c10::string_view mode = "zero"); 21 | 22 | at::Tensor prf_ldiv( 23 | const at::Tensor &self, 24 | const at::Scalar &other, 25 | c10::string_view mode = "zero"); 26 | 27 | at::Tensor prf_div( 28 | const at::Scalar &self, 29 | const at::Tensor &other, 30 | c10::string_view mode = "zero"); 31 | 32 | at::Tensor prf_ldiv( 33 | const at::Scalar &self, 34 | const at::Tensor &other, 35 | c10::string_view mode = "zero"); 36 | 37 | at::Tensor prf_div_( 38 | at::Tensor &self, 39 | const at::Tensor &other, 40 | c10::string_view mode = "zero"); 41 | 42 | at::Tensor prf_ldiv_( 43 | at::Tensor &self, 44 | const at::Tensor &other, 45 | c10::string_view mode = "zero"); 46 | 47 | at::Tensor prf_div_( 48 | at::Tensor &self, 49 | const at::Scalar &other, 50 | c10::string_view mode = "zero"); 51 | 52 | at::Tensor prf_ldiv_( 53 | at::Tensor &self, 54 | const at::Scalar &other, 55 | c10::string_view mode = "zero"); 56 | 57 | namespace detail { 58 | std::tuple _prf_div_backward( 59 | const at::Tensor &grad_output, 60 | const at::Tensor &self, 61 | const at::Tensor &other, 62 | c10::string_view mode = "zero"); 63 | 64 | std::tuple _prf_ldiv_backward( 65 | const at::Tensor &grad_output, 66 | const at::Tensor &self, 67 | const at::Tensor &other, 68 | c10::string_view mode = "zero"); 69 | 70 | at::Tensor _prf_div_backward( 71 | const at::Tensor &grad_output, 72 | const at::Tensor &self, 73 | const at::Scalar &other, 74 | c10::string_view mode = "zero"); 75 | 76 | at::Tensor _prf_ldiv_backward( 77 | const at::Tensor &grad_output, 78 | const at::Tensor &self, 79 | const at::Scalar &other, 80 | c10::string_view mode = "zero"); 81 | 82 | at::Tensor _prf_div_backward( 83 | const at::Tensor &grad_output, 84 | const at::Scalar &self, 85 | const at::Tensor &other, 86 | c10::string_view mode = "zero"); 87 | 88 | at::Tensor _prf_ldiv_backward( 89 | const at::Tensor &grad_output, 90 | const at::Scalar &self, 91 | const at::Tensor &other, 92 | c10::string_view mode = "zero"); 93 | 94 | std::tuple _prf_div__backward( 95 | const at::Tensor &grad_output, 96 | const at::Tensor &self, 97 | const at::Tensor &other, 98 | c10::string_view mode = "zero"); 99 | 100 | std::tuple _prf_ldiv__backward( 101 | const at::Tensor &grad_output, 102 | const at::Tensor &self, 103 | const at::Tensor &other, 104 | c10::string_view mode = "zero"); 105 | 106 | at::Tensor _prf_div__backward( 107 | const at::Tensor &grad_output, 108 | const at::Tensor &self, 109 | const at::Scalar &other, 110 | c10::string_view mode = "zero"); 111 | 112 | at::Tensor _prf_ldiv__backward( 113 | const at::Tensor &grad_output, 114 | const at::Tensor &self, 115 | const at::Scalar &other, 116 | c10::string_view mode = "zero"); 117 | } 118 | } 119 | } 120 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/snr.cpp: -------------------------------------------------------------------------------- 1 | #include "snr.h" 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _snr( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2) { 10 | static auto op = c10::Dispatcher::singleton() 11 | .findSchemaOrThrow("torchpairwise::_snr", "") 12 | .typed(); 13 | return op.call(x1, x2); 14 | } 15 | 16 | namespace detail { 17 | std::tuple __snr_backward( 18 | const at::Tensor &grad, 19 | const at::Tensor &x1, 20 | const at::Tensor &x2) { 21 | static auto op = 22 | c10::Dispatcher::singleton() 23 | .findSchemaOrThrow("torchpairwise::__snr_backward", "") 24 | .typed(); 25 | return op.call(grad, x1, x2); 26 | } 27 | } 28 | 29 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 30 | m.def(TORCH_SELECTIVE_SCHEMA( 31 | "torchpairwise::_snr(Tensor x1, Tensor x2) -> Tensor") 32 | ); 33 | m.def(TORCH_SELECTIVE_SCHEMA( 34 | "torchpairwise::__snr_backward(Tensor grad, Tensor x1, Tensor x2) -> (Tensor grad_x1, Tensor grad_x2)") 35 | ); 36 | } 37 | } // namespace ops 38 | } // namespace torchpairwise 39 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/snr.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _snr( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2); 10 | 11 | namespace detail { 12 | std::tuple __snr_backward( 13 | const at::Tensor &grad, 14 | const at::Tensor &x1, 15 | const at::Tensor &x2); 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/sqjensenshannon.cpp: -------------------------------------------------------------------------------- 1 | #include "sqjensenshannon.h" 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _sqjensenshannon( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2, 10 | c10::optional base) { 11 | static auto op = c10::Dispatcher::singleton() 12 | .findSchemaOrThrow("torchpairwise::_sqjensenshannon", "") 13 | .typed(); 14 | return op.call(x1, x2, base); 15 | } 16 | 17 | namespace detail { 18 | std::tuple __sqjensenshannon_backward( 19 | const at::Tensor &grad, 20 | const at::Tensor &x1, 21 | const at::Tensor &x2, 22 | c10::optional base) { 23 | static auto op = 24 | c10::Dispatcher::singleton() 25 | .findSchemaOrThrow("torchpairwise::__sqjensenshannon_backward", "") 26 | .typed(); 27 | return op.call(grad, x1, x2, base); 28 | } 29 | } 30 | 31 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 32 | m.def(TORCH_SELECTIVE_SCHEMA( 33 | "torchpairwise::_sqjensenshannon(Tensor x1, Tensor x2, float? base=None) -> Tensor") 34 | ); 35 | m.def(TORCH_SELECTIVE_SCHEMA( 36 | "torchpairwise::__sqjensenshannon_backward(Tensor grad, Tensor x1, Tensor x2, float? base) -> (Tensor grad_x1, Tensor grad_x2)") 37 | ); 38 | } 39 | } // namespace ops 40 | } // namespace torchpairwise 41 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/sqjensenshannon.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _sqjensenshannon( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2, 10 | c10::optional base = c10::nullopt); 11 | 12 | namespace detail { 13 | std::tuple __sqjensenshannon_backward( 14 | const at::Tensor &grad, 15 | const at::Tensor &x1, 16 | const at::Tensor &x2, 17 | c10::optional base); 18 | } 19 | } 20 | } 21 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/sqmahalanobis.cpp: -------------------------------------------------------------------------------- 1 | #include "sqmahalanobis.h" 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _sqmahalanobis( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2, 10 | const at::Tensor &VI) { 11 | static auto op = c10::Dispatcher::singleton() 12 | .findSchemaOrThrow("torchpairwise::_sqmahalanobis", "") 13 | .typed(); 14 | return op.call(x1, x2, VI); 15 | } 16 | 17 | namespace detail { 18 | std::tuple __sqmahalanobis_backward( 19 | const at::Tensor &grad, 20 | const at::Tensor &x1, 21 | const at::Tensor &x2, 22 | const at::Tensor &VI) { 23 | static auto op = 24 | c10::Dispatcher::singleton() 25 | .findSchemaOrThrow("torchpairwise::__sqmahalanobis_backward", "") 26 | .typed(); 27 | return op.call(grad, x1, x2, VI); 28 | } 29 | } 30 | 31 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 32 | m.def(TORCH_SELECTIVE_SCHEMA( 33 | "torchpairwise::_sqmahalanobis(Tensor x1, Tensor x2, Tensor VI) -> Tensor") 34 | ); 35 | m.def(TORCH_SELECTIVE_SCHEMA( 36 | "torchpairwise::__sqmahalanobis_backward(Tensor grad, Tensor x1, Tensor x2, Tensor VI) -> (Tensor grad_x1, Tensor grad_x2, Tensor grad_VI)") 37 | ); 38 | } 39 | } // namespace ops 40 | } // namespace torchpairwise 41 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/sqmahalanobis.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "../macros.h" 6 | 7 | namespace torchpairwise { 8 | namespace ops { 9 | at::Tensor _sqmahalanobis( 10 | const at::Tensor &x1, 11 | const at::Tensor &x2, 12 | const at::Tensor &VI); 13 | 14 | namespace detail { 15 | std::tuple __sqmahalanobis_backward( 16 | const at::Tensor &grad, 17 | const at::Tensor &x1, 18 | const at::Tensor &x2, 19 | const at::Tensor &VI); 20 | } 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/utils/dispatch.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // bool switch 6 | #define TORCHPAIRWISE_DISPATCH_BOOL_NAME(NAME, VAL, ...) \ 7 | if (!(VAL)) { \ 8 | static const bool NAME = false; \ 9 | __VA_ARGS__(); \ 10 | } else { \ 11 | static const bool NAME = true; \ 12 | __VA_ARGS__(); \ 13 | } 14 | 15 | #define TORCHPAIRWISE_DISPATCH_BOOL(ARG1, ...) \ 16 | TORCHPAIRWISE_DISPATCH_BOOL_NAME(ARG1, ARG1, __VA_ARGS__) 17 | 18 | // scalar type 19 | #define AT_DISPATCH_CASE_BOOLEAN_TYPE(...) \ 20 | AT_DISPATCH_CASE(at::ScalarType::Bool, __VA_ARGS__) 21 | 22 | #define AT_DISPATCH_BOOLEAN_TYPE(TYPE, NAME, ...) \ 23 | AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BOOLEAN_TYPE(__VA_ARGS__)) 24 | 25 | // index type 26 | #define TORCHPAIRWISE_DISPATCH_INDEX_TYPE_CPU(N_KERNELS, ...) \ 27 | using index_t = int64_t; \ 28 | __VA_ARGS__(); \ 29 | 30 | #define TORCHPAIRWISE_DISPATCH_INDEX_TYPE_CUDA(N_KERNELS, ...) \ 31 | if (((int64_t)N_KERNELS) > (1 << 31)) { \ 32 | using index_t = int64_t; \ 33 | __VA_ARGS__(); \ 34 | } \ 35 | else { \ 36 | using index_t = int; \ 37 | __VA_ARGS__(); \ 38 | } 39 | 40 | #define TORCHPAIRWISE_DISPATCH_INDEX_TYPE_DEVICE(N_KERNELS, DEVICE, ...) \ 41 | C10_CONCATENATE(TORCHPAIRWISE_DISPATCH_INDEX_TYPE_, DEVICE)(N_KERNELS, __VA_ARGS__) 42 | 43 | #define TORCHPAIRWISE_DISPATCH_INDEX_TYPE(N_KERNELS, ...) \ 44 | if (((int64_t)N_KERNELS) > (1 << 31)) { \ 45 | using index_t = int64_t; \ 46 | __VA_ARGS__(); \ 47 | } \ 48 | else { \ 49 | using index_t = int; \ 50 | __VA_ARGS__(); \ 51 | } 52 | 53 | namespace torchpairwise { 54 | namespace ops { 55 | template 56 | struct index_type { 57 | using type = int; 58 | }; 59 | 60 | template<> 61 | struct index_type { 62 | using type = int64_t; 63 | }; 64 | 65 | template<> 66 | struct index_type { 67 | using type = int; 68 | }; 69 | 70 | template 71 | using index_type_t = typename index_type::type; 72 | 73 | inline at::ScalarType get_index_type(int64_t n_kernels) { 74 | if ((n_kernels) > (1 << 31)) { 75 | return at::ScalarType::Long; 76 | } else { 77 | return at::ScalarType::Int; 78 | } 79 | } 80 | } 81 | } 82 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/utils/scalar_type_utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | namespace c10 { 8 | template 9 | struct is_signed : std::is_signed { 10 | }; 11 | 12 | template 13 | struct is_unsigned : std::is_unsigned { 14 | }; 15 | 16 | template<> 17 | struct is_signed> : std::bool_constant { 18 | }; 19 | 20 | template<> 21 | struct is_unsigned> : std::bool_constant { 22 | }; 23 | 24 | template<> 25 | struct is_signed> : std::bool_constant { 26 | }; 27 | 28 | template<> 29 | struct is_unsigned> : std::bool_constant { 30 | }; 31 | 32 | template<> 33 | struct is_signed> : std::bool_constant { 34 | }; 35 | 36 | template<> 37 | struct is_unsigned> : std::bool_constant { 38 | }; 39 | 40 | template<> 41 | struct is_signed> : std::bool_constant { 42 | }; 43 | 44 | template<> 45 | struct is_unsigned> : std::bool_constant { 46 | }; 47 | 48 | template<> 49 | struct is_signed> : std::bool_constant { 50 | }; 51 | 52 | template<> 53 | struct is_unsigned> : std::bool_constant { 54 | }; 55 | 56 | template<> 57 | struct is_signed> : std::bool_constant { 58 | }; 59 | 60 | template<> 61 | struct is_unsigned> : std::bool_constant { 62 | }; 63 | 64 | template<> 65 | struct is_signed> : std::bool_constant { 66 | }; 67 | 68 | template<> 69 | struct is_unsigned> : std::bool_constant { 70 | }; 71 | 72 | #define HALF_MIN at::Half(0xFBFF, at::Half::from_bits()) 73 | #define HALF_MAX at::Half(0x7BFF, at::Half::from_bits()) 74 | #define HALF_LB at::Half(0xFC00, at::Half::from_bits()) 75 | #define HALF_UB at::Half(0x7C00, at::Half::from_bits()) 76 | #define HALF_EPS at::Half(0x1400, at::Half::from_bits()) 77 | #define HALF_QNAN at::Half(0x7E00, at::Half::from_bits()) 78 | #define HALF_SNAN at::Half(0x7E00, at::Half::from_bits()) 79 | 80 | #define BFLOAT16_MIN at::BFloat16(0xFF7F, at::BFloat16::from_bits()) 81 | #define BFLOAT16_MAX at::BFloat16(0x7F7F, at::BFloat16::from_bits()) 82 | #define BFLOAT16_LB at::BFloat16(0xFF80, at::BFloat16::from_bits()) 83 | #define BFLOAT16_UB at::BFloat16(0x7F80, at::BFloat16::from_bits()) 84 | #define BFLOAT16_EPS at::BFloat16(0x3A80, at::BFloat16::from_bits()) 85 | #define BFLOAT16_QNAN at::BFloat16(0x7FC0, at::BFloat16::from_bits()) 86 | #define BFLOAT16_SNAN at::BFloat16(0x7FC0, at::BFloat16::from_bits()) 87 | 88 | #define AT_FORWARD_INTEGRAL_LIMITS(_, cpp_type, scalar_type) \ 89 | _(cpp_type, scalar_type, \ 90 | std::numeric_limits::min(), \ 91 | std::numeric_limits::max(), \ 92 | std::numeric_limits::min(), \ 93 | std::numeric_limits::max(), \ 94 | std::numeric_limits::epsilon(), \ 95 | 0, 0) 96 | 97 | #define AT_FORWARD_FLOATING_LIMITS(_, cpp_type, scalar_type) \ 98 | _(cpp_type, scalar_type, \ 99 | std::numeric_limits::min(), \ 100 | std::numeric_limits::max(), \ 101 | -std::numeric_limits::infinity(), \ 102 | std::numeric_limits::infinity(), \ 103 | std::numeric_limits::epsilon(), \ 104 | std::numeric_limits::quiet_NaN(), \ 105 | std::numeric_limits::signaling_NaN()) 106 | 107 | #define AT_FORALL_TYPES(_) \ 108 | AT_FORWARD_INTEGRAL_LIMITS(_, uint8_t, Byte) /* 0 */ \ 109 | AT_FORWARD_INTEGRAL_LIMITS(_, int8_t, Char) /* 1 */ \ 110 | AT_FORWARD_INTEGRAL_LIMITS(_, int16_t, Short) /* 2 */\ 111 | AT_FORWARD_INTEGRAL_LIMITS(_, int, Int) /* 3 */ \ 112 | AT_FORWARD_INTEGRAL_LIMITS(_, int64_t, Long) /* 4 */ \ 113 | _(at::Half, Half, HALF_MIN, HALF_MAX, HALF_LB, HALF_UB, HALF_EPS, HALF_QNAN, HALF_SNAN) /* 5 */ \ 114 | AT_FORWARD_FLOATING_LIMITS(_, float, Float) /* 6 */ \ 115 | AT_FORWARD_FLOATING_LIMITS(_, double, Double) /* 7 */\ 116 | AT_FORWARD_INTEGRAL_LIMITS(_, bool, Bool) /* 8 */ \ 117 | _(at::BFloat16, BFloat16, BFLOAT16_MIN, BFLOAT16_MAX, BFLOAT16_LB, BFLOAT16_UB, BFLOAT16_EPS, BFLOAT16_QNAN, BFLOAT16_SNAN) /* 9 */ 118 | 119 | template 120 | struct ScalarTypeLimits; 121 | 122 | template 123 | struct CPPTypeLimits; 124 | 125 | #define SPECIALIZE_ScalarTypeLimits(cpp_type, scalar_type, min_value, max_value, lb_value, ub_value, eps_value, qnan_value, snan_value) \ 126 | template <> \ 127 | struct ScalarTypeLimits { \ 128 | using scalar_t = cpp_type; \ 129 | C10_NODISCARD static constexpr cpp_type(min)() noexcept { \ 130 | return min_value; \ 131 | } \ 132 | C10_NODISCARD static constexpr cpp_type(max)() noexcept { \ 133 | return max_value; \ 134 | } \ 135 | C10_NODISCARD static constexpr cpp_type(lower_bound)() noexcept { \ 136 | return lb_value; \ 137 | } \ 138 | C10_NODISCARD static constexpr cpp_type(upper_bound)() noexcept { \ 139 | return ub_value; \ 140 | } \ 141 | C10_NODISCARD static constexpr cpp_type(epsilon)() noexcept { \ 142 | return eps_value; \ 143 | } \ 144 | C10_NODISCARD static constexpr cpp_type(quiet_nan)() noexcept { \ 145 | return qnan_value; \ 146 | } \ 147 | C10_NODISCARD static constexpr cpp_type(signaling_nan)() noexcept { \ 148 | return snan_value; \ 149 | } \ 150 | }; 151 | 152 | #define SPECIALIZE_CPPTypeLimits(cpp_type, _, min_value, max_value, lb_value, ub_value, eps_value, qnan_value, snan_value) \ 153 | template <> \ 154 | struct CPPTypeLimits { \ 155 | C10_NODISCARD static constexpr cpp_type(min)() noexcept { \ 156 | return min_value; \ 157 | } \ 158 | C10_NODISCARD static constexpr cpp_type(max)() noexcept { \ 159 | return max_value; \ 160 | } \ 161 | C10_NODISCARD static constexpr cpp_type(lower_bound)() noexcept { \ 162 | return lb_value; \ 163 | } \ 164 | C10_NODISCARD static constexpr cpp_type(upper_bound)() noexcept { \ 165 | return ub_value; \ 166 | } \ 167 | C10_NODISCARD static constexpr cpp_type(epsilon)() noexcept { \ 168 | return eps_value; \ 169 | } \ 170 | C10_NODISCARD static constexpr cpp_type(quiet_nan)() noexcept { \ 171 | return qnan_value; \ 172 | } \ 173 | C10_NODISCARD static constexpr cpp_type(signaling_nan)() noexcept { \ 174 | return snan_value; \ 175 | } \ 176 | }; 177 | 178 | AT_FORALL_TYPES(SPECIALIZE_ScalarTypeLimits) 179 | 180 | AT_FORALL_TYPES(SPECIALIZE_CPPTypeLimits) 181 | 182 | #undef AT_FORWARD_INTEGRAL_LIMITS 183 | #undef AT_FORWARD_FLOATING_LIMITS 184 | #undef AT_FORALL_TYPES 185 | #undef SPECIALIZE_ScalarTypeLimits 186 | #undef SPECIALIZE_CPPTypeLimits 187 | } 188 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/wminkowski.cpp: -------------------------------------------------------------------------------- 1 | #include "wminkowski.h" 2 | 3 | #include 4 | 5 | namespace torchpairwise { 6 | namespace ops { 7 | at::Tensor _wminkowski( 8 | const at::Tensor &x1, 9 | const at::Tensor &x2, 10 | const at::Tensor &w, 11 | double p) { 12 | static auto op = c10::Dispatcher::singleton() 13 | .findSchemaOrThrow("torchpairwise::_wminkowski", "") 14 | .typed(); 15 | return op.call(x1, x2, w, p); 16 | } 17 | 18 | namespace detail { 19 | std::tuple __wminkowski_backward( 20 | const at::Tensor &grad, 21 | const at::Tensor &x1, 22 | const at::Tensor &x2, 23 | const at::Tensor &w, 24 | double p) { 25 | static auto op = 26 | c10::Dispatcher::singleton() 27 | .findSchemaOrThrow("torchpairwise::__wminkowski_backward", "") 28 | .typed(); 29 | return op.call(grad, x1, x2, w, p); 30 | } 31 | } 32 | 33 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 34 | m.def(TORCH_SELECTIVE_SCHEMA( 35 | "torchpairwise::_wminkowski(Tensor x1, Tensor x2, Tensor w, float p=2) -> Tensor") 36 | ); 37 | m.def(TORCH_SELECTIVE_SCHEMA( 38 | "torchpairwise::__wminkowski_backward(Tensor grad, Tensor x1, Tensor x2, Tensor w, float p) -> (Tensor grad_x1, Tensor grad_x2, Tensor grad_w)") 39 | ); 40 | } 41 | } // namespace ops 42 | } // namespace torchpairwise 43 | -------------------------------------------------------------------------------- /torchpairwise/csrc/ops/wminkowski.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | #include "../macros.h" 6 | 7 | namespace torchpairwise { 8 | namespace ops { 9 | TORCHPAIRWISE_API at::Tensor _wminkowski( 10 | const at::Tensor &x1, 11 | const at::Tensor &x2, 12 | const at::Tensor &w, 13 | double p = 2); 14 | 15 | namespace detail { 16 | std::tuple __wminkowski_backward( 17 | const at::Tensor &grad, 18 | const at::Tensor &x1, 19 | const at::Tensor &x2, 20 | const at::Tensor &w, 21 | double p); 22 | } 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /torchpairwise/csrc/torchpairwise.cpp: -------------------------------------------------------------------------------- 1 | #ifdef USE_PYTHON 2 | #include 3 | #endif // USE_PYTHON 4 | 5 | #include 6 | 7 | #include "torchpairwise.h" 8 | 9 | #ifdef WITH_CUDA 10 | 11 | #include 12 | 13 | #endif 14 | 15 | // If we are in a Windows environment, we need to define 16 | // initialization functions for the _C extension 17 | #ifdef _WIN32 18 | #ifdef USE_PYTHON 19 | PyMODINIT_FUNC PyInit__C(void) { 20 | // No need to do anything. 21 | // extension.py will run on load 22 | return nullptr; 23 | } 24 | #endif // USE_PYTHON 25 | #endif // _WIN32 26 | 27 | namespace torchpairwise { 28 | int64_t cuda_version() { 29 | #ifdef WITH_CUDA 30 | return CUDA_VERSION; 31 | #else 32 | return -1; 33 | #endif 34 | } 35 | 36 | std::string cuda_arch_flags() { 37 | #ifdef WITH_CUDA 38 | #ifdef CUDA_ARCH_FLAGS 39 | static const char *flags = C10_STRINGIZE(CUDA_ARCH_FLAGS); 40 | return flags; 41 | #elifdef TORCH_CUDA_ARCH_LIST 42 | static const char *flags = C10_STRINGIZE(TORCH_CUDA_ARCH_LIST); 43 | return flags; 44 | #else 45 | // TODO: this is just a work around. 46 | return std::to_string(CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR); 47 | #endif 48 | #else 49 | return {}; 50 | #endif 51 | } 52 | 53 | TORCH_LIBRARY_FRAGMENT(torchpairwise, m) { 54 | m.def("_cuda_version", &cuda_version); 55 | m.def("_cuda_arch_flags", &cuda_arch_flags); 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /torchpairwise/csrc/torchpairwise.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "ops/ops.h" 4 | #include "macros.h" 5 | 6 | namespace torchpairwise { 7 | TORCHPAIRWISE_API int64_t cuda_version(); 8 | TORCHPAIRWISE_API std::string cuda_arch_flags(); 9 | } 10 | -------------------------------------------------------------------------------- /torchpairwise/extension.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/pytorch/vision/blob/main/torchvision/extension.py 3 | """ 4 | import ctypes 5 | import importlib 6 | import os 7 | import sys 8 | from typing import List 9 | from warnings import warn 10 | 11 | import torch 12 | from torch._ops import _OpNamespace 13 | 14 | extension_namespace = os.path.basename(os.path.dirname(__file__)) 15 | 16 | 17 | def _get_extension_path(lib_name): 18 | lib_dir = os.path.dirname(__file__) 19 | if os.name == "nt": 20 | # Register the main library location on the default DLL path 21 | import ctypes 22 | import sys 23 | 24 | kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) 25 | with_load_library_flags = hasattr(kernel32, "AddDllDirectory") 26 | prev_error_mode = kernel32.SetErrorMode(0x0001) 27 | 28 | if with_load_library_flags: 29 | kernel32.AddDllDirectory.restype = ctypes.c_void_p 30 | 31 | if sys.version_info >= (3, 8): 32 | os.add_dll_directory(lib_dir) 33 | elif with_load_library_flags: 34 | res = kernel32.AddDllDirectory(lib_dir) 35 | if res is None: 36 | err = ctypes.WinError(ctypes.get_last_error()) 37 | err.strerror += f" Error adding \"{lib_dir}\" to the DLL directories." 38 | raise err 39 | 40 | kernel32.SetErrorMode(prev_error_mode) 41 | 42 | loader_details = (importlib.machinery.ExtensionFileLoader, importlib.machinery.EXTENSION_SUFFIXES) 43 | 44 | extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) 45 | ext_specs = extfinder.find_spec(lib_name) 46 | if ext_specs is None: 47 | raise ImportError 48 | 49 | return ext_specs.origin 50 | 51 | 52 | _HAS_OPS = False 53 | 54 | 55 | def _has_ops(): 56 | return False 57 | 58 | 59 | try: 60 | # On Windows Python-3.8.x has `os.add_dll_directory` call, 61 | # which is called to configure dll search path. 62 | # To find cuda related dlls we need to make sure the 63 | # conda environment/bin path is configured Please take a look: 64 | # https://stackoverflow.com/questions/59330863/cant-import-dll-module-in-python 65 | # Please note: if some path can"t be added using add_dll_directory we simply ignore this path 66 | if os.name == "nt" and (3, 8) <= sys.version_info < (3, 9): 67 | env_path = os.environ["PATH"] 68 | path_arr = env_path.split(";") 69 | for path in path_arr: 70 | if os.path.exists(path): 71 | try: 72 | os.add_dll_directory(path) # type: ignore[attr-defined] 73 | except Exception: 74 | pass 75 | 76 | lib_path = _get_extension_path("_C") 77 | torch.ops.load_library(lib_path) 78 | _HAS_OPS = True 79 | 80 | 81 | def _has_ops(): # noqa: F811 82 | return True 83 | 84 | except (ImportError, OSError): 85 | pass 86 | finally: 87 | _ops = _OpNamespace(extension_namespace) 88 | 89 | 90 | def _assert_has_ops(): 91 | if not _has_ops(): 92 | raise RuntimeError( 93 | "Couldn\'t load custom C++ ops. Recompile C++ extension with:\n" 94 | "\tpython setup.py build_ext --inplace" 95 | ) 96 | 97 | 98 | def _check_cuda_version(minor=True): 99 | """ 100 | Make sure that CUDA versions match between the pytorch install and C++ extension install 101 | 102 | Args: 103 | minor (bool): If ``False``, ignore minor version difference. 104 | Defaults to ``True``. 105 | """ 106 | if not _HAS_OPS: 107 | return -1 108 | from torch.version import cuda as torch_version_cuda 109 | 110 | _version = _ops._cuda_version() 111 | if _version != -1 and torch_version_cuda is not None: 112 | ext_version = str(_version) 113 | if int(ext_version) < 10000: 114 | ext_major = int(ext_version[0]) 115 | ext_minor = int(ext_version[2]) 116 | else: 117 | ext_major = int(ext_version[0:2]) 118 | ext_minor = int(ext_version[3]) 119 | t_version = torch_version_cuda.split(".") 120 | t_major = int(t_version[0]) 121 | t_minor = int(t_version[1]) 122 | if t_major != ext_major or (minor and t_minor != ext_minor): 123 | raise RuntimeError( 124 | "Detected that PyTorch and Extension were compiled with different CUDA versions. " 125 | f"PyTorch has CUDA Version={t_major}.{t_minor} and " 126 | f"Extension has CUDA Version={ext_major}.{ext_minor}. " 127 | "Please reinstall the Extension that matches your PyTorch install." 128 | ) 129 | elif t_minor != ext_minor: 130 | warn( 131 | "Detected that PyTorch and Extension have a minor version mismatch. " 132 | f"PyTorch has CUDA Version={t_major}.{t_minor} and " 133 | f"Extension has CUDA Version={ext_major}.{ext_minor}. " 134 | "Most likely this shouldn\'t be a problem." 135 | ) 136 | return _version 137 | 138 | 139 | def _load_library(lib_name): 140 | lib_path = _get_extension_path(lib_name) 141 | # On Windows Python-3.8+ has `os.add_dll_directory` call, 142 | # which is called from _get_extension_path to configure dll search path 143 | # Condition below adds a workaround for older versions by 144 | # explicitly calling `LoadLibraryExW` with the following flags: 145 | # - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS (0x1000) 146 | # - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR (0x100) 147 | if os.name == "nt" and sys.version_info < (3, 8): 148 | _kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) 149 | if hasattr(_kernel32, "LoadLibraryExW"): 150 | _kernel32.LoadLibraryExW(lib_path, None, 0x00001100) 151 | else: 152 | warn("LoadLibraryExW is missing in kernel32.dll") 153 | 154 | torch.ops.load_library(lib_path) 155 | 156 | 157 | _check_cuda_version(False) 158 | 159 | 160 | ################### 161 | # Exposed functions 162 | ################### 163 | def has_ops() -> bool: 164 | r""" 165 | Check if C++ extension is successfully compiled. 166 | """ 167 | return _HAS_OPS 168 | 169 | 170 | def cuda_version() -> int: 171 | r""" 172 | Get compiled Cuda version. 173 | """ 174 | if not _HAS_OPS: 175 | return -1 176 | return _ops._cuda_version() 177 | 178 | 179 | def with_cuda() -> bool: 180 | r""" 181 | Check if C++ extension is compiled with Cuda. 182 | """ 183 | return cuda_version() != -1 184 | 185 | 186 | def cuda_arch_list() -> List[str]: 187 | r""" 188 | Returns list CUDA architectures this library was compiled for. 189 | """ 190 | if not _HAS_OPS: 191 | return [] 192 | return _ops._cuda_arch_flags().split() 193 | -------------------------------------------------------------------------------- /torchpairwise/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .cpdist import * 2 | from .metrics import * 3 | -------------------------------------------------------------------------------- /torchpairwise/ops/cpdist.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | cdist = torch.ops.torchpairwise.cdist 4 | pdist = torch.ops.torchpairwise.pdist 5 | -------------------------------------------------------------------------------- /torchpairwise/ops/cpdist.pyi: -------------------------------------------------------------------------------- 1 | from torch import Tensor 2 | 3 | 4 | def cdist(x1: Tensor, 5 | x2: Tensor, 6 | metric: str = 'minkowski', 7 | **kwargs): ... 8 | 9 | 10 | def pdist(input: Tensor, 11 | metric: str = 'minkowski', 12 | **kwargs): ... 13 | -------------------------------------------------------------------------------- /torchpairwise/ops/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .neighbors import * 2 | from .pairwise import * 3 | -------------------------------------------------------------------------------- /torchpairwise/ops/metrics/neighbors.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | k_neighbors_mask = torch.ops.torchpairwise.k_neighbors_mask 4 | radius_neighbors_mask = torch.ops.torchpairwise.radius_neighbors_mask 5 | -------------------------------------------------------------------------------- /torchpairwise/ops/metrics/neighbors.pyi: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import Tensor 4 | 5 | 6 | # noinspection PyPep8Naming 7 | def k_neighbors_mask(x1: Tensor, 8 | x2: Optional[Tensor] = None, 9 | k: int = 0, 10 | metric: str = "euclidean", 11 | **kwargs) -> Tensor: ... 12 | 13 | 14 | # noinspection PyPep8Naming 15 | def radius_neighbors_mask(x1: Tensor, 16 | x2: Optional[Tensor] = None, 17 | epsilon: float = 0., 18 | metric: str = "euclidean", 19 | **kwargs) -> Tensor: ... 20 | -------------------------------------------------------------------------------- /torchpairwise/ops/metrics/pairwise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # sklearn 4 | euclidean_distances = torch.ops.torchpairwise.euclidean_distances 5 | haversine_distances = torch.ops.torchpairwise.haversine_distances 6 | manhattan_distances = torch.ops.torchpairwise.manhattan_distances 7 | cosine_distances = torch.ops.torchpairwise.cosine_distances 8 | linear_kernel = torch.ops.torchpairwise.linear_kernel 9 | polynomial_kernel = torch.ops.torchpairwise.polynomial_kernel 10 | sigmoid_kernel = torch.ops.torchpairwise.sigmoid_kernel 11 | rbf_kernel = torch.ops.torchpairwise.rbf_kernel 12 | laplacian_kernel = torch.ops.torchpairwise.laplacian_kernel 13 | cosine_similarity = torch.ops.torchpairwise.cosine_similarity 14 | additive_chi2_kernel = torch.ops.torchpairwise.additive_chi2_kernel 15 | chi2_kernel = torch.ops.torchpairwise.chi2_kernel 16 | 17 | # scipy 18 | directed_hausdorff_distances = torch.ops.torchpairwise.directed_hausdorff_distances 19 | minkowski_distances = torch.ops.torchpairwise.minkowski_distances 20 | wminkowski_distances = torch.ops.torchpairwise.wminkowski_distances 21 | sqeuclidean_distances = torch.ops.torchpairwise.sqeuclidean_distances 22 | correlation_distances = torch.ops.torchpairwise.correlation_distances 23 | hamming_distances = torch.ops.torchpairwise.hamming_distances 24 | jaccard_distances = torch.ops.torchpairwise.jaccard_distances 25 | kulsinski_distances = torch.ops.torchpairwise.kulsinski_distances 26 | kulczynski1_distances = torch.ops.torchpairwise.kulczynski1_distances 27 | seuclidean_distances = torch.ops.torchpairwise.seuclidean_distances 28 | cityblock_distances = torch.ops.torchpairwise.cityblock_distances 29 | mahalanobis_distances = torch.ops.torchpairwise.mahalanobis_distances 30 | chebyshev_distances = torch.ops.torchpairwise.chebyshev_distances 31 | braycurtis_distances = torch.ops.torchpairwise.braycurtis_distances 32 | canberra_distances = torch.ops.torchpairwise.canberra_distances 33 | jensenshannon_distances = torch.ops.torchpairwise.jensenshannon_distances 34 | yule_distances = torch.ops.torchpairwise.yule_distances 35 | dice_distances = torch.ops.torchpairwise.dice_distances 36 | rogerstanimoto_distances = torch.ops.torchpairwise.rogerstanimoto_distances 37 | russellrao_distances = torch.ops.torchpairwise.russellrao_distances 38 | sokalmichener_distances = torch.ops.torchpairwise.sokalmichener_distances 39 | sokalsneath_distances = torch.ops.torchpairwise.sokalsneath_distances 40 | 41 | # others 42 | snr_distances = torch.ops.torchpairwise.snr_distances 43 | 44 | # aliases 45 | l1_distances = torch.ops.torchpairwise.l1_distances 46 | l2_distances = torch.ops.torchpairwise.l2_distances 47 | lp_distances = torch.ops.torchpairwise.lp_distances 48 | -------------------------------------------------------------------------------- /torchpairwise/ops/metrics/pairwise.pyi: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | from torch import Tensor, Generator 4 | from torch.types import Number 5 | 6 | 7 | # ~~~~~ sklearn ~~~~~ 8 | # noinspection PyPep8Naming 9 | def euclidean_distances(x1: Tensor, 10 | x2: Optional[Tensor] = None) -> Tensor: ... 11 | 12 | 13 | # noinspection PyPep8Naming 14 | def haversine_distances(x1: Tensor, 15 | x2: Optional[Tensor] = None) -> Tensor: ... 16 | 17 | 18 | # noinspection PyPep8Naming 19 | def manhattan_distances(x1: Tensor, 20 | x2: Optional[Tensor] = None) -> Tensor: ... 21 | 22 | 23 | # noinspection PyPep8Naming 24 | def cosine_distances(x1: Tensor, 25 | x2: Optional[Tensor] = None) -> Tensor: ... 26 | 27 | 28 | # noinspection PyPep8Naming 29 | def linear_kernel(x1: Tensor, 30 | x2: Optional[Tensor] = None) -> Tensor: ... 31 | 32 | 33 | # noinspection PyPep8Naming 34 | def polynomial_kernel(x1: Tensor, 35 | x2: Optional[Tensor] = None, 36 | degree: int = 3, 37 | gamma: Optional[float] = None, 38 | coef0: float = 1.) -> Tensor: ... 39 | 40 | 41 | # noinspection PyPep8Naming 42 | def sigmoid_kernel(x1: Tensor, 43 | x2: Optional[Tensor] = None, 44 | gamma: Optional[float] = None, 45 | coef0: float = 1.) -> Tensor: ... 46 | 47 | 48 | # noinspection PyPep8Naming 49 | def rbf_kernel(x1: Tensor, 50 | x2: Optional[Tensor] = None, 51 | gamma: Optional[float] = None) -> Tensor: ... 52 | 53 | 54 | # noinspection PyPep8Naming 55 | def laplacian_kernel(x1: Tensor, 56 | x2: Optional[Tensor] = None, 57 | gamma: Optional[float] = None) -> Tensor: ... 58 | 59 | 60 | # noinspection PyPep8Naming 61 | def cosine_similarity(x1: Tensor, 62 | x2: Optional[Tensor] = None) -> Tensor: ... 63 | 64 | 65 | # noinspection PyPep8Naming 66 | def additive_chi2_kernel(x1: Tensor, 67 | x2: Optional[Tensor] = None) -> Tensor: ... 68 | 69 | 70 | # noinspection PyPep8Naming 71 | def chi2_kernel(x1: Tensor, 72 | x2: Optional[Tensor] = None, 73 | gamma: float = 1.) -> Tensor: ... 74 | 75 | 76 | # ~~~~~ scipy ~~~~~ 77 | # noinspection PyPep8Naming 78 | def directed_hausdorff_distances(x1: Tensor, 79 | x2: Optional[Tensor] = None, 80 | *, 81 | shuffle: bool = False, 82 | generator: Optional[Generator] = None) -> Tensor: ... 83 | 84 | 85 | # noinspection PyPep8Naming 86 | def minkowski_distances(x1: Tensor, 87 | x2: Optional[Tensor] = None, 88 | p: float = 2) -> Tensor: ... 89 | 90 | 91 | # noinspection PyPep8Naming 92 | def wminkowski_distances(x1: Tensor, 93 | x2: Optional[Tensor] = None, 94 | p: float = 2, 95 | w: Optional[Tensor] = None) -> Tensor: ... 96 | 97 | 98 | # noinspection PyPep8Naming 99 | def sqeuclidean_distances(x1: Tensor, 100 | x2: Optional[Tensor] = None) -> Tensor: ... 101 | 102 | 103 | # noinspection PyPep8Naming 104 | def correlation_distances(x1: Tensor, 105 | x2: Optional[Tensor] = None) -> Tensor: ... 106 | 107 | 108 | # noinspection PyPep8Naming 109 | def hamming_distances(x1: Tensor, 110 | x2: Optional[Tensor] = None) -> Tensor: ... 111 | 112 | 113 | # noinspection PyPep8Naming 114 | def jaccard_distances(x1: Tensor, 115 | x2: Optional[Tensor] = None) -> Tensor: ... 116 | 117 | 118 | # noinspection PyPep8Naming 119 | def kulsinski_distances(x1: Tensor, 120 | x2: Optional[Tensor] = None) -> Tensor: ... 121 | 122 | 123 | # noinspection PyPep8Naming 124 | def kulczynski1_distances(x1: Tensor, 125 | x2: Optional[Tensor] = None) -> Tensor: ... 126 | 127 | 128 | # noinspection PyPep8Naming 129 | def seuclidean_distances(x1: Tensor, 130 | x2: Optional[Tensor] = None, 131 | V: Optional[Tensor] = None) -> Tensor: ... 132 | 133 | 134 | # noinspection PyPep8Naming 135 | def cityblock_distances(x1: Tensor, 136 | x2: Optional[Tensor] = None) -> Tensor: ... 137 | 138 | 139 | # noinspection PyPep8Naming 140 | def mahalanobis_distances(x1: Tensor, 141 | x2: Optional[Tensor] = None, 142 | VI: Optional[Tensor] = None) -> Tensor: ... 143 | 144 | 145 | # noinspection PyPep8Naming 146 | def chebyshev_distances(x1: Tensor, 147 | x2: Optional[Tensor] = None) -> Tensor: ... 148 | 149 | 150 | # noinspection PyPep8Naming 151 | def braycurtis_distances(x1: Tensor, 152 | x2: Optional[Tensor] = None) -> Tensor: ... 153 | 154 | 155 | # noinspection PyPep8Naming 156 | def canberra_distances(x1: Tensor, 157 | x2: Optional[Tensor] = None) -> Tensor: ... 158 | 159 | 160 | # noinspection PyPep8Naming 161 | def jensenshannon_distances(x1: Tensor, 162 | x2: Optional[Tensor] = None, 163 | base: Optional[float] = None, 164 | dim: int = -1, 165 | keepdim: bool = False) -> Tensor: ... 166 | 167 | 168 | # noinspection PyPep8Naming 169 | def yule_distances(x1: Tensor, 170 | x2: Optional[Tensor] = None) -> Tensor: ... 171 | 172 | 173 | # noinspection PyPep8Naming 174 | def dice_distances(x1: Tensor, 175 | x2: Optional[Tensor] = None) -> Tensor: ... 176 | 177 | 178 | # noinspection PyPep8Naming 179 | def rogerstanimoto_distances(x1: Tensor, 180 | x2: Optional[Tensor] = None) -> Tensor: ... 181 | 182 | 183 | # noinspection PyPep8Naming 184 | def russellrao_distances(x1: Tensor, 185 | x2: Optional[Tensor] = None) -> Tensor: ... 186 | 187 | 188 | # noinspection PyPep8Naming 189 | def sokalmichener_distances(x1: Tensor, 190 | x2: Optional[Tensor] = None) -> Tensor: ... 191 | 192 | 193 | # noinspection PyPep8Naming 194 | def sokalsneath_distances(x1: Tensor, 195 | x2: Optional[Tensor] = None) -> Tensor: ... 196 | 197 | 198 | # ~~~~~ others ~~~~~ 199 | # noinspection PyPep8Naming 200 | def snr_distances(x1: Tensor, 201 | x2: Optional[Tensor] = None, 202 | correction: Number = 1) -> Tensor: ... 203 | 204 | 205 | # ~~~~~ aliases ~~~~~ 206 | def l1_distances(x1: Tensor, 207 | x2: Optional[Tensor] = None) -> Tensor: ... 208 | 209 | 210 | def l2_distances(x1: Tensor, 211 | x2: Optional[Tensor] = None) -> Tensor: ... 212 | 213 | 214 | def lp_distances(x1: Tensor, 215 | x2: Optional[Tensor] = None, 216 | p: float = 2) -> Tensor: ... 217 | --------------------------------------------------------------------------------