├── .coveragerc ├── .github └── workflows │ ├── building.yml │ ├── cuda │ ├── cu101-Linux-env.sh │ ├── cu101-Linux.sh │ ├── cu101-Windows-env.sh │ ├── cu101-Windows.sh │ ├── cu102-Linux-env.sh │ ├── cu102-Linux.sh │ ├── cu102-Windows-env.sh │ ├── cu102-Windows.sh │ ├── cu111-Linux-env.sh │ ├── cu111-Linux.sh │ ├── cu111-Windows-env.sh │ ├── cu111-Windows.sh │ ├── cu113-Linux-env.sh │ ├── cu113-Linux.sh │ ├── cu113-Windows-env.sh │ ├── cu113-Windows.sh │ ├── cu115-Linux-env.sh │ ├── cu115-Linux.sh │ ├── cu115-Windows-env.sh │ ├── cu115-Windows.sh │ ├── cu116-Linux-env.sh │ ├── cu116-Linux.sh │ ├── cu116-Windows-env.sh │ ├── cu116-Windows.sh │ ├── cu117-Linux-env.sh │ ├── cu117-Linux.sh │ ├── cu117-Windows-env.sh │ ├── cu117-Windows.sh │ ├── cu118-Linux-env.sh │ ├── cu118-Linux.sh │ ├── cu118-Windows-env.sh │ ├── cu118-Windows.sh │ ├── cu121-Linux-env.sh │ ├── cu121-Linux.sh │ ├── cu121-Windows-env.sh │ ├── cu121-Windows.sh │ ├── cu124-Linux-env.sh │ ├── cu124-Linux.sh │ ├── cu124-Windows-env.sh │ ├── cu124-Windows.sh │ ├── cu126-Linux-env.sh │ ├── cu126-Linux.sh │ ├── cu126-Windows-env.sh │ ├── cu126-Windows.sh │ ├── cu128-Linux-env.sh │ ├── cu128-Linux.sh │ ├── cu128-Windows-env.sh │ └── cu128-Windows.sh │ ├── linting.yml │ ├── stale.yml │ └── testing.yml ├── .gitignore ├── CMakeLists.txt ├── LICENSE ├── MANIFEST.in ├── README.md ├── benchmark ├── .gitignore ├── gather.py └── scatter_segment.py ├── cmake └── TorchScatterConfig.cmake.in ├── csrc ├── cpu │ ├── index_info.h │ ├── reducer.h │ ├── scatter_cpu.cpp │ ├── scatter_cpu.h │ ├── segment_coo_cpu.cpp │ ├── segment_coo_cpu.h │ ├── segment_csr_cpu.cpp │ ├── segment_csr_cpu.h │ └── utils.h ├── cuda │ ├── atomics.cuh │ ├── index_info.cuh │ ├── reducer.cuh │ ├── scatter_cuda.cu │ ├── scatter_cuda.h │ ├── segment_coo_cuda.cu │ ├── segment_coo_cuda.h │ ├── segment_csr_cuda.cu │ ├── segment_csr_cuda.h │ └── utils.cuh ├── extensions.h ├── macros.h ├── scatter.cpp ├── scatter.h ├── segment_coo.cpp ├── segment_csr.cpp ├── utils.h └── version.cpp ├── docs ├── .nojekyll ├── Makefile ├── index.html ├── requirements.txt └── source │ ├── _figures │ ├── add.svg │ ├── add.tex │ ├── build.sh │ ├── div.svg │ ├── div.tex │ ├── max.svg │ ├── max.tex │ ├── mean.svg │ ├── mean.tex │ ├── min.svg │ ├── min.tex │ ├── mul.svg │ ├── mul.tex │ ├── segment_coo.svg │ ├── segment_coo.tex │ ├── std.svg │ ├── std.tex │ ├── sub.svg │ ├── sub.tex │ └── template.tex │ ├── conf.py │ ├── functions │ ├── composite.rst │ ├── scatter.rst │ ├── segment_coo.rst │ └── segment_csr.rst │ └── index.rst ├── pyproject.toml ├── readthedocs.yml ├── setup.cfg ├── setup.py ├── test ├── composite │ ├── test_logsumexp.py │ ├── test_softmax.py │ └── test_std.py ├── test_broadcasting.py ├── test_gather.py ├── test_multi_gpu.py ├── test_scatter.py ├── test_segment.py └── test_zero_tensors.py └── torch_scatter ├── __init__.py ├── composite ├── __init__.py ├── logsumexp.py ├── softmax.py └── std.py ├── placeholder.py ├── scatter.py ├── segment_coo.py ├── segment_csr.py ├── testing.py └── utils.py /.coveragerc: -------------------------------------------------------------------------------- 1 | [run] 2 | source=torch_scatter 3 | omit=torch_scatter/placeholder.py 4 | [report] 5 | exclude_lines = 6 | pragma: no cover 7 | torch.jit.script 8 | raise 9 | except 10 | -------------------------------------------------------------------------------- /.github/workflows/building.yml: -------------------------------------------------------------------------------- 1 | name: Building Wheels 2 | 3 | on: [workflow_dispatch] 4 | 5 | jobs: 6 | 7 | wheel: 8 | runs-on: ${{ matrix.os }} 9 | 10 | strategy: 11 | fail-fast: false 12 | matrix: 13 | os: [ubuntu-22.04, macos-14, windows-2019, ubuntu-22.04-arm] 14 | python-version: ['3.9', '3.10', '3.11', '3.12', '3.13'] 15 | torch-version: [2.7.0] # [2.6.0] 16 | cuda-version: ['cpu', 'cu118', 'cu124', 'cu126', 'cu128'] 17 | exclude: 18 | - torch-version: 2.6.0 19 | cuda-version: 'cu128' 20 | - torch-version: 2.7.0 21 | cuda-version: 'cu124' 22 | - os: macos-14 23 | cuda-version: 'cu118' 24 | - os: macos-14 25 | cuda-version: 'cu124' 26 | - os: macos-14 27 | cuda-version: 'cu126' 28 | - os: macos-14 29 | cuda-version: 'cu128' 30 | - os: ubuntu-22.04-arm 31 | cuda-version: 'cu118' 32 | - os: ubuntu-22.04-arm 33 | cuda-version: 'cu124' 34 | - os: ubuntu-22.04-arm 35 | cuda-version: 'cu126' 36 | - os: ubuntu-22.04-arm 37 | cuda-version: 'cu128' 38 | 39 | steps: 40 | - uses: actions/checkout@v2 41 | - name: Set up Python ${{ matrix.python-version }} 42 | uses: actions/setup-python@v2 43 | with: 44 | python-version: ${{ matrix.python-version }} 45 | 46 | - name: Upgrade pip 47 | run: | 48 | pip install --upgrade setuptools 49 | pip install wheel 50 | 51 | - name: Free Disk Space (Ubuntu) 52 | if: ${{ runner.os == 'Linux' }} 53 | uses: jlumbroso/free-disk-space@main 54 | 55 | - name: Install CUDA ${{ matrix.cuda-version }} 56 | if: ${{ matrix.cuda-version != 'cpu' }} 57 | run: | 58 | bash .github/workflows/cuda/${{ matrix.cuda-version }}-${{ runner.os }}.sh 59 | 60 | - name: Install PyTorch ${{ matrix.torch-version }}+${{ matrix.cuda-version }} 61 | run: | 62 | pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/${{ matrix.cuda-version }} 63 | python -c "import torch; print('PyTorch:', torch.__version__)" 64 | python -c "import torch; print('CUDA:', torch.version.cuda)" 65 | 66 | - name: Patch PyTorch static constexpr on Windows 67 | if: ${{ runner.os == 'Windows' }} 68 | run: | 69 | Torch_DIR=`python -c 'import os; import torch; print(os.path.dirname(torch.__file__))'` 70 | sed -i '31,38c\ 71 | TORCH_API void lazy_init_num_threads();' ${Torch_DIR}/include/ATen/Parallel.h 72 | shell: bash 73 | 74 | - name: Set version 75 | if: ${{ runner.os != 'macOS' }} 76 | run: | 77 | VERSION=`sed -n "s/^__version__ = '\(.*\)'/\1/p" torch_scatter/__init__.py` 78 | TORCH_VERSION=`echo "pt${{ matrix.torch-version }}" | sed "s/..$//" | sed "s/\.//g"` 79 | CUDA_VERSION=`echo ${{ matrix.cuda-version }}` 80 | echo "New version name: $VERSION+$TORCH_VERSION$CUDA_VERSION" 81 | sed -i "s/$VERSION/$VERSION+$TORCH_VERSION$CUDA_VERSION/" setup.py 82 | sed -i "s/$VERSION/$VERSION+$TORCH_VERSION$CUDA_VERSION/" torch_scatter/__init__.py 83 | shell: 84 | bash 85 | 86 | - name: Build wheel for CPU 87 | if: ${{ matrix.cuda-version == 'cpu' }} 88 | run: | 89 | FORCE_ONLY_CPU=1 python setup.py bdist_wheel --dist-dir=dist 90 | shell: 91 | bash 92 | 93 | - name: Build wheel for GPU 94 | if: ${{ matrix.cuda-version != 'cpu' }} 95 | run: | 96 | source .github/workflows/cuda/${{ matrix.cuda-version }}-${{ runner.os }}-env.sh 97 | FORCE_CUDA=1 python setup.py bdist_wheel --dist-dir=dist 98 | shell: 99 | bash 100 | 101 | - name: Configure AWS 102 | uses: aws-actions/configure-aws-credentials@v1 103 | with: 104 | aws-access-key-id: ${{ secrets.AWS_ACCESS_KEY_ID }} 105 | aws-secret-access-key: ${{ secrets.AWS_SECRET_ACCESS_KEY }} 106 | aws-region: us-west-1 107 | 108 | - name: Upload wheel 109 | run: | 110 | aws s3 sync dist s3://data.pyg.org/whl/torch-${{ matrix.torch-version }}+${{ matrix.cuda-version }} --grants read=uri=http://acs.amazonaws.com/groups/global/AllUsers 111 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu101-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-10.1 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu101-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu1804 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda-repo-${OS}-10-1-local-10.1.243-418.87.00_1.0-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-10-1-local-10.1.243-418.87.00_1.0-1_amd64.deb 9 | sudo apt-key add /var/cuda-repo-10-1-local-10.1.243-418.87.00/7fa2af80.pub 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-10-1 cuda-libraries-dev-10-1 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/10.1/Prod/local_installers/cuda-repo-${OS}-10-1-local-10.1.243-418.87.00_1.0-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu101-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v10.1 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu101-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install NVIDIA drivers, see: 4 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 5 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 6 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 7 | 8 | export CUDA_SHORT=10.1 9 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}/Prod/local_installers/ 10 | export CUDA_FILE=cuda_${CUDA_SHORT}.243_426.00_win10.exe 11 | 12 | # Install CUDA: 13 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 14 | echo "" 15 | echo "Installing from ${CUDA_FILE}..." 16 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 17 | echo "Done!" 18 | rm -f "${CUDA_FILE}" 19 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu102-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-10.2 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu102-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu1804 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda-repo-${OS}-10-2-local-10.2.89-440.33.01_1.0-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-10-2-local-10.2.89-440.33.01_1.0-1_amd64.deb 9 | sudo apt-key add /var/cuda-repo-10-2-local-10.2.89-440.33.01/7fa2af80.pub 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-10-2 cuda-libraries-dev-10-2 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/10.2/Prod/local_installers/cuda-repo-${OS}-10-2-local-10.2.89-440.33.01_1.0-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu102-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v10.2 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu102-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install NVIDIA drivers, see: 4 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 5 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 6 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 7 | 8 | export CUDA_SHORT=10.2 9 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}/Prod/local_installers 10 | export CUDA_FILE=cuda_${CUDA_SHORT}.89_441.22_win10.exe 11 | 12 | # Install CUDA: 13 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 14 | echo "" 15 | echo "Installing from ${CUDA_FILE}..." 16 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 17 | echo "Done!" 18 | rm -f "${CUDA_FILE}" 19 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu111-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-11.1 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu111-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu1804 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda-repo-${OS}-11-1-local_11.1.1-455.32.00-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-11-1-local_11.1.1-455.32.00-1_amd64.deb 9 | sudo apt-key add /var/cuda-repo-${OS}-11-1-local/7fa2af80.pub 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-11-1 cuda-libraries-dev-11-1 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/11.1.1/local_installers/cuda-repo-${OS}-11-1-local_11.1.1-455.32.00-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu111-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.1 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu111-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install NVIDIA drivers, see: 4 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 5 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 6 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 7 | 8 | export CUDA_SHORT=11.1 9 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.1/local_installers 10 | export CUDA_FILE=cuda_${CUDA_SHORT}.1_456.81_win10.exe 11 | 12 | # Install CUDA: 13 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 14 | echo "" 15 | echo "Installing from ${CUDA_FILE}..." 16 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 17 | echo "Done!" 18 | rm -f "${CUDA_FILE}" 19 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu113-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-11.3 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu113-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu1804 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb 9 | sudo apt-key add /var/cuda-repo-${OS}-11-3-local/7fa2af80.pub 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-11-3 cuda-libraries-dev-11-3 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/11.3.0/local_installers/cuda-repo-${OS}-11-3-local_11.3.0-465.19.01-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu113-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.3 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu113-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install NVIDIA drivers, see: 4 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 5 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 6 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 7 | 8 | export CUDA_SHORT=11.3 9 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 10 | export CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe 11 | 12 | # Install CUDA: 13 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 14 | echo "" 15 | echo "Installing from ${CUDA_FILE}..." 16 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 17 | echo "Done!" 18 | rm -f "${CUDA_FILE}" 19 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu115-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-11.5 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu115-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu1804 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda-repo-${OS}-11-5-local_11.5.2-495.29.05-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-11-5-local_11.5.2-495.29.05-1_amd64.deb 9 | sudo apt-key add /var/cuda-repo-${OS}-11-5-local/7fa2af80.pub 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-11-5 cuda-libraries-dev-11-5 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda-repo-${OS}-11-5-local_11.5.2-495.29.05-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu115-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.3 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu115-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # TODO We currently use CUDA 11.3 to build CUDA 11.5 Windows wheels 4 | 5 | # Install NVIDIA drivers, see: 6 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 7 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 8 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 9 | 10 | export CUDA_SHORT=11.3 11 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 12 | export CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe 13 | 14 | # Install CUDA: 15 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 16 | echo "" 17 | echo "Installing from ${CUDA_FILE}..." 18 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 19 | echo "Done!" 20 | rm -f "${CUDA_FILE}" 21 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu116-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-11.6 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu116-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu1804 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/11.6.2/local_installers/cuda-repo-${OS}-11-6-local_11.6.2-510.47.03-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-11-6-local_11.6.2-510.47.03-1_amd64.deb 9 | sudo apt-key add /var/cuda-repo-${OS}-11-6-local/7fa2af80.pub 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-11-6 cuda-libraries-dev-11-6 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/11.5.2/local_installers/cuda-repo-${OS}-11-6-local_11.6.2-510.47.03-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu116-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.3 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu116-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # TODO We currently use CUDA 11.3 to build CUDA 11.6 Windows wheels 4 | 5 | # Install NVIDIA drivers, see: 6 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 7 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 8 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 9 | 10 | export CUDA_SHORT=11.3 11 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 12 | export CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe 13 | 14 | # Install CUDA: 15 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 16 | echo "" 17 | echo "Installing from ${CUDA_FILE}..." 18 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 19 | echo "Done!" 20 | rm -f "${CUDA_FILE}" 21 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu117-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-11.7 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu117-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu2004 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda-repo-${OS}-11-7-local_11.7.1-515.65.01-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-11-7-local_11.7.1-515.65.01-1_amd64.deb 9 | sudo cp /var/cuda-repo-${OS}-11-7-local/cuda-*-keyring.gpg /usr/share/keyrings/ 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-11-7 cuda-libraries-dev-11-7 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/11.7.1/local_installers/cuda-repo-${OS}-11-7-local_11.7.1-515.65.01-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu117-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.3 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu117-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # TODO We currently use CUDA 11.3 to build CUDA 11.7 Windows wheels 4 | 5 | # Install NVIDIA drivers, see: 6 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 7 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 8 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 9 | 10 | export CUDA_SHORT=11.3 11 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 12 | export CUDA_FILE=cuda_${CUDA_SHORT}.0_465.89_win10.exe 13 | 14 | # Install CUDA: 15 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 16 | echo "" 17 | echo "Installing from ${CUDA_FILE}..." 18 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 19 | echo "Done!" 20 | rm -f "${CUDA_FILE}" 21 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu118-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-11.8 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu118-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu2204 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda-repo-${OS}-11-8-local_11.8.0-520.61.05-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-11-8-local_11.8.0-520.61.05-1_amd64.deb 9 | sudo cp /var/cuda-repo-${OS}-11-8-local/cuda-*-keyring.gpg /usr/share/keyrings/ 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-11-8 cuda-libraries-dev-11-8 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/11.8.0/local_installers/cuda-repo-${OS}-11-8-local_11.8.0-520.61.05-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu118-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v11.8 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu118-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install NVIDIA drivers, see: 4 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 5 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 6 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 7 | 8 | export CUDA_SHORT=11.8 9 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 10 | export CUDA_FILE=cuda_${CUDA_SHORT}.0_522.06_windows.exe 11 | 12 | # Install CUDA: 13 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 14 | echo "" 15 | echo "Installing from ${CUDA_FILE}..." 16 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 17 | echo "Done!" 18 | rm -f "${CUDA_FILE}" 19 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu121-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-12.1 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu121-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu2004 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/12.1.1/local_installers/cuda-repo-${OS}-12-1-local_12.1.1-530.30.02-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-12-1-local_12.1.1-530.30.02-1_amd64.deb 9 | sudo cp /var/cuda-repo-${OS}-12-1-local/cuda-*-keyring.gpg /usr/share/keyrings/ 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-12-1 cuda-libraries-dev-12-1 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/12.1.0/local_installers/cuda-repo-${OS}-12-1-local_12.1.1-530.30.02-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu121-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v12.1 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu121-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install NVIDIA drivers, see: 4 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 5 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 6 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 7 | 8 | export CUDA_SHORT=12.1 9 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.1/local_installers 10 | export CUDA_FILE=cuda_${CUDA_SHORT}.1_531.14_windows.exe 11 | 12 | # Install CUDA: 13 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 14 | echo "" 15 | echo "Installing from ${CUDA_FILE}..." 16 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 17 | echo "Done!" 18 | rm -f "${CUDA_FILE}" 19 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu124-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-12.4 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu124-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu2204 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/12.4.1/local_installers/cuda-repo-${OS}-12-4-local_12.4.1-550.54.15-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-12-4-local_12.4.1-550.54.15-1_amd64.deb 9 | sudo cp /var/cuda-repo-${OS}-12-4-local/cuda-*-keyring.gpg /usr/share/keyrings/ 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-12-4 cuda-libraries-dev-12-4 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda-repo-${OS}-12-4-local_12.4.1-550.54.15-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu124-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v12.4 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu124-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install NVIDIA drivers, see: 4 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 5 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 6 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 7 | 8 | export CUDA_SHORT=12.4 9 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.1/local_installers 10 | export CUDA_FILE=cuda_${CUDA_SHORT}.1_551.78_windows.exe 11 | 12 | # Install CUDA: 13 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 14 | echo "" 15 | echo "Installing from ${CUDA_FILE}..." 16 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 17 | echo "Done!" 18 | rm -f "${CUDA_FILE}" 19 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu126-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-12.6 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu126-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu2204 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/12.6.0/local_installers/cuda-repo-${OS}-12-6-local_12.6.0-560.28.03-1_amd64.deb 8 | sudo dpkg -i cuda-repo-${OS}-12-6-local_12.6.0-560.28.03-1_amd64.deb 9 | sudo cp /var/cuda-repo-${OS}-12-6-local/cuda-*-keyring.gpg /usr/share/keyrings/ 10 | 11 | sudo apt-get -qq update 12 | sudo apt install cuda-nvcc-12-6 cuda-libraries-dev-12-6 13 | sudo apt clean 14 | 15 | rm -f https://developer.download.nvidia.com/compute/cuda/12.6.0/local_installers/cuda-repo-${OS}-12-6-local_12.6.0-560.28.03-1_amd64.deb 16 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu126-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v12.6 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu126-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install NVIDIA drivers, see: 4 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 5 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 6 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 7 | 8 | export CUDA_SHORT=12.6 9 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 10 | export CUDA_FILE=cuda_${CUDA_SHORT}.0_560.76_windows.exe 11 | 12 | # Install CUDA: 13 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 14 | echo "" 15 | echo "Installing from ${CUDA_FILE}..." 16 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 17 | echo "Done!" 18 | rm -f "${CUDA_FILE}" 19 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu128-Linux-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/usr/local/cuda-12.8 4 | LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} 5 | PATH=${CUDA_HOME}/bin:${PATH} 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu128-Linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | OS=ubuntu2204 4 | 5 | wget -nv https://developer.download.nvidia.com/compute/cuda/repos/${OS}/x86_64/cuda-${OS}.pin 6 | sudo mv cuda-${OS}.pin /etc/apt/preferences.d/cuda-repository-pin-600 7 | wget -nv https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda-repo-${OS}-12-8-local_12.8.0-570.86.10-1_amd64.deb 8 | 9 | sudo dpkg -i cuda-repo-${OS}-12-8-local_12.8.0-570.86.10-1_amd64.deb 10 | sudo cp /var/cuda-repo-${OS}-12-8-local/cuda-*-keyring.gpg /usr/share/keyrings/ 11 | 12 | sudo apt-get -qq update 13 | sudo apt install cuda-nvcc-12-8 cuda-libraries-dev-12-8 14 | sudo apt clean 15 | 16 | rm -f https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda-repo-${OS}-12-8-local_12.8.0-570.86.10-1_amd64.deb 17 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu128-Windows-env.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_HOME=/c/Program\ Files/NVIDIA\ GPU\ Computing\ Toolkit/CUDA/v12.8 4 | PATH=${CUDA_HOME}/bin:$PATH 5 | PATH=/c/Program\ Files\ \(x86\)/Microsoft\ Visual\ Studio/2017/BuildTools/MSBuild/15.0/Bin:$PATH 6 | 7 | export FORCE_CUDA=1 8 | export TORCH_CUDA_ARCH_LIST="6.0+PTX" 9 | -------------------------------------------------------------------------------- /.github/workflows/cuda/cu128-Windows.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Install NVIDIA drivers, see: 4 | # https://github.com/pytorch/vision/blob/master/packaging/windows/internal/cuda_install.bat#L99-L102 5 | curl -k -L "https://drive.google.com/u/0/uc?id=1injUyo3lnarMgWyRcXqKg4UGnN0ysmuq&export=download" --output "/tmp/gpu_driver_dlls.zip" 6 | 7z x "/tmp/gpu_driver_dlls.zip" -o"/c/Windows/System32" 7 | 8 | export CUDA_SHORT=12.8 9 | export CUDA_URL=https://developer.download.nvidia.com/compute/cuda/${CUDA_SHORT}.0/local_installers 10 | export CUDA_FILE=cuda_${CUDA_SHORT}.0_571.96_windows.exe 11 | 12 | # Install CUDA: 13 | curl -k -L "${CUDA_URL}/${CUDA_FILE}" --output "${CUDA_FILE}" 14 | echo "" 15 | echo "Installing from ${CUDA_FILE}..." 16 | PowerShell -Command "Start-Process -FilePath \"${CUDA_FILE}\" -ArgumentList \"-s nvcc_${CUDA_SHORT} cuobjdump_${CUDA_SHORT} nvprune_${CUDA_SHORT} cupti_${CUDA_SHORT} cublas_dev_${CUDA_SHORT} cudart_${CUDA_SHORT} cufft_dev_${CUDA_SHORT} curand_dev_${CUDA_SHORT} cusolver_dev_${CUDA_SHORT} cusparse_dev_${CUDA_SHORT} thrust_${CUDA_SHORT} npp_dev_${CUDA_SHORT} nvrtc_dev_${CUDA_SHORT} nvml_dev_${CUDA_SHORT}\" -Wait -NoNewWindow" 17 | echo "Done!" 18 | rm -f "${CUDA_FILE}" 19 | -------------------------------------------------------------------------------- /.github/workflows/linting.yml: -------------------------------------------------------------------------------- 1 | name: Linting 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | jobs: 10 | 11 | flake8: 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: 3.8 21 | 22 | - name: Install dependencies 23 | run: | 24 | pip install flake8 25 | 26 | - name: Run linting 27 | run: | 28 | flake8 . 29 | -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: "Close stale issues and PRs" 2 | 3 | on: 4 | schedule: 5 | # Every day at 00:00 6 | - cron: "0 0 * * *" 7 | workflow_dispatch: 8 | 9 | jobs: 10 | stale: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/stale@v4.0.0 15 | with: 16 | stale-issue-message: 'This issue had no activity for **6 months**. It will be closed in **2 weeks** unless there is some new activity. Is this issue already resolved?' 17 | stale-issue-label: 'stale' 18 | exempt-issue-labels: 'bug,enhancement,good first issue' 19 | stale-pr-message: 'This pull request had no activity for **6 months**. It will be closed in **2 weeks** unless there is some new activity.' 20 | stale-pr-label: 'stale' 21 | days-before-stale: 180 22 | days-before-close: 14 23 | operations-per-run: 200 24 | -------------------------------------------------------------------------------- /.github/workflows/testing.yml: -------------------------------------------------------------------------------- 1 | name: Testing 2 | 3 | on: 4 | push: 5 | branches: 6 | - master 7 | pull_request: 8 | 9 | jobs: 10 | 11 | pytest: 12 | runs-on: ${{ matrix.os }} 13 | 14 | strategy: 15 | fail-fast: false 16 | matrix: 17 | os: [ubuntu-latest, windows-latest] 18 | python-version: [3.9] 19 | torch-version: [2.6.0, 2.7.0] 20 | 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v2 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Install PyTorch ${{ matrix.torch-version }} 29 | run: | 30 | pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu 31 | 32 | - name: Patch PyTorch static constexpr on Windows 33 | if: ${{ runner.os == 'Windows' }} 34 | run: | 35 | Torch_DIR=`python -c 'import os; import torch; print(os.path.dirname(torch.__file__))'` 36 | sed -i '31,38c\ 37 | TORCH_API void lazy_init_num_threads();' ${Torch_DIR}/include/ATen/Parallel.h 38 | shell: bash 39 | 40 | - name: Install main package 41 | run: | 42 | python setup.py develop 43 | 44 | - name: Run test-suite 45 | run: | 46 | pip install pytest pytest-cov 47 | pytest --cov --cov-report=xml 48 | 49 | - name: Upload coverage 50 | uses: codecov/codecov-action@v4 51 | if: success() 52 | with: 53 | fail_ci_if_error: false 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | build/ 3 | dist/ 4 | .cache/ 5 | .eggs/ 6 | *.egg-info/ 7 | .coverage 8 | *.so 9 | *.aux 10 | *.log 11 | *.pdf 12 | *.hip 13 | *_hip.cpp 14 | hip 15 | -------------------------------------------------------------------------------- /CMakeLists.txt: -------------------------------------------------------------------------------- 1 | cmake_minimum_required(VERSION 3.0) 2 | project(torchscatter) 3 | set(CMAKE_CXX_STANDARD 17) 4 | set(CMAKE_CXX_STANDARD_REQUIRED) 5 | set(TORCHSCATTER_VERSION 2.1.2) 6 | 7 | option(WITH_CUDA "Enable CUDA support" OFF) 8 | option(WITH_PYTHON "Link to Python when building" ON) 9 | 10 | if(WITH_CUDA) 11 | enable_language(CUDA) 12 | add_definitions(-D__CUDA_NO_HALF_OPERATORS__) 13 | add_definitions(-DWITH_CUDA) 14 | set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") 15 | endif() 16 | 17 | if (WITH_PYTHON) 18 | add_definitions(-DWITH_PYTHON) 19 | find_package(Python3 COMPONENTS Development) 20 | endif() 21 | find_package(Torch REQUIRED) 22 | 23 | file(GLOB HEADERS csrc/*.h) 24 | file(GLOB OPERATOR_SOURCES csrc/cpu/*.h csrc/cpu/*.cpp csrc/*.cpp) 25 | if(WITH_CUDA) 26 | file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} csrc/cuda/*.h csrc/cuda/*.cu) 27 | endif() 28 | 29 | add_library(${PROJECT_NAME} SHARED ${OPERATOR_SOURCES}) 30 | target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES}) 31 | if (WITH_PYTHON) 32 | target_link_libraries(${PROJECT_NAME} PRIVATE Python3::Python) 33 | endif() 34 | set_target_properties(${PROJECT_NAME} PROPERTIES EXPORT_NAME TorchScatter) 35 | 36 | target_include_directories(${PROJECT_NAME} INTERFACE 37 | "$" 38 | $) 39 | 40 | include(GNUInstallDirs) 41 | include(CMakePackageConfigHelpers) 42 | 43 | set(TORCHSCATTER_CMAKECONFIG_INSTALL_DIR "share/cmake/TorchScatter" CACHE STRING "install path for TorchScatterConfig.cmake") 44 | 45 | configure_package_config_file(cmake/TorchScatterConfig.cmake.in 46 | "${CMAKE_CURRENT_BINARY_DIR}/TorchScatterConfig.cmake" 47 | INSTALL_DESTINATION ${TORCHSCATTER_CMAKECONFIG_INSTALL_DIR}) 48 | 49 | write_basic_package_version_file(${CMAKE_CURRENT_BINARY_DIR}/TorchScatterConfigVersion.cmake 50 | VERSION ${TORCHSCATTER_VERSION} 51 | COMPATIBILITY AnyNewerVersion) 52 | 53 | install(FILES ${CMAKE_CURRENT_BINARY_DIR}/TorchScatterConfig.cmake 54 | ${CMAKE_CURRENT_BINARY_DIR}/TorchScatterConfigVersion.cmake 55 | DESTINATION ${TORCHSCATTER_CMAKECONFIG_INSTALL_DIR}) 56 | 57 | install(TARGETS ${PROJECT_NAME} 58 | EXPORT TorchScatterTargets 59 | LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} 60 | ) 61 | 62 | install(EXPORT TorchScatterTargets 63 | NAMESPACE TorchScatter:: 64 | DESTINATION ${TORCHSCATTER_CMAKECONFIG_INSTALL_DIR}) 65 | 66 | install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}) 67 | install(FILES 68 | csrc/cpu/scatter_cpu.h 69 | csrc/cpu/segment_coo_cpu.h 70 | csrc/cpu/segment_csr_cpu.h 71 | DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cpu) 72 | if(WITH_CUDA) 73 | install(FILES 74 | csrc/cuda/scatter_cuda.h 75 | csrc/cuda/segment_coo_cuda.h 76 | csrc/cuda/segment_csr_cuda.h 77 | DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cuda) 78 | endif() 79 | 80 | if(WITH_CUDA) 81 | set_property(TARGET torch_cuda PROPERTY INTERFACE_COMPILE_OPTIONS "") 82 | set_property(TARGET torch_cpu PROPERTY INTERFACE_COMPILE_OPTIONS "") 83 | endif() 84 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Matthias Fey 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in 11 | all copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 | THE SOFTWARE. 20 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | 4 | recursive-exclude test * 5 | recursive-include csrc * 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [pypi-image]: https://badge.fury.io/py/torch-scatter.svg 2 | [pypi-url]: https://pypi.python.org/pypi/torch-scatter 3 | [testing-image]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/testing.yml/badge.svg 4 | [testing-url]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/testing.yml 5 | [linting-image]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/linting.yml/badge.svg 6 | [linting-url]: https://github.com/rusty1s/pytorch_scatter/actions/workflows/linting.yml 7 | [docs-image]: https://readthedocs.org/projects/pytorch-scatter/badge/?version=latest 8 | [docs-url]: https://pytorch-scatter.readthedocs.io/en/latest/?badge=latest 9 | [coverage-image]: https://codecov.io/gh/rusty1s/pytorch_scatter/branch/master/graph/badge.svg 10 | [coverage-url]: https://codecov.io/github/rusty1s/pytorch_scatter?branch=master 11 | 12 | # PyTorch Scatter 13 | 14 | [![PyPI Version][pypi-image]][pypi-url] 15 | [![Testing Status][testing-image]][testing-url] 16 | [![Linting Status][linting-image]][linting-url] 17 | [![Docs Status][docs-image]][docs-url] 18 | [![Code Coverage][coverage-image]][coverage-url] 19 | 20 |

21 | 22 |

23 | 24 | -------------------------------------------------------------------------------- 25 | 26 | **[Documentation](https://pytorch-scatter.readthedocs.io)** 27 | 28 | This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package. 29 | Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. 30 | Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements. 31 | 32 | The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`: 33 | 34 | * [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/scatter.html) based on arbitrary indices 35 | * [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices 36 | * [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers 37 | 38 | In addition, we provide the following **composite functions** which make use of `scatter_*` operations under the hood: `scatter_std`, `scatter_logsumexp`, `scatter_softmax` and `scatter_log_softmax`. 39 | 40 | All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable. 41 | 42 | ## Installation 43 | 44 | ### Binaries 45 | 46 | We provide pip wheels for all major OS/PyTorch/CUDA combinations, see [here](https://data.pyg.org/whl). 47 | 48 | #### PyTorch 2.7 49 | 50 | To install the binaries for PyTorch 2.7.0, simply run 51 | 52 | ``` 53 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.7.0+${CUDA}.html 54 | ``` 55 | 56 | where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu126`, or `cu128` depending on your PyTorch installation. 57 | 58 | | | `cpu` | `cu118` | `cu126` | `cu128` | 59 | |-------------|-------|---------|---------|---------| 60 | | **Linux** | ✅ | ✅ | ✅ | ✅ | 61 | | **Windows** | ✅ | ✅ | ✅ | ✅ | 62 | | **macOS** | ✅ | | | | 63 | 64 | #### PyTorch 2.6 65 | 66 | To install the binaries for PyTorch 2.6.0, simply run 67 | 68 | ``` 69 | pip install torch-scatter -f https://data.pyg.org/whl/torch-2.6.0+${CUDA}.html 70 | ``` 71 | 72 | where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu124`, or `cu126` depending on your PyTorch installation. 73 | 74 | | | `cpu` | `cu118` | `cu124` | `cu126` | 75 | |-------------|-------|---------|---------|---------| 76 | | **Linux** | ✅ | ✅ | ✅ | ✅ | 77 | | **Windows** | ✅ | ✅ | ✅ | ✅ | 78 | | **macOS** | ✅ | | | | 79 | 80 | **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, PyTorch 2.3.0/2.3.1, PyTorch 2.4.0/2.4.1, and PyTorch 2.5.0/2.5.1 (following the same procedure). 81 | For older versions, you need to explicitly specify the latest supported version number or install via `pip install --no-index` in order to prevent a manual installation from source. 82 | You can look up the latest supported version number [here](https://data.pyg.org/whl). 83 | 84 | ### From source 85 | 86 | Ensure that at least PyTorch 1.4.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*: 87 | 88 | ``` 89 | $ python -c "import torch; print(torch.__version__)" 90 | >>> 1.4.0 91 | 92 | $ echo $PATH 93 | >>> /usr/local/cuda/bin:... 94 | 95 | $ echo $CPATH 96 | >>> /usr/local/cuda/include:... 97 | ``` 98 | 99 | Then run: 100 | 101 | ``` 102 | pip install torch-scatter 103 | ``` 104 | 105 | When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail. 106 | In this case, ensure that the compute capabilities are set via `TORCH_CUDA_ARCH_LIST`, *e.g.*: 107 | 108 | ``` 109 | export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX" 110 | ``` 111 | 112 | ## Example 113 | 114 | ```py 115 | import torch 116 | from torch_scatter import scatter_max 117 | 118 | src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]]) 119 | index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]]) 120 | 121 | out, argmax = scatter_max(src, index, dim=-1) 122 | ``` 123 | 124 | ``` 125 | print(out) 126 | tensor([[0, 0, 4, 3, 2, 0], 127 | [2, 4, 3, 0, 0, 0]]) 128 | 129 | print(argmax) 130 | tensor([[5, 5, 3, 4, 0, 1] 131 | [1, 4, 3, 5, 5, 5]]) 132 | ``` 133 | 134 | ## Running tests 135 | 136 | ``` 137 | pytest 138 | ``` 139 | 140 | ## C++ API 141 | 142 | `torch-scatter` also offers a C++ API that contains C++ equivalent of python models. 143 | For this, we need to add `TorchLib` to the `-DCMAKE_PREFIX_PATH` (run `import torch; print(torch.utils.cmake_prefix_path)` to obtain it). 144 | 145 | ``` 146 | mkdir build 147 | cd build 148 | # Add -DWITH_CUDA=on support for CUDA support 149 | cmake -DCMAKE_PREFIX_PATH="..." .. 150 | make 151 | make install 152 | ``` 153 | -------------------------------------------------------------------------------- /benchmark/.gitignore: -------------------------------------------------------------------------------- 1 | *.mat 2 | *.tmp 3 | -------------------------------------------------------------------------------- /benchmark/gather.py: -------------------------------------------------------------------------------- 1 | import time 2 | import itertools 3 | 4 | import argparse 5 | import torch 6 | from scipy.io import loadmat 7 | 8 | from torch_scatter import gather_coo, gather_csr 9 | 10 | from scatter_segment import short_rows, long_rows, download, bold 11 | 12 | 13 | @torch.no_grad() 14 | def correctness(dataset): 15 | group, name = dataset 16 | mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() 17 | rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) 18 | row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) 19 | dim_size = rowptr.size(0) - 1 20 | 21 | for size in sizes[1:]: 22 | try: 23 | x = torch.randn((dim_size, size), device=args.device) 24 | x = x.squeeze(-1) if size == 1 else x 25 | 26 | out1 = x.index_select(0, row) 27 | out2 = gather_coo(x, row) 28 | out3 = gather_csr(x, rowptr) 29 | 30 | assert torch.allclose(out1, out2, atol=1e-4) 31 | assert torch.allclose(out1, out3, atol=1e-4) 32 | except RuntimeError as e: 33 | if 'out of memory' not in str(e): 34 | raise RuntimeError(e) 35 | torch.cuda.empty_cache() 36 | 37 | 38 | def time_func(func, x): 39 | try: 40 | if torch.cuda.is_available(): 41 | torch.cuda.synchronize() 42 | t = time.perf_counter() 43 | 44 | if not args.with_backward: 45 | with torch.no_grad(): 46 | for _ in range(iters): 47 | func(x) 48 | else: 49 | x = x.requires_grad_() 50 | for _ in range(iters): 51 | out = func(x) 52 | torch.autograd.grad(out, x, out, only_inputs=True) 53 | 54 | if torch.cuda.is_available(): 55 | torch.cuda.synchronize() 56 | return time.perf_counter() - t 57 | except RuntimeError as e: 58 | if 'out of memory' not in str(e): 59 | raise RuntimeError(e) 60 | torch.cuda.empty_cache() 61 | return float('inf') 62 | 63 | 64 | def timing(dataset): 65 | group, name = dataset 66 | mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() 67 | rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) 68 | row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) 69 | dim_size = rowptr.size(0) - 1 70 | avg_row_len = row.size(0) / dim_size 71 | 72 | def select(x): 73 | return x.index_select(0, row) 74 | 75 | def gather(x): 76 | return x.gather(0, row.view(-1, 1).expand(-1, x.size(1))) 77 | 78 | def gat_coo(x): 79 | return gather_coo(x, row) 80 | 81 | def gat_csr(x): 82 | return gather_csr(x, rowptr) 83 | 84 | t1, t2, t3, t4 = [], [], [], [] 85 | for size in sizes: 86 | try: 87 | x = torch.randn((dim_size, size), device=args.device) 88 | 89 | t1 += [time_func(select, x)] 90 | t2 += [time_func(gather, x)] 91 | t3 += [time_func(gat_coo, x)] 92 | t4 += [time_func(gat_csr, x)] 93 | 94 | del x 95 | 96 | except RuntimeError as e: 97 | if 'out of memory' not in str(e): 98 | raise RuntimeError(e) 99 | torch.cuda.empty_cache() 100 | for t in (t1, t2, t3, t4): 101 | t.append(float('inf')) 102 | 103 | ts = torch.tensor([t1, t2, t3, t4]) 104 | winner = torch.zeros_like(ts, dtype=torch.bool) 105 | winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1 106 | winner = winner.tolist() 107 | 108 | name = f'{group}/{name}' 109 | print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):') 110 | print('\t'.join([' '] + [f'{size:>5}' for size in sizes])) 111 | print('\t'.join([bold('SELECT ')] + 112 | [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])])) 113 | print('\t'.join([bold('GAT ')] + 114 | [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])])) 115 | print('\t'.join([bold('GAT_COO')] + 116 | [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])])) 117 | print('\t'.join([bold('GAT_CSR')] + 118 | [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])])) 119 | print() 120 | 121 | 122 | if __name__ == '__main__': 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('--with_backward', action='store_true') 125 | parser.add_argument('--device', type=str, default='cuda') 126 | args = parser.parse_args() 127 | iters = 1 if args.device == 'cpu' else 20 128 | sizes = [1, 16, 32, 64, 128, 256, 512] 129 | sizes = sizes[:3] if args.device == 'cpu' else sizes 130 | 131 | for _ in range(10): # Warmup. 132 | torch.randn(100, 100, device=args.device).sum() 133 | for dataset in itertools.chain(short_rows, long_rows): 134 | download(dataset) 135 | correctness(dataset) 136 | timing(dataset) 137 | -------------------------------------------------------------------------------- /benchmark/scatter_segment.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os.path as osp 3 | import itertools 4 | 5 | import argparse 6 | import wget 7 | import torch 8 | from scipy.io import loadmat 9 | 10 | from torch_scatter import scatter, segment_coo, segment_csr 11 | 12 | short_rows = [ 13 | ('DIMACS10', 'citationCiteseer'), 14 | ('SNAP', 'web-Stanford'), 15 | ] 16 | long_rows = [ 17 | ('Janna', 'StocF-1465'), 18 | ('GHS_psdef', 'ldoor'), 19 | ] 20 | 21 | 22 | def download(dataset): 23 | url = 'https://sparse.tamu.edu/mat/{}/{}.mat' 24 | for group, name in itertools.chain(long_rows, short_rows): 25 | if not osp.exists(f'{name}.mat'): 26 | print(f'Downloading {group}/{name}:') 27 | wget.download(url.format(group, name)) 28 | print('') 29 | 30 | 31 | def bold(text, flag=True): 32 | return f'\033[1m{text}\033[0m' if flag else text 33 | 34 | 35 | @torch.no_grad() 36 | def correctness(dataset): 37 | group, name = dataset 38 | mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() 39 | rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) 40 | row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) 41 | dim_size = rowptr.size(0) - 1 42 | 43 | for size in sizes: 44 | try: 45 | x = torch.randn((row.size(0), size), device=args.device) 46 | x = x.squeeze(-1) if size == 1 else x 47 | 48 | out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add') 49 | out2 = segment_coo(x, row, dim_size=dim_size, reduce='add') 50 | out3 = segment_csr(x, rowptr, reduce='add') 51 | 52 | assert torch.allclose(out1, out2, atol=1e-4) 53 | assert torch.allclose(out1, out3, atol=1e-4) 54 | 55 | out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean') 56 | out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean') 57 | out3 = segment_csr(x, rowptr, reduce='mean') 58 | 59 | assert torch.allclose(out1, out2, atol=1e-4) 60 | assert torch.allclose(out1, out3, atol=1e-4) 61 | 62 | out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min') 63 | out2 = segment_coo(x, row, reduce='min') 64 | out3 = segment_csr(x, rowptr, reduce='min') 65 | 66 | assert torch.allclose(out1, out2, atol=1e-4) 67 | assert torch.allclose(out1, out3, atol=1e-4) 68 | 69 | out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max') 70 | out2 = segment_coo(x, row, reduce='max') 71 | out3 = segment_csr(x, rowptr, reduce='max') 72 | 73 | assert torch.allclose(out1, out2, atol=1e-4) 74 | assert torch.allclose(out1, out3, atol=1e-4) 75 | 76 | except RuntimeError as e: 77 | if 'out of memory' not in str(e): 78 | raise RuntimeError(e) 79 | torch.cuda.empty_cache() 80 | 81 | 82 | def time_func(func, x): 83 | try: 84 | if torch.cuda.is_available(): 85 | torch.cuda.synchronize() 86 | t = time.perf_counter() 87 | 88 | if not args.with_backward: 89 | with torch.no_grad(): 90 | for _ in range(iters): 91 | func(x) 92 | else: 93 | x = x.requires_grad_() 94 | for _ in range(iters): 95 | out = func(x) 96 | out = out[0] if isinstance(out, tuple) else out 97 | torch.autograd.grad(out, x, out, only_inputs=True) 98 | 99 | if torch.cuda.is_available(): 100 | torch.cuda.synchronize() 101 | return time.perf_counter() - t 102 | except RuntimeError as e: 103 | if 'out of memory' not in str(e): 104 | raise RuntimeError(e) 105 | torch.cuda.empty_cache() 106 | return float('inf') 107 | 108 | 109 | def timing(dataset): 110 | group, name = dataset 111 | mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr() 112 | rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long) 113 | row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long) 114 | row2 = row[torch.randperm(row.size(0))] 115 | dim_size = rowptr.size(0) - 1 116 | avg_row_len = row.size(0) / dim_size 117 | 118 | def sca1_row(x): 119 | out = x.new_zeros(dim_size, *x.size()[1:]) 120 | row_tmp = row.view(-1, 1).expand_as(x) if x.dim() > 1 else row 121 | return out.scatter_add_(0, row_tmp, x) 122 | 123 | def sca1_col(x): 124 | out = x.new_zeros(dim_size, *x.size()[1:]) 125 | row2_tmp = row2.view(-1, 1).expand_as(x) if x.dim() > 1 else row2 126 | return out.scatter_add_(0, row2_tmp, x) 127 | 128 | def sca2_row(x): 129 | return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce) 130 | 131 | def sca2_col(x): 132 | return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce) 133 | 134 | def seg_coo(x): 135 | return segment_coo(x, row, reduce=args.reduce) 136 | 137 | def seg_csr(x): 138 | return segment_csr(x, rowptr, reduce=args.reduce) 139 | 140 | def dense1(x): 141 | return getattr(torch, args.reduce)(x, dim=-2) 142 | 143 | def dense2(x): 144 | return getattr(torch, args.reduce)(x, dim=-1) 145 | 146 | t1, t2, t3, t4, t5, t6, t7, t8 = [], [], [], [], [], [], [], [] 147 | 148 | for size in sizes: 149 | try: 150 | x = torch.randn((row.size(0), size), device=args.device) 151 | x = x.squeeze(-1) if size == 1 else x 152 | 153 | t1 += [time_func(sca1_row, x)] 154 | t2 += [time_func(sca1_col, x)] 155 | t3 += [time_func(sca2_row, x)] 156 | t4 += [time_func(sca2_col, x)] 157 | t5 += [time_func(seg_coo, x)] 158 | t6 += [time_func(seg_csr, x)] 159 | 160 | del x 161 | 162 | except RuntimeError as e: 163 | if 'out of memory' not in str(e): 164 | raise RuntimeError(e) 165 | torch.cuda.empty_cache() 166 | for t in (t1, t2, t3, t4, t5, t6): 167 | t.append(float('inf')) 168 | 169 | try: 170 | x = torch.randn((dim_size, int(avg_row_len + 1), size), 171 | device=args.device) 172 | 173 | t7 += [time_func(dense1, x)] 174 | x = x.view(dim_size, size, int(avg_row_len + 1)) 175 | t8 += [time_func(dense2, x)] 176 | 177 | del x 178 | 179 | except RuntimeError as e: 180 | if 'out of memory' not in str(e): 181 | raise RuntimeError(e) 182 | torch.cuda.empty_cache() 183 | for t in (t7, t8): 184 | t.append(float('inf')) 185 | 186 | ts = torch.tensor([t1, t2, t3, t4, t5, t6, t7, t8]) 187 | winner = torch.zeros_like(ts, dtype=torch.bool) 188 | winner[ts.argmin(dim=0), torch.arange(len(sizes))] = 1 189 | winner = winner.tolist() 190 | 191 | name = f'{group}/{name}' 192 | print(f'{bold(name)} (avg row length: {avg_row_len:.2f}):') 193 | print('\t'.join([' '] + [f'{size:>5}' for size in sizes])) 194 | print('\t'.join([bold('SCA1_ROW')] + 195 | [bold(f'{t:.5f}', f) for t, f in zip(t1, winner[0])])) 196 | print('\t'.join([bold('SCA1_COL')] + 197 | [bold(f'{t:.5f}', f) for t, f in zip(t2, winner[1])])) 198 | print('\t'.join([bold('SCA2_ROW')] + 199 | [bold(f'{t:.5f}', f) for t, f in zip(t3, winner[2])])) 200 | print('\t'.join([bold('SCA2_COL')] + 201 | [bold(f'{t:.5f}', f) for t, f in zip(t4, winner[3])])) 202 | print('\t'.join([bold('SEG_COO ')] + 203 | [bold(f'{t:.5f}', f) for t, f in zip(t5, winner[4])])) 204 | print('\t'.join([bold('SEG_CSR ')] + 205 | [bold(f'{t:.5f}', f) for t, f in zip(t6, winner[5])])) 206 | print('\t'.join([bold('DENSE1 ')] + 207 | [bold(f'{t:.5f}', f) for t, f in zip(t7, winner[6])])) 208 | print('\t'.join([bold('DENSE2 ')] + 209 | [bold(f'{t:.5f}', f) for t, f in zip(t8, winner[7])])) 210 | print() 211 | 212 | 213 | if __name__ == '__main__': 214 | parser = argparse.ArgumentParser() 215 | parser.add_argument('--reduce', type=str, required=True, 216 | choices=['sum', 'mean', 'min', 'max']) 217 | parser.add_argument('--with_backward', action='store_true') 218 | parser.add_argument('--device', type=str, default='cuda') 219 | args = parser.parse_args() 220 | iters = 1 if args.device == 'cpu' else 20 221 | sizes = [1, 16, 32, 64, 128, 256, 512] 222 | sizes = sizes[:3] if args.device == 'cpu' else sizes 223 | 224 | for _ in range(10): # Warmup. 225 | torch.randn(100, 100, device=args.device).sum() 226 | for dataset in itertools.chain(short_rows, long_rows): 227 | download(dataset) 228 | correctness(dataset) 229 | timing(dataset) 230 | -------------------------------------------------------------------------------- /cmake/TorchScatterConfig.cmake.in: -------------------------------------------------------------------------------- 1 | # TorchScatterConfig.cmake 2 | # -------------------- 3 | # 4 | # Exported targets:: Scatter 5 | # 6 | 7 | @PACKAGE_INIT@ 8 | 9 | set(PN TorchScatter) 10 | set(${PN}_INCLUDE_DIR "${PACKAGE_PREFIX_DIR}/@CMAKE_INSTALL_INCLUDEDIR@") 11 | set(${PN}_LIBRARY "") 12 | set(${PN}_DEFINITIONS USING_${PN}) 13 | 14 | check_required_components(${PN}) 15 | 16 | 17 | if(NOT (CMAKE_VERSION VERSION_LESS 3.0)) 18 | #----------------------------------------------------------------------------- 19 | # Don't include targets if this file is being picked up by another 20 | # project which has already built this as a subproject 21 | #----------------------------------------------------------------------------- 22 | if(NOT TARGET ${PN}::TorchScatter) 23 | include("${CMAKE_CURRENT_LIST_DIR}/${PN}Targets.cmake") 24 | 25 | if(NOT TARGET torch_library) 26 | find_package(Torch REQUIRED) 27 | endif() 28 | if(NOT TARGET Python3::Python) 29 | find_package(Python3 COMPONENTS Development) 30 | endif() 31 | target_link_libraries(TorchScatter::TorchScatter INTERFACE ${TORCH_LIBRARIES} Python3::Python) 32 | 33 | if(@WITH_CUDA@) 34 | target_compile_definitions(TorchScatter::TorchScatter INTERFACE WITH_CUDA) 35 | endif() 36 | 37 | endif() 38 | endif() 39 | -------------------------------------------------------------------------------- /csrc/cpu/index_info.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../extensions.h" 4 | 5 | #define MAX_TENSORINFO_DIMS 25 6 | 7 | template struct TensorInfo { 8 | TensorInfo(scalar_t *p, int dim, int sz[MAX_TENSORINFO_DIMS], 9 | int st[MAX_TENSORINFO_DIMS]) { 10 | data = p; 11 | dims = dim; 12 | AT_ASSERT(dims < MAX_TENSORINFO_DIMS); 13 | 14 | for (int i = 0; i < dim; ++i) { 15 | sizes[i] = sz[i]; 16 | strides[i] = st[i]; 17 | } 18 | } 19 | 20 | scalar_t *data; 21 | int dims; 22 | int sizes[MAX_TENSORINFO_DIMS]; 23 | int strides[MAX_TENSORINFO_DIMS]; 24 | }; 25 | 26 | template 27 | TensorInfo getTensorInfo(const torch::Tensor &tensor) { 28 | int sizes[MAX_TENSORINFO_DIMS]; 29 | int strides[MAX_TENSORINFO_DIMS]; 30 | 31 | int dims = tensor.dim(); 32 | for (int i = 0; i < dims; ++i) { 33 | sizes[i] = tensor.size(i); 34 | strides[i] = tensor.stride(i); 35 | } 36 | 37 | return TensorInfo(tensor.data_ptr(), dims, sizes, 38 | strides); 39 | } 40 | 41 | template struct IndexToOffset { 42 | static inline int get(int idx, const TensorInfo &info) { 43 | int offset = 0; 44 | for (int i = info.dims - 1; i >= 0; --i) { 45 | offset += (idx % info.sizes[i]) * info.strides[i]; 46 | idx /= info.sizes[i]; 47 | } 48 | return offset; 49 | } 50 | }; 51 | 52 | template struct IndexPtrToOffset { 53 | static inline int get(int idx, const TensorInfo &info) { 54 | int offset = idx % (info.sizes[info.dims - 1] - 1); 55 | offset *= info.strides[info.dims - 1]; 56 | idx /= info.sizes[info.dims - 1] - 1; 57 | for (int i = info.dims - 2; i >= 0; --i) { 58 | offset += (idx % info.sizes[i]) * info.strides[i]; 59 | idx /= info.sizes[i]; 60 | } 61 | return offset; 62 | } 63 | }; 64 | -------------------------------------------------------------------------------- /csrc/cpu/reducer.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX }; 7 | 8 | const std::map reduce2REDUCE = { 9 | {"sum", SUM}, {"mean", MEAN}, {"mul", MUL}, 10 | {"div", DIV}, {"min", MIN}, {"max", MAX}, 11 | }; 12 | 13 | #define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ 14 | [&] { \ 15 | switch (reduce2REDUCE.at(reduce)) { \ 16 | case SUM: { \ 17 | static constexpr ReductionType REDUCE = SUM; \ 18 | return __VA_ARGS__(); \ 19 | } \ 20 | case MEAN: { \ 21 | static constexpr ReductionType REDUCE = MEAN; \ 22 | return __VA_ARGS__(); \ 23 | } \ 24 | case MUL: { \ 25 | static constexpr ReductionType REDUCE = MUL; \ 26 | return __VA_ARGS__(); \ 27 | } \ 28 | case DIV: { \ 29 | static constexpr ReductionType REDUCE = DIV; \ 30 | return __VA_ARGS__(); \ 31 | } \ 32 | case MIN: { \ 33 | static constexpr ReductionType REDUCE = MIN; \ 34 | return __VA_ARGS__(); \ 35 | } \ 36 | case MAX: { \ 37 | static constexpr ReductionType REDUCE = MAX; \ 38 | return __VA_ARGS__(); \ 39 | } \ 40 | } \ 41 | }() 42 | 43 | template struct Reducer { 44 | static inline scalar_t init() { 45 | if (REDUCE == MUL || REDUCE == DIV) 46 | return (scalar_t)1; 47 | else if (REDUCE == MIN) 48 | return std::numeric_limits::max(); 49 | else if (REDUCE == MAX) 50 | return std::numeric_limits::lowest(); 51 | else 52 | return (scalar_t)0; 53 | } 54 | 55 | static inline void update(scalar_t *val, scalar_t new_val, int64_t *arg, 56 | int64_t new_arg) { 57 | if (REDUCE == SUM || REDUCE == MEAN) 58 | *val = *val + new_val; 59 | else if (REDUCE == MUL) 60 | *val = *val * new_val; 61 | else if (REDUCE == DIV) 62 | *val = *val / new_val; 63 | else if ((REDUCE == MIN && new_val < *val) || 64 | (REDUCE == MAX && new_val > *val)) { 65 | *val = new_val; 66 | *arg = new_arg; 67 | } 68 | } 69 | 70 | static inline void write(scalar_t *address, scalar_t val, 71 | int64_t *arg_address, int64_t arg, int count) { 72 | if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) 73 | *address = val; 74 | else if (REDUCE == MEAN) 75 | *address = val / (scalar_t)(count > 0 ? count : 1); 76 | else if (REDUCE == MIN || REDUCE == MAX) { 77 | if (count > 0) { 78 | *address = val; 79 | *arg_address = arg; 80 | } else 81 | *address = (scalar_t)0; 82 | } 83 | } 84 | }; 85 | -------------------------------------------------------------------------------- /csrc/cpu/scatter_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "scatter_cpu.h" 2 | 3 | #include "index_info.h" 4 | #include "reducer.h" 5 | #include "utils.h" 6 | 7 | std::tuple> 8 | scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, 9 | std::optional optional_out, 10 | std::optional dim_size, std::string reduce) { 11 | CHECK_CPU(src); 12 | CHECK_CPU(index); 13 | if (optional_out.has_value()) 14 | CHECK_CPU(optional_out.value()); 15 | 16 | CHECK_INPUT(src.dim() == index.dim()); 17 | for (auto i = 0; i < index.dim() - 1; i++) 18 | CHECK_INPUT(src.size(i) >= index.size(i)); 19 | 20 | src = src.contiguous(); 21 | 22 | torch::Tensor out; 23 | if (optional_out.has_value()) { 24 | out = optional_out.value().contiguous(); 25 | for (auto i = 0; i < out.dim(); i++) 26 | if (i != dim) 27 | CHECK_INPUT(src.size(i) == out.size(i)); 28 | } else { 29 | auto sizes = src.sizes().vec(); 30 | if (dim_size.has_value()) 31 | sizes[dim] = dim_size.value(); 32 | else if (index.numel() == 0) 33 | sizes[dim] = 0; 34 | else 35 | sizes[dim] = 1 + *index.max().data_ptr(); 36 | out = torch::empty(sizes, src.options()); 37 | } 38 | 39 | std::optional arg_out = std::nullopt; 40 | int64_t *arg_out_data = nullptr; 41 | if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { 42 | arg_out = torch::full_like(out, src.size(dim), index.options()); 43 | arg_out_data = arg_out.value().data_ptr(); 44 | } 45 | 46 | if (src.numel() == 0) { 47 | if (!optional_out.has_value()) 48 | out.fill_(0); 49 | return std::make_tuple(out, arg_out); 50 | } 51 | 52 | auto B = 1; 53 | for (auto i = 0; i < dim; i++) 54 | B *= src.size(i); 55 | auto E = src.size(dim); 56 | auto K = src.numel() / (B * E); 57 | auto N = out.size(dim); 58 | 59 | auto index_info = getTensorInfo(index); 60 | AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "scatter_cpu", [&] { 61 | auto src_data = src.data_ptr(); 62 | auto out_data = out.data_ptr(); 63 | 64 | int64_t i, idx; 65 | AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { 66 | if (!optional_out.has_value()) 67 | out.fill_(Reducer::init()); 68 | 69 | for (auto b = 0; b < B; b++) { 70 | for (auto e = 0; e < E; e++) { 71 | for (auto k = 0; k < K; k++) { 72 | i = b * E * K + e * K + k; 73 | idx = index_info.data[IndexToOffset::get(i, index_info)]; 74 | Reducer::update( 75 | out_data + b * N * K + idx * K + k, src_data[i], 76 | arg_out_data + b * N * K + idx * K + k, e); 77 | } 78 | } 79 | } 80 | 81 | if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) 82 | out.masked_fill_(out == Reducer::init(), (scalar_t)0); 83 | }); 84 | }); 85 | 86 | return std::make_tuple(out, arg_out); 87 | } 88 | -------------------------------------------------------------------------------- /csrc/cpu/scatter_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../extensions.h" 4 | 5 | std::tuple> 6 | scatter_cpu(torch::Tensor src, torch::Tensor index, int64_t dim, 7 | std::optional optional_out, 8 | std::optional dim_size, std::string reduce); 9 | -------------------------------------------------------------------------------- /csrc/cpu/segment_coo_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "segment_coo_cpu.h" 2 | 3 | #include "index_info.h" 4 | #include "reducer.h" 5 | #include "utils.h" 6 | #include 7 | 8 | std::tuple> 9 | segment_coo_cpu(torch::Tensor src, torch::Tensor index, 10 | std::optional optional_out, 11 | std::optional dim_size, std::string reduce) { 12 | CHECK_CPU(src); 13 | CHECK_CPU(index); 14 | if (optional_out.has_value()) 15 | CHECK_CPU(optional_out.value()); 16 | 17 | CHECK_INPUT(src.dim() >= index.dim()); 18 | 19 | auto sizes = index.sizes().vec(); 20 | for (auto i = 0; i < index.dim(); i++) 21 | sizes[i] = src.size(i); 22 | index = index.expand(sizes); 23 | 24 | auto dim = index.dim() - 1; 25 | 26 | src = src.contiguous(); 27 | 28 | torch::Tensor out; 29 | if (optional_out.has_value()) { 30 | out = optional_out.value().contiguous(); 31 | for (auto i = 0; i < out.dim(); i++) 32 | if (i != dim) 33 | CHECK_INPUT(src.size(i) == out.size(i)); 34 | } else { 35 | sizes = src.sizes().vec(); 36 | if (dim_size.has_value()) 37 | sizes[dim] = dim_size.value(); 38 | else if (index.numel() == 0) 39 | sizes[dim] = 0; 40 | else { 41 | auto tmp = index.select(dim, index.size(dim) - 1); 42 | tmp = tmp.numel() > 1 ? tmp.max() : tmp; 43 | sizes[dim] = 1 + *tmp.data_ptr(); 44 | } 45 | out = torch::empty(sizes, src.options()); 46 | } 47 | 48 | std::optional arg_out = std::nullopt; 49 | int64_t *arg_out_data = nullptr; 50 | if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { 51 | arg_out = torch::full_like(out, src.size(dim), index.options()); 52 | arg_out_data = arg_out.value().data_ptr(); 53 | } else if (reduce2REDUCE.at(reduce) == MEAN) { 54 | auto sizes = index.sizes().vec(); 55 | sizes[dim] = out.size(dim); 56 | arg_out = torch::zeros(sizes, out.options()); 57 | } 58 | 59 | if (src.numel() == 0) { 60 | if (!optional_out.has_value()) 61 | out.fill_(0); 62 | return std::make_tuple(out, arg_out); 63 | } 64 | 65 | auto B = index.numel() / src.size(dim); 66 | auto E = src.size(dim); 67 | auto K = src.numel() / index.numel(); 68 | auto N = out.size(dim); 69 | 70 | auto index_info = getTensorInfo(index); 71 | auto stride = index_info.strides[index_info.dims - 1]; 72 | std::vector args(K); 73 | AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "segment_coo_cpu", [&] { 74 | using opmath_t = at::opmath_type; 75 | auto src_data = src.data_ptr(); 76 | auto out_data = out.data_ptr(); 77 | scalar_t *count_data = nullptr; 78 | 79 | std::vector vals(K); 80 | int64_t idx, next_idx, row_start; 81 | AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { 82 | if (!optional_out.has_value()) 83 | out.fill_(Reducer::init()); 84 | if (REDUCE == MEAN) 85 | count_data = arg_out.value().data_ptr(); 86 | 87 | for (auto b = 0; b < B; b++) { 88 | auto offset = IndexToOffset::get(b * E, index_info); 89 | idx = index_info.data[offset]; 90 | 91 | for (auto k = 0; k < K; k++) 92 | vals[k] = static_cast(out_data[b * N * K + k]); 93 | 94 | row_start = 0; 95 | for (auto e = 0; e < E; e++) { 96 | 97 | for (auto k = 0; k < K; k++) 98 | Reducer::update( 99 | &vals[k], static_cast(src_data[b * E * K + e * K + k]), &args[k], e); 100 | 101 | if (e == E - 1) { 102 | for (auto k = 0; k < K; k++) 103 | Reducer::write( 104 | out_data + b * N * K + idx * K + k, static_cast(vals[k]), 105 | arg_out_data + b * N * K + idx * K + k, args[k], 106 | e + 1 - row_start); 107 | if (REDUCE == MEAN) 108 | count_data[b * N + idx] = (scalar_t)(e + 1 - row_start); 109 | } else { 110 | next_idx = index_info.data[offset + (e + 1) * stride]; 111 | assert(idx <= next_idx); 112 | 113 | if (idx != next_idx) { 114 | for (auto k = 0; k < K; k++) { 115 | Reducer::write( 116 | out_data + b * N * K + idx * K + k, static_cast(vals[k]), 117 | arg_out_data + b * N * K + idx * K + k, args[k], 118 | e + 1 - row_start); 119 | 120 | vals[k] = static_cast(out_data[b * N * K + next_idx * K + k]); 121 | } 122 | if (REDUCE == MEAN) 123 | count_data[b * N + idx] = (scalar_t)(e + 1 - row_start); 124 | row_start = e + 1; 125 | } 126 | 127 | idx = next_idx; 128 | } 129 | } 130 | } 131 | if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) 132 | out.masked_fill_(out == Reducer::init(), (scalar_t)0); 133 | 134 | if (REDUCE == MEAN) 135 | arg_out.value().masked_fill_(arg_out.value() < (scalar_t)1, 136 | (scalar_t)1); 137 | }); 138 | }); 139 | 140 | return std::make_tuple(out, arg_out); 141 | } 142 | 143 | torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, 144 | std::optional optional_out) { 145 | CHECK_CPU(src); 146 | CHECK_CPU(index); 147 | if (optional_out.has_value()) 148 | CHECK_CPU(optional_out.value()); 149 | 150 | CHECK_INPUT(src.dim() >= index.dim()); 151 | for (auto i = 0; i < index.dim() - 1; i++) 152 | CHECK_INPUT(src.size(i) == index.size(i)); 153 | 154 | auto dim = index.dim() - 1; 155 | 156 | src = src.contiguous(); 157 | 158 | torch::Tensor out; 159 | if (optional_out.has_value()) { 160 | out = optional_out.value().contiguous(); 161 | for (auto i = 0; i < src.dim(); i++) 162 | if (i != dim) 163 | CHECK_INPUT(src.size(i) == out.size(i)); 164 | } else { 165 | auto sizes = src.sizes().vec(); 166 | sizes[dim] = index.size(dim); 167 | out = torch::empty(sizes, src.options()); 168 | } 169 | 170 | if (src.numel() == 0) { 171 | if (!optional_out.has_value()) 172 | out.fill_(0); 173 | return out; 174 | } 175 | 176 | auto B = index.numel() / out.size(dim); 177 | auto E = index.size(dim); 178 | auto K = out.numel() / index.numel(); 179 | auto N = src.size(dim); 180 | 181 | auto index_info = getTensorInfo(index); 182 | auto stride = index_info.strides[index_info.dims - 1]; 183 | AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "gather_coo_cpu", [&] { 184 | auto src_data = src.data_ptr(); 185 | auto out_data = out.data_ptr(); 186 | 187 | std::vector vals(K); 188 | int64_t idx, next_idx; 189 | for (auto b = 0; b < B; b++) { 190 | auto offset = IndexToOffset::get(b * E, index_info); 191 | idx = index_info.data[offset]; 192 | 193 | for (auto k = 0; k < K; k++) 194 | vals[k] = src_data[b * N * K + idx * K + k]; 195 | 196 | for (auto e = 0; e < E; e++) { 197 | for (auto k = 0; k < K; k++) 198 | out_data[b * E * K + e * K + k] = vals[k]; 199 | 200 | if (e < E - 1) { 201 | next_idx = index_info.data[offset + (e + 1) * stride]; 202 | CHECK_INPUT(idx <= next_idx); 203 | 204 | if (idx != next_idx) { 205 | idx = next_idx; 206 | for (auto k = 0; k < K; k++) 207 | vals[k] = src_data[b * N * K + idx * K + k]; 208 | } 209 | } 210 | } 211 | } 212 | }); 213 | 214 | return out; 215 | } 216 | -------------------------------------------------------------------------------- /csrc/cpu/segment_coo_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../extensions.h" 4 | 5 | std::tuple> 6 | segment_coo_cpu(torch::Tensor src, torch::Tensor index, 7 | std::optional optional_out, 8 | std::optional dim_size, std::string reduce); 9 | 10 | torch::Tensor gather_coo_cpu(torch::Tensor src, torch::Tensor index, 11 | std::optional optional_out); 12 | -------------------------------------------------------------------------------- /csrc/cpu/segment_csr_cpu.cpp: -------------------------------------------------------------------------------- 1 | #include "segment_csr_cpu.h" 2 | 3 | #include "index_info.h" 4 | #include "reducer.h" 5 | #include "utils.h" 6 | #include 7 | 8 | std::tuple> 9 | segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, 10 | std::optional optional_out, 11 | std::string reduce) { 12 | CHECK_CPU(src); 13 | CHECK_CPU(indptr); 14 | if (optional_out.has_value()) 15 | CHECK_CPU(optional_out.value()); 16 | 17 | CHECK_INPUT(src.dim() >= indptr.dim()); 18 | 19 | auto sizes = indptr.sizes().vec(); 20 | for (auto i = 0; i < indptr.dim() - 1; i++) 21 | sizes[i] = src.size(i); 22 | indptr = indptr.expand(sizes); 23 | 24 | auto dim = indptr.dim() - 1; 25 | 26 | src = src.contiguous(); 27 | 28 | torch::Tensor out; 29 | if (optional_out.has_value()) { 30 | out = optional_out.value().contiguous(); 31 | for (auto i = 0; i < out.dim(); i++) 32 | if (i != dim) 33 | CHECK_INPUT(src.size(i) == out.size(i)); 34 | CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1); 35 | } else { 36 | sizes = src.sizes().vec(); 37 | sizes[dim] = std::max(indptr.size(dim) - 1, 0); 38 | out = torch::empty(sizes, src.options()); 39 | } 40 | 41 | std::optional arg_out = std::nullopt; 42 | int64_t *arg_out_data = nullptr; 43 | if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { 44 | arg_out = torch::full(out.sizes(), src.size(dim), indptr.options()); 45 | arg_out_data = arg_out.value().data_ptr(); 46 | } 47 | 48 | if (src.numel() == 0) { 49 | if (!optional_out.has_value()) 50 | out.fill_(0); 51 | return std::make_tuple(out, arg_out); 52 | } 53 | 54 | auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); 55 | auto K = out.numel() / N; 56 | auto E = src.size(dim); 57 | 58 | auto indptr_info = getTensorInfo(indptr); 59 | auto stride = indptr_info.strides[indptr_info.dims - 1]; 60 | std::vector args(K); 61 | AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "segment_csr_cpu", [&] { 62 | using opmath_t = at::opmath_type; 63 | auto src_data = src.data_ptr(); 64 | auto out_data = out.data_ptr(); 65 | 66 | std::vector vals(K); 67 | int64_t row_start, row_end; 68 | AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { 69 | for (auto n = 0; n < N; n++) { 70 | auto offset = IndexPtrToOffset::get(n, indptr_info); 71 | row_start = indptr_info.data[offset]; 72 | row_end = indptr_info.data[offset + stride]; 73 | 74 | offset = (n / (indptr.size(-1) - 1)) * E * K; 75 | for (auto k = 0; k < K; k++) 76 | vals[k] = Reducer::init(); 77 | 78 | for (auto e = row_start; e < row_end; e++) 79 | for (auto k = 0; k < K; k++) 80 | Reducer::update( 81 | &vals[k], static_cast(src_data[offset + e * K + k]), &args[k], e); 82 | 83 | for (auto k = 0; k < K; k++) 84 | Reducer::write(out_data + n * K + k, static_cast(vals[k]), 85 | arg_out_data + n * K + k, args[k], 86 | row_end - row_start); 87 | } 88 | }); 89 | }); 90 | 91 | return std::make_tuple(out, arg_out); 92 | } 93 | 94 | torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, 95 | std::optional optional_out) { 96 | CHECK_CPU(src); 97 | CHECK_CPU(indptr); 98 | if (optional_out.has_value()) 99 | CHECK_CPU(optional_out.value()); 100 | 101 | CHECK_INPUT(src.dim() >= indptr.dim()); 102 | 103 | auto sizes = indptr.sizes().vec(); 104 | for (auto i = 0; i < indptr.dim() - 1; i++) 105 | sizes[i] = src.size(i); 106 | indptr = indptr.expand(sizes); 107 | 108 | auto dim = indptr.dim() - 1; 109 | CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1); 110 | 111 | src = src.contiguous(); 112 | 113 | torch::Tensor out; 114 | if (optional_out.has_value()) { 115 | out = optional_out.value().contiguous(); 116 | for (auto i = 0; i < out.dim(); i++) 117 | if (i != dim) 118 | CHECK_INPUT(src.size(i) == out.size(i)); 119 | } else { 120 | auto sizes = src.sizes().vec(); 121 | if (src.numel() > 0) 122 | sizes[dim] = *indptr.flatten()[-1].data_ptr(); 123 | else 124 | sizes[dim] = 0; 125 | out = torch::empty(sizes, src.options()); 126 | } 127 | 128 | if (src.numel() == 0) { 129 | if (!optional_out.has_value()) 130 | out.fill_(0); 131 | return out; 132 | } 133 | 134 | auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); 135 | auto K = src.numel() / N; 136 | auto E = out.size(dim); 137 | 138 | auto indptr_info = getTensorInfo(indptr); 139 | auto stride = indptr_info.strides[indptr_info.dims - 1]; 140 | AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "gather_csr_cpu", [&] { 141 | auto src_data = src.data_ptr(); 142 | auto out_data = out.data_ptr(); 143 | 144 | std::vector vals(K); 145 | int64_t row_start, row_end; 146 | for (auto n = 0; n < N; n++) { 147 | auto offset = IndexPtrToOffset::get(n, indptr_info); 148 | row_start = indptr_info.data[offset]; 149 | row_end = indptr_info.data[offset + stride]; 150 | 151 | for (auto k = 0; k < K; k++) 152 | vals[k] = src_data[n * K + k]; 153 | 154 | offset = (n / (indptr.size(-1) - 1)) * E * K; 155 | for (auto e = row_start; e < row_end; e++) 156 | for (auto k = 0; k < K; k++) 157 | out_data[offset + e * K + k] = vals[k]; 158 | } 159 | }); 160 | 161 | return out; 162 | } 163 | -------------------------------------------------------------------------------- /csrc/cpu/segment_csr_cpu.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../extensions.h" 4 | 5 | std::tuple> 6 | segment_csr_cpu(torch::Tensor src, torch::Tensor indptr, 7 | std::optional optional_out, 8 | std::string reduce); 9 | 10 | torch::Tensor gather_csr_cpu(torch::Tensor src, torch::Tensor indptr, 11 | std::optional optional_out); 12 | -------------------------------------------------------------------------------- /csrc/cpu/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../extensions.h" 4 | 5 | #define CHECK_CPU(x) AT_ASSERTM(x.device().is_cpu(), #x " must be CPU tensor") 6 | #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") 7 | -------------------------------------------------------------------------------- /csrc/cuda/index_info.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | // We need our own `IndexToOffset` implementation since we do not want to 6 | // access the last element of the `indexptr`. 7 | template struct IndexPtrToOffset { 8 | static inline __host__ __device__ int 9 | get(int idx, const at::cuda::detail::TensorInfo &info) { 10 | int offset = idx % (info.sizes[info.dims - 1] - 1); 11 | offset *= info.strides[info.dims - 1]; 12 | idx /= info.sizes[info.dims - 1] - 1; 13 | for (int i = info.dims - 2; i >= 0; --i) { 14 | offset += (idx % info.sizes[i]) * info.strides[i]; 15 | idx /= info.sizes[i]; 16 | } 17 | return offset; 18 | } 19 | }; 20 | -------------------------------------------------------------------------------- /csrc/cuda/reducer.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | #include "atomics.cuh" 7 | 8 | enum ReductionType { SUM, MEAN, MUL, DIV, MIN, MAX }; 9 | 10 | const std::map reduce2REDUCE = { 11 | {"sum", SUM}, {"mean", MEAN}, {"mul", MUL}, 12 | {"div", DIV}, {"min", MIN}, {"max", MAX}, 13 | }; 14 | 15 | #define AT_DISPATCH_REDUCTION_TYPES(reduce, ...) \ 16 | [&] { \ 17 | switch (reduce2REDUCE.at(reduce)) { \ 18 | case SUM: { \ 19 | static constexpr ReductionType REDUCE = SUM; \ 20 | return __VA_ARGS__(); \ 21 | } \ 22 | case MEAN: { \ 23 | static constexpr ReductionType REDUCE = MEAN; \ 24 | return __VA_ARGS__(); \ 25 | } \ 26 | case MUL: { \ 27 | static constexpr ReductionType REDUCE = MUL; \ 28 | return __VA_ARGS__(); \ 29 | } \ 30 | case DIV: { \ 31 | static constexpr ReductionType REDUCE = DIV; \ 32 | return __VA_ARGS__(); \ 33 | } \ 34 | case MIN: { \ 35 | static constexpr ReductionType REDUCE = MIN; \ 36 | return __VA_ARGS__(); \ 37 | } \ 38 | case MAX: { \ 39 | static constexpr ReductionType REDUCE = MAX; \ 40 | return __VA_ARGS__(); \ 41 | } \ 42 | } \ 43 | }() 44 | 45 | template struct Reducer { 46 | static inline __host__ __device__ scalar_t init() { 47 | if (REDUCE == MUL || REDUCE == DIV) 48 | return (scalar_t)1; 49 | else if (REDUCE == MIN) 50 | return std::numeric_limits::max(); 51 | else if (REDUCE == MAX) 52 | return std::numeric_limits::lowest(); 53 | else 54 | return (scalar_t)0; 55 | } 56 | 57 | static inline __host__ __device__ void update(scalar_t *val, 58 | scalar_t new_val) { 59 | if (REDUCE == SUM || REDUCE == MEAN) 60 | *val = *val + new_val; 61 | else if (REDUCE == MUL) 62 | *val = *val * new_val; 63 | else if (REDUCE == DIV) 64 | *val = *val / new_val; 65 | else if ((REDUCE == MIN && new_val < *val) || 66 | (REDUCE == MAX && new_val > *val)) { 67 | *val = new_val; 68 | } 69 | } 70 | 71 | static inline __host__ __device__ void update(scalar_t *val, scalar_t new_val, 72 | int64_t *arg, int64_t new_arg) { 73 | if (REDUCE == SUM || REDUCE == MEAN) 74 | *val = *val + new_val; 75 | else if (REDUCE == MUL) 76 | *val = *val * new_val; 77 | else if (REDUCE == DIV) 78 | *val = *val / new_val; 79 | else if ((REDUCE == MIN && new_val < *val) || 80 | (REDUCE == MAX && new_val > *val)) { 81 | *val = new_val; 82 | *arg = new_arg; 83 | } 84 | } 85 | 86 | static inline __host__ __device__ void write(scalar_t *address, scalar_t val, 87 | int64_t *arg_address, 88 | int64_t arg, int count) { 89 | if (REDUCE == SUM || REDUCE == MUL || REDUCE == DIV) 90 | *address = val; 91 | else if (REDUCE == MEAN) 92 | *address = val / (scalar_t)(count > 0 ? count : 1); 93 | else if (REDUCE == MIN || REDUCE == MAX) { 94 | if (count > 0) { 95 | *address = val; 96 | *arg_address = arg; 97 | } else 98 | *address = (scalar_t)0; 99 | } 100 | } 101 | 102 | static inline __device__ void atomic_write(scalar_t *address, scalar_t val) { 103 | if (REDUCE == SUM || REDUCE == MEAN) 104 | atomAdd(address, val); 105 | else if (REDUCE == MUL) 106 | atomMul(address, val); 107 | else if (REDUCE == DIV) 108 | atomDiv(address, val); 109 | else if (REDUCE == MIN) 110 | atomMin(address, val); 111 | else if (REDUCE == MAX) 112 | atomMax(address, val); 113 | } 114 | }; 115 | -------------------------------------------------------------------------------- /csrc/cuda/scatter_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "scatter_cuda.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "reducer.cuh" 8 | #include "utils.cuh" 9 | 10 | #define THREADS 256 11 | #define BLOCKS(N) (N + THREADS - 1) / THREADS 12 | 13 | template 14 | __global__ void 15 | scatter_kernel(const scalar_t *src_data, 16 | const at::cuda::detail::TensorInfo index_info, 17 | scalar_t *out_data, int E, int K, int N, int numel) { 18 | 19 | int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; 20 | 21 | int b = thread_idx / (E * K); 22 | int k = thread_idx % K; 23 | 24 | if (thread_idx < numel) { 25 | int offset = at::cuda::detail::IndexToOffset::get( 26 | thread_idx, index_info); 27 | int64_t idx = index_info.data[offset]; 28 | 29 | Reducer::atomic_write(out_data + b * N * K + idx * K + k, 30 | src_data[thread_idx]); 31 | } 32 | } 33 | 34 | template 35 | __global__ void 36 | scatter_arg_kernel(const scalar_t *src_data, 37 | const at::cuda::detail::TensorInfo index_info, 38 | const scalar_t *out_data, int64_t *arg_out_data, int E, 39 | int K, int N, int numel) { 40 | 41 | int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; 42 | 43 | int b = thread_idx / (E * K); 44 | int e = (thread_idx / K) % E; 45 | int k = thread_idx % K; 46 | 47 | if (thread_idx < numel) { 48 | int offset = at::cuda::detail::IndexToOffset::get( 49 | thread_idx, index_info); 50 | int64_t idx = index_info.data[offset]; 51 | 52 | if (src_data[thread_idx] == out_data[b * N * K + idx * K + k]) { 53 | arg_out_data[b * N * K + idx * K + k] = e; 54 | } 55 | } 56 | } 57 | 58 | std::tuple> 59 | scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, 60 | std::optional optional_out, 61 | std::optional dim_size, std::string reduce) { 62 | CHECK_CUDA(src); 63 | CHECK_CUDA(index); 64 | if (optional_out.has_value()) 65 | CHECK_CUDA(optional_out.value()); 66 | c10::cuda::MaybeSetDevice(src.get_device()); 67 | 68 | CHECK_INPUT(src.dim() == index.dim()); 69 | for (auto i = 0; i < index.dim() - 1; i++) 70 | CHECK_INPUT(src.size(i) >= index.size(i)); 71 | 72 | src = src.contiguous(); 73 | 74 | torch::Tensor out; 75 | if (optional_out.has_value()) { 76 | out = optional_out.value().contiguous(); 77 | for (auto i = 0; i < out.dim(); i++) 78 | if (i != dim) 79 | CHECK_INPUT(src.size(i) == out.size(i)); 80 | } else { 81 | auto sizes = src.sizes().vec(); 82 | if (dim_size.has_value()) 83 | sizes[dim] = dim_size.value(); 84 | else if (index.numel() == 0) 85 | sizes[dim] = 0; 86 | else { 87 | sizes[dim] = 1 + index.max().cpu().data_ptr()[0]; 88 | } 89 | out = torch::empty(sizes, src.options()); 90 | } 91 | 92 | std::optional arg_out = std::nullopt; 93 | int64_t *arg_out_data = nullptr; 94 | if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { 95 | arg_out = torch::full_like(out, src.size(dim), index.options()); 96 | arg_out_data = arg_out.value().data_ptr(); 97 | } 98 | 99 | if (src.numel() == 0) { 100 | if (!optional_out.has_value()) 101 | out.fill_(0); 102 | return std::make_tuple(out, arg_out); 103 | } 104 | 105 | auto B = 1; 106 | for (auto i = 0; i < dim; i++) 107 | B *= src.size(i); 108 | auto E = src.size(dim); 109 | auto K = src.numel() / (B * E); 110 | auto N = out.size(dim); 111 | 112 | auto index_info = at::cuda::detail::getTensorInfo(index); 113 | auto stream = at::cuda::getCurrentCUDAStream(); 114 | AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] { 115 | auto src_data = src.data_ptr(); 116 | auto out_data = out.data_ptr(); 117 | 118 | AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { 119 | if (!optional_out.has_value()) 120 | out.fill_(Reducer::init()); 121 | 122 | scatter_kernel 123 | <<>>( 124 | src_data, index_info, out_data, E, K, N, src.numel()); 125 | 126 | if (!optional_out.has_value() && (REDUCE == MIN || REDUCE == MAX)) 127 | out.masked_fill_(out == Reducer::init(), (scalar_t)0); 128 | 129 | if (REDUCE == MIN || REDUCE == MAX) 130 | scatter_arg_kernel 131 | <<>>( 132 | src_data, index_info, out_data, arg_out_data, E, K, N, 133 | src.numel()); 134 | }); 135 | }); 136 | 137 | return std::make_tuple(out, arg_out); 138 | } 139 | -------------------------------------------------------------------------------- /csrc/cuda/scatter_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../extensions.h" 4 | 5 | std::tuple> 6 | scatter_cuda(torch::Tensor src, torch::Tensor index, int64_t dim, 7 | std::optional optional_out, 8 | std::optional dim_size, std::string reduce); 9 | -------------------------------------------------------------------------------- /csrc/cuda/segment_coo_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../extensions.h" 4 | 5 | std::tuple> 6 | segment_coo_cuda(torch::Tensor src, torch::Tensor index, 7 | std::optional optional_out, 8 | std::optional dim_size, std::string reduce); 9 | 10 | torch::Tensor gather_coo_cuda(torch::Tensor src, torch::Tensor index, 11 | std::optional optional_out); 12 | -------------------------------------------------------------------------------- /csrc/cuda/segment_csr_cuda.cu: -------------------------------------------------------------------------------- 1 | #include "segment_csr_cuda.h" 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | #include "index_info.cuh" 8 | #include "reducer.cuh" 9 | #include "utils.cuh" 10 | 11 | #define THREADS 256 12 | #define BLOCKS(TB, N) (TB * N + THREADS - 1) / THREADS 13 | #define FULL_MASK 0xffffffff 14 | 15 | template 16 | __global__ void 17 | segment_csr_kernel(const scalar_t *src_data, 18 | const at::cuda::detail::TensorInfo indptr_info, 19 | scalar_t *out_data, int64_t *arg_out_data, size_t N, 20 | size_t E) { 21 | 22 | // Each warp processes exactly `32/TB` rows and aggregates all row values 23 | // via a parallel reduction. 24 | 25 | int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; 26 | int row_idx = thread_idx / TB; 27 | int lane_idx = thread_idx & (TB - 1); 28 | 29 | if (row_idx < N) { 30 | int offset = IndexPtrToOffset::get(row_idx, indptr_info); 31 | int64_t row_start = __ldg(indptr_info.data + offset); 32 | int64_t row_end = __ldg(indptr_info.data + offset + 33 | indptr_info.strides[indptr_info.dims - 1]); 34 | 35 | scalar_t val = Reducer::init(); 36 | int64_t arg, arg_tmp; 37 | 38 | offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; 39 | for (int64_t src_idx = row_start + lane_idx; src_idx < row_end; 40 | src_idx += TB) { 41 | Reducer::update(&val, src_data[offset + src_idx], &arg, 42 | src_idx); 43 | } 44 | 45 | #pragma unroll 46 | for (int i = TB / 2; i > 0; i /= 2) { 47 | // Parallel reduction inside a single warp. 48 | if (REDUCE == MIN || REDUCE == MAX) 49 | arg_tmp = SHFL_DOWN_SYNC(FULL_MASK, arg, i); 50 | Reducer::update( 51 | &val, SHFL_DOWN_SYNC(FULL_MASK, val, i), &arg, arg_tmp); 52 | } 53 | 54 | if (lane_idx == 0) { 55 | Reducer::write(out_data + row_idx, val, 56 | arg_out_data + row_idx, arg, 57 | row_end - row_start); 58 | } 59 | } 60 | } 61 | 62 | template 63 | __global__ void segment_csr_broadcast_kernel( 64 | const scalar_t *src_data, 65 | const at::cuda::detail::TensorInfo indptr_info, 66 | scalar_t *out_data, int64_t *arg_out_data, size_t N, size_t K, size_t E) { 67 | 68 | // Each thread processes exactly one row. It turned out that is more 69 | // efficient than using shared memory due to avoiding synchronization 70 | // barriers. 71 | 72 | int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; 73 | int row_idx = thread_idx / K; 74 | int lane_idx = thread_idx % K; 75 | 76 | if (thread_idx < N * K) { 77 | int offset = IndexPtrToOffset::get(row_idx, indptr_info); 78 | int64_t row_start = __ldg(indptr_info.data + offset); 79 | int64_t row_end = __ldg(indptr_info.data + offset + 80 | indptr_info.strides[indptr_info.dims - 1]); 81 | 82 | scalar_t val = Reducer::init(); 83 | int64_t arg; 84 | 85 | offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; 86 | for (int64_t src_idx = row_start; src_idx < row_end; src_idx++) { 87 | Reducer::update( 88 | &val, src_data[offset + K * src_idx + lane_idx], &arg, src_idx); 89 | } 90 | 91 | Reducer::write(out_data + thread_idx, val, 92 | arg_out_data + thread_idx, arg, 93 | row_end - row_start); 94 | } 95 | } 96 | 97 | std::tuple> 98 | segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, 99 | std::optional optional_out, 100 | std::string reduce) { 101 | CHECK_CUDA(src); 102 | CHECK_CUDA(indptr); 103 | if (optional_out.has_value()) 104 | CHECK_CUDA(optional_out.value()); 105 | c10::cuda::MaybeSetDevice(src.get_device()); 106 | 107 | CHECK_INPUT(src.dim() >= indptr.dim()); 108 | 109 | auto sizes = indptr.sizes().vec(); 110 | for (auto i = 0; i < indptr.dim() - 1; i++) 111 | sizes[i] = src.size(i); 112 | indptr = indptr.expand(sizes); 113 | 114 | auto dim = indptr.dim() - 1; 115 | 116 | src = src.contiguous(); 117 | 118 | torch::Tensor out; 119 | if (optional_out.has_value()) { 120 | out = optional_out.value().contiguous(); 121 | for (int i = 0; i < out.dim(); i++) 122 | if (i != dim) 123 | CHECK_INPUT(src.size(i) == out.size(i)); 124 | CHECK_INPUT(src.numel() == 0 || out.size(dim) == indptr.size(dim) - 1); 125 | } else { 126 | sizes = src.sizes().vec(); 127 | sizes[dim] = std::max(indptr.size(dim) - 1, 0); 128 | out = torch::empty(sizes, src.options()); 129 | } 130 | 131 | std::optional arg_out = std::nullopt; 132 | int64_t *arg_out_data = nullptr; 133 | if (reduce2REDUCE.at(reduce) == MIN || reduce2REDUCE.at(reduce) == MAX) { 134 | arg_out = torch::full(out.sizes(), src.size(dim), indptr.options()); 135 | arg_out_data = arg_out.value().data_ptr(); 136 | } 137 | 138 | if (src.numel() == 0) { 139 | if (!optional_out.has_value()) 140 | out.fill_(0); 141 | return std::make_tuple(out, arg_out); 142 | } 143 | 144 | auto N = out.size(dim) * (indptr.numel() / indptr.size(-1)); 145 | auto K = out.numel() / N; 146 | auto E = src.size(dim); 147 | 148 | auto indptr_info = at::cuda::detail::getTensorInfo(indptr); 149 | auto stream = at::cuda::getCurrentCUDAStream(); 150 | AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] { 151 | auto src_data = src.data_ptr(); 152 | auto out_data = out.data_ptr(); 153 | 154 | AT_DISPATCH_REDUCTION_TYPES(reduce, [&] { 155 | if (K == 1) { 156 | segment_csr_kernel 157 | <<>>( 158 | src_data, indptr_info, out_data, arg_out_data, N, E); 159 | } else { 160 | segment_csr_broadcast_kernel 161 | <<>>( 162 | src_data, indptr_info, out_data, arg_out_data, N, K, E); 163 | } 164 | }); 165 | }); 166 | 167 | return std::make_tuple(out, arg_out); 168 | } 169 | 170 | template 171 | __global__ void 172 | gather_csr_kernel(const scalar_t *src_data, 173 | const at::cuda::detail::TensorInfo indptr_info, 174 | scalar_t *out_data, size_t N, size_t E) { 175 | 176 | int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; 177 | int row_idx = thread_idx / TB; 178 | int lane_idx = thread_idx % TB; 179 | 180 | if (row_idx < N) { 181 | int offset = IndexPtrToOffset::get(row_idx, indptr_info); 182 | int row_start = __ldg(indptr_info.data + offset); 183 | int row_end = __ldg(indptr_info.data + offset + 184 | indptr_info.strides[indptr_info.dims - 1]); 185 | scalar_t val = __ldg(src_data + row_idx); 186 | 187 | offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E; 188 | for (int out_idx = row_start + lane_idx; out_idx < row_end; out_idx += TB) { 189 | out_data[offset + out_idx] = val; // "Mostly" coalesced. 190 | } 191 | } 192 | } 193 | 194 | template 195 | __global__ void gather_csr_broadcast_kernel( 196 | const scalar_t *src_data, 197 | const at::cuda::detail::TensorInfo indptr_info, 198 | scalar_t *out_data, size_t N, size_t K, size_t E) { 199 | 200 | int thread_idx = blockIdx.x * blockDim.x + threadIdx.x; 201 | int row_idx = thread_idx / K; 202 | int lane_idx = thread_idx % K; 203 | 204 | if (thread_idx < N * K) { 205 | int offset = IndexPtrToOffset::get(row_idx, indptr_info); 206 | int row_start = __ldg(indptr_info.data + offset); 207 | int row_end = __ldg(indptr_info.data + offset + 208 | indptr_info.strides[indptr_info.dims - 1]); 209 | 210 | scalar_t val = src_data[thread_idx]; // Coalesced. 211 | 212 | offset = (row_idx / (indptr_info.sizes[indptr_info.dims - 1] - 1)) * E * K; 213 | for (int out_idx = row_start; out_idx < row_end; out_idx++) { 214 | out_data[offset + K * out_idx + lane_idx] = val; // "Mostly" coalesced. 215 | } 216 | } 217 | } 218 | 219 | torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, 220 | std::optional optional_out) { 221 | CHECK_CUDA(src); 222 | CHECK_CUDA(indptr); 223 | if (optional_out.has_value()) 224 | CHECK_CUDA(optional_out.value()); 225 | c10::cuda::MaybeSetDevice(src.get_device()); 226 | 227 | CHECK_INPUT(src.dim() >= indptr.dim()); 228 | 229 | auto sizes = indptr.sizes().vec(); 230 | for (auto i = 0; i < indptr.dim() - 1; i++) 231 | sizes[i] = src.size(i); 232 | indptr = indptr.expand(sizes); 233 | 234 | auto dim = indptr.dim() - 1; 235 | CHECK_INPUT(src.size(dim) == 0 || src.size(dim) == indptr.size(dim) - 1); 236 | 237 | src = src.contiguous(); 238 | 239 | torch::Tensor out; 240 | if (optional_out.has_value()) { 241 | out = optional_out.value().contiguous(); 242 | for (auto i = 0; i < out.dim(); i++) 243 | if (i != dim) 244 | CHECK_INPUT(src.size(i) == out.size(i)); 245 | } else { 246 | auto sizes = src.sizes().vec(); 247 | if (src.numel() > 0) { 248 | sizes[dim] = indptr.flatten()[-1].cpu().data_ptr()[0]; 249 | } else { 250 | sizes[dim] = 0; 251 | } 252 | out = torch::empty(sizes, src.options()); 253 | } 254 | 255 | if (src.numel() == 0) { 256 | if (!optional_out.has_value()) 257 | out.fill_(0); 258 | return out; 259 | } 260 | 261 | auto N = src.size(dim) * (indptr.numel() / indptr.size(-1)); 262 | auto K = src.numel() / N; 263 | auto E = out.size(dim); 264 | 265 | auto indptr_info = at::cuda::detail::getTensorInfo(indptr); 266 | auto stream = at::cuda::getCurrentCUDAStream(); 267 | AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, src.scalar_type(), "_", [&] { 268 | auto src_data = src.data_ptr(); 269 | auto out_data = out.data_ptr(); 270 | 271 | if (K == 1) 272 | gather_csr_kernel<<>>( 273 | src_data, indptr_info, out_data, N, E); 274 | else 275 | gather_csr_broadcast_kernel 276 | <<>>(src_data, indptr_info, 277 | out_data, N, K, E); 278 | }); 279 | 280 | return out; 281 | } 282 | -------------------------------------------------------------------------------- /csrc/cuda/segment_csr_cuda.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../extensions.h" 4 | 5 | std::tuple> 6 | segment_csr_cuda(torch::Tensor src, torch::Tensor indptr, 7 | std::optional optional_out, 8 | std::string reduce); 9 | 10 | torch::Tensor gather_csr_cuda(torch::Tensor src, torch::Tensor indptr, 11 | std::optional optional_out); 12 | -------------------------------------------------------------------------------- /csrc/cuda/utils.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "../extensions.h" 4 | 5 | #define CHECK_CUDA(x) \ 6 | AT_ASSERTM(x.device().is_cuda(), #x " must be CUDA tensor") 7 | #define CHECK_INPUT(x) AT_ASSERTM(x, "Input mismatch") 8 | 9 | __device__ __inline__ at::Half __shfl_up_sync(const unsigned mask, 10 | const at::Half var, 11 | const unsigned int delta) { 12 | return __shfl_up_sync(mask, var.operator __half(), delta); 13 | } 14 | 15 | __device__ __inline__ at::Half __shfl_down_sync(const unsigned mask, 16 | const at::Half var, 17 | const unsigned int delta) { 18 | return __shfl_down_sync(mask, var.operator __half(), delta); 19 | } 20 | 21 | __device__ __inline__ at::Half __shfl_up(const at::Half var, 22 | const unsigned int delta) { 23 | return __shfl_up(var.operator __half(), delta); 24 | } 25 | 26 | __device__ __inline__ at::Half __shfl_down(const at::Half var, 27 | const unsigned int delta) { 28 | return __shfl_down(var.operator __half(), delta); 29 | } 30 | 31 | #ifdef USE_ROCM 32 | __device__ __inline__ at::Half __ldg(const at::Half* ptr) { 33 | return __ldg(reinterpret_cast(ptr)); 34 | } 35 | #define SHFL_UP_SYNC(mask, var, delta) __shfl_up(var, delta) 36 | #define SHFL_DOWN_SYNC(mask, var, delta) __shfl_down(var, delta) 37 | #else 38 | #define SHFL_UP_SYNC __shfl_up_sync 39 | #define SHFL_DOWN_SYNC __shfl_down_sync 40 | #endif 41 | -------------------------------------------------------------------------------- /csrc/extensions.h: -------------------------------------------------------------------------------- 1 | #include "macros.h" 2 | #include 3 | -------------------------------------------------------------------------------- /csrc/macros.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #ifdef _WIN32 4 | #if defined(torchscatter_EXPORTS) 5 | #define SCATTER_API __declspec(dllexport) 6 | #else 7 | #define SCATTER_API __declspec(dllimport) 8 | #endif 9 | #else 10 | #define SCATTER_API 11 | #endif 12 | 13 | #if (defined __cpp_inline_variables) || __cplusplus >= 201703L 14 | #define SCATTER_INLINE_VARIABLE inline 15 | #else 16 | #ifdef _MSC_VER 17 | #define SCATTER_INLINE_VARIABLE __declspec(selectany) 18 | #else 19 | #define SCATTER_INLINE_VARIABLE __attribute__((weak)) 20 | #endif 21 | #endif 22 | -------------------------------------------------------------------------------- /csrc/scatter.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include "extensions.h" 4 | 5 | namespace scatter { 6 | SCATTER_API int64_t cuda_version() noexcept; 7 | 8 | namespace detail { 9 | SCATTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version(); 10 | } // namespace detail 11 | } // namespace scatter 12 | 13 | SCATTER_API torch::Tensor 14 | scatter_sum(torch::Tensor src, torch::Tensor index, int64_t dim, 15 | std::optional optional_out, 16 | std::optional dim_size); 17 | 18 | SCATTER_API torch::Tensor 19 | scatter_mul(torch::Tensor src, torch::Tensor index, int64_t dim, 20 | std::optional optional_out, 21 | std::optional dim_size); 22 | 23 | SCATTER_API torch::Tensor 24 | scatter_mean(torch::Tensor src, torch::Tensor index, int64_t dim, 25 | std::optional optional_out, 26 | std::optional dim_size); 27 | 28 | SCATTER_API std::tuple 29 | scatter_min(torch::Tensor src, torch::Tensor index, int64_t dim, 30 | std::optional optional_out, 31 | std::optional dim_size); 32 | 33 | SCATTER_API std::tuple 34 | scatter_max(torch::Tensor src, torch::Tensor index, int64_t dim, 35 | std::optional optional_out, 36 | std::optional dim_size); 37 | 38 | SCATTER_API torch::Tensor 39 | segment_sum_coo(torch::Tensor src, torch::Tensor index, 40 | std::optional optional_out, 41 | std::optional dim_size); 42 | 43 | SCATTER_API torch::Tensor 44 | segment_mean_coo(torch::Tensor src, torch::Tensor index, 45 | std::optional optional_out, 46 | std::optional dim_size); 47 | 48 | SCATTER_API std::tuple 49 | segment_min_coo(torch::Tensor src, torch::Tensor index, 50 | std::optional optional_out, 51 | std::optional dim_size); 52 | 53 | SCATTER_API std::tuple 54 | segment_max_coo(torch::Tensor src, torch::Tensor index, 55 | std::optional optional_out, 56 | std::optional dim_size); 57 | 58 | SCATTER_API torch::Tensor 59 | gather_coo(torch::Tensor src, torch::Tensor index, 60 | std::optional optional_out); 61 | 62 | SCATTER_API torch::Tensor 63 | segment_sum_csr(torch::Tensor src, torch::Tensor indptr, 64 | std::optional optional_out); 65 | 66 | SCATTER_API torch::Tensor 67 | segment_mean_csr(torch::Tensor src, torch::Tensor indptr, 68 | std::optional optional_out); 69 | 70 | SCATTER_API std::tuple 71 | segment_min_csr(torch::Tensor src, torch::Tensor indptr, 72 | std::optional optional_out); 73 | 74 | SCATTER_API std::tuple 75 | segment_max_csr(torch::Tensor src, torch::Tensor indptr, 76 | std::optional optional_out); 77 | 78 | SCATTER_API torch::Tensor 79 | gather_csr(torch::Tensor src, torch::Tensor indptr, 80 | std::optional optional_out); 81 | -------------------------------------------------------------------------------- /csrc/segment_coo.cpp: -------------------------------------------------------------------------------- 1 | #ifdef WITH_PYTHON 2 | #include 3 | #endif 4 | 5 | #include 6 | 7 | #include "cpu/segment_coo_cpu.h" 8 | #include "macros.h" 9 | #include "utils.h" 10 | 11 | #ifdef WITH_CUDA 12 | #include "cuda/segment_coo_cuda.h" 13 | #endif 14 | 15 | #ifdef _WIN32 16 | #ifdef WITH_PYTHON 17 | #ifdef WITH_CUDA 18 | PyMODINIT_FUNC PyInit__segment_coo_cuda(void) { return NULL; } 19 | #else 20 | PyMODINIT_FUNC PyInit__segment_coo_cpu(void) { return NULL; } 21 | #endif 22 | #endif 23 | #endif 24 | 25 | std::tuple> 26 | segment_coo_fw(torch::Tensor src, torch::Tensor index, 27 | std::optional optional_out, 28 | std::optional dim_size, std::string reduce) { 29 | if (src.device().is_cuda()) { 30 | #ifdef WITH_CUDA 31 | return segment_coo_cuda(src, index, optional_out, dim_size, reduce); 32 | #else 33 | AT_ERROR("Not compiled with CUDA support"); 34 | #endif 35 | } else { 36 | return segment_coo_cpu(src, index, optional_out, dim_size, reduce); 37 | } 38 | } 39 | 40 | torch::Tensor gather_coo_fw(torch::Tensor src, torch::Tensor index, 41 | std::optional optional_out) { 42 | if (src.device().is_cuda()) { 43 | #ifdef WITH_CUDA 44 | return gather_coo_cuda(src, index, optional_out); 45 | #else 46 | AT_ERROR("Not compiled with CUDA support"); 47 | #endif 48 | } else { 49 | return gather_coo_cpu(src, index, optional_out); 50 | } 51 | } 52 | 53 | using torch::autograd::AutogradContext; 54 | using torch::autograd::Variable; 55 | using torch::autograd::variable_list; 56 | 57 | class SegmentSumCOO : public torch::autograd::Function { 58 | public: 59 | static variable_list forward(AutogradContext *ctx, Variable src, 60 | Variable index, 61 | std::optional optional_out, 62 | std::optional dim_size) { 63 | ctx->saved_data["src_shape"] = src.sizes(); 64 | auto result = segment_coo_fw(src, index, optional_out, dim_size, "sum"); 65 | auto out = std::get<0>(result); 66 | ctx->save_for_backward({index}); 67 | if (optional_out.has_value()) 68 | ctx->mark_dirty({optional_out.value()}); 69 | return {out}; 70 | } 71 | 72 | static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { 73 | auto grad_out = grad_outs[0]; 74 | auto saved = ctx->get_saved_variables(); 75 | auto index = saved[0]; 76 | auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); 77 | auto grad_in = torch::empty(src_shape, grad_out.options()); 78 | gather_coo_fw(grad_out, index, grad_in); 79 | return {grad_in, Variable(), Variable(), Variable()}; 80 | } 81 | }; 82 | 83 | class SegmentMeanCOO : public torch::autograd::Function { 84 | public: 85 | static variable_list forward(AutogradContext *ctx, Variable src, 86 | Variable index, 87 | std::optional optional_out, 88 | std::optional dim_size) { 89 | ctx->saved_data["src_shape"] = src.sizes(); 90 | auto result = segment_coo_fw(src, index, optional_out, dim_size, "mean"); 91 | auto out = std::get<0>(result); 92 | auto count = std::get<1>(result).value(); 93 | ctx->save_for_backward({index, count}); 94 | if (optional_out.has_value()) 95 | ctx->mark_dirty({optional_out.value()}); 96 | return {out}; 97 | } 98 | 99 | static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { 100 | auto grad_out = grad_outs[0]; 101 | auto saved = ctx->get_saved_variables(); 102 | auto index = saved[0]; 103 | auto count = saved[1]; 104 | auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); 105 | auto grad_in = torch::empty(src_shape, grad_out.options()); 106 | gather_coo_fw(grad_out, index, grad_in); 107 | count = gather_coo_fw(count, index, std::nullopt); 108 | for (auto i = 0; i < grad_out.dim() - index.dim(); i++) 109 | count = count.unsqueeze(-1); 110 | grad_in.true_divide_(count); 111 | return {grad_in, Variable(), Variable(), Variable()}; 112 | } 113 | }; 114 | 115 | class SegmentMinCOO : public torch::autograd::Function { 116 | public: 117 | static variable_list forward(AutogradContext *ctx, Variable src, 118 | Variable index, 119 | std::optional optional_out, 120 | std::optional dim_size) { 121 | ctx->saved_data["src_shape"] = src.sizes(); 122 | auto result = segment_coo_fw(src, index, optional_out, dim_size, "min"); 123 | auto out = std::get<0>(result); 124 | auto arg_out = std::get<1>(result).value(); 125 | ctx->save_for_backward({index, arg_out}); 126 | ctx->mark_non_differentiable({arg_out}); 127 | if (optional_out.has_value()) 128 | ctx->mark_dirty({optional_out.value()}); 129 | return {out, arg_out}; 130 | } 131 | 132 | static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { 133 | auto grad_out = grad_outs[0]; 134 | auto saved = ctx->get_saved_variables(); 135 | auto index = saved[0]; 136 | auto arg_out = saved[1]; 137 | auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); 138 | src_shape[index.dim() - 1] += 1; 139 | auto grad_in = torch::zeros(src_shape, grad_out.options()); 140 | grad_in.scatter_(index.dim() - 1, arg_out, grad_out); 141 | grad_in = 142 | grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1); 143 | return {grad_in, Variable(), Variable(), Variable()}; 144 | } 145 | }; 146 | 147 | class SegmentMaxCOO : public torch::autograd::Function { 148 | public: 149 | static variable_list forward(AutogradContext *ctx, Variable src, 150 | Variable index, 151 | std::optional optional_out, 152 | std::optional dim_size) { 153 | ctx->saved_data["src_shape"] = src.sizes(); 154 | auto result = segment_coo_fw(src, index, optional_out, dim_size, "max"); 155 | auto out = std::get<0>(result); 156 | auto arg_out = std::get<1>(result).value(); 157 | ctx->save_for_backward({index, arg_out}); 158 | ctx->mark_non_differentiable({arg_out}); 159 | if (optional_out.has_value()) 160 | ctx->mark_dirty({optional_out.value()}); 161 | return {out, arg_out}; 162 | } 163 | 164 | static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { 165 | auto grad_out = grad_outs[0]; 166 | auto saved = ctx->get_saved_variables(); 167 | auto index = saved[0]; 168 | auto arg_out = saved[1]; 169 | auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); 170 | src_shape[index.dim() - 1] += 1; 171 | auto grad_in = torch::zeros(src_shape, grad_out.options()); 172 | grad_in.scatter_(index.dim() - 1, arg_out, grad_out); 173 | grad_in = 174 | grad_in.narrow(index.dim() - 1, 0, src_shape[index.dim() - 1] - 1); 175 | return {grad_in, Variable(), Variable(), Variable()}; 176 | } 177 | }; 178 | 179 | class GatherCOO : public torch::autograd::Function { 180 | public: 181 | static variable_list forward(AutogradContext *ctx, Variable src, 182 | Variable index, 183 | std::optional optional_out) { 184 | ctx->saved_data["src_shape"] = src.sizes(); 185 | auto out = gather_coo_fw(src, index, optional_out); 186 | ctx->save_for_backward({index}); 187 | if (optional_out.has_value()) 188 | ctx->mark_dirty({optional_out.value()}); 189 | return {out}; 190 | } 191 | 192 | static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { 193 | auto grad_out = grad_outs[0]; 194 | auto saved = ctx->get_saved_variables(); 195 | auto index = saved[0]; 196 | auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); 197 | 198 | auto grad_in = torch::zeros(src_shape, grad_out.options()); 199 | segment_coo_fw(grad_out, index, grad_in, std::nullopt, "sum"); 200 | return {grad_in, Variable(), Variable()}; 201 | } 202 | }; 203 | 204 | SCATTER_API torch::Tensor 205 | segment_sum_coo(torch::Tensor src, torch::Tensor index, 206 | std::optional optional_out, 207 | std::optional dim_size) { 208 | return SegmentSumCOO::apply(src, index, optional_out, dim_size)[0]; 209 | } 210 | 211 | SCATTER_API torch::Tensor 212 | segment_mean_coo(torch::Tensor src, torch::Tensor index, 213 | std::optional optional_out, 214 | std::optional dim_size) { 215 | return SegmentMeanCOO::apply(src, index, optional_out, dim_size)[0]; 216 | } 217 | 218 | SCATTER_API std::tuple 219 | segment_min_coo(torch::Tensor src, torch::Tensor index, 220 | std::optional optional_out, 221 | std::optional dim_size) { 222 | auto result = SegmentMinCOO::apply(src, index, optional_out, dim_size); 223 | return std::make_tuple(result[0], result[1]); 224 | } 225 | 226 | SCATTER_API std::tuple 227 | segment_max_coo(torch::Tensor src, torch::Tensor index, 228 | std::optional optional_out, 229 | std::optional dim_size) { 230 | auto result = SegmentMaxCOO::apply(src, index, optional_out, dim_size); 231 | return std::make_tuple(result[0], result[1]); 232 | } 233 | 234 | SCATTER_API torch::Tensor 235 | gather_coo(torch::Tensor src, torch::Tensor index, 236 | std::optional optional_out) { 237 | return GatherCOO::apply(src, index, optional_out)[0]; 238 | } 239 | 240 | static auto registry = 241 | torch::RegisterOperators() 242 | .op("torch_scatter::segment_sum_coo", &segment_sum_coo) 243 | .op("torch_scatter::segment_mean_coo", &segment_mean_coo) 244 | .op("torch_scatter::segment_min_coo", &segment_min_coo) 245 | .op("torch_scatter::segment_max_coo", &segment_max_coo) 246 | .op("torch_scatter::gather_coo", &gather_coo); 247 | -------------------------------------------------------------------------------- /csrc/segment_csr.cpp: -------------------------------------------------------------------------------- 1 | #ifdef WITH_PYTHON 2 | #include 3 | #endif 4 | 5 | #include 6 | 7 | #include "cpu/segment_csr_cpu.h" 8 | #include "macros.h" 9 | #include "utils.h" 10 | 11 | #ifdef WITH_CUDA 12 | #include "cuda/segment_csr_cuda.h" 13 | #endif 14 | 15 | #ifdef _WIN32 16 | #ifdef WITH_PYTHON 17 | #ifdef WITH_CUDA 18 | PyMODINIT_FUNC PyInit__segment_csr_cuda(void) { return NULL; } 19 | #else 20 | PyMODINIT_FUNC PyInit__segment_csr_cpu(void) { return NULL; } 21 | #endif 22 | #endif 23 | #endif 24 | 25 | std::tuple> 26 | segment_csr_fw(torch::Tensor src, torch::Tensor indptr, 27 | std::optional optional_out, 28 | std::string reduce) { 29 | if (src.device().is_cuda()) { 30 | #ifdef WITH_CUDA 31 | return segment_csr_cuda(src, indptr, optional_out, reduce); 32 | #else 33 | AT_ERROR("Not compiled with CUDA support"); 34 | #endif 35 | } else { 36 | return segment_csr_cpu(src, indptr, optional_out, reduce); 37 | } 38 | } 39 | 40 | torch::Tensor gather_csr_fw(torch::Tensor src, torch::Tensor indptr, 41 | std::optional optional_out) { 42 | if (src.device().is_cuda()) { 43 | #ifdef WITH_CUDA 44 | return gather_csr_cuda(src, indptr, optional_out); 45 | #else 46 | AT_ERROR("Not compiled with CUDA support"); 47 | #endif 48 | } else { 49 | return gather_csr_cpu(src, indptr, optional_out); 50 | } 51 | } 52 | 53 | using torch::autograd::AutogradContext; 54 | using torch::autograd::Variable; 55 | using torch::autograd::variable_list; 56 | 57 | class SegmentSumCSR : public torch::autograd::Function { 58 | public: 59 | static variable_list forward(AutogradContext *ctx, Variable src, 60 | Variable indptr, 61 | std::optional optional_out) { 62 | ctx->saved_data["src_shape"] = src.sizes(); 63 | auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "sum")); 64 | ctx->save_for_backward({indptr}); 65 | if (optional_out.has_value()) 66 | ctx->mark_dirty({optional_out.value()}); 67 | return {out}; 68 | } 69 | 70 | static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { 71 | auto grad_out = grad_outs[0]; 72 | auto saved = ctx->get_saved_variables(); 73 | auto indptr = saved[0]; 74 | auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); 75 | auto grad_in = torch::empty(src_shape, grad_out.options()); 76 | gather_csr_fw(grad_out, indptr, grad_in); 77 | return {grad_in, Variable(), Variable()}; 78 | } 79 | }; 80 | 81 | class SegmentMeanCSR : public torch::autograd::Function { 82 | public: 83 | static variable_list forward(AutogradContext *ctx, Variable src, 84 | Variable indptr, 85 | std::optional optional_out) { 86 | ctx->saved_data["src_shape"] = src.sizes(); 87 | auto out = std::get<0>(segment_csr_fw(src, indptr, optional_out, "mean")); 88 | ctx->save_for_backward({indptr}); 89 | if (optional_out.has_value()) 90 | ctx->mark_dirty({optional_out.value()}); 91 | return {out}; 92 | } 93 | 94 | static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { 95 | auto grad_out = grad_outs[0]; 96 | auto saved = ctx->get_saved_variables(); 97 | auto indptr = saved[0]; 98 | auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); 99 | auto grad_in = torch::empty(src_shape, grad_out.options()); 100 | if (grad_in.numel() > 0) { 101 | gather_csr_fw(grad_out, indptr, grad_in); 102 | auto indptr1 = indptr.narrow(-1, 0, indptr.size(-1) - 1); 103 | auto indptr2 = indptr.narrow(-1, 1, indptr.size(-1) - 1); 104 | auto count = (indptr2 - indptr1).to(grad_in.options()); 105 | count = gather_csr_fw(count, indptr, std::nullopt); 106 | for (auto i = 0; i < grad_out.dim() - indptr.dim(); i++) 107 | count = count.unsqueeze(-1); 108 | grad_in.true_divide_(count); 109 | } 110 | return {grad_in, Variable(), Variable()}; 111 | } 112 | }; 113 | 114 | class SegmentMinCSR : public torch::autograd::Function { 115 | public: 116 | static variable_list forward(AutogradContext *ctx, Variable src, 117 | Variable indptr, 118 | std::optional optional_out) { 119 | ctx->saved_data["src_shape"] = src.sizes(); 120 | auto result = segment_csr_fw(src, indptr, optional_out, "min"); 121 | auto out = std::get<0>(result); 122 | auto arg_out = std::get<1>(result).value(); 123 | ctx->save_for_backward({indptr, arg_out}); 124 | ctx->mark_non_differentiable({arg_out}); 125 | if (optional_out.has_value()) 126 | ctx->mark_dirty({optional_out.value()}); 127 | return {out, arg_out}; 128 | } 129 | 130 | static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { 131 | auto grad_out = grad_outs[0]; 132 | auto saved = ctx->get_saved_variables(); 133 | auto indptr = saved[0]; 134 | auto arg_out = saved[1]; 135 | auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); 136 | src_shape[indptr.dim() - 1] += 1; 137 | auto grad_in = torch::zeros(src_shape, grad_out.options()); 138 | grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out); 139 | grad_in = 140 | grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1); 141 | return {grad_in, Variable(), Variable()}; 142 | } 143 | }; 144 | 145 | class SegmentMaxCSR : public torch::autograd::Function { 146 | public: 147 | static variable_list forward(AutogradContext *ctx, Variable src, 148 | Variable indptr, 149 | std::optional optional_out) { 150 | ctx->saved_data["src_shape"] = src.sizes(); 151 | auto result = segment_csr_fw(src, indptr, optional_out, "max"); 152 | auto out = std::get<0>(result); 153 | auto arg_out = std::get<1>(result).value(); 154 | ctx->save_for_backward({indptr, arg_out}); 155 | ctx->mark_non_differentiable({arg_out}); 156 | if (optional_out.has_value()) 157 | ctx->mark_dirty({optional_out.value()}); 158 | return {out, arg_out}; 159 | } 160 | 161 | static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { 162 | auto grad_out = grad_outs[0]; 163 | auto saved = ctx->get_saved_variables(); 164 | auto indptr = saved[0]; 165 | auto arg_out = saved[1]; 166 | auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); 167 | src_shape[indptr.dim() - 1] += 1; 168 | auto grad_in = torch::zeros(src_shape, grad_out.options()); 169 | grad_in.scatter_(indptr.dim() - 1, arg_out, grad_out); 170 | grad_in = 171 | grad_in.narrow(indptr.dim() - 1, 0, src_shape[indptr.dim() - 1] - 1); 172 | return {grad_in, Variable(), Variable()}; 173 | } 174 | }; 175 | 176 | class GatherCSR : public torch::autograd::Function { 177 | public: 178 | static variable_list forward(AutogradContext *ctx, Variable src, 179 | Variable indptr, 180 | std::optional optional_out) { 181 | ctx->saved_data["src_shape"] = src.sizes(); 182 | auto out = gather_csr_fw(src, indptr, optional_out); 183 | ctx->save_for_backward({indptr}); 184 | if (optional_out.has_value()) 185 | ctx->mark_dirty({optional_out.value()}); 186 | return {out}; 187 | } 188 | 189 | static variable_list backward(AutogradContext *ctx, variable_list grad_outs) { 190 | auto grad_out = grad_outs[0]; 191 | auto saved = ctx->get_saved_variables(); 192 | auto indptr = saved[0]; 193 | auto src_shape = list2vec(ctx->saved_data["src_shape"].toIntList()); 194 | 195 | auto grad_in = torch::empty(src_shape, grad_out.options()); 196 | segment_csr_fw(grad_out, indptr, grad_in, "sum"); 197 | return {grad_in, Variable(), Variable()}; 198 | } 199 | }; 200 | 201 | SCATTER_API torch::Tensor 202 | segment_sum_csr(torch::Tensor src, torch::Tensor indptr, 203 | std::optional optional_out) { 204 | return SegmentSumCSR::apply(src, indptr, optional_out)[0]; 205 | } 206 | 207 | SCATTER_API torch::Tensor 208 | segment_mean_csr(torch::Tensor src, torch::Tensor indptr, 209 | std::optional optional_out) { 210 | return SegmentMeanCSR::apply(src, indptr, optional_out)[0]; 211 | } 212 | 213 | SCATTER_API std::tuple 214 | segment_min_csr(torch::Tensor src, torch::Tensor indptr, 215 | std::optional optional_out) { 216 | auto result = SegmentMinCSR::apply(src, indptr, optional_out); 217 | return std::make_tuple(result[0], result[1]); 218 | } 219 | 220 | SCATTER_API std::tuple 221 | segment_max_csr(torch::Tensor src, torch::Tensor indptr, 222 | std::optional optional_out) { 223 | auto result = SegmentMaxCSR::apply(src, indptr, optional_out); 224 | return std::make_tuple(result[0], result[1]); 225 | } 226 | 227 | SCATTER_API torch::Tensor 228 | gather_csr(torch::Tensor src, torch::Tensor indptr, 229 | std::optional optional_out) { 230 | return GatherCSR::apply(src, indptr, optional_out)[0]; 231 | } 232 | 233 | static auto registry = 234 | torch::RegisterOperators() 235 | .op("torch_scatter::segment_sum_csr", &segment_sum_csr) 236 | .op("torch_scatter::segment_mean_csr", &segment_mean_csr) 237 | .op("torch_scatter::segment_min_csr", &segment_min_csr) 238 | .op("torch_scatter::segment_max_csr", &segment_max_csr) 239 | .op("torch_scatter::gather_csr", &gather_csr); 240 | -------------------------------------------------------------------------------- /csrc/utils.h: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | 6 | inline std::vector list2vec(const c10::List list) { 7 | std::vector result; 8 | result.reserve(list.size()); 9 | for (size_t i = 0; i < list.size(); i++) 10 | result.push_back(list[i]); 11 | return result; 12 | } 13 | -------------------------------------------------------------------------------- /csrc/version.cpp: -------------------------------------------------------------------------------- 1 | #ifdef WITH_PYTHON 2 | #include 3 | #endif 4 | 5 | #include 6 | #include "scatter.h" 7 | #include "macros.h" 8 | 9 | #ifdef WITH_CUDA 10 | #ifdef USE_ROCM 11 | #include 12 | #else 13 | #include 14 | #endif 15 | #endif 16 | 17 | #ifdef _WIN32 18 | #ifdef WITH_PYTHON 19 | #ifdef WITH_CUDA 20 | PyMODINIT_FUNC PyInit__version_cuda(void) { return NULL; } 21 | #else 22 | PyMODINIT_FUNC PyInit__version_cpu(void) { return NULL; } 23 | #endif 24 | #endif 25 | #endif 26 | 27 | namespace scatter { 28 | SCATTER_API int64_t cuda_version() noexcept { 29 | #ifdef WITH_CUDA 30 | #ifdef USE_ROCM 31 | return HIP_VERSION; 32 | #else 33 | return CUDA_VERSION; 34 | #endif 35 | #else 36 | return -1; 37 | #endif 38 | } 39 | } // namespace scatter 40 | 41 | static auto registry = torch::RegisterOperators().op( 42 | "torch_scatter::cuda_version", [] { return scatter::cuda_version(); }); 43 | -------------------------------------------------------------------------------- /docs/.nojekyll: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rusty1s/pytorch_scatter/38289bfa4dfd58961ef3cdb3c69ee70ce2bc8890/docs/.nojekyll -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | SPHINXBUILD := sphinx-build 2 | SPHINXPROJ := pytorch_scatter 3 | SOURCEDIR := source 4 | BUILDDIR := build 5 | 6 | .PHONY: help Makefile 7 | 8 | %: Makefile 9 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" 10 | -------------------------------------------------------------------------------- /docs/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Redirect 6 | 7 | 8 | 9 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | https://download.pytorch.org/whl/cpu/torch-1.11.0%2Bcpu-cp38-cp38-linux_x86_64.whl 2 | sphinx>=3 3 | sphinx_rtd_theme 4 | -------------------------------------------------------------------------------- /docs/source/_figures/add.tex: -------------------------------------------------------------------------------- 1 | \def\indices{{0, 0, 1, 0, 2, 2, 3, 3}} 2 | \def\inputs{{5, 1, 7, 2, 3, 2, 1, 3}} 3 | \def\outputs{{8, 7, 5, 4}} 4 | \def\colors{{"cyan", "orange", "olive", "magenta"}} 5 | \def\numberInputs{7} 6 | \def\numberOutputs{3} 7 | \def\operation{add} 8 | \input{template} 9 | -------------------------------------------------------------------------------- /docs/source/_figures/build.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | files=(add sub mul div mean max min std) 4 | 5 | for name in "${files[@]}"; do 6 | pdflatex "$name" 7 | pdf2svg "$name.pdf" "$name.svg" 8 | done 9 | -------------------------------------------------------------------------------- /docs/source/_figures/div.tex: -------------------------------------------------------------------------------- 1 | \def\indices{{0, 0, 1, 0, 2, 2, 3, 3}} 2 | \def\inputs{{5, 1, 7, 2, 3, 2, 1, 3}} 3 | \def\outputs{{"$\frac{1}{10}$", "$\frac{1}{7}$", "$\frac{1}{6}$", "$\frac{1}{3}$"}} 4 | \def\colors{{"cyan", "orange", "olive", "magenta"}} 5 | \def\numberInputs{7} 6 | \def\numberOutputs{3} 7 | \def\operation{div} 8 | \input{template} 9 | -------------------------------------------------------------------------------- /docs/source/_figures/max.tex: -------------------------------------------------------------------------------- 1 | \def\indices{{0, 0, 1, 0, 2, 2, 3, 3}} 2 | \def\inputs{{5, 1, 7, 2, 3, 2, 1, 3}} 3 | \def\outputs{{5, 7, 3, 3}} 4 | \def\colors{{"cyan", "orange", "olive", "magenta"}} 5 | \def\numberInputs{7} 6 | \def\numberOutputs{3} 7 | \def\operation{max} 8 | \input{template} 9 | -------------------------------------------------------------------------------- /docs/source/_figures/mean.tex: -------------------------------------------------------------------------------- 1 | \def\indices{{0, 0, 1, 0, 2, 2, 3, 3}} 2 | \def\inputs{{5, 1, 7, 2, 3, 2, 1, 3}} 3 | \def\outputs{{"$\frac{8}{3}$", "$\frac{7}{1}$", "$\frac{5}{2}$", "$\frac{4}{2}$"}} 4 | \def\colors{{"cyan", "orange", "olive", "magenta"}} 5 | \def\numberInputs{7} 6 | \def\numberOutputs{3} 7 | \def\operation{mean} 8 | \input{template} 9 | -------------------------------------------------------------------------------- /docs/source/_figures/min.tex: -------------------------------------------------------------------------------- 1 | \def\indices{{0, 0, 1, 0, 2, 2, 3, 3}} 2 | \def\inputs{{"-5", "-1", "-7", "-2", "-3", "-2", "-1", "-3"}} 3 | \def\outputs{{"-5", "-7", "-3", "-3"}} 4 | \def\colors{{"cyan", "orange", "olive", "magenta"}} 5 | \def\numberInputs{7} 6 | \def\numberOutputs{3} 7 | \def\operation{min} 8 | \input{template} 9 | -------------------------------------------------------------------------------- /docs/source/_figures/mul.tex: -------------------------------------------------------------------------------- 1 | \def\indices{{0, 0, 1, 0, 2, 2, 3, 3}} 2 | \def\inputs{{5, 1, 7, 2, 3, 2, 1, 3}} 3 | \def\outputs{{10, 7, 6, 3}} 4 | \def\colors{{"cyan", "orange", "olive", "magenta"}} 5 | \def\numberInputs{7} 6 | \def\numberOutputs{3} 7 | \def\operation{mul} 8 | \input{template} 9 | -------------------------------------------------------------------------------- /docs/source/_figures/segment_coo.tex: -------------------------------------------------------------------------------- 1 | \def\indices{{0, 0, 0, 1, 2, 2, 3, 3}} 2 | \def\inputs{{5, 1, 2, 7, 3, 2, 1, 3}} 3 | \def\outputs{{8, 7, 5, 4}} 4 | \def\colors{{"cyan", "orange", "olive", "magenta"}} 5 | \def\numberInputs{7} 6 | \def\numberOutputs{3} 7 | \def\operation{add} 8 | \input{template} 9 | -------------------------------------------------------------------------------- /docs/source/_figures/std.tex: -------------------------------------------------------------------------------- 1 | \def\indices{{0, 0, 1, 0, 2, 2, 3, 3}} 2 | \def\inputs{{5, 1, 7, 2, 3, 2, 1, 3}} 3 | \def\outputs{{2.1, 0, 0.7, 1.4}} 4 | \def\colors{{"cyan", "orange", "olive", "magenta"}} 5 | \def\numberInputs{7} 6 | \def\numberOutputs{3} 7 | \def\operation{std} 8 | \input{template} 9 | -------------------------------------------------------------------------------- /docs/source/_figures/sub.tex: -------------------------------------------------------------------------------- 1 | \def\indices{{0, 0, 1, 0, 2, 2, 3, 3}} 2 | \def\inputs{{5, 1, 7, 2, 3, 2, 1, 3}} 3 | \def\outputs{{"-8", "-7", "-5", "-4"}} 4 | \def\colors{{"cyan", "orange", "olive", "magenta"}} 5 | \def\numberInputs{7} 6 | \def\numberOutputs{3} 7 | \def\operation{sub} 8 | \input{template} 9 | -------------------------------------------------------------------------------- /docs/source/_figures/template.tex: -------------------------------------------------------------------------------- 1 | \documentclass[class=minimal]{standalone} 2 | 3 | \usepackage{tikz} 4 | \usetikzlibrary{shapes.geometric} 5 | \newcommand{\tiny}[1]{\scalebox{0.45}{#1}} 6 | 7 | \begin{document} 8 | 9 | \begin{tikzpicture} 10 | 11 | \tikzstyle{title}=[text width=1.1cm, inner sep=0pt] 12 | \tikzstyle{square}=[rectangle, draw, minimum width=0.5cm, minimum height=0.5cm, inner sep=0pt, fill opacity=0.5, text opacity=1] 13 | \tikzstyle{op}=[ellipse, draw, inner sep=-1pt, minimum height=9pt, minimum width=12pt] 14 | \tikzstyle{edge}=[->] 15 | \tikzstyle{round}=[out=-90, in=90, looseness=0.85] 16 | 17 | \node[title] at (-0.8, 2.2) {index}; 18 | \node[title] at (-0.8, 1.5) {input}; 19 | \foreach \i in {0,...,\numberInputs} { 20 | \pgfmathparse{\indices[\i]}\let\idx\pgfmathresult 21 | \pgfmathparse{\inputs[\i]}\let\in\pgfmathresult 22 | \pgfmathparse{\colors[\idx]}\let\co\pgfmathresult 23 | \node[square] (index\i) at (\i * 0.5, 2.2) {\idx}; 24 | \node[square, fill=\co] (input\i) at (\i * 0.5, 1.5) {\in}; 25 | \draw[edge] (index\i) -- (input\i); 26 | } 27 | 28 | \node[title] at (-0.8, 0.0) {out}; 29 | \foreach \i in {0,...,\numberOutputs} { 30 | \pgfmathparse{\outputs[\i]}\let\out\pgfmathresult 31 | \pgfmathparse{\colors[\i]}\let\co\pgfmathresult 32 | \def \x{(\numberInputs - \numberOutputs) * 0.25 + \i * 0.5} 33 | \node[op] (op\i) at ({\x}, 0.6) {\tiny{\operation}}; 34 | \node[square, fill=\co] (output\i) at ({\x}, 0.0) {\out}; 35 | \draw[edge] (op\i) -- (output\i); 36 | } 37 | 38 | \foreach \i in {0,...,\numberInputs} { 39 | \pgfmathparse{\indices[\i]}\let\idx\pgfmathresult 40 | \draw[edge] (input\i) to[round] (op\idx); 41 | } 42 | 43 | \end{tikzpicture} 44 | 45 | \end{document} 46 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import doctest 3 | 4 | import sphinx_rtd_theme 5 | import torch_scatter 6 | 7 | extensions = [ 8 | 'sphinx.ext.autodoc', 9 | 'sphinx.ext.intersphinx', 10 | 'sphinx.ext.mathjax', 11 | 'sphinx.ext.napoleon', 12 | 'sphinx.ext.viewcode', 13 | ] 14 | 15 | source_suffix = '.rst' 16 | master_doc = 'index' 17 | 18 | author = 'Matthias Fey' 19 | project = 'pytorch_scatter' 20 | copyright = '{}, {}'.format(datetime.datetime.now().year, author) 21 | 22 | version = torch_scatter.__version__ 23 | release = torch_scatter.__version__ 24 | 25 | html_theme = 'sphinx_rtd_theme' 26 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 27 | 28 | doctest_default_flags = doctest.NORMALIZE_WHITESPACE 29 | intersphinx_mapping = {'python': ('https://docs.python.org/', None)} 30 | -------------------------------------------------------------------------------- /docs/source/functions/composite.rst: -------------------------------------------------------------------------------- 1 | composite 2 | ========= 3 | 4 | .. currentmodule:: torch_scatter.composite 5 | 6 | .. automodule:: torch_scatter.composite 7 | :members: 8 | :undoc-members: 9 | -------------------------------------------------------------------------------- /docs/source/functions/scatter.rst: -------------------------------------------------------------------------------- 1 | Scatter 2 | ======= 3 | 4 | .. automodule:: torch_scatter 5 | :noindex: 6 | 7 | .. autofunction:: scatter 8 | -------------------------------------------------------------------------------- /docs/source/functions/segment_coo.rst: -------------------------------------------------------------------------------- 1 | Segment COO 2 | =========== 3 | 4 | .. automodule:: torch_scatter 5 | :noindex: 6 | 7 | .. autofunction:: segment_coo 8 | -------------------------------------------------------------------------------- /docs/source/functions/segment_csr.rst: -------------------------------------------------------------------------------- 1 | Segment CSR 2 | =========== 3 | 4 | .. automodule:: torch_scatter 5 | :noindex: 6 | 7 | .. autofunction:: segment_csr 8 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/rusty1s/pytorch_scatter 2 | 3 | PyTorch Scatter Documentation 4 | ============================= 5 | 6 | This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in `PyTorch `_, which are missing in the main package. 7 | Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. 8 | Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements. 9 | 10 | All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable. 11 | 12 | .. toctree:: 13 | :glob: 14 | :maxdepth: 0 15 | :caption: Package reference 16 | 17 | functions/scatter 18 | functions/segment_coo 19 | functions/segment_csr 20 | functions/composite 21 | 22 | Indices and tables 23 | ================== 24 | 25 | * :ref:`genindex` 26 | * :ref:`modindex` 27 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "torch"] 3 | build-backend = "setuptools.build_meta" 4 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | build: 4 | image: latest 5 | 6 | python: 7 | version: 3.8 8 | system_packages: true 9 | install: 10 | - requirements: docs/requirements.txt 11 | - method: setuptools 12 | path: . 13 | 14 | formats: [] 15 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | long_description=file: README.md 3 | long_description_content_type=text/markdown 4 | 5 | classifiers = 6 | Development Status :: 5 - Production/Stable 7 | License :: OSI Approved :: MIT License 8 | Programming Language :: Python 9 | Programming Language :: Python :: 3.8 10 | Programming Language :: Python :: 3.9 11 | Programming Language :: Python :: 3.10 12 | Programming Language :: Python :: 3.11 13 | Programming Language :: Python :: 3.12 14 | Programming Language :: Python :: 3.13 15 | Programming Language :: Python :: 3 :: Only 16 | 17 | [aliases] 18 | test = pytest 19 | 20 | [tool:pytest] 21 | addopts = --capture=no 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import os.path as osp 4 | import platform 5 | import sys 6 | from itertools import product 7 | 8 | import torch 9 | from setuptools import find_packages, setup 10 | from torch.__config__ import parallel_info 11 | from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, CppExtension, 12 | CUDAExtension) 13 | 14 | __version__ = '2.1.2' 15 | URL = 'https://github.com/rusty1s/pytorch_scatter' 16 | 17 | WITH_CUDA = False 18 | if torch.cuda.is_available(): 19 | WITH_CUDA = CUDA_HOME is not None or torch.version.hip 20 | suffices = ['cpu', 'cuda'] if WITH_CUDA else ['cpu'] 21 | if os.getenv('FORCE_CUDA', '0') == '1': 22 | suffices = ['cuda', 'cpu'] 23 | if os.getenv('FORCE_ONLY_CUDA', '0') == '1': 24 | suffices = ['cuda'] 25 | if os.getenv('FORCE_ONLY_CPU', '0') == '1': 26 | suffices = ['cpu'] 27 | 28 | BUILD_DOCS = os.getenv('BUILD_DOCS', '0') == '1' 29 | WITH_SYMBOLS = os.getenv('WITH_SYMBOLS', '0') == '1' 30 | 31 | 32 | def get_extensions(): 33 | extensions = [] 34 | 35 | extensions_dir = osp.join('csrc') 36 | main_files = glob.glob(osp.join(extensions_dir, '*.cpp')) 37 | # remove generated 'hip' files, in case of rebuilds 38 | main_files = [path for path in main_files if 'hip' not in path] 39 | 40 | for main, suffix in product(main_files, suffices): 41 | define_macros = [('WITH_PYTHON', None)] 42 | undef_macros = [] 43 | 44 | if sys.platform == 'win32': 45 | define_macros += [('torchscatter_EXPORTS', None)] 46 | 47 | extra_compile_args = {'cxx': ['-O3']} 48 | if not os.name == 'nt': # Not on Windows: 49 | extra_compile_args['cxx'] += ['-Wno-sign-compare'] 50 | extra_link_args = [] if WITH_SYMBOLS else ['-s'] 51 | 52 | info = parallel_info() 53 | if ('backend: OpenMP' in info and 'OpenMP not found' not in info 54 | and sys.platform != 'darwin'): 55 | extra_compile_args['cxx'] += ['-DAT_PARALLEL_OPENMP'] 56 | if sys.platform == 'win32': 57 | extra_compile_args['cxx'] += ['/openmp'] 58 | else: 59 | extra_compile_args['cxx'] += ['-fopenmp'] 60 | else: 61 | print('Compiling without OpenMP...') 62 | 63 | # Compile for mac arm64 64 | if sys.platform == 'darwin': 65 | extra_compile_args['cxx'] += ['-D_LIBCPP_DISABLE_AVAILABILITY'] 66 | if platform.machine == 'arm64': 67 | extra_compile_args['cxx'] += ['-arch', 'arm64'] 68 | extra_link_args += ['-arch', 'arm64'] 69 | 70 | if suffix == 'cuda': 71 | define_macros += [('WITH_CUDA', None)] 72 | nvcc_flags = os.getenv('NVCC_FLAGS', '') 73 | nvcc_flags = [] if nvcc_flags == '' else nvcc_flags.split(' ') 74 | nvcc_flags += ['-O3'] 75 | if torch.version.hip: 76 | # USE_ROCM was added to later versions of PyTorch. 77 | # Define here to support older PyTorch versions as well: 78 | define_macros += [('USE_ROCM', None)] 79 | undef_macros += ['__HIP_NO_HALF_CONVERSIONS__'] 80 | else: 81 | nvcc_flags += ['--expt-relaxed-constexpr'] 82 | extra_compile_args['nvcc'] = nvcc_flags 83 | 84 | name = main.split(os.sep)[-1][:-4] 85 | sources = [main] 86 | 87 | path = osp.join(extensions_dir, 'cpu', f'{name}_cpu.cpp') 88 | if osp.exists(path): 89 | sources += [path] 90 | 91 | path = osp.join(extensions_dir, 'cuda', f'{name}_cuda.cu') 92 | if suffix == 'cuda' and osp.exists(path): 93 | sources += [path] 94 | 95 | Extension = CppExtension if suffix == 'cpu' else CUDAExtension 96 | extension = Extension( 97 | f'torch_scatter._{name}_{suffix}', 98 | sources, 99 | include_dirs=[extensions_dir], 100 | define_macros=define_macros, 101 | undef_macros=undef_macros, 102 | extra_compile_args=extra_compile_args, 103 | extra_link_args=extra_link_args, 104 | ) 105 | extensions += [extension] 106 | 107 | return extensions 108 | 109 | 110 | install_requires = [] 111 | 112 | test_requires = [ 113 | 'pytest', 114 | 'pytest-cov', 115 | ] 116 | 117 | # work-around hipify abs paths 118 | include_package_data = True 119 | if torch.cuda.is_available() and torch.version.hip: 120 | include_package_data = False 121 | 122 | setup( 123 | name='torch_scatter', 124 | version=__version__, 125 | description='PyTorch Extension Library of Optimized Scatter Operations', 126 | author='Matthias Fey', 127 | author_email='matthias.fey@tu-dortmund.de', 128 | url=URL, 129 | download_url=f'{URL}/archive/{__version__}.tar.gz', 130 | keywords=['pytorch', 'scatter', 'segment', 'gather'], 131 | python_requires='>=3.8', 132 | install_requires=install_requires, 133 | extras_require={ 134 | 'test': test_requires, 135 | }, 136 | ext_modules=get_extensions() if not BUILD_DOCS else [], 137 | cmdclass={ 138 | 'build_ext': 139 | BuildExtension.with_options(no_python_abi_suffix=True, use_ninja=False) 140 | }, 141 | packages=find_packages(), 142 | include_package_data=include_package_data, 143 | ) 144 | -------------------------------------------------------------------------------- /test/composite/test_logsumexp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter_logsumexp 3 | 4 | 5 | def test_logsumexp(): 6 | inputs = torch.tensor([ 7 | 0.5, 8 | 0.5, 9 | 0.0, 10 | -2.1, 11 | 3.2, 12 | 7.0, 13 | -1.0, 14 | -100.0, 15 | ]) 16 | inputs.requires_grad_() 17 | index = torch.tensor([0, 0, 1, 1, 1, 2, 4, 4]) 18 | splits = [2, 3, 1, 0, 2] 19 | 20 | outputs = scatter_logsumexp(inputs, index) 21 | 22 | for src, out in zip(inputs.split(splits), outputs.unbind()): 23 | if src.numel() > 0: 24 | assert out.tolist() == torch.logsumexp(src, dim=0).tolist() 25 | else: 26 | assert out.item() == 0.0 27 | 28 | outputs.backward(torch.randn_like(outputs)) 29 | 30 | jit = torch.jit.script(scatter_logsumexp) 31 | assert jit(inputs, index).tolist() == outputs.tolist() 32 | 33 | 34 | def test_logsumexp_out(): 35 | src = torch.tensor([-1.0, -50.0]) 36 | index = torch.tensor([0, 0]) 37 | out = torch.tensor([-10.0, -10.0]) 38 | 39 | scatter_logsumexp(src=src, index=index, out=out) 40 | assert out.allclose(torch.tensor([-0.9999, -10.0]), atol=1e-4) 41 | -------------------------------------------------------------------------------- /test/composite/test_softmax.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter_log_softmax, scatter_softmax 3 | 4 | 5 | def test_softmax(): 6 | src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) 7 | src.requires_grad_() 8 | index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) 9 | 10 | out = scatter_softmax(src, index) 11 | 12 | out0 = torch.softmax(torch.tensor([0.2, 0.2]), dim=-1) 13 | out1 = torch.softmax(torch.tensor([0, -2.1, 3.2]), dim=-1) 14 | out2 = torch.softmax(torch.tensor([7], dtype=torch.float), dim=-1) 15 | out4 = torch.softmax(torch.tensor([-1, float('-inf')]), dim=-1) 16 | 17 | expected = torch.stack([ 18 | out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] 19 | ], dim=0) 20 | 21 | assert torch.allclose(out, expected) 22 | 23 | out.backward(torch.randn_like(out)) 24 | 25 | jit = torch.jit.script(scatter_softmax) 26 | assert jit(src, index).tolist() == out.tolist() 27 | 28 | 29 | def test_log_softmax(): 30 | src = torch.tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')]) 31 | src.requires_grad_() 32 | index = torch.tensor([0, 1, 0, 1, 1, 2, 4, 4]) 33 | 34 | out = scatter_log_softmax(src, index) 35 | 36 | out0 = torch.log_softmax(torch.tensor([0.2, 0.2]), dim=-1) 37 | out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2]), dim=-1) 38 | out2 = torch.log_softmax(torch.tensor([7], dtype=torch.float), dim=-1) 39 | out4 = torch.log_softmax(torch.tensor([-1, float('-inf')]), dim=-1) 40 | 41 | expected = torch.stack([ 42 | out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] 43 | ], dim=0) 44 | 45 | assert torch.allclose(out, expected) 46 | 47 | out.backward(torch.randn_like(out)) 48 | 49 | jit = torch.jit.script(scatter_log_softmax) 50 | assert jit(src, index).tolist() == out.tolist() 51 | -------------------------------------------------------------------------------- /test/composite/test_std.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch_scatter import scatter_std 3 | 4 | 5 | def test_std(): 6 | src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]], dtype=torch.float) 7 | src.requires_grad_() 8 | index = torch.tensor([[0, 0, 0, 0, 0], [1, 1, 1, 1, 1]], dtype=torch.long) 9 | 10 | out = scatter_std(src, index, dim=-1, unbiased=True) 11 | std = src.std(dim=-1, unbiased=True)[0] 12 | expected = torch.tensor([[std, 0], [0, std]]) 13 | assert torch.allclose(out, expected) 14 | 15 | out.backward(torch.randn_like(out)) 16 | 17 | jit = torch.jit.script(scatter_std) 18 | assert jit(src, index, dim=-1, unbiased=True).tolist() == out.tolist() 19 | -------------------------------------------------------------------------------- /test/test_broadcasting.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | import torch 5 | from torch_scatter import scatter 6 | from torch_scatter.testing import devices, reductions 7 | 8 | 9 | @pytest.mark.parametrize('reduce,device', product(reductions, devices)) 10 | def test_broadcasting(reduce, device): 11 | B, C, H, W = (4, 3, 8, 8) 12 | 13 | src = torch.randn((B, C, H, W), device=device) 14 | index = torch.randint(0, H, (H, )).to(device, torch.long) 15 | out = scatter(src, index, dim=2, dim_size=H, reduce=reduce) 16 | assert out.size() == (B, C, H, W) 17 | 18 | src = torch.randn((B, C, H, W), device=device) 19 | index = torch.randint(0, H, (B, 1, H, W)).to(device, torch.long) 20 | out = scatter(src, index, dim=2, dim_size=H, reduce=reduce) 21 | assert out.size() == (B, C, H, W) 22 | 23 | src = torch.randn((B, C, H, W), device=device) 24 | index = torch.randint(0, H, (H, )).to(device, torch.long) 25 | out = scatter(src, index, dim=2, dim_size=H, reduce=reduce) 26 | assert out.size() == (B, C, H, W) 27 | -------------------------------------------------------------------------------- /test/test_gather.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | import torch 5 | from torch.autograd import gradcheck 6 | from torch_scatter import gather_coo, gather_csr 7 | from torch_scatter.testing import devices, dtypes, tensor 8 | 9 | tests = [ 10 | { 11 | 'src': [1, 2, 3, 4], 12 | 'index': [0, 0, 1, 1, 1, 3], 13 | 'indptr': [0, 2, 5, 5, 6], 14 | 'expected': [1, 1, 2, 2, 2, 4], 15 | }, 16 | { 17 | 'src': [[1, 2], [3, 4], [5, 6], [7, 8]], 18 | 'index': [0, 0, 1, 1, 1, 3], 19 | 'indptr': [0, 2, 5, 5, 6], 20 | 'expected': [[1, 2], [1, 2], [3, 4], [3, 4], [3, 4], [7, 8]] 21 | }, 22 | { 23 | 'src': [[1, 3, 5, 7], [2, 4, 6, 8]], 24 | 'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], 25 | 'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], 26 | 'expected': [[1, 1, 3, 3, 3, 7], [2, 2, 2, 4, 4, 6]], 27 | }, 28 | { 29 | 'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]], 30 | 'index': [[0, 0, 1], [0, 2, 2]], 31 | 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]], 32 | 'expected': [[[1, 2], [1, 2], [3, 4]], [[7, 9], [12, 13], [12, 13]]], 33 | }, 34 | { 35 | 'src': [[1], [2]], 36 | 'index': [[0, 0], [0, 0]], 37 | 'indptr': [[0, 2], [0, 2]], 38 | 'expected': [[1, 1], [2, 2]], 39 | }, 40 | { 41 | 'src': [[[1, 1]], [[2, 2]]], 42 | 'index': [[0, 0], [0, 0]], 43 | 'indptr': [[0, 2], [0, 2]], 44 | 'expected': [[[1, 1], [1, 1]], [[2, 2], [2, 2]]], 45 | }, 46 | ] 47 | 48 | 49 | @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) 50 | def test_forward(test, dtype, device): 51 | src = tensor(test['src'], dtype, device) 52 | index = tensor(test['index'], torch.long, device) 53 | indptr = tensor(test['indptr'], torch.long, device) 54 | expected = tensor(test['expected'], dtype, device) 55 | 56 | out = gather_csr(src, indptr) 57 | assert torch.all(out == expected) 58 | 59 | out = gather_coo(src, index) 60 | assert torch.all(out == expected) 61 | 62 | 63 | @pytest.mark.parametrize('test,device', product(tests, devices)) 64 | def test_backward(test, device): 65 | src = tensor(test['src'], torch.double, device) 66 | src.requires_grad_() 67 | index = tensor(test['index'], torch.long, device) 68 | indptr = tensor(test['indptr'], torch.long, device) 69 | 70 | assert gradcheck(gather_csr, (src, indptr, None)) is True 71 | assert gradcheck(gather_coo, (src, index, None)) is True 72 | 73 | 74 | @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) 75 | def test_out(test, dtype, device): 76 | src = tensor(test['src'], dtype, device) 77 | index = tensor(test['index'], torch.long, device) 78 | indptr = tensor(test['indptr'], torch.long, device) 79 | expected = tensor(test['expected'], dtype, device) 80 | 81 | size = list(src.size()) 82 | size[index.dim() - 1] = index.size(-1) 83 | out = src.new_full(size, -2) 84 | 85 | gather_csr(src, indptr, out) 86 | assert torch.all(out == expected) 87 | 88 | out.fill_(-2) 89 | 90 | gather_coo(src, index, out) 91 | assert torch.all(out == expected) 92 | 93 | 94 | @pytest.mark.parametrize('test,dtype,device', product(tests, dtypes, devices)) 95 | def test_non_contiguous(test, dtype, device): 96 | src = tensor(test['src'], dtype, device) 97 | index = tensor(test['index'], torch.long, device) 98 | indptr = tensor(test['indptr'], torch.long, device) 99 | expected = tensor(test['expected'], dtype, device) 100 | 101 | if src.dim() > 1: 102 | src = src.transpose(0, 1).contiguous().transpose(0, 1) 103 | if index.dim() > 1: 104 | index = index.transpose(0, 1).contiguous().transpose(0, 1) 105 | if indptr.dim() > 1: 106 | indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) 107 | 108 | out = gather_csr(src, indptr) 109 | assert torch.all(out == expected) 110 | 111 | out = gather_coo(src, index) 112 | assert torch.all(out == expected) 113 | -------------------------------------------------------------------------------- /test/test_multi_gpu.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | import torch 5 | import torch_scatter 6 | from torch_scatter.testing import dtypes, reductions, tensor 7 | 8 | tests = [ 9 | { 10 | 'src': [1, 2, 3, 4, 5, 6], 11 | 'index': [0, 0, 1, 1, 1, 3], 12 | 'indptr': [0, 2, 5, 5, 6], 13 | 'dim': 0, 14 | 'sum': [3, 12, 0, 6], 15 | 'add': [3, 12, 0, 6], 16 | 'mean': [1.5, 4, 0, 6], 17 | 'min': [1, 3, 0, 6], 18 | 'max': [2, 5, 0, 6], 19 | }, 20 | ] 21 | 22 | 23 | @pytest.mark.skipif(not torch.cuda.is_available(), reason='CUDA not available') 24 | @pytest.mark.skipif(torch.cuda.device_count() < 2, reason='No multiple GPUS') 25 | @pytest.mark.parametrize('test,reduce,dtype', product(tests, reductions, 26 | dtypes)) 27 | def test_forward(test, reduce, dtype): 28 | device = torch.device('cuda:1') 29 | src = tensor(test['src'], dtype, device) 30 | index = tensor(test['index'], torch.long, device) 31 | indptr = tensor(test['indptr'], torch.long, device) 32 | dim = test['dim'] 33 | expected = tensor(test[reduce], dtype, device) 34 | 35 | out = torch_scatter.scatter(src, index, dim, reduce=reduce) 36 | assert torch.all(out == expected) 37 | 38 | out = torch_scatter.segment_coo(src, index, reduce=reduce) 39 | assert torch.all(out == expected) 40 | 41 | out = torch_scatter.segment_csr(src, indptr, reduce=reduce) 42 | assert torch.all(out == expected) 43 | -------------------------------------------------------------------------------- /test/test_scatter.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | import torch 5 | import torch_scatter 6 | from torch.autograd import gradcheck 7 | from torch_scatter.testing import devices, dtypes, reductions, tensor 8 | 9 | reductions = reductions + ['mul'] 10 | 11 | tests = [ 12 | { 13 | 'src': [1, 3, 2, 4, 5, 6], 14 | 'index': [0, 1, 0, 1, 1, 3], 15 | 'dim': -1, 16 | 'sum': [3, 12, 0, 6], 17 | 'add': [3, 12, 0, 6], 18 | 'mul': [2, 60, 1, 6], 19 | 'mean': [1.5, 4, 0, 6], 20 | 'min': [1, 3, 0, 6], 21 | 'arg_min': [0, 1, 6, 5], 22 | 'max': [2, 5, 0, 6], 23 | 'arg_max': [2, 4, 6, 5], 24 | }, 25 | { 26 | 'src': [[1, 2], [5, 6], [3, 4], [7, 8], [9, 10], [11, 12]], 27 | 'index': [0, 1, 0, 1, 1, 3], 28 | 'dim': 0, 29 | 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], 30 | 'add': [[4, 6], [21, 24], [0, 0], [11, 12]], 31 | 'mul': [[1 * 3, 2 * 4], [5 * 7 * 9, 6 * 8 * 10], [1, 1], [11, 12]], 32 | 'mean': [[2, 3], [7, 8], [0, 0], [11, 12]], 33 | 'min': [[1, 2], [5, 6], [0, 0], [11, 12]], 34 | 'arg_min': [[0, 0], [1, 1], [6, 6], [5, 5]], 35 | 'max': [[3, 4], [9, 10], [0, 0], [11, 12]], 36 | 'arg_max': [[2, 2], [4, 4], [6, 6], [5, 5]], 37 | }, 38 | { 39 | 'src': [[1, 5, 3, 7, 9, 11], [2, 4, 8, 6, 10, 12]], 40 | 'index': [[0, 1, 0, 1, 1, 3], [0, 0, 1, 0, 1, 2]], 41 | 'dim': 1, 42 | 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], 43 | 'add': [[4, 21, 0, 11], [12, 18, 12, 0]], 44 | 'mul': [[1 * 3, 5 * 7 * 9, 1, 11], [2 * 4 * 6, 8 * 10, 12, 1]], 45 | 'mean': [[2, 7, 0, 11], [4, 9, 12, 0]], 46 | 'min': [[1, 5, 0, 11], [2, 8, 12, 0]], 47 | 'arg_min': [[0, 1, 6, 5], [0, 2, 5, 6]], 48 | 'max': [[3, 9, 0, 11], [6, 10, 12, 0]], 49 | 'arg_max': [[2, 4, 6, 5], [3, 4, 5, 6]], 50 | }, 51 | { 52 | 'src': [[[1, 2], [5, 6], [3, 4]], [[10, 11], [7, 9], [12, 13]]], 53 | 'index': [[0, 1, 0], [2, 0, 2]], 54 | 'dim': 1, 55 | 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 56 | 'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 57 | 'mul': [[[3, 8], [5, 6], [1, 1]], [[7, 9], [1, 1], [120, 11 * 13]]], 58 | 'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], 59 | 'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], 60 | 'arg_min': [[[0, 0], [1, 1], [3, 3]], [[1, 1], [3, 3], [0, 0]]], 61 | 'max': [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]], 62 | 'arg_max': [[[2, 2], [1, 1], [3, 3]], [[1, 1], [3, 3], [2, 2]]], 63 | }, 64 | { 65 | 'src': [[1, 3], [2, 4]], 66 | 'index': [[0, 0], [0, 0]], 67 | 'dim': 1, 68 | 'sum': [[4], [6]], 69 | 'add': [[4], [6]], 70 | 'mul': [[3], [8]], 71 | 'mean': [[2], [3]], 72 | 'min': [[1], [2]], 73 | 'arg_min': [[0], [0]], 74 | 'max': [[3], [4]], 75 | 'arg_max': [[1], [1]], 76 | }, 77 | { 78 | 'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], 79 | 'index': [[0, 0], [0, 0]], 80 | 'dim': 1, 81 | 'sum': [[[4, 4]], [[6, 6]]], 82 | 'add': [[[4, 4]], [[6, 6]]], 83 | 'mul': [[[3, 3]], [[8, 8]]], 84 | 'mean': [[[2, 2]], [[3, 3]]], 85 | 'min': [[[1, 1]], [[2, 2]]], 86 | 'arg_min': [[[0, 0]], [[0, 0]]], 87 | 'max': [[[3, 3]], [[4, 4]]], 88 | 'arg_max': [[[1, 1]], [[1, 1]]], 89 | }, 90 | ] 91 | 92 | 93 | @pytest.mark.parametrize('test,reduce,dtype,device', 94 | product(tests, reductions, dtypes, devices)) 95 | def test_forward(test, reduce, dtype, device): 96 | src = tensor(test['src'], dtype, device) 97 | index = tensor(test['index'], torch.long, device) 98 | dim = test['dim'] 99 | expected = tensor(test[reduce], dtype, device) 100 | 101 | fn = getattr(torch_scatter, 'scatter_' + reduce) 102 | jit = torch.jit.script(fn) 103 | out1 = fn(src, index, dim) 104 | out2 = jit(src, index, dim) 105 | if isinstance(out1, tuple): 106 | out1, arg_out1 = out1 107 | out2, arg_out2 = out2 108 | arg_expected = tensor(test['arg_' + reduce], torch.long, device) 109 | assert torch.all(arg_out1 == arg_expected) 110 | assert arg_out1.tolist() == arg_out1.tolist() 111 | assert torch.all(out1 == expected) 112 | assert out1.tolist() == out2.tolist() 113 | 114 | 115 | @pytest.mark.parametrize('test,reduce,device', 116 | product(tests, reductions, devices)) 117 | def test_backward(test, reduce, device): 118 | src = tensor(test['src'], torch.double, device) 119 | src.requires_grad_() 120 | index = tensor(test['index'], torch.long, device) 121 | dim = test['dim'] 122 | 123 | assert gradcheck(torch_scatter.scatter, 124 | (src, index, dim, None, None, reduce)) 125 | 126 | 127 | @pytest.mark.parametrize('test,reduce,dtype,device', 128 | product(tests, reductions, dtypes, devices)) 129 | def test_out(test, reduce, dtype, device): 130 | src = tensor(test['src'], dtype, device) 131 | index = tensor(test['index'], torch.long, device) 132 | dim = test['dim'] 133 | expected = tensor(test[reduce], dtype, device) 134 | 135 | out = torch.full_like(expected, -2) 136 | 137 | getattr(torch_scatter, 'scatter_' + reduce)(src, index, dim, out) 138 | 139 | if reduce == 'sum' or reduce == 'add': 140 | expected = expected - 2 141 | elif reduce == 'mul': 142 | expected = out # We can not really test this here. 143 | elif reduce == 'mean': 144 | expected = out # We can not really test this here. 145 | elif reduce == 'min': 146 | expected = expected.fill_(-2) 147 | elif reduce == 'max': 148 | expected[expected == 0] = -2 149 | else: 150 | raise ValueError 151 | 152 | assert torch.all(out == expected) 153 | 154 | 155 | @pytest.mark.parametrize('test,reduce,dtype,device', 156 | product(tests, reductions, dtypes, devices)) 157 | def test_non_contiguous(test, reduce, dtype, device): 158 | src = tensor(test['src'], dtype, device) 159 | index = tensor(test['index'], torch.long, device) 160 | dim = test['dim'] 161 | expected = tensor(test[reduce], dtype, device) 162 | 163 | if src.dim() > 1: 164 | src = src.transpose(0, 1).contiguous().transpose(0, 1) 165 | if index.dim() > 1: 166 | index = index.transpose(0, 1).contiguous().transpose(0, 1) 167 | 168 | out = getattr(torch_scatter, 'scatter_' + reduce)(src, index, dim) 169 | if isinstance(out, tuple): 170 | out, arg_out = out 171 | arg_expected = tensor(test['arg_' + reduce], torch.long, device) 172 | assert torch.all(arg_out == arg_expected) 173 | assert torch.all(out == expected) 174 | -------------------------------------------------------------------------------- /test/test_segment.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | import torch 5 | import torch_scatter 6 | from torch.autograd import gradcheck 7 | from torch_scatter.testing import devices, dtypes, reductions, tensor 8 | 9 | tests = [ 10 | { 11 | 'src': [1, 2, 3, 4, 5, 6], 12 | 'index': [0, 0, 1, 1, 1, 3], 13 | 'indptr': [0, 2, 5, 5, 6], 14 | 'sum': [3, 12, 0, 6], 15 | 'add': [3, 12, 0, 6], 16 | 'mean': [1.5, 4, 0, 6], 17 | 'min': [1, 3, 0, 6], 18 | 'arg_min': [0, 2, 6, 5], 19 | 'max': [2, 5, 0, 6], 20 | 'arg_max': [1, 4, 6, 5], 21 | }, 22 | { 23 | 'src': [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], 24 | 'index': [0, 0, 1, 1, 1, 3], 25 | 'indptr': [0, 2, 5, 5, 6], 26 | 'sum': [[4, 6], [21, 24], [0, 0], [11, 12]], 27 | 'add': [[4, 6], [21, 24], [0, 0], [11, 12]], 28 | 'mean': [[2, 3], [7, 8], [0, 0], [11, 12]], 29 | 'min': [[1, 2], [5, 6], [0, 0], [11, 12]], 30 | 'arg_min': [[0, 0], [2, 2], [6, 6], [5, 5]], 31 | 'max': [[3, 4], [9, 10], [0, 0], [11, 12]], 32 | 'arg_max': [[1, 1], [4, 4], [6, 6], [5, 5]], 33 | }, 34 | { 35 | 'src': [[1, 3, 5, 7, 9, 11], [2, 4, 6, 8, 10, 12]], 36 | 'index': [[0, 0, 1, 1, 1, 3], [0, 0, 0, 1, 1, 2]], 37 | 'indptr': [[0, 2, 5, 5, 6], [0, 3, 5, 6, 6]], 38 | 'sum': [[4, 21, 0, 11], [12, 18, 12, 0]], 39 | 'add': [[4, 21, 0, 11], [12, 18, 12, 0]], 40 | 'mean': [[2, 7, 0, 11], [4, 9, 12, 0]], 41 | 'min': [[1, 5, 0, 11], [2, 8, 12, 0]], 42 | 'arg_min': [[0, 2, 6, 5], [0, 3, 5, 6]], 43 | 'max': [[3, 9, 0, 11], [6, 10, 12, 0]], 44 | 'arg_max': [[1, 4, 6, 5], [2, 4, 5, 6]], 45 | }, 46 | { 47 | 'src': [[[1, 2], [3, 4], [5, 6]], [[7, 9], [10, 11], [12, 13]]], 48 | 'index': [[0, 0, 1], [0, 2, 2]], 49 | 'indptr': [[0, 2, 3, 3], [0, 1, 1, 3]], 50 | 'sum': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 51 | 'add': [[[4, 6], [5, 6], [0, 0]], [[7, 9], [0, 0], [22, 24]]], 52 | 'mean': [[[2, 3], [5, 6], [0, 0]], [[7, 9], [0, 0], [11, 12]]], 53 | 'min': [[[1, 2], [5, 6], [0, 0]], [[7, 9], [0, 0], [10, 11]]], 54 | 'arg_min': [[[0, 0], [2, 2], [3, 3]], [[0, 0], [3, 3], [1, 1]]], 55 | 'max': [[[3, 4], [5, 6], [0, 0]], [[7, 9], [0, 0], [12, 13]]], 56 | 'arg_max': [[[1, 1], [2, 2], [3, 3]], [[0, 0], [3, 3], [2, 2]]], 57 | }, 58 | { 59 | 'src': [[1, 3], [2, 4]], 60 | 'index': [[0, 0], [0, 0]], 61 | 'indptr': [[0, 2], [0, 2]], 62 | 'sum': [[4], [6]], 63 | 'add': [[4], [6]], 64 | 'mean': [[2], [3]], 65 | 'min': [[1], [2]], 66 | 'arg_min': [[0], [0]], 67 | 'max': [[3], [4]], 68 | 'arg_max': [[1], [1]], 69 | }, 70 | { 71 | 'src': [[[1, 1], [3, 3]], [[2, 2], [4, 4]]], 72 | 'index': [[0, 0], [0, 0]], 73 | 'indptr': [[0, 2], [0, 2]], 74 | 'sum': [[[4, 4]], [[6, 6]]], 75 | 'add': [[[4, 4]], [[6, 6]]], 76 | 'mean': [[[2, 2]], [[3, 3]]], 77 | 'min': [[[1, 1]], [[2, 2]]], 78 | 'arg_min': [[[0, 0]], [[0, 0]]], 79 | 'max': [[[3, 3]], [[4, 4]]], 80 | 'arg_max': [[[1, 1]], [[1, 1]]], 81 | }, 82 | ] 83 | 84 | 85 | @pytest.mark.parametrize('test,reduce,dtype,device', 86 | product(tests, reductions, dtypes, devices)) 87 | def test_forward(test, reduce, dtype, device): 88 | src = tensor(test['src'], dtype, device) 89 | index = tensor(test['index'], torch.long, device) 90 | indptr = tensor(test['indptr'], torch.long, device) 91 | expected = tensor(test[reduce], dtype, device) 92 | 93 | fn = getattr(torch_scatter, 'segment_' + reduce + '_csr') 94 | jit = torch.jit.script(fn) 95 | out1 = fn(src, indptr) 96 | out2 = jit(src, indptr) 97 | if isinstance(out1, tuple): 98 | out1, arg_out1 = out1 99 | out2, arg_out2 = out2 100 | arg_expected = tensor(test['arg_' + reduce], torch.long, device) 101 | assert torch.all(arg_out1 == arg_expected) 102 | assert arg_out1.tolist() == arg_out2.tolist() 103 | assert torch.all(out1 == expected) 104 | assert out1.tolist() == out2.tolist() 105 | 106 | fn = getattr(torch_scatter, 'segment_' + reduce + '_coo') 107 | jit = torch.jit.script(fn) 108 | out1 = fn(src, index) 109 | out2 = jit(src, index) 110 | if isinstance(out1, tuple): 111 | out1, arg_out1 = out1 112 | out2, arg_out2 = out2 113 | arg_expected = tensor(test['arg_' + reduce], torch.long, device) 114 | assert torch.all(arg_out1 == arg_expected) 115 | assert arg_out1.tolist() == arg_out2.tolist() 116 | assert torch.all(out1 == expected) 117 | assert out1.tolist() == out2.tolist() 118 | 119 | 120 | @pytest.mark.parametrize('test,reduce,device', 121 | product(tests, reductions, devices)) 122 | def test_backward(test, reduce, device): 123 | src = tensor(test['src'], torch.double, device) 124 | src.requires_grad_() 125 | index = tensor(test['index'], torch.long, device) 126 | indptr = tensor(test['indptr'], torch.long, device) 127 | 128 | assert gradcheck(torch_scatter.segment_csr, (src, indptr, None, reduce)) 129 | assert gradcheck(torch_scatter.segment_coo, 130 | (src, index, None, None, reduce)) 131 | 132 | 133 | @pytest.mark.parametrize('test,reduce,dtype,device', 134 | product(tests, reductions, dtypes, devices)) 135 | def test_out(test, reduce, dtype, device): 136 | src = tensor(test['src'], dtype, device) 137 | index = tensor(test['index'], torch.long, device) 138 | indptr = tensor(test['indptr'], torch.long, device) 139 | expected = tensor(test[reduce], dtype, device) 140 | 141 | out = torch.full_like(expected, -2) 142 | 143 | getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr, out) 144 | assert torch.all(out == expected) 145 | 146 | out.fill_(-2) 147 | 148 | getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index, out) 149 | 150 | if reduce == 'sum' or reduce == 'add': 151 | expected = expected - 2 152 | elif reduce == 'mean': 153 | expected = out # We can not really test this here. 154 | elif reduce == 'min': 155 | expected = expected.fill_(-2) 156 | elif reduce == 'max': 157 | expected[expected == 0] = -2 158 | else: 159 | raise ValueError 160 | 161 | assert torch.all(out == expected) 162 | 163 | 164 | @pytest.mark.parametrize('test,reduce,dtype,device', 165 | product(tests, reductions, dtypes, devices)) 166 | def test_non_contiguous(test, reduce, dtype, device): 167 | src = tensor(test['src'], dtype, device) 168 | index = tensor(test['index'], torch.long, device) 169 | indptr = tensor(test['indptr'], torch.long, device) 170 | expected = tensor(test[reduce], dtype, device) 171 | 172 | if src.dim() > 1: 173 | src = src.transpose(0, 1).contiguous().transpose(0, 1) 174 | if index.dim() > 1: 175 | index = index.transpose(0, 1).contiguous().transpose(0, 1) 176 | if indptr.dim() > 1: 177 | indptr = indptr.transpose(0, 1).contiguous().transpose(0, 1) 178 | 179 | out = getattr(torch_scatter, 'segment_' + reduce + '_csr')(src, indptr) 180 | if isinstance(out, tuple): 181 | out, arg_out = out 182 | arg_expected = tensor(test['arg_' + reduce], torch.long, device) 183 | assert torch.all(arg_out == arg_expected) 184 | assert torch.all(out == expected) 185 | 186 | out = getattr(torch_scatter, 'segment_' + reduce + '_coo')(src, index) 187 | if isinstance(out, tuple): 188 | out, arg_out = out 189 | arg_expected = tensor(test['arg_' + reduce], torch.long, device) 190 | assert torch.all(arg_out == arg_expected) 191 | assert torch.all(out == expected) 192 | -------------------------------------------------------------------------------- /test/test_zero_tensors.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | 3 | import pytest 4 | import torch 5 | from torch_scatter import (gather_coo, gather_csr, scatter, segment_coo, 6 | segment_csr) 7 | from torch_scatter.testing import devices, grad_dtypes, reductions, tensor 8 | 9 | 10 | @pytest.mark.parametrize('reduce,dtype,device', 11 | product(reductions, grad_dtypes, devices)) 12 | def test_zero_elements(reduce, dtype, device): 13 | x = torch.randn(0, 0, 0, 16, dtype=dtype, device=device, 14 | requires_grad=True) 15 | index = tensor([], torch.long, device) 16 | indptr = tensor([], torch.long, device) 17 | 18 | out = scatter(x, index, dim=0, dim_size=0, reduce=reduce) 19 | out.backward(torch.randn_like(out)) 20 | assert out.size() == (0, 0, 0, 16) 21 | 22 | out = segment_coo(x, index, dim_size=0, reduce=reduce) 23 | out.backward(torch.randn_like(out)) 24 | assert out.size() == (0, 0, 0, 16) 25 | 26 | out = gather_coo(x, index) 27 | out.backward(torch.randn_like(out)) 28 | assert out.size() == (0, 0, 0, 16) 29 | 30 | out = segment_csr(x, indptr, reduce=reduce) 31 | out.backward(torch.randn_like(out)) 32 | assert out.size() == (0, 0, 0, 16) 33 | 34 | out = gather_csr(x, indptr) 35 | out.backward(torch.randn_like(out)) 36 | assert out.size() == (0, 0, 0, 16) 37 | -------------------------------------------------------------------------------- /torch_scatter/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import os 3 | import os.path as osp 4 | 5 | import torch 6 | 7 | __version__ = '2.1.2' 8 | 9 | for library in ['_version', '_scatter', '_segment_csr', '_segment_coo']: 10 | cuda_spec = importlib.machinery.PathFinder().find_spec( 11 | f'{library}_cuda', [osp.dirname(__file__)]) 12 | cpu_spec = importlib.machinery.PathFinder().find_spec( 13 | f'{library}_cpu', [osp.dirname(__file__)]) 14 | spec = cuda_spec or cpu_spec 15 | if spec is not None: 16 | torch.ops.load_library(spec.origin) 17 | elif os.getenv('BUILD_DOCS', '0') != '1': # pragma: no cover 18 | raise ImportError(f"Could not find module '{library}_cpu' in " 19 | f"{osp.dirname(__file__)}") 20 | else: # pragma: no cover 21 | from .placeholder import cuda_version_placeholder 22 | torch.ops.torch_scatter.cuda_version = cuda_version_placeholder 23 | 24 | from .placeholder import scatter_placeholder 25 | torch.ops.torch_scatter.scatter_mul = scatter_placeholder 26 | 27 | from .placeholder import scatter_arg_placeholder 28 | torch.ops.torch_scatter.scatter_min = scatter_arg_placeholder 29 | torch.ops.torch_scatter.scatter_max = scatter_arg_placeholder 30 | 31 | from .placeholder import (gather_csr_placeholder, 32 | segment_csr_arg_placeholder, 33 | segment_csr_placeholder) 34 | torch.ops.torch_scatter.segment_sum_csr = segment_csr_placeholder 35 | torch.ops.torch_scatter.segment_mean_csr = segment_csr_placeholder 36 | torch.ops.torch_scatter.segment_min_csr = segment_csr_arg_placeholder 37 | torch.ops.torch_scatter.segment_max_csr = segment_csr_arg_placeholder 38 | torch.ops.torch_scatter.gather_csr = gather_csr_placeholder 39 | 40 | from .placeholder import (gather_coo_placeholder, 41 | segment_coo_arg_placeholder, 42 | segment_coo_placeholder) 43 | torch.ops.torch_scatter.segment_sum_coo = segment_coo_placeholder 44 | torch.ops.torch_scatter.segment_mean_coo = segment_coo_placeholder 45 | torch.ops.torch_scatter.segment_min_coo = segment_coo_arg_placeholder 46 | torch.ops.torch_scatter.segment_max_coo = segment_coo_arg_placeholder 47 | torch.ops.torch_scatter.gather_coo = gather_coo_placeholder 48 | 49 | cuda_version = torch.ops.torch_scatter.cuda_version() 50 | is_not_hip = torch.version.hip is None 51 | is_cuda = torch.version.cuda is not None 52 | if is_not_hip and is_cuda and cuda_version != -1: # pragma: no cover 53 | if cuda_version < 10000: 54 | major, minor = int(str(cuda_version)[0]), int(str(cuda_version)[2]) 55 | else: 56 | major, minor = int(str(cuda_version)[0:2]), int(str(cuda_version)[3]) 57 | t_major, t_minor = [int(x) for x in torch.version.cuda.split('.')] 58 | 59 | if t_major != major: 60 | raise RuntimeError( 61 | f'Detected that PyTorch and torch_scatter were compiled with ' 62 | f'different CUDA versions. PyTorch has CUDA version ' 63 | f'{t_major}.{t_minor} and torch_scatter has CUDA version ' 64 | f'{major}.{minor}. Please reinstall the torch_scatter that ' 65 | f'matches your PyTorch install.') 66 | 67 | from .scatter import scatter_sum, scatter_add, scatter_mul # noqa 68 | from .scatter import scatter_mean, scatter_min, scatter_max, scatter # noqa 69 | from .segment_csr import segment_sum_csr, segment_add_csr # noqa 70 | from .segment_csr import segment_mean_csr, segment_min_csr # noqa 71 | from .segment_csr import segment_max_csr, segment_csr, gather_csr # noqa 72 | from .segment_coo import segment_sum_coo, segment_add_coo # noqa 73 | from .segment_coo import segment_mean_coo, segment_min_coo # noqa 74 | from .segment_coo import segment_max_coo, segment_coo, gather_coo # noqa 75 | from .composite import scatter_std, scatter_logsumexp # noqa 76 | from .composite import scatter_softmax, scatter_log_softmax # noqa 77 | 78 | __all__ = [ 79 | 'scatter_sum', 80 | 'scatter_add', 81 | 'scatter_mul', 82 | 'scatter_mean', 83 | 'scatter_min', 84 | 'scatter_max', 85 | 'scatter', 86 | 'segment_sum_csr', 87 | 'segment_add_csr', 88 | 'segment_mean_csr', 89 | 'segment_min_csr', 90 | 'segment_max_csr', 91 | 'segment_csr', 92 | 'gather_csr', 93 | 'segment_sum_coo', 94 | 'segment_add_coo', 95 | 'segment_mean_coo', 96 | 'segment_min_coo', 97 | 'segment_max_coo', 98 | 'segment_coo', 99 | 'gather_coo', 100 | 'scatter_std', 101 | 'scatter_logsumexp', 102 | 'scatter_softmax', 103 | 'scatter_log_softmax', 104 | 'torch_scatter', 105 | '__version__', 106 | ] 107 | -------------------------------------------------------------------------------- /torch_scatter/composite/__init__.py: -------------------------------------------------------------------------------- 1 | from .std import scatter_std 2 | from .logsumexp import scatter_logsumexp 3 | from .softmax import scatter_log_softmax, scatter_softmax 4 | 5 | __all__ = [ 6 | 'scatter_std', 7 | 'scatter_logsumexp', 8 | 'scatter_softmax', 9 | 'scatter_log_softmax', 10 | ] 11 | -------------------------------------------------------------------------------- /torch_scatter/composite/logsumexp.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch_scatter import scatter_max, scatter_sum 5 | from torch_scatter.utils import broadcast 6 | 7 | 8 | def scatter_logsumexp( 9 | src: torch.Tensor, 10 | index: torch.Tensor, 11 | dim: int = -1, 12 | out: Optional[torch.Tensor] = None, 13 | dim_size: Optional[int] = None, 14 | eps: float = 1e-12, 15 | ) -> torch.Tensor: 16 | if not torch.is_floating_point(src): 17 | raise ValueError('`scatter_logsumexp` can only be computed over ' 18 | 'tensors with floating point data types.') 19 | 20 | index = broadcast(index, src, dim) 21 | 22 | if out is not None: 23 | dim_size = out.size(dim) 24 | else: 25 | if dim_size is None: 26 | dim_size = int(index.max()) + 1 27 | 28 | size = list(src.size()) 29 | size[dim] = dim_size 30 | max_value_per_index = torch.full( 31 | size, 32 | fill_value=float('-inf'), 33 | dtype=src.dtype, 34 | device=src.device, 35 | ) 36 | scatter_max(src, index, dim, max_value_per_index, dim_size=dim_size)[0] 37 | max_per_src_element = max_value_per_index.gather(dim, index) 38 | recentered_score = src - max_per_src_element 39 | recentered_score.masked_fill_(torch.isnan(recentered_score), float('-inf')) 40 | 41 | orig_out: Optional[torch.Tensor] = None 42 | if out is not None: 43 | orig_out = out.clone() 44 | out = out.sub_(max_value_per_index).exp_() 45 | 46 | sum_per_index = scatter_sum(recentered_score.exp_(), index, dim, out, 47 | dim_size) 48 | 49 | out = sum_per_index.add_(eps).log_().add_(max_value_per_index) 50 | 51 | if orig_out is None: 52 | return out.nan_to_num_(neginf=0.0) 53 | 54 | mask = ~out.isfinite() 55 | out[mask] = orig_out[mask] 56 | return out 57 | -------------------------------------------------------------------------------- /torch_scatter/composite/softmax.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | 5 | from torch_scatter import scatter_sum, scatter_max 6 | from torch_scatter.utils import broadcast 7 | 8 | 9 | def scatter_softmax(src: torch.Tensor, index: torch.Tensor, 10 | dim: int = -1, 11 | dim_size: Optional[int] = None) -> torch.Tensor: 12 | if not torch.is_floating_point(src): 13 | raise ValueError('`scatter_softmax` can only be computed over tensors ' 14 | 'with floating point data types.') 15 | 16 | index = broadcast(index, src, dim) 17 | 18 | max_value_per_index = scatter_max( 19 | src, index, dim=dim, dim_size=dim_size)[0] 20 | max_per_src_element = max_value_per_index.gather(dim, index) 21 | 22 | recentered_scores = src - max_per_src_element 23 | recentered_scores_exp = recentered_scores.exp_() 24 | 25 | sum_per_index = scatter_sum( 26 | recentered_scores_exp, index, dim, dim_size=dim_size) 27 | normalizing_constants = sum_per_index.gather(dim, index) 28 | 29 | return recentered_scores_exp.div(normalizing_constants) 30 | 31 | 32 | def scatter_log_softmax(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 33 | eps: float = 1e-12, 34 | dim_size: Optional[int] = None) -> torch.Tensor: 35 | if not torch.is_floating_point(src): 36 | raise ValueError('`scatter_log_softmax` can only be computed over ' 37 | 'tensors with floating point data types.') 38 | 39 | index = broadcast(index, src, dim) 40 | 41 | max_value_per_index = scatter_max( 42 | src, index, dim=dim, dim_size=dim_size)[0] 43 | max_per_src_element = max_value_per_index.gather(dim, index) 44 | 45 | recentered_scores = src - max_per_src_element 46 | 47 | sum_per_index = scatter_sum( 48 | recentered_scores.exp(), index, dim, dim_size=dim_size) 49 | normalizing_constants = sum_per_index.add_(eps).log_().gather(dim, index) 50 | 51 | return recentered_scores.sub_(normalizing_constants) 52 | -------------------------------------------------------------------------------- /torch_scatter/composite/std.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch_scatter import scatter_sum 5 | from torch_scatter.utils import broadcast 6 | 7 | 8 | def scatter_std(src: torch.Tensor, index: torch.Tensor, dim: int = -1, 9 | out: Optional[torch.Tensor] = None, 10 | dim_size: Optional[int] = None, 11 | unbiased: bool = True) -> torch.Tensor: 12 | 13 | if out is not None: 14 | dim_size = out.size(dim) 15 | 16 | if dim < 0: 17 | dim = src.dim() + dim 18 | 19 | count_dim = dim 20 | if index.dim() <= dim: 21 | count_dim = index.dim() - 1 22 | 23 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) 24 | count = scatter_sum(ones, index, count_dim, dim_size=dim_size) 25 | 26 | index = broadcast(index, src, dim) 27 | tmp = scatter_sum(src, index, dim, dim_size=dim_size) 28 | count = broadcast(count, tmp, dim).clamp(1) 29 | mean = tmp.div(count) 30 | 31 | var = (src - mean.gather(dim, index)) 32 | var = var * var 33 | out = scatter_sum(var, index, dim, out, dim_size) 34 | 35 | if unbiased: 36 | count = count.sub(1).clamp_(1) 37 | out = out.div(count + 1e-6).sqrt() 38 | 39 | return out 40 | -------------------------------------------------------------------------------- /torch_scatter/placeholder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | 6 | def cuda_version_placeholder() -> int: 7 | return -1 8 | 9 | 10 | def scatter_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int, 11 | out: Optional[torch.Tensor], 12 | dim_size: Optional[int]) -> torch.Tensor: 13 | raise ImportError 14 | return src 15 | 16 | 17 | def scatter_arg_placeholder(src: torch.Tensor, index: torch.Tensor, dim: int, 18 | out: Optional[torch.Tensor], 19 | dim_size: Optional[int] 20 | ) -> Tuple[torch.Tensor, torch.Tensor]: 21 | raise ImportError 22 | return src, index 23 | 24 | 25 | def segment_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor, 26 | out: Optional[torch.Tensor]) -> torch.Tensor: 27 | raise ImportError 28 | return src 29 | 30 | 31 | def segment_csr_arg_placeholder(src: torch.Tensor, indptr: torch.Tensor, 32 | out: Optional[torch.Tensor] 33 | ) -> Tuple[torch.Tensor, torch.Tensor]: 34 | raise ImportError 35 | return src, indptr 36 | 37 | 38 | def gather_csr_placeholder(src: torch.Tensor, indptr: torch.Tensor, 39 | out: Optional[torch.Tensor]) -> torch.Tensor: 40 | raise ImportError 41 | return src 42 | 43 | 44 | def segment_coo_placeholder(src: torch.Tensor, index: torch.Tensor, 45 | out: Optional[torch.Tensor], 46 | dim_size: Optional[int]) -> torch.Tensor: 47 | raise ImportError 48 | return src 49 | 50 | 51 | def segment_coo_arg_placeholder(src: torch.Tensor, index: torch.Tensor, 52 | out: Optional[torch.Tensor], 53 | dim_size: Optional[int] 54 | ) -> Tuple[torch.Tensor, torch.Tensor]: 55 | raise ImportError 56 | return src, index 57 | 58 | 59 | def gather_coo_placeholder(src: torch.Tensor, index: torch.Tensor, 60 | out: Optional[torch.Tensor]) -> torch.Tensor: 61 | raise ImportError 62 | return src 63 | -------------------------------------------------------------------------------- /torch_scatter/scatter.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | from .utils import broadcast 6 | 7 | 8 | def scatter_sum(src: torch.Tensor, 9 | index: torch.Tensor, 10 | dim: int = -1, 11 | out: Optional[torch.Tensor] = None, 12 | dim_size: Optional[int] = None) -> torch.Tensor: 13 | index = broadcast(index, src, dim) 14 | if out is None: 15 | size = list(src.size()) 16 | if dim_size is not None: 17 | size[dim] = dim_size 18 | elif index.numel() == 0: 19 | size[dim] = 0 20 | else: 21 | size[dim] = int(index.max()) + 1 22 | out = torch.zeros(size, dtype=src.dtype, device=src.device) 23 | return out.scatter_add_(dim, index, src) 24 | else: 25 | return out.scatter_add_(dim, index, src) 26 | 27 | 28 | def scatter_add(src: torch.Tensor, 29 | index: torch.Tensor, 30 | dim: int = -1, 31 | out: Optional[torch.Tensor] = None, 32 | dim_size: Optional[int] = None) -> torch.Tensor: 33 | return scatter_sum(src, index, dim, out, dim_size) 34 | 35 | 36 | def scatter_mul(src: torch.Tensor, 37 | index: torch.Tensor, 38 | dim: int = -1, 39 | out: Optional[torch.Tensor] = None, 40 | dim_size: Optional[int] = None) -> torch.Tensor: 41 | return torch.ops.torch_scatter.scatter_mul(src, index, dim, out, dim_size) 42 | 43 | 44 | def scatter_mean(src: torch.Tensor, 45 | index: torch.Tensor, 46 | dim: int = -1, 47 | out: Optional[torch.Tensor] = None, 48 | dim_size: Optional[int] = None) -> torch.Tensor: 49 | out = scatter_sum(src, index, dim, out, dim_size) 50 | dim_size = out.size(dim) 51 | 52 | index_dim = dim 53 | if index_dim < 0: 54 | index_dim = index_dim + src.dim() 55 | if index.dim() <= index_dim: 56 | index_dim = index.dim() - 1 57 | 58 | ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) 59 | count = scatter_sum(ones, index, index_dim, None, dim_size) 60 | count[count < 1] = 1 61 | count = broadcast(count, out, dim) 62 | if out.is_floating_point(): 63 | out.true_divide_(count) 64 | else: 65 | out.div_(count, rounding_mode='floor') 66 | return out 67 | 68 | 69 | def scatter_min( 70 | src: torch.Tensor, 71 | index: torch.Tensor, 72 | dim: int = -1, 73 | out: Optional[torch.Tensor] = None, 74 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: 75 | return torch.ops.torch_scatter.scatter_min(src, index, dim, out, dim_size) 76 | 77 | 78 | def scatter_max( 79 | src: torch.Tensor, 80 | index: torch.Tensor, 81 | dim: int = -1, 82 | out: Optional[torch.Tensor] = None, 83 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: 84 | return torch.ops.torch_scatter.scatter_max(src, index, dim, out, dim_size) 85 | 86 | 87 | def scatter(src: torch.Tensor, 88 | index: torch.Tensor, 89 | dim: int = -1, 90 | out: Optional[torch.Tensor] = None, 91 | dim_size: Optional[int] = None, 92 | reduce: str = "sum") -> torch.Tensor: 93 | r""" 94 | | 95 | 96 | .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ 97 | master/docs/source/_figures/add.svg?sanitize=true 98 | :align: center 99 | :width: 400px 100 | 101 | | 102 | 103 | Reduces all values from the :attr:`src` tensor into :attr:`out` at the 104 | indices specified in the :attr:`index` tensor along a given axis 105 | :attr:`dim`. 106 | For each value in :attr:`src`, its output index is specified by its index 107 | in :attr:`src` for dimensions outside of :attr:`dim` and by the 108 | corresponding value in :attr:`index` for dimension :attr:`dim`. 109 | The applied reduction is defined via the :attr:`reduce` argument. 110 | 111 | Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional 112 | tensors with size :math:`(x_0, ..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})` 113 | and :attr:`dim` = `i`, then :attr:`out` must be an :math:`n`-dimensional 114 | tensor with size :math:`(x_0, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})`. 115 | Moreover, the values of :attr:`index` must be between :math:`0` and 116 | :math:`y - 1`, although no specific ordering of indices is required. 117 | The :attr:`index` tensor supports broadcasting in case its dimensions do 118 | not match with :attr:`src`. 119 | 120 | For one-dimensional tensors with :obj:`reduce="sum"`, the operation 121 | computes 122 | 123 | .. math:: 124 | \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j 125 | 126 | where :math:`\sum_j` is over :math:`j` such that 127 | :math:`\mathrm{index}_j = i`. 128 | 129 | .. note:: 130 | 131 | This operation is implemented via atomic operations on the GPU and is 132 | therefore **non-deterministic** since the order of parallel operations 133 | to the same value is undetermined. 134 | For floating-point variables, this results in a source of variance in 135 | the result. 136 | 137 | :param src: The source tensor. 138 | :param index: The indices of elements to scatter. 139 | :param dim: The axis along which to index. (default: :obj:`-1`) 140 | :param out: The destination tensor. 141 | :param dim_size: If :attr:`out` is not given, automatically create output 142 | with size :attr:`dim_size` at dimension :attr:`dim`. 143 | If :attr:`dim_size` is not given, a minimal sized output tensor 144 | according to :obj:`index.max() + 1` is returned. 145 | :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mul"`, 146 | :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) 147 | 148 | :rtype: :class:`Tensor` 149 | 150 | .. code-block:: python 151 | 152 | from torch_scatter import scatter 153 | 154 | src = torch.randn(10, 6, 64) 155 | index = torch.tensor([0, 1, 0, 1, 2, 1]) 156 | 157 | # Broadcasting in the first and last dim. 158 | out = scatter(src, index, dim=1, reduce="sum") 159 | 160 | print(out.size()) 161 | 162 | .. code-block:: 163 | 164 | torch.Size([10, 3, 64]) 165 | """ 166 | if reduce == 'sum' or reduce == 'add': 167 | return scatter_sum(src, index, dim, out, dim_size) 168 | if reduce == 'mul': 169 | return scatter_mul(src, index, dim, out, dim_size) 170 | elif reduce == 'mean': 171 | return scatter_mean(src, index, dim, out, dim_size) 172 | elif reduce == 'min': 173 | return scatter_min(src, index, dim, out, dim_size)[0] 174 | elif reduce == 'max': 175 | return scatter_max(src, index, dim, out, dim_size)[0] 176 | else: 177 | raise ValueError 178 | -------------------------------------------------------------------------------- /torch_scatter/segment_coo.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | 6 | def segment_sum_coo(src: torch.Tensor, index: torch.Tensor, 7 | out: Optional[torch.Tensor] = None, 8 | dim_size: Optional[int] = None) -> torch.Tensor: 9 | return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size) 10 | 11 | 12 | def segment_add_coo(src: torch.Tensor, index: torch.Tensor, 13 | out: Optional[torch.Tensor] = None, 14 | dim_size: Optional[int] = None) -> torch.Tensor: 15 | return torch.ops.torch_scatter.segment_sum_coo(src, index, out, dim_size) 16 | 17 | 18 | def segment_mean_coo(src: torch.Tensor, index: torch.Tensor, 19 | out: Optional[torch.Tensor] = None, 20 | dim_size: Optional[int] = None) -> torch.Tensor: 21 | return torch.ops.torch_scatter.segment_mean_coo(src, index, out, dim_size) 22 | 23 | 24 | def segment_min_coo( 25 | src: torch.Tensor, index: torch.Tensor, 26 | out: Optional[torch.Tensor] = None, 27 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: 28 | return torch.ops.torch_scatter.segment_min_coo(src, index, out, dim_size) 29 | 30 | 31 | def segment_max_coo( 32 | src: torch.Tensor, index: torch.Tensor, 33 | out: Optional[torch.Tensor] = None, 34 | dim_size: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: 35 | return torch.ops.torch_scatter.segment_max_coo(src, index, out, dim_size) 36 | 37 | 38 | def segment_coo(src: torch.Tensor, index: torch.Tensor, 39 | out: Optional[torch.Tensor] = None, 40 | dim_size: Optional[int] = None, 41 | reduce: str = "sum") -> torch.Tensor: 42 | r""" 43 | | 44 | 45 | .. image:: https://raw.githubusercontent.com/rusty1s/pytorch_scatter/ 46 | master/docs/source/_figures/segment_coo.svg?sanitize=true 47 | :align: center 48 | :width: 400px 49 | 50 | | 51 | 52 | Reduces all values from the :attr:`src` tensor into :attr:`out` at the 53 | indices specified in the :attr:`index` tensor along the last dimension of 54 | :attr:`index`. 55 | For each value in :attr:`src`, its output index is specified by its index 56 | in :attr:`src` for dimensions outside of :obj:`index.dim() - 1` and by the 57 | corresponding value in :attr:`index` for dimension :obj:`index.dim() - 1`. 58 | The applied reduction is defined via the :attr:`reduce` argument. 59 | 60 | Formally, if :attr:`src` and :attr:`index` are :math:`n`-dimensional and 61 | :math:`m`-dimensional tensors with 62 | size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and 63 | :math:`(x_0, ..., x_{m-1}, x_m)`, respectively, then :attr:`out` must be an 64 | :math:`n`-dimensional tensor with size 65 | :math:`(x_0, ..., x_{m-1}, y, x_{m+1}, ..., x_{n-1})`. 66 | Moreover, the values of :attr:`index` must be between :math:`0` and 67 | :math:`y - 1` in ascending order. 68 | The :attr:`index` tensor supports broadcasting in case its dimensions do 69 | not match with :attr:`src`. 70 | 71 | For one-dimensional tensors with :obj:`reduce="sum"`, the operation 72 | computes 73 | 74 | .. math:: 75 | \mathrm{out}_i = \mathrm{out}_i + \sum_j~\mathrm{src}_j 76 | 77 | where :math:`\sum_j` is over :math:`j` such that 78 | :math:`\mathrm{index}_j = i`. 79 | 80 | In contrast to :meth:`scatter`, this method expects values in :attr:`index` 81 | **to be sorted** along dimension :obj:`index.dim() - 1`. 82 | Due to the use of sorted indices, :meth:`segment_coo` is usually faster 83 | than the more general :meth:`scatter` operation. 84 | 85 | .. note:: 86 | 87 | This operation is implemented via atomic operations on the GPU and is 88 | therefore **non-deterministic** since the order of parallel operations 89 | to the same value is undetermined. 90 | For floating-point variables, this results in a source of variance in 91 | the result. 92 | 93 | :param src: The source tensor. 94 | :param index: The sorted indices of elements to segment. 95 | The number of dimensions of :attr:`index` needs to be less than or 96 | equal to :attr:`src`. 97 | :param out: The destination tensor. 98 | :param dim_size: If :attr:`out` is not given, automatically create output 99 | with size :attr:`dim_size` at dimension :obj:`index.dim() - 1`. 100 | If :attr:`dim_size` is not given, a minimal sized output tensor 101 | according to :obj:`index.max() + 1` is returned. 102 | :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`, 103 | :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) 104 | 105 | :rtype: :class:`Tensor` 106 | 107 | .. code-block:: python 108 | 109 | from torch_scatter import segment_coo 110 | 111 | src = torch.randn(10, 6, 64) 112 | index = torch.tensor([0, 0, 1, 1, 1, 2]) 113 | index = index.view(1, -1) # Broadcasting in the first and last dim. 114 | 115 | out = segment_coo(src, index, reduce="sum") 116 | 117 | print(out.size()) 118 | 119 | .. code-block:: 120 | 121 | torch.Size([10, 3, 64]) 122 | """ 123 | if reduce == 'sum' or reduce == 'add': 124 | return segment_sum_coo(src, index, out, dim_size) 125 | elif reduce == 'mean': 126 | return segment_mean_coo(src, index, out, dim_size) 127 | elif reduce == 'min': 128 | return segment_min_coo(src, index, out, dim_size)[0] 129 | elif reduce == 'max': 130 | return segment_max_coo(src, index, out, dim_size)[0] 131 | else: 132 | raise ValueError 133 | 134 | 135 | def gather_coo(src: torch.Tensor, index: torch.Tensor, 136 | out: Optional[torch.Tensor] = None) -> torch.Tensor: 137 | return torch.ops.torch_scatter.gather_coo(src, index, out) 138 | -------------------------------------------------------------------------------- /torch_scatter/segment_csr.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Tuple 2 | 3 | import torch 4 | 5 | 6 | def segment_sum_csr(src: torch.Tensor, indptr: torch.Tensor, 7 | out: Optional[torch.Tensor] = None) -> torch.Tensor: 8 | return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out) 9 | 10 | 11 | def segment_add_csr(src: torch.Tensor, indptr: torch.Tensor, 12 | out: Optional[torch.Tensor] = None) -> torch.Tensor: 13 | return torch.ops.torch_scatter.segment_sum_csr(src, indptr, out) 14 | 15 | 16 | def segment_mean_csr(src: torch.Tensor, indptr: torch.Tensor, 17 | out: Optional[torch.Tensor] = None) -> torch.Tensor: 18 | return torch.ops.torch_scatter.segment_mean_csr(src, indptr, out) 19 | 20 | 21 | def segment_min_csr( 22 | src: torch.Tensor, indptr: torch.Tensor, 23 | out: Optional[torch.Tensor] = None 24 | ) -> Tuple[torch.Tensor, torch.Tensor]: 25 | return torch.ops.torch_scatter.segment_min_csr(src, indptr, out) 26 | 27 | 28 | def segment_max_csr( 29 | src: torch.Tensor, indptr: torch.Tensor, 30 | out: Optional[torch.Tensor] = None 31 | ) -> Tuple[torch.Tensor, torch.Tensor]: 32 | return torch.ops.torch_scatter.segment_max_csr(src, indptr, out) 33 | 34 | 35 | def segment_csr(src: torch.Tensor, indptr: torch.Tensor, 36 | out: Optional[torch.Tensor] = None, 37 | reduce: str = "sum") -> torch.Tensor: 38 | r""" 39 | Reduces all values from the :attr:`src` tensor into :attr:`out` within the 40 | ranges specified in the :attr:`indptr` tensor along the last dimension of 41 | :attr:`indptr`. 42 | For each value in :attr:`src`, its output index is specified by its index 43 | in :attr:`src` for dimensions outside of :obj:`indptr.dim() - 1` and by the 44 | corresponding range index in :attr:`indptr` for dimension 45 | :obj:`indptr.dim() - 1`. 46 | The applied reduction is defined via the :attr:`reduce` argument. 47 | 48 | Formally, if :attr:`src` and :attr:`indptr` are :math:`n`-dimensional and 49 | :math:`m`-dimensional tensors with 50 | size :math:`(x_0, ..., x_{m-1}, x_m, x_{m+1}, ..., x_{n-1})` and 51 | :math:`(x_0, ..., x_{m-2}, y)`, respectively, then :attr:`out` must be an 52 | :math:`n`-dimensional tensor with size 53 | :math:`(x_0, ..., x_{m-2}, y - 1, x_{m}, ..., x_{n-1})`. 54 | Moreover, the values of :attr:`indptr` must be between :math:`0` and 55 | :math:`x_m` in ascending order. 56 | The :attr:`indptr` tensor supports broadcasting in case its dimensions do 57 | not match with :attr:`src`. 58 | 59 | For one-dimensional tensors with :obj:`reduce="sum"`, the operation 60 | computes 61 | 62 | .. math:: 63 | \mathrm{out}_i = 64 | \sum_{j = \mathrm{indptr}[i]}^{\mathrm{indptr}[i+1]-1}~\mathrm{src}_j. 65 | 66 | Due to the use of index pointers, :meth:`segment_csr` is the fastest 67 | method to apply for grouped reductions. 68 | 69 | .. note:: 70 | 71 | In contrast to :meth:`scatter()` and :meth:`segment_coo`, this 72 | operation is **fully-deterministic**. 73 | 74 | :param src: The source tensor. 75 | :param indptr: The index pointers between elements to segment. 76 | The number of dimensions of :attr:`index` needs to be less than or 77 | equal to :attr:`src`. 78 | :param out: The destination tensor. 79 | :param reduce: The reduce operation (:obj:`"sum"`, :obj:`"mean"`, 80 | :obj:`"min"` or :obj:`"max"`). (default: :obj:`"sum"`) 81 | 82 | :rtype: :class:`Tensor` 83 | 84 | .. code-block:: python 85 | 86 | from torch_scatter import segment_csr 87 | 88 | src = torch.randn(10, 6, 64) 89 | indptr = torch.tensor([0, 2, 5, 6]) 90 | indptr = indptr.view(1, -1) # Broadcasting in the first and last dim. 91 | 92 | out = segment_csr(src, indptr, reduce="sum") 93 | 94 | print(out.size()) 95 | 96 | .. code-block:: 97 | 98 | torch.Size([10, 3, 64]) 99 | """ 100 | if reduce == 'sum' or reduce == 'add': 101 | return segment_sum_csr(src, indptr, out) 102 | elif reduce == 'mean': 103 | return segment_mean_csr(src, indptr, out) 104 | elif reduce == 'min': 105 | return segment_min_csr(src, indptr, out)[0] 106 | elif reduce == 'max': 107 | return segment_max_csr(src, indptr, out)[0] 108 | else: 109 | raise ValueError 110 | 111 | 112 | def gather_csr(src: torch.Tensor, indptr: torch.Tensor, 113 | out: Optional[torch.Tensor] = None) -> torch.Tensor: 114 | return torch.ops.torch_scatter.gather_csr(src, indptr, out) 115 | -------------------------------------------------------------------------------- /torch_scatter/testing.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | 5 | reductions = ['sum', 'add', 'mean', 'min', 'max'] 6 | 7 | dtypes = [ 8 | torch.half, torch.bfloat16, torch.float, torch.double, torch.int, 9 | torch.long 10 | ] 11 | grad_dtypes = [torch.float, torch.double] 12 | 13 | devices = [torch.device('cpu')] 14 | if torch.cuda.is_available(): 15 | devices += [torch.device('cuda:0')] 16 | 17 | 18 | def tensor(x: Any, dtype: torch.dtype, device: torch.device): 19 | return None if x is None else torch.tensor(x, device=device).to(dtype) 20 | -------------------------------------------------------------------------------- /torch_scatter/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): 5 | if dim < 0: 6 | dim = other.dim() + dim 7 | if src.dim() == 1: 8 | for _ in range(0, dim): 9 | src = src.unsqueeze(0) 10 | for _ in range(src.dim(), other.dim()): 11 | src = src.unsqueeze(-1) 12 | src = src.expand(other.size()) 13 | return src 14 | --------------------------------------------------------------------------------