├── .python-version ├── docs ├── .gitignore ├── Makefile ├── make.bat ├── readme.md ├── reference │ ├── profile.rst │ ├── debug.rst │ ├── interface.rst │ ├── grove.rst │ ├── prepcovars.rst │ ├── mcmcstep.rst │ ├── mcmcloop.rst │ ├── index.rst │ └── jaxext.rst ├── guide │ ├── index.rst │ ├── installation.rst │ └── quickstart.rst ├── index.rst ├── _static │ └── custom.css ├── pkglist.md ├── development.rst └── conf.py ├── .Rprofile ├── src └── bartz │ ├── _version.py │ ├── jaxext │ ├── scipy │ │ ├── __init__.py │ │ ├── stats.py │ │ └── special.py │ ├── _autobatch.py │ └── __init__.py │ ├── __init__.py │ └── _profiler.py ├── config ├── ipython │ ├── .gitignore │ └── profile_default │ │ ├── startup │ │ └── startup.ipy │ │ └── ipython_config.py └── refs-for-asv.py ├── .asv ├── .gitignore └── results │ ├── gattocrucco-m1 │ ├── machine.json │ ├── b145bd73-virtualenv-py3.13.json │ ├── 485be21f-virtualenv-py3.13.json │ ├── c0940a3a-virtualenv-py3.13.json │ ├── d549d06f-virtualenv-py3.13.json │ └── 1119f48f-virtualenv-py3.13.json │ └── benchmarks.json ├── _site ├── .gitignore └── index.html ├── renv ├── .gitignore └── settings.json ├── .gitlint ├── .github ├── dependabot.yml └── workflows │ └── release.yml ├── .gitignore ├── LICENSE ├── benchmarks ├── __init__.py └── rmse.py ├── tests ├── __init__.py ├── rbartpackages │ ├── __init__.py │ ├── bartMachine.py │ ├── dbarts.py │ ├── BART.py │ ├── BART3.py │ └── _base.py ├── conftest.py ├── test_mcmcstep.py ├── test_mcmcloop.py ├── util.py ├── test_meta.py ├── test_prepcovars.py └── test_debug.py ├── .pre-commit-config.yaml ├── README.md ├── Makefile ├── asv.conf.json └── pyproject.toml /.python-version: -------------------------------------------------------------------------------- 1 | 3.14 2 | -------------------------------------------------------------------------------- /docs/.gitignore: -------------------------------------------------------------------------------- 1 | _build/ 2 | -------------------------------------------------------------------------------- /.Rprofile: -------------------------------------------------------------------------------- 1 | source("renv/activate.R") 2 | -------------------------------------------------------------------------------- /src/bartz/_version.py: -------------------------------------------------------------------------------- 1 | __version__ = '0.7.0' 2 | -------------------------------------------------------------------------------- /config/ipython/.gitignore: -------------------------------------------------------------------------------- 1 | history.sqlite 2 | db/dhist 3 | -------------------------------------------------------------------------------- /.asv/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | !results/ 4 | !results/** 5 | -------------------------------------------------------------------------------- /_site/.gitignore: -------------------------------------------------------------------------------- 1 | docs/ 2 | docs-dev/ 3 | coverage/ 4 | benchmarks/ 5 | -------------------------------------------------------------------------------- /renv/.gitignore: -------------------------------------------------------------------------------- 1 | library/ 2 | local/ 3 | cellar/ 4 | lock/ 5 | python/ 6 | sandbox/ 7 | staging/ 8 | -------------------------------------------------------------------------------- /.gitlint: -------------------------------------------------------------------------------- 1 | [general] 2 | ignore = body-is-missing 3 | 4 | [title-max-length] 5 | line-length = 50 6 | 7 | [body-max-line-length] 8 | line-length = 72 9 | -------------------------------------------------------------------------------- /.asv/results/gattocrucco-m1/machine.json: -------------------------------------------------------------------------------- 1 | { 2 | "arch": "arm64", 3 | "cpu": "Apple M1 Pro", 4 | "machine": "gattocrucco-m1", 5 | "num_cpu": "8", 6 | "os": "Darwin 24.3.0", 7 | "ram": "17179869184", 8 | "version": 1 9 | } -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | # Please see the documentation for all configuration options: 2 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file 3 | 4 | version: 2 5 | updates: 6 | - package-ecosystem: "github-actions" 7 | directory: "/" 8 | schedule: 9 | interval: "monthly" 10 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__ 3 | dist/ 4 | .coverage* 5 | .venv/ 6 | 7 | # jax tracer 8 | **/????_??_??_??_??_??/ALL_HOSTS.op_stats.pb 9 | **/????_??_??_??_??_??/cache_version.txt 10 | **/????_??_??_??_??_??/*.SSTABLE 11 | **/????_??_??_??_??_??/*.trace.json.gz 12 | **/????_??_??_??_??_??/*.xplane.pb 13 | **/????_??_??_??_??_??/*.hlo_proto.pb 14 | **/????_??_??_??_??_??/.cached_tools.json 15 | 16 | # python profiler 17 | *.prof 18 | -------------------------------------------------------------------------------- /renv/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "bioconductor.version": null, 3 | "external.libraries": [], 4 | "ignored.packages": [], 5 | "package.dependency.fields": [ 6 | "Imports", 7 | "Depends", 8 | "LinkingTo" 9 | ], 10 | "ppm.enabled": null, 11 | "ppm.ignored.urls": [], 12 | "r.version": null, 13 | "snapshot.type": "implicit", 14 | "use.cache": true, 15 | "vcs.ignore.cellar": true, 16 | "vcs.ignore.library": true, 17 | "vcs.ignore.local": true, 18 | "vcs.manage.ignores": true 19 | } 20 | -------------------------------------------------------------------------------- /config/ipython/profile_default/startup/startup.ipy: -------------------------------------------------------------------------------- 1 | with open(__file__) as f: 2 | print("Startup commands:") 3 | for i, line in enumerate(f): 4 | if i >= 7: 5 | print(f">>> {line.rstrip('\n')}") 6 | print() 7 | 8 | %matplotlib 9 | %load_ext autoreload 10 | %autoreload 2 11 | %load_ext snakeviz 12 | 13 | import appnope 14 | appnope.nope() 15 | 16 | from functools import partial 17 | 18 | import jax 19 | from jax import numpy as jnp 20 | import numpy as np 21 | import equinox as eqx 22 | import jaxtyping as jt 23 | 24 | import bartz 25 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /_site/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | bartz 5 | 6 | 7 | 8 | 9 |

bartz

10 | 11 |

12 | A JAX implementation of BART (Bayesian Additive Regression Trees). 13 |

21 |

22 | 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.http://sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024-2025 The Bartz Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /docs/readme.md: -------------------------------------------------------------------------------- 1 | 26 | 27 | ```{include} ../README.md 28 | ``` 29 | -------------------------------------------------------------------------------- /benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | # bartz/benchmarks/__init__.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Benchmarking code run by asv.""" 26 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/__init__.py 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Unit tests to be run with pytest.""" 26 | -------------------------------------------------------------------------------- /config/ipython/profile_default/ipython_config.py: -------------------------------------------------------------------------------- 1 | # bartz/config/ipython/profile_default/ipython_config.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | -------------------------------------------------------------------------------- /tests/rbartpackages/__init__.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/rbartpackages/__init__.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Wrappers of R BART packages.""" 26 | -------------------------------------------------------------------------------- /src/bartz/jaxext/scipy/__init__.py: -------------------------------------------------------------------------------- 1 | # bartz/src/bartz/jaxext/scipy/__init__.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Mockup of the :external:py:mod:`scipy` module.""" 26 | -------------------------------------------------------------------------------- /docs/reference/profile.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/reference/profile.rst 2 | .. 3 | .. Copyright (c) 2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | Profiling 26 | --------- 27 | 28 | .. autofunction:: bartz.profile_mode 29 | -------------------------------------------------------------------------------- /docs/reference/debug.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/reference/debug.rst 2 | .. 3 | .. Copyright (c) 2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | Debugging 26 | --------- 27 | 28 | .. automodule:: bartz.debug 29 | :members: 30 | -------------------------------------------------------------------------------- /docs/reference/interface.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/interface.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | Interface 26 | ========= 27 | 28 | .. automodule:: bartz.BART 29 | :members: 30 | -------------------------------------------------------------------------------- /docs/reference/grove.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/grove.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | Tree manipulation 26 | ================= 27 | 28 | .. automodule:: bartz.grove 29 | :members: 30 | -------------------------------------------------------------------------------- /docs/reference/prepcovars.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/prepcovars.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | Data processing 26 | =============== 27 | 28 | .. automodule:: bartz.prepcovars 29 | :members: 30 | -------------------------------------------------------------------------------- /docs/reference/mcmcstep.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/mcmcstep.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | MCMC setup and step 26 | =================== 27 | 28 | .. automodule:: bartz.mcmcstep 29 | :members: 30 | -------------------------------------------------------------------------------- /docs/guide/index.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/guide/index.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | Guide 26 | ===== 27 | 28 | .. toctree:: 29 | :maxdepth: 1 30 | 31 | installation.rst 32 | quickstart.rst 33 | -------------------------------------------------------------------------------- /docs/reference/mcmcloop.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/mcmcloop.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | MCMC loop 26 | ========= 27 | 28 | .. automodule:: bartz.mcmcloop 29 | :members: 30 | :special-members: __call__ 31 | -------------------------------------------------------------------------------- /docs/reference/index.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/reference/index.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | Reference 26 | ========= 27 | 28 | .. toctree:: 29 | :maxdepth: 1 30 | 31 | interface.rst 32 | grove.rst 33 | mcmcstep.rst 34 | mcmcloop.rst 35 | prepcovars.rst 36 | jaxext.rst 37 | debug.rst 38 | profile.rst 39 | -------------------------------------------------------------------------------- /src/bartz/__init__.py: -------------------------------------------------------------------------------- 1 | # bartz/src/bartz/__init__.py 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """ 26 | Super-fast BART (Bayesian Additive Regression Trees) in Python. 27 | 28 | See the manual at https://gattocrucco.github.io/bartz/docs 29 | """ 30 | 31 | from bartz import BART, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401 32 | from bartz._profiler import profile_mode # noqa: F401 33 | from bartz._version import __version__ # noqa: F401 34 | -------------------------------------------------------------------------------- /src/bartz/jaxext/scipy/stats.py: -------------------------------------------------------------------------------- 1 | # bartz/src/bartz/jaxext/scipy/stats.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Mockup of the :external:py:mod:`scipy.stats` module.""" 26 | 27 | from bartz.jaxext.scipy.special import gammainccinv 28 | 29 | 30 | class invgamma: 31 | """Class that represents the distribution InvGamma(a, 1).""" 32 | 33 | @staticmethod 34 | def ppf(q, a): 35 | """Percentile point function.""" 36 | return 1 / gammainccinv(a, q) 37 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/index.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | .. module:: bartz 26 | 27 | bartz 28 | ===== 29 | 30 | Contents 31 | -------- 32 | 33 | .. toctree:: 34 | :maxdepth: 1 35 | 36 | readme.md 37 | 38 | .. toctree:: 39 | :maxdepth: 2 40 | 41 | guide/index.rst 42 | reference/index.rst 43 | 44 | .. toctree:: 45 | :maxdepth: 1 46 | 47 | changelog.md 48 | development.rst 49 | pkglist.md 50 | 51 | * :ref:`genindex` 52 | * :ref:`search` 53 | -------------------------------------------------------------------------------- /docs/reference/jaxext.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/jaxext.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | JAX extensions 26 | ============== 27 | 28 | bartz.jaxext 29 | ------------ 30 | 31 | .. automodule:: bartz.jaxext 32 | :members: 33 | 34 | .. autofunction:: bartz.jaxext.autobatch 35 | 36 | bartz.jaxext.scipy.special 37 | -------------------------- 38 | 39 | .. automodule:: bartz.jaxext.scipy.special 40 | :members: 41 | 42 | bartz.jaxext.scipy.stats 43 | ------------------------ 44 | 45 | .. automodule:: bartz.jaxext.scipy.stats 46 | :members: 47 | -------------------------------------------------------------------------------- /.github/workflows/release.yml: -------------------------------------------------------------------------------- 1 | # bartz/workflows/release.yml 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | name: release 26 | permissions: 27 | contents: read 28 | 29 | on: 30 | release: 31 | types: [published] 32 | workflow_dispatch: 33 | 34 | jobs: 35 | tests: 36 | runs-on: ubuntu-latest 37 | steps: 38 | - name: Set up Python 39 | uses: actions/setup-python@v6 40 | with: 41 | python-version-file: ".python-version" 42 | - name: Update pip 43 | run: python -m pip install --upgrade pip 44 | - name: Install software 45 | run: python -m pip install bartz 46 | - name: Try to import it 47 | run: python -c 'import bartz;print(bartz.__version__)' 48 | -------------------------------------------------------------------------------- /docs/guide/installation.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/installation.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | Installation 26 | ============ 27 | 28 | Install and set up Python. There are various ways to do it; my favorite one is to use `uv `_. Then: 29 | 30 | .. code-block:: sh 31 | 32 | pip install bartz 33 | 34 | To install the latest development version, do instead 35 | 36 | .. code-block:: sh 37 | 38 | pip install git+https://github.com/Gattocrucco/bartz.git 39 | 40 | To install a specific commit, do 41 | 42 | .. code-block:: sh 43 | 44 | pip install git+https://github.com/Gattocrucco/bartz.git@ 45 | 46 | To use on GPU on a system that doesn't provide `jax` pre-installed, read how to install jax `in its manual `_. 47 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | 4 | exclude: '^\.asv/results/.+\.json$' 5 | default_stages: [pre-commit] 6 | default_install_hook_types: [pre-commit, commit-msg] 7 | 8 | repos: 9 | - repo: https://github.com/pre-commit/pre-commit-hooks 10 | rev: v6.0.0 11 | hooks: 12 | - id: check-added-large-files 13 | - id: check-ast 14 | - id: check-case-conflict 15 | - id: check-docstring-first 16 | - id: check-executables-have-shebangs 17 | - id: check-illegal-windows-names 18 | - id: check-merge-conflict 19 | - id: check-shebang-scripts-are-executable 20 | - id: check-symlinks 21 | - id: check-toml 22 | - id: check-yaml 23 | - id: destroyed-symlinks 24 | - id: detect-private-key 25 | - id: end-of-file-fixer 26 | - id: forbid-submodules 27 | - id: mixed-line-ending 28 | args: [--fix=lf] 29 | - id: name-tests-test 30 | args: [--pytest-test-first] 31 | exclude: tests/(rbartpackages/.+\.py|util\.py)$ 32 | - id: trailing-whitespace 33 | exclude: '^renv/activate\.R$' # because renv edits it automatically, with trailing whitespace 34 | - repo: https://github.com/sbrunner/hooks 35 | rev: 1.6.1 36 | hooks: 37 | - id: copyright 38 | - id: copyright-required 39 | files: \.(py|rst)$ 40 | exclude: src/bartz/_version\.py$ 41 | - repo: https://github.com/astral-sh/ruff-pre-commit 42 | rev: v0.14.3 43 | hooks: 44 | - id: ruff-check 45 | args: [--fix] 46 | - id: ruff-format 47 | - repo: https://github.com/jsh9/pydoclint 48 | rev: 0.7.6 49 | hooks: 50 | - id: pydoclint 51 | - repo: https://github.com/abravalheri/validate-pyproject 52 | rev: v0.24.1 53 | hooks: 54 | - id: validate-pyproject 55 | # Optional extra validations from SchemaStore: 56 | additional_dependencies: ["validate-pyproject-schema-store[all]"] 57 | - repo: https://github.com/jorisroovers/gitlint 58 | rev: v0.19.1 59 | hooks: 60 | - id: gitlint # this is already defined on stage commit-msg 61 | -------------------------------------------------------------------------------- /tests/rbartpackages/bartMachine.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/rbartpackages/bartMachine.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Python wrapper of the R package bartMachine.""" 26 | 27 | # ruff: noqa: D102 28 | 29 | from rpy2 import robjects 30 | 31 | from tests.rbartpackages._base import RObjectBase, rmethod 32 | 33 | 34 | class bartMachine(RObjectBase): # noqa: D101, because the doc is pulled from R 35 | _rfuncname = 'bartMachine::bartMachine' 36 | 37 | def __init__(self, *args, num_cores=None, megabytes=5000, **kw): 38 | robjects.r(f'options(java.parameters = "-Xmx{megabytes:d}m")') 39 | robjects.r('loadNamespace("bartMachine")') 40 | if num_cores is not None: 41 | robjects.r(f'bartMachine::set_bart_machine_num_cores({int(num_cores)})') 42 | super().__init__(*args, **kw) 43 | 44 | @rmethod 45 | def predict(self, *args, **kw): ... 46 | 47 | @rmethod 48 | def get_posterior(self, *args, **kw): ... 49 | 50 | @rmethod 51 | def get_sigsqs(self, *args, **kw): ... 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PyPI](https://img.shields.io/pypi/v/bartz)](https://pypi.org/project/bartz/) 2 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.13931477.svg)](https://doi.org/10.5281/zenodo.13931477) 3 | 4 | # BART vectoriZed 5 | 6 | An implementation of Bayesian Additive Regression Trees (BART) in JAX. 7 | 8 | If you don't know what BART is, but know XGBoost, consider BART as a sort of Bayesian XGBoost. bartz makes BART run as fast as XGBoost. 9 | 10 | BART is a nonparametric Bayesian regression technique. Given training predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function. 11 | 12 | This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also good on CPU. Most other implementations of BART are for R, and run on CPU only. 13 | 14 | On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of) if n > 20,000, but using 1/20 of the memory. On GPU, the speed premium depends on sample size; it is convenient over CPU only for n > 10,000. The maximum speedup is currently 200x, on an Nvidia A100 and with at least 2,000,000 observations. 15 | 16 | [This Colab notebook](https://colab.research.google.com/github/Gattocrucco/bartz/blob/main/docs/examples/basic_simdata.ipynb) runs bartz with n = 100,000 observations, p = 1000 predictors, 10,000 trees, for 1000 MCMC iterations, in 6 minutes. 17 | 18 | ## Links 19 | 20 | - [Documentation (latest release)](https://gattocrucco.github.io/bartz/docs) 21 | - [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev) 22 | - [Repository](https://github.com/Gattocrucco/bartz) 23 | - [Code coverage](https://gattocrucco.github.io/bartz/coverage) 24 | - [Benchmarks](https://gattocrucco.github.io/bartz/benchmarks) 25 | - [List of BART packages](https://gattocrucco.github.io/bartz/docs-dev/pkglist.html) 26 | 27 | ## Citing bartz 28 | 29 | Article: Petrillo (2024), "Very fast Bayesian Additive Regression Trees on GPU", [arXiv:2410.23244](https://arxiv.org/abs/2410.23244). 30 | 31 | To cite the software directly, including the specific version, use [zenodo](https://doi.org/10.5281/zenodo.13931477). 32 | -------------------------------------------------------------------------------- /docs/guide/quickstart.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/quickstart.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | Quickstart 26 | ========== 27 | 28 | Basics 29 | ------ 30 | 31 | Import and use the `bartz.BART.gbart` class: 32 | 33 | .. code-block:: python 34 | 35 | from bartz.BART import gbart 36 | bart = gbart(X, y, ...) 37 | y_pred = bart.predict(X_test) 38 | 39 | The interface hews to the R package `BART `_, with a few differences explained in the documentation of `bartz.BART.gbart`. 40 | 41 | JAX 42 | --- 43 | 44 | `bartz` is implemented using `jax`, a Google library for machine learning. It allows to run the code on GPU or TPU and do various other things. 45 | 46 | For basic usage, JAX is just an alternative implementation of `numpy`. The arrays returned by `~bartz.BART.gbart` are "jax arrays" instead of "numpy arrays", but there is no perceived difference in their functionality. If you pass numpy arrays to `bartz`, they will be converted automatically. You don't have to deal with `jax` in any way. 47 | 48 | For advanced usage, refer to the `jax documentation `_. 49 | 50 | Advanced 51 | -------- 52 | 53 | `bartz` exposes the various functions that implement the MCMC of BART. You can use those yourself to try to make your own variant of BART. See the rest of the documentation for reference; the main entry points are `bartz.mcmcstep.init` and `bartz.mcmcloop.run_mcmc`. Using the internals is the only way to change the device used by each step of the algorithm, which is useful to pre-process data on CPU and move to GPU only the state of the MCMC if the data preprocessing step does not fit in the GPU memory. 54 | -------------------------------------------------------------------------------- /docs/_static/custom.css: -------------------------------------------------------------------------------- 1 | /* bartz/docs/_static/custom.css 2 | * 3 | * Copyright (c) 2024-2025, The Bartz Contributors 4 | * 5 | * This file is part of bartz. 6 | * 7 | * Permission is hereby granted, free of charge, to any person obtaining a copy 8 | * of this software and associated documentation files (the "Software"), to deal 9 | * in the Software without restriction, including without limitation the rights 10 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | * copies of the Software, and to permit persons to whom the Software is 12 | * furnished to do so, subject to the following conditions: 13 | * 14 | * The above copyright notice and this permission notice shall be included in all 15 | * copies or substantial portions of the Software. 16 | * 17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | * SOFTWARE. 24 | */ 25 | 26 | dl.py.method, dl.py.function { 27 | margin-top: 2em; 28 | margin-bottom: 2em; 29 | } 30 | 31 | dl.py.property { 32 | margin-top: 1em; 33 | } 34 | 35 | dl.py.class, dl.py.function { 36 | margin-top: 2.5em; 37 | margin-bottom: 2.5em; 38 | } 39 | 40 | h2 + dl.py.class, h2 + dl.py.function { 41 | margin-top: 1em; 42 | } 43 | 44 | /* space between parameter/attributes lists */ 45 | dl.field-list > dd { 46 | margin-bottom: 1em; 47 | } 48 | 49 | /* no additional space after last list in the block */ 50 | dl.field-list > dd:last-child { 51 | margin-bottom: 0em; 52 | } 53 | 54 | /* space between paragraphs in multi-paragraph param descriptions */ 55 | ul.simple > li > p:not(:first-child) { 56 | margin-top: 0.5em; 57 | } 58 | 59 | /* no space between param name and first paragraph in multi-paragraph param descriptions */ 60 | ul.simple > li > p:nth-child(2) { 61 | margin-top: 0em; 62 | } 63 | 64 | /* space between parameters */ 65 | ul.simple > li:not(:last-child) { 66 | margin-bottom: 0.5em; 67 | } 68 | 69 | /* highlight types that originate from type hints rather than the docstring */ 70 | span.sphinx_autodoc_typehints-type > code { 71 | background-color: #ee9; 72 | } 73 | 74 | /* viewcode extension */ 75 | 76 | span.linenos { 77 | padding-right: 2ex; 78 | } 79 | 80 | div.viewcode-block:target { 81 | margin: 0; 82 | padding-bottom: 2em; 83 | } 84 | 85 | :not(.notranslate) > div.highlight pre { 86 | padding-left: 1ex; 87 | padding-right: 1ex; 88 | font-size: 0.73em; 89 | background: #f4f4f4; 90 | margin-bottom: 2em; 91 | } 92 | -------------------------------------------------------------------------------- /config/refs-for-asv.py: -------------------------------------------------------------------------------- 1 | # bartz/config/refs-for-asv.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """ 26 | Print a list of git refs for ASV benchmarking. 27 | 28 | This script outputs: 29 | 1. All tags on the default branch with commit dates after CUTOFF_DATE 30 | 2. The HEAD of the default branch 31 | 32 | The output format is one ref per line, suitable for piping to `asv run HASHFILE:-` 33 | """ 34 | 35 | import datetime 36 | 37 | from git import Repo 38 | 39 | # Configuration 40 | CUTOFF_DATE = datetime.datetime(2025, 1, 1, tzinfo=datetime.timezone.utc) 41 | 42 | 43 | def main(): 44 | repo = Repo('.') 45 | 46 | # Get the default branch name from git 47 | # This queries the symbolic-ref for the remote's HEAD 48 | default_branch_name = repo.git.symbolic_ref( 49 | 'refs/remotes/origin/HEAD', short=True 50 | ).split('/')[-1] 51 | 52 | # Get the default branch 53 | main_branch = repo.refs[default_branch_name] 54 | 55 | # Collect tags that are reachable from main and after cutoff date 56 | tags_to_include = [] 57 | 58 | for tag in repo.tags: 59 | # Get the commit the tag points to 60 | commit = tag.commit 61 | 62 | # Check if this tag is reachable from main 63 | if not repo.is_ancestor(commit, main_branch.commit): 64 | continue 65 | 66 | # Check if commit date is after cutoff 67 | commit_date = datetime.datetime.fromtimestamp( 68 | commit.committed_date, tz=datetime.timezone.utc 69 | ) 70 | 71 | if commit_date >= CUTOFF_DATE: 72 | tags_to_include.append((commit_date, tag.name)) 73 | 74 | # Sort tags by commit date 75 | tags_to_include.sort() 76 | 77 | # Print tags 78 | for _, tag_name in tags_to_include: 79 | print(tag_name) 80 | 81 | # Print default branch ref 82 | print(default_branch_name) 83 | 84 | 85 | if __name__ == '__main__': 86 | main() 87 | -------------------------------------------------------------------------------- /.asv/results/gattocrucco-m1/b145bd73-virtualenv-py3.13.json: -------------------------------------------------------------------------------- 1 | {"commit_hash": "b145bd73241934c9085d819c17f748601a5df1f6", "env_name": "virtualenv-py3.13", "date": 1745432679000, "params": {"arch": "arm64", "cpu": "Apple M1 Pro", "machine": "gattocrucco-m1", "num_cpu": "8", "os": "Darwin 24.3.0", "ram": "17179869184", "python": "3.13"}, "python": "3.13", "requirements": {}, "env_vars": {}, "result_columns": ["result", "params", "version", "started_at", "duration", "stats_ci_99_a", "stats_ci_99_b", "stats_q_25", "stats_q_75", "stats_number", "stats_repeat", "samples", "profile"], "results": {"speed.TimeGbart.time_gbart": [[0.8069052919163369, NaN, NaN, NaN, 0.03726235398789868, NaN, NaN, NaN, 2.4787650835351087, NaN, NaN, NaN, 0.13813462504185736, NaN, NaN, NaN], [["0", "10"], ["'cold'", "'warm'"], ["1", "2", "8", "32"]], "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 1761995899604, 57.82, [0.7853, null, null, null, 0.035602, null, null, null, 2.4398, null, null, null, 0.1356, null, null, null], [0.87869, null, null, null, 0.038062, null, null, null, 2.5118, null, null, null, 0.14207, null, null, null], [0.79807, null, null, null, 0.037027, null, null, null, 2.464, null, null, null, 0.13772, null, null, null], [0.81622, null, null, null, 0.037874, null, null, null, 2.4884, null, null, null, 0.14041, null, null, null], [1, null, null, null, 1, null, null, null, 1, null, null, null, 1, null, null, null], [10, null, null, null, 10, null, null, null, 8, null, null, null, 10, null, null, null]], "speed.TimeRunMcmc.time_run_mcmc": [[NaN, NaN, 1.604135291534476, 1.4034559369902126, 0.01205670804483816, 0.0002243124763481319, 1.6102928330074064, 0.02777252095984295], [["'compile'", "'run'"], ["0", "10"], ["'cold'", "'warm'"]], "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 1761995928502, 85.303, [null, null, 1.5351, 1.344, 0.010912, 0.00015521, 1.547, 0.027445], [null, null, 1.8783, 1.5496, 0.012672, 0.00027988, 1.7484, 0.028061], [null, null, 1.5735, 1.3705, 0.011628, 0.00017985, 1.5839, 0.027565], [null, null, 1.678, 1.4848, 0.01225, 0.00026531, 1.649, 0.027994], [null, null, 1, 1, 1, 1, 1, 1], [null, null, 10, 8, 10, 10, 10, 10]], "speed.TimeStep.time_step": [[0.946801999467425, NaN, NaN, NaN, 1.0379889375763014, 1.1937205209978856, 0.0029921979876235127, NaN, NaN, NaN, 0.003127937510726042, 0.006610083481064066], [["'compile'", "'run'"], ["'plain'", "'binary'", "'weights'", "'sparse'", "'vmap-1'", "'vmap-2'"]], "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 1761995971144, 126.78, [0.87452, null, null, null, 0.98037, 1.1273, 0.0029009, null, null, null, 0.0030316, 0.0064166], [0.96662, null, null, null, 1.2071, 1.3243, 0.0034885, null, null, null, 0.003271, 0.0067285], [0.9111, null, null, null, 1.0082, 1.1611, 0.0029625, null, null, null, 0.0030783, 0.0065423], [0.95784, null, null, null, 1.0891, 1.2815, 0.0030141, null, null, null, 0.0031771, 0.0066871], [1, null, null, null, 1, 1, 4, null, null, null, 4, 2], [10, null, null, null, 10, 8, 10, null, null, null, 10, 10]], "rmse.EvalGbart.track_rmse": [[0.5344812870025635], [], "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28", 1761995893086, 6.5178]}, "durations": {}, "version": 2} -------------------------------------------------------------------------------- /.asv/results/gattocrucco-m1/485be21f-virtualenv-py3.13.json: -------------------------------------------------------------------------------- 1 | {"commit_hash": "485be21f4d702417930e51a0e4a169b287fedb7a", "env_name": "virtualenv-py3.13", "date": 1747384424000, "params": {"arch": "arm64", "cpu": "Apple M1 Pro", "machine": "gattocrucco-m1", "num_cpu": "8", "os": "Darwin 24.3.0", "ram": "17179869184", "python": "3.13"}, "python": "3.13", "requirements": {}, "env_vars": {}, "result_columns": ["result", "params", "version", "started_at", "duration", "stats_ci_99_a", "stats_ci_99_b", "stats_q_25", "stats_q_75", "stats_number", "stats_repeat", "samples", "profile"], "results": {"speed.TimeGbart.time_gbart": [[0.9466408959706314, NaN, NaN, NaN, 0.0372482294915244, NaN, NaN, NaN, 2.0026256454875693, NaN, NaN, NaN, 0.1389597289962694, NaN, NaN, NaN], [["0", "10"], ["'cold'", "'warm'"], ["1", "2", "8", "32"]], "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 1761996203795, 68.798, [0.90027, null, null, null, 0.036742, null, null, null, 1.9701, null, null, null, 0.13684, null, null, null], [1.0987, null, null, null, 0.042018, null, null, null, 2.0889, null, null, null, 0.14329, null, null, null], [0.91486, null, null, null, 0.036953, null, null, null, 1.9925, null, null, null, 0.13821, null, null, null], [0.97631, null, null, null, 0.037703, null, null, null, 2.0392, null, null, null, 0.14018, null, null, null], [1, null, null, null, 1, null, null, null, 1, null, null, null, 1, null, null, null], [10, null, null, null, 10, null, null, null, 10, null, null, null, 10, null, null, null]], "speed.TimeRunMcmc.time_run_mcmc": [[NaN, NaN, 1.150307229545433, 0.9872167290304787, 0.13351675000740215, 0.00015162501949816942, 1.1659881874802522, 0.027849478588905185], [["'compile'", "'run'"], ["0", "10"], ["'cold'", "'warm'"]], "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 1761996233129, 69.481, [null, null, 1.1156, 0.9476, 0.12715, 0.0001245, 1.1263, 0.027543], [null, null, 1.354, 1.0314, 0.1361, 0.00024, 1.2283, 0.028174], [null, null, 1.1261, 0.98495, 0.1327, 0.00013392, 1.1431, 0.02769], [null, null, 1.1632, 1.0095, 0.13516, 0.00017698, 1.192, 0.028044], [null, null, 1, 1, 1, 1, 1, 1], [null, null, 10, 10, 10, 10, 10, 10]], "speed.TimeStep.time_step": [[0.889989978983067, NaN, 0.9213577499613166, NaN, 0.9991666045389138, 1.2441695625311695, 0.0030455156229436398, NaN, 0.0050226875076380875, NaN, 0.0031668073788750917, 0.0066105417790822685], [["'compile'", "'run'"], ["'plain'", "'binary'", "'weights'", "'sparse'", "'vmap-1'", "'vmap-2'"]], "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 1761996267849, 166.61, [0.88136, null, 0.89148, null, 0.95987, 1.1576, 0.0028791, null, 0.0047176, null, 0.0030783, 0.0064644], [0.94919, null, 0.98912, null, 1.0999, 1.4406, 0.0033612, null, 0.01948, null, 0.0038439, 0.0068285], [0.88286, null, 0.9116, null, 0.98365, 1.1791, 0.0029699, null, 0.0047705, null, 0.0031384, 0.0065015], [0.91493, null, 0.95307, null, 1.0086, 1.3196, 0.0031029, null, 0.0053578, null, 0.0033833, 0.0067922], [1, null, 1, null, 1, 1, 4, null, 3, null, 4, 2], [10, null, 10, null, 10, 8, 10, null, 10, null, 10, 10]], "rmse.EvalGbart.track_rmse": [[0.5344812870025635], [], "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28", 1761996197843, 5.9512]}, "durations": {"": 9.19114875793457}, "version": 2} -------------------------------------------------------------------------------- /.asv/results/gattocrucco-m1/c0940a3a-virtualenv-py3.13.json: -------------------------------------------------------------------------------- 1 | {"commit_hash": "c0940a3a05ac16fc8412ac99180dff2501950dc2", "env_name": "virtualenv-py3.13", "date": 1748555083000, "params": {"arch": "arm64", "cpu": "Apple M1 Pro", "machine": "gattocrucco-m1", "num_cpu": "8", "os": "Darwin 24.3.0", "ram": "17179869184", "python": "3.13"}, "python": "3.13", "requirements": {}, "env_vars": {}, "result_columns": ["result", "params", "version", "started_at", "duration", "stats_ci_99_a", "stats_ci_99_b", "stats_q_25", "stats_q_75", "stats_number", "stats_repeat", "samples", "profile"], "results": {"speed.TimeGbart.time_gbart": [[0.7986744999652728, NaN, NaN, NaN, 0.03784627094864845, NaN, NaN, NaN, 3.5406936869840138, NaN, NaN, NaN, 0.9474997499492019, NaN, NaN, NaN], [["0", "10"], ["'cold'", "'warm'"], ["1", "2", "8", "32"]], "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 1761996542831, 86.547, [0.75583, null, null, null, 0.035992, null, null, null, 3.1683, null, null, null, 0.93938, null, null, null], [0.99608, null, null, null, 0.038397, null, null, null, 3.8884, null, null, null, 0.9503, null, null, null], [0.78607, null, null, null, 0.037178, null, null, null, 3.3809, null, null, null, 0.94124, null, null, null], [0.8184, null, null, null, 0.038229, null, null, null, 3.6931, null, null, null, 0.94927, null, null, null], [1, null, null, null, 1, null, null, null, 1, null, null, null, 1, null, null, null], [10, null, null, null, 10, null, null, null, 6, null, null, null, 10, null, null, null]], "speed.TimeRunMcmc.time_run_mcmc": [[NaN, NaN, 1.160155457968358, 1.0007016459712759, NaN, NaN, 1.3899157289997675, 0.029487416497431695], [["'compile'", "'run'"], ["0", "10"], ["'cold'", "'warm'"]], "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 1761996580465, 61.236, [null, null, 1.1115, 0.94854, null, null, 1.2218, 0.028881], [null, null, 1.2112, 1.1713, null, null, 1.72, 0.030835], [null, null, 1.152, 0.96688, null, null, 1.2693, 0.029031], [null, null, 1.1734, 1.0377, null, null, 1.4687, 0.029823], [null, null, 1, 1, null, null, 1, 1], [null, null, 10, 10, null, null, 10, 10]], "speed.TimeStep.time_step": [[0.8653726250049658, 0.7353072710102424, 0.8938072085147724, NaN, 0.9353522704914212, 1.1077213124954142, 0.002947359360405244, 0.003221333507099189, 0.004795805492904037, NaN, 0.003138265630695969, 0.0064666144899092615], [["'compile'", "'run'"], ["'plain'", "'binary'", "'weights'", "'sparse'", "'vmap-1'", "'vmap-2'"]], "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 1761996612279, 183.0, [0.81426, 0.69315, 0.84356, null, 0.87336, 1.0136, 0.0029027, 0.0030542, 0.0047239, null, 0.0030811, 0.0063696], [0.96683, 0.79986, 0.96876, null, 1.1369, 1.1911, 0.0030285, 0.0033125, 0.0049198, null, 0.003279, 0.0066911], [0.83593, 0.69796, 0.85921, null, 0.90863, 1.0732, 0.0029294, 0.0031727, 0.0047741, null, 0.003114, 0.0063988], [0.88469, 0.75935, 0.92098, null, 1.0198, 1.1492, 0.0029909, 0.0032789, 0.0048458, null, 0.0031694, 0.0065193], [1, 1, 1, null, 1, 1, 4, 4, 3, null, 4, 2], [10, 10, 10, null, 10, 10, 10, 10, 10, null, 10, 10]], "rmse.EvalGbart.track_rmse": [[0.525512158870697], [], "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28", 1761996536239, 6.5916]}, "durations": {"": 10.419498920440674}, "version": 2} -------------------------------------------------------------------------------- /.asv/results/gattocrucco-m1/d549d06f-virtualenv-py3.13.json: -------------------------------------------------------------------------------- 1 | {"commit_hash": "d549d06f29cfea142857d4682c1689a0a9b3806e", "env_name": "virtualenv-py3.13", "date": 1751913844000, "params": {"arch": "arm64", "cpu": "Apple M1 Pro", "machine": "gattocrucco-m1", "num_cpu": "8", "os": "Darwin 24.3.0", "ram": "17179869184", "python": "3.13"}, "python": "3.13", "requirements": {}, "env_vars": {}, "result_columns": ["result", "params", "version", "started_at", "duration", "stats_ci_99_a", "stats_ci_99_b", "stats_q_25", "stats_q_75", "stats_number", "stats_repeat", "samples", "profile"], "results": {"speed.TimeGbart.time_gbart": [[0.8703406045096926, NaN, NaN, NaN, 0.03905785398092121, NaN, NaN, NaN, 2.04708585399203, NaN, NaN, NaN, 1.0372965830028988, NaN, NaN, NaN], [["0", "10"], ["'cold'", "'warm'"], ["1", "2", "8", "32"]], "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 1761996916550, 87.455, [0.82967, null, null, null, 0.037165, null, null, null, 2.0258, null, null, null, 1.0274, null, null, null], [0.92584, null, null, null, 0.040358, null, null, null, 2.089, null, null, null, 1.055, null, null, null], [0.84177, null, null, null, 0.03859, null, null, null, 2.0352, null, null, null, 1.0339, null, null, null], [0.87421, null, null, null, 0.039264, null, null, null, 2.0633, null, null, null, 1.0433, null, null, null], [1, null, null, null, 1, null, null, null, 1, null, null, null, 1, null, null, null], [10, null, null, null, 10, null, null, null, 10, null, null, null, 10, null, null, null]], "speed.TimeRunMcmc.time_run_mcmc": [[NaN, NaN, 1.119391416956205, 0.9509509579511359, 0.14862164546502754, 0.0016324580064974725, 1.1865209790412337, 0.029436582990456372], [["'compile'", "'run'"], ["0", "10"], ["'cold'", "'warm'"]], "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 1761996955172, 67.026, [null, null, 1.0468, 0.91029, 0.1451, 0.0014815, 1.1656, 0.028909], [null, null, 1.1435, 1.0121, 0.19122, 0.0017985, 1.3533, 0.030238], [null, null, 1.075, 0.92686, 0.14663, 0.0015223, 1.1743, 0.029287], [null, null, 1.1319, 0.9702, 0.15451, 0.0016916, 1.2108, 0.03003], [null, null, 1, 1, 1, 1, 1, 1], [null, null, 10, 10, 10, 10, 10, 10]], "speed.TimeStep.time_step": [[0.8716923129977658, 0.7712800209410489, 0.8728721039951779, 1.2229048334993422, 0.9160569999949075, 1.0702567920088768, 0.003013729117810726, 0.003291833374532871, 0.004813388998930653, 0.003396270753000863, 0.003162864624755457, 0.006512770749395713], [["'compile'", "'run'"], ["'plain'", "'binary'", "'weights'", "'sparse'", "'vmap-1'", "'vmap-2'"]], "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 1761996988663, 225.56, [0.82677, 0.73139, 0.82608, 1.1591, 0.89152, 0.99325, 0.0029709, 0.0031492, 0.0047708, 0.0033145, 0.0031189, 0.0063524], [0.94984, 0.82489, 0.91989, 1.3356, 1.0005, 1.1078, 0.0030658, 0.0033573, 0.0049056, 0.0034595, 0.0032906, 0.0067557], [0.8507, 0.75394, 0.84239, 1.1941, 0.90712, 1.0246, 0.0029931, 0.0032374, 0.0047891, 0.0033697, 0.0031508, 0.0064062], [0.90333, 0.80636, 0.91378, 1.2453, 0.92358, 1.0892, 0.0030351, 0.003323, 0.0048326, 0.0034266, 0.0031988, 0.0066069], [1, 1, 1, 1, 1, 1, 4, 4, 3, 4, 4, 2], [10, 10, 10, 8, 10, 10, 10, 10, 10, 10, 10, 10]], "rmse.EvalGbart.track_rmse": [[0.5359137058258057], [], "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28", 1761996907185, 9.3649]}, "durations": {"": 8.990105867385864}, "version": 2} -------------------------------------------------------------------------------- /.asv/results/gattocrucco-m1/1119f48f-virtualenv-py3.13.json: -------------------------------------------------------------------------------- 1 | {"commit_hash": "1119f48ff708d49d1a31965ffa0fe540a23cab1d", "env_name": "virtualenv-py3.13", "date": 1761993911000, "params": {"arch": "arm64", "cpu": "Apple M1 Pro", "machine": "gattocrucco-m1", "num_cpu": "8", "os": "Darwin 24.3.0", "ram": "17179869184", "python": "3.13"}, "python": "3.13", "requirements": {}, "env_vars": {}, "result_columns": ["result", "params", "version", "started_at", "duration", "stats_ci_99_a", "stats_ci_99_b", "stats_q_25", "stats_q_75", "stats_number", "stats_repeat", "samples", "profile"], "results": {"speed.TimeGbart.time_gbart": [[0.8992934999987483, NaN, NaN, NaN, 0.04088618699461222, NaN, NaN, NaN, 2.178266270959284, NaN, NaN, NaN, 0.1406069994554855, 0.2530354790505953, 0.8698047290090472, 3.028478687047027], [["0", "10"], ["'cold'", "'warm'"], ["1", "2", "8", "32"]], "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 1761997347469, 150.87, [0.86069, null, null, null, 0.038242, null, null, null, 2.0901, null, null, null, 0.13891, 0.25126, 0.85568, 2.9536], [0.93961, null, null, null, 0.043125, null, null, null, 2.2448, null, null, null, 0.14589, 0.29274, 0.90863, 3.2281], [0.86776, null, null, null, 0.040236, null, null, null, 2.131, null, null, null, 0.13976, 0.25276, 0.86375, 3.0193], [0.93173, null, null, null, 0.041769, null, null, null, 2.2277, null, null, null, 0.14129, 0.25456, 0.88325, 3.107], [1, null, null, null, 1, null, null, null, 1, null, null, null, 1, 1, 1, 1], [10, null, null, null, 10, null, null, null, 10, null, null, null, 10, 10, 10, 6]], "speed.TimeRunMcmc.time_run_mcmc": [[NaN, NaN, 1.1052386040100828, 0.9534334164927714, 0.1504268335411325, 0.0014830209547653794, 1.2058203124906868, 0.029219582967925817], [["'compile'", "'run'"], ["0", "10"], ["'cold'", "'warm'"]], "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 1761997418190, 67.958, [null, null, 1.0629, 0.90019, 0.1416, 0.0014253, 1.1318, 0.028844], [null, null, 1.1791, 1.0855, 0.15775, 0.0015505, 1.3001, 0.038026], [null, null, 1.087, 0.93143, 0.14874, 0.0014329, 1.1858, 0.029009], [null, null, 1.1291, 1.0156, 0.15259, 0.0015008, 1.2335, 0.02945], [null, null, 1, 1, 1, 1, 1, 1], [null, null, 10, 10, 10, 10, 10, 10]], "speed.TimeStep.time_step": [[0.8965204374981113, 0.7574879370513372, 0.8828994379728101, 1.1899915629765019, 0.9233364164829254, 1.1276098960079253, 0.0030094896210357547, 0.003220703118131496, 0.004806166669974724, 0.0033490677597001195, 0.003208968642866239, 0.006511489511467516], [["'compile'", "'run'"], ["'plain'", "'binary'", "'weights'", "'sparse'", "'vmap-1'", "'vmap-2'"]], "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 1761997452786, 227.79, [0.86485, 0.7257, 0.8345, 1.1405, 0.89542, 1.0732, 0.0029289, 0.0031292, 0.0047317, 0.0032789, 0.0030704, 0.0064343], [0.95944, 0.80832, 0.92182, 1.277, 0.99602, 1.2108, 0.0030537, 0.0034741, 0.0048411, 0.0034172, 0.0032448, 0.0067581], [0.87863, 0.7509, 0.85363, 1.1441, 0.90718, 1.0893, 0.0030013, 0.0031874, 0.0047664, 0.0033363, 0.0031489, 0.0064767], [0.92631, 0.77443, 0.8909, 1.2212, 0.94361, 1.167, 0.0030189, 0.00327, 0.0048233, 0.0033708, 0.0032221, 0.0065548], [1, 1, 1, 1, 1, 1, 4, 4, 3, 4, 4, 2], [10, 10, 10, 8, 10, 10, 10, 10, 10, 10, 10, 10]], "rmse.EvalGbart.track_rmse": [[0.5359137058258057], [], "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28", 1761997338166, 9.303]}, "durations": {"": 8.945157051086426}, "version": 2} -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/conftest.py 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Pytest configuration.""" 26 | 27 | from contextlib import nullcontext 28 | from re import fullmatch 29 | 30 | import jax 31 | import numpy as np 32 | import pytest 33 | 34 | from bartz.jaxext import get_default_device, split 35 | 36 | jax.config.update('jax_debug_key_reuse', True) 37 | jax.config.update('jax_debug_nans', True) 38 | jax.config.update('jax_debug_infs', True) 39 | jax.config.update('jax_legacy_prng_key', 'error') 40 | 41 | 42 | @pytest.fixture 43 | def keys(request) -> split: 44 | """ 45 | Return a deterministic per-test-case list of jax random keys. 46 | 47 | To use a key, do `keys.pop()`. If consumed this way, this list of keys can 48 | be safely used by multiple fixtures involved in the test case. 49 | """ 50 | nodeid = request.node.nodeid 51 | # exclude xdist_group suffixes because they are active only under xdist 52 | match = fullmatch(r'(.+?\.py::.+?(\[.+?\])?)(@.+)?', nodeid) 53 | nodeid = match.group(1) 54 | seed = np.array([nodeid], np.bytes_).view(np.uint8) 55 | rng = np.random.default_rng(seed) 56 | seed = np.array(rng.bytes(4)).view(np.uint32) 57 | key = jax.random.key(seed) 58 | return split(key, 128) 59 | 60 | 61 | def pytest_addoption(parser: pytest.Parser) -> None: 62 | """Add custom command line options.""" 63 | parser.addoption( 64 | '--platform', 65 | choices=['cpu', 'gpu', 'auto'], 66 | default='auto', 67 | help='JAX platform to use: cpu, gpu, or auto (default: auto)', 68 | ) 69 | 70 | 71 | def pytest_sessionstart(session: pytest.Session) -> None: 72 | """Configure and print the jax device.""" 73 | # Get the platform option 74 | platform = session.config.getoption('--platform') 75 | 76 | # Set the default JAX device if not auto 77 | if platform != 'auto': 78 | current_platform = get_default_device().platform 79 | if current_platform != platform: 80 | jax.config.update('jax_default_device', jax.devices(platform)[0]) 81 | assert get_default_device().platform == platform 82 | 83 | # Get the capture manager plugin 84 | capman = session.config.pluginmanager.get_plugin('capturemanager') 85 | 86 | # Suspend capturing temporarily 87 | if capman: 88 | ctx = capman.global_and_fixture_disabled() 89 | else: 90 | ctx = nullcontext() 91 | 92 | with ctx: 93 | device_kind = get_default_device().device_kind 94 | print(f'jax default device: {device_kind}') 95 | -------------------------------------------------------------------------------- /docs/pkglist.md: -------------------------------------------------------------------------------- 1 | 26 | 27 | # Other BART packages 28 | 29 | - [stochtree](https://github.com/StochasticTree/stochtree) C++ library with R and Python interfaces 30 | - [bnptools](https://github.com/rsparapa/bnptools) Feature-rich R packages for BART and some variants 31 | - [vdorie](https://github.com/vdorie)'s repositories (dbarts, bartCause, stan4bart) 32 | - [bartMachine](https://github.com/kapelner/bartMachine) R package, supports missing predictors imputation 33 | - [SoftBART](https://github.com/theodds/SoftBART) R package with a smooth version of BART 34 | - [jaredsmurray](https://github.com/jaredsmurray)'s repositories (bcf, monbart) 35 | - [skdeshpande91](https://github.com/skdeshpande91)'s repositories (flexBART, flexBCF, VCBART) 36 | - [JingyuHe](https://github.com/JingyuHe)'s repositories (XBART & related) 37 | - [BayesTree](https://cran.r-project.org/package=BayesTree) R package, original BART implementation 38 | - [OpenBT](https://bitbucket.org/mpratola/openbt) Heteroskedastic BART, rotate & perturb proposals, C++ library with R interface 39 | - [OpenBT](https://github.com/jcyannotty/OpenBT) fork of the above 40 | - [lsqfitgp](https://github.com/Gattocrucco/lsqfitgp) Infinite trees limit of BART 41 | - [mBART](https://github.com/remcc/mBART_shlib) 42 | - [SequentialBART](https://github.com/mjdaniels/SequentialBART) 43 | - [sparseBART](https://github.com/cspanbauer/sparseBART) 44 | - [pymc-bart](https://github.com/pymc-devs/pymc-bart) BART within PyMC 45 | - [semibart](https://github.com/zeldow/semibart) 46 | - [ebprado](https://github.com/ebprado)'s repositories 47 | - [EoghanONeill](https://github.com/EoghanONeill)'s repositories 48 | - [MateusMaiaDS](https://github.com/MateusMaiaDS)'s repositories (subart, gpbart) 49 | - [nchenderson](https://github.com/nchenderson)'s repositories 50 | - [bartpy](https://github.com/JakeColtman/bartpy) 51 | - [BayesTreePrior](https://github.com/AlexiaJM/BayesTreePrior) Sample the prior of BART 52 | - [BayesTree.jl](https://github.com/mathcg/BayesTree.jl) 53 | - [longbet](https://github.com/google/longbet) 54 | - [richael008](https://github.com/richael008)'s repositories 55 | - [bartMan](https://github.com/AlanInglis/bartMan) R package, posterior analysis and diagnostics 56 | - [drbart](https://github.com/vittorioorlandi/drbart) Density regression with BART 57 | - [BART-BMA](https://github.com/BelindaHernandez/BART-BMA) (superseded by [bartBMAnew](https://github.com/EoghanONeill/bartBMAnew)) 58 | - [XBCF](https://github.com/socket778/XBCF) (superseded by StochasticTree) 59 | - [mpibart](https://matthewpratola.com/mpibart) Old parallel implementation 60 | - [HE-BART](https://github.com/brunaw/HE-BART) 61 | - [pgbart](https://github.com/balajiln/pgbart) Particle Gibbs BART 62 | - [skewBART](https://github.com/Seungha-Um/skewBART) multivariate skewed responses 63 | -------------------------------------------------------------------------------- /tests/test_mcmcstep.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/test_mcmcstep.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Test `bartz.mcmcstep`.""" 26 | 27 | from jax import numpy as jnp 28 | from jax import vmap 29 | from jax.random import bernoulli, clone, permutation, randint 30 | from jaxtyping import Array, Bool, Int32, Key 31 | from scipy import stats 32 | 33 | from bartz.jaxext import split 34 | from bartz.mcmcstep import randint_masked 35 | 36 | 37 | def vmap_randint_masked( 38 | key: Key[Array, ''], mask: Bool[Array, ' n'], size: int 39 | ) -> Int32[Array, '* n']: 40 | """Vectorized version of `randint_masked`.""" 41 | vrm = vmap(randint_masked, in_axes=(0, None)) 42 | keys = split(key, 1) 43 | return vrm(keys.pop(size), mask) 44 | 45 | 46 | class TestRandintMasked: 47 | """Test `mcmcstep.randint_masked`.""" 48 | 49 | def test_all_false(self, keys): 50 | """Check what happens when no value is allowed.""" 51 | for size in range(1, 10): 52 | u = randint_masked(keys.pop(), jnp.zeros(size, bool)) 53 | assert u == size 54 | 55 | def test_all_true(self, keys): 56 | """Check it's equivalent to `randint` when all values are allowed.""" 57 | key = keys.pop() 58 | size = 10_000 59 | u1 = randint_masked(key, jnp.ones(size, bool)) 60 | u2 = randint(clone(key), (), 0, size) 61 | assert u1 == u2 62 | 63 | def test_no_disallowed_values(self, keys): 64 | """Check disallowed values are never selected.""" 65 | key = keys.pop() 66 | for _ in range(100): 67 | keys = split(key, 3) 68 | mask = bernoulli(keys.pop(), 0.5, (10,)) 69 | if not jnp.any(mask): # pragma: no cover, rarely happens 70 | continue 71 | u = randint_masked(keys.pop(), mask) 72 | assert 0 <= u < mask.size 73 | assert mask[u] 74 | key = keys.pop() 75 | 76 | def test_correct_distribution(self, keys): 77 | """Check the distribution of values is uniform.""" 78 | # create mask 79 | num_allowed = 10 80 | mask = jnp.zeros(2 * num_allowed, bool) 81 | mask = mask.at[:num_allowed].set(True) 82 | indices = jnp.arange(mask.size) 83 | indices = permutation(keys.pop(), indices) 84 | mask = mask[indices] 85 | 86 | # sample values 87 | n = 10_000 88 | u: Int32[Array, '{n}'] = vmap_randint_masked(keys.pop(), mask, n) 89 | u = indices[u] 90 | assert jnp.all(u < num_allowed) 91 | 92 | # check that the distribution is uniform 93 | # likelihood ratio test for multinomial with free p vs. constant p 94 | k = jnp.bincount(u, length=num_allowed) 95 | llr = jnp.sum(jnp.where(k, k * jnp.log(k / n * num_allowed), 0)) 96 | lamda = 2 * llr 97 | pvalue = stats.chi2.sf(lamda, num_allowed - 1) 98 | assert pvalue > 0.1 99 | -------------------------------------------------------------------------------- /tests/test_mcmcloop.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/test_mcmcloop.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Test `bartz.mcmcloop`.""" 26 | 27 | from functools import partial 28 | 29 | from equinox import filter_jit 30 | from jax import numpy as jnp 31 | from jax import vmap 32 | from jax.tree import map_with_path 33 | from jax.tree_util import tree_map 34 | from jaxtyping import Array, Float32, UInt8 35 | from numpy.testing import assert_array_equal 36 | 37 | from bartz import mcmcloop, mcmcstep 38 | 39 | 40 | def gen_data( 41 | p: int, n: int 42 | ) -> tuple[UInt8[Array, 'p n'], Float32[Array, ' n'], UInt8[Array, ' p']]: 43 | """Generate pretty nonsensical data.""" 44 | X = jnp.arange(p * n, dtype=jnp.uint8).reshape(p, n) 45 | X = vmap(jnp.roll)(X, jnp.arange(p)) 46 | max_split = jnp.full(p, 255, jnp.uint8) 47 | y = jnp.cos(jnp.linspace(0, 2 * jnp.pi / 32 * n, n)) 48 | return X, y, max_split 49 | 50 | 51 | def make_p_nonterminal(maxdepth: int) -> Float32[Array, ' {maxdepth}-1']: 52 | """Prepare the p_nonterminal argument to `mcmcstep.init`.""" 53 | depth = jnp.arange(maxdepth - 1) 54 | base = 0.95 55 | power = 2 56 | return base / (1 + depth).astype(float) ** power 57 | 58 | 59 | @filter_jit 60 | def init(p: int, n: int, ntree: int, **kwargs): 61 | """Simplified version of `bartz.mcmcstep.init` with data pre-filled.""" 62 | X, y, max_split = gen_data(p, n) 63 | return mcmcstep.init( 64 | X=X, 65 | y=y, 66 | max_split=max_split, 67 | num_trees=ntree, 68 | p_nonterminal=make_p_nonterminal(6), 69 | sigma_mu2=1.0, 70 | sigma2_alpha=1, 71 | sigma2_beta=1, 72 | min_points_per_decision_node=10, 73 | filter_splitless_vars=False, 74 | **kwargs, 75 | ) 76 | 77 | 78 | class TestRunMcmc: 79 | """Test `mcmcloop.run_mcmc`.""" 80 | 81 | def test_final_state_overflow(self, keys): 82 | """Check that the final state is the one in the trace even if there's overflow.""" 83 | initial_state = init(10, 100, 20) 84 | final_state, _, main_trace = mcmcloop.run_mcmc( 85 | keys.pop(), initial_state, 10, inner_loop_length=9 86 | ) 87 | 88 | assert_array_equal(final_state.forest.leaf_tree, main_trace.leaf_tree[-1]) 89 | assert_array_equal(final_state.forest.var_tree, main_trace.var_tree[-1]) 90 | assert_array_equal(final_state.forest.split_tree, main_trace.split_tree[-1]) 91 | assert_array_equal(final_state.sigma2, main_trace.sigma2[-1]) 92 | 93 | def test_zero_iterations(self, keys): 94 | """Check there's no error if the loop does not run.""" 95 | initial_state = init(10, 100, 20) 96 | final_state, burnin_trace, main_trace = mcmcloop.run_mcmc( 97 | keys.pop(), initial_state, 0, n_burn=0 98 | ) 99 | 100 | tree_map(partial(assert_array_equal, strict=True), initial_state, final_state) 101 | 102 | def assert_empty_trace(path, x): # noqa: ARG001, for debugging 103 | assert x.shape[0] == 0 104 | 105 | map_with_path(assert_empty_trace, burnin_trace) 106 | map_with_path(assert_empty_trace, main_trace) 107 | -------------------------------------------------------------------------------- /benchmarks/rmse.py: -------------------------------------------------------------------------------- 1 | # bartz/benchmarks/rmse.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Measure the predictive performance on test sets.""" 26 | 27 | from contextlib import redirect_stdout 28 | from dataclasses import dataclass 29 | from functools import partial 30 | from io import StringIO 31 | 32 | from jax import jit, random, vmap 33 | from jax import numpy as jnp 34 | 35 | try: 36 | from bartz.BART import gbart 37 | except ImportError: 38 | from bartz import BART as gbart 39 | 40 | 41 | @partial(jit, static_argnums=(1, 2)) 42 | def simulate_data(key, n: int, p: int, max_interactions): 43 | """Simulate data for regression. 44 | 45 | This uses data-based standardization, so you have to generate train & 46 | test at once. 47 | """ 48 | # split random key 49 | keys = list(random.split(key, 4)) 50 | 51 | # generate matrices 52 | X = random.uniform(keys.pop(), (p, n)) 53 | beta = random.normal(keys.pop(), (p,)) 54 | A = random.normal(keys.pop(), (p, p)) 55 | error = random.normal(keys.pop(), (n,)) 56 | 57 | # make A banded to limit the number of interactions 58 | num_nonzero = 1 + (max_interactions - 1) // 2 59 | num_nonzero = jnp.clip(num_nonzero, 0, p) 60 | interaction_pattern = jnp.arange(p) < num_nonzero 61 | multi_roll = vmap(jnp.roll, in_axes=(None, 0)) 62 | nonzero = multi_roll(interaction_pattern, jnp.arange(p)) 63 | A *= nonzero 64 | 65 | # compute terms 66 | linear = beta @ X 67 | quadratic = jnp.einsum('ai,bi,ab->i', X, X, A) 68 | 69 | # equalize the terms 70 | mu = linear / jnp.std(linear) + quadratic / jnp.std(quadratic) 71 | mu /= jnp.std(mu) # because linear and quadratic are correlated 72 | 73 | return X, mu, error 74 | 75 | 76 | @dataclass(frozen=True) 77 | class Data: 78 | """Data for regression.""" 79 | 80 | X_train: jnp.ndarray 81 | mu_train: jnp.ndarray 82 | error_train: jnp.ndarray 83 | X_test: jnp.ndarray 84 | mu_test: jnp.ndarray 85 | error_test: jnp.ndarray 86 | 87 | @property 88 | def y_train(self): 89 | """Return the training targets.""" 90 | return self.mu_train + self.error_train 91 | 92 | @property 93 | def y_test(self): 94 | """Return the test targets.""" 95 | return self.mu_test + self.error_test 96 | 97 | 98 | def make_data(key, n_train: int, n_test: int, p: int) -> Data: 99 | """Simulate data and split in train-test set.""" 100 | X, mu, error = simulate_data(key, n_train + n_test, p, 5) 101 | return Data( 102 | X[:, :n_train], 103 | mu[:n_train], 104 | error[:n_train], 105 | X[:, n_train:], 106 | mu[n_train:], 107 | error[n_train:], 108 | ) 109 | 110 | 111 | class EvalGbart: 112 | """Out-of-sample evaluation of gbart.""" 113 | 114 | timeout = 30.0 115 | unit = 'latent_sdev' 116 | 117 | def track_rmse(self) -> float: 118 | """Return the RMSE for predictions on a test set.""" 119 | key = random.key(2025_06_26_21_02) 120 | data = make_data(key, 100, 1000, 20) 121 | with redirect_stdout(StringIO()): 122 | bart = gbart( 123 | data.X_train, 124 | data.y_train, 125 | x_test=data.X_test, 126 | nskip=1000, 127 | ndpost=1000, 128 | seed=key, 129 | ) 130 | return jnp.sqrt(jnp.mean(jnp.square(bart.yhat_test_mean - data.mu_test))).item() 131 | -------------------------------------------------------------------------------- /tests/util.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/util.py 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Functions intended to be shared across the test suite.""" 26 | 27 | from pathlib import Path 28 | 29 | import numpy as np 30 | import tomli 31 | from jaxtyping import ArrayLike 32 | from scipy import linalg 33 | 34 | 35 | def assert_close_matrices( 36 | actual: ArrayLike, 37 | desired: ArrayLike, 38 | *, 39 | rtol: float = 0.0, 40 | atol: float = 0.0, 41 | tozero: bool = False, 42 | ): 43 | """ 44 | Check if two matrices are similar. 45 | 46 | Parameters 47 | ---------- 48 | actual 49 | desired 50 | The two matrices to be compared. Must be scalars, vectors, or 2d arrays. 51 | Scalars and vectors are intepreted as 1x1 and Nx1 matrices, but the two 52 | arrays must have the same shape beforehand. 53 | rtol 54 | atol 55 | Relative and absolute tolerances for the comparison. The closeness 56 | condition is: 57 | 58 | ||actual - desired|| <= atol + rtol * ||desired||, 59 | 60 | where the norm is the matrix 2-norm, i.e., the maximum (in absolute 61 | value) singular value. 62 | tozero 63 | If True, use the following codition instead: 64 | 65 | ||actual|| <= atol + rtol * ||desired|| 66 | 67 | So `actual` is compared to zero, and `desired` is only used as a 68 | reference to set the threshold. 69 | 70 | Raises 71 | ------ 72 | ValueError 73 | If the two matrices have different shapes. 74 | """ 75 | actual = np.asarray(actual) 76 | desired = np.asarray(desired) 77 | if actual.shape != desired.shape: 78 | msg = f'{actual.shape=} != {desired.shape=}' 79 | raise ValueError(msg) 80 | if actual.size > 0: 81 | actual = np.atleast_1d(actual) 82 | desired = np.atleast_1d(desired) 83 | 84 | if tozero: 85 | expr = 'actual' 86 | ref = 'zero' 87 | else: 88 | expr = 'actual - desired' 89 | ref = 'desired' 90 | 91 | dnorm = linalg.norm(desired, 2) 92 | adnorm = linalg.norm(eval(expr), 2) # noqa: S307, expr is a literal 93 | ratio = adnorm / dnorm if dnorm else np.nan 94 | 95 | msg = f"""\ 96 | matrices actual and {ref} are not close in 2-norm 97 | matrix shape: {desired.shape} 98 | norm(desired) = {dnorm:.2g} 99 | norm({expr}) = {adnorm:.2g} (atol = {atol:.2g}) 100 | ratio = {ratio:.2g} (rtol = {rtol:.2g})""" 101 | 102 | assert adnorm <= atol + rtol * dnorm, msg 103 | 104 | 105 | def get_old_python_str() -> str: 106 | """Read the oldest supported Python from pyproject.toml.""" 107 | with Path('pyproject.toml').open('rb') as file: 108 | return tomli.load(file)['project']['requires-python'].removeprefix('>=') 109 | 110 | 111 | def get_old_python_tuple() -> tuple[int, int]: 112 | """Read the oldest supported Python from pyproject.toml as a tuple.""" 113 | ver_str = get_old_python_str() 114 | major, minor = ver_str.split('.') 115 | return int(major), int(minor) 116 | 117 | 118 | def get_version() -> str: 119 | """Read the bartz version from pyproject.toml.""" 120 | with Path('pyproject.toml').open('rb') as file: 121 | return tomli.load(file)['project']['version'] 122 | 123 | 124 | def update_version(): 125 | """Update the version file.""" 126 | version = get_version() 127 | Path('src/bartz/_version.py').write_text(f'__version__ = {version!r}\n') 128 | -------------------------------------------------------------------------------- /tests/rbartpackages/dbarts.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/rbartpackages/dbarts.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Python wrapper of the R package `dbarts`.""" 26 | 27 | # ruff: noqa: D101, D102 28 | 29 | from rpy2 import robjects 30 | 31 | from tests.rbartpackages._base import RObjectBase, rmethod 32 | 33 | 34 | class bart(RObjectBase): 35 | """ 36 | 37 | Python interface to dbarts::bart. 38 | 39 | The named numeric vector form of the `splitprobs` parameter must be 40 | specified as a dictionary in Python. 41 | 42 | """ 43 | 44 | _rfuncname = 'dbarts::bart' 45 | _split_probs = 'splitprobs' 46 | 47 | def __init__(self, *args, **kw): 48 | split_probs = kw.get(self._split_probs) 49 | if isinstance(split_probs, dict): 50 | values = list(split_probs.values()) 51 | names = list(split_probs.keys()) 52 | split_probs = robjects.FloatVector(values) 53 | split_probs = robjects.r('setNames')(split_probs, names) 54 | kw[self._split_probs] = split_probs 55 | 56 | super().__init__(*args, **kw) 57 | 58 | @rmethod 59 | def predict(self, *args, **kw): ... 60 | 61 | @rmethod 62 | def extract(self, *args, **kw): ... 63 | 64 | @rmethod 65 | def fitted(self, *args, **kw): ... 66 | 67 | 68 | class bart2(bart): 69 | """ 70 | 71 | Python interface to dbarts::bart2. 72 | 73 | The named numeric vector form of the `split_probs` parameter must be 74 | specified as a dictionary in Python. 75 | 76 | """ 77 | 78 | _rfuncname = 'dbarts::bart2' 79 | _split_probs = 'split_probs' 80 | 81 | def __init__(self, formula, *args, **kw): 82 | formula = robjects.Formula(formula) 83 | super().__init__(formula, *args, **kw) 84 | 85 | 86 | class rbart_vi(bart2): 87 | """ 88 | 89 | Python interface to dbarts::rbart_vi. 90 | 91 | The named numeric vector form of the `split_probs` parameter must be 92 | specified as a dictionary in Python. 93 | 94 | """ 95 | 96 | _rfuncname = 'dbarts::rbart_vi' 97 | 98 | 99 | class dbarts(RObjectBase): 100 | _rfuncname = 'dbarts::dbarts' 101 | 102 | @rmethod 103 | def run(self, *args, **kw): ... 104 | 105 | @rmethod 106 | def sampleTreesFromPrior(self, *args, **kw): ... 107 | 108 | @rmethod 109 | def sampleNodeParametersFromPrior(self, *args, **kw): ... 110 | 111 | @rmethod 112 | def copy(self, *args, **kw): ... 113 | 114 | @rmethod 115 | def show(self, *args, **kw): ... 116 | 117 | @rmethod 118 | def predict(self, *args, **kw): ... 119 | 120 | @rmethod 121 | def setControl(self, *args, **kw): ... 122 | 123 | @rmethod 124 | def setModel(self, *args, **kw): ... 125 | 126 | @rmethod 127 | def setData(self, *args, **kw): ... 128 | 129 | @rmethod 130 | def setResponse(self, *args, **kw): ... 131 | 132 | @rmethod 133 | def setOffset(self, *args, **kw): ... 134 | 135 | @rmethod 136 | def setSigma(self, *args, **kw): ... 137 | 138 | @rmethod 139 | def setPredictor(self, *args, **kw): ... 140 | 141 | @rmethod 142 | def setTestPredictor(self, *args, **kw): ... 143 | 144 | @rmethod 145 | def setTestPredictorAndOffset(self, *args, **kw): ... 146 | 147 | @rmethod 148 | def setTestOffset(self, *args, **kw): ... 149 | 150 | @rmethod 151 | def printTrees(self, *args, **kw): ... 152 | 153 | @rmethod 154 | def plotTree(self, *args, **kw): ... 155 | 156 | 157 | class dbartsControl(RObjectBase): 158 | _rfuncname = 'dbarts::dbartsControl' 159 | -------------------------------------------------------------------------------- /tests/test_meta.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/test_meta.py 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Test properties of pytest itself or other utilities.""" 26 | 27 | from functools import partial 28 | 29 | import jax 30 | import pytest 31 | from jax import jit, random 32 | from jax import numpy as jnp 33 | from jax.errors import KeyReuseError 34 | 35 | 36 | @pytest.fixture 37 | def keys1(keys): 38 | """Pass-through the `keys` fixture.""" 39 | return keys 40 | 41 | 42 | @pytest.fixture 43 | def keys2(keys): 44 | """Pass-through the `keys` fixture.""" 45 | return keys 46 | 47 | 48 | def test_random_keys_do_not_depend_on_fixture(keys1, keys2): 49 | """Check that the `keys` fixture is per-test-case, not per-fixture.""" 50 | assert keys1 is keys2 51 | 52 | 53 | def test_number_of_random_keys(keys): 54 | """Check the fixed number of available keys. 55 | 56 | This is here just as reference for the `test_random_keys_are_consumed` test 57 | below. 58 | """ 59 | assert len(keys) == 128 60 | 61 | 62 | @pytest.fixture 63 | def consume_one_key(keys): # noqa: D103 64 | return keys.pop() 65 | 66 | 67 | @pytest.fixture 68 | def consume_another_key(keys): # noqa: D103 69 | return keys.pop() 70 | 71 | 72 | def test_random_keys_are_consumed(consume_one_key, consume_another_key, keys): # noqa: ARG001 73 | """Check that the random keys in `keys` can't be used more than once.""" 74 | assert len(keys) == 126 75 | 76 | 77 | def test_debug_key_reuse(keys): 78 | """Check that the jax debug_key_reuse option works.""" 79 | key = keys.pop() 80 | random.uniform(key) 81 | with pytest.raises(KeyReuseError): 82 | random.uniform(key) 83 | 84 | 85 | def test_debug_key_reuse_within_jit(keys): 86 | """Check that the jax debug_key_reuse option works within a jitted function.""" 87 | 88 | @jit 89 | def func(key): 90 | return random.uniform(key) + random.uniform(key) 91 | 92 | with pytest.raises(KeyReuseError): 93 | func(keys.pop()) 94 | 95 | 96 | class TestJaxNoCopyBehavior: 97 | """Check whether jax makes actual copies of arrays in various conditions.""" 98 | 99 | def test_unconditional_buffer_donation(self): 100 | """Test jax donates buffers even if they are small.""" 101 | # nan-debug mode makes jax create some copies apparently 102 | with jax.debug_nans(False): 103 | # check buffer donation works unconditionally 104 | x = jnp.arange(100) 105 | xp = x.unsafe_buffer_pointer() 106 | 107 | @partial(jit, donate_argnums=(0,)) 108 | def noop(x): 109 | return x 110 | 111 | y = noop(x) 112 | yp = y.unsafe_buffer_pointer() 113 | 114 | assert xp == yp 115 | with pytest.raises(RuntimeError, match=r'delete'): 116 | x[0] 117 | 118 | def test_jnp_array_copy_no_jit(self): 119 | """Test jnp.array makes copies outside jitted functions.""" 120 | y = jnp.arange(100) 121 | yp = y.unsafe_buffer_pointer() 122 | 123 | z = jnp.array(y) 124 | zp = z.unsafe_buffer_pointer() 125 | 126 | assert zp != yp 127 | 128 | def test_jnp_array_no_copy_jit(self): 129 | """Check jnp.array does not make copies within jit.""" 130 | # nan-debug mode makes jax create some copies apparently 131 | with jax.debug_nans(False): 132 | y = jnp.arange(100) 133 | yp = y.unsafe_buffer_pointer() 134 | 135 | @partial(jit, donate_argnums=(0,)) 136 | def array(x): 137 | return jnp.array(x) 138 | 139 | q = array(y) 140 | qp = q.unsafe_buffer_pointer() 141 | 142 | assert qp == yp 143 | -------------------------------------------------------------------------------- /tests/rbartpackages/BART.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/rbartpackages/BART.py 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Wrapper for the R package BART.""" 26 | 27 | from typing import NamedTuple, TypedDict, cast 28 | 29 | import numpy as np 30 | from jaxtyping import AbstractDtype, Bool, Float64, Int32 31 | from numpy import ndarray 32 | from rpy2.rlike.container import NamedList 33 | 34 | from tests.rbartpackages._base import RObjectBase, rmethod 35 | 36 | 37 | class TreeDraws(TypedDict): 38 | """Type of the `treedraws` attribute of `mc_gbart`.""" 39 | 40 | cutpoints: dict[int | str, Float64[ndarray, ' numcut[i]']] 41 | trees: str 42 | 43 | 44 | class String(AbstractDtype): 45 | """Represent a `numpy.str_` data dtype.""" 46 | 47 | dtypes = r' 0): 104 | self.rm_const -= 1 105 | else: 106 | msg = 'failed to parse rm.const because indices change sign' 107 | raise ValueError(msg) 108 | 109 | if self.sigma_mean is not None: 110 | self.sigma_mean = self.sigma_mean.item() 111 | 112 | r_treedraws = cast(NamedList, self.treedraws) 113 | cutpoints: NamedList = r_treedraws.getbyname('cutpoints') 114 | self.treedraws = { 115 | 'cutpoints': { 116 | i if it.name is None else it.name.item(): it.value 117 | for i, it in enumerate(cutpoints.items()) 118 | }, 119 | 'trees': r_treedraws.getbyname('trees').item(), 120 | } 121 | 122 | @rmethod 123 | def predict( 124 | self, newdata: Float64[ndarray, 'm p'], *args, **kwargs 125 | ) -> Float64[ndarray, 'ndpost/mc_cores m']: 126 | """Compute predictions.""" 127 | ... 128 | 129 | 130 | class bartModelMatrix(RObjectBase): # noqa: D101 because the R doc is added automatically 131 | _rfuncname = 'BART::bartModelMatrix' 132 | 133 | 134 | class gbart(mc_gbart): # noqa: D101 because the R doc is added automatically 135 | _rfuncname = 'BART::gbart' 136 | 137 | sigma: Float64[ndarray, ' nskip+ndpost'] | None = None 138 | -------------------------------------------------------------------------------- /tests/test_prepcovars.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/test_prepcovars.py 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Test the `bartz.prepcovars` module.""" 26 | 27 | import pytest 28 | from jax import debug_infs 29 | from jax import numpy as jnp 30 | from numpy.testing import assert_array_equal 31 | 32 | from bartz.prepcovars import bin_predictors, quantilized_splits_from_matrix 33 | 34 | 35 | class TestQuantilizer: 36 | """Test `prepcovars.quantilized_splits_from_matrix`.""" 37 | 38 | @pytest.mark.parametrize( 39 | 'fill_value', [jnp.finfo(jnp.float32).max, jnp.iinfo(jnp.int32).max] 40 | ) 41 | def test_splits_fill(self, fill_value): 42 | """Check how predictors with less unique values are right-padded.""" 43 | with debug_infs(not jnp.isinf(fill_value)): 44 | fill_value = jnp.array(fill_value) 45 | x = jnp.array([[1, 1, 3, 3], [1, 3, 3, 5], [1, 3, 5, 7]], fill_value.dtype) 46 | splits, _ = quantilized_splits_from_matrix(x, 100) 47 | expected_splits = [[2, fill_value, fill_value], [2, 4, fill_value], [2, 4, 6]] 48 | assert_array_equal(splits, expected_splits) 49 | 50 | def test_max_splits(self): 51 | """Check that the number of splits per predictor is counted correctly.""" 52 | x = jnp.array([[1, 1, 1, 1], [4, 4, 1, 1], [2, 1, 3, 2], [1, 4, 2, 3]]) 53 | _, max_split = quantilized_splits_from_matrix(x, 100) 54 | assert_array_equal(max_split, jnp.arange(4)) 55 | 56 | def test_integer_splits_overflow(self): 57 | """Check that the splits are computed correctly at the limit of overflow.""" 58 | x = jnp.array([[-(2**31), 2**31 - 2]]) 59 | splits, _ = quantilized_splits_from_matrix(x, 100) 60 | expected_splits = [[-1]] 61 | assert_array_equal(splits, expected_splits) 62 | 63 | @pytest.mark.parametrize('dtype', [int, float]) 64 | def test_splits_type(self, dtype): 65 | """Check that the input type is preserved.""" 66 | x = jnp.arange(10, dtype=dtype)[None, :] 67 | splits, _ = quantilized_splits_from_matrix(x, 100) 68 | assert splits.dtype == x.dtype 69 | 70 | def test_splits_length(self): 71 | """Check that the correct number of splits is returned in corner cases.""" 72 | x = jnp.linspace(0, 1, 10)[None, :] 73 | 74 | short_splits, _ = quantilized_splits_from_matrix(x, 2) 75 | assert short_splits.shape == (1, 1) 76 | 77 | long_splits, _ = quantilized_splits_from_matrix(x, 100) 78 | assert long_splits.shape == (1, 9) 79 | 80 | just_right_splits, _ = quantilized_splits_from_matrix(x, 10) 81 | assert just_right_splits.shape == (1, 9) 82 | 83 | no_splits, _ = quantilized_splits_from_matrix(x, 1) 84 | assert no_splits.shape == (1, 0) 85 | 86 | def test_round_trip(self): 87 | """Check that `bin_predictors` is the ~inverse of `quantilized_splits_from_matrix`.""" 88 | x = jnp.arange(10)[None, :] 89 | splits, _ = quantilized_splits_from_matrix(x, 100) 90 | b = bin_predictors(x, splits) 91 | assert_array_equal(x, b) 92 | 93 | def test_one_value(self): 94 | """Check there's only 1 bin (0 splits) if there is 1 datapoint.""" 95 | x = jnp.arange(10)[:, None] 96 | _, max_split = quantilized_splits_from_matrix(x, 100) 97 | assert_array_equal(max_split, jnp.full(len(x), 0)) 98 | 99 | def test_zero_values(self): 100 | """Check what happens when no binning is possible.""" 101 | x = jnp.empty((1, 0)) 102 | with pytest.raises(ValueError, match='at least 1'): 103 | quantilized_splits_from_matrix(x, 100) 104 | 105 | def test_zero_bins(self): 106 | """Check what happens when no binning is possible.""" 107 | x = jnp.arange(10)[None, :] 108 | with pytest.raises(ValueError, match='at least 1'): 109 | quantilized_splits_from_matrix(x, 0) 110 | 111 | 112 | def test_binner_left_boundary(): 113 | """Check that the first bin is right-closed.""" 114 | splits = jnp.array([[1, 2, 3]]) 115 | 116 | x = jnp.array([[0, 1]]) 117 | b = bin_predictors(x, splits) 118 | assert_array_equal(b, [[0, 0]]) 119 | 120 | 121 | def test_binner_right_boundary(): 122 | """Check that the next-to-last bin is right-closed.""" 123 | splits = jnp.array([[1, 2, 3, 2**31 - 1]]) 124 | 125 | x = jnp.array([[2**31 - 1]]) 126 | b = bin_predictors(x, splits) 127 | assert_array_equal(b, [[3]]) 128 | -------------------------------------------------------------------------------- /docs/development.rst: -------------------------------------------------------------------------------- 1 | .. bartz/docs/development.rst 2 | .. 3 | .. Copyright (c) 2024-2025, The Bartz Contributors 4 | .. 5 | .. This file is part of bartz. 6 | .. 7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy 8 | .. of this software and associated documentation files (the "Software"), to deal 9 | .. in the Software without restriction, including without limitation the rights 10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | .. copies of the Software, and to permit persons to whom the Software is 12 | .. furnished to do so, subject to the following conditions: 13 | .. 14 | .. The above copyright notice and this permission notice shall be included in all 15 | .. copies or substantial portions of the Software. 16 | .. 17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | .. SOFTWARE. 24 | 25 | Development 26 | =========== 27 | 28 | Initial setup 29 | ------------- 30 | 31 | `Fork `_ the repository on Github, then clone the fork: 32 | 33 | .. code-block:: shell 34 | 35 | git clone git@github.com:YourGithubUserName/bartz.git 36 | cd bartz 37 | 38 | Install `R `_ and `uv `_ (for example, with `Homebrew `_ do :literal:`brew install r uv`). Then run 39 | 40 | .. code-block:: shell 41 | 42 | make setup 43 | 44 | to set up the Python and R environments. 45 | 46 | The Python environment is managed by uv. To run commands that involve the Python installation, do :literal:`uv run `. For example, to start an IPython shell, do :literal:`uv run ipython`. Alternatively, do :literal:`source .venv/bin/activate` to activate the virtual environment in the current shell. 47 | 48 | The R environment is automatically active when you use :literal:`R` in the project directory. 49 | 50 | Pre-defined commands 51 | -------------------- 52 | 53 | Development commands are defined in a makefile. Run :literal:`make` without arguments to list the targets. 54 | 55 | Documentation 56 | ------------- 57 | 58 | To build the documentation for the current working copy, do 59 | 60 | .. code-block:: shell 61 | 62 | make docs 63 | 64 | To build the documentation for the latest release tag, do 65 | 66 | .. code-block:: shell 67 | 68 | make docs-latest 69 | 70 | To debug the documentation build, do 71 | 72 | .. code-block:: shell 73 | 74 | make docs SPHINXOPTS='--fresh-env --pdb' 75 | 76 | Benchmarks 77 | ---------- 78 | 79 | The benchmarks are managed with `asv `_. The basic asv workflow is: 80 | 81 | .. code-block:: shell 82 | 83 | uv run asv run # run and save benchmarks on main branch 84 | uv run asv publish # create html report 85 | uv run asv preview # start a local server to view the report 86 | 87 | :literal:`asv run` writes the results into files saved in :literal:`./benchmarks`. These files are tracked by git; consider deliberately not committing all results generated while developing. 88 | 89 | There are a few make targets for common asv commands. The most useful command during development is 90 | 91 | .. code-block:: shell 92 | 93 | make asv-quick ARGS='--bench ' 94 | 95 | This runs only benchmarks whose name matches , only once, within the working copy and current Python environment. 96 | 97 | Profiling 98 | --------- 99 | 100 | Use the `JAX profiling utilities `_ to profile `bartz`. By default the MCMC loop is compiled all at once, which makes it quite opaque to profiling. There are two ways to understand what's going on inside in more detail: 1) inspect the individual operations and use intuition to understand to what piece of code they correspond to, 2) turn on bartz's profile mode. Basic workflow: 101 | 102 | .. code-block:: python 103 | 104 | from jax.profiler import trace, ProfileOptions 105 | from bartz.BART import gbart 106 | from bartz import profile_mode 107 | 108 | traceopt = ProfileOptions() 109 | 110 | # this setting makes Python function calls show up in the trace 111 | traceopt.python_tracer_level = 1 112 | 113 | # on cpu, this makes the trace detailed enough to understand what's going on 114 | # even within compiled functions 115 | traceopt.host_tracer_level = 2 116 | 117 | with trace('./trace_results', profiler_options=traceopt), profile_mode(True): 118 | bart = gbart(...) 119 | 120 | On the first run, the trace will show compilation operations, while subsequent runs (within the same Python shell) will be warmed-up. Start a xprof server to visualize the results: 121 | 122 | .. code-block:: shell 123 | 124 | $ uvx --python 3.13 xprof ./trace_results 125 | [...] 126 | XProf at http://localhost:8791/ (Press CTRL+C to quit) 127 | 128 | Open the provided URL in a browser. In the sidebar, select the tool "Trace Viewer". 129 | 130 | In "profile mode", the MCMC loop is split into a few chunks that are compiled separately, allowing to see at a glance how much time each phase of the MCMC cycle takes. This causes some overhead, so the timings are not equivalent to the normal mode ones. On some specific example on CPU, Bartz was 20% slower in profile mode with one chain, and 2x slower with multiple chains. 131 | -------------------------------------------------------------------------------- /tests/rbartpackages/BART3.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/rbartpackages/BART3.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Wrapper for the R package BART3.""" 26 | 27 | from typing import NamedTuple, TypedDict 28 | 29 | import numpy as np 30 | from jaxtyping import AbstractDtype, Float64, Int32 31 | from numpy import ndarray 32 | 33 | from tests.rbartpackages._base import RObjectBase, rmethod 34 | 35 | 36 | class TreeDraws(TypedDict): 37 | """Type of the `treedraws` attribute of `mc_gbart`.""" 38 | 39 | cutpoints: dict[int | str, Float64[ndarray, ' numcut[i]']] 40 | trees: str 41 | 42 | 43 | class String(AbstractDtype): 44 | """Represent a `numpy.str_` data dtype.""" 45 | 46 | dtypes = r' 0): 111 | self.rm_const -= 1 112 | else: 113 | msg = 'failed to parse rm.const because indices change sign' 114 | raise ValueError(msg) 115 | 116 | if self.sigest is not None: 117 | self.sigest = self.sigest.item() 118 | if self.sigma_mean is not None: 119 | self.sigma_mean = self.sigma_mean.item() 120 | 121 | if hasattr(self.treedraws, 'getbyname'): 122 | # it's a NamedList 123 | self.treedraws = { 124 | 'cutpoints': { 125 | i if it.name is None else it.name.item(): it.value 126 | for i, it in enumerate( 127 | self.treedraws.getbyname('cutpoints').items() 128 | ) 129 | }, 130 | 'trees': self.treedraws.getbyname('trees').item(), 131 | } 132 | else: 133 | # it's an OrdDict 134 | self.treedraws = { 135 | 'cutpoints': { 136 | i if k is None else k.item(): v 137 | for i, (k, v) in enumerate(self.treedraws['cutpoints'].items()) 138 | }, 139 | 'trees': self.treedraws['trees'].item(), 140 | } 141 | 142 | @rmethod 143 | def predict( 144 | self, newdata: Float64[ndarray, 'm p'], *args, **kwargs 145 | ) -> Float64[ndarray, 'ndpost m']: 146 | """Compute predictions.""" 147 | ... 148 | 149 | 150 | class bartModelMatrix(RObjectBase): # noqa: D101 because the R doc is added automatically 151 | _rfuncname = 'BART3::bartModelMatrix' 152 | 153 | 154 | class gbart(mc_gbart): # noqa: D101 because the R doc is added automatically 155 | _rfuncname = 'BART3::gbart' 156 | 157 | sigma: Float64[ndarray, ' nskip+ndpost'] | None = None 158 | -------------------------------------------------------------------------------- /tests/test_debug.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/test_debug.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Test `bartz.debug`.""" 26 | 27 | from collections import namedtuple 28 | 29 | import pytest 30 | from equinox import tree_at 31 | from jax import numpy as jnp 32 | from jax import random 33 | from scipy import stats 34 | from scipy.stats import ks_1samp 35 | 36 | from bartz.debug import check_trace, format_tree, sample_prior 37 | from bartz.jaxext import minimal_unsigned_dtype 38 | from bartz.mcmcloop import TreesTrace 39 | 40 | 41 | def manual_tree( 42 | leaf: list[list[float]], var: list[list[int]], split: list[list[int]] 43 | ) -> TreesTrace: 44 | """Facilitate the hardcoded definition of tree heaps.""" 45 | assert len(leaf) == len(var) + 1 == len(split) + 1 46 | 47 | def check_powers_of_2(seq: list[list]): 48 | """Check if the lengths of the lists in `seq` are powers of 2.""" 49 | return all(len(x) == 2**i for i, x in enumerate(seq)) 50 | 51 | check_powers_of_2(leaf) 52 | check_powers_of_2(var) 53 | check_powers_of_2(split) 54 | 55 | tree = TreesTrace( 56 | jnp.concatenate([jnp.zeros(1), *map(jnp.array, leaf)]), 57 | jnp.concatenate([jnp.zeros(1, int), *map(jnp.array, var)]), 58 | jnp.concatenate([jnp.zeros(1, int), *map(jnp.array, split)]), 59 | ) 60 | assert tree.leaf_tree.dtype == jnp.float32 61 | assert tree.var_tree.dtype == jnp.int32 62 | assert tree.split_tree.dtype == jnp.int32 63 | return tree 64 | 65 | 66 | def test_format_tree(): 67 | """Check the output of `format_tree` on a single example.""" 68 | tree = manual_tree( 69 | [[1.0], [2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], [[4], [1, 2]], [[15], [0, 3]] 70 | ) 71 | s = format_tree(tree) 72 | print(s) 73 | ref_s = """\ 74 | 1 ┐x4 < 15 75 | 2 ├── 2.0 76 | 3 └──┐x2 < 3 77 | 6 ├──╢6.0 78 | 7 └──╢7.0""" 79 | assert s == ref_s 80 | 81 | 82 | class TestSamplePrior: 83 | """Test `debug.sample_prior`.""" 84 | 85 | Args = namedtuple( 86 | 'Args', 87 | ['key', 'trace_length', 'num_trees', 'max_split', 'p_nonterminal', 'sigma_mu'], 88 | ) 89 | 90 | @pytest.fixture 91 | def args(self, keys): 92 | """Prepare arguments for `sample_prior`.""" 93 | # config 94 | trace_length = 1000 95 | num_trees = 200 96 | maxdepth = 6 97 | alpha = 0.95 98 | beta = 2 99 | max_split = 5 100 | 101 | # prepare arguments 102 | d = jnp.arange(maxdepth - 1) 103 | p_nonterminal = alpha / (1 + d).astype(float) ** beta 104 | p = maxdepth - 1 105 | max_split = jnp.full(p, jnp.array(max_split, minimal_unsigned_dtype(max_split))) 106 | sigma_mu = 1 / jnp.sqrt(num_trees) 107 | 108 | return self.Args( 109 | keys.pop(), trace_length, num_trees, max_split, p_nonterminal, sigma_mu 110 | ) 111 | 112 | def test_valid_trees(self, args: Args): 113 | """Check all sampled trees are valid.""" 114 | trees = sample_prior(*args) 115 | batch_shape = (args.trace_length, args.num_trees) 116 | heap_size = 2 ** (args.p_nonterminal.size + 1) 117 | assert trees.leaf_tree.shape == (*batch_shape, heap_size) 118 | assert trees.var_tree.shape == (*batch_shape, heap_size // 2) 119 | assert trees.split_tree.shape == (*batch_shape, heap_size // 2) 120 | bad = check_trace(trees, args.max_split) 121 | num_bad = jnp.count_nonzero(bad).item() 122 | assert num_bad == 0 123 | 124 | def test_max_depth(self, keys, args: Args): 125 | """Check that trees stop growing when p_nonterminal = 0.""" 126 | for max_depth in range(args.p_nonterminal.size + 1): 127 | p_nonterminal = jnp.zeros_like(args.p_nonterminal) 128 | p_nonterminal = p_nonterminal.at[:max_depth].set(1.0) 129 | args = tree_at(lambda args: args.p_nonterminal, args, p_nonterminal) 130 | args = tree_at(lambda args: args.key, args, keys.pop()) 131 | trees = sample_prior(*args) 132 | assert jnp.all(trees.split_tree[:, :, 1 : 2**max_depth]) 133 | assert not jnp.any(trees.split_tree[:, :, 2**max_depth :]) 134 | 135 | def test_forest_sdev(self, keys, args: Args): 136 | """Check that the sum of trees is standard Normal.""" 137 | trees = sample_prior(*args) 138 | leaf_indices = random.randint( 139 | keys.pop(), trees.leaf_tree.shape[:2], 0, trees.leaf_tree.shape[-1] 140 | ) 141 | batch_indices = jnp.ogrid[ 142 | : trees.leaf_tree.shape[0], : trees.leaf_tree.shape[1] 143 | ] 144 | leaves = trees.leaf_tree[(*batch_indices, leaf_indices)] 145 | sum_of_trees = jnp.sum(leaves, axis=1) 146 | 147 | test = ks_1samp(sum_of_trees, stats.norm.cdf) 148 | assert test.pvalue > 0.1 149 | 150 | def test_trees_differ(self, args: Args): 151 | """Check that trees are different across iterations.""" 152 | trees = sample_prior(*args) 153 | for attr in ('leaf_tree', 'var_tree', 'split_tree'): 154 | heap = getattr(trees, attr) 155 | diff_trace = jnp.diff(heap, axis=0) 156 | diff_forest = jnp.diff(heap, axis=1) 157 | assert jnp.any(diff_trace) 158 | assert jnp.any(diff_forest) 159 | -------------------------------------------------------------------------------- /tests/rbartpackages/_base.py: -------------------------------------------------------------------------------- 1 | # bartz/tests/rbartpackages/_base.py 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | from collections.abc import Callable 26 | from functools import wraps 27 | from re import fullmatch, match 28 | 29 | import numpy as np 30 | from rpy2 import robjects 31 | from rpy2.robjects import BoolVector, conversion, numpy2ri 32 | from rpy2.robjects.help import Package 33 | from rpy2.robjects.methods import RS4 34 | 35 | # converter for pandas 36 | pandas_converter = conversion.Converter('pandas') 37 | try: 38 | from rpy2.robjects import pandas2ri 39 | except ImportError: 40 | pass 41 | else: 42 | pandas_converter = pandas2ri.converter 43 | 44 | # converter for polars 45 | polars_converter = conversion.Converter('polars') 46 | try: 47 | import polars 48 | from rpy2.robjects import pandas2ri 49 | except ImportError: 50 | pass 51 | else: 52 | 53 | def polars_to_r(df): 54 | df = df.to_pandas() 55 | return pandas2ri.py2rpy(df) 56 | 57 | polars_converter.py2rpy.register(polars.DataFrame, polars_to_r) 58 | polars_converter.py2rpy.register(polars.Series, polars_to_r) 59 | 60 | # converter for jax 61 | jax_converter = conversion.Converter('jax') 62 | try: 63 | import jax 64 | except ImportError: 65 | pass 66 | else: 67 | 68 | def jax_to_r(x): 69 | x = np.asarray(x) 70 | if x.ndim == 0: 71 | x = x[()] 72 | return numpy2ri.py2rpy(x) 73 | 74 | jax_converter.py2rpy.register(jax.Array, jax_to_r) 75 | 76 | # converter for numpy 77 | numpy_converter = numpy2ri.converter 78 | 79 | 80 | # converter for BoolVector (why isn't it in the numpy converter?) 81 | def bool_vector_to_python(x): 82 | return np.array(x, bool) 83 | 84 | 85 | bool_vector_converter = conversion.Converter('bool_vector') 86 | bool_vector_converter.rpy2py.register(BoolVector, bool_vector_to_python) 87 | 88 | 89 | # converter for python dictionaries 90 | dict_converter = conversion.Converter('dict') 91 | 92 | 93 | def dict_to_r(x): 94 | return robjects.ListVector(x) 95 | 96 | 97 | dict_converter.py2rpy.register(dict, dict_to_r) 98 | 99 | R_IDENTIFIER = r'(?:[a-zA-Z]|\.(?![0-9]))[a-zA-Z0-9._]*' 100 | 101 | 102 | class RObjectBase: 103 | """ 104 | Base class for Python wrappers of R objects creators. 105 | 106 | Subclasses should define the class attribute `_rfuncname`, and declare 107 | stub methods decorated with `rmethod`. 108 | 109 | _rfuncname : str 110 | An R function in the format ``'::``. The function is 111 | called with the initialization arguments, converted to R objects, and is 112 | expected to return an R object. The attributes of the R object are 113 | converted to equivalent Python values and set as attributes of the 114 | Python object. The R object itself is assigned to the member `_robject`. 115 | """ 116 | 117 | _converter = ( 118 | robjects.default_converter 119 | + pandas_converter 120 | + polars_converter 121 | + numpy_converter 122 | + bool_vector_converter 123 | + jax_converter 124 | + dict_converter 125 | ) 126 | _convctx = conversion.localconverter(_converter) 127 | 128 | def _py2r(self, x): 129 | if isinstance(x, __class__): 130 | return x._robject 131 | with self._convctx: 132 | return self._converter.py2rpy(x) 133 | 134 | def _r2py(self, x): 135 | with self._convctx: 136 | return self._converter.rpy2py(x) 137 | 138 | def _args2r(self, args): 139 | return tuple(map(self._py2r, args)) 140 | 141 | def _kw2r(self, kw): 142 | return {key: self._py2r(value) for key, value in kw.items()} 143 | 144 | _rfuncname: str = NotImplemented 145 | 146 | @property 147 | def _library(self) -> str: 148 | """Parse `_rfuncname` to get the library. Also checks `_rfuncname` is valid.""" 149 | pattern = rf'^({R_IDENTIFIER})::({R_IDENTIFIER})$' 150 | m = match(pattern, self._rfuncname) 151 | if m is None: 152 | msg = f'Invalid _rfuncname: {self._rfuncname}.' 153 | raise ValueError(msg) 154 | return m.group(1) 155 | 156 | def __init__(self, *args, **kw): 157 | robjects.r(f'loadNamespace("{self._library}")') 158 | func = robjects.r(self._rfuncname) 159 | obj = func(*self._args2r(args), **self._kw2r(kw)) 160 | self._robject = obj 161 | if hasattr(obj, 'items'): 162 | for s, v in obj.items(): 163 | setattr(self, s.replace('.', '_'), self._r2py(v)) 164 | 165 | def __init_subclass__(cls, **kw): 166 | """Automatically add R documentation to subclasses.""" 167 | library, name = cls._rfuncname.split('::') 168 | page = Package(library).fetch(name) 169 | if cls.__doc__ is None: 170 | cls.__doc__ = '' 171 | cls.__doc__ += 'R documentation:\n' + page.to_docstring() 172 | 173 | 174 | def rmethod(meth: Callable, *, rname: str | None = None) -> Callable: 175 | """Automatically implement a method using the correspoding R method. 176 | 177 | Parameters 178 | ---------- 179 | meth 180 | A method in a subclass of `RObjectBase`. 181 | rname 182 | The name of the method in R. If not specified, use the name of `meth`. 183 | 184 | Returns 185 | ------- 186 | methimpl 187 | An implementation of the method that calls the R method. The original 188 | implementation of meth is completely discarded. 189 | 190 | Examples 191 | -------- 192 | >>> class MyRObject(RObjectBase): 193 | ... _rfuncname = 'mypackage::myfunction' 194 | ... @partial(rmethod, rname='my.method') 195 | ... def my_method(self, arg1: int, arg2: str): 196 | ... ... 197 | """ 198 | if rname is None: 199 | rname = meth.__name__ 200 | 201 | # I can't automatically add a docstring to the method because the R class 202 | # can be determined at runtime 203 | 204 | @wraps(meth) 205 | def impl(self, *args, **kw): 206 | if isinstance(self._robject, RS4): 207 | func = robjects.r['$'](self._robject, rname) 208 | out = func(*self._args2r(args), **self._kw2r(kw)) 209 | 210 | else: 211 | if not fullmatch(R_IDENTIFIER, rname): 212 | msg = f'Invalid R method name: {rname}' 213 | raise ValueError(msg) 214 | rclass = self._robject.rclass[0] 215 | func = robjects.r( 216 | f'getS3method("{rname}", "{rclass}", envir = asNamespace("{self._library}"))' 217 | ) 218 | out = func(self._robject, *self._args2r(args), **self._kw2r(kw)) 219 | 220 | return self._r2py(out) 221 | 222 | return impl 223 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # bartz/docs/conf.py 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | # Configuration file for the Sphinx documentation builder. 26 | # 27 | # This file only contains a selection of the most common options. For a full 28 | # list see the documentation: 29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 30 | 31 | import datetime 32 | import inspect 33 | import os 34 | import pathlib 35 | import sys 36 | from functools import cached_property 37 | 38 | import git 39 | from packaging import version as pkgversion 40 | 41 | # -- Doc variant ------------------------------------------------------------- 42 | 43 | repo = git.Repo(search_parent_directories=True) 44 | 45 | variant = os.environ.get('BARTZ_DOC_VARIANT', 'dev') 46 | 47 | if variant == 'dev': 48 | commit = repo.head.commit.hexsha 49 | uncommitted_stuff = repo.is_dirty() 50 | version = f'{commit[:7]}{"+" if uncommitted_stuff else ""}' 51 | 52 | elif variant == 'latest': 53 | # list git tags 54 | tags = [t.name for t in repo.tags] 55 | print(f'git tags: {tags}') 56 | 57 | # find final versions in tags 58 | versions = [] 59 | for t in tags: 60 | try: 61 | v = pkgversion.parse(t) 62 | except pkgversion.InvalidVersion: 63 | continue 64 | if v.is_prerelease or v.is_devrelease: 65 | continue 66 | versions.append((v, t)) 67 | print(f'tags for releases: {versions}') 68 | 69 | # find latest versions 70 | versions.sort(key=lambda x: x[0]) 71 | version, tag = versions[-1] 72 | 73 | # check it out and check it matches the version in the package 74 | repo.git.checkout(tag) 75 | import bartz 76 | 77 | assert pkgversion.parse(bartz.__version__) == version 78 | 79 | version = str(version) 80 | uncommitted_stuff = False 81 | 82 | else: 83 | raise KeyError(variant) 84 | 85 | import bartz 86 | 87 | # -- Project information ----------------------------------------------------- 88 | 89 | project = f'bartz {version}' 90 | author = 'The Bartz Contributors' 91 | 92 | now = datetime.datetime.now(tz=datetime.timezone.utc) 93 | year = '2024' 94 | if now.year > int(year): 95 | year += '-' + str(now.year) 96 | copyright = year + ', ' + author # noqa: A001, because sphinx uses this variable 97 | 98 | release = version 99 | 100 | # -- General configuration --------------------------------------------------- 101 | 102 | # Add any Sphinx extension module names here, as strings. They can be 103 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 104 | # ones. 105 | extensions = [ 106 | 'sphinx.ext.napoleon', 107 | 'sphinx.ext.autodoc', 108 | 'sphinx_autodoc_typehints', # (!) keep after napoleon 109 | 'sphinx.ext.mathjax', 110 | 'sphinx.ext.intersphinx', # link to other documentations automatically 111 | 'myst_parser', # markdown support 112 | ] 113 | 114 | # decide whether to use viewcode or linkcode extension 115 | ext = 'viewcode' # copy source code in static website 116 | if not uncommitted_stuff: 117 | commit = repo.head.commit.hexsha 118 | branches = repo.git.branch('--remotes', '--contains', commit) 119 | commit_on_github = bool(branches.strip()) 120 | if commit_on_github: 121 | ext = 'linkcode' # links to code on github 122 | extensions.append(f'sphinx.ext.{ext}') 123 | 124 | myst_enable_extensions = [ 125 | # "amsmath", 126 | 'dollarmath' 127 | ] 128 | 129 | # Add any paths that contain templates here, relative to this directory. 130 | # templates_path = ['_templates'] # noqa: ERA001 131 | 132 | # List of patterns, relative to source directory, that match files and 133 | # directories to ignore when looking for source files. 134 | # This pattern also affects html_static_path and html_extra_path. 135 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 136 | 137 | 138 | # -- Options for HTML output ------------------------------------------------- 139 | 140 | # The theme to use for HTML and HTML Help pages. See the documentation for 141 | # a list of builtin themes. 142 | # 143 | html_theme = 'alabaster' 144 | 145 | html_title = f'{project} documentation' 146 | 147 | html_theme_options = dict( 148 | description='Super-fast BART (Bayesian Additive Regression Trees) in Python', 149 | fixed_sidebar=True, 150 | github_button=True, 151 | github_type='star', 152 | github_repo='bartz', 153 | github_user='Gattocrucco', 154 | show_relbars=True, 155 | ) 156 | 157 | # Add any paths that contain custom static files (such as style sheets) here, 158 | # relative to this directory. They are copied after the builtin static files, 159 | # so a file named "default.css" will overwrite the builtin "default.css". 160 | html_static_path = ['_static'] 161 | 162 | master_doc = 'index' 163 | 164 | # -- Other options ------------------------------------------------- 165 | 166 | default_role = 'py:obj' 167 | 168 | # autodoc 169 | autoclass_content = 'both' # concatenate the class and __init__ docstrings 170 | # default arguments are printed as in source instead of being evaluated 171 | autodoc_preserve_defaults = True 172 | autodoc_default_options = {'member-order': 'bysource'} 173 | 174 | # autodoc-typehints 175 | typehints_use_rtype = False 176 | typehints_document_rtype = True 177 | always_use_bars_union = True 178 | typehints_defaults = 'comma' 179 | 180 | # napoleon 181 | napoleon_google_docstring = False 182 | napoleon_use_ivar = True 183 | napoleon_use_rtype = False 184 | 185 | # intersphinx 186 | intersphinx_mapping = dict( 187 | scipy=('https://docs.scipy.org/doc/scipy', None), 188 | numpy=('https://numpy.org/doc/stable', None), 189 | jax=('https://docs.jax.dev/en/latest', None), 190 | ) 191 | 192 | # viewcode 193 | viewcode_line_numbers = True 194 | 195 | 196 | def linkcode_resolve(domain, info): 197 | """ 198 | Determine the URL corresponding to Python object, for extension linkcode. 199 | 200 | Adapted from scipy/doc/release/conf.py. 201 | """ 202 | assert domain == 'py' 203 | 204 | modname = info['module'] 205 | assert modname.startswith('bartz') 206 | fullname = info['fullname'] 207 | 208 | submod = sys.modules.get(modname) 209 | assert submod 210 | 211 | obj = submod 212 | for part in fullname.split('.'): 213 | obj = getattr(obj, part) 214 | 215 | if isinstance(obj, cached_property): 216 | obj = obj.func 217 | obj = inspect.unwrap(obj) 218 | 219 | fn = inspect.getsourcefile(obj) 220 | assert fn 221 | 222 | source, lineno = inspect.getsourcelines(obj) 223 | assert lineno 224 | linespec = f'#L{lineno}-L{lineno + len(source) - 1}' 225 | 226 | prefix = 'https://github.com/Gattocrucco/bartz/blob' 227 | root = pathlib.Path(bartz.__file__).parent 228 | path = pathlib.Path(fn).relative_to(root).as_posix() 229 | return f'{prefix}/{commit}/src/bartz/{path}{linespec}' 230 | -------------------------------------------------------------------------------- /src/bartz/jaxext/_autobatch.py: -------------------------------------------------------------------------------- 1 | # bartz/src/bartz/jaxext/_autobatch.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Implementation of `autobatch`.""" 26 | 27 | import math 28 | from collections.abc import Callable 29 | from functools import wraps 30 | from warnings import warn 31 | 32 | from jax import eval_shape, jit 33 | from jax import numpy as jnp 34 | from jax.lax import scan 35 | from jax.tree import flatten as tree_flatten 36 | from jax.tree import map as tree_map 37 | from jax.tree import reduce as tree_reduce 38 | from jaxtyping import PyTree 39 | 40 | 41 | def expand_axes(axes, tree): 42 | """Expand `axes` such that they match the pytreedef of `tree`.""" 43 | 44 | def expand_axis(axis, subtree): 45 | return tree_map(lambda _: axis, subtree) 46 | 47 | return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None) 48 | 49 | 50 | def check_no_nones(axes, tree): 51 | def check_not_none(_, axis): 52 | assert axis is not None 53 | 54 | tree_map(check_not_none, tree, axes) 55 | 56 | 57 | def extract_size(axes, tree): 58 | def get_size(x, axis): 59 | if axis is None: 60 | return None 61 | else: 62 | return x.shape[axis] 63 | 64 | sizes = tree_map(get_size, tree, axes) 65 | sizes, _ = tree_flatten(sizes) 66 | assert all(s == sizes[0] for s in sizes) 67 | return sizes[0] 68 | 69 | 70 | def sum_nbytes(tree): 71 | def nbytes(x): 72 | return math.prod(x.shape) * x.dtype.itemsize 73 | 74 | return tree_reduce(lambda size, x: size + nbytes(x), tree, 0) 75 | 76 | 77 | def next_divisor_small(dividend, min_divisor): 78 | for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1): 79 | if dividend % divisor == 0: 80 | return divisor 81 | return dividend 82 | 83 | 84 | def next_divisor_large(dividend, min_divisor): 85 | max_inv_divisor = dividend // min_divisor 86 | for inv_divisor in range(max_inv_divisor, 0, -1): 87 | if dividend % inv_divisor == 0: 88 | return dividend // inv_divisor 89 | return dividend 90 | 91 | 92 | def next_divisor(dividend, min_divisor): 93 | if dividend == 0: 94 | return min_divisor 95 | if min_divisor * min_divisor <= dividend: 96 | return next_divisor_small(dividend, min_divisor) 97 | return next_divisor_large(dividend, min_divisor) 98 | 99 | 100 | def pull_nonbatched(axes, tree): 101 | def pull_nonbatched(x, axis): 102 | if axis is None: 103 | return None 104 | else: 105 | return x 106 | 107 | return tree_map(pull_nonbatched, tree, axes), tree 108 | 109 | 110 | def push_nonbatched(axes, tree, original_tree): 111 | def push_nonbatched(original_x, x, axis): 112 | if axis is None: 113 | return original_x 114 | else: 115 | return x 116 | 117 | return tree_map(push_nonbatched, original_tree, tree, axes) 118 | 119 | 120 | def move_axes_out(axes, tree): 121 | def move_axis_out(x, axis): 122 | return jnp.moveaxis(x, axis, 0) 123 | 124 | return tree_map(move_axis_out, tree, axes) 125 | 126 | 127 | def move_axes_in(axes, tree): 128 | def move_axis_in(x, axis): 129 | return jnp.moveaxis(x, 0, axis) 130 | 131 | return tree_map(move_axis_in, tree, axes) 132 | 133 | 134 | def batch(tree, nbatches): 135 | def batch(x): 136 | return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:]) 137 | 138 | return tree_map(batch, tree) 139 | 140 | 141 | def unbatch(tree): 142 | def unbatch(x): 143 | return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:]) 144 | 145 | return tree_map(unbatch, tree) 146 | 147 | 148 | def check_same(tree1, tree2): 149 | def check_same(x1, x2): 150 | assert x1.shape == x2.shape 151 | assert x1.dtype == x2.dtype 152 | 153 | tree_map(check_same, tree1, tree2) 154 | 155 | 156 | def autobatch( 157 | func: Callable, 158 | max_io_nbytes: int, 159 | in_axes: PyTree[int | None] = 0, 160 | out_axes: PyTree[int] = 0, 161 | return_nbatches: bool = False, 162 | ) -> Callable: 163 | """ 164 | Batch a function such that each batch is smaller than a threshold. 165 | 166 | Parameters 167 | ---------- 168 | func 169 | A jittable function with positional arguments only, with inputs and 170 | outputs pytrees of arrays. 171 | max_io_nbytes 172 | The maximum number of input + output bytes in each batch (excluding 173 | unbatched arguments.) 174 | in_axes 175 | A tree matching (a prefix of) the structure of the function input, 176 | indicating along which axes each array should be batched. A `None` axis 177 | indicates to not batch an argument. 178 | out_axes 179 | The same for outputs (but non-batching is not allowed). 180 | return_nbatches 181 | If True, the number of batches is returned as a second output. 182 | 183 | Returns 184 | ------- 185 | A function with the same signature as `func`, save for the return value if `return_nbatches`. 186 | """ 187 | initial_in_axes = in_axes 188 | initial_out_axes = out_axes 189 | 190 | @jit 191 | @wraps(func) 192 | def batched_func(*args): 193 | example_result = eval_shape(func, *args) 194 | 195 | in_axes = expand_axes(initial_in_axes, args) 196 | out_axes = expand_axes(initial_out_axes, example_result) 197 | check_no_nones(out_axes, example_result) 198 | 199 | size = extract_size((in_axes, out_axes), (args, example_result)) 200 | 201 | args, nonbatched_args = pull_nonbatched(in_axes, args) 202 | 203 | total_nbytes = sum_nbytes((args, example_result)) 204 | min_nbatches = total_nbytes // max_io_nbytes + bool( 205 | total_nbytes % max_io_nbytes 206 | ) 207 | min_nbatches = max(1, min_nbatches) 208 | nbatches = next_divisor(size, min_nbatches) 209 | assert 1 <= nbatches <= max(1, size) 210 | assert size % nbatches == 0 211 | assert total_nbytes % nbatches == 0 212 | 213 | batch_nbytes = total_nbytes // nbatches 214 | if batch_nbytes > max_io_nbytes: 215 | assert size == nbatches 216 | msg = f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}' 217 | warn(msg) 218 | 219 | def loop(_, args): 220 | args = move_axes_in(in_axes, args) 221 | args = push_nonbatched(in_axes, args, nonbatched_args) 222 | result = func(*args) 223 | result = move_axes_out(out_axes, result) 224 | return None, result 225 | 226 | args = move_axes_out(in_axes, args) 227 | args = batch(args, nbatches) 228 | _, result = scan(loop, None, args) 229 | result = unbatch(result) 230 | result = move_axes_in(out_axes, result) 231 | 232 | check_same(example_result, result) 233 | 234 | if return_nbatches: 235 | return result, nbatches 236 | return result 237 | 238 | return batched_func 239 | -------------------------------------------------------------------------------- /src/bartz/jaxext/__init__.py: -------------------------------------------------------------------------------- 1 | # bartz/src/bartz/jaxext/__init__.py 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Additions to jax.""" 26 | 27 | import math 28 | from collections.abc import Sequence 29 | from functools import partial 30 | 31 | import jax 32 | from jax import Device, ensure_compile_time_eval, jit, random 33 | from jax import numpy as jnp 34 | from jax.lax import scan 35 | from jax.scipy.special import ndtr 36 | from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped 37 | 38 | from bartz.jaxext._autobatch import autobatch # noqa: F401 39 | from bartz.jaxext.scipy.special import ndtri 40 | 41 | 42 | def vmap_nodoc(fun, *args, **kw): 43 | """ 44 | Acts like `jax.vmap` but preserves the docstring of the function unchanged. 45 | 46 | This is useful if the docstring already takes into account that the 47 | arguments have additional axes due to vmap. 48 | """ 49 | doc = fun.__doc__ 50 | fun = jax.vmap(fun, *args, **kw) 51 | fun.__doc__ = doc 52 | return fun 53 | 54 | 55 | def minimal_unsigned_dtype(value): 56 | """Return the smallest unsigned integer dtype that can represent `value`.""" 57 | if value < 2**8: 58 | return jnp.uint8 59 | if value < 2**16: 60 | return jnp.uint16 61 | if value < 2**32: 62 | return jnp.uint32 63 | return jnp.uint64 64 | 65 | 66 | @partial(jax.jit, static_argnums=(1,)) 67 | def unique( 68 | x: Shaped[Array, ' _'], size: int, fill_value: Scalar 69 | ) -> tuple[Shaped[Array, ' {size}'], int]: 70 | """ 71 | Restricted version of `jax.numpy.unique` that uses less memory. 72 | 73 | Parameters 74 | ---------- 75 | x 76 | The input array. 77 | size 78 | The length of the output. 79 | fill_value 80 | The value to fill the output with if `size` is greater than the number 81 | of unique values in `x`. 82 | 83 | Returns 84 | ------- 85 | out : Shaped[Array, '{size}'] 86 | The unique values in `x`, sorted, and right-padded with `fill_value`. 87 | actual_length : int 88 | The number of used values in `out`. 89 | """ 90 | if x.size == 0: 91 | return jnp.full(size, fill_value, x.dtype), 0 92 | if size == 0: 93 | return jnp.empty(0, x.dtype), 0 94 | x = jnp.sort(x) 95 | 96 | def loop(carry, x): 97 | i_out, last, out = carry 98 | i_out = jnp.where(x == last, i_out, i_out + 1) 99 | out = out.at[i_out].set(x) 100 | return (i_out, x, out), None 101 | 102 | carry = 0, x[0], jnp.full(size, fill_value, x.dtype) 103 | (actual_length, _, out), _ = scan(loop, carry, x[:size]) 104 | return out, actual_length + 1 105 | 106 | 107 | class split: 108 | """ 109 | Split a key into `num` keys. 110 | 111 | Parameters 112 | ---------- 113 | key 114 | The key to split. 115 | num 116 | The number of keys to split into. 117 | """ 118 | 119 | _keys: tuple[Key[Array, ''], ...] 120 | _num_used: int 121 | 122 | def __init__(self, key: Key[Array, ''], num: int = 2): 123 | self._keys = _split_unpack(key, num) 124 | self._num_used = 0 125 | 126 | def __len__(self): 127 | return len(self._keys) - self._num_used 128 | 129 | def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, '*']: 130 | """ 131 | Pop one or more keys from the list. 132 | 133 | Parameters 134 | ---------- 135 | shape 136 | The shape of the keys to pop. If empty (default), a single key is 137 | popped and returned. If not empty, the popped key is split and 138 | reshaped to the target shape. 139 | 140 | Returns 141 | ------- 142 | The popped keys as a jax array with the requested shape. 143 | 144 | Raises 145 | ------ 146 | IndexError 147 | If the list is empty. 148 | """ 149 | if len(self) == 0: 150 | msg = 'No keys left to pop' 151 | raise IndexError(msg) 152 | if not isinstance(shape, tuple): 153 | shape = (shape,) 154 | key = self._keys[self._num_used] 155 | self._num_used += 1 156 | if shape: 157 | key = _split_shaped(key, shape) 158 | return key 159 | 160 | 161 | @partial(jit, static_argnums=(1,)) 162 | def _split_unpack(key: Key[Array, ''], num: int) -> tuple[Key[Array, ''], ...]: 163 | keys = random.split(key, num) 164 | return tuple(keys) 165 | 166 | 167 | @partial(jit, static_argnums=(1,)) 168 | def _split_shaped(key: Key[Array, ''], shape: tuple[int, ...]) -> Key[Array, '*']: 169 | num = math.prod(shape) 170 | keys = random.split(key, num) 171 | return keys.reshape(shape) 172 | 173 | 174 | def truncated_normal_onesided( 175 | key: Key[Array, ''], 176 | shape: Sequence[int], 177 | upper: Bool[Array, '*'], 178 | bound: Float32[Array, '*'], 179 | *, 180 | clip: bool = True, 181 | ) -> Float32[Array, '*']: 182 | """ 183 | Sample from a one-sided truncated standard normal distribution. 184 | 185 | Parameters 186 | ---------- 187 | key 188 | JAX random key. 189 | shape 190 | Shape of output array, broadcasted with other inputs. 191 | upper 192 | True for (-∞, bound], False for [bound, ∞). 193 | bound 194 | The truncation boundary. 195 | clip 196 | Whether to clip the truncated uniform samples to (0, 1) before 197 | transforming them to truncated normal. Intended for debugging purposes. 198 | 199 | Returns 200 | ------- 201 | Array of samples from the truncated normal distribution. 202 | """ 203 | # Pseudocode: 204 | # | if upper: 205 | # | if bound < 0: 206 | # | ndtri(uniform(0, ndtr(bound))) = 207 | # | ndtri(ndtr(bound) * u) 208 | # | if bound > 0: 209 | # | -ndtri(uniform(ndtr(-bound), 1)) = 210 | # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u)) 211 | # | if not upper: 212 | # | if bound < 0: 213 | # | ndtri(uniform(ndtr(bound), 1)) = 214 | # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u)) 215 | # | if bound > 0: 216 | # | -ndtri(uniform(0, ndtr(-bound))) = 217 | # | -ndtri(ndtr(-bound) * u) 218 | shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape) 219 | bound_pos = bound > 0 220 | ndtr_bound = ndtr(bound) 221 | ndtr_neg_bound = ndtr(-bound) 222 | scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound) 223 | shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound) 224 | u = random.uniform(key, shape) 225 | left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)] 226 | right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1) 227 | truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u) 228 | if clip: 229 | # on gpu the accuracy is lower and sometimes u can reach the boundaries 230 | zero = jnp.zeros((), truncated_u.dtype) 231 | one = jnp.ones((), truncated_u.dtype) 232 | truncated_u = jnp.clip( 233 | truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero) 234 | ) 235 | truncated_norm = ndtri(truncated_u) 236 | return jnp.where(bound_pos, -truncated_norm, truncated_norm) 237 | 238 | 239 | def get_default_device() -> Device: 240 | """Get the current default JAX device.""" 241 | with ensure_compile_time_eval(): 242 | return jnp.zeros(()).device 243 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | # bartz/Makefile 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | # Makefile for running tests, prepare and upload a release. 26 | 27 | COVERAGE_SUFFIX = 28 | OLD_PYTHON = $(shell uv run --group=ci python -c 'from tests.util import get_old_python_str; print(get_old_python_str())') 29 | OLD_DATE = 2025-05-15 30 | 31 | .PHONY: all 32 | all: 33 | @echo "Available targets:" 34 | @echo "- setup: create R and Python environments for development" 35 | @echo "- tests: run unit tests, saving coverage information" 36 | @echo "- tests-old: run unit tests with oldest supported python and dependencies" 37 | @echo '- tests-gpu: variant of `tests` that works on gpu' 38 | @echo "- docs: build html documentation" 39 | @echo "- docs-latest: build html documentation for latest release" 40 | @echo "- covreport: build html coverage report" 41 | @echo "- covcheck: check coverage is above some thresholds" 42 | @echo "- release: packages the python module, invokes tests and docs first" 43 | @echo "- upload: upload release to PyPI" 44 | @echo "- upload-test: upload release to TestPyPI" 45 | @echo "- asv-run: run benchmarks on all unbenchmarked tagged releases and main" 46 | @echo "- asv-publish: create html benchmark report" 47 | @echo "- asv-preview: create html report and start server" 48 | @echo "- asv-main: run benchmarks on main branch" 49 | @echo "- asv-quick: run quick benchmarks on current code, no saving" 50 | @echo "- ipython: start an ipython shell with stuff pre-imported" 51 | @echo "- ipython-old: start an ipython shell with oldest supported python and dependencies" 52 | @echo 53 | @echo "Release workflow:" 54 | @echo "- $$ uv version --bump major|minor|patch" 55 | @echo "- describe release in docs/changelog.md" 56 | @echo "- $$ make release (repeat until it goes smoothly)" 57 | @echo "- push and check CI completes (if it doesn't, go to previous step)" 58 | @echo "- $$ make upload" 59 | @echo "- publish github release (updates zenodo automatically)" 60 | @echo "- if the online docs are not up-to-date, press 'run workflow' on https://github.com/Gattocrucco/bartz/actions/workflows/tests.yml, and try to understand why 'make upload' didn't do it" 61 | 62 | 63 | .PHONY: setup 64 | setup: 65 | Rscript -e "renv::restore()" 66 | uv run --all-groups pre-commit install 67 | @CUDA_VERSION=$$(nvidia-smi 2>/dev/null | grep -o 'CUDA Version: [0-9]*' | cut -d' ' -f3); \ 68 | if [ "$$CUDA_VERSION" = "12" ]; then \ 69 | echo "Detected CUDA 12, installing jax[cuda12]"; \ 70 | uv pip install "jax[cuda12]"; \ 71 | elif [ "$$CUDA_VERSION" = "13" ]; then \ 72 | echo "Detected CUDA 13, installing jax[cuda13]"; \ 73 | uv pip install "jax[cuda13]"; \ 74 | else \ 75 | echo "No CUDA detected"; \ 76 | fi 77 | 78 | 79 | ################# TESTS ################# 80 | 81 | TESTS_VARS = COVERAGE_FILE=.coverage.tests$(COVERAGE_SUFFIX) 82 | TESTS_COMMAND = python -m pytest --cov --cov-context=test --numprocesses=2 --dist=worksteal 83 | 84 | UV_RUN_CI = uv run --group=ci 85 | UV_OPTS_OLD = --python=$(OLD_PYTHON) --resolution=lowest-direct --exclude-newer=$(OLD_DATE) 86 | UV_VARS_OLD = UV_PROJECT_ENVIRONMENT=.venv-old 87 | UV_RUN_CI_OLD = $(UV_VARS_OLD) $(UV_RUN_CI) $(UV_OPTS_OLD) 88 | 89 | .PHONY: tests 90 | tests: 91 | $(TESTS_VARS) $(UV_RUN_CI) $(TESTS_COMMAND) $(ARGS) 92 | 93 | .PHONY: tests-old 94 | tests-old: 95 | $(TESTS_VARS) $(UV_RUN_CI_OLD) $(TESTS_COMMAND) $(ARGS) 96 | 97 | .PHONY: tests-gpu 98 | tests-gpu: 99 | nvidia-smi 100 | XLA_PYTHON_CLIENT_MEM_FRACTION=.20 $(TESTS_VARS) $(UV_RUN_CI) $(TESTS_COMMAND) --platform=gpu --numprocesses=3 $(ARGS) 101 | 102 | ################# DOCS ################# 103 | 104 | .PHONY: docs 105 | docs: 106 | $(UV_RUN_CI) make -C docs html 107 | test ! -d _site/docs-dev || rm -r _site/docs-dev 108 | mv docs/_build/html _site/docs-dev 109 | @echo 110 | @echo "Now open _site/index.html" 111 | 112 | .PHONY: docs-latest 113 | docs-latest: 114 | BARTZ_DOC_VARIANT=latest $(UV_RUN_CI) make -C docs html 115 | git switch - || git switch main 116 | test ! -d _site/docs || rm -r _site/docs 117 | mv docs/_build/html _site/docs 118 | @echo 119 | @echo "Now open _site/index.html" 120 | 121 | .PHONY: covreport 122 | covreport: 123 | $(UV_RUN_CI) coverage combine --keep 124 | $(UV_RUN_CI) coverage html --include='src/*' 125 | @echo 126 | @echo "Now open _site/coverage/index.html" 127 | 128 | .PHONY: covcheck 129 | covcheck: 130 | $(UV_RUN_CI) coverage combine --keep 131 | $(UV_RUN_CI) coverage report --include='tests/**/test_*.py' 132 | $(UV_RUN_CI) coverage report --include='src/*' 133 | $(UV_RUN_CI) coverage report --include='tests/**/test_*.py' --fail-under=99 --format=total 134 | $(UV_RUN_CI) coverage report --include='src/*' --fail-under=90 --format=total 135 | 136 | ################# RELEASE ################# 137 | 138 | .PHONY: update-deps 139 | update-deps: 140 | test ! -d .venv || rm -r .venv 141 | uv lock --upgrade 142 | 143 | .PHONY: copy-version 144 | copy-version: src/bartz/_version.py 145 | src/bartz/_version.py: pyproject.toml 146 | uv run --group=ci python -c 'from tests.util import update_version; update_version()' 147 | 148 | .PHONY: check-committed 149 | check-committed: 150 | git diff --quiet 151 | git diff --quiet --staged 152 | 153 | .PHONY: release 154 | release: update-deps copy-version check-committed 155 | @$(MAKE) tests 156 | @$(MAKE) tests-old 157 | @$(MAKE) docs 158 | test ! -d dist || rm -r dist 159 | uv build 160 | 161 | .PHONY: version-tag 162 | version-tag: copy-version check-committed 163 | git fetch --tags 164 | git tag v$(shell uv run python -c 'import bartz; print(bartz.__version__)') 165 | git push --tags 166 | 167 | .PHONY: upload 168 | upload: version-tag 169 | @echo "Enter PyPI token:" 170 | @read -s UV_PUBLISH_TOKEN && \ 171 | export UV_PUBLISH_TOKEN="$$UV_PUBLISH_TOKEN" && \ 172 | uv publish 173 | @VERSION=$$(uv run python -c 'import bartz; print(bartz.__version__)') && \ 174 | echo "Try to install bartz $$VERSION from PyPI" && \ 175 | uv tool run --with="bartz==$$VERSION" python -c 'import bartz; print(bartz.__version__)' 176 | 177 | .PHONY: upload-test 178 | upload-test: check-committed 179 | @echo "Enter TestPyPI token:" 180 | @read -s UV_PUBLISH_TOKEN && \ 181 | export UV_PUBLISH_TOKEN="$$UV_PUBLISH_TOKEN" && \ 182 | uv publish --check-url=https://test.pypi.org/simple/ --publish-url=https://test.pypi.org/legacy/ 183 | @VERSION=$$(uv run --group=ci python -c 'from tests.util import get_version; print(get_version())') && \ 184 | echo "Try to install bartz $$VERSION from TestPyPI" && \ 185 | uv tool run --index=https://test.pypi.org/simple/ --index-strategy=unsafe-best-match --with="bartz==$$VERSION" python -c 'import bartz; print(bartz.__version__)' 186 | 187 | 188 | ################# BENCHMARKS ################# 189 | 190 | ASV = $(UV_RUN_CI) python -m asv 191 | 192 | .PHONY: asv-run 193 | asv-run: 194 | $(UV_RUN_CI) python config/refs-for-asv.py | $(ASV) run --skip-existing --show-stderr HASHFILE:- $(ARGS) 195 | 196 | .PHONY: asv-publish 197 | asv-publish: 198 | $(ASV) publish $(ARGS) 199 | 200 | .PHONY: asv-preview 201 | asv-preview: asv-publish 202 | $(ASV) preview $(ARGS) 203 | 204 | .PHONY: asv-main 205 | asv-main: 206 | $(ASV) run --show-stderr main^! $(ARGS) 207 | 208 | .PHONY: asv-quick 209 | asv-quick: 210 | $(ASV) run --python=same --quick --dry-run --show-stderr $(ARGS) 211 | 212 | 213 | ################# IPYTHON SHELL ################# 214 | 215 | .PHONY: ipython 216 | ipython: 217 | IPYTHONDIR=config/ipython uv run --all-groups python -m IPython $(ARGS) 218 | 219 | .PHONY: ipython-old 220 | ipython-old: 221 | IPYTHONDIR=config/ipython $(UV_VARS_OLD) uv run --all-groups $(UV_OPTS_OLD) python -m IPython $(ARGS) 222 | -------------------------------------------------------------------------------- /src/bartz/_profiler.py: -------------------------------------------------------------------------------- 1 | # bartz/src/bartz/_profiler.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Module with utilities related to profiling bartz.""" 26 | 27 | from collections.abc import Callable, Iterator 28 | from contextlib import contextmanager 29 | from functools import wraps 30 | from typing import Any, TypeVar 31 | 32 | from jax import block_until_ready, debug, jit 33 | from jax.lax import cond, scan 34 | from jax.profiler import TraceAnnotation 35 | from jaxtyping import Array, Bool 36 | 37 | PROFILE_MODE: bool = False 38 | 39 | T = TypeVar('T') 40 | Carry = TypeVar('Carry') 41 | 42 | 43 | def get_profile_mode() -> bool: 44 | """Return the current profile mode status. 45 | 46 | Returns 47 | ------- 48 | True if profile mode is enabled, False otherwise. 49 | """ 50 | return PROFILE_MODE 51 | 52 | 53 | def set_profile_mode(value: bool, /) -> None: 54 | """Set the profile mode status. 55 | 56 | Parameters 57 | ---------- 58 | value 59 | If True, enable profile mode. If False, disable it. 60 | """ 61 | global PROFILE_MODE # noqa: PLW0603 62 | PROFILE_MODE = value 63 | 64 | 65 | @contextmanager 66 | def profile_mode(value: bool, /) -> Iterator[None]: 67 | """Context manager to temporarily set profile mode. 68 | 69 | Parameters 70 | ---------- 71 | value 72 | Profile mode value to set within the context. 73 | 74 | Examples 75 | -------- 76 | >>> with profile_mode(True): 77 | ... # Code runs with profile mode enabled 78 | ... pass 79 | 80 | Notes 81 | ----- 82 | In profiling mode, the MCMC loop is not compiled into a single function, but 83 | instead compiled in smaller pieces that are instrumented to show up in the 84 | jax tracer and Python profiling statistics. Search for function names 85 | starting with 'jab' (see `jit_and_block_if_profiling`). 86 | 87 | Jax tracing is not enabled by this context manager and if used must be 88 | handled separately by the user; this context manager only makes sure that 89 | the execution flow will be more interpretable in the traces if the tracer is 90 | used. 91 | """ 92 | old_value = get_profile_mode() 93 | set_profile_mode(value) 94 | try: 95 | yield 96 | finally: 97 | set_profile_mode(old_value) 98 | 99 | 100 | def jit_and_block_if_profiling( 101 | func: Callable[..., T], block_before: bool = False, **kwargs 102 | ) -> Callable[..., T]: 103 | """Apply JIT compilation and block if profiling is enabled. 104 | 105 | When profile mode is off, the function runs without JIT. When profile mode 106 | is on, the function is JIT compiled and blocks outputs to ensure proper 107 | timing. 108 | 109 | Parameters 110 | ---------- 111 | func 112 | Function to wrap. 113 | block_before 114 | If True block inputs before passing them to the JIT-compiled function. 115 | This ensures that any pending computations are completed before entering 116 | the JIT-compiled function. This phase is not included in the trace 117 | event. 118 | **kwargs 119 | Additional arguments to pass to `jax.jit`. 120 | 121 | Returns 122 | ------- 123 | Wrapped function. 124 | 125 | Notes 126 | ----- 127 | Under profiling mode, the function invocation is handled such that a custom 128 | jax trace event with name `jab[]` is created. The statistics on 129 | the actual Python function will be off, while the function 130 | `jab_inner_wrapper` represents the actual execution time. 131 | """ 132 | jitted_func = jit(func, **kwargs) 133 | 134 | event_name = f'jab[{func.__name__}]' 135 | 136 | # this wrapper is meant to measure the time spent executing the function 137 | def jab_inner_wrapper(*args, **kwargs) -> T: 138 | with TraceAnnotation(event_name): 139 | result = jitted_func(*args, **kwargs) 140 | return block_until_ready(result) 141 | 142 | @wraps(func) 143 | def jab_outer_wrapper(*args: Any, **kwargs: Any) -> T: 144 | if get_profile_mode(): 145 | if block_before: 146 | args, kwargs = block_until_ready((args, kwargs)) 147 | return jab_inner_wrapper(*args, **kwargs) 148 | else: 149 | return func(*args, **kwargs) 150 | 151 | return jab_outer_wrapper 152 | 153 | 154 | def jit_if_not_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]: 155 | """Apply JIT compilation only when not profiling. 156 | 157 | When profile mode is off, the function is JIT compiled. When profile mode is 158 | on, the function runs as-is. 159 | 160 | Parameters 161 | ---------- 162 | func 163 | Function to wrap. 164 | *args 165 | **kwargs 166 | Additional arguments to pass to `jax.jit`. 167 | 168 | Returns 169 | ------- 170 | Wrapped function. 171 | """ 172 | jitted_func = jit(func, *args, **kwargs) 173 | 174 | @wraps(func) 175 | def wrapper(*args: Any, **kwargs: Any) -> T: 176 | if get_profile_mode(): 177 | return func(*args, **kwargs) 178 | else: 179 | return jitted_func(*args, **kwargs) 180 | 181 | return wrapper 182 | 183 | 184 | def scan_if_not_profiling( 185 | f: Callable[[Carry, None], tuple[Carry, None]], 186 | init: Carry, 187 | xs: None, 188 | length: int, 189 | /, 190 | ) -> tuple[Carry, None]: 191 | """Restricted replacement for `jax.lax.scan` that uses a Python loop when profiling. 192 | 193 | Parameters 194 | ---------- 195 | f 196 | Scan body function with signature (carry, None) -> (carry, None). 197 | init 198 | Initial carry value. 199 | xs 200 | Input values to scan over (not supported). 201 | length 202 | Integer specifying the number of loop iterations. 203 | 204 | Returns 205 | ------- 206 | Tuple of (final_carry, None) (stacked outputs not supported). 207 | """ 208 | assert xs is None 209 | if get_profile_mode(): 210 | carry = init 211 | for _i in range(length): 212 | carry, _ = f(carry, None) 213 | return carry, None 214 | 215 | else: 216 | return scan(f, init, None, length) 217 | 218 | 219 | def cond_if_not_profiling( 220 | pred: bool | Bool[Array, ''], 221 | true_fun: Callable[..., T], 222 | false_fun: Callable[..., T], 223 | /, 224 | *operands, 225 | ) -> T: 226 | """Restricted replacement for `jax.lax.cond` that uses a Python if when profiling. 227 | 228 | Parameters 229 | ---------- 230 | pred 231 | Boolean predicate to choose which function to execute. 232 | true_fun 233 | Function to execute if `pred` is True. 234 | false_fun 235 | Function to execute if `pred` is False. 236 | *operands 237 | Arguments passed to `true_fun` and `false_fun`. 238 | 239 | Returns 240 | ------- 241 | Result of either `true_fun()` or `false_fun()`. 242 | """ 243 | if get_profile_mode(): 244 | if pred: 245 | return true_fun(*operands) 246 | else: 247 | return false_fun(*operands) 248 | else: 249 | return cond(pred, true_fun, false_fun, *operands) 250 | 251 | 252 | def callback_if_not_profiling( 253 | callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any 254 | ): 255 | """Restricted replacement for `jax.debug.callback` that calls the callback directly in profiling mode.""" 256 | if get_profile_mode(): 257 | callback(*args, **kwargs) 258 | else: 259 | debug.callback(callback, *args, ordered=ordered, **kwargs) 260 | -------------------------------------------------------------------------------- /asv.conf.json: -------------------------------------------------------------------------------- 1 | { 2 | // The version of the config file format. Do not change, unless 3 | // you know what you are doing. 4 | "version": 1, 5 | 6 | // The name of the project being benchmarked 7 | "project": "bartz", 8 | 9 | // The project's homepage 10 | "project_url": "https://github.com/Gattocrucco/bartz", 11 | 12 | // The URL or local path of the source code repository for the 13 | // project being benchmarked 14 | "repo": ".", 15 | 16 | // The Python project's subdirectory in your repo. If missing or 17 | // the empty string, the project is assumed to be located at the root 18 | // of the repository. 19 | // "repo_subdir": "", 20 | 21 | // Customizable commands for building the project. 22 | // See asv.conf.json documentation. 23 | "build_command": [ 24 | "python -m build --wheel -o {build_cache_dir} {build_dir}" 25 | ], 26 | // To build the package using setuptools and a setup.py file, uncomment the following lines 27 | // "build_command": [ 28 | // "python setup.py build", 29 | // "python -mpip wheel -w {build_cache_dir} {build_dir}" 30 | // ], 31 | 32 | // Customizable commands for installing and uninstalling the project. 33 | // See asv.conf.json documentation. 34 | // "install_command": ["in-dir={env_dir} python -mpip install {wheel_file}"], 35 | // "uninstall_command": ["return-code=any python -mpip uninstall -y {project}"], 36 | 37 | // List of branches to benchmark. If not provided, defaults to "main" 38 | // (for git) or "default" (for mercurial). 39 | "branches": [ 40 | "main", 41 | ], 42 | 43 | // The DVCS being used. If not set, it will be automatically 44 | // determined from "repo" by looking at the protocol in the URL 45 | // (if remote), or by looking for special directories, such as 46 | // ".git" (if local). 47 | // "dvcs": "git", 48 | 49 | // The tool to use to create environments. May be "conda", 50 | // "virtualenv", "mamba" (above 3.8) 51 | // or other value depending on the plugins in use. 52 | // If missing or the empty string, the tool will be automatically 53 | // determined by looking for tools on the PATH environment 54 | // variable. 55 | "environment_type": "virtualenv", 56 | 57 | // timeout in seconds for installing any dependencies in environment 58 | // defaults to 10 min 59 | //"install_timeout": 600, 60 | 61 | // the base URL to show a commit for the project. 62 | "show_commit_url": "https://github.com/Gattocrucco/bartz/commit/", 63 | 64 | // The Pythons you'd like to test against. If not provided, defaults 65 | // to the current version of Python used to run `asv`. 66 | // "pythons": ["3.8", "3.12"], 67 | 68 | // The list of conda channel names to be searched for benchmark 69 | // dependency packages in the specified order 70 | // "conda_channels": ["conda-forge", "defaults"], 71 | 72 | // A conda environment file that is used for environment creation. 73 | // "conda_environment_file": "environment.yml", 74 | 75 | // The matrix of dependencies to test. Each key of the "req" 76 | // requirements dictionary is the name of a package (in PyPI) and 77 | // the values are version numbers. An empty list or empty string 78 | // indicates to just test against the default (latest) 79 | // version. null indicates that the package is to not be 80 | // installed. If the package to be tested is only available from 81 | // PyPi, and the 'environment_type' is conda, then you can preface 82 | // the package name by 'pip+', and the package will be installed 83 | // via pip (with all the conda available packages installed first, 84 | // followed by the pip installed packages). 85 | // 86 | // The ``@env`` and ``@env_nobuild`` keys contain the matrix of 87 | // environment variables to pass to build and benchmark commands. 88 | // An environment will be created for every combination of the 89 | // cartesian product of the "@env" variables in this matrix. 90 | // Variables in "@env_nobuild" will be passed to every environment 91 | // during the benchmark phase, but will not trigger creation of 92 | // new environments. A value of ``null`` means that the variable 93 | // will not be set for the current combination. 94 | // 95 | // "matrix": { 96 | // "req": { 97 | // "numpy": ["1.6", "1.7"], 98 | // "six": ["", null], // test with and without six installed 99 | // "pip+emcee": [""] // emcee is only available for install with pip. 100 | // }, 101 | // "env": {"ENV_VAR_1": ["val1", "val2"]}, 102 | // "env_nobuild": {"ENV_VAR_2": ["val3", null]}, 103 | // }, 104 | 105 | // Combinations of libraries/python versions can be excluded/included 106 | // from the set to test. Each entry is a dictionary containing additional 107 | // key-value pairs to include/exclude. 108 | // 109 | // An exclude entry excludes entries where all values match. The 110 | // values are regexps that should match the whole string. 111 | // 112 | // An include entry adds an environment. Only the packages listed 113 | // are installed. The 'python' key is required. The exclude rules 114 | // do not apply to includes. 115 | // 116 | // In addition to package names, the following keys are available: 117 | // 118 | // - python 119 | // Python version, as in the *pythons* variable above. 120 | // - environment_type 121 | // Environment type, as above. 122 | // - sys_platform 123 | // Platform, as in sys.platform. Possible values for the common 124 | // cases: 'linux2', 'win32', 'cygwin', 'darwin'. 125 | // - req 126 | // Required packages 127 | // - env 128 | // Environment variables 129 | // - env_nobuild 130 | // Non-build environment variables 131 | // 132 | // "exclude": [ 133 | // {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows 134 | // {"environment_type": "conda", "req": {"six": null}}, // don't run without six on conda 135 | // {"env": {"ENV_VAR_1": "val2"}}, // skip val2 for ENV_VAR_1 136 | // ], 137 | // 138 | // "include": [ 139 | // // additional env for python3.12 140 | // {"python": "3.12", "req": {"numpy": "1.26"}, "env_nobuild": {"FOO": "123"}}, 141 | // // additional env if run on windows+conda 142 | // {"platform": "win32", "environment_type": "conda", "python": "3.12", "req": {"libpython": ""}}, 143 | // ], 144 | 145 | // The directory (relative to the current directory) that benchmarks are 146 | // stored in. If not provided, defaults to "benchmarks" 147 | // "benchmark_dir": "benchmarks", 148 | 149 | // The directory (relative to the current directory) to cache the Python 150 | // environments in. If not provided, defaults to "env" 151 | "env_dir": ".asv/env", 152 | 153 | // The directory (relative to the current directory) that raw benchmark 154 | // results are stored in. If not provided, defaults to "results". 155 | "results_dir": ".asv/results", 156 | 157 | // The directory (relative to the current directory) that the html tree 158 | // should be written to. If not provided, defaults to "html". 159 | "html_dir": "_site/benchmarks", 160 | 161 | // The number of characters to retain in the commit hashes. 162 | // "hash_length": 8, 163 | 164 | // `asv` will cache results of the recent builds in each 165 | // environment, making them faster to install next time. This is 166 | // the number of builds to keep, per environment. 167 | // "build_cache_size": 2, 168 | 169 | // The commits after which the regression search in `asv publish` 170 | // should start looking for regressions. Dictionary whose keys are 171 | // regexps matching to benchmark names, and values corresponding to 172 | // the commit (exclusive) after which to start looking for 173 | // regressions. The default is to start from the first commit 174 | // with results. If the commit is `null`, regression detection is 175 | // skipped for the matching benchmark. 176 | // 177 | // "regressions_first_commits": { 178 | // "some_benchmark": "352cdf", // Consider regressions only after this commit 179 | // "another_benchmark": null, // Skip regression detection altogether 180 | // }, 181 | 182 | // The thresholds for relative change in results, after which `asv 183 | // publish` starts reporting regressions. Dictionary of the same 184 | // form as in ``regressions_first_commits``, with values 185 | // indicating the thresholds. If multiple entries match, the 186 | // maximum is taken. If no entry matches, the default is 5%. 187 | // 188 | // "regressions_thresholds": { 189 | // "some_benchmark": 0.01, // Threshold of 1% 190 | // "another_benchmark": 0.5, // Threshold of 50% 191 | // }, 192 | } 193 | -------------------------------------------------------------------------------- /src/bartz/jaxext/scipy/special.py: -------------------------------------------------------------------------------- 1 | # bartz/src/bartz/jaxext/scipy/special.py 2 | # 3 | # Copyright (c) 2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | """Mockup of the :external:py:mod:`scipy.special` module.""" 26 | 27 | from functools import wraps 28 | 29 | from jax import ShapeDtypeStruct, jit, pure_callback 30 | from jax import numpy as jnp 31 | from scipy.special import gammainccinv as scipy_gammainccinv 32 | 33 | 34 | def _float_type(*args): 35 | """Determine the jax floating point result type given operands/types.""" 36 | t = jnp.result_type(*args) 37 | return jnp.sin(jnp.empty(0, t)).dtype 38 | 39 | 40 | def _castto(func, dtype): 41 | @wraps(func) 42 | def newfunc(*args, **kw): 43 | return func(*args, **kw).astype(dtype) 44 | 45 | return newfunc 46 | 47 | 48 | @jit 49 | def gammainccinv(a, y): 50 | """Survival function inverse of the Gamma(a, 1) distribution.""" 51 | shape = jnp.broadcast_shapes(a.shape, y.shape) 52 | dtype = _float_type(a.dtype, y.dtype) 53 | dummy = ShapeDtypeStruct(shape, dtype) 54 | ufunc = _castto(scipy_gammainccinv, dtype) 55 | return pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims') 56 | 57 | 58 | ################# COPIED AND ADAPTED FROM JAX ################## 59 | # Copyright 2018 The JAX Authors. 60 | # 61 | # Licensed under the Apache License, Version 2.0 (the "License"); 62 | # you may not use this file except in compliance with the License. 63 | # You may obtain a copy of the License at 64 | # 65 | # https://www.apache.org/licenses/LICENSE-2.0 66 | # 67 | # Unless required by applicable law or agreed to in writing, software 68 | # distributed under the License is distributed on an "AS IS" BASIS, 69 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 70 | # See the License for the specific language governing permissions and 71 | # limitations under the License. 72 | 73 | import numpy as np 74 | from jax import debug_infs, lax 75 | 76 | 77 | def ndtri(p): 78 | """Compute the inverse of the CDF of the Normal distribution function. 79 | 80 | This is a patch of `jax.scipy.special.ndtri`. 81 | """ 82 | dtype = lax.dtype(p) 83 | if dtype not in (jnp.float32, jnp.float64): 84 | msg = f'x.dtype={dtype} is not supported, see docstring for supported types.' 85 | raise TypeError(msg) 86 | return _ndtri(p) 87 | 88 | 89 | def _ndtri(p): 90 | # Constants used in piece-wise rational approximations. Taken from the cephes 91 | # library: 92 | # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html 93 | p0 = list( 94 | reversed( 95 | [ 96 | -5.99633501014107895267e1, 97 | 9.80010754185999661536e1, 98 | -5.66762857469070293439e1, 99 | 1.39312609387279679503e1, 100 | -1.23916583867381258016e0, 101 | ] 102 | ) 103 | ) 104 | q0 = list( 105 | reversed( 106 | [ 107 | 1.0, 108 | 1.95448858338141759834e0, 109 | 4.67627912898881538453e0, 110 | 8.63602421390890590575e1, 111 | -2.25462687854119370527e2, 112 | 2.00260212380060660359e2, 113 | -8.20372256168333339912e1, 114 | 1.59056225126211695515e1, 115 | -1.18331621121330003142e0, 116 | ] 117 | ) 118 | ) 119 | p1 = list( 120 | reversed( 121 | [ 122 | 4.05544892305962419923e0, 123 | 3.15251094599893866154e1, 124 | 5.71628192246421288162e1, 125 | 4.40805073893200834700e1, 126 | 1.46849561928858024014e1, 127 | 2.18663306850790267539e0, 128 | -1.40256079171354495875e-1, 129 | -3.50424626827848203418e-2, 130 | -8.57456785154685413611e-4, 131 | ] 132 | ) 133 | ) 134 | q1 = list( 135 | reversed( 136 | [ 137 | 1.0, 138 | 1.57799883256466749731e1, 139 | 4.53907635128879210584e1, 140 | 4.13172038254672030440e1, 141 | 1.50425385692907503408e1, 142 | 2.50464946208309415979e0, 143 | -1.42182922854787788574e-1, 144 | -3.80806407691578277194e-2, 145 | -9.33259480895457427372e-4, 146 | ] 147 | ) 148 | ) 149 | p2 = list( 150 | reversed( 151 | [ 152 | 3.23774891776946035970e0, 153 | 6.91522889068984211695e0, 154 | 3.93881025292474443415e0, 155 | 1.33303460815807542389e0, 156 | 2.01485389549179081538e-1, 157 | 1.23716634817820021358e-2, 158 | 3.01581553508235416007e-4, 159 | 2.65806974686737550832e-6, 160 | 6.23974539184983293730e-9, 161 | ] 162 | ) 163 | ) 164 | q2 = list( 165 | reversed( 166 | [ 167 | 1.0, 168 | 6.02427039364742014255e0, 169 | 3.67983563856160859403e0, 170 | 1.37702099489081330271e0, 171 | 2.16236993594496635890e-1, 172 | 1.34204006088543189037e-2, 173 | 3.28014464682127739104e-4, 174 | 2.89247864745380683936e-6, 175 | 6.79019408009981274425e-9, 176 | ] 177 | ) 178 | ) 179 | 180 | dtype = lax.dtype(p).type 181 | shape = jnp.shape(p) 182 | 183 | def _create_polynomial(var, coeffs): 184 | """Compute n_th order polynomial via Horner's method.""" 185 | coeffs = np.array(coeffs, dtype) 186 | if not coeffs.size: 187 | return jnp.zeros_like(var) 188 | return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var 189 | 190 | maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.0)), dtype(1.0) - p, p) 191 | # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs 192 | # later on. The result from the computation when p == 0 is not used so any 193 | # number that doesn't result in NaNs is fine. 194 | sanitized_mcp = jnp.where( 195 | maybe_complement_p == dtype(0.0), 196 | jnp.full(shape, dtype(0.5)), 197 | maybe_complement_p, 198 | ) 199 | 200 | # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2). 201 | w = sanitized_mcp - dtype(0.5) 202 | ww = lax.square(w) 203 | x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) / _create_polynomial(ww, q0)) 204 | x_for_big_p *= -dtype(np.sqrt(2.0 * np.pi)) 205 | 206 | # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z), 207 | # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different 208 | # arrays based on whether p < exp(-32). 209 | z = lax.sqrt(dtype(-2.0) * lax.log(sanitized_mcp)) 210 | first_term = z - lax.log(z) / z 211 | second_term_small_p = ( 212 | _create_polynomial(dtype(1.0) / z, p2) 213 | / _create_polynomial(dtype(1.0) / z, q2) 214 | / z 215 | ) 216 | second_term_otherwise = ( 217 | _create_polynomial(dtype(1.0) / z, p1) 218 | / _create_polynomial(dtype(1.0) / z, q1) 219 | / z 220 | ) 221 | x_for_small_p = first_term - second_term_small_p 222 | x_otherwise = first_term - second_term_otherwise 223 | 224 | x = jnp.where( 225 | sanitized_mcp > dtype(np.exp(-2.0)), 226 | x_for_big_p, 227 | jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise), 228 | ) 229 | 230 | x = jnp.where(p > dtype(1.0 - np.exp(-2.0)), x, -x) 231 | with debug_infs(False): 232 | infinity = jnp.full(shape, dtype(np.inf)) 233 | neg_infinity = -infinity 234 | return jnp.where( 235 | p == dtype(0.0), neg_infinity, jnp.where(p == dtype(1.0), infinity, x) 236 | ) 237 | 238 | 239 | ################################################################ 240 | -------------------------------------------------------------------------------- /.asv/results/benchmarks.json: -------------------------------------------------------------------------------- 1 | { 2 | "rmse.EvalGbart.track_rmse": { 3 | "code": "class EvalGbart:\n def track_rmse(self) -> float:\n \"\"\"Return the RMSE for predictions on a test set.\"\"\"\n key = random.key(2025_06_26_21_02)\n data = make_data(key, 100, 1000, 20)\n with redirect_stdout(StringIO()):\n bart = gbart(\n data.X_train,\n data.y_train,\n x_test=data.X_test,\n nskip=1000,\n ndpost=1000,\n seed=key,\n )\n return jnp.sqrt(jnp.mean(jnp.square(bart.yhat_test_mean - data.mu_test))).item()", 4 | "name": "rmse.EvalGbart.track_rmse", 5 | "param_names": [], 6 | "params": [], 7 | "timeout": 30.0, 8 | "type": "track", 9 | "unit": "latent_sdev", 10 | "version": "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28" 11 | }, 12 | "speed.TimeGbart.time_gbart": { 13 | "code": "class TimeGbart:\n def time_gbart(self, *_):\n \"\"\"Time instantiating the class.\"\"\"\n with redirect_stdout(StringIO()):\n bart = gbart(**self.kw)\n block_until_ready((bart._mcmc_state, bart._main_trace))\n\n def setup(self, niters: int, cache: Cache, nchains: int):\n \"\"\"Prepare the arguments and run once to warm-up.\"\"\"\n # check support for multiple chains\n if (niters == 0 or cache == 'cold') and nchains > 1:\n msg = 'skip multi-chain with 0 iterations or cold cache'\n raise NotImplementedError(msg)\n \n sig = signature(gbart)\n support_multichain = 'mc_cores' in sig.parameters\n if nchains != 1 and not support_multichain:\n msg = 'multi-chain not supported'\n raise NotImplementedError(msg)\n \n # random seed\n key = random.key(2025_06_24_14_55)\n keys = list(random.split(key, 3))\n \n # generate simulated data\n sigma = 0.1\n T = 2\n X = random.uniform(keys.pop(), (P, N), float, -2, 2)\n f = lambda X: jnp.sum(jnp.cos(2 * jnp.pi / T * X), axis=0)\n y = f(X) + sigma * random.normal(keys.pop(), (N,))\n \n # arguments\n self.kw = dict(\n x_train=X,\n y_train=y,\n nskip=niters // 2,\n ndpost=(niters - niters // 2) * nchains,\n seed=keys.pop(),\n )\n if support_multichain:\n self.kw.update(mc_cores=nchains)\n \n # decide how much to cold-start\n match cache:\n case 'cold':\n clear_caches()\n case 'warm':\n self.time_gbart()", 14 | "min_run_count": 2, 15 | "name": "speed.TimeGbart.time_gbart", 16 | "number": 1, 17 | "param_names": [ 18 | "niters", 19 | "cache", 20 | "nchains" 21 | ], 22 | "params": [ 23 | [ 24 | "0", 25 | "10" 26 | ], 27 | [ 28 | "'cold'", 29 | "'warm'" 30 | ], 31 | [ 32 | "1", 33 | "2", 34 | "8", 35 | "32" 36 | ] 37 | ], 38 | "repeat": 0, 39 | "rounds": 2, 40 | "sample_time": 0.01, 41 | "timeout": 30.0, 42 | "type": "time", 43 | "unit": "seconds", 44 | "version": "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 45 | "warmup_time": 0.0 46 | }, 47 | "speed.TimeRunMcmc.time_run_mcmc": { 48 | "code": "class TimeRunMcmc:\n @skip_for_params(list(product(['compile'], [0], params[2])))\n def time_run_mcmc(self, mode: Mode, *_):\n \"\"\"Time running or compiling the function.\"\"\"\n match mode:\n case 'compile':\n # re-wrap and jit the function in the benchmark case because otherwise\n # the compiled function gets cached even if I call `compile` explicitly\n @partial(\n jit, static_argnames=('n_save', 'n_skip', 'n_burn', 'callback')\n )\n def f(**kw):\n return run_mcmc(**kw)\n \n f.lower(**self.kw).compile()\n \n case 'run':\n block_until_ready(run_mcmc(**self.kw))\n\n def setup(self, mode: Mode, niters: int, cache: Cache):\n \"\"\"Prepare the arguments, compile the function, and run to warm-up.\"\"\"\n self.kw = dict(\n key=random.key(2025_04_25_15_57),\n bart=simple_init(P, N, NTREE),\n n_save=niters // 2,\n n_burn=niters // 2,\n n_skip=1,\n callback=lambda **_: None,\n )\n \n # adapt arguments for old versions\n sig = signature(run_mcmc)\n if 'callback' not in sig.parameters:\n self.kw.pop('callback')\n \n # catch bug and skip if found\n try:\n array_kw = {k: v for k, v in self.kw.items() if isinstance(v, jnp.ndarray)}\n nonarray_kw = {\n k: v for k, v in self.kw.items() if not isinstance(v, jnp.ndarray)\n }\n partial_run_mcmc = partial(run_mcmc, **nonarray_kw)\n eval_shape(partial_run_mcmc, **array_kw)\n except ZeroDivisionError:\n if niters:\n raise\n else:\n msg = 'skipping due to division by zero bug with zero iterations'\n raise NotImplementedError(msg) from None\n \n # decide how much to cold-start\n match cache:\n case 'cold':\n clear_caches()\n case 'warm':\n # prepare copies of the args because of buffer donation\n key = jnp.copy(self.kw['key'])\n bart = tree_map(jnp.copy, self.kw['bart'])\n self.time_run_mcmc(mode)\n # put copies in place of donated buffers\n self.kw.update(key=key, bart=bart)", 49 | "min_run_count": 2, 50 | "name": "speed.TimeRunMcmc.time_run_mcmc", 51 | "number": 1, 52 | "param_names": [ 53 | "mode", 54 | "niters", 55 | "cache" 56 | ], 57 | "params": [ 58 | [ 59 | "'compile'", 60 | "'run'" 61 | ], 62 | [ 63 | "0", 64 | "10" 65 | ], 66 | [ 67 | "'cold'", 68 | "'warm'" 69 | ] 70 | ], 71 | "repeat": 0, 72 | "rounds": 2, 73 | "sample_time": 0.01, 74 | "timeout": 30.0, 75 | "type": "time", 76 | "unit": "seconds", 77 | "version": "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 78 | "warmup_time": 0.0 79 | }, 80 | "speed.TimeStep.time_step": { 81 | "code": "class TimeStep:\n def time_step(self, mode: Mode, _):\n \"\"\"Time running compiling `step` or running it a few times.\"\"\"\n match mode:\n case 'compile':\n \n @jit\n def f(*args):\n return self.func(*args)\n \n f.lower(*self.args).compile()\n \n case 'run':\n block_until_ready(self.compiled_func(*self.args))\n\n def setup(self, mode: Mode, kind: Kind):\n \"\"\"Create an initial MCMC state and random seed, compile & warm-up.\"\"\"\n key = random.key(2025_06_24_12_07)\n if kind.startswith('vmap-'):\n length = int(kind.split('-')[1])\n keys = list(random.split(key, (2, length)))\n else:\n keys = list(random.split(key))\n \n self.args = (keys, simple_init(P, N, NTREE, kind))\n \n def func(keys, bart):\n bart = step(key=keys.pop(), bart=bart)\n if kind == 'sparse':\n bart = mcmcstep.step_sparse(keys.pop(), bart)\n return bart\n \n if kind.startswith('vmap-'):\n axes = vmap_axes_for_state(self.args[1])\n func = vmap(func, in_axes=(0, axes), out_axes=axes)\n \n self.func = func\n self.compiled_func = jit(func).lower(*self.args).compile()\n if mode == 'run':\n block_until_ready(self.compiled_func(*self.args))", 82 | "min_run_count": 2, 83 | "name": "speed.TimeStep.time_step", 84 | "number": 0, 85 | "param_names": [ 86 | "mode", 87 | "kind" 88 | ], 89 | "params": [ 90 | [ 91 | "'compile'", 92 | "'run'" 93 | ], 94 | [ 95 | "'plain'", 96 | "'binary'", 97 | "'weights'", 98 | "'sparse'", 99 | "'vmap-1'", 100 | "'vmap-2'" 101 | ] 102 | ], 103 | "repeat": 0, 104 | "rounds": 2, 105 | "sample_time": 0.01, 106 | "timeout": 30.0, 107 | "type": "time", 108 | "unit": "seconds", 109 | "version": "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 110 | "warmup_time": -1 111 | }, 112 | "version": 2 113 | } -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | # bartz/pyproject.toml 2 | # 3 | # Copyright (c) 2024-2025, The Bartz Contributors 4 | # 5 | # This file is part of bartz. 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | [build-system] 26 | requires = ["uv_build>=0.9.5,<0.10.0"] 27 | build-backend = "uv_build" 28 | 29 | [project] 30 | name = "bartz" 31 | version = "0.7.0" 32 | description = "Super-fast BART (Bayesian Additive Regression Trees) in Python" 33 | authors = [{ name = "Giacomo Petrillo", email = "info@giacomopetrillo.com" }] 34 | license = "MIT" 35 | readme = "README.md" 36 | requires-python = ">=3.10" 37 | dependencies = [ 38 | "equinox>=0.12.2", 39 | "jax>=0.5.3", 40 | "jaxtyping>=0.3.2", 41 | "numpy>=1.25.2", 42 | "scipy>=1.11.4", 43 | ] 44 | 45 | [project.urls] 46 | Homepage = "https://github.com/Gattocrucco/bartz" 47 | Documentation = "https://gattocrucco.github.io/bartz/docs-dev" 48 | Issues = "https://github.com/Gattocrucco/bartz/issues" 49 | 50 | [dependency-groups] 51 | only-local = [ 52 | "appnope>=0.1.4", 53 | "ipython>=8.36.0", 54 | "matplotlib>=3.10.3", 55 | "matplotlib-label-lines>=0.8.1", 56 | "pre-commit>=4.2.0", 57 | "scikit-learn>=1.6.1", 58 | "snakeviz>=2.2.2", 59 | "virtualenv>=20.31.2", 60 | "xgboost>=3.0.0", 61 | ] 62 | ci = [ 63 | "asv>=0.6.4", 64 | "flaky>=3.8.1", 65 | "gitpython>=3.1.43", 66 | "myst-parser>=4.0.1", 67 | "packaging>=25.0", 68 | "polars[pandas,pyarrow]>=1.29.0", 69 | "pytest>=8.3.5", 70 | "pytest-cov>=6.1.1", 71 | "pytest-timeout>=2.4.0", 72 | "pytest-timer[termcolor]>=1.0.0", 73 | "pytest-xdist>=3.6.1", 74 | "rpy2>=3.5.17", 75 | "sphinx>=8.1.3", 76 | "sphinx-autodoc-typehints>=3.0.1", 77 | "tomli>=2.2.1", 78 | ] 79 | 80 | [tool.pytest.ini_options] 81 | cache_dir = "config/pytest_cache" 82 | testpaths = ["tests"] 83 | addopts = [ 84 | "-r xXfE", 85 | "--pdbcls=IPython.terminal.debugger:TerminalPdb", 86 | "--durations=3", 87 | "--verbose", 88 | "--import-mode=importlib", 89 | ] 90 | filterwarnings = [ 91 | "ignore:unclosed database:ResourceWarning", 92 | ] 93 | timeout = 512 94 | timeout_method = "thread" # when jax hangs, signals do not work 95 | 96 | [tool.coverage.run] 97 | branch = true 98 | source_pkgs = ["bartz", "tests"] 99 | 100 | [tool.coverage.report] 101 | show_missing = true 102 | 103 | [tool.coverage.html] 104 | show_contexts = true 105 | directory = "_site/coverage" 106 | 107 | [tool.coverage.paths] 108 | # the first path in each list must be the source directory in the machine that's 109 | # generating the coverage report 110 | 111 | github = [ 112 | '/home/runner/work/bartz/bartz/src/bartz/', 113 | '/Users/runner/work/bartz/bartz/src/bartz/', 114 | 'D:\a\bartz\bartz\src\bartz\', 115 | '/Library/Frameworks/Python.framework/Versions/*/lib/python*/site-packages/bartz/', 116 | '/Users/runner/hostedtoolcache/Python/*/*/lib/python*/site-packages/bartz/', 117 | '/opt/hostedtoolcache/Python/*/*/lib/python*/site-packages/bartz/', 118 | 'C:\hostedtoolcache\windows\Python\*\*\Lib\site-packages\bartz\', 119 | ] 120 | 121 | local = [ 122 | 'src/bartz/', 123 | '/home/runner/work/bartz/bartz/src/bartz/', 124 | '/Users/runner/work/bartz/bartz/src/bartz/', 125 | 'D:\a\bartz\bartz\src\bartz\', 126 | '/Library/Frameworks/Python.framework/Versions/*/lib/python*/site-packages/bartz/', 127 | '/Users/runner/hostedtoolcache/Python/*/*/lib/python*/site-packages/bartz/', 128 | '/opt/hostedtoolcache/Python/*/*/lib/python*/site-packages/bartz/', 129 | 'C:\hostedtoolcache\windows\Python\*\*\Lib\site-packages\bartz\', 130 | ] 131 | 132 | [tool.ruff] 133 | exclude = [".asv", "*.ipynb"] 134 | cache-dir = "config/ruff_cache" 135 | 136 | [tool.ruff.format] 137 | quote-style = "single" 138 | skip-magic-trailing-comma = true 139 | 140 | [tool.ruff.lint.isort] 141 | split-on-trailing-comma = false 142 | 143 | [tool.ruff.lint] 144 | select = [ 145 | "ERA", # eradicate 146 | "S", # flake8-bandit 147 | "BLE", # flake8-blind-except 148 | "B", # bugbear 149 | "A", # flake8-builtins 150 | "C4", # flake8-comprehensions 151 | "CPY", # flake8-copyright 152 | "DTZ", # flake8-datetimez 153 | "T10", # flake8-debugger 154 | "EM", # flake8-errmsg 155 | "EXE", # flake8-executable 156 | "FIX", # flake8-fixme 157 | "ISC", # flake8-implicit-str-concat 158 | "INP", # flake8-no-pep420 159 | "PIE", # flake8-pie 160 | "T20", # flake8-print 161 | "PT", # flake8-pytest-style 162 | "RSE", # flake8-raise 163 | "RET", # flake8-return 164 | "SLF", # flake8-self 165 | "SIM", # flake8-simplify 166 | "TID", # flake8-tidy-imports 167 | "ARG", # flake8-unused-arguments 168 | "PTH", # flake8-use-pathlib 169 | "FLY", # flynt 170 | "I", # isort 171 | "C90", # mccabe 172 | "NPY", # NumPy-specific rules 173 | "PERF", # Perflint 174 | "W", # pycodestyle Warning 175 | "F", # pyflakes 176 | "D", # pydocstyle 177 | "PGH", # pygrep-hooks 178 | "PLC", # Pylint Convention 179 | "PLE", # Pylint Error 180 | "PLR", # Pylint Refactor 181 | "PLW", # Pyling Warning 182 | "UP", # pyupgrade 183 | "FURB", # refurb 184 | "RUF", # Ruff-specific rules 185 | "TRY", # tryceratops 186 | ] 187 | ignore = [ 188 | "B028", # warn with stacklevel = 2 189 | "C408", # Unnecessary `dict()` call (rewrite as a literal), it's too convenient for kwargs 190 | "D105", # Missing docstring in magic method 191 | "F722", # Syntax error in forward annotation. I ignore this because jaxtyping uses strings for shapes instead of for deferred annotations. 192 | "PIE790", # Unnecessary ... or pass. Ignored because sometimes I use ... as sentinel to tell the rest of ruff and pyright that an implementation is a stub. 193 | "PLR0913", # Too many arguments in function definition. Maybe I should do something about this? 194 | "PLR2004", # Magic value used in comparison, consider replacing `*` with a constant variable 195 | "RET505", # Unnecessary `{branch}` after `return` statement. I ignore this because I like to keep branches for readability. 196 | "RET506", # Unnecessary `else` after `raise` statement. I ignore this because I like to keep branches for readability. 197 | "S101", # Use of `assert` detected. Too annoying. 198 | "SIM108", # SIM108 Use ternary operator `*` instead of `if`-`else`-block, I find blocks more readable 199 | "UP037", # Remove quotes from type annotation. Ignore because jaxtyping. 200 | ] 201 | 202 | [tool.ruff.lint.per-file-ignores] 203 | "{config/*,docs/*}" = [ 204 | "D100", # Missing docstring in public module 205 | "D101", # Missing docstring in public class 206 | "D102", # Missing docstring in public method 207 | "D103", # Missing docstring in public function 208 | "D104", # Missing docstring in public package 209 | "INP001", # File * is part of an implicit namespace package. Add an `__init__.py`. 210 | ] 211 | "src/bartz/_version.py" = [ 212 | "CPY001", # Missing copyright notice at top of file 213 | ] 214 | "{config/*,docs/*,tests/*}" = [ 215 | "T201", # `print` found 216 | ] 217 | "{tests/*,benchmarks/*}" = [ 218 | "SLF001", # Private member accessed: `*` 219 | "TID253", # `{module}` is banned at the module level 220 | ] 221 | "docs/conf.py" = [ 222 | "S607", # Starting a process with a partial executable path. Ignored because for a build script it makes more sense to use PATH. 223 | ] 224 | 225 | [tool.ruff.lint.pydocstyle] 226 | convention = "numpy" 227 | 228 | [tool.ruff.lint.flake8-copyright] 229 | min-file-size = 1 230 | 231 | [tool.ruff.lint.flake8-tidy-imports] 232 | banned-module-level-imports = ["bartz.debug"] 233 | ban-relative-imports = "all" 234 | 235 | [tool.pydoclint] 236 | arg-type-hints-in-signature = true 237 | arg-type-hints-in-docstring = false 238 | check-return-types = false 239 | check-yield-types = false 240 | treat-property-methods-as-class-attributes = true 241 | check-style-mismatch = true 242 | show-filenames-in-every-violation-message = true 243 | check-class-attributes = false 244 | # do not check class attributes because in dataclasses I document them as 245 | # init parameters because they are duplicated in the html docs otherwise. 246 | --------------------------------------------------------------------------------