├── .bumpversion.cfg ├── .github ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md └── workflows │ └── tests.yml ├── .gitignore ├── .readthedocs.yml ├── LICENSE ├── MANIFEST.in ├── README.md ├── docs └── source │ ├── conf.py │ ├── details.rst │ ├── img │ └── small_graph.svg │ ├── index.rst │ ├── installation.rst │ └── usage.rst ├── pyproject.toml ├── setup.cfg ├── src └── torch_ppr │ ├── __init__.py │ ├── api.py │ ├── py.typed │ ├── utils.py │ └── version.py ├── tests ├── __init__.py ├── test_api.py ├── test_utils.py └── test_version.py └── tox.ini /.bumpversion.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.0.9-dev 3 | commit = True 4 | tag = False 5 | parse = (?P\d+)\.(?P\d+)\.(?P\d+)(?:-(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))? 6 | serialize = 7 | {major}.{minor}.{patch}-{release}+{build} 8 | {major}.{minor}.{patch}+{build} 9 | {major}.{minor}.{patch}-{release} 10 | {major}.{minor}.{patch} 11 | 12 | [bumpversion:part:release] 13 | optional_value = production 14 | first_value = dev 15 | values = 16 | dev 17 | production 18 | 19 | [bumpverion:part:build] 20 | values = [0-9A-Za-z-]+ 21 | 22 | [bumpversion:file:setup.cfg] 23 | search = version = {current_version} 24 | replace = version = {new_version} 25 | 26 | [bumpversion:file:docs/source/conf.py] 27 | search = release = "{current_version}" 28 | replace = release = "{new_version}" 29 | 30 | [bumpversion:file:src/torch_ppr/version.py] 31 | search = VERSION = "{current_version}" 32 | replace = VERSION = "{new_version}" 33 | -------------------------------------------------------------------------------- /.github/CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, caste, color, religion, or sexual 10 | identity and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the overall 26 | community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or advances of 31 | any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email address, 35 | without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | max.berrendorf@gmail.com. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series of 86 | actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or permanent 93 | ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within the 113 | community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.1, available at 119 | [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1]. 120 | 121 | Community Impact Guidelines were inspired by 122 | [Mozilla's code of conduct enforcement ladder][Mozilla CoC]. 123 | 124 | For answers to common questions about this code of conduct, see the FAQ at 125 | [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at 126 | [https://www.contributor-covenant.org/translations][translations]. 127 | 128 | [homepage]: https://www.contributor-covenant.org 129 | [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html 130 | [Mozilla CoC]: https://github.com/mozilla/diversity 131 | [FAQ]: https://www.contributor-covenant.org/faq 132 | [translations]: https://www.contributor-covenant.org/translations 133 | -------------------------------------------------------------------------------- /.github/CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing 2 | 3 | Contributions to this repository are welcomed and encouraged. 4 | 5 | ## Code Contribution 6 | 7 | This project uses the [GitHub Flow](https://guides.github.com/introduction/flow) 8 | model for code contributions. Follow these steps: 9 | 10 | 1. [Create a fork](https://help.github.com/articles/fork-a-repo) of the upstream 11 | repository at [`mberr/torch-ppr`](https://github.com/mberr/torch-ppr) 12 | on your GitHub account (or in one of your organizations) 13 | 2. [Clone your fork](https://docs.github.com/en/repositories/creating-and-managing-repositories/cloning-a-repository) 14 | with `git clone https://github.com//torch-ppr.git` 15 | 3. Make and commit changes to your fork with `git commit` 16 | 4. Push changes to your fork with `git push` 17 | 5. Repeat steps 3 and 4 as needed 18 | 6. Submit a pull request back to the upstream repository 19 | 20 | ### Merge Model 21 | 22 | This repository uses [squash merges](https://docs.github.com/en/github/collaborating-with-pull-requests/incorporating-changes-from-a-pull-request/about-pull-request-merges#squash-and-merge-your-pull-request-commits) 23 | to group all related commits in a given pull request into a single commit upon 24 | acceptance and merge into the main branch. This has several benefits: 25 | 26 | 1. Keeps the commit history on the main branch focused on high-level narrative 27 | 2. Enables people to make lots of small commits without worrying about muddying 28 | up the commit history 29 | 3. Commits correspond 1-to-1 with pull requests 30 | 31 | ### Code Style 32 | 33 | This project encourages the use of optional static typing. It 34 | uses [`mypy`](http://mypy-lang.org/) as a type checker 35 | and [`sphinx_autodoc_typehints`](https://github.com/agronholm/sphinx-autodoc-typehints) 36 | to automatically generate documentation based on type hints. You can check if 37 | your code passes `mypy` with `tox -e mypy`. 38 | 39 | This project uses [`black`](https://github.com/psf/black) to automatically 40 | enforce a consistent code style. You can apply `black` and other pre-configured 41 | linters with `tox -e lint`. 42 | 43 | This project uses [`flake8`](https://flake8.pycqa.org) and several plugins for 44 | additional checks of documentation style, security issues, good variable 45 | nomenclature, and more ( 46 | see [`tox.ini`](tox.ini) for a list of flake8 plugins). You can check if your 47 | code passes `flake8` with `tox -e flake8`. 48 | 49 | Each of these checks are run on each commit using GitHub Actions as a continuous 50 | integration service. Passing all of them is required for accepting a 51 | contribution. If you're unsure how to address the feedback from one of these 52 | tools, please say so either in the description of your pull request or in a 53 | comment, and we will help you. 54 | 55 | ### Logging 56 | 57 | Python's builtin `print()` should not be used (except when writing to files), 58 | it's checked by the 59 | [`flake8-print`](https://github.com/jbkahn/flake8-print) plugin to `flake8`. If 60 | you're in a command line setting or `main()` function for a module, you can use 61 | `click.echo()`. Otherwise, you can use the builtin `logging` module by adding 62 | `logger = logging.getLogger(__name__)` below the imports at the top of your 63 | file. 64 | 65 | ### Documentation 66 | 67 | All public functions (i.e., not starting with an underscore `_`) must be 68 | documented using the [sphinx documentation format](https://sphinx-rtd-tutorial.readthedocs.io/en/latest/docstrings.html#the-sphinx-docstring-format). 69 | The [`darglint`](https://github.com/terrencepreilly/darglint) plugin to `flake8` 70 | reports on functions that are not fully documented. 71 | 72 | This project uses [`sphinx`](https://www.sphinx-doc.org) to automatically build 73 | documentation into a narrative structure. You can check that the documentation 74 | builds properly in an isolated environment with `tox -e docs-test` and actually 75 | build it locally with `tox -e docs`. 76 | 77 | ### Testing 78 | 79 | Functions in this repository should be unit tested. These can either be written 80 | using the `unittest` framework in the `tests/` directory or as embedded 81 | doctests. You can check that the unit tests pass with `tox -e py` and that the 82 | doctests pass with `tox -e doctests`. These tests are required to pass for 83 | accepting a contribution. 84 | 85 | ### Syncing your fork 86 | 87 | If other code is updated before your contribution gets merged, you might need to 88 | resolve conflicts against the main branch. After cloning, you should add the 89 | upstream repository with 90 | 91 | ```shell 92 | $ git remote add mberr https://github.com/mberr/torch-ppr.git 93 | ``` 94 | 95 | Then, you can merge upstream code into your branch. You can also use the GitHub 96 | UI to do this by following [this tutorial](https://docs.github.com/en/github/collaborating-with-pull-requests/working-with-forks/syncing-a-fork). 97 | 98 | ### Python Version Compatibility 99 | 100 | This project aims to support all versions of Python that have not passed their 101 | end-of-life dates. After end-of-life, the version will be removed from the Trove 102 | qualifiers in the [`setup.cfg`](setup.cfg) and from the GitHub Actions testing 103 | configuration. 104 | 105 | See https://endoflife.date/python for a timeline of Python release and 106 | end-of-life dates. 107 | -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | lint: 11 | name: Lint 12 | runs-on: ubuntu-latest 13 | strategy: 14 | matrix: 15 | python-version: [ "3.8", "3.9", "3.10" ] 16 | steps: 17 | - uses: actions/checkout@v2 18 | - name: Set up Python ${{ matrix.python-version }} 19 | uses: actions/setup-python@v2 20 | with: 21 | python-version: ${{ matrix.python-version }} 22 | - name: Install dependencies 23 | run: pip install tox 24 | - name: Check manifest 25 | run: tox -e manifest 26 | - name: Check code quality with flake8 27 | run: tox -e flake8 28 | - name: Check package metadata with Pyroma 29 | run: tox -e pyroma 30 | - name: Check static typing with MyPy 31 | run: tox -e mypy 32 | docs: 33 | name: Documentation 34 | runs-on: ubuntu-latest 35 | strategy: 36 | matrix: 37 | python-version: [ "3.8", "3.9", "3.10" ] 38 | steps: 39 | - uses: actions/checkout@v2 40 | - name: Set up Python ${{ matrix.python-version }} 41 | uses: actions/setup-python@v2 42 | with: 43 | python-version: ${{ matrix.python-version }} 44 | - name: Install dependencies 45 | run: pip install tox 46 | - name: Check RST conformity with doc8 47 | run: tox -e doc8 48 | - name: Check docstring coverage 49 | run: tox -e docstr-coverage 50 | - name: Check documentation build with Sphinx 51 | run: tox -e docs-test 52 | tests: 53 | name: Tests 54 | strategy: 55 | matrix: 56 | os: [ "ubuntu-latest" ] 57 | python-version: [ "3.8", "3.9", "3.10" ] 58 | torch-version: [ "torch-1.11", "torch-1.12", "torch-1.13" ] 59 | runs-on: ${{ matrix.os }} 60 | steps: 61 | - uses: actions/checkout@v2 62 | - name: Set up Python ${{ matrix.python-version }} 63 | uses: actions/setup-python@v2 64 | with: 65 | python-version: ${{ matrix.python-version }} 66 | - name: Install dependencies 67 | run: pip install tox 68 | - name: Test with pytest and generate coverage file 69 | run: 70 | tox -e py-${{ matrix.torch-version }} 71 | - name: Upload coverage report to codecov 72 | uses: codecov/codecov-action@v1 73 | if: success() 74 | with: 75 | file: coverage.xml 76 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/macos,linux,windows,python,jupyternotebooks,jetbrains,pycharm,vim,emacs,visualstudiocode,visualstudio 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,linux,windows,python,jupyternotebooks,jetbrains,pycharm,vim,emacs,visualstudiocode,visualstudio 3 | 4 | ### Emacs ### 5 | # -*- mode: gitignore; -*- 6 | *~ 7 | \#*\# 8 | /.emacs.desktop 9 | /.emacs.desktop.lock 10 | *.elc 11 | auto-save-list 12 | tramp 13 | .\#* 14 | 15 | # Org-mode 16 | .org-id-locations 17 | *_archive 18 | 19 | # flymake-mode 20 | *_flymake.* 21 | 22 | # eshell files 23 | /eshell/history 24 | /eshell/lastdir 25 | 26 | # elpa packages 27 | /elpa/ 28 | 29 | # reftex files 30 | *.rel 31 | 32 | # AUCTeX auto folder 33 | /auto/ 34 | 35 | # cask packages 36 | .cask/ 37 | dist/ 38 | 39 | # Flycheck 40 | flycheck_*.el 41 | 42 | # server auth directory 43 | /server/ 44 | 45 | # projectiles files 46 | .projectile 47 | 48 | # directory configuration 49 | .dir-locals.el 50 | 51 | # network security 52 | /network-security.data 53 | 54 | 55 | ### JetBrains ### 56 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 57 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 58 | 59 | # User-specific stuff 60 | .idea/**/workspace.xml 61 | .idea/**/tasks.xml 62 | .idea/**/usage.statistics.xml 63 | .idea/**/dictionaries 64 | .idea/**/shelf 65 | 66 | # AWS User-specific 67 | .idea/**/aws.xml 68 | 69 | # Generated files 70 | .idea/**/contentModel.xml 71 | 72 | # Sensitive or high-churn files 73 | .idea/**/dataSources/ 74 | .idea/**/dataSources.ids 75 | .idea/**/dataSources.local.xml 76 | .idea/**/sqlDataSources.xml 77 | .idea/**/dynamic.xml 78 | .idea/**/uiDesigner.xml 79 | .idea/**/dbnavigator.xml 80 | 81 | # Gradle 82 | .idea/**/gradle.xml 83 | .idea/**/libraries 84 | 85 | # Gradle and Maven with auto-import 86 | # When using Gradle or Maven with auto-import, you should exclude module files, 87 | # since they will be recreated, and may cause churn. Uncomment if using 88 | # auto-import. 89 | # .idea/artifacts 90 | # .idea/compiler.xml 91 | # .idea/jarRepositories.xml 92 | # .idea/modules.xml 93 | # .idea/*.iml 94 | # .idea/modules 95 | # *.iml 96 | # *.ipr 97 | 98 | # CMake 99 | cmake-build-*/ 100 | 101 | # Mongo Explorer plugin 102 | .idea/**/mongoSettings.xml 103 | 104 | # File-based project format 105 | *.iws 106 | 107 | # IntelliJ 108 | out/ 109 | 110 | # mpeltonen/sbt-idea plugin 111 | .idea_modules/ 112 | 113 | # JIRA plugin 114 | atlassian-ide-plugin.xml 115 | 116 | # Cursive Clojure plugin 117 | .idea/replstate.xml 118 | 119 | # SonarLint plugin 120 | .idea/sonarlint/ 121 | 122 | # Crashlytics plugin (for Android Studio and IntelliJ) 123 | com_crashlytics_export_strings.xml 124 | crashlytics.properties 125 | crashlytics-build.properties 126 | fabric.properties 127 | 128 | # Editor-based Rest Client 129 | .idea/httpRequests 130 | 131 | # Android studio 3.1+ serialized cache file 132 | .idea/caches/build_file_checksums.ser 133 | 134 | ### JetBrains Patch ### 135 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 136 | 137 | # *.iml 138 | # modules.xml 139 | # .idea/misc.xml 140 | # *.ipr 141 | 142 | # Sonarlint plugin 143 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 144 | .idea/**/sonarlint/ 145 | 146 | # SonarQube Plugin 147 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 148 | .idea/**/sonarIssues.xml 149 | 150 | # Markdown Navigator plugin 151 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 152 | .idea/**/markdown-navigator.xml 153 | .idea/**/markdown-navigator-enh.xml 154 | .idea/**/markdown-navigator/ 155 | 156 | # Cache file creation bug 157 | # See https://youtrack.jetbrains.com/issue/JBR-2257 158 | .idea/$CACHE_FILE$ 159 | 160 | # CodeStream plugin 161 | # https://plugins.jetbrains.com/plugin/12206-codestream 162 | .idea/codestream.xml 163 | 164 | ### JupyterNotebooks ### 165 | # gitignore template for Jupyter Notebooks 166 | # website: http://jupyter.org/ 167 | 168 | .ipynb_checkpoints 169 | */.ipynb_checkpoints/* 170 | 171 | # IPython 172 | profile_default/ 173 | ipython_config.py 174 | 175 | # Remove previous ipynb_checkpoints 176 | # git rm -r .ipynb_checkpoints/ 177 | 178 | ### Linux ### 179 | 180 | # temporary files which can be created if a process still has a handle open of a deleted file 181 | .fuse_hidden* 182 | 183 | # KDE directory preferences 184 | .directory 185 | 186 | # Linux trash folder which might appear on any partition or disk 187 | .Trash-* 188 | 189 | # .nfs files are created when an open file is removed but is still being accessed 190 | .nfs* 191 | 192 | ### macOS ### 193 | # General 194 | .DS_Store 195 | .AppleDouble 196 | .LSOverride 197 | 198 | # Icon must end with two \r 199 | Icon 200 | 201 | 202 | # Thumbnails 203 | ._* 204 | 205 | # Files that might appear in the root of a volume 206 | .DocumentRevisions-V100 207 | .fseventsd 208 | .Spotlight-V100 209 | .TemporaryItems 210 | .Trashes 211 | .VolumeIcon.icns 212 | .com.apple.timemachine.donotpresent 213 | 214 | # Directories potentially created on remote AFP share 215 | .AppleDB 216 | .AppleDesktop 217 | Network Trash Folder 218 | Temporary Items 219 | .apdisk 220 | 221 | ### PyCharm ### 222 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 223 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 224 | 225 | # User-specific stuff 226 | 227 | # AWS User-specific 228 | 229 | # Generated files 230 | 231 | # Sensitive or high-churn files 232 | 233 | # Gradle 234 | 235 | # Gradle and Maven with auto-import 236 | # When using Gradle or Maven with auto-import, you should exclude module files, 237 | # since they will be recreated, and may cause churn. Uncomment if using 238 | # auto-import. 239 | # .idea/artifacts 240 | # .idea/compiler.xml 241 | # .idea/jarRepositories.xml 242 | # .idea/modules.xml 243 | # .idea/*.iml 244 | # .idea/modules 245 | # *.iml 246 | # *.ipr 247 | 248 | # CMake 249 | 250 | # Mongo Explorer plugin 251 | 252 | # File-based project format 253 | 254 | # IntelliJ 255 | 256 | # mpeltonen/sbt-idea plugin 257 | 258 | # JIRA plugin 259 | 260 | # Cursive Clojure plugin 261 | 262 | # SonarLint plugin 263 | 264 | # Crashlytics plugin (for Android Studio and IntelliJ) 265 | 266 | # Editor-based Rest Client 267 | 268 | # Android studio 3.1+ serialized cache file 269 | 270 | ### PyCharm Patch ### 271 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 272 | 273 | # *.iml 274 | # modules.xml 275 | # .idea/misc.xml 276 | # *.ipr 277 | 278 | # Sonarlint plugin 279 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 280 | 281 | # SonarQube Plugin 282 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 283 | 284 | # Markdown Navigator plugin 285 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 286 | 287 | # Cache file creation bug 288 | # See https://youtrack.jetbrains.com/issue/JBR-2257 289 | 290 | # CodeStream plugin 291 | # https://plugins.jetbrains.com/plugin/12206-codestream 292 | 293 | ### Python ### 294 | # Byte-compiled / optimized / DLL files 295 | __pycache__/ 296 | *.py[cod] 297 | *$py.class 298 | 299 | # C extensions 300 | *.so 301 | 302 | # Distribution / packaging 303 | .Python 304 | build/ 305 | develop-eggs/ 306 | downloads/ 307 | eggs/ 308 | .eggs/ 309 | lib/ 310 | lib64/ 311 | parts/ 312 | sdist/ 313 | var/ 314 | wheels/ 315 | share/python-wheels/ 316 | *.egg-info/ 317 | .installed.cfg 318 | *.egg 319 | MANIFEST 320 | 321 | # PyInstaller 322 | # Usually these files are written by a python script from a template 323 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 324 | *.manifest 325 | *.spec 326 | 327 | # Installer logs 328 | pip-log.txt 329 | pip-delete-this-directory.txt 330 | 331 | # Unit test / coverage reports 332 | htmlcov/ 333 | .tox/ 334 | .nox/ 335 | .coverage 336 | .coverage.* 337 | .cache 338 | nosetests.xml 339 | coverage.xml 340 | *.cover 341 | *.py,cover 342 | .hypothesis/ 343 | .pytest_cache/ 344 | cover/ 345 | 346 | # Translations 347 | *.mo 348 | *.pot 349 | 350 | # Django stuff: 351 | *.log 352 | local_settings.py 353 | db.sqlite3 354 | db.sqlite3-journal 355 | 356 | # Flask stuff: 357 | instance/ 358 | .webassets-cache 359 | 360 | # Scrapy stuff: 361 | .scrapy 362 | 363 | # Sphinx documentation 364 | docs/_build/ 365 | 366 | # PyBuilder 367 | .pybuilder/ 368 | target/ 369 | 370 | # Jupyter Notebook 371 | 372 | # IPython 373 | 374 | # pyenv 375 | # For a library or package, you might want to ignore these files since the code is 376 | # intended to run in multiple environments; otherwise, check them in: 377 | # .python-version 378 | 379 | # pipenv 380 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 381 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 382 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 383 | # install all needed dependencies. 384 | #Pipfile.lock 385 | 386 | # poetry 387 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 388 | # This is especially recommended for binary packages to ensure reproducibility, and is more 389 | # commonly ignored for libraries. 390 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 391 | #poetry.lock 392 | 393 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 394 | __pypackages__/ 395 | 396 | # Celery stuff 397 | celerybeat-schedule 398 | celerybeat.pid 399 | 400 | # SageMath parsed files 401 | *.sage.py 402 | 403 | # Environments 404 | .env 405 | .venv 406 | env/ 407 | venv/ 408 | ENV/ 409 | env.bak/ 410 | venv.bak/ 411 | 412 | # Spyder project settings 413 | .spyderproject 414 | .spyproject 415 | 416 | # Rope project settings 417 | .ropeproject 418 | 419 | # mkdocs documentation 420 | /site 421 | 422 | # mypy 423 | .mypy_cache/ 424 | .dmypy.json 425 | dmypy.json 426 | 427 | # Pyre type checker 428 | .pyre/ 429 | 430 | # pytype static type analyzer 431 | .pytype/ 432 | 433 | # Cython debug symbols 434 | cython_debug/ 435 | 436 | # PyCharm 437 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 438 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 439 | # and can be added to the global gitignore or merged into this file. For a more nuclear 440 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 441 | #.idea/ 442 | 443 | ### Vim ### 444 | # Swap 445 | [._]*.s[a-v][a-z] 446 | !*.svg # comment out if you don't need vector files 447 | [._]*.sw[a-p] 448 | [._]s[a-rt-v][a-z] 449 | [._]ss[a-gi-z] 450 | [._]sw[a-p] 451 | 452 | # Session 453 | Session.vim 454 | Sessionx.vim 455 | 456 | # Temporary 457 | .netrwhist 458 | # Auto-generated tag files 459 | tags 460 | # Persistent undo 461 | [._]*.un~ 462 | 463 | ### VisualStudioCode ### 464 | .vscode/* 465 | !.vscode/settings.json 466 | !.vscode/tasks.json 467 | !.vscode/launch.json 468 | !.vscode/extensions.json 469 | !.vscode/*.code-snippets 470 | 471 | # Local History for Visual Studio Code 472 | .history/ 473 | 474 | # Built Visual Studio Code Extensions 475 | *.vsix 476 | 477 | ### VisualStudioCode Patch ### 478 | # Ignore all local history of files 479 | .history 480 | .ionide 481 | 482 | # Support for Project snippet scope 483 | 484 | ### Windows ### 485 | # Windows thumbnail cache files 486 | Thumbs.db 487 | Thumbs.db:encryptable 488 | ehthumbs.db 489 | ehthumbs_vista.db 490 | 491 | # Dump file 492 | *.stackdump 493 | 494 | # Folder config file 495 | [Dd]esktop.ini 496 | 497 | # Recycle Bin used on file shares 498 | $RECYCLE.BIN/ 499 | 500 | # Windows Installer files 501 | *.cab 502 | *.msi 503 | *.msix 504 | *.msm 505 | *.msp 506 | 507 | # Windows shortcuts 508 | *.lnk 509 | 510 | ### VisualStudio ### 511 | ## Ignore Visual Studio temporary files, build results, and 512 | ## files generated by popular Visual Studio add-ons. 513 | ## 514 | ## Get latest from https://github.com/github/gitignore/blob/main/VisualStudio.gitignore 515 | 516 | # User-specific files 517 | *.rsuser 518 | *.suo 519 | *.user 520 | *.userosscache 521 | *.sln.docstates 522 | 523 | # User-specific files (MonoDevelop/Xamarin Studio) 524 | *.userprefs 525 | 526 | # Mono auto generated files 527 | mono_crash.* 528 | 529 | # Build results 530 | [Dd]ebug/ 531 | [Dd]ebugPublic/ 532 | [Rr]elease/ 533 | [Rr]eleases/ 534 | x64/ 535 | x86/ 536 | [Ww][Ii][Nn]32/ 537 | [Aa][Rr][Mm]/ 538 | [Aa][Rr][Mm]64/ 539 | bld/ 540 | [Bb]in/ 541 | [Oo]bj/ 542 | [Ll]og/ 543 | [Ll]ogs/ 544 | 545 | # Visual Studio 2015/2017 cache/options directory 546 | .vs/ 547 | # Uncomment if you have tasks that create the project's static files in wwwroot 548 | #wwwroot/ 549 | 550 | # Visual Studio 2017 auto generated files 551 | Generated\ Files/ 552 | 553 | # MSTest test Results 554 | [Tt]est[Rr]esult*/ 555 | [Bb]uild[Ll]og.* 556 | 557 | # NUnit 558 | *.VisualState.xml 559 | TestResult.xml 560 | nunit-*.xml 561 | 562 | # Build Results of an ATL Project 563 | [Dd]ebugPS/ 564 | [Rr]eleasePS/ 565 | dlldata.c 566 | 567 | # Benchmark Results 568 | BenchmarkDotNet.Artifacts/ 569 | 570 | # .NET Core 571 | project.lock.json 572 | project.fragment.lock.json 573 | artifacts/ 574 | 575 | # ASP.NET Scaffolding 576 | ScaffoldingReadMe.txt 577 | 578 | # StyleCop 579 | StyleCopReport.xml 580 | 581 | # Files built by Visual Studio 582 | *_i.c 583 | *_p.c 584 | *_h.h 585 | *.ilk 586 | *.meta 587 | *.obj 588 | *.iobj 589 | *.pch 590 | *.pdb 591 | *.ipdb 592 | *.pgc 593 | *.pgd 594 | *.rsp 595 | *.sbr 596 | *.tlb 597 | *.tli 598 | *.tlh 599 | *.tmp 600 | *.tmp_proj 601 | *_wpftmp.csproj 602 | *.tlog 603 | *.vspscc 604 | *.vssscc 605 | .builds 606 | *.pidb 607 | *.svclog 608 | *.scc 609 | 610 | # Chutzpah Test files 611 | _Chutzpah* 612 | 613 | # Visual C++ cache files 614 | ipch/ 615 | *.aps 616 | *.ncb 617 | *.opendb 618 | *.opensdf 619 | *.sdf 620 | *.cachefile 621 | *.VC.db 622 | *.VC.VC.opendb 623 | 624 | # Visual Studio profiler 625 | *.psess 626 | *.vsp 627 | *.vspx 628 | *.sap 629 | 630 | # Visual Studio Trace Files 631 | *.e2e 632 | 633 | # TFS 2012 Local Workspace 634 | $tf/ 635 | 636 | # Guidance Automation Toolkit 637 | *.gpState 638 | 639 | # ReSharper is a .NET coding add-in 640 | _ReSharper*/ 641 | *.[Rr]e[Ss]harper 642 | *.DotSettings.user 643 | 644 | # TeamCity is a build add-in 645 | _TeamCity* 646 | 647 | # DotCover is a Code Coverage Tool 648 | *.dotCover 649 | 650 | # AxoCover is a Code Coverage Tool 651 | .axoCover/* 652 | !.axoCover/settings.json 653 | 654 | # Coverlet is a free, cross platform Code Coverage Tool 655 | coverage*.json 656 | coverage*.xml 657 | coverage*.info 658 | 659 | # Visual Studio code coverage results 660 | *.coverage 661 | *.coveragexml 662 | 663 | # NCrunch 664 | _NCrunch_* 665 | .*crunch*.local.xml 666 | nCrunchTemp_* 667 | 668 | # MightyMoose 669 | *.mm.* 670 | AutoTest.Net/ 671 | 672 | # Web workbench (sass) 673 | .sass-cache/ 674 | 675 | # Installshield output folder 676 | [Ee]xpress/ 677 | 678 | # DocProject is a documentation generator add-in 679 | DocProject/buildhelp/ 680 | DocProject/Help/*.HxT 681 | DocProject/Help/*.HxC 682 | DocProject/Help/*.hhc 683 | DocProject/Help/*.hhk 684 | DocProject/Help/*.hhp 685 | DocProject/Help/Html2 686 | DocProject/Help/html 687 | 688 | # Click-Once directory 689 | publish/ 690 | 691 | # Publish Web Output 692 | *.[Pp]ublish.xml 693 | *.azurePubxml 694 | # Note: Comment the next line if you want to checkin your web deploy settings, 695 | # but database connection strings (with potential passwords) will be unencrypted 696 | *.pubxml 697 | *.publishproj 698 | 699 | # Microsoft Azure Web App publish settings. Comment the next line if you want to 700 | # checkin your Azure Web App publish settings, but sensitive information contained 701 | # in these scripts will be unencrypted 702 | PublishScripts/ 703 | 704 | # NuGet Packages 705 | *.nupkg 706 | # NuGet Symbol Packages 707 | *.snupkg 708 | # The packages folder can be ignored because of Package Restore 709 | **/[Pp]ackages/* 710 | # except build/, which is used as an MSBuild target. 711 | !**/[Pp]ackages/build/ 712 | # Uncomment if necessary however generally it will be regenerated when needed 713 | #!**/[Pp]ackages/repositories.config 714 | # NuGet v3's project.json files produces more ignorable files 715 | *.nuget.props 716 | *.nuget.targets 717 | 718 | # Microsoft Azure Build Output 719 | csx/ 720 | *.build.csdef 721 | 722 | # Microsoft Azure Emulator 723 | ecf/ 724 | rcf/ 725 | 726 | # Windows Store app package directories and files 727 | AppPackages/ 728 | BundleArtifacts/ 729 | Package.StoreAssociation.xml 730 | _pkginfo.txt 731 | *.appx 732 | *.appxbundle 733 | *.appxupload 734 | 735 | # Visual Studio cache files 736 | # files ending in .cache can be ignored 737 | *.[Cc]ache 738 | # but keep track of directories ending in .cache 739 | !?*.[Cc]ache/ 740 | 741 | # Others 742 | ClientBin/ 743 | ~$* 744 | *.dbmdl 745 | *.dbproj.schemaview 746 | *.jfm 747 | *.pfx 748 | *.publishsettings 749 | orleans.codegen.cs 750 | 751 | # Including strong name files can present a security risk 752 | # (https://github.com/github/gitignore/pull/2483#issue-259490424) 753 | #*.snk 754 | 755 | # Since there are multiple workflows, uncomment next line to ignore bower_components 756 | # (https://github.com/github/gitignore/pull/1529#issuecomment-104372622) 757 | #bower_components/ 758 | 759 | # RIA/Silverlight projects 760 | Generated_Code/ 761 | 762 | # Backup & report files from converting an old project file 763 | # to a newer Visual Studio version. Backup files are not needed, 764 | # because we have git ;-) 765 | _UpgradeReport_Files/ 766 | Backup*/ 767 | UpgradeLog*.XML 768 | UpgradeLog*.htm 769 | ServiceFabricBackup/ 770 | *.rptproj.bak 771 | 772 | # SQL Server files 773 | *.mdf 774 | *.ldf 775 | *.ndf 776 | 777 | # Business Intelligence projects 778 | *.rdl.data 779 | *.bim.layout 780 | *.bim_*.settings 781 | *.rptproj.rsuser 782 | *- [Bb]ackup.rdl 783 | *- [Bb]ackup ([0-9]).rdl 784 | *- [Bb]ackup ([0-9][0-9]).rdl 785 | 786 | # Microsoft Fakes 787 | FakesAssemblies/ 788 | 789 | # GhostDoc plugin setting file 790 | *.GhostDoc.xml 791 | 792 | # Node.js Tools for Visual Studio 793 | .ntvs_analysis.dat 794 | node_modules/ 795 | 796 | # Visual Studio 6 build log 797 | *.plg 798 | 799 | # Visual Studio 6 workspace options file 800 | *.opt 801 | 802 | # Visual Studio 6 auto-generated workspace file (contains which files were open etc.) 803 | *.vbw 804 | 805 | # Visual Studio 6 auto-generated project file (contains which files were open etc.) 806 | *.vbp 807 | 808 | # Visual Studio 6 workspace and project file (working project files containing files to include in project) 809 | *.dsw 810 | *.dsp 811 | 812 | # Visual Studio 6 technical files 813 | 814 | # Visual Studio LightSwitch build output 815 | **/*.HTMLClient/GeneratedArtifacts 816 | **/*.DesktopClient/GeneratedArtifacts 817 | **/*.DesktopClient/ModelManifest.xml 818 | **/*.Server/GeneratedArtifacts 819 | **/*.Server/ModelManifest.xml 820 | _Pvt_Extensions 821 | 822 | # Paket dependency manager 823 | .paket/paket.exe 824 | paket-files/ 825 | 826 | # FAKE - F# Make 827 | .fake/ 828 | 829 | # CodeRush personal settings 830 | .cr/personal 831 | 832 | # Python Tools for Visual Studio (PTVS) 833 | *.pyc 834 | 835 | # Cake - Uncomment if you are using it 836 | # tools/** 837 | # !tools/packages.config 838 | 839 | # Tabs Studio 840 | *.tss 841 | 842 | # Telerik's JustMock configuration file 843 | *.jmconfig 844 | 845 | # BizTalk build output 846 | *.btp.cs 847 | *.btm.cs 848 | *.odx.cs 849 | *.xsd.cs 850 | 851 | # OpenCover UI analysis results 852 | OpenCover/ 853 | 854 | # Azure Stream Analytics local run output 855 | ASALocalRun/ 856 | 857 | # MSBuild Binary and Structured Log 858 | *.binlog 859 | 860 | # NVidia Nsight GPU debugger configuration file 861 | *.nvuser 862 | 863 | # MFractors (Xamarin productivity tool) working folder 864 | .mfractor/ 865 | 866 | # Local History for Visual Studio 867 | .localhistory/ 868 | 869 | # Visual Studio History (VSHistory) files 870 | .vshistory/ 871 | 872 | # BeatPulse healthcheck temp database 873 | healthchecksdb 874 | 875 | # Backup folder for Package Reference Convert tool in Visual Studio 2017 876 | MigrationBackup/ 877 | 878 | # Ionide (cross platform F# VS Code tools) working folder 879 | .ionide/ 880 | 881 | # Fody - auto-generated XML schema 882 | FodyWeavers.xsd 883 | 884 | # VS Code files for those working on multiple tools 885 | *.code-workspace 886 | 887 | # Local History for Visual Studio Code 888 | 889 | # Windows Installer files from build outputs 890 | 891 | # JetBrains Rider 892 | *.sln.iml 893 | 894 | ### VisualStudio Patch ### 895 | # Additional files built by Visual Studio 896 | 897 | # End of https://www.toptal.com/developers/gitignore/api/macos,linux,windows,python,jupyternotebooks,jetbrains,pycharm,vim,emacs,visualstudiocode,visualstudio 898 | 899 | # VS Code 900 | .vscode 901 | 902 | scratch/ 903 | 904 | # virtual env 905 | venv -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | # See: https://docs.readthedocs.io/en/latest/config-file/v2.html 2 | 3 | version: 2 4 | 5 | build: 6 | image: latest 7 | 8 | python: 9 | version: "3.8" 10 | install: 11 | - method: pip 12 | path: . 13 | extra_requirements: 14 | - docs 15 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Max Berrendorf 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | graft src 2 | graft tests 3 | prune scripts 4 | prune notebooks 5 | prune tests/.pytest_cache 6 | 7 | recursive-include docs/source *.py 8 | recursive-include docs/source *.rst 9 | recursive-include docs/source *.png 10 | recursive-include docs/source *.svg 11 | 12 | global-exclude *.py[cod] __pycache__ *.so *.dylib .DS_Store *.gpickle 13 | 14 | include README.md LICENSE 15 | exclude tox.ini .flake8 .bumpversion.cfg .readthedocs.yml 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 6 | 7 |

8 | torch-ppr 9 |

10 | 11 |

12 | 13 | Tests 14 | 15 | 16 | PyPI 17 | 18 | 19 | PyPI - Python Version 20 | 21 | 22 | PyPI - License 23 | 24 | 25 | Documentation Status 26 | 27 | 28 | Codecov status 29 | 30 | 31 | Cookiecutter template from @cthoyt 32 | 33 | 34 | Code style: black 35 | 36 | 37 | Contributor Covenant 38 | 39 |

40 | 41 | This package allows calculating page-rank and personalized page-rank via power iteration with PyTorch, 42 | which also supports calculation on GPU (or other accelerators). 43 | 44 | ## 💪 Getting Started 45 | 46 | As a simple example, consider this simple graph with five nodes. 47 |

48 | 49 |

50 | 51 | Its edge list is given as 52 | ```python-console 53 | >>> import torch 54 | >>> edge_index = torch.as_tensor(data=[(0, 1), (1, 2), (1, 3), (2, 4)]).t() 55 | ``` 56 | 57 | We can use 58 | ```python-console 59 | >>> from torch_ppr import page_rank 60 | >>> page_rank(edge_index=edge_index) 61 | tensor([0.1269, 0.3694, 0.2486, 0.1269, 0.1281]) 62 | ``` 63 | to calculate the page rank, i.e., a measure of global importance. 64 | We notice that the central node receives the largest importance score, 65 | while all other nodes have lower importance. Moreover, the two 66 | indistinguishable nodes `0` and `3` receive the same page rank. 67 | 68 | We can also calculate *personalized* page rank which measures importance 69 | from the perspective of a single node. 70 | For instance, for node `2`, we have 71 | ```python-console 72 | >>> from torch_ppr import personalized_page_rank 73 | >>> personalized_page_rank(edge_index=edge_index, indices=[2]) 74 | tensor([[0.1103, 0.3484, 0.2922, 0.1103, 0.1388]]) 75 | ``` 76 | Thus, the most important node is the central node `1`, nodes `0` and `3` receive 77 | the same importance value which is below the value of the direct neighbor `4`. 78 | 79 | By the virtue of using PyTorch, the code seamlessly works on GPUs, too, and 80 | supports auto-grad differentiation. Moreover, the calculation of personalized 81 | page rank supports automatic batch size optimization via 82 | [`torch_max_mem`](https://github.com/mberr/torch-max-mem). 83 | 84 | ## 🚀 Installation 85 | 86 | The most recent release can be installed from 87 | [PyPI](https://pypi.org/project/torch_ppr/) with: 88 | 89 | ```bash 90 | $ pip install torch_ppr 91 | ``` 92 | 93 | The most recent code and data can be installed directly from GitHub with: 94 | 95 | ```bash 96 | $ pip install git+https://github.com/mberr/torch-ppr.git 97 | ``` 98 | 99 | ## 👐 Contributing 100 | 101 | Contributions, whether filing an issue, making a pull request, or forking, are appreciated. See 102 | [CONTRIBUTING.md](https://github.com/mberr/torch-ppr/blob/master/.github/CONTRIBUTING.md) for more information on getting involved. 103 | 104 | ## 👋 Attribution 105 | 106 | ### ⚖️ License 107 | 108 | The code in this package is licensed under the MIT License. 109 | 110 | 115 | 116 | 124 | 125 | 134 | 135 | ### 🍪 Cookiecutter 136 | 137 | This package was created with [@audreyfeldroy](https://github.com/audreyfeldroy)'s 138 | [cookiecutter](https://github.com/cookiecutter/cookiecutter) package using [@cthoyt](https://github.com/cthoyt)'s 139 | [cookiecutter-snekpack](https://github.com/cthoyt/cookiecutter-snekpack) template. 140 | 141 | ## 🛠️ For Developers 142 | 143 |
144 | See developer instructions 145 | 146 | 147 | The final section of the README is for if you want to get involved by making a code contribution. 148 | 149 | ### Development Installation 150 | 151 | To install in development mode, use the following: 152 | 153 | ```bash 154 | $ git clone git+https://github.com/mberr/torch-ppr.git 155 | $ cd torch-ppr 156 | $ pip install -e . 157 | ``` 158 | 159 | ### 🥼 Testing 160 | 161 | After cloning the repository and installing `tox` with `pip install tox`, the unit tests in the `tests/` folder can be 162 | run reproducibly with: 163 | 164 | ```shell 165 | $ tox 166 | ``` 167 | 168 | Additionally, these tests are automatically re-run with each commit in a [GitHub Action](https://github.com/mberr/torch-ppr/actions?query=workflow%3ATests). 169 | 170 | ### 📖 Building the Documentation 171 | 172 | The documentation can be built locally using the following: 173 | 174 | ```shell 175 | $ git clone git+https://github.com/mberr/torch-ppr.git 176 | $ cd torch-ppr 177 | $ tox -e docs 178 | $ open docs/build/html/index.html 179 | ``` 180 | 181 | The documentation automatically installs the package as well as the `docs` 182 | extra specified in the [`setup.cfg`](setup.cfg). `sphinx` plugins 183 | like `texext` can be added there. Additionally, they need to be added to the 184 | `extensions` list in [`docs/source/conf.py`](docs/source/conf.py). 185 | 186 | ### 📦 Making a Release 187 | 188 | After installing the package in development mode and installing 189 | `tox` with `pip install tox`, the commands for making a new release are contained within the `finish` environment 190 | in `tox.ini`. Run the following from the shell: 191 | 192 | ```shell 193 | $ tox -e finish 194 | ``` 195 | 196 | This script does the following: 197 | 198 | 1. Uses [Bump2Version](https://github.com/c4urself/bump2version) to switch the version number in the `setup.cfg`, 199 | `src/torch_ppr/version.py`, and [`docs/source/conf.py`](docs/source/conf.py) to not have the `-dev` suffix 200 | 2. Packages the code in both a tar archive and a wheel using [`build`](https://github.com/pypa/build) 201 | 3. Uploads to PyPI using [`twine`](https://github.com/pypa/twine). Be sure to have a `.pypirc` file configured to avoid the need for manual input at this 202 | step 203 | 4. Push to GitHub. You'll need to make a release going with the commit where the version was bumped. 204 | 5. Bump the version to the next patch. If you made big changes and want to bump the version by minor, you can 205 | use `tox -e bumpversion minor` after. 206 |
207 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # 3 | # Configuration file for the Sphinx documentation builder. 4 | # 5 | # This file does only contain a selection of the most common options. For a 6 | # full list see the documentation: 7 | # http://www.sphinx-doc.org/en/master/config 8 | 9 | # -- Path setup -------------------------------------------------------------- 10 | 11 | # If extensions (or modules to document with autodoc) are in another directory, 12 | # add these directories to sys.path here. If the directory is relative to the 13 | # documentation root, use os.path.abspath to make it absolute, like shown here. 14 | # 15 | 16 | import os 17 | import re 18 | import sys 19 | from datetime import date 20 | 21 | sys.path.insert(0, os.path.abspath("../../src")) 22 | 23 | # -- Project information ----------------------------------------------------- 24 | 25 | project = "torch_ppr" 26 | copyright = f"{date.today().year}, Max Berrendorf" 27 | author = "Max Berrendorf" 28 | 29 | # The full version, including alpha/beta/rc tags. 30 | release = "0.0.9-dev" 31 | 32 | # The short X.Y version. 33 | parsed_version = re.match( 34 | "(?P\d+)\.(?P\d+)\.(?P\d+)(?:-(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?", 35 | release, 36 | ) 37 | version = parsed_version.expand("\g.\g.\g") 38 | 39 | if parsed_version.group("release"): 40 | tags.add("prerelease") 41 | 42 | # -- General configuration --------------------------------------------------- 43 | 44 | # If your documentation needs a minimal Sphinx version, state it here. 45 | # 46 | # needs_sphinx = '1.0' 47 | 48 | # If true, the current module name will be prepended to all description 49 | # unit titles (such as .. function::). 50 | add_module_names = False 51 | 52 | # A list of prefixes that are ignored when creating the module index. (new in Sphinx 0.6) 53 | modindex_common_prefix = ["torch_ppr."] 54 | 55 | # Add any Sphinx extension module names here, as strings. They can be 56 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 57 | # ones. 58 | extensions = [ 59 | "sphinx.ext.autosummary", 60 | "sphinx.ext.autodoc", 61 | "sphinx.ext.coverage", 62 | "sphinx.ext.intersphinx", 63 | "sphinx.ext.todo", 64 | "sphinx.ext.mathjax", 65 | "sphinx.ext.viewcode", 66 | "sphinx_autodoc_typehints", 67 | "sphinx_automodapi.automodapi", 68 | "sphinx_automodapi.smart_resolver", 69 | # 'texext', 70 | ] 71 | 72 | 73 | # generate autosummary pages 74 | autosummary_generate = True 75 | 76 | # Add any paths that contain templates here, relative to this directory. 77 | templates_path = ["_templates"] 78 | 79 | # The suffix(es) of source filenames. 80 | # You can specify multiple suffix as a list of string: 81 | # 82 | # source_suffix = ['.rst', '.md'] 83 | source_suffix = ".rst" 84 | 85 | # The master toctree document. 86 | master_doc = "index" 87 | 88 | # The language for content autogenerated by Sphinx. Refer to documentation 89 | # for a list of supported languages. 90 | # 91 | # This is also used if you do content translation via gettext catalogs. 92 | # Usually you set "language" from the command line for these cases. 93 | language = "en" 94 | 95 | # List of patterns, relative to source directory, that match files and 96 | # directories to ignore when looking for source files. 97 | # This pattern also affects html_static_path and html_extra_path. 98 | exclude_patterns = [] 99 | 100 | # The name of the Pygments (syntax highlighting) style to use. 101 | pygments_style = "sphinx" 102 | 103 | # -- Options for HTML output ------------------------------------------------- 104 | 105 | # The theme to use for HTML and HTML Help pages. See the documentation for 106 | # a list of builtin themes. 107 | # 108 | html_theme = "sphinx_rtd_theme" 109 | 110 | # Theme options are theme-specific and customize the look and feel of a theme 111 | # further. For a list of options available for each theme, see the 112 | # documentation. 113 | # 114 | # html_theme_options = {} 115 | 116 | # Add any paths that contain custom static files (such as style sheets) here, 117 | # relative to this directory. They are copied after the builtin static files, 118 | # so a file named "default.css" will overwrite the builtin "default.css". 119 | # html_static_path = ['_static'] 120 | 121 | # Custom sidebar templates, must be a dictionary that maps document names 122 | # to template names. 123 | # 124 | # The default sidebars (for documents that don't match any pattern) are 125 | # defined by theme itself. Builtin themes are using these templates by 126 | # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', 127 | # 'searchbox.html']``. 128 | # 129 | # html_sidebars = {} 130 | 131 | # The name of an image file (relative to this directory) to place at the top 132 | # of the sidebar. 133 | # 134 | if os.path.exists("logo.png"): 135 | html_logo = "logo.png" 136 | 137 | # -- Options for HTMLHelp output --------------------------------------------- 138 | 139 | # Output file base name for HTML help builder. 140 | htmlhelp_basename = "torch-pprdoc" 141 | 142 | # -- Options for LaTeX output ------------------------------------------------ 143 | 144 | # latex_elements = { 145 | # The paper size ('letterpaper' or 'a4paper'). 146 | # 147 | # 'papersize': 'letterpaper', 148 | # 149 | # The font size ('10pt', '11pt' or '12pt'). 150 | # 151 | # 'pointsize': '10pt', 152 | # 153 | # Additional stuff for the LaTeX preamble. 154 | # 155 | # 'preamble': '', 156 | # 157 | # Latex figure (float) alignment 158 | # 159 | # 'figure_align': 'htbp', 160 | # } 161 | 162 | # Grouping the document tree into LaTeX files. List of tuples 163 | # (source start file, target name, title, 164 | # author, documentclass [howto, manual, or own class]). 165 | # latex_documents = [ 166 | # ( 167 | # master_doc, 168 | # 'torch_ppr.tex', 169 | # 'torch-ppr Documentation', 170 | # author, 171 | # 'manual', 172 | # ), 173 | # ] 174 | 175 | # -- Options for manual page output ------------------------------------------ 176 | 177 | # One entry per manual page. List of tuples 178 | # (source start file, name, description, authors, manual section). 179 | man_pages = [ 180 | ( 181 | master_doc, 182 | "torch_ppr", 183 | "torch-ppr Documentation", 184 | [author], 185 | 1, 186 | ), 187 | ] 188 | 189 | # -- Options for Texinfo output ---------------------------------------------- 190 | 191 | # Grouping the document tree into Texinfo files. List of tuples 192 | # (source start file, target name, title, author, 193 | # dir menu entry, description, category) 194 | texinfo_documents = [ 195 | ( 196 | master_doc, 197 | "torch_ppr", 198 | "torch-ppr Documentation", 199 | author, 200 | "Max Berrendorf", 201 | "(Personalized) Page-Rank computation using PyTorch", 202 | "Miscellaneous", 203 | ), 204 | ] 205 | 206 | # -- Options for Epub output ------------------------------------------------- 207 | 208 | # Bibliographic Dublin Core info. 209 | # epub_title = project 210 | 211 | # The unique identifier of the text. This can be a ISBN number 212 | # or the project homepage. 213 | # 214 | # epub_identifier = '' 215 | 216 | # A unique identification for the text. 217 | # 218 | # epub_uid = '' 219 | 220 | # A list of files that should not be packed into the epub file. 221 | # epub_exclude_files = ['search.html'] 222 | 223 | # -- Extension configuration ------------------------------------------------- 224 | 225 | # -- Options for intersphinx extension --------------------------------------- 226 | 227 | # Example configuration for intersphinx: refer to the Python standard library. 228 | intersphinx_mapping = { 229 | "https://docs.python.org/3/": None, 230 | "torch": ("https://pytorch.org/docs/stable", None), 231 | "torch_max_mem": ("https://torch-max-mem.readthedocs.io/en/stable/", None), 232 | } 233 | 234 | autoclass_content = "both" 235 | 236 | # Don't sort alphabetically, explained at: 237 | # https://stackoverflow.com/questions/37209921/python-how-not-to-sort-sphinx-output-in-alphabetical-order 238 | autodoc_member_order = "bysource" 239 | -------------------------------------------------------------------------------- /docs/source/details.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ========= 3 | .. automodule:: torch_ppr.utils 4 | :members: 5 | -------------------------------------------------------------------------------- /docs/source/img/small_graph.svg: -------------------------------------------------------------------------------- 1 | 2 | 4 | 5 | 6 | 7 | 8 | 9 | 2022-05-06T17:12:54.941966 10 | image/svg+xml 11 | 12 | 13 | Matplotlib v3.5.2, https://matplotlib.org/ 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 31 | 32 | 33 | 34 | 37 | 40 | 43 | 46 | 47 | 48 | 49 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | torch-ppr 2 | ========= 3 | 4 | `torch-ppr` is a package to calculate (personalized) page ranks with PyTorch. 5 | 6 | As a simple example, consider this simple graph with five nodes. 7 | 8 | .. image:: img/small_graph.svg 9 | :width: 400 10 | :alt: Simple Example Graph 11 | 12 | Its edge list is given as 13 | 14 | .. code-block:: python 15 | 16 | edge_index = torch.as_tensor(data=[(0, 1), (1, 2), (1, 3), (2, 4)]).t() 17 | 18 | We can use 19 | 20 | .. code-block:: python 21 | 22 | >>> page_rank(edge_index=edge_index) 23 | tensor([0.1269, 0.3694, 0.2486, 0.1269, 0.1281]) 24 | 25 | 26 | to calculate the page rank, i.e., a measure of global importance. 27 | We notice that the central node receives the largest importance score, 28 | while all other nodes have equal importance. 29 | 30 | We can also calculate *personalized* page rank which measures importance 31 | from the perspective of a single node. 32 | For instance, for node `2`, we have 33 | 34 | .. code-block:: python 35 | 36 | >>> personalized_page_rank(edge_index=edge_index, indices=[2]) 37 | tensor([[0.1103, 0.3484, 0.2922, 0.1103, 0.1388]]) 38 | 39 | 40 | By the virtue of using PyTorch, the code seamlessly works on GPUs, too, and 41 | supports auto-grad differentiation. Moreover, the calculation of personalized 42 | page rank supports automatic batch size optimization via 43 | `torch_max_mem `_. 44 | 45 | Table of Contents 46 | ----------------- 47 | .. toctree:: 48 | :maxdepth: 2 49 | :caption: Getting Started 50 | :name: start 51 | 52 | installation 53 | usage 54 | details 55 | 56 | 57 | Indices and Tables 58 | ------------------ 59 | * :ref:`genindex` 60 | * :ref:`modindex` 61 | * :ref:`search` 62 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | Installation 2 | ============ 3 | The most recent release can be installed from 4 | `PyPI `_ with: 5 | 6 | .. code-block:: shell 7 | 8 | $ pip install torch_ppr 9 | 10 | The most recent code and data can be installed directly from GitHub with: 11 | 12 | .. code-block:: shell 13 | 14 | $ pip install git+https://github.com/mberr/torch-ppr.git 15 | 16 | To install in development mode, use the following: 17 | 18 | .. code-block:: shell 19 | 20 | $ git clone git+https://github.com/mberr/torch-ppr.git 21 | $ cd torch-ppr 22 | $ pip install -e . 23 | -------------------------------------------------------------------------------- /docs/source/usage.rst: -------------------------------------------------------------------------------- 1 | Usage 2 | ===== 3 | .. automodule:: torch_ppr.api 4 | :members: 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # See https://setuptools.readthedocs.io/en/latest/build_meta.html 2 | [build-system] 3 | requires = ["setuptools", "wheel"] 4 | build-backend = "setuptools.build_meta:__legacy__" 5 | 6 | [tool.black] 7 | line-length = 100 8 | target-version = ["py38", "py39", "py310"] 9 | 10 | [tool.isort] 11 | profile = "black" 12 | multi_line_output = 3 13 | line_length = 100 14 | include_trailing_comma = true 15 | reverse_relative = true 16 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | ########################## 2 | # Setup.py Configuration # 3 | ########################## 4 | [metadata] 5 | name = torch_ppr 6 | version = 0.0.9-dev 7 | description = (Personalized) Page-Rank computation using PyTorch 8 | long_description = file: README.md 9 | long_description_content_type = text/markdown 10 | 11 | # URLs associated with the project 12 | url = https://github.com/mberr/torch-ppr 13 | download_url = https://github.com/mberr/torch-ppr/releases 14 | project_urls = 15 | Bug Tracker = https://github.com/mberr/torch-ppr/issues 16 | Source Code = https://github.com/mberr/torch-ppr 17 | 18 | # Author information 19 | author = Max Berrendorf 20 | author_email = max.berrendorf@gmail.com 21 | maintainer = Max Berrendorf 22 | maintainer_email = max.berrendorf@gmail.com 23 | 24 | # License Information 25 | license = MIT 26 | license_file = LICENSE 27 | 28 | # Search tags 29 | classifiers = 30 | Development Status :: 1 - Planning 31 | Environment :: Console 32 | Intended Audience :: Developers 33 | License :: OSI Approved :: MIT License 34 | Operating System :: OS Independent 35 | Framework :: Pytest 36 | Framework :: tox 37 | Framework :: Sphinx 38 | Programming Language :: Python 39 | Programming Language :: Python :: 3.8 40 | Programming Language :: Python :: 3.9 41 | Programming Language :: Python :: 3.10 42 | Programming Language :: Python :: 3 :: Only 43 | # TODO add your topics from the Trove controlled vocabulary (see https://pypi.org/classifiers) 44 | keywords = 45 | snekpack 46 | cookiecutter 47 | # TODO add your own free-text keywords 48 | 49 | [options] 50 | install_requires = 51 | # Missing itertools from the standard library you didn't know you needed 52 | more_itertools 53 | # Use progress bars excessively 54 | tqdm 55 | torch>=1.11 56 | # for automatic batch size selection 57 | torch_max_mem 58 | 59 | # Random options 60 | zip_safe = false 61 | include_package_data = True 62 | python_requires = >=3.8 63 | 64 | # Where is my code 65 | packages = find: 66 | package_dir = 67 | = src 68 | 69 | [options.packages.find] 70 | where = src 71 | 72 | [options.extras_require] 73 | tests = 74 | pytest 75 | coverage 76 | docs = 77 | sphinx 78 | sphinx-rtd-theme 79 | 80 | sphinx-autodoc-typehints 81 | sphinx_automodapi 82 | # To include LaTeX comments easily in your docs. 83 | # If you uncomment this, don't forget to do the same in docs/conf.py 84 | # texext 85 | 86 | 87 | 88 | ###################### 89 | # Doc8 Configuration # 90 | # (doc8.ini) # 91 | ###################### 92 | [doc8] 93 | max-line-length = 120 94 | 95 | ########################## 96 | # Coverage Configuration # 97 | # (.coveragerc) # 98 | ########################## 99 | [coverage:run] 100 | branch = True 101 | source = torch_ppr 102 | omit = 103 | tests/* 104 | docs/* 105 | 106 | [coverage:paths] 107 | source = 108 | src/torch_ppr 109 | .tox/*/lib/python*/site-packages/torch_ppr 110 | 111 | [coverage:report] 112 | show_missing = True 113 | exclude_lines = 114 | pragma: no cover 115 | raise NotImplementedError 116 | if __name__ == __main__: 117 | if TYPE_CHECKING: 118 | def __str__ 119 | def __repr__ 120 | 121 | ########################## 122 | # Darglint Configuration # 123 | ########################## 124 | [darglint] 125 | docstring_style = sphinx 126 | strictness = short 127 | 128 | ######################### 129 | # Flake8 Configuration # 130 | # (.flake8) # 131 | ######################### 132 | [flake8] 133 | ignore = 134 | # Line break before binary operator (flake8 is wrong) 135 | W503 136 | # whitespace before ':' 137 | #E203 138 | exclude = 139 | .tox, 140 | .git, 141 | __pycache__, 142 | docs/source/conf.py, 143 | build, 144 | dist, 145 | tests/fixtures/*, 146 | *.pyc, 147 | *.egg-info, 148 | .cache, 149 | .eggs, 150 | data 151 | max-line-length = 120 152 | max-complexity = 20 153 | import-order-style = pycharm 154 | application-import-names = 155 | torch_ppr 156 | tests 157 | per-file-ignores = 158 | # assertions with pytest 159 | tests/*:S101 160 | -------------------------------------------------------------------------------- /src/torch_ppr/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """(Personalized) Page-Rank computation using PyTorch.""" 4 | 5 | from .api import page_rank, personalized_page_rank 6 | 7 | __all__ = [ 8 | "page_rank", 9 | "personalized_page_rank", 10 | ] 11 | -------------------------------------------------------------------------------- /src/torch_ppr/api.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """The public API.""" 4 | 5 | import logging 6 | from typing import Optional 7 | 8 | import torch 9 | 10 | from .utils import ( 11 | DeviceHint, 12 | batched_personalized_page_rank, 13 | power_iteration, 14 | prepare_page_rank_adjacency, 15 | prepare_x0, 16 | resolve_device, 17 | validate_adjacency, 18 | validate_x, 19 | ) 20 | 21 | __all__ = [ 22 | "page_rank", 23 | "personalized_page_rank", 24 | ] 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | 29 | def page_rank( 30 | *, 31 | adj: Optional[torch.Tensor] = None, 32 | edge_index: Optional[torch.LongTensor] = None, 33 | num_nodes: Optional[int] = None, 34 | add_identity: bool = False, 35 | max_iter: int = 1_000, 36 | alpha: float = 0.05, 37 | epsilon: float = 1.0e-04, 38 | x0: Optional[torch.Tensor] = None, 39 | use_tqdm: bool = False, 40 | device: DeviceHint = None, 41 | ) -> torch.Tensor: 42 | """ 43 | Compute page rank by power iteration. 44 | 45 | :param adj: 46 | the adjacency matrix, cf. :func:`torch_ppr.utils.prepare_page_rank_adjacency`. Preferred over ``edge_index``. 47 | :param edge_index: shape: ``(2, m)`` 48 | the edge index of the graph, i.e, the edge list. cf. :func:`torch_ppr.utils.prepare_page_rank_adjacency` 49 | :param num_nodes: 50 | the number of nodes used to determine the shape of the adjacency matrix. 51 | If ``None``, and ``adj`` is not already provided, it is inferred from ``edge_index``. 52 | :param add_identity: 53 | whether to add an identity matrix to ``A`` to ensure that each node has a degree of at least one. 54 | 55 | :param max_iter: ``max_iter > 0`` 56 | the maximum number of iterations 57 | :param alpha: ``0 < alpha < 1`` 58 | the smoothing value / teleport probability 59 | :param epsilon: ``epsilon > 0`` 60 | a (small) constant to check for convergence 61 | :param x0: shape: ``(n,)`` 62 | the initial value for ``x``. If ``None``, set to a constant $1/n$ vector, 63 | cf. :func:`torch_ppr.utils.prepare_x0`. Otherwise, the tensor is checked for being valid using 64 | :func:`torch_ppr.utils.validate_x`. 65 | :param use_tqdm: 66 | whether to use a tqdm progress bar 67 | :param device: 68 | the device to use, or a hint thereof 69 | 70 | 71 | :return: shape: ``(n,)`` or ``(batch_size, n)`` 72 | the page-rank vector, i.e., a score between 0 and 1 for each node. 73 | """ 74 | # normalize inputs 75 | adj = prepare_page_rank_adjacency( 76 | adj=adj, edge_index=edge_index, num_nodes=num_nodes, add_identity=add_identity 77 | ) 78 | validate_adjacency(adj=adj) 79 | 80 | x0 = prepare_x0(x0=x0, n=adj.shape[0]) 81 | 82 | # input normalization 83 | validate_x(x=x0, n=adj.shape[0]) 84 | 85 | # power iteration 86 | x = power_iteration( 87 | adj=adj, 88 | x0=x0, 89 | alpha=alpha, 90 | max_iter=max_iter, 91 | use_tqdm=use_tqdm, 92 | epsilon=epsilon, 93 | device=device, 94 | ) 95 | if x.ndim < 2: 96 | return x 97 | return x.t() 98 | 99 | 100 | def personalized_page_rank( 101 | *, 102 | adj: Optional[torch.Tensor] = None, 103 | edge_index: Optional[torch.LongTensor] = None, 104 | add_identity: bool = False, 105 | num_nodes: Optional[int] = None, 106 | indices: Optional[torch.Tensor] = None, 107 | device: DeviceHint = None, 108 | batch_size: Optional[int] = None, 109 | **kwargs, 110 | ) -> torch.Tensor: 111 | """ 112 | Personalized Page-Rank (PPR) computation. 113 | 114 | .. note:: 115 | this method supports automatic memory optimization / batch size selection using :mod:`torch_max_mem`. 116 | 117 | :param adj: shape: ``(n, n)`` 118 | the adjacency matrix, cf. :func:`torch_ppr.utils.prepare_page_rank_adjacency` 119 | :param edge_index: shape: ``(2, m)`` 120 | the edge index, cf. :func:`torch_ppr.utils.prepare_page_rank_adjacency` 121 | :param num_nodes: 122 | the number of nodes used to determine the shape of the adjacency matrix. 123 | If ``None``, and ``adj`` is not already provided, it is inferred from ``edge_index``. 124 | :param add_identity: 125 | whether to add an identity matrix to ``A`` to ensure that each node has a degree of at least one. 126 | 127 | :param indices: shape: ``(k,)`` 128 | the node indices for which to calculate the PPR. Defaults to all nodes. 129 | :param device: 130 | the device to use 131 | :param batch_size: ``batch_size > 0`` 132 | the batch size. Defaults to the number of indices. It will be reduced if necessary. 133 | :param kwargs: 134 | additional keyword-based parameters passed to :func:`torch_ppr.utils.batched_personalized_page_rank` 135 | 136 | :return: shape: ``(k, n)`` 137 | the PPR vectors for each node index 138 | 139 | The following shows an example where a custom adjacency matrix is provided. For illustrative purposes, we randomly 140 | generate one: 141 | 142 | >>> import torch 143 | >>> adj = (torch.rand(300, 300)*10).round().to_sparse() 144 | 145 | Next, we need to ensure that the matrix is row-normalized, i.e., the individual rows sum to 1. Here, we re-use a 146 | utility method provided by the library: 147 | 148 | >>> from torch_ppr.utils import sparse_normalize 149 | >>> adj_normalized = sparse_normalize(adj, dim=0) 150 | 151 | Finally, we can use this matrix to calculate the personalized page rank for some nodes 152 | 153 | >>> from torch_ppr import personalized_page_rank 154 | >>> indices = torch.as_tensor([1, 2], dtype=torch.long) 155 | >>> ppr = personalized_page_rank(adj=adj_normalized, indices=indices) 156 | """ 157 | # resolve device first 158 | device = resolve_device(device=device) 159 | # prepare adjacency and indices only once 160 | adj = prepare_page_rank_adjacency( 161 | adj=adj, edge_index=edge_index, num_nodes=num_nodes, add_identity=add_identity 162 | ).to(device=device) 163 | validate_adjacency(adj=adj) 164 | 165 | if indices is None: 166 | indices = torch.arange(adj.shape[0], device=device) 167 | else: 168 | indices = torch.as_tensor(indices, dtype=torch.long, device=device) 169 | # normalize inputs 170 | batch_size = batch_size or len(indices) 171 | return batched_personalized_page_rank( 172 | adj=adj, indices=indices, device=device, batch_size=batch_size, **kwargs 173 | ).t() 174 | -------------------------------------------------------------------------------- /src/torch_ppr/py.typed: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/torch_ppr/utils.py: -------------------------------------------------------------------------------- 1 | """Utility functions.""" 2 | import logging 3 | from typing import Any, Collection, Mapping, Optional, Union 4 | 5 | import torch 6 | from torch.nn import functional 7 | from torch_max_mem import MemoryUtilizationMaximizer 8 | from tqdm.auto import tqdm 9 | 10 | __all__ = [ 11 | "DeviceHint", 12 | "resolve_device", 13 | "prepare_num_nodes", 14 | "edge_index_to_sparse_matrix", 15 | "prepare_page_rank_adjacency", 16 | "validate_x", 17 | "prepare_x0", 18 | "power_iteration", 19 | "batched_personalized_page_rank", 20 | "sparse_normalize", 21 | ] 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | DeviceHint = Union[None, str, torch.device] 26 | 27 | 28 | def resolve_device(device: DeviceHint = None) -> torch.device: 29 | """ 30 | Resolve the device to use. 31 | 32 | :param device: 33 | the device hint 34 | 35 | :return: 36 | the resolved device 37 | """ 38 | # pass-through torch.device 39 | if isinstance(device, torch.device): 40 | return device 41 | if device is None: 42 | if torch.cuda.is_available(): 43 | device = "cuda" 44 | else: 45 | device = "cpu" 46 | device = torch.device(device=device) 47 | logger.info(f"Resolved device={device}") 48 | return device 49 | 50 | 51 | def prepare_num_nodes(edge_index: torch.Tensor, num_nodes: Optional[int] = None) -> int: 52 | """ 53 | Prepare the number of nodes. 54 | 55 | If an explicit number is given, this number will be used. Otherwise, infers the number of nodes as the maximum id 56 | in the edge index. 57 | 58 | :param edge_index: shape: ``(2, m)`` 59 | the edge index 60 | :param num_nodes: 61 | the number of nodes. If ``None``, it is inferred from ``edge_index``. 62 | 63 | :return: 64 | the number of nodes 65 | """ 66 | if num_nodes is not None: 67 | return num_nodes 68 | 69 | num_nodes = edge_index.max().item() + 1 70 | logger.info(f"Inferred num_nodes={num_nodes}") 71 | return num_nodes 72 | 73 | 74 | def edge_index_to_sparse_matrix( 75 | edge_index: torch.LongTensor, num_nodes: Optional[int] = None 76 | ) -> torch.Tensor: 77 | """ 78 | Convert an edge index to a sparse matrix. 79 | 80 | Uses the edge index for non-zero entries, and fills in ``1`` as entries. 81 | 82 | :param edge_index: shape: ``(2, m)`` 83 | the edge index 84 | :param num_nodes: 85 | the number of nodes used to determine the shape of the matrix. 86 | If ``None``, it is inferred from ``edge_index``. 87 | 88 | :return: shape: ``(n, n)`` 89 | the adjacency matrix as a sparse tensor, cf. :func:`torch.sparse_coo_tensor`. 90 | """ 91 | num_nodes = prepare_num_nodes(edge_index=edge_index, num_nodes=num_nodes) 92 | return torch.sparse_coo_tensor( 93 | indices=edge_index, 94 | values=torch.ones_like(edge_index[0], dtype=torch.get_default_dtype()), 95 | size=(num_nodes, num_nodes), 96 | ) 97 | 98 | 99 | def validate_adjacency(adj: torch.Tensor, n: Optional[int] = None, rtol: float = 1.0e-04): 100 | """ 101 | Validate the page-rank adjacency matrix. 102 | 103 | In particular, the method checks that 104 | 105 | - the shape is ``(n, n)`` 106 | - the row-sum is ``1`` 107 | 108 | :param adj: shape: ``(n, n)`` 109 | the adjacency matrix 110 | :param n: 111 | the number of nodes 112 | :param rtol: 113 | the tolerance for checking the sum is close to 1.0 114 | 115 | :raises ValueError: 116 | if the adjacency matrix is invalid 117 | """ 118 | # check dtype 119 | if not torch.is_floating_point(adj): 120 | if adj.shape[0] == 2 and adj.shape[1] != 2: 121 | logger.warning( 122 | "The passed adjacency matrix looks like an edge_index; did you pass it for the wrong parameter?" 123 | ) 124 | raise ValueError( 125 | f"Invalid adjacency matrix data type: {adj.dtype}, should be a floating dtype." 126 | ) 127 | 128 | # check shape 129 | if n is None: 130 | n = adj.shape[0] 131 | if adj.shape != (n, n): 132 | raise ValueError(f"Invalid adjacency matrix shape: {adj.shape}. expected: {(n, n)}") 133 | 134 | # check value range 135 | if adj.is_sparse and not adj.is_sparse_csr: 136 | adj = adj.coalesce() 137 | values = adj.values() 138 | if (values < 0.0).any() or (values > 1.0).any(): 139 | raise ValueError( 140 | f"Invalid values outside of [0, 1]: min={values.min().item()}, max={values.max().item()}" 141 | ) 142 | 143 | # check column-sum 144 | if adj.is_sparse and not adj.is_sparse_csr: 145 | adj_sum = torch.sparse.sum(adj, dim=0).to_dense() 146 | else: 147 | # hotfix until torch.sparse.sum is implemented 148 | adj_sum = adj.t() @ torch.ones(adj.shape[0]) 149 | exp_sum = torch.ones_like(adj_sum) 150 | mask = adj_sum == 0 151 | if mask.any(): 152 | logger.warning(f"Adjacency contains {mask.sum().item()} isolated nodes.") 153 | exp_sum[mask] = 0.0 154 | if not torch.allclose(adj_sum, exp_sum, rtol=rtol): 155 | raise ValueError( 156 | f"Invalid column sum: {adj_sum} (min: {adj_sum.min().item()}, max: {adj_sum.max().item()}). " 157 | f"Expected 1.0 with a relative tolerance of {rtol}.", 158 | ) 159 | 160 | 161 | def sparse_diagonal(values: torch.Tensor) -> torch.Tensor: 162 | """Create a sparse diagonal matrix with the given values. 163 | 164 | :param values: shape: ``(n,)`` 165 | the values 166 | 167 | :return: shape: ``(n, n)`` 168 | a sparse diagonal matrix 169 | """ 170 | return torch.sparse_coo_tensor( 171 | indices=torch.arange(values.shape[0], device=values.device).unsqueeze(dim=0).repeat(2, 1), 172 | values=values, 173 | ) 174 | 175 | 176 | def sparse_normalize(matrix: torch.Tensor, dim: int = 0) -> torch.Tensor: 177 | """ 178 | Normalize a sparse matrix to row/column sum of 1. 179 | 180 | :param matrix: 181 | the sparse matrix 182 | :param dim: 183 | the dimension along which to normalize, either 0 for rows or 1 for columns 184 | 185 | :return: 186 | the normalized sparse matrix 187 | """ 188 | # calculate row/column sum 189 | row_or_column_sum = ( 190 | torch.sparse.sum(matrix, dim=dim).to_dense().clamp_min(min=torch.finfo(matrix.dtype).eps) 191 | ) 192 | # invert and create diagonal matrix 193 | scaling_matrix = sparse_diagonal(values=torch.reciprocal(row_or_column_sum)) 194 | # multiply matrix 195 | if dim == 0: 196 | args = (matrix, scaling_matrix) 197 | else: 198 | args = (scaling_matrix, matrix) 199 | # note: we do not pass by keyword due to instable API 200 | return torch.sparse.mm(*args) 201 | 202 | 203 | def prepare_page_rank_adjacency( 204 | adj: Optional[torch.Tensor] = None, 205 | edge_index: Optional[torch.LongTensor] = None, 206 | num_nodes: Optional[int] = None, 207 | add_identity: bool = False, 208 | ) -> torch.Tensor: 209 | """ 210 | Prepare the page-rank adjacency matrix. 211 | 212 | If no explicit adjacency is given, the methods first creates an adjacency matrix from the edge index, 213 | cf. :func:`edge_index_to_sparse_matrix`. Next, the matrix is symmetrized as 214 | 215 | .. math:: 216 | A := A + A^T 217 | 218 | Finally, the matrix is normalized such that the columns sum to one. 219 | 220 | :param adj: shape: ``(n, n)`` 221 | the adjacency matrix 222 | :param edge_index: shape: ``(2, m)`` 223 | the edge index 224 | :param num_nodes: 225 | the number of nodes used to determine the shape of the adjacency matrix. 226 | If ``None``, and ``adj`` is not already provided, it is inferred from ``edge_index``. 227 | :param add_identity: 228 | whether to add an identity matrix to ``A`` to ensure that each node has a degree of at least one. 229 | 230 | :raises ValueError: 231 | if neither is provided, or the adjacency matrix is invalid 232 | 233 | :return: shape: ``(n, n)`` 234 | the symmetric, normalized, and sparse adjacency matrix 235 | """ 236 | if adj is not None: 237 | return adj 238 | 239 | if edge_index is None: 240 | raise ValueError("Must provide at least one of `adj` and `edge_index`.") 241 | 242 | # convert to sparse matrix, shape: (n, n) 243 | adj = edge_index_to_sparse_matrix(edge_index=edge_index, num_nodes=num_nodes) 244 | # symmetrize 245 | adj = adj + adj.t() 246 | # add identity matrix if requested 247 | if add_identity: 248 | adj = adj + sparse_diagonal(torch.ones(adj.shape[0], dtype=adj.dtype, device=adj.device)) 249 | 250 | # adjacency normalization: normalize to row-sum = 1 251 | return sparse_normalize(matrix=adj, dim=0) 252 | 253 | 254 | def validate_x(x: torch.Tensor, n: Optional[int] = None) -> None: 255 | """ 256 | Validate a (batched) page-rank vector. 257 | 258 | In particular, the method checks that 259 | 260 | - the tensor dimension is ``(n,)`` or ``(n, batch_size)`` 261 | - all entries are between ``0`` and ``1`` 262 | - the entries sum to ``1`` (along the first dimension) 263 | 264 | :param x: 265 | the initial value. 266 | :param n: 267 | the number of nodes. 268 | 269 | :raises ValueError: 270 | if the input is invalid. 271 | """ 272 | if x.ndim > 2 or (n is not None and x.shape[0] != n): 273 | raise ValueError(f"Invalid shape: {x.shape}") 274 | 275 | if (x < 0.0).any() or (x > 1.0).any(): 276 | raise ValueError( 277 | f"Encountered values outside of [0, 1]. min={x.min().item()}, max={x.max().item()}" 278 | ) 279 | 280 | x_sum = x.sum(dim=0) 281 | if not torch.allclose(x_sum, torch.ones_like(x_sum)): 282 | raise ValueError(f"The entries do not sum to 1. {x_sum[x_sum != 0]}") 283 | 284 | 285 | def prepare_x0( 286 | x0: Optional[torch.Tensor] = None, 287 | indices: Optional[Collection[int]] = None, 288 | n: Optional[int] = None, 289 | ) -> torch.Tensor: 290 | """ 291 | Prepare a start value. 292 | 293 | The following precedence order is used: 294 | 295 | 1. an explicit start value, via ``x0``. If present, this tensor is passed through without further modification. 296 | 2. a one-hot matrix created via ``indices``. The matrix is of shape ``(n, len(indices))`` and has a single 1 per 297 | column at the given indices. 298 | 3. a uniform ``1/n`` vector of shape ``(n,)`` 299 | 300 | :param x0: 301 | the start value. 302 | :param indices: 303 | a non-zero indices 304 | :param n: 305 | the number of nodes 306 | 307 | :raises ValueError: 308 | if neither ``x0`` nor ``n`` are provided 309 | 310 | :return: shape: ``(n,)`` or ``(n, batch_size)`` 311 | the initial value ``x`` 312 | """ 313 | if x0 is not None: 314 | return x0 315 | if n is None: 316 | raise ValueError("If x0 is not provided, n must be given.") 317 | if indices is not None: 318 | k = len(indices) 319 | x0 = torch.zeros(n, k) 320 | x0[indices, torch.arange(k, device=x0.device)] = 1.0 321 | return x0 322 | return torch.full(size=(n,), fill_value=1.0 / n) 323 | 324 | 325 | def power_iteration( 326 | adj: torch.Tensor, 327 | x0: torch.Tensor, 328 | alpha: float = 0.05, 329 | max_iter: int = 1_000, 330 | use_tqdm: bool = False, 331 | epsilon: float = 1.0e-04, 332 | device: DeviceHint = None, 333 | ) -> torch.Tensor: 334 | r""" 335 | Perform the power iteration. 336 | 337 | .. math:: 338 | \mathbf{x}^{(i+1)} = (1 - \alpha) \cdot \mathbf{A} \mathbf{x}^{(i)} + \alpha \mathbf{x}^{(0)} 339 | 340 | :param adj: shape: ``(n, n)`` 341 | the (sparse) adjacency matrix 342 | :param x0: shape: ``(n,)``, or ``(n, batch_size)`` 343 | the initial value for ``x``. 344 | :param alpha: ``0 < alpha < 1`` 345 | the smoothing value / teleport probability 346 | :param max_iter: ``0 < max_iter`` 347 | the maximum number of iterations 348 | :param epsilon: ``epsilon > 0`` 349 | a (small) constant to check for convergence 350 | :param use_tqdm: 351 | whether to use a tqdm progress bar 352 | :param device: 353 | the device to use, or a hint thereof, cf. :func:`resolve_device` 354 | 355 | :return: shape: ``(n,)`` or ``(n, batch_size)`` 356 | the ``x`` value after convergence (or maximum number of iterations). 357 | """ 358 | # normalize device 359 | device = resolve_device(device=device) 360 | # send tensors to device 361 | adj = adj.to(device=device) 362 | x0 = x0.to(device=device) 363 | no_batch = x0.ndim < 2 364 | if no_batch: 365 | x0 = x0.unsqueeze(dim=-1) 366 | # power iteration 367 | x_old = x = x0 368 | beta = 1.0 - alpha 369 | progress = tqdm(range(max_iter), unit_scale=True, leave=False, disable=not use_tqdm) 370 | for i in progress: 371 | # calculate x = (1 - alpha) * A.dot(x) + alpha * x0 372 | x = torch.sparse.addmm( 373 | # dense matrix to be added 374 | x0, 375 | # sparse matrix to be multiplied 376 | adj, 377 | # dense matrix to be multiplied 378 | x, 379 | # multiplier for added matrix 380 | beta=alpha, 381 | # multiplier for product 382 | alpha=beta, 383 | ) 384 | # note: while the adjacency matrix should already be row-sum normalized, 385 | # we additionally normalize x to avoid accumulating errors due to loss of precision 386 | x = functional.normalize(x, dim=0, p=1) 387 | # calculate difference, shape: (batch_size,) 388 | diff = torch.linalg.norm(x - x_old, ord=float("+inf"), axis=0) 389 | mask = diff > epsilon 390 | if use_tqdm: 391 | progress.set_postfix( 392 | max_diff=diff.max().item(), converged=1.0 - mask.float().mean().item() 393 | ) 394 | if not mask.any(): 395 | logger.debug(f"Converged after {i} iterations up to {epsilon}.") 396 | break 397 | x_old = x 398 | else: # for/else, cf. https://book.pythontips.com/en/latest/for_-_else.html 399 | logger.warning(f"No convergence after {max_iter} iterations with epsilon={epsilon}.") 400 | if no_batch: 401 | x = x.squeeze(dim=-1) 402 | return x 403 | 404 | 405 | def _ppr_hasher(kwargs: Mapping[str, Any]) -> int: 406 | # assumption: batched PPR memory consumption only depends on the matrix A, 407 | # in particular, the shape and the number of nonzero elements 408 | adj: torch.Tensor = kwargs.get("adj") 409 | return hash((adj.shape[0], getattr(adj, "nnz", adj.numel()))) 410 | 411 | 412 | ppr_maximizer = MemoryUtilizationMaximizer(hasher=_ppr_hasher) 413 | 414 | 415 | @ppr_maximizer 416 | def batched_personalized_page_rank( 417 | adj: torch.Tensor, 418 | indices: torch.Tensor, 419 | batch_size: int, 420 | **kwargs, 421 | ) -> torch.Tensor: 422 | """ 423 | Batch-wise PPR computation with automatic memory optimization. 424 | 425 | :param adj: shape: ``(n, n)`` 426 | the adjacency matrix. 427 | :param indices: shape: ``k`` 428 | the indices for which to compute PPR 429 | :param batch_size: ``batch_size > 0`` 430 | the batch size. Will be reduced if necessary 431 | :param kwargs: 432 | additional keyword-based parameters passed to :func:`power_iteration` 433 | 434 | :return: shape: ``(n, k)`` 435 | the PPR vectors for each node index 436 | """ 437 | return torch.cat( 438 | [ 439 | power_iteration(adj=adj, x0=prepare_x0(indices=indices_batch, n=adj.shape[0]), **kwargs) 440 | for indices_batch in torch.split(indices, batch_size) 441 | ], 442 | dim=1, 443 | ) 444 | -------------------------------------------------------------------------------- /src/torch_ppr/version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Version information for :mod:`torch_ppr`. 4 | 5 | Run with ``python -m torch_ppr.version`` 6 | """ 7 | 8 | import os 9 | from subprocess import CalledProcessError, check_output # noqa: S404 10 | 11 | __all__ = [ 12 | "VERSION", 13 | "get_version", 14 | "get_git_hash", 15 | ] 16 | 17 | VERSION = "0.0.9-dev" 18 | 19 | 20 | def get_git_hash() -> str: 21 | """Get the :mod:`torch_ppr` git hash.""" 22 | with open(os.devnull, "w") as devnull: 23 | try: 24 | ret = check_output( # noqa: S603,S607 25 | ["git", "rev-parse", "HEAD"], 26 | cwd=os.path.dirname(__file__), 27 | stderr=devnull, 28 | ) 29 | except CalledProcessError: 30 | return "UNHASHED" 31 | else: 32 | return ret.strip().decode("utf-8")[:8] 33 | 34 | 35 | def get_version(with_git_hash: bool = False): 36 | """Get the :mod:`torch_ppr` version string, including a git hash.""" 37 | return f"{VERSION}-{get_git_hash()}" if with_git_hash else VERSION 38 | 39 | 40 | if __name__ == "__main__": 41 | print(get_version(with_git_hash=True)) # noqa:T201 42 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Tests for :mod:`torch_ppr`.""" 4 | -------------------------------------------------------------------------------- /tests/test_api.py: -------------------------------------------------------------------------------- 1 | """Tests for public API.""" 2 | 3 | import unittest 4 | 5 | import torch 6 | 7 | from torch_ppr import page_rank, personalized_page_rank, utils 8 | 9 | 10 | class APITest(unittest.TestCase): 11 | """Test public API.""" 12 | 13 | num_nodes: int = 7 14 | num_edges: int = 33 15 | 16 | def setUp(self) -> None: 17 | """Prepare data.""" 18 | generator = torch.manual_seed(42) 19 | self.edge_index = torch.cat( 20 | [ 21 | torch.randint( 22 | self.num_nodes, size=(2, self.num_edges - self.num_nodes), generator=generator 23 | ), 24 | # ensure connectivity 25 | torch.arange(self.num_nodes).unsqueeze(0).repeat(2, 1), 26 | ], 27 | dim=-1, 28 | ) 29 | self.adj = utils.prepare_page_rank_adjacency(edge_index=self.edge_index) 30 | 31 | def test_page_rank_edge_index(self): 32 | """Test Page Rank calculation for an adjacency given as edge list.""" 33 | page_rank(edge_index=self.edge_index) 34 | 35 | def test_page_rank_adj(self): 36 | """Test Page Rank calculation.""" 37 | page_rank(adj=self.adj) 38 | 39 | def test_personalized_page_rank_edge_index(self): 40 | """Test Personalized Page Rank calculation for an adjacency given as edge list.""" 41 | personalized_page_rank(edge_index=self.edge_index) 42 | 43 | def test_personalized_page_rank_adj(self): 44 | """Test Personalized Page Rank calculation.""" 45 | personalized_page_rank(adj=self.adj) 46 | 47 | def test_page_rank_manual(self): 48 | """Test Page Rank calculation on a simple manually created example.""" 49 | # A - B - C 50 | # | 51 | # D 52 | edge_index = torch.as_tensor(data=[(0, 1), (1, 2), (1, 3)]).t() 53 | x = page_rank(edge_index=edge_index) 54 | # verify that central node has the largest PR value 55 | assert x.argmax() == 1 56 | 57 | def test_page_rank_isolated_vertices(self): 58 | """Test Page-Rank with isolated vertices.""" 59 | # create isolated node, ID=0 60 | edge_index = self.edge_index + 1 61 | x = page_rank(edge_index=edge_index, add_identity=True) 62 | # isolated node has only one self-loop -> no change in mass to initial mass 63 | self.assertAlmostEqual(x[0].item(), 1 / (self.num_nodes + 1)) 64 | # verify that other nodes are unaffected 65 | x2 = page_rank(edge_index=self.edge_index) 66 | # rescale 67 | x2 = x2 * (self.num_nodes / (self.num_nodes + 1)) 68 | assert torch.allclose(x2, x[1:], atol=1.0e-02) 69 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | """Tests for the API.""" 2 | import unittest 3 | from typing import Counter, Optional, Tuple 4 | 5 | import pytest 6 | import torch 7 | from torch.nn import functional 8 | 9 | from torch_ppr import utils 10 | 11 | 12 | def test_resolve_device(): 13 | """Test for resolving devices.""" 14 | for hint, device in ( 15 | (None, torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")), 16 | ("cpu", torch.device("cpu")), 17 | (torch.device("cpu"), torch.device("cpu")), 18 | ): 19 | assert device == utils.resolve_device(device=hint) 20 | 21 | 22 | class UtilsTest(unittest.TestCase): 23 | """Test utilities.""" 24 | 25 | num_nodes: int = 7 26 | num_edges: int = 33 27 | 28 | def setUp(self) -> None: 29 | """Prepare data.""" 30 | # fix seed for reproducible tests 31 | torch.manual_seed(seed=42) 32 | self.edge_index = torch.cat( 33 | [ 34 | torch.randint(self.num_nodes, size=(2, self.num_edges - self.num_nodes)), 35 | # ensure connectivity 36 | torch.arange(self.num_nodes).unsqueeze(0).repeat(2, 1), 37 | ], 38 | dim=-1, 39 | ) 40 | target_indices = self.edge_index[1].tolist() 41 | counts = Counter(target_indices) 42 | values = torch.as_tensor([1.0 / counts[i] for i in target_indices]) 43 | self.adj = torch.sparse_coo_tensor(indices=self.edge_index, values=values) 44 | 45 | def _verify_adjacency(self, adj: torch.Tensor): 46 | assert torch.is_tensor(adj) 47 | assert adj.shape == (self.num_nodes, self.num_nodes) 48 | 49 | def test_prepare_num_nodes(self): 50 | """Test inferring the number of nodes from an edge index.""" 51 | for num_nodes in (None, self.num_nodes): 52 | assert ( 53 | utils.prepare_num_nodes(edge_index=self.edge_index, num_nodes=num_nodes) 54 | == self.num_nodes 55 | ) 56 | 57 | def test_edge_index_to_sparse_matrix(self): 58 | """Test conversion of edge indices to sparse matrices.""" 59 | for num_nodes_ in (self.num_nodes, None): 60 | adj = utils.edge_index_to_sparse_matrix( 61 | edge_index=self.edge_index, 62 | num_nodes=num_nodes_, 63 | ) 64 | assert adj.shape == (self.num_nodes, self.num_nodes) 65 | 66 | def test_validate_adjacancy(self): 67 | """Test adjacency validation.""" 68 | adj = utils.prepare_page_rank_adjacency(edge_index=self.edge_index) 69 | # plain validation with shape inference 70 | utils.validate_adjacency(adj=adj) 71 | # plain validation with explicit shape 72 | utils.validate_adjacency(adj=adj, n=self.num_nodes) 73 | # validation with CSR matrix 74 | utils.validate_adjacency(adj=adj.to_sparse_csr()) 75 | # test error raising 76 | for adj in ( 77 | # an edge_index instead of adj 78 | self.edge_index, 79 | # wrong shape 80 | torch.sparse_coo_tensor( 81 | indices=torch.empty(2, 0, dtype=torch.long), 82 | values=torch.empty(0), 83 | size=(2, 3), 84 | ), 85 | # wrong value range 86 | torch.sparse_coo_tensor( 87 | indices=self.edge_index, 88 | values=torch.full(size=(self.num_edges,), fill_value=2.0), 89 | size=(self.num_nodes, self.num_nodes), 90 | ), 91 | # wrong sum 92 | torch.sparse_coo_tensor( 93 | indices=self.edge_index, 94 | values=torch.ones(self.num_edges), 95 | size=(self.num_nodes, self.num_nodes), 96 | ), 97 | ): 98 | with self.assertRaises(ValueError): 99 | utils.validate_adjacency(adj=adj) 100 | 101 | def test_prepare_page_rank_adjacency(self): 102 | """Test adjacency preparation.""" 103 | for (adj, edge_index, add_identity) in ( 104 | # from edge index 105 | (None, self.edge_index, False), 106 | # passing through adjacency matrix 107 | (self.adj, None, False), 108 | (self.adj, self.edge_index, False), 109 | # add identity 110 | (None, self.edge_index, True), 111 | ): 112 | adj2 = utils.prepare_page_rank_adjacency( 113 | adj=adj, edge_index=edge_index, add_identity=add_identity 114 | ) 115 | utils.validate_adjacency(adj=adj2, n=self.num_nodes) 116 | if adj is not None: 117 | assert adj is adj2 118 | 119 | def _valid_x0(self, size: Optional[Tuple[int, ...]] = None) -> torch.Tensor: 120 | """Generate a valid x0.""" 121 | size = size or (self.num_nodes,) 122 | return functional.normalize(torch.rand(size=size), p=1, dim=0) 123 | 124 | def test_validate_x(self): 125 | """Test page-rank vector validation.""" 126 | # valid single 127 | x0_valid = self._valid_x0() 128 | utils.validate_x(x=x0_valid, n=self.num_nodes) 129 | # valid batch 130 | x0_valid_batch = self._valid_x0(size=(self.num_nodes, 12)) 131 | utils.validate_x(x=x0_valid_batch, n=self.num_nodes) 132 | # invalid shape, wrong dim 133 | with self.assertRaises(ValueError): 134 | utils.validate_x(x=x0_valid, n=self.num_nodes + 1) 135 | # invalid shape, too many dim 136 | with self.assertRaises(ValueError): 137 | utils.validate_x(x=x0_valid_batch[..., None], n=self.num_nodes) 138 | # too large value 139 | for value in (-1.0, 2.0): 140 | with self.assertRaises(ValueError): 141 | x0_invalid = x0_valid.clone() 142 | x0_invalid[0] = value 143 | utils.validate_x(x=x0_invalid, n=self.num_nodes) 144 | 145 | def test_prepare_x0(self): 146 | """Test x0 preparation.""" 147 | for x0, indices in ( 148 | # x0 pass-through 149 | (self._valid_x0(), None), 150 | (self._valid_x0(size=(self.num_nodes, 12)), None), 151 | (self._valid_x0(), [1, 2]), 152 | # indices 153 | (None, [1, 2]), 154 | # only n 155 | (None, None), 156 | ): 157 | x0 = utils.prepare_x0(x0, indices, n=self.num_nodes) 158 | utils.validate_x(x0, n=self.num_nodes) 159 | 160 | def test_power_iteration(self): 161 | """Test power-iteration.""" 162 | adj = utils.prepare_page_rank_adjacency(edge_index=self.edge_index) 163 | for x0 in (self._valid_x0(), self._valid_x0(size=(self.num_nodes, 12))): 164 | x = utils.power_iteration(adj=adj, x0=x0, max_iter=5) 165 | utils.validate_x(x=x, n=self.num_nodes) 166 | 167 | def test_batched_personalized_page_rank(self): 168 | """Test batched PPR calculation.""" 169 | x = utils.batched_personalized_page_rank( 170 | adj=self.adj, indices=torch.arange(self.num_nodes), batch_size=self.num_nodes // 3 171 | ) 172 | utils.validate_x(x) 173 | 174 | 175 | @pytest.mark.parametrize("n", [8, 16]) 176 | def test_sparse_diagonal(n: int): 177 | """Test for sparse diagonal matrix creation.""" 178 | values = torch.rand(n) 179 | matrix = utils.sparse_diagonal(values=values) 180 | assert torch.is_tensor(matrix) 181 | assert matrix.shape == (n, n) 182 | assert matrix.is_sparse 183 | assert torch.allclose(matrix.to_dense(), torch.diag(values)) 184 | 185 | 186 | @pytest.mark.parametrize("seed", [21, 42, 63]) 187 | def test_sparse_normalize(seed: int): 188 | """Test for sparse matrix normalization.""" 189 | generator = torch.manual_seed(seed=seed) 190 | n_rows, n_cols = torch.randint(10, 20, size=(2,), generator=generator) 191 | matrix = torch.rand(size=(n_rows, n_cols), generator=generator) 192 | # make sparse 193 | matrix[matrix < 0.5] = 0 194 | matrix = matrix.to_sparse() 195 | # normalize 196 | for dim in (0, 1): 197 | matrix_norm = utils.sparse_normalize(matrix=matrix, dim=dim) 198 | sparse_sum = torch.sparse.sum(matrix_norm, dim=dim) 199 | assert torch.allclose(sparse_sum.values(), torch.ones_like(sparse_sum.values())) 200 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | """Trivial version test.""" 4 | 5 | import unittest 6 | 7 | from torch_ppr.version import get_version 8 | 9 | 10 | class TestVersion(unittest.TestCase): 11 | """Trivially test a version.""" 12 | 13 | def test_version_type(self): 14 | """Test the version is a string. 15 | 16 | This is only meant to be an example test. 17 | """ 18 | version = get_version(with_git_hash=True) 19 | self.assertIsInstance(version, str) 20 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | # Tox (http://tox.testrun.org/) is a tool for running tests 2 | # in multiple virtualenvs. This configuration file will run the 3 | # test suite on all supported python versions. To use it, "pip install tox" 4 | # and then run "tox" from this directory. 5 | 6 | [tox] 7 | # To use a PEP 517 build-backend you are required to configure tox to use an isolated_build: 8 | # https://tox.readthedocs.io/en/latest/example/package.html 9 | isolated_build = True 10 | 11 | # These environments are run in order if you just use `tox`: 12 | envlist = 13 | # always keep coverage-clean first 14 | # coverage-clean 15 | # code linters/stylers 16 | lint 17 | manifest 18 | pyroma 19 | flake8 20 | mypy 21 | # documentation linters/checkers 22 | doc8 23 | docstr-coverage 24 | docs-test 25 | # the actual tests 26 | py-torch-{1.11,1.12,1.13} 27 | # always keep coverage-report last 28 | # coverage-report 29 | 30 | [testenv] 31 | # Runs on the "tests" directory by default, or passes the positional 32 | # arguments from `tox -e py ... 33 | commands = 34 | coverage run -p -m pytest --durations=20 {posargs:tests} 35 | coverage combine 36 | coverage xml 37 | setenv = 38 | PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cpu 39 | deps = 40 | torch-1.11: torch~=1.11.0 41 | torch-1.12: torch~=1.12.0 42 | torch-1.13: torch~=1.13.0 43 | extras = 44 | # See the [options.extras_require] entry in setup.cfg for "tests" 45 | tests 46 | 47 | [testenv:coverage-clean] 48 | deps = coverage 49 | skip_install = true 50 | commands = coverage erase 51 | 52 | [testenv:lint] 53 | deps = 54 | black 55 | isort 56 | skip_install = true 57 | commands = 58 | black src/ tests/ 59 | isort src/ tests/ 60 | description = Run linters. 61 | 62 | [testenv:manifest] 63 | deps = check-manifest 64 | skip_install = true 65 | commands = check-manifest 66 | description = Check that the MANIFEST.in is written properly and give feedback on how to fix it. 67 | 68 | [testenv:flake8] 69 | skip_install = true 70 | deps = 71 | darglint 72 | flake8 73 | flake8-black 74 | flake8-bandit 75 | flake8-bugbear 76 | flake8-colors 77 | flake8-docstrings 78 | flake8-isort 79 | flake8-print 80 | pep8-naming 81 | pydocstyle 82 | commands = 83 | flake8 src/ tests/ 84 | description = Run the flake8 tool with several plugins (bandit, docstrings, import order, pep8 naming). See https://cthoyt.com/2020/04/25/how-to-code-with-me-flake8.html for more information. 85 | 86 | [testenv:pyroma] 87 | deps = 88 | pygments 89 | pyroma 90 | skip_install = true 91 | commands = pyroma --min=10 . 92 | description = Run the pyroma tool to check the package friendliness of the project. 93 | 94 | [testenv:mypy] 95 | deps = mypy 96 | skip_install = true 97 | commands = mypy --install-types --non-interactive --ignore-missing-imports src/ 98 | description = Run the mypy tool to check static typing on the project. 99 | 100 | [testenv:doc8] 101 | skip_install = true 102 | deps = 103 | sphinx 104 | doc8 105 | commands = 106 | doc8 docs/source/ 107 | description = Run the doc8 tool to check the style of the RST files in the project docs. 108 | 109 | [testenv:docstr-coverage] 110 | skip_install = true 111 | deps = 112 | docstr-coverage 113 | commands = 114 | docstr-coverage src/ tests/ --skip-private --skip-magic 115 | description = Run the docstr-coverage tool to check documentation coverage 116 | 117 | [testenv:docs] 118 | description = Build the documentation locally. 119 | extras = 120 | # See the [options.extras_require] entry in setup.cfg for "docs" 121 | docs 122 | # You might need to add additional extras if your documentation covers it 123 | commands = 124 | python -m sphinx -W -b html -d docs/build/doctrees docs/source docs/build/html 125 | 126 | [testenv:docs-test] 127 | description = Test building the documentation in an isolated environment. 128 | changedir = docs 129 | extras = 130 | {[testenv:docs]extras} 131 | commands = 132 | mkdir -p {envtmpdir} 133 | cp -r source {envtmpdir}/source 134 | python -m sphinx -W -b html -d {envtmpdir}/build/doctrees {envtmpdir}/source {envtmpdir}/build/html 135 | python -m sphinx -W -b coverage -d {envtmpdir}/build/doctrees {envtmpdir}/source {envtmpdir}/build/coverage 136 | cat {envtmpdir}/build/coverage/c.txt 137 | cat {envtmpdir}/build/coverage/python.txt 138 | whitelist_externals = 139 | /bin/cp 140 | /bin/cat 141 | /bin/mkdir 142 | # for compatibility on GitHub actions 143 | /usr/bin/cp 144 | /usr/bin/cat 145 | /usr/bin/mkdir 146 | 147 | [testenv:coverage-report] 148 | deps = coverage 149 | skip_install = true 150 | commands = 151 | coverage combine 152 | coverage report 153 | 154 | #################### 155 | # Deployment tools # 156 | #################### 157 | 158 | [testenv:bumpversion] 159 | commands = bumpversion {posargs} 160 | skip_install = true 161 | passenv = HOME 162 | deps = 163 | bumpversion 164 | 165 | [testenv:build] 166 | skip_install = true 167 | deps = 168 | wheel 169 | build 170 | commands = 171 | python -m build --sdist --wheel --no-isolation 172 | 173 | [testenv:release] 174 | description = Release the code to PyPI so users can pip install it 175 | skip_install = true 176 | deps = 177 | {[testenv:build]deps} 178 | twine >= 1.5.0 179 | commands = 180 | {[testenv:build]commands} 181 | twine upload --skip-existing dist/* 182 | 183 | [testenv:testrelease] 184 | description = Release the code to the test PyPI site 185 | skip_install = true 186 | deps = 187 | {[testenv:build]deps} 188 | twine >= 1.5.0 189 | commands = 190 | {[testenv:build]commands} 191 | twine upload --skip-existing --repository-url https://test.pypi.org/simple/ dist/* 192 | 193 | [testenv:finish] 194 | skip_install = true 195 | passenv = 196 | HOME 197 | TWINE_USERNAME 198 | TWINE_PASSWORD 199 | deps = 200 | {[testenv:release]deps} 201 | bump2version 202 | commands = 203 | bump2version release --tag 204 | {[testenv:release]commands} 205 | git push --tags 206 | bump2version patch 207 | git push 208 | whitelist_externals = 209 | /usr/bin/git 210 | --------------------------------------------------------------------------------