├── .python-version
├── docs
├── .gitignore
├── Makefile
├── make.bat
├── readme.md
├── reference
│ ├── profile.rst
│ ├── debug.rst
│ ├── interface.rst
│ ├── grove.rst
│ ├── prepcovars.rst
│ ├── mcmcstep.rst
│ ├── mcmcloop.rst
│ ├── index.rst
│ └── jaxext.rst
├── guide
│ ├── index.rst
│ ├── installation.rst
│ └── quickstart.rst
├── index.rst
├── _static
│ └── custom.css
├── pkglist.md
├── development.rst
└── conf.py
├── .Rprofile
├── src
└── bartz
│ ├── _version.py
│ ├── jaxext
│ ├── scipy
│ │ ├── __init__.py
│ │ ├── stats.py
│ │ └── special.py
│ ├── _autobatch.py
│ └── __init__.py
│ ├── __init__.py
│ └── _profiler.py
├── config
├── ipython
│ ├── .gitignore
│ └── profile_default
│ │ ├── startup
│ │ └── startup.ipy
│ │ └── ipython_config.py
└── refs-for-asv.py
├── .asv
├── .gitignore
└── results
│ ├── gattocrucco-m1
│ ├── machine.json
│ ├── b145bd73-virtualenv-py3.13.json
│ ├── 485be21f-virtualenv-py3.13.json
│ ├── c0940a3a-virtualenv-py3.13.json
│ ├── d549d06f-virtualenv-py3.13.json
│ └── 1119f48f-virtualenv-py3.13.json
│ └── benchmarks.json
├── _site
├── .gitignore
└── index.html
├── renv
├── .gitignore
└── settings.json
├── .gitlint
├── .github
├── dependabot.yml
└── workflows
│ └── release.yml
├── .gitignore
├── LICENSE
├── benchmarks
├── __init__.py
└── rmse.py
├── tests
├── __init__.py
├── rbartpackages
│ ├── __init__.py
│ ├── bartMachine.py
│ ├── dbarts.py
│ ├── BART.py
│ ├── BART3.py
│ └── _base.py
├── conftest.py
├── test_mcmcstep.py
├── test_mcmcloop.py
├── util.py
├── test_meta.py
├── test_prepcovars.py
└── test_debug.py
├── .pre-commit-config.yaml
├── README.md
├── Makefile
├── asv.conf.json
└── pyproject.toml
/.python-version:
--------------------------------------------------------------------------------
1 | 3.14
2 |
--------------------------------------------------------------------------------
/docs/.gitignore:
--------------------------------------------------------------------------------
1 | _build/
2 |
--------------------------------------------------------------------------------
/.Rprofile:
--------------------------------------------------------------------------------
1 | source("renv/activate.R")
2 |
--------------------------------------------------------------------------------
/src/bartz/_version.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.7.0'
2 |
--------------------------------------------------------------------------------
/config/ipython/.gitignore:
--------------------------------------------------------------------------------
1 | history.sqlite
2 | db/dhist
3 |
--------------------------------------------------------------------------------
/.asv/.gitignore:
--------------------------------------------------------------------------------
1 | *
2 | !.gitignore
3 | !results/
4 | !results/**
5 |
--------------------------------------------------------------------------------
/_site/.gitignore:
--------------------------------------------------------------------------------
1 | docs/
2 | docs-dev/
3 | coverage/
4 | benchmarks/
5 |
--------------------------------------------------------------------------------
/renv/.gitignore:
--------------------------------------------------------------------------------
1 | library/
2 | local/
3 | cellar/
4 | lock/
5 | python/
6 | sandbox/
7 | staging/
8 |
--------------------------------------------------------------------------------
/.gitlint:
--------------------------------------------------------------------------------
1 | [general]
2 | ignore = body-is-missing
3 |
4 | [title-max-length]
5 | line-length = 50
6 |
7 | [body-max-line-length]
8 | line-length = 72
9 |
--------------------------------------------------------------------------------
/.asv/results/gattocrucco-m1/machine.json:
--------------------------------------------------------------------------------
1 | {
2 | "arch": "arm64",
3 | "cpu": "Apple M1 Pro",
4 | "machine": "gattocrucco-m1",
5 | "num_cpu": "8",
6 | "os": "Darwin 24.3.0",
7 | "ram": "17179869184",
8 | "version": 1
9 | }
--------------------------------------------------------------------------------
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | # Please see the documentation for all configuration options:
2 | # https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
3 |
4 | version: 2
5 | updates:
6 | - package-ecosystem: "github-actions"
7 | directory: "/"
8 | schedule:
9 | interval: "monthly"
10 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python
2 | __pycache__
3 | dist/
4 | .coverage*
5 | .venv/
6 |
7 | # jax tracer
8 | **/????_??_??_??_??_??/ALL_HOSTS.op_stats.pb
9 | **/????_??_??_??_??_??/cache_version.txt
10 | **/????_??_??_??_??_??/*.SSTABLE
11 | **/????_??_??_??_??_??/*.trace.json.gz
12 | **/????_??_??_??_??_??/*.xplane.pb
13 | **/????_??_??_??_??_??/*.hlo_proto.pb
14 | **/????_??_??_??_??_??/.cached_tools.json
15 |
16 | # python profiler
17 | *.prof
18 |
--------------------------------------------------------------------------------
/renv/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "bioconductor.version": null,
3 | "external.libraries": [],
4 | "ignored.packages": [],
5 | "package.dependency.fields": [
6 | "Imports",
7 | "Depends",
8 | "LinkingTo"
9 | ],
10 | "ppm.enabled": null,
11 | "ppm.ignored.urls": [],
12 | "r.version": null,
13 | "snapshot.type": "implicit",
14 | "use.cache": true,
15 | "vcs.ignore.cellar": true,
16 | "vcs.ignore.library": true,
17 | "vcs.ignore.local": true,
18 | "vcs.manage.ignores": true
19 | }
20 |
--------------------------------------------------------------------------------
/config/ipython/profile_default/startup/startup.ipy:
--------------------------------------------------------------------------------
1 | with open(__file__) as f:
2 | print("Startup commands:")
3 | for i, line in enumerate(f):
4 | if i >= 7:
5 | print(f">>> {line.rstrip('\n')}")
6 | print()
7 |
8 | %matplotlib
9 | %load_ext autoreload
10 | %autoreload 2
11 | %load_ext snakeviz
12 |
13 | import appnope
14 | appnope.nope()
15 |
16 | from functools import partial
17 |
18 | import jax
19 | from jax import numpy as jnp
20 | import numpy as np
21 | import equinox as eqx
22 | import jaxtyping as jt
23 |
24 | import bartz
25 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line, and also
5 | # from the environment for the first two.
6 | SPHINXOPTS ?=
7 | SPHINXBUILD ?= sphinx-build
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
21 |
--------------------------------------------------------------------------------
/_site/index.html:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
bartz
5 |
6 |
7 |
8 |
9 | bartz
10 |
11 |
12 | A JAX implementation of BART (Bayesian Additive Regression Trees).
13 |
21 |
22 |
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 |
13 | if "%1" == "" goto help
14 |
15 | %SPHINXBUILD% >NUL 2>NUL
16 | if errorlevel 9009 (
17 | echo.
18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
19 | echo.installed, then set the SPHINXBUILD environment variable to point
20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
21 | echo.may add the Sphinx directory to PATH.
22 | echo.
23 | echo.If you don't have Sphinx installed, grab it from
24 | echo.http://sphinx-doc.org/
25 | exit /b 1
26 | )
27 |
28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
29 | goto end
30 |
31 | :help
32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
33 |
34 | :end
35 | popd
36 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024-2025 The Bartz Contributors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/docs/readme.md:
--------------------------------------------------------------------------------
1 |
26 |
27 | ```{include} ../README.md
28 | ```
29 |
--------------------------------------------------------------------------------
/benchmarks/__init__.py:
--------------------------------------------------------------------------------
1 | # bartz/benchmarks/__init__.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Benchmarking code run by asv."""
26 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/__init__.py
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Unit tests to be run with pytest."""
26 |
--------------------------------------------------------------------------------
/config/ipython/profile_default/ipython_config.py:
--------------------------------------------------------------------------------
1 | # bartz/config/ipython/profile_default/ipython_config.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
--------------------------------------------------------------------------------
/tests/rbartpackages/__init__.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/rbartpackages/__init__.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Wrappers of R BART packages."""
26 |
--------------------------------------------------------------------------------
/src/bartz/jaxext/scipy/__init__.py:
--------------------------------------------------------------------------------
1 | # bartz/src/bartz/jaxext/scipy/__init__.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Mockup of the :external:py:mod:`scipy` module."""
26 |
--------------------------------------------------------------------------------
/docs/reference/profile.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/reference/profile.rst
2 | ..
3 | .. Copyright (c) 2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | Profiling
26 | ---------
27 |
28 | .. autofunction:: bartz.profile_mode
29 |
--------------------------------------------------------------------------------
/docs/reference/debug.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/reference/debug.rst
2 | ..
3 | .. Copyright (c) 2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | Debugging
26 | ---------
27 |
28 | .. automodule:: bartz.debug
29 | :members:
30 |
--------------------------------------------------------------------------------
/docs/reference/interface.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/interface.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | Interface
26 | =========
27 |
28 | .. automodule:: bartz.BART
29 | :members:
30 |
--------------------------------------------------------------------------------
/docs/reference/grove.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/grove.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | Tree manipulation
26 | =================
27 |
28 | .. automodule:: bartz.grove
29 | :members:
30 |
--------------------------------------------------------------------------------
/docs/reference/prepcovars.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/prepcovars.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | Data processing
26 | ===============
27 |
28 | .. automodule:: bartz.prepcovars
29 | :members:
30 |
--------------------------------------------------------------------------------
/docs/reference/mcmcstep.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/mcmcstep.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | MCMC setup and step
26 | ===================
27 |
28 | .. automodule:: bartz.mcmcstep
29 | :members:
30 |
--------------------------------------------------------------------------------
/docs/guide/index.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/guide/index.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | Guide
26 | =====
27 |
28 | .. toctree::
29 | :maxdepth: 1
30 |
31 | installation.rst
32 | quickstart.rst
33 |
--------------------------------------------------------------------------------
/docs/reference/mcmcloop.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/mcmcloop.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | MCMC loop
26 | =========
27 |
28 | .. automodule:: bartz.mcmcloop
29 | :members:
30 | :special-members: __call__
31 |
--------------------------------------------------------------------------------
/docs/reference/index.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/reference/index.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | Reference
26 | =========
27 |
28 | .. toctree::
29 | :maxdepth: 1
30 |
31 | interface.rst
32 | grove.rst
33 | mcmcstep.rst
34 | mcmcloop.rst
35 | prepcovars.rst
36 | jaxext.rst
37 | debug.rst
38 | profile.rst
39 |
--------------------------------------------------------------------------------
/src/bartz/__init__.py:
--------------------------------------------------------------------------------
1 | # bartz/src/bartz/__init__.py
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """
26 | Super-fast BART (Bayesian Additive Regression Trees) in Python.
27 |
28 | See the manual at https://gattocrucco.github.io/bartz/docs
29 | """
30 |
31 | from bartz import BART, grove, jaxext, mcmcloop, mcmcstep, prepcovars # noqa: F401
32 | from bartz._profiler import profile_mode # noqa: F401
33 | from bartz._version import __version__ # noqa: F401
34 |
--------------------------------------------------------------------------------
/src/bartz/jaxext/scipy/stats.py:
--------------------------------------------------------------------------------
1 | # bartz/src/bartz/jaxext/scipy/stats.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Mockup of the :external:py:mod:`scipy.stats` module."""
26 |
27 | from bartz.jaxext.scipy.special import gammainccinv
28 |
29 |
30 | class invgamma:
31 | """Class that represents the distribution InvGamma(a, 1)."""
32 |
33 | @staticmethod
34 | def ppf(q, a):
35 | """Percentile point function."""
36 | return 1 / gammainccinv(a, q)
37 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/index.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | .. module:: bartz
26 |
27 | bartz
28 | =====
29 |
30 | Contents
31 | --------
32 |
33 | .. toctree::
34 | :maxdepth: 1
35 |
36 | readme.md
37 |
38 | .. toctree::
39 | :maxdepth: 2
40 |
41 | guide/index.rst
42 | reference/index.rst
43 |
44 | .. toctree::
45 | :maxdepth: 1
46 |
47 | changelog.md
48 | development.rst
49 | pkglist.md
50 |
51 | * :ref:`genindex`
52 | * :ref:`search`
53 |
--------------------------------------------------------------------------------
/docs/reference/jaxext.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/jaxext.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | JAX extensions
26 | ==============
27 |
28 | bartz.jaxext
29 | ------------
30 |
31 | .. automodule:: bartz.jaxext
32 | :members:
33 |
34 | .. autofunction:: bartz.jaxext.autobatch
35 |
36 | bartz.jaxext.scipy.special
37 | --------------------------
38 |
39 | .. automodule:: bartz.jaxext.scipy.special
40 | :members:
41 |
42 | bartz.jaxext.scipy.stats
43 | ------------------------
44 |
45 | .. automodule:: bartz.jaxext.scipy.stats
46 | :members:
47 |
--------------------------------------------------------------------------------
/.github/workflows/release.yml:
--------------------------------------------------------------------------------
1 | # bartz/workflows/release.yml
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | name: release
26 | permissions:
27 | contents: read
28 |
29 | on:
30 | release:
31 | types: [published]
32 | workflow_dispatch:
33 |
34 | jobs:
35 | tests:
36 | runs-on: ubuntu-latest
37 | steps:
38 | - name: Set up Python
39 | uses: actions/setup-python@v6
40 | with:
41 | python-version-file: ".python-version"
42 | - name: Update pip
43 | run: python -m pip install --upgrade pip
44 | - name: Install software
45 | run: python -m pip install bartz
46 | - name: Try to import it
47 | run: python -c 'import bartz;print(bartz.__version__)'
48 |
--------------------------------------------------------------------------------
/docs/guide/installation.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/installation.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | Installation
26 | ============
27 |
28 | Install and set up Python. There are various ways to do it; my favorite one is to use `uv `_. Then:
29 |
30 | .. code-block:: sh
31 |
32 | pip install bartz
33 |
34 | To install the latest development version, do instead
35 |
36 | .. code-block:: sh
37 |
38 | pip install git+https://github.com/Gattocrucco/bartz.git
39 |
40 | To install a specific commit, do
41 |
42 | .. code-block:: sh
43 |
44 | pip install git+https://github.com/Gattocrucco/bartz.git@
45 |
46 | To use on GPU on a system that doesn't provide `jax` pre-installed, read how to install jax `in its manual `_.
47 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 |
4 | exclude: '^\.asv/results/.+\.json$'
5 | default_stages: [pre-commit]
6 | default_install_hook_types: [pre-commit, commit-msg]
7 |
8 | repos:
9 | - repo: https://github.com/pre-commit/pre-commit-hooks
10 | rev: v6.0.0
11 | hooks:
12 | - id: check-added-large-files
13 | - id: check-ast
14 | - id: check-case-conflict
15 | - id: check-docstring-first
16 | - id: check-executables-have-shebangs
17 | - id: check-illegal-windows-names
18 | - id: check-merge-conflict
19 | - id: check-shebang-scripts-are-executable
20 | - id: check-symlinks
21 | - id: check-toml
22 | - id: check-yaml
23 | - id: destroyed-symlinks
24 | - id: detect-private-key
25 | - id: end-of-file-fixer
26 | - id: forbid-submodules
27 | - id: mixed-line-ending
28 | args: [--fix=lf]
29 | - id: name-tests-test
30 | args: [--pytest-test-first]
31 | exclude: tests/(rbartpackages/.+\.py|util\.py)$
32 | - id: trailing-whitespace
33 | exclude: '^renv/activate\.R$' # because renv edits it automatically, with trailing whitespace
34 | - repo: https://github.com/sbrunner/hooks
35 | rev: 1.6.1
36 | hooks:
37 | - id: copyright
38 | - id: copyright-required
39 | files: \.(py|rst)$
40 | exclude: src/bartz/_version\.py$
41 | - repo: https://github.com/astral-sh/ruff-pre-commit
42 | rev: v0.14.3
43 | hooks:
44 | - id: ruff-check
45 | args: [--fix]
46 | - id: ruff-format
47 | - repo: https://github.com/jsh9/pydoclint
48 | rev: 0.7.6
49 | hooks:
50 | - id: pydoclint
51 | - repo: https://github.com/abravalheri/validate-pyproject
52 | rev: v0.24.1
53 | hooks:
54 | - id: validate-pyproject
55 | # Optional extra validations from SchemaStore:
56 | additional_dependencies: ["validate-pyproject-schema-store[all]"]
57 | - repo: https://github.com/jorisroovers/gitlint
58 | rev: v0.19.1
59 | hooks:
60 | - id: gitlint # this is already defined on stage commit-msg
61 |
--------------------------------------------------------------------------------
/tests/rbartpackages/bartMachine.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/rbartpackages/bartMachine.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Python wrapper of the R package bartMachine."""
26 |
27 | # ruff: noqa: D102
28 |
29 | from rpy2 import robjects
30 |
31 | from tests.rbartpackages._base import RObjectBase, rmethod
32 |
33 |
34 | class bartMachine(RObjectBase): # noqa: D101, because the doc is pulled from R
35 | _rfuncname = 'bartMachine::bartMachine'
36 |
37 | def __init__(self, *args, num_cores=None, megabytes=5000, **kw):
38 | robjects.r(f'options(java.parameters = "-Xmx{megabytes:d}m")')
39 | robjects.r('loadNamespace("bartMachine")')
40 | if num_cores is not None:
41 | robjects.r(f'bartMachine::set_bart_machine_num_cores({int(num_cores)})')
42 | super().__init__(*args, **kw)
43 |
44 | @rmethod
45 | def predict(self, *args, **kw): ...
46 |
47 | @rmethod
48 | def get_posterior(self, *args, **kw): ...
49 |
50 | @rmethod
51 | def get_sigsqs(self, *args, **kw): ...
52 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://pypi.org/project/bartz/)
2 | [](https://doi.org/10.5281/zenodo.13931477)
3 |
4 | # BART vectoriZed
5 |
6 | An implementation of Bayesian Additive Regression Trees (BART) in JAX.
7 |
8 | If you don't know what BART is, but know XGBoost, consider BART as a sort of Bayesian XGBoost. bartz makes BART run as fast as XGBoost.
9 |
10 | BART is a nonparametric Bayesian regression technique. Given training predictors $X$ and responses $y$, BART finds a function to predict $y$ given $X$. The result of the inference is a sample of possible functions, representing the uncertainty over the determination of the function.
11 |
12 | This Python module provides an implementation of BART that runs on GPU, to process large datasets faster. It is also good on CPU. Most other implementations of BART are for R, and run on CPU only.
13 |
14 | On CPU, bartz runs at the speed of dbarts (the fastest implementation I know of) if n > 20,000, but using 1/20 of the memory. On GPU, the speed premium depends on sample size; it is convenient over CPU only for n > 10,000. The maximum speedup is currently 200x, on an Nvidia A100 and with at least 2,000,000 observations.
15 |
16 | [This Colab notebook](https://colab.research.google.com/github/Gattocrucco/bartz/blob/main/docs/examples/basic_simdata.ipynb) runs bartz with n = 100,000 observations, p = 1000 predictors, 10,000 trees, for 1000 MCMC iterations, in 6 minutes.
17 |
18 | ## Links
19 |
20 | - [Documentation (latest release)](https://gattocrucco.github.io/bartz/docs)
21 | - [Documentation (development version)](https://gattocrucco.github.io/bartz/docs-dev)
22 | - [Repository](https://github.com/Gattocrucco/bartz)
23 | - [Code coverage](https://gattocrucco.github.io/bartz/coverage)
24 | - [Benchmarks](https://gattocrucco.github.io/bartz/benchmarks)
25 | - [List of BART packages](https://gattocrucco.github.io/bartz/docs-dev/pkglist.html)
26 |
27 | ## Citing bartz
28 |
29 | Article: Petrillo (2024), "Very fast Bayesian Additive Regression Trees on GPU", [arXiv:2410.23244](https://arxiv.org/abs/2410.23244).
30 |
31 | To cite the software directly, including the specific version, use [zenodo](https://doi.org/10.5281/zenodo.13931477).
32 |
--------------------------------------------------------------------------------
/docs/guide/quickstart.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/quickstart.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | Quickstart
26 | ==========
27 |
28 | Basics
29 | ------
30 |
31 | Import and use the `bartz.BART.gbart` class:
32 |
33 | .. code-block:: python
34 |
35 | from bartz.BART import gbart
36 | bart = gbart(X, y, ...)
37 | y_pred = bart.predict(X_test)
38 |
39 | The interface hews to the R package `BART `_, with a few differences explained in the documentation of `bartz.BART.gbart`.
40 |
41 | JAX
42 | ---
43 |
44 | `bartz` is implemented using `jax`, a Google library for machine learning. It allows to run the code on GPU or TPU and do various other things.
45 |
46 | For basic usage, JAX is just an alternative implementation of `numpy`. The arrays returned by `~bartz.BART.gbart` are "jax arrays" instead of "numpy arrays", but there is no perceived difference in their functionality. If you pass numpy arrays to `bartz`, they will be converted automatically. You don't have to deal with `jax` in any way.
47 |
48 | For advanced usage, refer to the `jax documentation `_.
49 |
50 | Advanced
51 | --------
52 |
53 | `bartz` exposes the various functions that implement the MCMC of BART. You can use those yourself to try to make your own variant of BART. See the rest of the documentation for reference; the main entry points are `bartz.mcmcstep.init` and `bartz.mcmcloop.run_mcmc`. Using the internals is the only way to change the device used by each step of the algorithm, which is useful to pre-process data on CPU and move to GPU only the state of the MCMC if the data preprocessing step does not fit in the GPU memory.
54 |
--------------------------------------------------------------------------------
/docs/_static/custom.css:
--------------------------------------------------------------------------------
1 | /* bartz/docs/_static/custom.css
2 | *
3 | * Copyright (c) 2024-2025, The Bartz Contributors
4 | *
5 | * This file is part of bartz.
6 | *
7 | * Permission is hereby granted, free of charge, to any person obtaining a copy
8 | * of this software and associated documentation files (the "Software"), to deal
9 | * in the Software without restriction, including without limitation the rights
10 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | * copies of the Software, and to permit persons to whom the Software is
12 | * furnished to do so, subject to the following conditions:
13 | *
14 | * The above copyright notice and this permission notice shall be included in all
15 | * copies or substantial portions of the Software.
16 | *
17 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | * SOFTWARE.
24 | */
25 |
26 | dl.py.method, dl.py.function {
27 | margin-top: 2em;
28 | margin-bottom: 2em;
29 | }
30 |
31 | dl.py.property {
32 | margin-top: 1em;
33 | }
34 |
35 | dl.py.class, dl.py.function {
36 | margin-top: 2.5em;
37 | margin-bottom: 2.5em;
38 | }
39 |
40 | h2 + dl.py.class, h2 + dl.py.function {
41 | margin-top: 1em;
42 | }
43 |
44 | /* space between parameter/attributes lists */
45 | dl.field-list > dd {
46 | margin-bottom: 1em;
47 | }
48 |
49 | /* no additional space after last list in the block */
50 | dl.field-list > dd:last-child {
51 | margin-bottom: 0em;
52 | }
53 |
54 | /* space between paragraphs in multi-paragraph param descriptions */
55 | ul.simple > li > p:not(:first-child) {
56 | margin-top: 0.5em;
57 | }
58 |
59 | /* no space between param name and first paragraph in multi-paragraph param descriptions */
60 | ul.simple > li > p:nth-child(2) {
61 | margin-top: 0em;
62 | }
63 |
64 | /* space between parameters */
65 | ul.simple > li:not(:last-child) {
66 | margin-bottom: 0.5em;
67 | }
68 |
69 | /* highlight types that originate from type hints rather than the docstring */
70 | span.sphinx_autodoc_typehints-type > code {
71 | background-color: #ee9;
72 | }
73 |
74 | /* viewcode extension */
75 |
76 | span.linenos {
77 | padding-right: 2ex;
78 | }
79 |
80 | div.viewcode-block:target {
81 | margin: 0;
82 | padding-bottom: 2em;
83 | }
84 |
85 | :not(.notranslate) > div.highlight pre {
86 | padding-left: 1ex;
87 | padding-right: 1ex;
88 | font-size: 0.73em;
89 | background: #f4f4f4;
90 | margin-bottom: 2em;
91 | }
92 |
--------------------------------------------------------------------------------
/config/refs-for-asv.py:
--------------------------------------------------------------------------------
1 | # bartz/config/refs-for-asv.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """
26 | Print a list of git refs for ASV benchmarking.
27 |
28 | This script outputs:
29 | 1. All tags on the default branch with commit dates after CUTOFF_DATE
30 | 2. The HEAD of the default branch
31 |
32 | The output format is one ref per line, suitable for piping to `asv run HASHFILE:-`
33 | """
34 |
35 | import datetime
36 |
37 | from git import Repo
38 |
39 | # Configuration
40 | CUTOFF_DATE = datetime.datetime(2025, 1, 1, tzinfo=datetime.timezone.utc)
41 |
42 |
43 | def main():
44 | repo = Repo('.')
45 |
46 | # Get the default branch name from git
47 | # This queries the symbolic-ref for the remote's HEAD
48 | default_branch_name = repo.git.symbolic_ref(
49 | 'refs/remotes/origin/HEAD', short=True
50 | ).split('/')[-1]
51 |
52 | # Get the default branch
53 | main_branch = repo.refs[default_branch_name]
54 |
55 | # Collect tags that are reachable from main and after cutoff date
56 | tags_to_include = []
57 |
58 | for tag in repo.tags:
59 | # Get the commit the tag points to
60 | commit = tag.commit
61 |
62 | # Check if this tag is reachable from main
63 | if not repo.is_ancestor(commit, main_branch.commit):
64 | continue
65 |
66 | # Check if commit date is after cutoff
67 | commit_date = datetime.datetime.fromtimestamp(
68 | commit.committed_date, tz=datetime.timezone.utc
69 | )
70 |
71 | if commit_date >= CUTOFF_DATE:
72 | tags_to_include.append((commit_date, tag.name))
73 |
74 | # Sort tags by commit date
75 | tags_to_include.sort()
76 |
77 | # Print tags
78 | for _, tag_name in tags_to_include:
79 | print(tag_name)
80 |
81 | # Print default branch ref
82 | print(default_branch_name)
83 |
84 |
85 | if __name__ == '__main__':
86 | main()
87 |
--------------------------------------------------------------------------------
/.asv/results/gattocrucco-m1/b145bd73-virtualenv-py3.13.json:
--------------------------------------------------------------------------------
1 | {"commit_hash": "b145bd73241934c9085d819c17f748601a5df1f6", "env_name": "virtualenv-py3.13", "date": 1745432679000, "params": {"arch": "arm64", "cpu": "Apple M1 Pro", "machine": "gattocrucco-m1", "num_cpu": "8", "os": "Darwin 24.3.0", "ram": "17179869184", "python": "3.13"}, "python": "3.13", "requirements": {}, "env_vars": {}, "result_columns": ["result", "params", "version", "started_at", "duration", "stats_ci_99_a", "stats_ci_99_b", "stats_q_25", "stats_q_75", "stats_number", "stats_repeat", "samples", "profile"], "results": {"speed.TimeGbart.time_gbart": [[0.8069052919163369, NaN, NaN, NaN, 0.03726235398789868, NaN, NaN, NaN, 2.4787650835351087, NaN, NaN, NaN, 0.13813462504185736, NaN, NaN, NaN], [["0", "10"], ["'cold'", "'warm'"], ["1", "2", "8", "32"]], "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 1761995899604, 57.82, [0.7853, null, null, null, 0.035602, null, null, null, 2.4398, null, null, null, 0.1356, null, null, null], [0.87869, null, null, null, 0.038062, null, null, null, 2.5118, null, null, null, 0.14207, null, null, null], [0.79807, null, null, null, 0.037027, null, null, null, 2.464, null, null, null, 0.13772, null, null, null], [0.81622, null, null, null, 0.037874, null, null, null, 2.4884, null, null, null, 0.14041, null, null, null], [1, null, null, null, 1, null, null, null, 1, null, null, null, 1, null, null, null], [10, null, null, null, 10, null, null, null, 8, null, null, null, 10, null, null, null]], "speed.TimeRunMcmc.time_run_mcmc": [[NaN, NaN, 1.604135291534476, 1.4034559369902126, 0.01205670804483816, 0.0002243124763481319, 1.6102928330074064, 0.02777252095984295], [["'compile'", "'run'"], ["0", "10"], ["'cold'", "'warm'"]], "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 1761995928502, 85.303, [null, null, 1.5351, 1.344, 0.010912, 0.00015521, 1.547, 0.027445], [null, null, 1.8783, 1.5496, 0.012672, 0.00027988, 1.7484, 0.028061], [null, null, 1.5735, 1.3705, 0.011628, 0.00017985, 1.5839, 0.027565], [null, null, 1.678, 1.4848, 0.01225, 0.00026531, 1.649, 0.027994], [null, null, 1, 1, 1, 1, 1, 1], [null, null, 10, 8, 10, 10, 10, 10]], "speed.TimeStep.time_step": [[0.946801999467425, NaN, NaN, NaN, 1.0379889375763014, 1.1937205209978856, 0.0029921979876235127, NaN, NaN, NaN, 0.003127937510726042, 0.006610083481064066], [["'compile'", "'run'"], ["'plain'", "'binary'", "'weights'", "'sparse'", "'vmap-1'", "'vmap-2'"]], "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 1761995971144, 126.78, [0.87452, null, null, null, 0.98037, 1.1273, 0.0029009, null, null, null, 0.0030316, 0.0064166], [0.96662, null, null, null, 1.2071, 1.3243, 0.0034885, null, null, null, 0.003271, 0.0067285], [0.9111, null, null, null, 1.0082, 1.1611, 0.0029625, null, null, null, 0.0030783, 0.0065423], [0.95784, null, null, null, 1.0891, 1.2815, 0.0030141, null, null, null, 0.0031771, 0.0066871], [1, null, null, null, 1, 1, 4, null, null, null, 4, 2], [10, null, null, null, 10, 8, 10, null, null, null, 10, 10]], "rmse.EvalGbart.track_rmse": [[0.5344812870025635], [], "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28", 1761995893086, 6.5178]}, "durations": {}, "version": 2}
--------------------------------------------------------------------------------
/.asv/results/gattocrucco-m1/485be21f-virtualenv-py3.13.json:
--------------------------------------------------------------------------------
1 | {"commit_hash": "485be21f4d702417930e51a0e4a169b287fedb7a", "env_name": "virtualenv-py3.13", "date": 1747384424000, "params": {"arch": "arm64", "cpu": "Apple M1 Pro", "machine": "gattocrucco-m1", "num_cpu": "8", "os": "Darwin 24.3.0", "ram": "17179869184", "python": "3.13"}, "python": "3.13", "requirements": {}, "env_vars": {}, "result_columns": ["result", "params", "version", "started_at", "duration", "stats_ci_99_a", "stats_ci_99_b", "stats_q_25", "stats_q_75", "stats_number", "stats_repeat", "samples", "profile"], "results": {"speed.TimeGbart.time_gbart": [[0.9466408959706314, NaN, NaN, NaN, 0.0372482294915244, NaN, NaN, NaN, 2.0026256454875693, NaN, NaN, NaN, 0.1389597289962694, NaN, NaN, NaN], [["0", "10"], ["'cold'", "'warm'"], ["1", "2", "8", "32"]], "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 1761996203795, 68.798, [0.90027, null, null, null, 0.036742, null, null, null, 1.9701, null, null, null, 0.13684, null, null, null], [1.0987, null, null, null, 0.042018, null, null, null, 2.0889, null, null, null, 0.14329, null, null, null], [0.91486, null, null, null, 0.036953, null, null, null, 1.9925, null, null, null, 0.13821, null, null, null], [0.97631, null, null, null, 0.037703, null, null, null, 2.0392, null, null, null, 0.14018, null, null, null], [1, null, null, null, 1, null, null, null, 1, null, null, null, 1, null, null, null], [10, null, null, null, 10, null, null, null, 10, null, null, null, 10, null, null, null]], "speed.TimeRunMcmc.time_run_mcmc": [[NaN, NaN, 1.150307229545433, 0.9872167290304787, 0.13351675000740215, 0.00015162501949816942, 1.1659881874802522, 0.027849478588905185], [["'compile'", "'run'"], ["0", "10"], ["'cold'", "'warm'"]], "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 1761996233129, 69.481, [null, null, 1.1156, 0.9476, 0.12715, 0.0001245, 1.1263, 0.027543], [null, null, 1.354, 1.0314, 0.1361, 0.00024, 1.2283, 0.028174], [null, null, 1.1261, 0.98495, 0.1327, 0.00013392, 1.1431, 0.02769], [null, null, 1.1632, 1.0095, 0.13516, 0.00017698, 1.192, 0.028044], [null, null, 1, 1, 1, 1, 1, 1], [null, null, 10, 10, 10, 10, 10, 10]], "speed.TimeStep.time_step": [[0.889989978983067, NaN, 0.9213577499613166, NaN, 0.9991666045389138, 1.2441695625311695, 0.0030455156229436398, NaN, 0.0050226875076380875, NaN, 0.0031668073788750917, 0.0066105417790822685], [["'compile'", "'run'"], ["'plain'", "'binary'", "'weights'", "'sparse'", "'vmap-1'", "'vmap-2'"]], "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 1761996267849, 166.61, [0.88136, null, 0.89148, null, 0.95987, 1.1576, 0.0028791, null, 0.0047176, null, 0.0030783, 0.0064644], [0.94919, null, 0.98912, null, 1.0999, 1.4406, 0.0033612, null, 0.01948, null, 0.0038439, 0.0068285], [0.88286, null, 0.9116, null, 0.98365, 1.1791, 0.0029699, null, 0.0047705, null, 0.0031384, 0.0065015], [0.91493, null, 0.95307, null, 1.0086, 1.3196, 0.0031029, null, 0.0053578, null, 0.0033833, 0.0067922], [1, null, 1, null, 1, 1, 4, null, 3, null, 4, 2], [10, null, 10, null, 10, 8, 10, null, 10, null, 10, 10]], "rmse.EvalGbart.track_rmse": [[0.5344812870025635], [], "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28", 1761996197843, 5.9512]}, "durations": {"": 9.19114875793457}, "version": 2}
--------------------------------------------------------------------------------
/.asv/results/gattocrucco-m1/c0940a3a-virtualenv-py3.13.json:
--------------------------------------------------------------------------------
1 | {"commit_hash": "c0940a3a05ac16fc8412ac99180dff2501950dc2", "env_name": "virtualenv-py3.13", "date": 1748555083000, "params": {"arch": "arm64", "cpu": "Apple M1 Pro", "machine": "gattocrucco-m1", "num_cpu": "8", "os": "Darwin 24.3.0", "ram": "17179869184", "python": "3.13"}, "python": "3.13", "requirements": {}, "env_vars": {}, "result_columns": ["result", "params", "version", "started_at", "duration", "stats_ci_99_a", "stats_ci_99_b", "stats_q_25", "stats_q_75", "stats_number", "stats_repeat", "samples", "profile"], "results": {"speed.TimeGbart.time_gbart": [[0.7986744999652728, NaN, NaN, NaN, 0.03784627094864845, NaN, NaN, NaN, 3.5406936869840138, NaN, NaN, NaN, 0.9474997499492019, NaN, NaN, NaN], [["0", "10"], ["'cold'", "'warm'"], ["1", "2", "8", "32"]], "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 1761996542831, 86.547, [0.75583, null, null, null, 0.035992, null, null, null, 3.1683, null, null, null, 0.93938, null, null, null], [0.99608, null, null, null, 0.038397, null, null, null, 3.8884, null, null, null, 0.9503, null, null, null], [0.78607, null, null, null, 0.037178, null, null, null, 3.3809, null, null, null, 0.94124, null, null, null], [0.8184, null, null, null, 0.038229, null, null, null, 3.6931, null, null, null, 0.94927, null, null, null], [1, null, null, null, 1, null, null, null, 1, null, null, null, 1, null, null, null], [10, null, null, null, 10, null, null, null, 6, null, null, null, 10, null, null, null]], "speed.TimeRunMcmc.time_run_mcmc": [[NaN, NaN, 1.160155457968358, 1.0007016459712759, NaN, NaN, 1.3899157289997675, 0.029487416497431695], [["'compile'", "'run'"], ["0", "10"], ["'cold'", "'warm'"]], "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 1761996580465, 61.236, [null, null, 1.1115, 0.94854, null, null, 1.2218, 0.028881], [null, null, 1.2112, 1.1713, null, null, 1.72, 0.030835], [null, null, 1.152, 0.96688, null, null, 1.2693, 0.029031], [null, null, 1.1734, 1.0377, null, null, 1.4687, 0.029823], [null, null, 1, 1, null, null, 1, 1], [null, null, 10, 10, null, null, 10, 10]], "speed.TimeStep.time_step": [[0.8653726250049658, 0.7353072710102424, 0.8938072085147724, NaN, 0.9353522704914212, 1.1077213124954142, 0.002947359360405244, 0.003221333507099189, 0.004795805492904037, NaN, 0.003138265630695969, 0.0064666144899092615], [["'compile'", "'run'"], ["'plain'", "'binary'", "'weights'", "'sparse'", "'vmap-1'", "'vmap-2'"]], "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 1761996612279, 183.0, [0.81426, 0.69315, 0.84356, null, 0.87336, 1.0136, 0.0029027, 0.0030542, 0.0047239, null, 0.0030811, 0.0063696], [0.96683, 0.79986, 0.96876, null, 1.1369, 1.1911, 0.0030285, 0.0033125, 0.0049198, null, 0.003279, 0.0066911], [0.83593, 0.69796, 0.85921, null, 0.90863, 1.0732, 0.0029294, 0.0031727, 0.0047741, null, 0.003114, 0.0063988], [0.88469, 0.75935, 0.92098, null, 1.0198, 1.1492, 0.0029909, 0.0032789, 0.0048458, null, 0.0031694, 0.0065193], [1, 1, 1, null, 1, 1, 4, 4, 3, null, 4, 2], [10, 10, 10, null, 10, 10, 10, 10, 10, null, 10, 10]], "rmse.EvalGbart.track_rmse": [[0.525512158870697], [], "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28", 1761996536239, 6.5916]}, "durations": {"": 10.419498920440674}, "version": 2}
--------------------------------------------------------------------------------
/.asv/results/gattocrucco-m1/d549d06f-virtualenv-py3.13.json:
--------------------------------------------------------------------------------
1 | {"commit_hash": "d549d06f29cfea142857d4682c1689a0a9b3806e", "env_name": "virtualenv-py3.13", "date": 1751913844000, "params": {"arch": "arm64", "cpu": "Apple M1 Pro", "machine": "gattocrucco-m1", "num_cpu": "8", "os": "Darwin 24.3.0", "ram": "17179869184", "python": "3.13"}, "python": "3.13", "requirements": {}, "env_vars": {}, "result_columns": ["result", "params", "version", "started_at", "duration", "stats_ci_99_a", "stats_ci_99_b", "stats_q_25", "stats_q_75", "stats_number", "stats_repeat", "samples", "profile"], "results": {"speed.TimeGbart.time_gbart": [[0.8703406045096926, NaN, NaN, NaN, 0.03905785398092121, NaN, NaN, NaN, 2.04708585399203, NaN, NaN, NaN, 1.0372965830028988, NaN, NaN, NaN], [["0", "10"], ["'cold'", "'warm'"], ["1", "2", "8", "32"]], "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 1761996916550, 87.455, [0.82967, null, null, null, 0.037165, null, null, null, 2.0258, null, null, null, 1.0274, null, null, null], [0.92584, null, null, null, 0.040358, null, null, null, 2.089, null, null, null, 1.055, null, null, null], [0.84177, null, null, null, 0.03859, null, null, null, 2.0352, null, null, null, 1.0339, null, null, null], [0.87421, null, null, null, 0.039264, null, null, null, 2.0633, null, null, null, 1.0433, null, null, null], [1, null, null, null, 1, null, null, null, 1, null, null, null, 1, null, null, null], [10, null, null, null, 10, null, null, null, 10, null, null, null, 10, null, null, null]], "speed.TimeRunMcmc.time_run_mcmc": [[NaN, NaN, 1.119391416956205, 0.9509509579511359, 0.14862164546502754, 0.0016324580064974725, 1.1865209790412337, 0.029436582990456372], [["'compile'", "'run'"], ["0", "10"], ["'cold'", "'warm'"]], "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 1761996955172, 67.026, [null, null, 1.0468, 0.91029, 0.1451, 0.0014815, 1.1656, 0.028909], [null, null, 1.1435, 1.0121, 0.19122, 0.0017985, 1.3533, 0.030238], [null, null, 1.075, 0.92686, 0.14663, 0.0015223, 1.1743, 0.029287], [null, null, 1.1319, 0.9702, 0.15451, 0.0016916, 1.2108, 0.03003], [null, null, 1, 1, 1, 1, 1, 1], [null, null, 10, 10, 10, 10, 10, 10]], "speed.TimeStep.time_step": [[0.8716923129977658, 0.7712800209410489, 0.8728721039951779, 1.2229048334993422, 0.9160569999949075, 1.0702567920088768, 0.003013729117810726, 0.003291833374532871, 0.004813388998930653, 0.003396270753000863, 0.003162864624755457, 0.006512770749395713], [["'compile'", "'run'"], ["'plain'", "'binary'", "'weights'", "'sparse'", "'vmap-1'", "'vmap-2'"]], "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 1761996988663, 225.56, [0.82677, 0.73139, 0.82608, 1.1591, 0.89152, 0.99325, 0.0029709, 0.0031492, 0.0047708, 0.0033145, 0.0031189, 0.0063524], [0.94984, 0.82489, 0.91989, 1.3356, 1.0005, 1.1078, 0.0030658, 0.0033573, 0.0049056, 0.0034595, 0.0032906, 0.0067557], [0.8507, 0.75394, 0.84239, 1.1941, 0.90712, 1.0246, 0.0029931, 0.0032374, 0.0047891, 0.0033697, 0.0031508, 0.0064062], [0.90333, 0.80636, 0.91378, 1.2453, 0.92358, 1.0892, 0.0030351, 0.003323, 0.0048326, 0.0034266, 0.0031988, 0.0066069], [1, 1, 1, 1, 1, 1, 4, 4, 3, 4, 4, 2], [10, 10, 10, 8, 10, 10, 10, 10, 10, 10, 10, 10]], "rmse.EvalGbart.track_rmse": [[0.5359137058258057], [], "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28", 1761996907185, 9.3649]}, "durations": {"": 8.990105867385864}, "version": 2}
--------------------------------------------------------------------------------
/.asv/results/gattocrucco-m1/1119f48f-virtualenv-py3.13.json:
--------------------------------------------------------------------------------
1 | {"commit_hash": "1119f48ff708d49d1a31965ffa0fe540a23cab1d", "env_name": "virtualenv-py3.13", "date": 1761993911000, "params": {"arch": "arm64", "cpu": "Apple M1 Pro", "machine": "gattocrucco-m1", "num_cpu": "8", "os": "Darwin 24.3.0", "ram": "17179869184", "python": "3.13"}, "python": "3.13", "requirements": {}, "env_vars": {}, "result_columns": ["result", "params", "version", "started_at", "duration", "stats_ci_99_a", "stats_ci_99_b", "stats_q_25", "stats_q_75", "stats_number", "stats_repeat", "samples", "profile"], "results": {"speed.TimeGbart.time_gbart": [[0.8992934999987483, NaN, NaN, NaN, 0.04088618699461222, NaN, NaN, NaN, 2.178266270959284, NaN, NaN, NaN, 0.1406069994554855, 0.2530354790505953, 0.8698047290090472, 3.028478687047027], [["0", "10"], ["'cold'", "'warm'"], ["1", "2", "8", "32"]], "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4", 1761997347469, 150.87, [0.86069, null, null, null, 0.038242, null, null, null, 2.0901, null, null, null, 0.13891, 0.25126, 0.85568, 2.9536], [0.93961, null, null, null, 0.043125, null, null, null, 2.2448, null, null, null, 0.14589, 0.29274, 0.90863, 3.2281], [0.86776, null, null, null, 0.040236, null, null, null, 2.131, null, null, null, 0.13976, 0.25276, 0.86375, 3.0193], [0.93173, null, null, null, 0.041769, null, null, null, 2.2277, null, null, null, 0.14129, 0.25456, 0.88325, 3.107], [1, null, null, null, 1, null, null, null, 1, null, null, null, 1, 1, 1, 1], [10, null, null, null, 10, null, null, null, 10, null, null, null, 10, 10, 10, 6]], "speed.TimeRunMcmc.time_run_mcmc": [[NaN, NaN, 1.1052386040100828, 0.9534334164927714, 0.1504268335411325, 0.0014830209547653794, 1.2058203124906868, 0.029219582967925817], [["'compile'", "'run'"], ["0", "10"], ["'cold'", "'warm'"]], "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c", 1761997418190, 67.958, [null, null, 1.0629, 0.90019, 0.1416, 0.0014253, 1.1318, 0.028844], [null, null, 1.1791, 1.0855, 0.15775, 0.0015505, 1.3001, 0.038026], [null, null, 1.087, 0.93143, 0.14874, 0.0014329, 1.1858, 0.029009], [null, null, 1.1291, 1.0156, 0.15259, 0.0015008, 1.2335, 0.02945], [null, null, 1, 1, 1, 1, 1, 1], [null, null, 10, 10, 10, 10, 10, 10]], "speed.TimeStep.time_step": [[0.8965204374981113, 0.7574879370513372, 0.8828994379728101, 1.1899915629765019, 0.9233364164829254, 1.1276098960079253, 0.0030094896210357547, 0.003220703118131496, 0.004806166669974724, 0.0033490677597001195, 0.003208968642866239, 0.006511489511467516], [["'compile'", "'run'"], ["'plain'", "'binary'", "'weights'", "'sparse'", "'vmap-1'", "'vmap-2'"]], "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7", 1761997452786, 227.79, [0.86485, 0.7257, 0.8345, 1.1405, 0.89542, 1.0732, 0.0029289, 0.0031292, 0.0047317, 0.0032789, 0.0030704, 0.0064343], [0.95944, 0.80832, 0.92182, 1.277, 0.99602, 1.2108, 0.0030537, 0.0034741, 0.0048411, 0.0034172, 0.0032448, 0.0067581], [0.87863, 0.7509, 0.85363, 1.1441, 0.90718, 1.0893, 0.0030013, 0.0031874, 0.0047664, 0.0033363, 0.0031489, 0.0064767], [0.92631, 0.77443, 0.8909, 1.2212, 0.94361, 1.167, 0.0030189, 0.00327, 0.0048233, 0.0033708, 0.0032221, 0.0065548], [1, 1, 1, 1, 1, 1, 4, 4, 3, 4, 4, 2], [10, 10, 10, 8, 10, 10, 10, 10, 10, 10, 10, 10]], "rmse.EvalGbart.track_rmse": [[0.5359137058258057], [], "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28", 1761997338166, 9.303]}, "durations": {"": 8.945157051086426}, "version": 2}
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/conftest.py
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Pytest configuration."""
26 |
27 | from contextlib import nullcontext
28 | from re import fullmatch
29 |
30 | import jax
31 | import numpy as np
32 | import pytest
33 |
34 | from bartz.jaxext import get_default_device, split
35 |
36 | jax.config.update('jax_debug_key_reuse', True)
37 | jax.config.update('jax_debug_nans', True)
38 | jax.config.update('jax_debug_infs', True)
39 | jax.config.update('jax_legacy_prng_key', 'error')
40 |
41 |
42 | @pytest.fixture
43 | def keys(request) -> split:
44 | """
45 | Return a deterministic per-test-case list of jax random keys.
46 |
47 | To use a key, do `keys.pop()`. If consumed this way, this list of keys can
48 | be safely used by multiple fixtures involved in the test case.
49 | """
50 | nodeid = request.node.nodeid
51 | # exclude xdist_group suffixes because they are active only under xdist
52 | match = fullmatch(r'(.+?\.py::.+?(\[.+?\])?)(@.+)?', nodeid)
53 | nodeid = match.group(1)
54 | seed = np.array([nodeid], np.bytes_).view(np.uint8)
55 | rng = np.random.default_rng(seed)
56 | seed = np.array(rng.bytes(4)).view(np.uint32)
57 | key = jax.random.key(seed)
58 | return split(key, 128)
59 |
60 |
61 | def pytest_addoption(parser: pytest.Parser) -> None:
62 | """Add custom command line options."""
63 | parser.addoption(
64 | '--platform',
65 | choices=['cpu', 'gpu', 'auto'],
66 | default='auto',
67 | help='JAX platform to use: cpu, gpu, or auto (default: auto)',
68 | )
69 |
70 |
71 | def pytest_sessionstart(session: pytest.Session) -> None:
72 | """Configure and print the jax device."""
73 | # Get the platform option
74 | platform = session.config.getoption('--platform')
75 |
76 | # Set the default JAX device if not auto
77 | if platform != 'auto':
78 | current_platform = get_default_device().platform
79 | if current_platform != platform:
80 | jax.config.update('jax_default_device', jax.devices(platform)[0])
81 | assert get_default_device().platform == platform
82 |
83 | # Get the capture manager plugin
84 | capman = session.config.pluginmanager.get_plugin('capturemanager')
85 |
86 | # Suspend capturing temporarily
87 | if capman:
88 | ctx = capman.global_and_fixture_disabled()
89 | else:
90 | ctx = nullcontext()
91 |
92 | with ctx:
93 | device_kind = get_default_device().device_kind
94 | print(f'jax default device: {device_kind}')
95 |
--------------------------------------------------------------------------------
/docs/pkglist.md:
--------------------------------------------------------------------------------
1 |
26 |
27 | # Other BART packages
28 |
29 | - [stochtree](https://github.com/StochasticTree/stochtree) C++ library with R and Python interfaces
30 | - [bnptools](https://github.com/rsparapa/bnptools) Feature-rich R packages for BART and some variants
31 | - [vdorie](https://github.com/vdorie)'s repositories (dbarts, bartCause, stan4bart)
32 | - [bartMachine](https://github.com/kapelner/bartMachine) R package, supports missing predictors imputation
33 | - [SoftBART](https://github.com/theodds/SoftBART) R package with a smooth version of BART
34 | - [jaredsmurray](https://github.com/jaredsmurray)'s repositories (bcf, monbart)
35 | - [skdeshpande91](https://github.com/skdeshpande91)'s repositories (flexBART, flexBCF, VCBART)
36 | - [JingyuHe](https://github.com/JingyuHe)'s repositories (XBART & related)
37 | - [BayesTree](https://cran.r-project.org/package=BayesTree) R package, original BART implementation
38 | - [OpenBT](https://bitbucket.org/mpratola/openbt) Heteroskedastic BART, rotate & perturb proposals, C++ library with R interface
39 | - [OpenBT](https://github.com/jcyannotty/OpenBT) fork of the above
40 | - [lsqfitgp](https://github.com/Gattocrucco/lsqfitgp) Infinite trees limit of BART
41 | - [mBART](https://github.com/remcc/mBART_shlib)
42 | - [SequentialBART](https://github.com/mjdaniels/SequentialBART)
43 | - [sparseBART](https://github.com/cspanbauer/sparseBART)
44 | - [pymc-bart](https://github.com/pymc-devs/pymc-bart) BART within PyMC
45 | - [semibart](https://github.com/zeldow/semibart)
46 | - [ebprado](https://github.com/ebprado)'s repositories
47 | - [EoghanONeill](https://github.com/EoghanONeill)'s repositories
48 | - [MateusMaiaDS](https://github.com/MateusMaiaDS)'s repositories (subart, gpbart)
49 | - [nchenderson](https://github.com/nchenderson)'s repositories
50 | - [bartpy](https://github.com/JakeColtman/bartpy)
51 | - [BayesTreePrior](https://github.com/AlexiaJM/BayesTreePrior) Sample the prior of BART
52 | - [BayesTree.jl](https://github.com/mathcg/BayesTree.jl)
53 | - [longbet](https://github.com/google/longbet)
54 | - [richael008](https://github.com/richael008)'s repositories
55 | - [bartMan](https://github.com/AlanInglis/bartMan) R package, posterior analysis and diagnostics
56 | - [drbart](https://github.com/vittorioorlandi/drbart) Density regression with BART
57 | - [BART-BMA](https://github.com/BelindaHernandez/BART-BMA) (superseded by [bartBMAnew](https://github.com/EoghanONeill/bartBMAnew))
58 | - [XBCF](https://github.com/socket778/XBCF) (superseded by StochasticTree)
59 | - [mpibart](https://matthewpratola.com/mpibart) Old parallel implementation
60 | - [HE-BART](https://github.com/brunaw/HE-BART)
61 | - [pgbart](https://github.com/balajiln/pgbart) Particle Gibbs BART
62 | - [skewBART](https://github.com/Seungha-Um/skewBART) multivariate skewed responses
63 |
--------------------------------------------------------------------------------
/tests/test_mcmcstep.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/test_mcmcstep.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Test `bartz.mcmcstep`."""
26 |
27 | from jax import numpy as jnp
28 | from jax import vmap
29 | from jax.random import bernoulli, clone, permutation, randint
30 | from jaxtyping import Array, Bool, Int32, Key
31 | from scipy import stats
32 |
33 | from bartz.jaxext import split
34 | from bartz.mcmcstep import randint_masked
35 |
36 |
37 | def vmap_randint_masked(
38 | key: Key[Array, ''], mask: Bool[Array, ' n'], size: int
39 | ) -> Int32[Array, '* n']:
40 | """Vectorized version of `randint_masked`."""
41 | vrm = vmap(randint_masked, in_axes=(0, None))
42 | keys = split(key, 1)
43 | return vrm(keys.pop(size), mask)
44 |
45 |
46 | class TestRandintMasked:
47 | """Test `mcmcstep.randint_masked`."""
48 |
49 | def test_all_false(self, keys):
50 | """Check what happens when no value is allowed."""
51 | for size in range(1, 10):
52 | u = randint_masked(keys.pop(), jnp.zeros(size, bool))
53 | assert u == size
54 |
55 | def test_all_true(self, keys):
56 | """Check it's equivalent to `randint` when all values are allowed."""
57 | key = keys.pop()
58 | size = 10_000
59 | u1 = randint_masked(key, jnp.ones(size, bool))
60 | u2 = randint(clone(key), (), 0, size)
61 | assert u1 == u2
62 |
63 | def test_no_disallowed_values(self, keys):
64 | """Check disallowed values are never selected."""
65 | key = keys.pop()
66 | for _ in range(100):
67 | keys = split(key, 3)
68 | mask = bernoulli(keys.pop(), 0.5, (10,))
69 | if not jnp.any(mask): # pragma: no cover, rarely happens
70 | continue
71 | u = randint_masked(keys.pop(), mask)
72 | assert 0 <= u < mask.size
73 | assert mask[u]
74 | key = keys.pop()
75 |
76 | def test_correct_distribution(self, keys):
77 | """Check the distribution of values is uniform."""
78 | # create mask
79 | num_allowed = 10
80 | mask = jnp.zeros(2 * num_allowed, bool)
81 | mask = mask.at[:num_allowed].set(True)
82 | indices = jnp.arange(mask.size)
83 | indices = permutation(keys.pop(), indices)
84 | mask = mask[indices]
85 |
86 | # sample values
87 | n = 10_000
88 | u: Int32[Array, '{n}'] = vmap_randint_masked(keys.pop(), mask, n)
89 | u = indices[u]
90 | assert jnp.all(u < num_allowed)
91 |
92 | # check that the distribution is uniform
93 | # likelihood ratio test for multinomial with free p vs. constant p
94 | k = jnp.bincount(u, length=num_allowed)
95 | llr = jnp.sum(jnp.where(k, k * jnp.log(k / n * num_allowed), 0))
96 | lamda = 2 * llr
97 | pvalue = stats.chi2.sf(lamda, num_allowed - 1)
98 | assert pvalue > 0.1
99 |
--------------------------------------------------------------------------------
/tests/test_mcmcloop.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/test_mcmcloop.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Test `bartz.mcmcloop`."""
26 |
27 | from functools import partial
28 |
29 | from equinox import filter_jit
30 | from jax import numpy as jnp
31 | from jax import vmap
32 | from jax.tree import map_with_path
33 | from jax.tree_util import tree_map
34 | from jaxtyping import Array, Float32, UInt8
35 | from numpy.testing import assert_array_equal
36 |
37 | from bartz import mcmcloop, mcmcstep
38 |
39 |
40 | def gen_data(
41 | p: int, n: int
42 | ) -> tuple[UInt8[Array, 'p n'], Float32[Array, ' n'], UInt8[Array, ' p']]:
43 | """Generate pretty nonsensical data."""
44 | X = jnp.arange(p * n, dtype=jnp.uint8).reshape(p, n)
45 | X = vmap(jnp.roll)(X, jnp.arange(p))
46 | max_split = jnp.full(p, 255, jnp.uint8)
47 | y = jnp.cos(jnp.linspace(0, 2 * jnp.pi / 32 * n, n))
48 | return X, y, max_split
49 |
50 |
51 | def make_p_nonterminal(maxdepth: int) -> Float32[Array, ' {maxdepth}-1']:
52 | """Prepare the p_nonterminal argument to `mcmcstep.init`."""
53 | depth = jnp.arange(maxdepth - 1)
54 | base = 0.95
55 | power = 2
56 | return base / (1 + depth).astype(float) ** power
57 |
58 |
59 | @filter_jit
60 | def init(p: int, n: int, ntree: int, **kwargs):
61 | """Simplified version of `bartz.mcmcstep.init` with data pre-filled."""
62 | X, y, max_split = gen_data(p, n)
63 | return mcmcstep.init(
64 | X=X,
65 | y=y,
66 | max_split=max_split,
67 | num_trees=ntree,
68 | p_nonterminal=make_p_nonterminal(6),
69 | sigma_mu2=1.0,
70 | sigma2_alpha=1,
71 | sigma2_beta=1,
72 | min_points_per_decision_node=10,
73 | filter_splitless_vars=False,
74 | **kwargs,
75 | )
76 |
77 |
78 | class TestRunMcmc:
79 | """Test `mcmcloop.run_mcmc`."""
80 |
81 | def test_final_state_overflow(self, keys):
82 | """Check that the final state is the one in the trace even if there's overflow."""
83 | initial_state = init(10, 100, 20)
84 | final_state, _, main_trace = mcmcloop.run_mcmc(
85 | keys.pop(), initial_state, 10, inner_loop_length=9
86 | )
87 |
88 | assert_array_equal(final_state.forest.leaf_tree, main_trace.leaf_tree[-1])
89 | assert_array_equal(final_state.forest.var_tree, main_trace.var_tree[-1])
90 | assert_array_equal(final_state.forest.split_tree, main_trace.split_tree[-1])
91 | assert_array_equal(final_state.sigma2, main_trace.sigma2[-1])
92 |
93 | def test_zero_iterations(self, keys):
94 | """Check there's no error if the loop does not run."""
95 | initial_state = init(10, 100, 20)
96 | final_state, burnin_trace, main_trace = mcmcloop.run_mcmc(
97 | keys.pop(), initial_state, 0, n_burn=0
98 | )
99 |
100 | tree_map(partial(assert_array_equal, strict=True), initial_state, final_state)
101 |
102 | def assert_empty_trace(path, x): # noqa: ARG001, for debugging
103 | assert x.shape[0] == 0
104 |
105 | map_with_path(assert_empty_trace, burnin_trace)
106 | map_with_path(assert_empty_trace, main_trace)
107 |
--------------------------------------------------------------------------------
/benchmarks/rmse.py:
--------------------------------------------------------------------------------
1 | # bartz/benchmarks/rmse.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Measure the predictive performance on test sets."""
26 |
27 | from contextlib import redirect_stdout
28 | from dataclasses import dataclass
29 | from functools import partial
30 | from io import StringIO
31 |
32 | from jax import jit, random, vmap
33 | from jax import numpy as jnp
34 |
35 | try:
36 | from bartz.BART import gbart
37 | except ImportError:
38 | from bartz import BART as gbart
39 |
40 |
41 | @partial(jit, static_argnums=(1, 2))
42 | def simulate_data(key, n: int, p: int, max_interactions):
43 | """Simulate data for regression.
44 |
45 | This uses data-based standardization, so you have to generate train &
46 | test at once.
47 | """
48 | # split random key
49 | keys = list(random.split(key, 4))
50 |
51 | # generate matrices
52 | X = random.uniform(keys.pop(), (p, n))
53 | beta = random.normal(keys.pop(), (p,))
54 | A = random.normal(keys.pop(), (p, p))
55 | error = random.normal(keys.pop(), (n,))
56 |
57 | # make A banded to limit the number of interactions
58 | num_nonzero = 1 + (max_interactions - 1) // 2
59 | num_nonzero = jnp.clip(num_nonzero, 0, p)
60 | interaction_pattern = jnp.arange(p) < num_nonzero
61 | multi_roll = vmap(jnp.roll, in_axes=(None, 0))
62 | nonzero = multi_roll(interaction_pattern, jnp.arange(p))
63 | A *= nonzero
64 |
65 | # compute terms
66 | linear = beta @ X
67 | quadratic = jnp.einsum('ai,bi,ab->i', X, X, A)
68 |
69 | # equalize the terms
70 | mu = linear / jnp.std(linear) + quadratic / jnp.std(quadratic)
71 | mu /= jnp.std(mu) # because linear and quadratic are correlated
72 |
73 | return X, mu, error
74 |
75 |
76 | @dataclass(frozen=True)
77 | class Data:
78 | """Data for regression."""
79 |
80 | X_train: jnp.ndarray
81 | mu_train: jnp.ndarray
82 | error_train: jnp.ndarray
83 | X_test: jnp.ndarray
84 | mu_test: jnp.ndarray
85 | error_test: jnp.ndarray
86 |
87 | @property
88 | def y_train(self):
89 | """Return the training targets."""
90 | return self.mu_train + self.error_train
91 |
92 | @property
93 | def y_test(self):
94 | """Return the test targets."""
95 | return self.mu_test + self.error_test
96 |
97 |
98 | def make_data(key, n_train: int, n_test: int, p: int) -> Data:
99 | """Simulate data and split in train-test set."""
100 | X, mu, error = simulate_data(key, n_train + n_test, p, 5)
101 | return Data(
102 | X[:, :n_train],
103 | mu[:n_train],
104 | error[:n_train],
105 | X[:, n_train:],
106 | mu[n_train:],
107 | error[n_train:],
108 | )
109 |
110 |
111 | class EvalGbart:
112 | """Out-of-sample evaluation of gbart."""
113 |
114 | timeout = 30.0
115 | unit = 'latent_sdev'
116 |
117 | def track_rmse(self) -> float:
118 | """Return the RMSE for predictions on a test set."""
119 | key = random.key(2025_06_26_21_02)
120 | data = make_data(key, 100, 1000, 20)
121 | with redirect_stdout(StringIO()):
122 | bart = gbart(
123 | data.X_train,
124 | data.y_train,
125 | x_test=data.X_test,
126 | nskip=1000,
127 | ndpost=1000,
128 | seed=key,
129 | )
130 | return jnp.sqrt(jnp.mean(jnp.square(bart.yhat_test_mean - data.mu_test))).item()
131 |
--------------------------------------------------------------------------------
/tests/util.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/util.py
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Functions intended to be shared across the test suite."""
26 |
27 | from pathlib import Path
28 |
29 | import numpy as np
30 | import tomli
31 | from jaxtyping import ArrayLike
32 | from scipy import linalg
33 |
34 |
35 | def assert_close_matrices(
36 | actual: ArrayLike,
37 | desired: ArrayLike,
38 | *,
39 | rtol: float = 0.0,
40 | atol: float = 0.0,
41 | tozero: bool = False,
42 | ):
43 | """
44 | Check if two matrices are similar.
45 |
46 | Parameters
47 | ----------
48 | actual
49 | desired
50 | The two matrices to be compared. Must be scalars, vectors, or 2d arrays.
51 | Scalars and vectors are intepreted as 1x1 and Nx1 matrices, but the two
52 | arrays must have the same shape beforehand.
53 | rtol
54 | atol
55 | Relative and absolute tolerances for the comparison. The closeness
56 | condition is:
57 |
58 | ||actual - desired|| <= atol + rtol * ||desired||,
59 |
60 | where the norm is the matrix 2-norm, i.e., the maximum (in absolute
61 | value) singular value.
62 | tozero
63 | If True, use the following codition instead:
64 |
65 | ||actual|| <= atol + rtol * ||desired||
66 |
67 | So `actual` is compared to zero, and `desired` is only used as a
68 | reference to set the threshold.
69 |
70 | Raises
71 | ------
72 | ValueError
73 | If the two matrices have different shapes.
74 | """
75 | actual = np.asarray(actual)
76 | desired = np.asarray(desired)
77 | if actual.shape != desired.shape:
78 | msg = f'{actual.shape=} != {desired.shape=}'
79 | raise ValueError(msg)
80 | if actual.size > 0:
81 | actual = np.atleast_1d(actual)
82 | desired = np.atleast_1d(desired)
83 |
84 | if tozero:
85 | expr = 'actual'
86 | ref = 'zero'
87 | else:
88 | expr = 'actual - desired'
89 | ref = 'desired'
90 |
91 | dnorm = linalg.norm(desired, 2)
92 | adnorm = linalg.norm(eval(expr), 2) # noqa: S307, expr is a literal
93 | ratio = adnorm / dnorm if dnorm else np.nan
94 |
95 | msg = f"""\
96 | matrices actual and {ref} are not close in 2-norm
97 | matrix shape: {desired.shape}
98 | norm(desired) = {dnorm:.2g}
99 | norm({expr}) = {adnorm:.2g} (atol = {atol:.2g})
100 | ratio = {ratio:.2g} (rtol = {rtol:.2g})"""
101 |
102 | assert adnorm <= atol + rtol * dnorm, msg
103 |
104 |
105 | def get_old_python_str() -> str:
106 | """Read the oldest supported Python from pyproject.toml."""
107 | with Path('pyproject.toml').open('rb') as file:
108 | return tomli.load(file)['project']['requires-python'].removeprefix('>=')
109 |
110 |
111 | def get_old_python_tuple() -> tuple[int, int]:
112 | """Read the oldest supported Python from pyproject.toml as a tuple."""
113 | ver_str = get_old_python_str()
114 | major, minor = ver_str.split('.')
115 | return int(major), int(minor)
116 |
117 |
118 | def get_version() -> str:
119 | """Read the bartz version from pyproject.toml."""
120 | with Path('pyproject.toml').open('rb') as file:
121 | return tomli.load(file)['project']['version']
122 |
123 |
124 | def update_version():
125 | """Update the version file."""
126 | version = get_version()
127 | Path('src/bartz/_version.py').write_text(f'__version__ = {version!r}\n')
128 |
--------------------------------------------------------------------------------
/tests/rbartpackages/dbarts.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/rbartpackages/dbarts.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Python wrapper of the R package `dbarts`."""
26 |
27 | # ruff: noqa: D101, D102
28 |
29 | from rpy2 import robjects
30 |
31 | from tests.rbartpackages._base import RObjectBase, rmethod
32 |
33 |
34 | class bart(RObjectBase):
35 | """
36 |
37 | Python interface to dbarts::bart.
38 |
39 | The named numeric vector form of the `splitprobs` parameter must be
40 | specified as a dictionary in Python.
41 |
42 | """
43 |
44 | _rfuncname = 'dbarts::bart'
45 | _split_probs = 'splitprobs'
46 |
47 | def __init__(self, *args, **kw):
48 | split_probs = kw.get(self._split_probs)
49 | if isinstance(split_probs, dict):
50 | values = list(split_probs.values())
51 | names = list(split_probs.keys())
52 | split_probs = robjects.FloatVector(values)
53 | split_probs = robjects.r('setNames')(split_probs, names)
54 | kw[self._split_probs] = split_probs
55 |
56 | super().__init__(*args, **kw)
57 |
58 | @rmethod
59 | def predict(self, *args, **kw): ...
60 |
61 | @rmethod
62 | def extract(self, *args, **kw): ...
63 |
64 | @rmethod
65 | def fitted(self, *args, **kw): ...
66 |
67 |
68 | class bart2(bart):
69 | """
70 |
71 | Python interface to dbarts::bart2.
72 |
73 | The named numeric vector form of the `split_probs` parameter must be
74 | specified as a dictionary in Python.
75 |
76 | """
77 |
78 | _rfuncname = 'dbarts::bart2'
79 | _split_probs = 'split_probs'
80 |
81 | def __init__(self, formula, *args, **kw):
82 | formula = robjects.Formula(formula)
83 | super().__init__(formula, *args, **kw)
84 |
85 |
86 | class rbart_vi(bart2):
87 | """
88 |
89 | Python interface to dbarts::rbart_vi.
90 |
91 | The named numeric vector form of the `split_probs` parameter must be
92 | specified as a dictionary in Python.
93 |
94 | """
95 |
96 | _rfuncname = 'dbarts::rbart_vi'
97 |
98 |
99 | class dbarts(RObjectBase):
100 | _rfuncname = 'dbarts::dbarts'
101 |
102 | @rmethod
103 | def run(self, *args, **kw): ...
104 |
105 | @rmethod
106 | def sampleTreesFromPrior(self, *args, **kw): ...
107 |
108 | @rmethod
109 | def sampleNodeParametersFromPrior(self, *args, **kw): ...
110 |
111 | @rmethod
112 | def copy(self, *args, **kw): ...
113 |
114 | @rmethod
115 | def show(self, *args, **kw): ...
116 |
117 | @rmethod
118 | def predict(self, *args, **kw): ...
119 |
120 | @rmethod
121 | def setControl(self, *args, **kw): ...
122 |
123 | @rmethod
124 | def setModel(self, *args, **kw): ...
125 |
126 | @rmethod
127 | def setData(self, *args, **kw): ...
128 |
129 | @rmethod
130 | def setResponse(self, *args, **kw): ...
131 |
132 | @rmethod
133 | def setOffset(self, *args, **kw): ...
134 |
135 | @rmethod
136 | def setSigma(self, *args, **kw): ...
137 |
138 | @rmethod
139 | def setPredictor(self, *args, **kw): ...
140 |
141 | @rmethod
142 | def setTestPredictor(self, *args, **kw): ...
143 |
144 | @rmethod
145 | def setTestPredictorAndOffset(self, *args, **kw): ...
146 |
147 | @rmethod
148 | def setTestOffset(self, *args, **kw): ...
149 |
150 | @rmethod
151 | def printTrees(self, *args, **kw): ...
152 |
153 | @rmethod
154 | def plotTree(self, *args, **kw): ...
155 |
156 |
157 | class dbartsControl(RObjectBase):
158 | _rfuncname = 'dbarts::dbartsControl'
159 |
--------------------------------------------------------------------------------
/tests/test_meta.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/test_meta.py
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Test properties of pytest itself or other utilities."""
26 |
27 | from functools import partial
28 |
29 | import jax
30 | import pytest
31 | from jax import jit, random
32 | from jax import numpy as jnp
33 | from jax.errors import KeyReuseError
34 |
35 |
36 | @pytest.fixture
37 | def keys1(keys):
38 | """Pass-through the `keys` fixture."""
39 | return keys
40 |
41 |
42 | @pytest.fixture
43 | def keys2(keys):
44 | """Pass-through the `keys` fixture."""
45 | return keys
46 |
47 |
48 | def test_random_keys_do_not_depend_on_fixture(keys1, keys2):
49 | """Check that the `keys` fixture is per-test-case, not per-fixture."""
50 | assert keys1 is keys2
51 |
52 |
53 | def test_number_of_random_keys(keys):
54 | """Check the fixed number of available keys.
55 |
56 | This is here just as reference for the `test_random_keys_are_consumed` test
57 | below.
58 | """
59 | assert len(keys) == 128
60 |
61 |
62 | @pytest.fixture
63 | def consume_one_key(keys): # noqa: D103
64 | return keys.pop()
65 |
66 |
67 | @pytest.fixture
68 | def consume_another_key(keys): # noqa: D103
69 | return keys.pop()
70 |
71 |
72 | def test_random_keys_are_consumed(consume_one_key, consume_another_key, keys): # noqa: ARG001
73 | """Check that the random keys in `keys` can't be used more than once."""
74 | assert len(keys) == 126
75 |
76 |
77 | def test_debug_key_reuse(keys):
78 | """Check that the jax debug_key_reuse option works."""
79 | key = keys.pop()
80 | random.uniform(key)
81 | with pytest.raises(KeyReuseError):
82 | random.uniform(key)
83 |
84 |
85 | def test_debug_key_reuse_within_jit(keys):
86 | """Check that the jax debug_key_reuse option works within a jitted function."""
87 |
88 | @jit
89 | def func(key):
90 | return random.uniform(key) + random.uniform(key)
91 |
92 | with pytest.raises(KeyReuseError):
93 | func(keys.pop())
94 |
95 |
96 | class TestJaxNoCopyBehavior:
97 | """Check whether jax makes actual copies of arrays in various conditions."""
98 |
99 | def test_unconditional_buffer_donation(self):
100 | """Test jax donates buffers even if they are small."""
101 | # nan-debug mode makes jax create some copies apparently
102 | with jax.debug_nans(False):
103 | # check buffer donation works unconditionally
104 | x = jnp.arange(100)
105 | xp = x.unsafe_buffer_pointer()
106 |
107 | @partial(jit, donate_argnums=(0,))
108 | def noop(x):
109 | return x
110 |
111 | y = noop(x)
112 | yp = y.unsafe_buffer_pointer()
113 |
114 | assert xp == yp
115 | with pytest.raises(RuntimeError, match=r'delete'):
116 | x[0]
117 |
118 | def test_jnp_array_copy_no_jit(self):
119 | """Test jnp.array makes copies outside jitted functions."""
120 | y = jnp.arange(100)
121 | yp = y.unsafe_buffer_pointer()
122 |
123 | z = jnp.array(y)
124 | zp = z.unsafe_buffer_pointer()
125 |
126 | assert zp != yp
127 |
128 | def test_jnp_array_no_copy_jit(self):
129 | """Check jnp.array does not make copies within jit."""
130 | # nan-debug mode makes jax create some copies apparently
131 | with jax.debug_nans(False):
132 | y = jnp.arange(100)
133 | yp = y.unsafe_buffer_pointer()
134 |
135 | @partial(jit, donate_argnums=(0,))
136 | def array(x):
137 | return jnp.array(x)
138 |
139 | q = array(y)
140 | qp = q.unsafe_buffer_pointer()
141 |
142 | assert qp == yp
143 |
--------------------------------------------------------------------------------
/tests/rbartpackages/BART.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/rbartpackages/BART.py
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Wrapper for the R package BART."""
26 |
27 | from typing import NamedTuple, TypedDict, cast
28 |
29 | import numpy as np
30 | from jaxtyping import AbstractDtype, Bool, Float64, Int32
31 | from numpy import ndarray
32 | from rpy2.rlike.container import NamedList
33 |
34 | from tests.rbartpackages._base import RObjectBase, rmethod
35 |
36 |
37 | class TreeDraws(TypedDict):
38 | """Type of the `treedraws` attribute of `mc_gbart`."""
39 |
40 | cutpoints: dict[int | str, Float64[ndarray, ' numcut[i]']]
41 | trees: str
42 |
43 |
44 | class String(AbstractDtype):
45 | """Represent a `numpy.str_` data dtype."""
46 |
47 | dtypes = r' 0):
104 | self.rm_const -= 1
105 | else:
106 | msg = 'failed to parse rm.const because indices change sign'
107 | raise ValueError(msg)
108 |
109 | if self.sigma_mean is not None:
110 | self.sigma_mean = self.sigma_mean.item()
111 |
112 | r_treedraws = cast(NamedList, self.treedraws)
113 | cutpoints: NamedList = r_treedraws.getbyname('cutpoints')
114 | self.treedraws = {
115 | 'cutpoints': {
116 | i if it.name is None else it.name.item(): it.value
117 | for i, it in enumerate(cutpoints.items())
118 | },
119 | 'trees': r_treedraws.getbyname('trees').item(),
120 | }
121 |
122 | @rmethod
123 | def predict(
124 | self, newdata: Float64[ndarray, 'm p'], *args, **kwargs
125 | ) -> Float64[ndarray, 'ndpost/mc_cores m']:
126 | """Compute predictions."""
127 | ...
128 |
129 |
130 | class bartModelMatrix(RObjectBase): # noqa: D101 because the R doc is added automatically
131 | _rfuncname = 'BART::bartModelMatrix'
132 |
133 |
134 | class gbart(mc_gbart): # noqa: D101 because the R doc is added automatically
135 | _rfuncname = 'BART::gbart'
136 |
137 | sigma: Float64[ndarray, ' nskip+ndpost'] | None = None
138 |
--------------------------------------------------------------------------------
/tests/test_prepcovars.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/test_prepcovars.py
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Test the `bartz.prepcovars` module."""
26 |
27 | import pytest
28 | from jax import debug_infs
29 | from jax import numpy as jnp
30 | from numpy.testing import assert_array_equal
31 |
32 | from bartz.prepcovars import bin_predictors, quantilized_splits_from_matrix
33 |
34 |
35 | class TestQuantilizer:
36 | """Test `prepcovars.quantilized_splits_from_matrix`."""
37 |
38 | @pytest.mark.parametrize(
39 | 'fill_value', [jnp.finfo(jnp.float32).max, jnp.iinfo(jnp.int32).max]
40 | )
41 | def test_splits_fill(self, fill_value):
42 | """Check how predictors with less unique values are right-padded."""
43 | with debug_infs(not jnp.isinf(fill_value)):
44 | fill_value = jnp.array(fill_value)
45 | x = jnp.array([[1, 1, 3, 3], [1, 3, 3, 5], [1, 3, 5, 7]], fill_value.dtype)
46 | splits, _ = quantilized_splits_from_matrix(x, 100)
47 | expected_splits = [[2, fill_value, fill_value], [2, 4, fill_value], [2, 4, 6]]
48 | assert_array_equal(splits, expected_splits)
49 |
50 | def test_max_splits(self):
51 | """Check that the number of splits per predictor is counted correctly."""
52 | x = jnp.array([[1, 1, 1, 1], [4, 4, 1, 1], [2, 1, 3, 2], [1, 4, 2, 3]])
53 | _, max_split = quantilized_splits_from_matrix(x, 100)
54 | assert_array_equal(max_split, jnp.arange(4))
55 |
56 | def test_integer_splits_overflow(self):
57 | """Check that the splits are computed correctly at the limit of overflow."""
58 | x = jnp.array([[-(2**31), 2**31 - 2]])
59 | splits, _ = quantilized_splits_from_matrix(x, 100)
60 | expected_splits = [[-1]]
61 | assert_array_equal(splits, expected_splits)
62 |
63 | @pytest.mark.parametrize('dtype', [int, float])
64 | def test_splits_type(self, dtype):
65 | """Check that the input type is preserved."""
66 | x = jnp.arange(10, dtype=dtype)[None, :]
67 | splits, _ = quantilized_splits_from_matrix(x, 100)
68 | assert splits.dtype == x.dtype
69 |
70 | def test_splits_length(self):
71 | """Check that the correct number of splits is returned in corner cases."""
72 | x = jnp.linspace(0, 1, 10)[None, :]
73 |
74 | short_splits, _ = quantilized_splits_from_matrix(x, 2)
75 | assert short_splits.shape == (1, 1)
76 |
77 | long_splits, _ = quantilized_splits_from_matrix(x, 100)
78 | assert long_splits.shape == (1, 9)
79 |
80 | just_right_splits, _ = quantilized_splits_from_matrix(x, 10)
81 | assert just_right_splits.shape == (1, 9)
82 |
83 | no_splits, _ = quantilized_splits_from_matrix(x, 1)
84 | assert no_splits.shape == (1, 0)
85 |
86 | def test_round_trip(self):
87 | """Check that `bin_predictors` is the ~inverse of `quantilized_splits_from_matrix`."""
88 | x = jnp.arange(10)[None, :]
89 | splits, _ = quantilized_splits_from_matrix(x, 100)
90 | b = bin_predictors(x, splits)
91 | assert_array_equal(x, b)
92 |
93 | def test_one_value(self):
94 | """Check there's only 1 bin (0 splits) if there is 1 datapoint."""
95 | x = jnp.arange(10)[:, None]
96 | _, max_split = quantilized_splits_from_matrix(x, 100)
97 | assert_array_equal(max_split, jnp.full(len(x), 0))
98 |
99 | def test_zero_values(self):
100 | """Check what happens when no binning is possible."""
101 | x = jnp.empty((1, 0))
102 | with pytest.raises(ValueError, match='at least 1'):
103 | quantilized_splits_from_matrix(x, 100)
104 |
105 | def test_zero_bins(self):
106 | """Check what happens when no binning is possible."""
107 | x = jnp.arange(10)[None, :]
108 | with pytest.raises(ValueError, match='at least 1'):
109 | quantilized_splits_from_matrix(x, 0)
110 |
111 |
112 | def test_binner_left_boundary():
113 | """Check that the first bin is right-closed."""
114 | splits = jnp.array([[1, 2, 3]])
115 |
116 | x = jnp.array([[0, 1]])
117 | b = bin_predictors(x, splits)
118 | assert_array_equal(b, [[0, 0]])
119 |
120 |
121 | def test_binner_right_boundary():
122 | """Check that the next-to-last bin is right-closed."""
123 | splits = jnp.array([[1, 2, 3, 2**31 - 1]])
124 |
125 | x = jnp.array([[2**31 - 1]])
126 | b = bin_predictors(x, splits)
127 | assert_array_equal(b, [[3]])
128 |
--------------------------------------------------------------------------------
/docs/development.rst:
--------------------------------------------------------------------------------
1 | .. bartz/docs/development.rst
2 | ..
3 | .. Copyright (c) 2024-2025, The Bartz Contributors
4 | ..
5 | .. This file is part of bartz.
6 | ..
7 | .. Permission is hereby granted, free of charge, to any person obtaining a copy
8 | .. of this software and associated documentation files (the "Software"), to deal
9 | .. in the Software without restriction, including without limitation the rights
10 | .. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | .. copies of the Software, and to permit persons to whom the Software is
12 | .. furnished to do so, subject to the following conditions:
13 | ..
14 | .. The above copyright notice and this permission notice shall be included in all
15 | .. copies or substantial portions of the Software.
16 | ..
17 | .. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | .. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | .. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | .. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | .. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | .. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | .. SOFTWARE.
24 |
25 | Development
26 | ===========
27 |
28 | Initial setup
29 | -------------
30 |
31 | `Fork `_ the repository on Github, then clone the fork:
32 |
33 | .. code-block:: shell
34 |
35 | git clone git@github.com:YourGithubUserName/bartz.git
36 | cd bartz
37 |
38 | Install `R `_ and `uv `_ (for example, with `Homebrew `_ do :literal:`brew install r uv`). Then run
39 |
40 | .. code-block:: shell
41 |
42 | make setup
43 |
44 | to set up the Python and R environments.
45 |
46 | The Python environment is managed by uv. To run commands that involve the Python installation, do :literal:`uv run `. For example, to start an IPython shell, do :literal:`uv run ipython`. Alternatively, do :literal:`source .venv/bin/activate` to activate the virtual environment in the current shell.
47 |
48 | The R environment is automatically active when you use :literal:`R` in the project directory.
49 |
50 | Pre-defined commands
51 | --------------------
52 |
53 | Development commands are defined in a makefile. Run :literal:`make` without arguments to list the targets.
54 |
55 | Documentation
56 | -------------
57 |
58 | To build the documentation for the current working copy, do
59 |
60 | .. code-block:: shell
61 |
62 | make docs
63 |
64 | To build the documentation for the latest release tag, do
65 |
66 | .. code-block:: shell
67 |
68 | make docs-latest
69 |
70 | To debug the documentation build, do
71 |
72 | .. code-block:: shell
73 |
74 | make docs SPHINXOPTS='--fresh-env --pdb'
75 |
76 | Benchmarks
77 | ----------
78 |
79 | The benchmarks are managed with `asv `_. The basic asv workflow is:
80 |
81 | .. code-block:: shell
82 |
83 | uv run asv run # run and save benchmarks on main branch
84 | uv run asv publish # create html report
85 | uv run asv preview # start a local server to view the report
86 |
87 | :literal:`asv run` writes the results into files saved in :literal:`./benchmarks`. These files are tracked by git; consider deliberately not committing all results generated while developing.
88 |
89 | There are a few make targets for common asv commands. The most useful command during development is
90 |
91 | .. code-block:: shell
92 |
93 | make asv-quick ARGS='--bench '
94 |
95 | This runs only benchmarks whose name matches , only once, within the working copy and current Python environment.
96 |
97 | Profiling
98 | ---------
99 |
100 | Use the `JAX profiling utilities `_ to profile `bartz`. By default the MCMC loop is compiled all at once, which makes it quite opaque to profiling. There are two ways to understand what's going on inside in more detail: 1) inspect the individual operations and use intuition to understand to what piece of code they correspond to, 2) turn on bartz's profile mode. Basic workflow:
101 |
102 | .. code-block:: python
103 |
104 | from jax.profiler import trace, ProfileOptions
105 | from bartz.BART import gbart
106 | from bartz import profile_mode
107 |
108 | traceopt = ProfileOptions()
109 |
110 | # this setting makes Python function calls show up in the trace
111 | traceopt.python_tracer_level = 1
112 |
113 | # on cpu, this makes the trace detailed enough to understand what's going on
114 | # even within compiled functions
115 | traceopt.host_tracer_level = 2
116 |
117 | with trace('./trace_results', profiler_options=traceopt), profile_mode(True):
118 | bart = gbart(...)
119 |
120 | On the first run, the trace will show compilation operations, while subsequent runs (within the same Python shell) will be warmed-up. Start a xprof server to visualize the results:
121 |
122 | .. code-block:: shell
123 |
124 | $ uvx --python 3.13 xprof ./trace_results
125 | [...]
126 | XProf at http://localhost:8791/ (Press CTRL+C to quit)
127 |
128 | Open the provided URL in a browser. In the sidebar, select the tool "Trace Viewer".
129 |
130 | In "profile mode", the MCMC loop is split into a few chunks that are compiled separately, allowing to see at a glance how much time each phase of the MCMC cycle takes. This causes some overhead, so the timings are not equivalent to the normal mode ones. On some specific example on CPU, Bartz was 20% slower in profile mode with one chain, and 2x slower with multiple chains.
131 |
--------------------------------------------------------------------------------
/tests/rbartpackages/BART3.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/rbartpackages/BART3.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Wrapper for the R package BART3."""
26 |
27 | from typing import NamedTuple, TypedDict
28 |
29 | import numpy as np
30 | from jaxtyping import AbstractDtype, Float64, Int32
31 | from numpy import ndarray
32 |
33 | from tests.rbartpackages._base import RObjectBase, rmethod
34 |
35 |
36 | class TreeDraws(TypedDict):
37 | """Type of the `treedraws` attribute of `mc_gbart`."""
38 |
39 | cutpoints: dict[int | str, Float64[ndarray, ' numcut[i]']]
40 | trees: str
41 |
42 |
43 | class String(AbstractDtype):
44 | """Represent a `numpy.str_` data dtype."""
45 |
46 | dtypes = r' 0):
111 | self.rm_const -= 1
112 | else:
113 | msg = 'failed to parse rm.const because indices change sign'
114 | raise ValueError(msg)
115 |
116 | if self.sigest is not None:
117 | self.sigest = self.sigest.item()
118 | if self.sigma_mean is not None:
119 | self.sigma_mean = self.sigma_mean.item()
120 |
121 | if hasattr(self.treedraws, 'getbyname'):
122 | # it's a NamedList
123 | self.treedraws = {
124 | 'cutpoints': {
125 | i if it.name is None else it.name.item(): it.value
126 | for i, it in enumerate(
127 | self.treedraws.getbyname('cutpoints').items()
128 | )
129 | },
130 | 'trees': self.treedraws.getbyname('trees').item(),
131 | }
132 | else:
133 | # it's an OrdDict
134 | self.treedraws = {
135 | 'cutpoints': {
136 | i if k is None else k.item(): v
137 | for i, (k, v) in enumerate(self.treedraws['cutpoints'].items())
138 | },
139 | 'trees': self.treedraws['trees'].item(),
140 | }
141 |
142 | @rmethod
143 | def predict(
144 | self, newdata: Float64[ndarray, 'm p'], *args, **kwargs
145 | ) -> Float64[ndarray, 'ndpost m']:
146 | """Compute predictions."""
147 | ...
148 |
149 |
150 | class bartModelMatrix(RObjectBase): # noqa: D101 because the R doc is added automatically
151 | _rfuncname = 'BART3::bartModelMatrix'
152 |
153 |
154 | class gbart(mc_gbart): # noqa: D101 because the R doc is added automatically
155 | _rfuncname = 'BART3::gbart'
156 |
157 | sigma: Float64[ndarray, ' nskip+ndpost'] | None = None
158 |
--------------------------------------------------------------------------------
/tests/test_debug.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/test_debug.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Test `bartz.debug`."""
26 |
27 | from collections import namedtuple
28 |
29 | import pytest
30 | from equinox import tree_at
31 | from jax import numpy as jnp
32 | from jax import random
33 | from scipy import stats
34 | from scipy.stats import ks_1samp
35 |
36 | from bartz.debug import check_trace, format_tree, sample_prior
37 | from bartz.jaxext import minimal_unsigned_dtype
38 | from bartz.mcmcloop import TreesTrace
39 |
40 |
41 | def manual_tree(
42 | leaf: list[list[float]], var: list[list[int]], split: list[list[int]]
43 | ) -> TreesTrace:
44 | """Facilitate the hardcoded definition of tree heaps."""
45 | assert len(leaf) == len(var) + 1 == len(split) + 1
46 |
47 | def check_powers_of_2(seq: list[list]):
48 | """Check if the lengths of the lists in `seq` are powers of 2."""
49 | return all(len(x) == 2**i for i, x in enumerate(seq))
50 |
51 | check_powers_of_2(leaf)
52 | check_powers_of_2(var)
53 | check_powers_of_2(split)
54 |
55 | tree = TreesTrace(
56 | jnp.concatenate([jnp.zeros(1), *map(jnp.array, leaf)]),
57 | jnp.concatenate([jnp.zeros(1, int), *map(jnp.array, var)]),
58 | jnp.concatenate([jnp.zeros(1, int), *map(jnp.array, split)]),
59 | )
60 | assert tree.leaf_tree.dtype == jnp.float32
61 | assert tree.var_tree.dtype == jnp.int32
62 | assert tree.split_tree.dtype == jnp.int32
63 | return tree
64 |
65 |
66 | def test_format_tree():
67 | """Check the output of `format_tree` on a single example."""
68 | tree = manual_tree(
69 | [[1.0], [2.0, 3.0], [4.0, 5.0, 6.0, 7.0]], [[4], [1, 2]], [[15], [0, 3]]
70 | )
71 | s = format_tree(tree)
72 | print(s)
73 | ref_s = """\
74 | 1 ┐x4 < 15
75 | 2 ├── 2.0
76 | 3 └──┐x2 < 3
77 | 6 ├──╢6.0
78 | 7 └──╢7.0"""
79 | assert s == ref_s
80 |
81 |
82 | class TestSamplePrior:
83 | """Test `debug.sample_prior`."""
84 |
85 | Args = namedtuple(
86 | 'Args',
87 | ['key', 'trace_length', 'num_trees', 'max_split', 'p_nonterminal', 'sigma_mu'],
88 | )
89 |
90 | @pytest.fixture
91 | def args(self, keys):
92 | """Prepare arguments for `sample_prior`."""
93 | # config
94 | trace_length = 1000
95 | num_trees = 200
96 | maxdepth = 6
97 | alpha = 0.95
98 | beta = 2
99 | max_split = 5
100 |
101 | # prepare arguments
102 | d = jnp.arange(maxdepth - 1)
103 | p_nonterminal = alpha / (1 + d).astype(float) ** beta
104 | p = maxdepth - 1
105 | max_split = jnp.full(p, jnp.array(max_split, minimal_unsigned_dtype(max_split)))
106 | sigma_mu = 1 / jnp.sqrt(num_trees)
107 |
108 | return self.Args(
109 | keys.pop(), trace_length, num_trees, max_split, p_nonterminal, sigma_mu
110 | )
111 |
112 | def test_valid_trees(self, args: Args):
113 | """Check all sampled trees are valid."""
114 | trees = sample_prior(*args)
115 | batch_shape = (args.trace_length, args.num_trees)
116 | heap_size = 2 ** (args.p_nonterminal.size + 1)
117 | assert trees.leaf_tree.shape == (*batch_shape, heap_size)
118 | assert trees.var_tree.shape == (*batch_shape, heap_size // 2)
119 | assert trees.split_tree.shape == (*batch_shape, heap_size // 2)
120 | bad = check_trace(trees, args.max_split)
121 | num_bad = jnp.count_nonzero(bad).item()
122 | assert num_bad == 0
123 |
124 | def test_max_depth(self, keys, args: Args):
125 | """Check that trees stop growing when p_nonterminal = 0."""
126 | for max_depth in range(args.p_nonterminal.size + 1):
127 | p_nonterminal = jnp.zeros_like(args.p_nonterminal)
128 | p_nonterminal = p_nonterminal.at[:max_depth].set(1.0)
129 | args = tree_at(lambda args: args.p_nonterminal, args, p_nonterminal)
130 | args = tree_at(lambda args: args.key, args, keys.pop())
131 | trees = sample_prior(*args)
132 | assert jnp.all(trees.split_tree[:, :, 1 : 2**max_depth])
133 | assert not jnp.any(trees.split_tree[:, :, 2**max_depth :])
134 |
135 | def test_forest_sdev(self, keys, args: Args):
136 | """Check that the sum of trees is standard Normal."""
137 | trees = sample_prior(*args)
138 | leaf_indices = random.randint(
139 | keys.pop(), trees.leaf_tree.shape[:2], 0, trees.leaf_tree.shape[-1]
140 | )
141 | batch_indices = jnp.ogrid[
142 | : trees.leaf_tree.shape[0], : trees.leaf_tree.shape[1]
143 | ]
144 | leaves = trees.leaf_tree[(*batch_indices, leaf_indices)]
145 | sum_of_trees = jnp.sum(leaves, axis=1)
146 |
147 | test = ks_1samp(sum_of_trees, stats.norm.cdf)
148 | assert test.pvalue > 0.1
149 |
150 | def test_trees_differ(self, args: Args):
151 | """Check that trees are different across iterations."""
152 | trees = sample_prior(*args)
153 | for attr in ('leaf_tree', 'var_tree', 'split_tree'):
154 | heap = getattr(trees, attr)
155 | diff_trace = jnp.diff(heap, axis=0)
156 | diff_forest = jnp.diff(heap, axis=1)
157 | assert jnp.any(diff_trace)
158 | assert jnp.any(diff_forest)
159 |
--------------------------------------------------------------------------------
/tests/rbartpackages/_base.py:
--------------------------------------------------------------------------------
1 | # bartz/tests/rbartpackages/_base.py
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | from collections.abc import Callable
26 | from functools import wraps
27 | from re import fullmatch, match
28 |
29 | import numpy as np
30 | from rpy2 import robjects
31 | from rpy2.robjects import BoolVector, conversion, numpy2ri
32 | from rpy2.robjects.help import Package
33 | from rpy2.robjects.methods import RS4
34 |
35 | # converter for pandas
36 | pandas_converter = conversion.Converter('pandas')
37 | try:
38 | from rpy2.robjects import pandas2ri
39 | except ImportError:
40 | pass
41 | else:
42 | pandas_converter = pandas2ri.converter
43 |
44 | # converter for polars
45 | polars_converter = conversion.Converter('polars')
46 | try:
47 | import polars
48 | from rpy2.robjects import pandas2ri
49 | except ImportError:
50 | pass
51 | else:
52 |
53 | def polars_to_r(df):
54 | df = df.to_pandas()
55 | return pandas2ri.py2rpy(df)
56 |
57 | polars_converter.py2rpy.register(polars.DataFrame, polars_to_r)
58 | polars_converter.py2rpy.register(polars.Series, polars_to_r)
59 |
60 | # converter for jax
61 | jax_converter = conversion.Converter('jax')
62 | try:
63 | import jax
64 | except ImportError:
65 | pass
66 | else:
67 |
68 | def jax_to_r(x):
69 | x = np.asarray(x)
70 | if x.ndim == 0:
71 | x = x[()]
72 | return numpy2ri.py2rpy(x)
73 |
74 | jax_converter.py2rpy.register(jax.Array, jax_to_r)
75 |
76 | # converter for numpy
77 | numpy_converter = numpy2ri.converter
78 |
79 |
80 | # converter for BoolVector (why isn't it in the numpy converter?)
81 | def bool_vector_to_python(x):
82 | return np.array(x, bool)
83 |
84 |
85 | bool_vector_converter = conversion.Converter('bool_vector')
86 | bool_vector_converter.rpy2py.register(BoolVector, bool_vector_to_python)
87 |
88 |
89 | # converter for python dictionaries
90 | dict_converter = conversion.Converter('dict')
91 |
92 |
93 | def dict_to_r(x):
94 | return robjects.ListVector(x)
95 |
96 |
97 | dict_converter.py2rpy.register(dict, dict_to_r)
98 |
99 | R_IDENTIFIER = r'(?:[a-zA-Z]|\.(?![0-9]))[a-zA-Z0-9._]*'
100 |
101 |
102 | class RObjectBase:
103 | """
104 | Base class for Python wrappers of R objects creators.
105 |
106 | Subclasses should define the class attribute `_rfuncname`, and declare
107 | stub methods decorated with `rmethod`.
108 |
109 | _rfuncname : str
110 | An R function in the format ``'::``. The function is
111 | called with the initialization arguments, converted to R objects, and is
112 | expected to return an R object. The attributes of the R object are
113 | converted to equivalent Python values and set as attributes of the
114 | Python object. The R object itself is assigned to the member `_robject`.
115 | """
116 |
117 | _converter = (
118 | robjects.default_converter
119 | + pandas_converter
120 | + polars_converter
121 | + numpy_converter
122 | + bool_vector_converter
123 | + jax_converter
124 | + dict_converter
125 | )
126 | _convctx = conversion.localconverter(_converter)
127 |
128 | def _py2r(self, x):
129 | if isinstance(x, __class__):
130 | return x._robject
131 | with self._convctx:
132 | return self._converter.py2rpy(x)
133 |
134 | def _r2py(self, x):
135 | with self._convctx:
136 | return self._converter.rpy2py(x)
137 |
138 | def _args2r(self, args):
139 | return tuple(map(self._py2r, args))
140 |
141 | def _kw2r(self, kw):
142 | return {key: self._py2r(value) for key, value in kw.items()}
143 |
144 | _rfuncname: str = NotImplemented
145 |
146 | @property
147 | def _library(self) -> str:
148 | """Parse `_rfuncname` to get the library. Also checks `_rfuncname` is valid."""
149 | pattern = rf'^({R_IDENTIFIER})::({R_IDENTIFIER})$'
150 | m = match(pattern, self._rfuncname)
151 | if m is None:
152 | msg = f'Invalid _rfuncname: {self._rfuncname}.'
153 | raise ValueError(msg)
154 | return m.group(1)
155 |
156 | def __init__(self, *args, **kw):
157 | robjects.r(f'loadNamespace("{self._library}")')
158 | func = robjects.r(self._rfuncname)
159 | obj = func(*self._args2r(args), **self._kw2r(kw))
160 | self._robject = obj
161 | if hasattr(obj, 'items'):
162 | for s, v in obj.items():
163 | setattr(self, s.replace('.', '_'), self._r2py(v))
164 |
165 | def __init_subclass__(cls, **kw):
166 | """Automatically add R documentation to subclasses."""
167 | library, name = cls._rfuncname.split('::')
168 | page = Package(library).fetch(name)
169 | if cls.__doc__ is None:
170 | cls.__doc__ = ''
171 | cls.__doc__ += 'R documentation:\n' + page.to_docstring()
172 |
173 |
174 | def rmethod(meth: Callable, *, rname: str | None = None) -> Callable:
175 | """Automatically implement a method using the correspoding R method.
176 |
177 | Parameters
178 | ----------
179 | meth
180 | A method in a subclass of `RObjectBase`.
181 | rname
182 | The name of the method in R. If not specified, use the name of `meth`.
183 |
184 | Returns
185 | -------
186 | methimpl
187 | An implementation of the method that calls the R method. The original
188 | implementation of meth is completely discarded.
189 |
190 | Examples
191 | --------
192 | >>> class MyRObject(RObjectBase):
193 | ... _rfuncname = 'mypackage::myfunction'
194 | ... @partial(rmethod, rname='my.method')
195 | ... def my_method(self, arg1: int, arg2: str):
196 | ... ...
197 | """
198 | if rname is None:
199 | rname = meth.__name__
200 |
201 | # I can't automatically add a docstring to the method because the R class
202 | # can be determined at runtime
203 |
204 | @wraps(meth)
205 | def impl(self, *args, **kw):
206 | if isinstance(self._robject, RS4):
207 | func = robjects.r['$'](self._robject, rname)
208 | out = func(*self._args2r(args), **self._kw2r(kw))
209 |
210 | else:
211 | if not fullmatch(R_IDENTIFIER, rname):
212 | msg = f'Invalid R method name: {rname}'
213 | raise ValueError(msg)
214 | rclass = self._robject.rclass[0]
215 | func = robjects.r(
216 | f'getS3method("{rname}", "{rclass}", envir = asNamespace("{self._library}"))'
217 | )
218 | out = func(self._robject, *self._args2r(args), **self._kw2r(kw))
219 |
220 | return self._r2py(out)
221 |
222 | return impl
223 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | # bartz/docs/conf.py
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | # Configuration file for the Sphinx documentation builder.
26 | #
27 | # This file only contains a selection of the most common options. For a full
28 | # list see the documentation:
29 | # https://www.sphinx-doc.org/en/master/usage/configuration.html
30 |
31 | import datetime
32 | import inspect
33 | import os
34 | import pathlib
35 | import sys
36 | from functools import cached_property
37 |
38 | import git
39 | from packaging import version as pkgversion
40 |
41 | # -- Doc variant -------------------------------------------------------------
42 |
43 | repo = git.Repo(search_parent_directories=True)
44 |
45 | variant = os.environ.get('BARTZ_DOC_VARIANT', 'dev')
46 |
47 | if variant == 'dev':
48 | commit = repo.head.commit.hexsha
49 | uncommitted_stuff = repo.is_dirty()
50 | version = f'{commit[:7]}{"+" if uncommitted_stuff else ""}'
51 |
52 | elif variant == 'latest':
53 | # list git tags
54 | tags = [t.name for t in repo.tags]
55 | print(f'git tags: {tags}')
56 |
57 | # find final versions in tags
58 | versions = []
59 | for t in tags:
60 | try:
61 | v = pkgversion.parse(t)
62 | except pkgversion.InvalidVersion:
63 | continue
64 | if v.is_prerelease or v.is_devrelease:
65 | continue
66 | versions.append((v, t))
67 | print(f'tags for releases: {versions}')
68 |
69 | # find latest versions
70 | versions.sort(key=lambda x: x[0])
71 | version, tag = versions[-1]
72 |
73 | # check it out and check it matches the version in the package
74 | repo.git.checkout(tag)
75 | import bartz
76 |
77 | assert pkgversion.parse(bartz.__version__) == version
78 |
79 | version = str(version)
80 | uncommitted_stuff = False
81 |
82 | else:
83 | raise KeyError(variant)
84 |
85 | import bartz
86 |
87 | # -- Project information -----------------------------------------------------
88 |
89 | project = f'bartz {version}'
90 | author = 'The Bartz Contributors'
91 |
92 | now = datetime.datetime.now(tz=datetime.timezone.utc)
93 | year = '2024'
94 | if now.year > int(year):
95 | year += '-' + str(now.year)
96 | copyright = year + ', ' + author # noqa: A001, because sphinx uses this variable
97 |
98 | release = version
99 |
100 | # -- General configuration ---------------------------------------------------
101 |
102 | # Add any Sphinx extension module names here, as strings. They can be
103 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
104 | # ones.
105 | extensions = [
106 | 'sphinx.ext.napoleon',
107 | 'sphinx.ext.autodoc',
108 | 'sphinx_autodoc_typehints', # (!) keep after napoleon
109 | 'sphinx.ext.mathjax',
110 | 'sphinx.ext.intersphinx', # link to other documentations automatically
111 | 'myst_parser', # markdown support
112 | ]
113 |
114 | # decide whether to use viewcode or linkcode extension
115 | ext = 'viewcode' # copy source code in static website
116 | if not uncommitted_stuff:
117 | commit = repo.head.commit.hexsha
118 | branches = repo.git.branch('--remotes', '--contains', commit)
119 | commit_on_github = bool(branches.strip())
120 | if commit_on_github:
121 | ext = 'linkcode' # links to code on github
122 | extensions.append(f'sphinx.ext.{ext}')
123 |
124 | myst_enable_extensions = [
125 | # "amsmath",
126 | 'dollarmath'
127 | ]
128 |
129 | # Add any paths that contain templates here, relative to this directory.
130 | # templates_path = ['_templates'] # noqa: ERA001
131 |
132 | # List of patterns, relative to source directory, that match files and
133 | # directories to ignore when looking for source files.
134 | # This pattern also affects html_static_path and html_extra_path.
135 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
136 |
137 |
138 | # -- Options for HTML output -------------------------------------------------
139 |
140 | # The theme to use for HTML and HTML Help pages. See the documentation for
141 | # a list of builtin themes.
142 | #
143 | html_theme = 'alabaster'
144 |
145 | html_title = f'{project} documentation'
146 |
147 | html_theme_options = dict(
148 | description='Super-fast BART (Bayesian Additive Regression Trees) in Python',
149 | fixed_sidebar=True,
150 | github_button=True,
151 | github_type='star',
152 | github_repo='bartz',
153 | github_user='Gattocrucco',
154 | show_relbars=True,
155 | )
156 |
157 | # Add any paths that contain custom static files (such as style sheets) here,
158 | # relative to this directory. They are copied after the builtin static files,
159 | # so a file named "default.css" will overwrite the builtin "default.css".
160 | html_static_path = ['_static']
161 |
162 | master_doc = 'index'
163 |
164 | # -- Other options -------------------------------------------------
165 |
166 | default_role = 'py:obj'
167 |
168 | # autodoc
169 | autoclass_content = 'both' # concatenate the class and __init__ docstrings
170 | # default arguments are printed as in source instead of being evaluated
171 | autodoc_preserve_defaults = True
172 | autodoc_default_options = {'member-order': 'bysource'}
173 |
174 | # autodoc-typehints
175 | typehints_use_rtype = False
176 | typehints_document_rtype = True
177 | always_use_bars_union = True
178 | typehints_defaults = 'comma'
179 |
180 | # napoleon
181 | napoleon_google_docstring = False
182 | napoleon_use_ivar = True
183 | napoleon_use_rtype = False
184 |
185 | # intersphinx
186 | intersphinx_mapping = dict(
187 | scipy=('https://docs.scipy.org/doc/scipy', None),
188 | numpy=('https://numpy.org/doc/stable', None),
189 | jax=('https://docs.jax.dev/en/latest', None),
190 | )
191 |
192 | # viewcode
193 | viewcode_line_numbers = True
194 |
195 |
196 | def linkcode_resolve(domain, info):
197 | """
198 | Determine the URL corresponding to Python object, for extension linkcode.
199 |
200 | Adapted from scipy/doc/release/conf.py.
201 | """
202 | assert domain == 'py'
203 |
204 | modname = info['module']
205 | assert modname.startswith('bartz')
206 | fullname = info['fullname']
207 |
208 | submod = sys.modules.get(modname)
209 | assert submod
210 |
211 | obj = submod
212 | for part in fullname.split('.'):
213 | obj = getattr(obj, part)
214 |
215 | if isinstance(obj, cached_property):
216 | obj = obj.func
217 | obj = inspect.unwrap(obj)
218 |
219 | fn = inspect.getsourcefile(obj)
220 | assert fn
221 |
222 | source, lineno = inspect.getsourcelines(obj)
223 | assert lineno
224 | linespec = f'#L{lineno}-L{lineno + len(source) - 1}'
225 |
226 | prefix = 'https://github.com/Gattocrucco/bartz/blob'
227 | root = pathlib.Path(bartz.__file__).parent
228 | path = pathlib.Path(fn).relative_to(root).as_posix()
229 | return f'{prefix}/{commit}/src/bartz/{path}{linespec}'
230 |
--------------------------------------------------------------------------------
/src/bartz/jaxext/_autobatch.py:
--------------------------------------------------------------------------------
1 | # bartz/src/bartz/jaxext/_autobatch.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Implementation of `autobatch`."""
26 |
27 | import math
28 | from collections.abc import Callable
29 | from functools import wraps
30 | from warnings import warn
31 |
32 | from jax import eval_shape, jit
33 | from jax import numpy as jnp
34 | from jax.lax import scan
35 | from jax.tree import flatten as tree_flatten
36 | from jax.tree import map as tree_map
37 | from jax.tree import reduce as tree_reduce
38 | from jaxtyping import PyTree
39 |
40 |
41 | def expand_axes(axes, tree):
42 | """Expand `axes` such that they match the pytreedef of `tree`."""
43 |
44 | def expand_axis(axis, subtree):
45 | return tree_map(lambda _: axis, subtree)
46 |
47 | return tree_map(expand_axis, axes, tree, is_leaf=lambda x: x is None)
48 |
49 |
50 | def check_no_nones(axes, tree):
51 | def check_not_none(_, axis):
52 | assert axis is not None
53 |
54 | tree_map(check_not_none, tree, axes)
55 |
56 |
57 | def extract_size(axes, tree):
58 | def get_size(x, axis):
59 | if axis is None:
60 | return None
61 | else:
62 | return x.shape[axis]
63 |
64 | sizes = tree_map(get_size, tree, axes)
65 | sizes, _ = tree_flatten(sizes)
66 | assert all(s == sizes[0] for s in sizes)
67 | return sizes[0]
68 |
69 |
70 | def sum_nbytes(tree):
71 | def nbytes(x):
72 | return math.prod(x.shape) * x.dtype.itemsize
73 |
74 | return tree_reduce(lambda size, x: size + nbytes(x), tree, 0)
75 |
76 |
77 | def next_divisor_small(dividend, min_divisor):
78 | for divisor in range(min_divisor, int(math.sqrt(dividend)) + 1):
79 | if dividend % divisor == 0:
80 | return divisor
81 | return dividend
82 |
83 |
84 | def next_divisor_large(dividend, min_divisor):
85 | max_inv_divisor = dividend // min_divisor
86 | for inv_divisor in range(max_inv_divisor, 0, -1):
87 | if dividend % inv_divisor == 0:
88 | return dividend // inv_divisor
89 | return dividend
90 |
91 |
92 | def next_divisor(dividend, min_divisor):
93 | if dividend == 0:
94 | return min_divisor
95 | if min_divisor * min_divisor <= dividend:
96 | return next_divisor_small(dividend, min_divisor)
97 | return next_divisor_large(dividend, min_divisor)
98 |
99 |
100 | def pull_nonbatched(axes, tree):
101 | def pull_nonbatched(x, axis):
102 | if axis is None:
103 | return None
104 | else:
105 | return x
106 |
107 | return tree_map(pull_nonbatched, tree, axes), tree
108 |
109 |
110 | def push_nonbatched(axes, tree, original_tree):
111 | def push_nonbatched(original_x, x, axis):
112 | if axis is None:
113 | return original_x
114 | else:
115 | return x
116 |
117 | return tree_map(push_nonbatched, original_tree, tree, axes)
118 |
119 |
120 | def move_axes_out(axes, tree):
121 | def move_axis_out(x, axis):
122 | return jnp.moveaxis(x, axis, 0)
123 |
124 | return tree_map(move_axis_out, tree, axes)
125 |
126 |
127 | def move_axes_in(axes, tree):
128 | def move_axis_in(x, axis):
129 | return jnp.moveaxis(x, 0, axis)
130 |
131 | return tree_map(move_axis_in, tree, axes)
132 |
133 |
134 | def batch(tree, nbatches):
135 | def batch(x):
136 | return x.reshape(nbatches, x.shape[0] // nbatches, *x.shape[1:])
137 |
138 | return tree_map(batch, tree)
139 |
140 |
141 | def unbatch(tree):
142 | def unbatch(x):
143 | return x.reshape(x.shape[0] * x.shape[1], *x.shape[2:])
144 |
145 | return tree_map(unbatch, tree)
146 |
147 |
148 | def check_same(tree1, tree2):
149 | def check_same(x1, x2):
150 | assert x1.shape == x2.shape
151 | assert x1.dtype == x2.dtype
152 |
153 | tree_map(check_same, tree1, tree2)
154 |
155 |
156 | def autobatch(
157 | func: Callable,
158 | max_io_nbytes: int,
159 | in_axes: PyTree[int | None] = 0,
160 | out_axes: PyTree[int] = 0,
161 | return_nbatches: bool = False,
162 | ) -> Callable:
163 | """
164 | Batch a function such that each batch is smaller than a threshold.
165 |
166 | Parameters
167 | ----------
168 | func
169 | A jittable function with positional arguments only, with inputs and
170 | outputs pytrees of arrays.
171 | max_io_nbytes
172 | The maximum number of input + output bytes in each batch (excluding
173 | unbatched arguments.)
174 | in_axes
175 | A tree matching (a prefix of) the structure of the function input,
176 | indicating along which axes each array should be batched. A `None` axis
177 | indicates to not batch an argument.
178 | out_axes
179 | The same for outputs (but non-batching is not allowed).
180 | return_nbatches
181 | If True, the number of batches is returned as a second output.
182 |
183 | Returns
184 | -------
185 | A function with the same signature as `func`, save for the return value if `return_nbatches`.
186 | """
187 | initial_in_axes = in_axes
188 | initial_out_axes = out_axes
189 |
190 | @jit
191 | @wraps(func)
192 | def batched_func(*args):
193 | example_result = eval_shape(func, *args)
194 |
195 | in_axes = expand_axes(initial_in_axes, args)
196 | out_axes = expand_axes(initial_out_axes, example_result)
197 | check_no_nones(out_axes, example_result)
198 |
199 | size = extract_size((in_axes, out_axes), (args, example_result))
200 |
201 | args, nonbatched_args = pull_nonbatched(in_axes, args)
202 |
203 | total_nbytes = sum_nbytes((args, example_result))
204 | min_nbatches = total_nbytes // max_io_nbytes + bool(
205 | total_nbytes % max_io_nbytes
206 | )
207 | min_nbatches = max(1, min_nbatches)
208 | nbatches = next_divisor(size, min_nbatches)
209 | assert 1 <= nbatches <= max(1, size)
210 | assert size % nbatches == 0
211 | assert total_nbytes % nbatches == 0
212 |
213 | batch_nbytes = total_nbytes // nbatches
214 | if batch_nbytes > max_io_nbytes:
215 | assert size == nbatches
216 | msg = f'batch_nbytes = {batch_nbytes} > max_io_nbytes = {max_io_nbytes}'
217 | warn(msg)
218 |
219 | def loop(_, args):
220 | args = move_axes_in(in_axes, args)
221 | args = push_nonbatched(in_axes, args, nonbatched_args)
222 | result = func(*args)
223 | result = move_axes_out(out_axes, result)
224 | return None, result
225 |
226 | args = move_axes_out(in_axes, args)
227 | args = batch(args, nbatches)
228 | _, result = scan(loop, None, args)
229 | result = unbatch(result)
230 | result = move_axes_in(out_axes, result)
231 |
232 | check_same(example_result, result)
233 |
234 | if return_nbatches:
235 | return result, nbatches
236 | return result
237 |
238 | return batched_func
239 |
--------------------------------------------------------------------------------
/src/bartz/jaxext/__init__.py:
--------------------------------------------------------------------------------
1 | # bartz/src/bartz/jaxext/__init__.py
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Additions to jax."""
26 |
27 | import math
28 | from collections.abc import Sequence
29 | from functools import partial
30 |
31 | import jax
32 | from jax import Device, ensure_compile_time_eval, jit, random
33 | from jax import numpy as jnp
34 | from jax.lax import scan
35 | from jax.scipy.special import ndtr
36 | from jaxtyping import Array, Bool, Float32, Key, Scalar, Shaped
37 |
38 | from bartz.jaxext._autobatch import autobatch # noqa: F401
39 | from bartz.jaxext.scipy.special import ndtri
40 |
41 |
42 | def vmap_nodoc(fun, *args, **kw):
43 | """
44 | Acts like `jax.vmap` but preserves the docstring of the function unchanged.
45 |
46 | This is useful if the docstring already takes into account that the
47 | arguments have additional axes due to vmap.
48 | """
49 | doc = fun.__doc__
50 | fun = jax.vmap(fun, *args, **kw)
51 | fun.__doc__ = doc
52 | return fun
53 |
54 |
55 | def minimal_unsigned_dtype(value):
56 | """Return the smallest unsigned integer dtype that can represent `value`."""
57 | if value < 2**8:
58 | return jnp.uint8
59 | if value < 2**16:
60 | return jnp.uint16
61 | if value < 2**32:
62 | return jnp.uint32
63 | return jnp.uint64
64 |
65 |
66 | @partial(jax.jit, static_argnums=(1,))
67 | def unique(
68 | x: Shaped[Array, ' _'], size: int, fill_value: Scalar
69 | ) -> tuple[Shaped[Array, ' {size}'], int]:
70 | """
71 | Restricted version of `jax.numpy.unique` that uses less memory.
72 |
73 | Parameters
74 | ----------
75 | x
76 | The input array.
77 | size
78 | The length of the output.
79 | fill_value
80 | The value to fill the output with if `size` is greater than the number
81 | of unique values in `x`.
82 |
83 | Returns
84 | -------
85 | out : Shaped[Array, '{size}']
86 | The unique values in `x`, sorted, and right-padded with `fill_value`.
87 | actual_length : int
88 | The number of used values in `out`.
89 | """
90 | if x.size == 0:
91 | return jnp.full(size, fill_value, x.dtype), 0
92 | if size == 0:
93 | return jnp.empty(0, x.dtype), 0
94 | x = jnp.sort(x)
95 |
96 | def loop(carry, x):
97 | i_out, last, out = carry
98 | i_out = jnp.where(x == last, i_out, i_out + 1)
99 | out = out.at[i_out].set(x)
100 | return (i_out, x, out), None
101 |
102 | carry = 0, x[0], jnp.full(size, fill_value, x.dtype)
103 | (actual_length, _, out), _ = scan(loop, carry, x[:size])
104 | return out, actual_length + 1
105 |
106 |
107 | class split:
108 | """
109 | Split a key into `num` keys.
110 |
111 | Parameters
112 | ----------
113 | key
114 | The key to split.
115 | num
116 | The number of keys to split into.
117 | """
118 |
119 | _keys: tuple[Key[Array, ''], ...]
120 | _num_used: int
121 |
122 | def __init__(self, key: Key[Array, ''], num: int = 2):
123 | self._keys = _split_unpack(key, num)
124 | self._num_used = 0
125 |
126 | def __len__(self):
127 | return len(self._keys) - self._num_used
128 |
129 | def pop(self, shape: int | tuple[int, ...] = ()) -> Key[Array, '*']:
130 | """
131 | Pop one or more keys from the list.
132 |
133 | Parameters
134 | ----------
135 | shape
136 | The shape of the keys to pop. If empty (default), a single key is
137 | popped and returned. If not empty, the popped key is split and
138 | reshaped to the target shape.
139 |
140 | Returns
141 | -------
142 | The popped keys as a jax array with the requested shape.
143 |
144 | Raises
145 | ------
146 | IndexError
147 | If the list is empty.
148 | """
149 | if len(self) == 0:
150 | msg = 'No keys left to pop'
151 | raise IndexError(msg)
152 | if not isinstance(shape, tuple):
153 | shape = (shape,)
154 | key = self._keys[self._num_used]
155 | self._num_used += 1
156 | if shape:
157 | key = _split_shaped(key, shape)
158 | return key
159 |
160 |
161 | @partial(jit, static_argnums=(1,))
162 | def _split_unpack(key: Key[Array, ''], num: int) -> tuple[Key[Array, ''], ...]:
163 | keys = random.split(key, num)
164 | return tuple(keys)
165 |
166 |
167 | @partial(jit, static_argnums=(1,))
168 | def _split_shaped(key: Key[Array, ''], shape: tuple[int, ...]) -> Key[Array, '*']:
169 | num = math.prod(shape)
170 | keys = random.split(key, num)
171 | return keys.reshape(shape)
172 |
173 |
174 | def truncated_normal_onesided(
175 | key: Key[Array, ''],
176 | shape: Sequence[int],
177 | upper: Bool[Array, '*'],
178 | bound: Float32[Array, '*'],
179 | *,
180 | clip: bool = True,
181 | ) -> Float32[Array, '*']:
182 | """
183 | Sample from a one-sided truncated standard normal distribution.
184 |
185 | Parameters
186 | ----------
187 | key
188 | JAX random key.
189 | shape
190 | Shape of output array, broadcasted with other inputs.
191 | upper
192 | True for (-∞, bound], False for [bound, ∞).
193 | bound
194 | The truncation boundary.
195 | clip
196 | Whether to clip the truncated uniform samples to (0, 1) before
197 | transforming them to truncated normal. Intended for debugging purposes.
198 |
199 | Returns
200 | -------
201 | Array of samples from the truncated normal distribution.
202 | """
203 | # Pseudocode:
204 | # | if upper:
205 | # | if bound < 0:
206 | # | ndtri(uniform(0, ndtr(bound))) =
207 | # | ndtri(ndtr(bound) * u)
208 | # | if bound > 0:
209 | # | -ndtri(uniform(ndtr(-bound), 1)) =
210 | # | -ndtri(ndtr(-bound) + ndtr(bound) * (1 - u))
211 | # | if not upper:
212 | # | if bound < 0:
213 | # | ndtri(uniform(ndtr(bound), 1)) =
214 | # | ndtri(ndtr(bound) + ndtr(-bound) * (1 - u))
215 | # | if bound > 0:
216 | # | -ndtri(uniform(0, ndtr(-bound))) =
217 | # | -ndtri(ndtr(-bound) * u)
218 | shape = jnp.broadcast_shapes(shape, upper.shape, bound.shape)
219 | bound_pos = bound > 0
220 | ndtr_bound = ndtr(bound)
221 | ndtr_neg_bound = ndtr(-bound)
222 | scale = jnp.where(upper, ndtr_bound, ndtr_neg_bound)
223 | shift = jnp.where(upper, ndtr_neg_bound, ndtr_bound)
224 | u = random.uniform(key, shape)
225 | left_u = scale * (1 - u) # ~ uniform in (0, ndtr(±bound)]
226 | right_u = shift + scale * u # ~ uniform in [ndtr(∓bound), 1)
227 | truncated_u = jnp.where(upper ^ bound_pos, left_u, right_u)
228 | if clip:
229 | # on gpu the accuracy is lower and sometimes u can reach the boundaries
230 | zero = jnp.zeros((), truncated_u.dtype)
231 | one = jnp.ones((), truncated_u.dtype)
232 | truncated_u = jnp.clip(
233 | truncated_u, jnp.nextafter(zero, one), jnp.nextafter(one, zero)
234 | )
235 | truncated_norm = ndtri(truncated_u)
236 | return jnp.where(bound_pos, -truncated_norm, truncated_norm)
237 |
238 |
239 | def get_default_device() -> Device:
240 | """Get the current default JAX device."""
241 | with ensure_compile_time_eval():
242 | return jnp.zeros(()).device
243 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | # bartz/Makefile
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | # Makefile for running tests, prepare and upload a release.
26 |
27 | COVERAGE_SUFFIX =
28 | OLD_PYTHON = $(shell uv run --group=ci python -c 'from tests.util import get_old_python_str; print(get_old_python_str())')
29 | OLD_DATE = 2025-05-15
30 |
31 | .PHONY: all
32 | all:
33 | @echo "Available targets:"
34 | @echo "- setup: create R and Python environments for development"
35 | @echo "- tests: run unit tests, saving coverage information"
36 | @echo "- tests-old: run unit tests with oldest supported python and dependencies"
37 | @echo '- tests-gpu: variant of `tests` that works on gpu'
38 | @echo "- docs: build html documentation"
39 | @echo "- docs-latest: build html documentation for latest release"
40 | @echo "- covreport: build html coverage report"
41 | @echo "- covcheck: check coverage is above some thresholds"
42 | @echo "- release: packages the python module, invokes tests and docs first"
43 | @echo "- upload: upload release to PyPI"
44 | @echo "- upload-test: upload release to TestPyPI"
45 | @echo "- asv-run: run benchmarks on all unbenchmarked tagged releases and main"
46 | @echo "- asv-publish: create html benchmark report"
47 | @echo "- asv-preview: create html report and start server"
48 | @echo "- asv-main: run benchmarks on main branch"
49 | @echo "- asv-quick: run quick benchmarks on current code, no saving"
50 | @echo "- ipython: start an ipython shell with stuff pre-imported"
51 | @echo "- ipython-old: start an ipython shell with oldest supported python and dependencies"
52 | @echo
53 | @echo "Release workflow:"
54 | @echo "- $$ uv version --bump major|minor|patch"
55 | @echo "- describe release in docs/changelog.md"
56 | @echo "- $$ make release (repeat until it goes smoothly)"
57 | @echo "- push and check CI completes (if it doesn't, go to previous step)"
58 | @echo "- $$ make upload"
59 | @echo "- publish github release (updates zenodo automatically)"
60 | @echo "- if the online docs are not up-to-date, press 'run workflow' on https://github.com/Gattocrucco/bartz/actions/workflows/tests.yml, and try to understand why 'make upload' didn't do it"
61 |
62 |
63 | .PHONY: setup
64 | setup:
65 | Rscript -e "renv::restore()"
66 | uv run --all-groups pre-commit install
67 | @CUDA_VERSION=$$(nvidia-smi 2>/dev/null | grep -o 'CUDA Version: [0-9]*' | cut -d' ' -f3); \
68 | if [ "$$CUDA_VERSION" = "12" ]; then \
69 | echo "Detected CUDA 12, installing jax[cuda12]"; \
70 | uv pip install "jax[cuda12]"; \
71 | elif [ "$$CUDA_VERSION" = "13" ]; then \
72 | echo "Detected CUDA 13, installing jax[cuda13]"; \
73 | uv pip install "jax[cuda13]"; \
74 | else \
75 | echo "No CUDA detected"; \
76 | fi
77 |
78 |
79 | ################# TESTS #################
80 |
81 | TESTS_VARS = COVERAGE_FILE=.coverage.tests$(COVERAGE_SUFFIX)
82 | TESTS_COMMAND = python -m pytest --cov --cov-context=test --numprocesses=2 --dist=worksteal
83 |
84 | UV_RUN_CI = uv run --group=ci
85 | UV_OPTS_OLD = --python=$(OLD_PYTHON) --resolution=lowest-direct --exclude-newer=$(OLD_DATE)
86 | UV_VARS_OLD = UV_PROJECT_ENVIRONMENT=.venv-old
87 | UV_RUN_CI_OLD = $(UV_VARS_OLD) $(UV_RUN_CI) $(UV_OPTS_OLD)
88 |
89 | .PHONY: tests
90 | tests:
91 | $(TESTS_VARS) $(UV_RUN_CI) $(TESTS_COMMAND) $(ARGS)
92 |
93 | .PHONY: tests-old
94 | tests-old:
95 | $(TESTS_VARS) $(UV_RUN_CI_OLD) $(TESTS_COMMAND) $(ARGS)
96 |
97 | .PHONY: tests-gpu
98 | tests-gpu:
99 | nvidia-smi
100 | XLA_PYTHON_CLIENT_MEM_FRACTION=.20 $(TESTS_VARS) $(UV_RUN_CI) $(TESTS_COMMAND) --platform=gpu --numprocesses=3 $(ARGS)
101 |
102 | ################# DOCS #################
103 |
104 | .PHONY: docs
105 | docs:
106 | $(UV_RUN_CI) make -C docs html
107 | test ! -d _site/docs-dev || rm -r _site/docs-dev
108 | mv docs/_build/html _site/docs-dev
109 | @echo
110 | @echo "Now open _site/index.html"
111 |
112 | .PHONY: docs-latest
113 | docs-latest:
114 | BARTZ_DOC_VARIANT=latest $(UV_RUN_CI) make -C docs html
115 | git switch - || git switch main
116 | test ! -d _site/docs || rm -r _site/docs
117 | mv docs/_build/html _site/docs
118 | @echo
119 | @echo "Now open _site/index.html"
120 |
121 | .PHONY: covreport
122 | covreport:
123 | $(UV_RUN_CI) coverage combine --keep
124 | $(UV_RUN_CI) coverage html --include='src/*'
125 | @echo
126 | @echo "Now open _site/coverage/index.html"
127 |
128 | .PHONY: covcheck
129 | covcheck:
130 | $(UV_RUN_CI) coverage combine --keep
131 | $(UV_RUN_CI) coverage report --include='tests/**/test_*.py'
132 | $(UV_RUN_CI) coverage report --include='src/*'
133 | $(UV_RUN_CI) coverage report --include='tests/**/test_*.py' --fail-under=99 --format=total
134 | $(UV_RUN_CI) coverage report --include='src/*' --fail-under=90 --format=total
135 |
136 | ################# RELEASE #################
137 |
138 | .PHONY: update-deps
139 | update-deps:
140 | test ! -d .venv || rm -r .venv
141 | uv lock --upgrade
142 |
143 | .PHONY: copy-version
144 | copy-version: src/bartz/_version.py
145 | src/bartz/_version.py: pyproject.toml
146 | uv run --group=ci python -c 'from tests.util import update_version; update_version()'
147 |
148 | .PHONY: check-committed
149 | check-committed:
150 | git diff --quiet
151 | git diff --quiet --staged
152 |
153 | .PHONY: release
154 | release: update-deps copy-version check-committed
155 | @$(MAKE) tests
156 | @$(MAKE) tests-old
157 | @$(MAKE) docs
158 | test ! -d dist || rm -r dist
159 | uv build
160 |
161 | .PHONY: version-tag
162 | version-tag: copy-version check-committed
163 | git fetch --tags
164 | git tag v$(shell uv run python -c 'import bartz; print(bartz.__version__)')
165 | git push --tags
166 |
167 | .PHONY: upload
168 | upload: version-tag
169 | @echo "Enter PyPI token:"
170 | @read -s UV_PUBLISH_TOKEN && \
171 | export UV_PUBLISH_TOKEN="$$UV_PUBLISH_TOKEN" && \
172 | uv publish
173 | @VERSION=$$(uv run python -c 'import bartz; print(bartz.__version__)') && \
174 | echo "Try to install bartz $$VERSION from PyPI" && \
175 | uv tool run --with="bartz==$$VERSION" python -c 'import bartz; print(bartz.__version__)'
176 |
177 | .PHONY: upload-test
178 | upload-test: check-committed
179 | @echo "Enter TestPyPI token:"
180 | @read -s UV_PUBLISH_TOKEN && \
181 | export UV_PUBLISH_TOKEN="$$UV_PUBLISH_TOKEN" && \
182 | uv publish --check-url=https://test.pypi.org/simple/ --publish-url=https://test.pypi.org/legacy/
183 | @VERSION=$$(uv run --group=ci python -c 'from tests.util import get_version; print(get_version())') && \
184 | echo "Try to install bartz $$VERSION from TestPyPI" && \
185 | uv tool run --index=https://test.pypi.org/simple/ --index-strategy=unsafe-best-match --with="bartz==$$VERSION" python -c 'import bartz; print(bartz.__version__)'
186 |
187 |
188 | ################# BENCHMARKS #################
189 |
190 | ASV = $(UV_RUN_CI) python -m asv
191 |
192 | .PHONY: asv-run
193 | asv-run:
194 | $(UV_RUN_CI) python config/refs-for-asv.py | $(ASV) run --skip-existing --show-stderr HASHFILE:- $(ARGS)
195 |
196 | .PHONY: asv-publish
197 | asv-publish:
198 | $(ASV) publish $(ARGS)
199 |
200 | .PHONY: asv-preview
201 | asv-preview: asv-publish
202 | $(ASV) preview $(ARGS)
203 |
204 | .PHONY: asv-main
205 | asv-main:
206 | $(ASV) run --show-stderr main^! $(ARGS)
207 |
208 | .PHONY: asv-quick
209 | asv-quick:
210 | $(ASV) run --python=same --quick --dry-run --show-stderr $(ARGS)
211 |
212 |
213 | ################# IPYTHON SHELL #################
214 |
215 | .PHONY: ipython
216 | ipython:
217 | IPYTHONDIR=config/ipython uv run --all-groups python -m IPython $(ARGS)
218 |
219 | .PHONY: ipython-old
220 | ipython-old:
221 | IPYTHONDIR=config/ipython $(UV_VARS_OLD) uv run --all-groups $(UV_OPTS_OLD) python -m IPython $(ARGS)
222 |
--------------------------------------------------------------------------------
/src/bartz/_profiler.py:
--------------------------------------------------------------------------------
1 | # bartz/src/bartz/_profiler.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Module with utilities related to profiling bartz."""
26 |
27 | from collections.abc import Callable, Iterator
28 | from contextlib import contextmanager
29 | from functools import wraps
30 | from typing import Any, TypeVar
31 |
32 | from jax import block_until_ready, debug, jit
33 | from jax.lax import cond, scan
34 | from jax.profiler import TraceAnnotation
35 | from jaxtyping import Array, Bool
36 |
37 | PROFILE_MODE: bool = False
38 |
39 | T = TypeVar('T')
40 | Carry = TypeVar('Carry')
41 |
42 |
43 | def get_profile_mode() -> bool:
44 | """Return the current profile mode status.
45 |
46 | Returns
47 | -------
48 | True if profile mode is enabled, False otherwise.
49 | """
50 | return PROFILE_MODE
51 |
52 |
53 | def set_profile_mode(value: bool, /) -> None:
54 | """Set the profile mode status.
55 |
56 | Parameters
57 | ----------
58 | value
59 | If True, enable profile mode. If False, disable it.
60 | """
61 | global PROFILE_MODE # noqa: PLW0603
62 | PROFILE_MODE = value
63 |
64 |
65 | @contextmanager
66 | def profile_mode(value: bool, /) -> Iterator[None]:
67 | """Context manager to temporarily set profile mode.
68 |
69 | Parameters
70 | ----------
71 | value
72 | Profile mode value to set within the context.
73 |
74 | Examples
75 | --------
76 | >>> with profile_mode(True):
77 | ... # Code runs with profile mode enabled
78 | ... pass
79 |
80 | Notes
81 | -----
82 | In profiling mode, the MCMC loop is not compiled into a single function, but
83 | instead compiled in smaller pieces that are instrumented to show up in the
84 | jax tracer and Python profiling statistics. Search for function names
85 | starting with 'jab' (see `jit_and_block_if_profiling`).
86 |
87 | Jax tracing is not enabled by this context manager and if used must be
88 | handled separately by the user; this context manager only makes sure that
89 | the execution flow will be more interpretable in the traces if the tracer is
90 | used.
91 | """
92 | old_value = get_profile_mode()
93 | set_profile_mode(value)
94 | try:
95 | yield
96 | finally:
97 | set_profile_mode(old_value)
98 |
99 |
100 | def jit_and_block_if_profiling(
101 | func: Callable[..., T], block_before: bool = False, **kwargs
102 | ) -> Callable[..., T]:
103 | """Apply JIT compilation and block if profiling is enabled.
104 |
105 | When profile mode is off, the function runs without JIT. When profile mode
106 | is on, the function is JIT compiled and blocks outputs to ensure proper
107 | timing.
108 |
109 | Parameters
110 | ----------
111 | func
112 | Function to wrap.
113 | block_before
114 | If True block inputs before passing them to the JIT-compiled function.
115 | This ensures that any pending computations are completed before entering
116 | the JIT-compiled function. This phase is not included in the trace
117 | event.
118 | **kwargs
119 | Additional arguments to pass to `jax.jit`.
120 |
121 | Returns
122 | -------
123 | Wrapped function.
124 |
125 | Notes
126 | -----
127 | Under profiling mode, the function invocation is handled such that a custom
128 | jax trace event with name `jab[]` is created. The statistics on
129 | the actual Python function will be off, while the function
130 | `jab_inner_wrapper` represents the actual execution time.
131 | """
132 | jitted_func = jit(func, **kwargs)
133 |
134 | event_name = f'jab[{func.__name__}]'
135 |
136 | # this wrapper is meant to measure the time spent executing the function
137 | def jab_inner_wrapper(*args, **kwargs) -> T:
138 | with TraceAnnotation(event_name):
139 | result = jitted_func(*args, **kwargs)
140 | return block_until_ready(result)
141 |
142 | @wraps(func)
143 | def jab_outer_wrapper(*args: Any, **kwargs: Any) -> T:
144 | if get_profile_mode():
145 | if block_before:
146 | args, kwargs = block_until_ready((args, kwargs))
147 | return jab_inner_wrapper(*args, **kwargs)
148 | else:
149 | return func(*args, **kwargs)
150 |
151 | return jab_outer_wrapper
152 |
153 |
154 | def jit_if_not_profiling(func: Callable[..., T], *args, **kwargs) -> Callable[..., T]:
155 | """Apply JIT compilation only when not profiling.
156 |
157 | When profile mode is off, the function is JIT compiled. When profile mode is
158 | on, the function runs as-is.
159 |
160 | Parameters
161 | ----------
162 | func
163 | Function to wrap.
164 | *args
165 | **kwargs
166 | Additional arguments to pass to `jax.jit`.
167 |
168 | Returns
169 | -------
170 | Wrapped function.
171 | """
172 | jitted_func = jit(func, *args, **kwargs)
173 |
174 | @wraps(func)
175 | def wrapper(*args: Any, **kwargs: Any) -> T:
176 | if get_profile_mode():
177 | return func(*args, **kwargs)
178 | else:
179 | return jitted_func(*args, **kwargs)
180 |
181 | return wrapper
182 |
183 |
184 | def scan_if_not_profiling(
185 | f: Callable[[Carry, None], tuple[Carry, None]],
186 | init: Carry,
187 | xs: None,
188 | length: int,
189 | /,
190 | ) -> tuple[Carry, None]:
191 | """Restricted replacement for `jax.lax.scan` that uses a Python loop when profiling.
192 |
193 | Parameters
194 | ----------
195 | f
196 | Scan body function with signature (carry, None) -> (carry, None).
197 | init
198 | Initial carry value.
199 | xs
200 | Input values to scan over (not supported).
201 | length
202 | Integer specifying the number of loop iterations.
203 |
204 | Returns
205 | -------
206 | Tuple of (final_carry, None) (stacked outputs not supported).
207 | """
208 | assert xs is None
209 | if get_profile_mode():
210 | carry = init
211 | for _i in range(length):
212 | carry, _ = f(carry, None)
213 | return carry, None
214 |
215 | else:
216 | return scan(f, init, None, length)
217 |
218 |
219 | def cond_if_not_profiling(
220 | pred: bool | Bool[Array, ''],
221 | true_fun: Callable[..., T],
222 | false_fun: Callable[..., T],
223 | /,
224 | *operands,
225 | ) -> T:
226 | """Restricted replacement for `jax.lax.cond` that uses a Python if when profiling.
227 |
228 | Parameters
229 | ----------
230 | pred
231 | Boolean predicate to choose which function to execute.
232 | true_fun
233 | Function to execute if `pred` is True.
234 | false_fun
235 | Function to execute if `pred` is False.
236 | *operands
237 | Arguments passed to `true_fun` and `false_fun`.
238 |
239 | Returns
240 | -------
241 | Result of either `true_fun()` or `false_fun()`.
242 | """
243 | if get_profile_mode():
244 | if pred:
245 | return true_fun(*operands)
246 | else:
247 | return false_fun(*operands)
248 | else:
249 | return cond(pred, true_fun, false_fun, *operands)
250 |
251 |
252 | def callback_if_not_profiling(
253 | callback: Callable[..., None], *args: Any, ordered: bool = False, **kwargs: Any
254 | ):
255 | """Restricted replacement for `jax.debug.callback` that calls the callback directly in profiling mode."""
256 | if get_profile_mode():
257 | callback(*args, **kwargs)
258 | else:
259 | debug.callback(callback, *args, ordered=ordered, **kwargs)
260 |
--------------------------------------------------------------------------------
/asv.conf.json:
--------------------------------------------------------------------------------
1 | {
2 | // The version of the config file format. Do not change, unless
3 | // you know what you are doing.
4 | "version": 1,
5 |
6 | // The name of the project being benchmarked
7 | "project": "bartz",
8 |
9 | // The project's homepage
10 | "project_url": "https://github.com/Gattocrucco/bartz",
11 |
12 | // The URL or local path of the source code repository for the
13 | // project being benchmarked
14 | "repo": ".",
15 |
16 | // The Python project's subdirectory in your repo. If missing or
17 | // the empty string, the project is assumed to be located at the root
18 | // of the repository.
19 | // "repo_subdir": "",
20 |
21 | // Customizable commands for building the project.
22 | // See asv.conf.json documentation.
23 | "build_command": [
24 | "python -m build --wheel -o {build_cache_dir} {build_dir}"
25 | ],
26 | // To build the package using setuptools and a setup.py file, uncomment the following lines
27 | // "build_command": [
28 | // "python setup.py build",
29 | // "python -mpip wheel -w {build_cache_dir} {build_dir}"
30 | // ],
31 |
32 | // Customizable commands for installing and uninstalling the project.
33 | // See asv.conf.json documentation.
34 | // "install_command": ["in-dir={env_dir} python -mpip install {wheel_file}"],
35 | // "uninstall_command": ["return-code=any python -mpip uninstall -y {project}"],
36 |
37 | // List of branches to benchmark. If not provided, defaults to "main"
38 | // (for git) or "default" (for mercurial).
39 | "branches": [
40 | "main",
41 | ],
42 |
43 | // The DVCS being used. If not set, it will be automatically
44 | // determined from "repo" by looking at the protocol in the URL
45 | // (if remote), or by looking for special directories, such as
46 | // ".git" (if local).
47 | // "dvcs": "git",
48 |
49 | // The tool to use to create environments. May be "conda",
50 | // "virtualenv", "mamba" (above 3.8)
51 | // or other value depending on the plugins in use.
52 | // If missing or the empty string, the tool will be automatically
53 | // determined by looking for tools on the PATH environment
54 | // variable.
55 | "environment_type": "virtualenv",
56 |
57 | // timeout in seconds for installing any dependencies in environment
58 | // defaults to 10 min
59 | //"install_timeout": 600,
60 |
61 | // the base URL to show a commit for the project.
62 | "show_commit_url": "https://github.com/Gattocrucco/bartz/commit/",
63 |
64 | // The Pythons you'd like to test against. If not provided, defaults
65 | // to the current version of Python used to run `asv`.
66 | // "pythons": ["3.8", "3.12"],
67 |
68 | // The list of conda channel names to be searched for benchmark
69 | // dependency packages in the specified order
70 | // "conda_channels": ["conda-forge", "defaults"],
71 |
72 | // A conda environment file that is used for environment creation.
73 | // "conda_environment_file": "environment.yml",
74 |
75 | // The matrix of dependencies to test. Each key of the "req"
76 | // requirements dictionary is the name of a package (in PyPI) and
77 | // the values are version numbers. An empty list or empty string
78 | // indicates to just test against the default (latest)
79 | // version. null indicates that the package is to not be
80 | // installed. If the package to be tested is only available from
81 | // PyPi, and the 'environment_type' is conda, then you can preface
82 | // the package name by 'pip+', and the package will be installed
83 | // via pip (with all the conda available packages installed first,
84 | // followed by the pip installed packages).
85 | //
86 | // The ``@env`` and ``@env_nobuild`` keys contain the matrix of
87 | // environment variables to pass to build and benchmark commands.
88 | // An environment will be created for every combination of the
89 | // cartesian product of the "@env" variables in this matrix.
90 | // Variables in "@env_nobuild" will be passed to every environment
91 | // during the benchmark phase, but will not trigger creation of
92 | // new environments. A value of ``null`` means that the variable
93 | // will not be set for the current combination.
94 | //
95 | // "matrix": {
96 | // "req": {
97 | // "numpy": ["1.6", "1.7"],
98 | // "six": ["", null], // test with and without six installed
99 | // "pip+emcee": [""] // emcee is only available for install with pip.
100 | // },
101 | // "env": {"ENV_VAR_1": ["val1", "val2"]},
102 | // "env_nobuild": {"ENV_VAR_2": ["val3", null]},
103 | // },
104 |
105 | // Combinations of libraries/python versions can be excluded/included
106 | // from the set to test. Each entry is a dictionary containing additional
107 | // key-value pairs to include/exclude.
108 | //
109 | // An exclude entry excludes entries where all values match. The
110 | // values are regexps that should match the whole string.
111 | //
112 | // An include entry adds an environment. Only the packages listed
113 | // are installed. The 'python' key is required. The exclude rules
114 | // do not apply to includes.
115 | //
116 | // In addition to package names, the following keys are available:
117 | //
118 | // - python
119 | // Python version, as in the *pythons* variable above.
120 | // - environment_type
121 | // Environment type, as above.
122 | // - sys_platform
123 | // Platform, as in sys.platform. Possible values for the common
124 | // cases: 'linux2', 'win32', 'cygwin', 'darwin'.
125 | // - req
126 | // Required packages
127 | // - env
128 | // Environment variables
129 | // - env_nobuild
130 | // Non-build environment variables
131 | //
132 | // "exclude": [
133 | // {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows
134 | // {"environment_type": "conda", "req": {"six": null}}, // don't run without six on conda
135 | // {"env": {"ENV_VAR_1": "val2"}}, // skip val2 for ENV_VAR_1
136 | // ],
137 | //
138 | // "include": [
139 | // // additional env for python3.12
140 | // {"python": "3.12", "req": {"numpy": "1.26"}, "env_nobuild": {"FOO": "123"}},
141 | // // additional env if run on windows+conda
142 | // {"platform": "win32", "environment_type": "conda", "python": "3.12", "req": {"libpython": ""}},
143 | // ],
144 |
145 | // The directory (relative to the current directory) that benchmarks are
146 | // stored in. If not provided, defaults to "benchmarks"
147 | // "benchmark_dir": "benchmarks",
148 |
149 | // The directory (relative to the current directory) to cache the Python
150 | // environments in. If not provided, defaults to "env"
151 | "env_dir": ".asv/env",
152 |
153 | // The directory (relative to the current directory) that raw benchmark
154 | // results are stored in. If not provided, defaults to "results".
155 | "results_dir": ".asv/results",
156 |
157 | // The directory (relative to the current directory) that the html tree
158 | // should be written to. If not provided, defaults to "html".
159 | "html_dir": "_site/benchmarks",
160 |
161 | // The number of characters to retain in the commit hashes.
162 | // "hash_length": 8,
163 |
164 | // `asv` will cache results of the recent builds in each
165 | // environment, making them faster to install next time. This is
166 | // the number of builds to keep, per environment.
167 | // "build_cache_size": 2,
168 |
169 | // The commits after which the regression search in `asv publish`
170 | // should start looking for regressions. Dictionary whose keys are
171 | // regexps matching to benchmark names, and values corresponding to
172 | // the commit (exclusive) after which to start looking for
173 | // regressions. The default is to start from the first commit
174 | // with results. If the commit is `null`, regression detection is
175 | // skipped for the matching benchmark.
176 | //
177 | // "regressions_first_commits": {
178 | // "some_benchmark": "352cdf", // Consider regressions only after this commit
179 | // "another_benchmark": null, // Skip regression detection altogether
180 | // },
181 |
182 | // The thresholds for relative change in results, after which `asv
183 | // publish` starts reporting regressions. Dictionary of the same
184 | // form as in ``regressions_first_commits``, with values
185 | // indicating the thresholds. If multiple entries match, the
186 | // maximum is taken. If no entry matches, the default is 5%.
187 | //
188 | // "regressions_thresholds": {
189 | // "some_benchmark": 0.01, // Threshold of 1%
190 | // "another_benchmark": 0.5, // Threshold of 50%
191 | // },
192 | }
193 |
--------------------------------------------------------------------------------
/src/bartz/jaxext/scipy/special.py:
--------------------------------------------------------------------------------
1 | # bartz/src/bartz/jaxext/scipy/special.py
2 | #
3 | # Copyright (c) 2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | """Mockup of the :external:py:mod:`scipy.special` module."""
26 |
27 | from functools import wraps
28 |
29 | from jax import ShapeDtypeStruct, jit, pure_callback
30 | from jax import numpy as jnp
31 | from scipy.special import gammainccinv as scipy_gammainccinv
32 |
33 |
34 | def _float_type(*args):
35 | """Determine the jax floating point result type given operands/types."""
36 | t = jnp.result_type(*args)
37 | return jnp.sin(jnp.empty(0, t)).dtype
38 |
39 |
40 | def _castto(func, dtype):
41 | @wraps(func)
42 | def newfunc(*args, **kw):
43 | return func(*args, **kw).astype(dtype)
44 |
45 | return newfunc
46 |
47 |
48 | @jit
49 | def gammainccinv(a, y):
50 | """Survival function inverse of the Gamma(a, 1) distribution."""
51 | shape = jnp.broadcast_shapes(a.shape, y.shape)
52 | dtype = _float_type(a.dtype, y.dtype)
53 | dummy = ShapeDtypeStruct(shape, dtype)
54 | ufunc = _castto(scipy_gammainccinv, dtype)
55 | return pure_callback(ufunc, dummy, a, y, vmap_method='expand_dims')
56 |
57 |
58 | ################# COPIED AND ADAPTED FROM JAX ##################
59 | # Copyright 2018 The JAX Authors.
60 | #
61 | # Licensed under the Apache License, Version 2.0 (the "License");
62 | # you may not use this file except in compliance with the License.
63 | # You may obtain a copy of the License at
64 | #
65 | # https://www.apache.org/licenses/LICENSE-2.0
66 | #
67 | # Unless required by applicable law or agreed to in writing, software
68 | # distributed under the License is distributed on an "AS IS" BASIS,
69 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
70 | # See the License for the specific language governing permissions and
71 | # limitations under the License.
72 |
73 | import numpy as np
74 | from jax import debug_infs, lax
75 |
76 |
77 | def ndtri(p):
78 | """Compute the inverse of the CDF of the Normal distribution function.
79 |
80 | This is a patch of `jax.scipy.special.ndtri`.
81 | """
82 | dtype = lax.dtype(p)
83 | if dtype not in (jnp.float32, jnp.float64):
84 | msg = f'x.dtype={dtype} is not supported, see docstring for supported types.'
85 | raise TypeError(msg)
86 | return _ndtri(p)
87 |
88 |
89 | def _ndtri(p):
90 | # Constants used in piece-wise rational approximations. Taken from the cephes
91 | # library:
92 | # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
93 | p0 = list(
94 | reversed(
95 | [
96 | -5.99633501014107895267e1,
97 | 9.80010754185999661536e1,
98 | -5.66762857469070293439e1,
99 | 1.39312609387279679503e1,
100 | -1.23916583867381258016e0,
101 | ]
102 | )
103 | )
104 | q0 = list(
105 | reversed(
106 | [
107 | 1.0,
108 | 1.95448858338141759834e0,
109 | 4.67627912898881538453e0,
110 | 8.63602421390890590575e1,
111 | -2.25462687854119370527e2,
112 | 2.00260212380060660359e2,
113 | -8.20372256168333339912e1,
114 | 1.59056225126211695515e1,
115 | -1.18331621121330003142e0,
116 | ]
117 | )
118 | )
119 | p1 = list(
120 | reversed(
121 | [
122 | 4.05544892305962419923e0,
123 | 3.15251094599893866154e1,
124 | 5.71628192246421288162e1,
125 | 4.40805073893200834700e1,
126 | 1.46849561928858024014e1,
127 | 2.18663306850790267539e0,
128 | -1.40256079171354495875e-1,
129 | -3.50424626827848203418e-2,
130 | -8.57456785154685413611e-4,
131 | ]
132 | )
133 | )
134 | q1 = list(
135 | reversed(
136 | [
137 | 1.0,
138 | 1.57799883256466749731e1,
139 | 4.53907635128879210584e1,
140 | 4.13172038254672030440e1,
141 | 1.50425385692907503408e1,
142 | 2.50464946208309415979e0,
143 | -1.42182922854787788574e-1,
144 | -3.80806407691578277194e-2,
145 | -9.33259480895457427372e-4,
146 | ]
147 | )
148 | )
149 | p2 = list(
150 | reversed(
151 | [
152 | 3.23774891776946035970e0,
153 | 6.91522889068984211695e0,
154 | 3.93881025292474443415e0,
155 | 1.33303460815807542389e0,
156 | 2.01485389549179081538e-1,
157 | 1.23716634817820021358e-2,
158 | 3.01581553508235416007e-4,
159 | 2.65806974686737550832e-6,
160 | 6.23974539184983293730e-9,
161 | ]
162 | )
163 | )
164 | q2 = list(
165 | reversed(
166 | [
167 | 1.0,
168 | 6.02427039364742014255e0,
169 | 3.67983563856160859403e0,
170 | 1.37702099489081330271e0,
171 | 2.16236993594496635890e-1,
172 | 1.34204006088543189037e-2,
173 | 3.28014464682127739104e-4,
174 | 2.89247864745380683936e-6,
175 | 6.79019408009981274425e-9,
176 | ]
177 | )
178 | )
179 |
180 | dtype = lax.dtype(p).type
181 | shape = jnp.shape(p)
182 |
183 | def _create_polynomial(var, coeffs):
184 | """Compute n_th order polynomial via Horner's method."""
185 | coeffs = np.array(coeffs, dtype)
186 | if not coeffs.size:
187 | return jnp.zeros_like(var)
188 | return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var
189 |
190 | maybe_complement_p = jnp.where(p > dtype(-np.expm1(-2.0)), dtype(1.0) - p, p)
191 | # Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
192 | # later on. The result from the computation when p == 0 is not used so any
193 | # number that doesn't result in NaNs is fine.
194 | sanitized_mcp = jnp.where(
195 | maybe_complement_p == dtype(0.0),
196 | jnp.full(shape, dtype(0.5)),
197 | maybe_complement_p,
198 | )
199 |
200 | # Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
201 | w = sanitized_mcp - dtype(0.5)
202 | ww = lax.square(w)
203 | x_for_big_p = w + w * ww * (_create_polynomial(ww, p0) / _create_polynomial(ww, q0))
204 | x_for_big_p *= -dtype(np.sqrt(2.0 * np.pi))
205 |
206 | # Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
207 | # where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
208 | # arrays based on whether p < exp(-32).
209 | z = lax.sqrt(dtype(-2.0) * lax.log(sanitized_mcp))
210 | first_term = z - lax.log(z) / z
211 | second_term_small_p = (
212 | _create_polynomial(dtype(1.0) / z, p2)
213 | / _create_polynomial(dtype(1.0) / z, q2)
214 | / z
215 | )
216 | second_term_otherwise = (
217 | _create_polynomial(dtype(1.0) / z, p1)
218 | / _create_polynomial(dtype(1.0) / z, q1)
219 | / z
220 | )
221 | x_for_small_p = first_term - second_term_small_p
222 | x_otherwise = first_term - second_term_otherwise
223 |
224 | x = jnp.where(
225 | sanitized_mcp > dtype(np.exp(-2.0)),
226 | x_for_big_p,
227 | jnp.where(z >= dtype(8.0), x_for_small_p, x_otherwise),
228 | )
229 |
230 | x = jnp.where(p > dtype(1.0 - np.exp(-2.0)), x, -x)
231 | with debug_infs(False):
232 | infinity = jnp.full(shape, dtype(np.inf))
233 | neg_infinity = -infinity
234 | return jnp.where(
235 | p == dtype(0.0), neg_infinity, jnp.where(p == dtype(1.0), infinity, x)
236 | )
237 |
238 |
239 | ################################################################
240 |
--------------------------------------------------------------------------------
/.asv/results/benchmarks.json:
--------------------------------------------------------------------------------
1 | {
2 | "rmse.EvalGbart.track_rmse": {
3 | "code": "class EvalGbart:\n def track_rmse(self) -> float:\n \"\"\"Return the RMSE for predictions on a test set.\"\"\"\n key = random.key(2025_06_26_21_02)\n data = make_data(key, 100, 1000, 20)\n with redirect_stdout(StringIO()):\n bart = gbart(\n data.X_train,\n data.y_train,\n x_test=data.X_test,\n nskip=1000,\n ndpost=1000,\n seed=key,\n )\n return jnp.sqrt(jnp.mean(jnp.square(bart.yhat_test_mean - data.mu_test))).item()",
4 | "name": "rmse.EvalGbart.track_rmse",
5 | "param_names": [],
6 | "params": [],
7 | "timeout": 30.0,
8 | "type": "track",
9 | "unit": "latent_sdev",
10 | "version": "afd40ad3255f218a76e6833332dc91afa0d19ac0f6daf1b7b9c75664c4586d28"
11 | },
12 | "speed.TimeGbart.time_gbart": {
13 | "code": "class TimeGbart:\n def time_gbart(self, *_):\n \"\"\"Time instantiating the class.\"\"\"\n with redirect_stdout(StringIO()):\n bart = gbart(**self.kw)\n block_until_ready((bart._mcmc_state, bart._main_trace))\n\n def setup(self, niters: int, cache: Cache, nchains: int):\n \"\"\"Prepare the arguments and run once to warm-up.\"\"\"\n # check support for multiple chains\n if (niters == 0 or cache == 'cold') and nchains > 1:\n msg = 'skip multi-chain with 0 iterations or cold cache'\n raise NotImplementedError(msg)\n \n sig = signature(gbart)\n support_multichain = 'mc_cores' in sig.parameters\n if nchains != 1 and not support_multichain:\n msg = 'multi-chain not supported'\n raise NotImplementedError(msg)\n \n # random seed\n key = random.key(2025_06_24_14_55)\n keys = list(random.split(key, 3))\n \n # generate simulated data\n sigma = 0.1\n T = 2\n X = random.uniform(keys.pop(), (P, N), float, -2, 2)\n f = lambda X: jnp.sum(jnp.cos(2 * jnp.pi / T * X), axis=0)\n y = f(X) + sigma * random.normal(keys.pop(), (N,))\n \n # arguments\n self.kw = dict(\n x_train=X,\n y_train=y,\n nskip=niters // 2,\n ndpost=(niters - niters // 2) * nchains,\n seed=keys.pop(),\n )\n if support_multichain:\n self.kw.update(mc_cores=nchains)\n \n # decide how much to cold-start\n match cache:\n case 'cold':\n clear_caches()\n case 'warm':\n self.time_gbart()",
14 | "min_run_count": 2,
15 | "name": "speed.TimeGbart.time_gbart",
16 | "number": 1,
17 | "param_names": [
18 | "niters",
19 | "cache",
20 | "nchains"
21 | ],
22 | "params": [
23 | [
24 | "0",
25 | "10"
26 | ],
27 | [
28 | "'cold'",
29 | "'warm'"
30 | ],
31 | [
32 | "1",
33 | "2",
34 | "8",
35 | "32"
36 | ]
37 | ],
38 | "repeat": 0,
39 | "rounds": 2,
40 | "sample_time": 0.01,
41 | "timeout": 30.0,
42 | "type": "time",
43 | "unit": "seconds",
44 | "version": "127118a7023d5587f25f10a56bbccd86f8200efd5a4c22fe60fa246df36599a4",
45 | "warmup_time": 0.0
46 | },
47 | "speed.TimeRunMcmc.time_run_mcmc": {
48 | "code": "class TimeRunMcmc:\n @skip_for_params(list(product(['compile'], [0], params[2])))\n def time_run_mcmc(self, mode: Mode, *_):\n \"\"\"Time running or compiling the function.\"\"\"\n match mode:\n case 'compile':\n # re-wrap and jit the function in the benchmark case because otherwise\n # the compiled function gets cached even if I call `compile` explicitly\n @partial(\n jit, static_argnames=('n_save', 'n_skip', 'n_burn', 'callback')\n )\n def f(**kw):\n return run_mcmc(**kw)\n \n f.lower(**self.kw).compile()\n \n case 'run':\n block_until_ready(run_mcmc(**self.kw))\n\n def setup(self, mode: Mode, niters: int, cache: Cache):\n \"\"\"Prepare the arguments, compile the function, and run to warm-up.\"\"\"\n self.kw = dict(\n key=random.key(2025_04_25_15_57),\n bart=simple_init(P, N, NTREE),\n n_save=niters // 2,\n n_burn=niters // 2,\n n_skip=1,\n callback=lambda **_: None,\n )\n \n # adapt arguments for old versions\n sig = signature(run_mcmc)\n if 'callback' not in sig.parameters:\n self.kw.pop('callback')\n \n # catch bug and skip if found\n try:\n array_kw = {k: v for k, v in self.kw.items() if isinstance(v, jnp.ndarray)}\n nonarray_kw = {\n k: v for k, v in self.kw.items() if not isinstance(v, jnp.ndarray)\n }\n partial_run_mcmc = partial(run_mcmc, **nonarray_kw)\n eval_shape(partial_run_mcmc, **array_kw)\n except ZeroDivisionError:\n if niters:\n raise\n else:\n msg = 'skipping due to division by zero bug with zero iterations'\n raise NotImplementedError(msg) from None\n \n # decide how much to cold-start\n match cache:\n case 'cold':\n clear_caches()\n case 'warm':\n # prepare copies of the args because of buffer donation\n key = jnp.copy(self.kw['key'])\n bart = tree_map(jnp.copy, self.kw['bart'])\n self.time_run_mcmc(mode)\n # put copies in place of donated buffers\n self.kw.update(key=key, bart=bart)",
49 | "min_run_count": 2,
50 | "name": "speed.TimeRunMcmc.time_run_mcmc",
51 | "number": 1,
52 | "param_names": [
53 | "mode",
54 | "niters",
55 | "cache"
56 | ],
57 | "params": [
58 | [
59 | "'compile'",
60 | "'run'"
61 | ],
62 | [
63 | "0",
64 | "10"
65 | ],
66 | [
67 | "'cold'",
68 | "'warm'"
69 | ]
70 | ],
71 | "repeat": 0,
72 | "rounds": 2,
73 | "sample_time": 0.01,
74 | "timeout": 30.0,
75 | "type": "time",
76 | "unit": "seconds",
77 | "version": "38ac9d8c1419fab7715c0dd625229b7a900c5dece4f6bc91a8ab8074e9c9281c",
78 | "warmup_time": 0.0
79 | },
80 | "speed.TimeStep.time_step": {
81 | "code": "class TimeStep:\n def time_step(self, mode: Mode, _):\n \"\"\"Time running compiling `step` or running it a few times.\"\"\"\n match mode:\n case 'compile':\n \n @jit\n def f(*args):\n return self.func(*args)\n \n f.lower(*self.args).compile()\n \n case 'run':\n block_until_ready(self.compiled_func(*self.args))\n\n def setup(self, mode: Mode, kind: Kind):\n \"\"\"Create an initial MCMC state and random seed, compile & warm-up.\"\"\"\n key = random.key(2025_06_24_12_07)\n if kind.startswith('vmap-'):\n length = int(kind.split('-')[1])\n keys = list(random.split(key, (2, length)))\n else:\n keys = list(random.split(key))\n \n self.args = (keys, simple_init(P, N, NTREE, kind))\n \n def func(keys, bart):\n bart = step(key=keys.pop(), bart=bart)\n if kind == 'sparse':\n bart = mcmcstep.step_sparse(keys.pop(), bart)\n return bart\n \n if kind.startswith('vmap-'):\n axes = vmap_axes_for_state(self.args[1])\n func = vmap(func, in_axes=(0, axes), out_axes=axes)\n \n self.func = func\n self.compiled_func = jit(func).lower(*self.args).compile()\n if mode == 'run':\n block_until_ready(self.compiled_func(*self.args))",
82 | "min_run_count": 2,
83 | "name": "speed.TimeStep.time_step",
84 | "number": 0,
85 | "param_names": [
86 | "mode",
87 | "kind"
88 | ],
89 | "params": [
90 | [
91 | "'compile'",
92 | "'run'"
93 | ],
94 | [
95 | "'plain'",
96 | "'binary'",
97 | "'weights'",
98 | "'sparse'",
99 | "'vmap-1'",
100 | "'vmap-2'"
101 | ]
102 | ],
103 | "repeat": 0,
104 | "rounds": 2,
105 | "sample_time": 0.01,
106 | "timeout": 30.0,
107 | "type": "time",
108 | "unit": "seconds",
109 | "version": "fed013a2153c474f5396625e7a319060a4a7d94e98d92dbd9b48f275c7a082e7",
110 | "warmup_time": -1
111 | },
112 | "version": 2
113 | }
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | # bartz/pyproject.toml
2 | #
3 | # Copyright (c) 2024-2025, The Bartz Contributors
4 | #
5 | # This file is part of bartz.
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 | [build-system]
26 | requires = ["uv_build>=0.9.5,<0.10.0"]
27 | build-backend = "uv_build"
28 |
29 | [project]
30 | name = "bartz"
31 | version = "0.7.0"
32 | description = "Super-fast BART (Bayesian Additive Regression Trees) in Python"
33 | authors = [{ name = "Giacomo Petrillo", email = "info@giacomopetrillo.com" }]
34 | license = "MIT"
35 | readme = "README.md"
36 | requires-python = ">=3.10"
37 | dependencies = [
38 | "equinox>=0.12.2",
39 | "jax>=0.5.3",
40 | "jaxtyping>=0.3.2",
41 | "numpy>=1.25.2",
42 | "scipy>=1.11.4",
43 | ]
44 |
45 | [project.urls]
46 | Homepage = "https://github.com/Gattocrucco/bartz"
47 | Documentation = "https://gattocrucco.github.io/bartz/docs-dev"
48 | Issues = "https://github.com/Gattocrucco/bartz/issues"
49 |
50 | [dependency-groups]
51 | only-local = [
52 | "appnope>=0.1.4",
53 | "ipython>=8.36.0",
54 | "matplotlib>=3.10.3",
55 | "matplotlib-label-lines>=0.8.1",
56 | "pre-commit>=4.2.0",
57 | "scikit-learn>=1.6.1",
58 | "snakeviz>=2.2.2",
59 | "virtualenv>=20.31.2",
60 | "xgboost>=3.0.0",
61 | ]
62 | ci = [
63 | "asv>=0.6.4",
64 | "flaky>=3.8.1",
65 | "gitpython>=3.1.43",
66 | "myst-parser>=4.0.1",
67 | "packaging>=25.0",
68 | "polars[pandas,pyarrow]>=1.29.0",
69 | "pytest>=8.3.5",
70 | "pytest-cov>=6.1.1",
71 | "pytest-timeout>=2.4.0",
72 | "pytest-timer[termcolor]>=1.0.0",
73 | "pytest-xdist>=3.6.1",
74 | "rpy2>=3.5.17",
75 | "sphinx>=8.1.3",
76 | "sphinx-autodoc-typehints>=3.0.1",
77 | "tomli>=2.2.1",
78 | ]
79 |
80 | [tool.pytest.ini_options]
81 | cache_dir = "config/pytest_cache"
82 | testpaths = ["tests"]
83 | addopts = [
84 | "-r xXfE",
85 | "--pdbcls=IPython.terminal.debugger:TerminalPdb",
86 | "--durations=3",
87 | "--verbose",
88 | "--import-mode=importlib",
89 | ]
90 | filterwarnings = [
91 | "ignore:unclosed database:ResourceWarning",
92 | ]
93 | timeout = 512
94 | timeout_method = "thread" # when jax hangs, signals do not work
95 |
96 | [tool.coverage.run]
97 | branch = true
98 | source_pkgs = ["bartz", "tests"]
99 |
100 | [tool.coverage.report]
101 | show_missing = true
102 |
103 | [tool.coverage.html]
104 | show_contexts = true
105 | directory = "_site/coverage"
106 |
107 | [tool.coverage.paths]
108 | # the first path in each list must be the source directory in the machine that's
109 | # generating the coverage report
110 |
111 | github = [
112 | '/home/runner/work/bartz/bartz/src/bartz/',
113 | '/Users/runner/work/bartz/bartz/src/bartz/',
114 | 'D:\a\bartz\bartz\src\bartz\',
115 | '/Library/Frameworks/Python.framework/Versions/*/lib/python*/site-packages/bartz/',
116 | '/Users/runner/hostedtoolcache/Python/*/*/lib/python*/site-packages/bartz/',
117 | '/opt/hostedtoolcache/Python/*/*/lib/python*/site-packages/bartz/',
118 | 'C:\hostedtoolcache\windows\Python\*\*\Lib\site-packages\bartz\',
119 | ]
120 |
121 | local = [
122 | 'src/bartz/',
123 | '/home/runner/work/bartz/bartz/src/bartz/',
124 | '/Users/runner/work/bartz/bartz/src/bartz/',
125 | 'D:\a\bartz\bartz\src\bartz\',
126 | '/Library/Frameworks/Python.framework/Versions/*/lib/python*/site-packages/bartz/',
127 | '/Users/runner/hostedtoolcache/Python/*/*/lib/python*/site-packages/bartz/',
128 | '/opt/hostedtoolcache/Python/*/*/lib/python*/site-packages/bartz/',
129 | 'C:\hostedtoolcache\windows\Python\*\*\Lib\site-packages\bartz\',
130 | ]
131 |
132 | [tool.ruff]
133 | exclude = [".asv", "*.ipynb"]
134 | cache-dir = "config/ruff_cache"
135 |
136 | [tool.ruff.format]
137 | quote-style = "single"
138 | skip-magic-trailing-comma = true
139 |
140 | [tool.ruff.lint.isort]
141 | split-on-trailing-comma = false
142 |
143 | [tool.ruff.lint]
144 | select = [
145 | "ERA", # eradicate
146 | "S", # flake8-bandit
147 | "BLE", # flake8-blind-except
148 | "B", # bugbear
149 | "A", # flake8-builtins
150 | "C4", # flake8-comprehensions
151 | "CPY", # flake8-copyright
152 | "DTZ", # flake8-datetimez
153 | "T10", # flake8-debugger
154 | "EM", # flake8-errmsg
155 | "EXE", # flake8-executable
156 | "FIX", # flake8-fixme
157 | "ISC", # flake8-implicit-str-concat
158 | "INP", # flake8-no-pep420
159 | "PIE", # flake8-pie
160 | "T20", # flake8-print
161 | "PT", # flake8-pytest-style
162 | "RSE", # flake8-raise
163 | "RET", # flake8-return
164 | "SLF", # flake8-self
165 | "SIM", # flake8-simplify
166 | "TID", # flake8-tidy-imports
167 | "ARG", # flake8-unused-arguments
168 | "PTH", # flake8-use-pathlib
169 | "FLY", # flynt
170 | "I", # isort
171 | "C90", # mccabe
172 | "NPY", # NumPy-specific rules
173 | "PERF", # Perflint
174 | "W", # pycodestyle Warning
175 | "F", # pyflakes
176 | "D", # pydocstyle
177 | "PGH", # pygrep-hooks
178 | "PLC", # Pylint Convention
179 | "PLE", # Pylint Error
180 | "PLR", # Pylint Refactor
181 | "PLW", # Pyling Warning
182 | "UP", # pyupgrade
183 | "FURB", # refurb
184 | "RUF", # Ruff-specific rules
185 | "TRY", # tryceratops
186 | ]
187 | ignore = [
188 | "B028", # warn with stacklevel = 2
189 | "C408", # Unnecessary `dict()` call (rewrite as a literal), it's too convenient for kwargs
190 | "D105", # Missing docstring in magic method
191 | "F722", # Syntax error in forward annotation. I ignore this because jaxtyping uses strings for shapes instead of for deferred annotations.
192 | "PIE790", # Unnecessary ... or pass. Ignored because sometimes I use ... as sentinel to tell the rest of ruff and pyright that an implementation is a stub.
193 | "PLR0913", # Too many arguments in function definition. Maybe I should do something about this?
194 | "PLR2004", # Magic value used in comparison, consider replacing `*` with a constant variable
195 | "RET505", # Unnecessary `{branch}` after `return` statement. I ignore this because I like to keep branches for readability.
196 | "RET506", # Unnecessary `else` after `raise` statement. I ignore this because I like to keep branches for readability.
197 | "S101", # Use of `assert` detected. Too annoying.
198 | "SIM108", # SIM108 Use ternary operator `*` instead of `if`-`else`-block, I find blocks more readable
199 | "UP037", # Remove quotes from type annotation. Ignore because jaxtyping.
200 | ]
201 |
202 | [tool.ruff.lint.per-file-ignores]
203 | "{config/*,docs/*}" = [
204 | "D100", # Missing docstring in public module
205 | "D101", # Missing docstring in public class
206 | "D102", # Missing docstring in public method
207 | "D103", # Missing docstring in public function
208 | "D104", # Missing docstring in public package
209 | "INP001", # File * is part of an implicit namespace package. Add an `__init__.py`.
210 | ]
211 | "src/bartz/_version.py" = [
212 | "CPY001", # Missing copyright notice at top of file
213 | ]
214 | "{config/*,docs/*,tests/*}" = [
215 | "T201", # `print` found
216 | ]
217 | "{tests/*,benchmarks/*}" = [
218 | "SLF001", # Private member accessed: `*`
219 | "TID253", # `{module}` is banned at the module level
220 | ]
221 | "docs/conf.py" = [
222 | "S607", # Starting a process with a partial executable path. Ignored because for a build script it makes more sense to use PATH.
223 | ]
224 |
225 | [tool.ruff.lint.pydocstyle]
226 | convention = "numpy"
227 |
228 | [tool.ruff.lint.flake8-copyright]
229 | min-file-size = 1
230 |
231 | [tool.ruff.lint.flake8-tidy-imports]
232 | banned-module-level-imports = ["bartz.debug"]
233 | ban-relative-imports = "all"
234 |
235 | [tool.pydoclint]
236 | arg-type-hints-in-signature = true
237 | arg-type-hints-in-docstring = false
238 | check-return-types = false
239 | check-yield-types = false
240 | treat-property-methods-as-class-attributes = true
241 | check-style-mismatch = true
242 | show-filenames-in-every-violation-message = true
243 | check-class-attributes = false
244 | # do not check class attributes because in dataclasses I document them as
245 | # init parameters because they are duplicated in the html docs otherwise.
246 |
--------------------------------------------------------------------------------