├── pytensor ├── py.typed ├── bin │ └── __init__.py ├── link │ ├── c │ │ ├── __init__.py │ │ ├── exceptions.py │ │ ├── c_code │ │ │ └── pytensor_mod_helper.h │ │ └── cvm.py │ ├── mlx │ │ ├── dispatch │ │ │ ├── signal │ │ │ │ ├── __init__.py │ │ │ │ └── conv.py │ │ │ ├── __init__.py │ │ │ ├── extra_ops.py │ │ │ ├── sort.py │ │ │ ├── shape.py │ │ │ └── blockwise.py │ │ └── __init__.py │ ├── numba │ │ ├── dispatch │ │ │ ├── linalg │ │ │ │ ├── __init__.py │ │ │ │ ├── solve │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── utils.py │ │ │ │ │ └── norm.py │ │ │ │ └── decomposition │ │ │ │ │ └── __init__.py │ │ │ ├── signal │ │ │ │ └── __init__.py │ │ │ ├── __init__.py │ │ │ └── sort.py │ │ ├── __init__.py │ │ └── linker.py │ ├── jax │ │ ├── __init__.py │ │ └── dispatch │ │ │ ├── signal │ │ │ ├── __init__.py │ │ │ └── conv.py │ │ │ ├── blas.py │ │ │ ├── sort.py │ │ │ ├── blockwise.py │ │ │ ├── einsum.py │ │ │ ├── __init__.py │ │ │ ├── sparse.py │ │ │ ├── pad.py │ │ │ ├── math.py │ │ │ └── nlinalg.py │ ├── __init__.py │ └── pytorch │ │ └── dispatch │ │ ├── math.py │ │ ├── blas.py │ │ ├── slinalg.py │ │ ├── sort.py │ │ ├── __init__.py │ │ ├── blockwise.py │ │ ├── extra_ops.py │ │ ├── shape.py │ │ └── nlinalg.py ├── misc │ ├── __init__.py │ ├── check_blas_many.sh │ ├── may_share_memory.py │ ├── ordered_set.py │ ├── frozendict.py │ └── pkl_utils.py ├── d3viz │ ├── __init__.py │ ├── css │ │ ├── d3-context-menu.css │ │ └── d3viz.css │ └── js │ │ └── d3-context-menu.js ├── scalar │ ├── __init__.py │ └── sharedvar.py ├── tensor │ ├── _linalg │ │ ├── __init__.py │ │ └── solve │ │ │ └── __init__.py │ ├── linalg.py │ ├── signal │ │ └── __init__.py │ ├── conv │ │ └── __init__.py │ ├── random │ │ ├── __init__.py │ │ ├── rewriting │ │ │ └── __init__.py │ │ └── var.py │ ├── var.py │ ├── exceptions.py │ ├── rewriting │ │ ├── __init__.py │ │ └── einsum.py │ ├── c_code │ │ └── alt_blas_common.h │ └── xlogx.py ├── graph │ ├── rewriting │ │ └── __init__.py │ ├── __init__.py │ └── null_type.py ├── typed_list │ ├── __init__.py │ └── rewriting.py ├── xtensor │ ├── rewriting │ │ ├── __init__.py │ │ ├── math.py │ │ └── utils.py │ └── __init__.py ├── scan │ └── scan_perform_ext.py ├── npy_2_compat.py ├── sparse │ ├── utils.py │ ├── sharedvar.py │ └── __init__.py ├── compile │ └── __init__.py └── updates.py ├── tests ├── __init__.py ├── d3viz │ ├── __init__.py │ ├── models.py │ └── test_d3viz.py ├── graph │ ├── __init__.py │ ├── rewriting │ │ └── __init__.py │ ├── test_utils.py │ └── test_types.py ├── link │ ├── __init__.py │ ├── c │ │ ├── __init__.py │ │ └── c_code │ │ │ └── test_cenum.h │ ├── jax │ │ ├── __init__.py │ │ ├── signal │ │ │ ├── __init__.py │ │ │ └── test_conv.py │ │ ├── test_sort.py │ │ ├── test_einsum.py │ │ ├── test_blas.py │ │ ├── test_blockwise.py │ │ └── test_math.py │ ├── mlx │ │ ├── __init__.py │ │ ├── test_extra_ops.py │ │ ├── test_sort.py │ │ ├── test_math.py │ │ └── test_nlinalg.py │ ├── numba │ │ ├── __init__.py │ │ ├── linalg │ │ │ ├── __init__.py │ │ │ └── solve │ │ │ │ └── __init__.py │ │ └── test_sort.py │ └── pytorch │ │ ├── __init__.py │ │ ├── conftest.py │ │ ├── test_sort.py │ │ ├── test_slinalg.py │ │ ├── test_blas.py │ │ ├── test_math.py │ │ ├── test_blockwise.py │ │ └── test_shape.py ├── misc │ ├── __init__.py │ └── test_pkl_utils.py ├── scan │ └── __init__.py ├── compile │ ├── __init__.py │ └── function │ │ └── __init__.py ├── scalar │ ├── __init__.py │ └── test_type.py ├── tensor │ ├── __init__.py │ ├── conv │ │ └── __init__.py │ ├── linalg │ │ └── __init__.py │ ├── random │ │ ├── __init__.py │ │ └── rewriting │ │ │ └── __init__.py │ ├── signal │ │ └── __init__.py │ ├── rewriting │ │ ├── __init__.py │ │ ├── test_ofg.py │ │ └── test_einsum.py │ ├── test_xlogx.py │ └── test_type_other.py ├── typed_list │ └── __init__.py ├── xtensor │ ├── __init__.py │ └── test_reduction.py ├── sparse │ ├── __init__.py │ ├── test_sharedvar.py │ ├── test_linalg.py │ └── test_utils.py └── test_updates.py ├── LICENSE.txt ├── .gitattributes ├── doc ├── bcast.png ├── .templates │ ├── PLACEHOLDER │ ├── rendered_citation.html │ ├── nb-badges.html │ └── layout.html ├── images │ ├── lstm.png │ ├── PyTensor.png │ ├── talk2010.gif │ ├── talk2010.png │ ├── Elman_srnn.png │ ├── blocksparse.png │ ├── PyTensor_logo.png │ ├── lstm_memorycell.png │ └── github.svg ├── tutorial │ ├── apply.png │ ├── bcast.png │ ├── dlogistic.png │ ├── logistic.png │ ├── pics │ │ ├── d3viz.png │ │ ├── logreg_pydotprint_predict.png │ │ ├── logreg_pydotprint_train.png │ │ └── logreg_pydotprint_prediction.png │ ├── symbolic_graphs.rst │ ├── profiling_example.py │ ├── adding_solution_1.py │ ├── logistic.gp │ ├── index.rst │ └── profiling_example_out.prof ├── extending │ ├── apply.png │ ├── pics │ │ ├── symbolic_graph_opt.png │ │ └── symbolic_graph_unopt.png │ ├── extending_faq.rst │ ├── tips.rst │ └── index.rst ├── library │ ├── tensor │ │ ├── functional.rst │ │ ├── bcast.png │ │ ├── plot_fft.png │ │ ├── random │ │ │ └── distributions.rst │ │ ├── conv.rst │ │ ├── basic_opt.rst │ │ ├── optimize.rst │ │ ├── utils.rst │ │ ├── elemwise.rst │ │ ├── extra_ops.rst │ │ ├── math_opt.rst │ │ ├── io.rst │ │ ├── nlinalg.rst │ │ ├── slinalg.rst │ │ ├── index.rst │ │ └── fft.rst │ ├── d3viz │ │ ├── examples │ │ │ ├── mlp.png │ │ │ ├── mlp2.pdf │ │ │ ├── mlp2.png │ │ │ └── d3viz │ │ │ │ ├── css │ │ │ │ ├── d3-context-menu.css │ │ │ │ └── d3viz.css │ │ │ │ └── js │ │ │ │ └── d3-context-menu.js │ │ └── index_files │ │ │ ├── index_10_0.png │ │ │ ├── index_11_0.png │ │ │ ├── index_24_0.png │ │ │ └── index_25_0.png │ ├── xtensor │ │ ├── linalg.md │ │ ├── random.md │ │ ├── math.md │ │ ├── module_functions.md │ │ └── type.md │ ├── compile │ │ ├── ops.rst │ │ ├── profilemode.rst │ │ ├── index.rst │ │ └── opfromgraph.rst │ ├── graph │ │ ├── op.rst │ │ ├── graph.rst │ │ ├── replace.rst │ │ ├── index.rst │ │ ├── type.rst │ │ ├── utils.rst │ │ ├── features.rst │ │ └── fgraph.rst │ ├── scalar │ │ └── index.rst │ ├── misc │ │ └── pkl_utils.rst │ ├── sparse │ │ └── sandbox.rst │ ├── typed_list.rst │ └── index.rst ├── robots.txt ├── blog.md ├── _thumbnails │ └── autodiff │ │ └── vector_jacobian_product.png ├── README.md ├── internal │ ├── index.rst │ └── how_to_release.rst ├── user_guide.rst ├── css.inc ├── environment.yml ├── install.rst ├── gallery │ └── page_footer.md ├── core_development_guide.rst ├── acknowledgement.rst ├── generate_dtype_tensor_table.py └── links.rst ├── CODE_OF_CONDUCT.md ├── .github ├── ISSUE_TEMPLATE │ ├── config.yml │ ├── developer.yml │ ├── documentation.yml │ └── feature-request.yml ├── dependabot.yml ├── release.yml ├── workflows │ ├── mypy.yml │ ├── slow-tests-issue.yml │ └── zizmor.yml └── PULL_REQUEST_TEMPLATE.md ├── .readthedocs.yaml ├── GOVERNANCE.md ├── MANIFEST.in ├── codecov.yml ├── .gitignore ├── scripts └── mypy-failing.txt ├── setup.py ├── environment-osx-arm64.yml ├── environment.yml ├── conftest.py ├── .pre-commit-config.yaml └── CONTRIBUTING.md /pytensor/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/d3viz/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/graph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/scan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | doc/LICENSE.txt -------------------------------------------------------------------------------- /pytensor/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytensor/link/c/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytensor/misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/compile/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/c/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/jax/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/mlx/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/numba/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/scalar/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/typed_list/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/xtensor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/graph/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/jax/signal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/conv/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/linalg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/random/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/signal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/compile/function/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/numba/linalg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/numba/linalg/solve/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/random/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytensor/link/mlx/dispatch/signal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytensor/link/numba/dispatch/linalg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | pytensor/_version.py export-subst 2 | -------------------------------------------------------------------------------- /pytensor/link/numba/dispatch/linalg/solve/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytensor/link/numba/dispatch/linalg/decomposition/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pytensor/d3viz/__init__.py: -------------------------------------------------------------------------------- 1 | from pytensor.d3viz.d3viz import d3viz, d3write 2 | -------------------------------------------------------------------------------- /pytensor/scalar/__init__.py: -------------------------------------------------------------------------------- 1 | from .basic import * 2 | from .math import * 3 | -------------------------------------------------------------------------------- /doc/bcast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/bcast.png -------------------------------------------------------------------------------- /pytensor/link/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from pytensor.link.jax.linker import JAXLinker 2 | -------------------------------------------------------------------------------- /pytensor/link/mlx/__init__.py: -------------------------------------------------------------------------------- 1 | from pytensor.link.mlx.linker import MLXLinker 2 | -------------------------------------------------------------------------------- /pytensor/link/__init__.py: -------------------------------------------------------------------------------- 1 | from pytensor.link.pytorch.linker import PytorchLinker 2 | -------------------------------------------------------------------------------- /pytensor/link/numba/__init__.py: -------------------------------------------------------------------------------- 1 | from pytensor.link.numba.linker import NumbaLinker 2 | -------------------------------------------------------------------------------- /doc/.templates/PLACEHOLDER: -------------------------------------------------------------------------------- 1 | sphinx doesn't like it when this repertory isn't available 2 | -------------------------------------------------------------------------------- /doc/images/lstm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/images/lstm.png -------------------------------------------------------------------------------- /doc/tutorial/apply.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/tutorial/apply.png -------------------------------------------------------------------------------- /doc/tutorial/bcast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/tutorial/bcast.png -------------------------------------------------------------------------------- /pytensor/link/jax/dispatch/signal/__init__.py: -------------------------------------------------------------------------------- 1 | import pytensor.link.jax.dispatch.signal.conv 2 | -------------------------------------------------------------------------------- /doc/extending/apply.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/extending/apply.png -------------------------------------------------------------------------------- /doc/images/PyTensor.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/images/PyTensor.png -------------------------------------------------------------------------------- /doc/images/talk2010.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/images/talk2010.gif -------------------------------------------------------------------------------- /doc/images/talk2010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/images/talk2010.png -------------------------------------------------------------------------------- /pytensor/link/numba/dispatch/signal/__init__.py: -------------------------------------------------------------------------------- 1 | import pytensor.link.numba.dispatch.signal.conv 2 | -------------------------------------------------------------------------------- /pytensor/tensor/_linalg/__init__.py: -------------------------------------------------------------------------------- 1 | # Register rewrites 2 | import pytensor.tensor._linalg.solve 3 | -------------------------------------------------------------------------------- /doc/images/Elman_srnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/images/Elman_srnn.png -------------------------------------------------------------------------------- /doc/images/blocksparse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/images/blocksparse.png -------------------------------------------------------------------------------- /doc/library/tensor/functional.rst: -------------------------------------------------------------------------------- 1 | .. automodule:: pytensor.tensor.functional 2 | :members: vectorize 3 | -------------------------------------------------------------------------------- /doc/robots.txt: -------------------------------------------------------------------------------- 1 | User-agent: * 2 | 3 | Sitemap: https://pytensor.readthedocs.io/en/latest/sitemap.xml 4 | -------------------------------------------------------------------------------- /doc/tutorial/dlogistic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/tutorial/dlogistic.png -------------------------------------------------------------------------------- /doc/tutorial/logistic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/tutorial/logistic.png -------------------------------------------------------------------------------- /doc/tutorial/pics/d3viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/tutorial/pics/d3viz.png -------------------------------------------------------------------------------- /doc/images/PyTensor_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/images/PyTensor_logo.png -------------------------------------------------------------------------------- /doc/library/tensor/bcast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/library/tensor/bcast.png -------------------------------------------------------------------------------- /pytensor/tensor/linalg.py: -------------------------------------------------------------------------------- 1 | from pytensor.tensor.nlinalg import * 2 | from pytensor.tensor.slinalg import * 3 | -------------------------------------------------------------------------------- /doc/images/lstm_memorycell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/images/lstm_memorycell.png -------------------------------------------------------------------------------- /doc/library/tensor/plot_fft.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/library/tensor/plot_fft.png -------------------------------------------------------------------------------- /doc/library/d3viz/examples/mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/library/d3viz/examples/mlp.png -------------------------------------------------------------------------------- /doc/tutorial/symbolic_graphs.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | This page has been moved. Please refer to: :ref:`graphstructures`. 4 | -------------------------------------------------------------------------------- /doc/blog.md: -------------------------------------------------------------------------------- 1 | --- 2 | orphan: true 3 | --- 4 | 5 | # Recent updates 6 | 7 | 8 | -------------------------------------------------------------------------------- /doc/library/d3viz/examples/mlp2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/library/d3viz/examples/mlp2.pdf -------------------------------------------------------------------------------- /doc/library/d3viz/examples/mlp2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/library/d3viz/examples/mlp2.png -------------------------------------------------------------------------------- /pytensor/tensor/_linalg/solve/__init__.py: -------------------------------------------------------------------------------- 1 | # Register rewrites in the database 2 | import pytensor.tensor._linalg.solve.rewriting 3 | -------------------------------------------------------------------------------- /doc/extending/pics/symbolic_graph_opt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/extending/pics/symbolic_graph_opt.png -------------------------------------------------------------------------------- /pytensor/graph/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | from pytensor.graph.rewriting.utils import rewrite_graph 2 | 3 | 4 | all = ("rewrite_graph",) 5 | -------------------------------------------------------------------------------- /doc/extending/pics/symbolic_graph_unopt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/extending/pics/symbolic_graph_unopt.png -------------------------------------------------------------------------------- /doc/library/d3viz/index_files/index_10_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/library/d3viz/index_files/index_10_0.png -------------------------------------------------------------------------------- /doc/library/d3viz/index_files/index_11_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/library/d3viz/index_files/index_11_0.png -------------------------------------------------------------------------------- /doc/library/d3viz/index_files/index_24_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/library/d3viz/index_files/index_24_0.png -------------------------------------------------------------------------------- /doc/library/d3viz/index_files/index_25_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/library/d3viz/index_files/index_25_0.png -------------------------------------------------------------------------------- /doc/tutorial/pics/logreg_pydotprint_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/tutorial/pics/logreg_pydotprint_predict.png -------------------------------------------------------------------------------- /doc/tutorial/pics/logreg_pydotprint_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/tutorial/pics/logreg_pydotprint_train.png -------------------------------------------------------------------------------- /doc/_thumbnails/autodiff/vector_jacobian_product.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/_thumbnails/autodiff/vector_jacobian_product.png -------------------------------------------------------------------------------- /doc/tutorial/pics/logreg_pydotprint_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pymc-devs/pytensor/HEAD/doc/tutorial/pics/logreg_pydotprint_prediction.png -------------------------------------------------------------------------------- /pytensor/tensor/signal/__init__.py: -------------------------------------------------------------------------------- 1 | from pytensor.tensor.signal.conv import convolve1d, convolve2d 2 | 3 | 4 | __all__ = ("convolve1d", "convolve2d") 5 | -------------------------------------------------------------------------------- /pytensor/typed_list/__init__.py: -------------------------------------------------------------------------------- 1 | from pytensor.typed_list import rewriting 2 | from pytensor.typed_list.basic import * 3 | from pytensor.typed_list.type import TypedListType 4 | -------------------------------------------------------------------------------- /doc/library/xtensor/linalg.md: -------------------------------------------------------------------------------- 1 | (libdoc_xtensor_linalg)= 2 | # `xtensor.linalg` -- Linear algebra operations 3 | 4 | ```{eval-rst} 5 | .. automodule:: pytensor.xtensor.linalg 6 | :members: 7 | ``` 8 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # PyTensor Code of Conduct 2 | 3 | As a library of the PyMC project, PyTensor follows the 4 | [PyMC code of conduct](https://github.com/pymc-devs/pymc/blob/main/CODE_OF_CONDUCT.md). 5 | -------------------------------------------------------------------------------- /doc/library/xtensor/random.md: -------------------------------------------------------------------------------- 1 | (libdoc_xtensor_random)= 2 | # `xtensor.random` Random number generator operations 3 | 4 | ```{eval-rst} 5 | .. automodule:: pytensor.xtensor.random 6 | :members: 7 | ``` 8 | -------------------------------------------------------------------------------- /tests/link/c/c_code/test_cenum.h: -------------------------------------------------------------------------------- 1 | #ifndef PYTENSOR_TEST_CENUM 2 | #define PYTENSOR_TEST_CENUM 3 | 4 | #define MILLION 1000000 5 | #define BILLION 1000000000 6 | #define TWO_BILLIONS 2000000000 7 | 8 | #endif 9 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false 2 | contact_links: 3 | - name: PyMC Discourse 4 | url: https://discourse.pymc.io/ 5 | about: Ask installation usage questions about PyTensor 6 | -------------------------------------------------------------------------------- /doc/library/xtensor/math.md: -------------------------------------------------------------------------------- 1 | (libdoc_xtensor_math)= 2 | # `xtensor.math` Mathematical operations 3 | 4 | ```{eval-rst} 5 | .. automodule:: pytensor.xtensor.math 6 | :members: 7 | :exclude-members: XDot, dot 8 | ``` -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | sphinx: 3 | configuration: doc/conf.py 4 | conda: 5 | environment: doc/environment.yml 6 | build: 7 | os: "ubuntu-lts-latest" 8 | tools: 9 | python: "mambaforge-latest" 10 | -------------------------------------------------------------------------------- /GOVERNANCE.md: -------------------------------------------------------------------------------- 1 | # PyTensor Governance 2 | 3 | As a library of the PyMC project, PyTensor governance and decision making 4 | is described in the [main PyMC governance doc](https://github.com/pymc-devs/pymc/blob/main/GOVERNANCE.md) 5 | -------------------------------------------------------------------------------- /doc/library/tensor/random/distributions.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_tensor_random_distributions: 2 | 3 | Distributions 4 | ============= 5 | 6 | .. automodule:: pytensor.tensor.random.basic 7 | :members: 8 | :special-members: __call__ 9 | -------------------------------------------------------------------------------- /doc/README.md: -------------------------------------------------------------------------------- 1 | # PyTensor Documentation 2 | Welcome to the PyTensor documentation. Instructions on how to contribute can be found [here](https://pytensor.readthedocs.io/en/latest/dev_start_guide.html#contributing-to-the-documentation) 3 | -------------------------------------------------------------------------------- /doc/internal/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _internal: 3 | 4 | ====================== 5 | Internal Documentation 6 | ====================== 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | metadocumentation 12 | how_to_release 13 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/developer.yml: -------------------------------------------------------------------------------- 1 | name: Developer issue 2 | description: This template is for developers only! 3 | 4 | body: 5 | - type: textarea 6 | attributes: 7 | label: Description 8 | validations: 9 | required: true 10 | -------------------------------------------------------------------------------- /doc/library/compile/ops.rst: -------------------------------------------------------------------------------- 1 | ================================================== 2 | :mod:`ops` -- Some Common Ops and extra Ops stuff 3 | ================================================== 4 | 5 | .. automodule:: pytensor.compile.ops 6 | :members: 7 | -------------------------------------------------------------------------------- /doc/library/graph/op.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_graph_op: 2 | 3 | =========================================== 4 | :mod:`op` -- Objects that define operations 5 | =========================================== 6 | 7 | .. automodule:: pytensor.graph.op 8 | :members: 9 | -------------------------------------------------------------------------------- /doc/library/scalar/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _libdoc_scalar: 3 | 4 | ============================================================== 5 | :mod:`scalar` -- Symbolic Scalar Types, Ops [doc TODO] 6 | ============================================================== 7 | -------------------------------------------------------------------------------- /doc/library/xtensor/module_functions.md: -------------------------------------------------------------------------------- 1 | (libdoc_xtensor_module_function)= 2 | # `xtensor` -- Module level operations 3 | 4 | ```{eval-rst} 5 | .. automodule:: pytensor.xtensor 6 | :members: broadcast, concat, dot, full_like, ones_like, zeros_like 7 | ``` 8 | -------------------------------------------------------------------------------- /doc/user_guide.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _user_guide: 3 | 4 | ========== 5 | User Guide 6 | ========== 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | extending/index 12 | optimizations 13 | troubleshooting 14 | glossary 15 | links 16 | acknowledgement 17 | -------------------------------------------------------------------------------- /pytensor/tensor/conv/__init__.py: -------------------------------------------------------------------------------- 1 | from .abstract_conv import ( 2 | bilinear_upsampling, 3 | causal_conv1d, 4 | conv2d, 5 | conv2d_transpose, 6 | conv3d, 7 | frac_bilinear_upsampling, 8 | separable_conv2d, 9 | separable_conv3d, 10 | ) 11 | -------------------------------------------------------------------------------- /doc/library/graph/graph.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_graph_graph: 2 | 3 | ================================================ 4 | :mod:`graph` -- Interface for the PyTensor graph 5 | ================================================ 6 | 7 | .. automodule:: pytensor.graph.basic 8 | :members: 9 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | # Maintain dependencies for GitHub Actions 4 | - package-ecosystem: "github-actions" 5 | directory: "/" 6 | schedule: 7 | interval: "weekly" 8 | labels: 9 | - "GitHub CI/CD" 10 | - "no releasenotes" 11 | -------------------------------------------------------------------------------- /doc/library/graph/replace.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_graph_replace: 2 | 3 | ================================================== 4 | :mod:`replace` -- High level graph transformations 5 | ================================================== 6 | 7 | .. automodule:: pytensor.graph.replace 8 | :members: 9 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | global-include *.txt 2 | global-include *.c 3 | global-include *.cc 4 | global-include *.cu 5 | global-include *.cuh 6 | global-include *.cpp 7 | global-include *.h 8 | global-include *.hh 9 | global-include *.sh 10 | recursive-include doc * 11 | include pytensor/_version.py 12 | -------------------------------------------------------------------------------- /pytensor/xtensor/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | import pytensor.xtensor.rewriting.basic 2 | import pytensor.xtensor.rewriting.indexing 3 | import pytensor.xtensor.rewriting.math 4 | import pytensor.xtensor.rewriting.reduction 5 | import pytensor.xtensor.rewriting.shape 6 | import pytensor.xtensor.rewriting.vectorization 7 | -------------------------------------------------------------------------------- /pytensor/tensor/random/__init__.py: -------------------------------------------------------------------------------- 1 | # Initialize `RandomVariable` rewrites 2 | import pytensor.tensor.random.rewriting 3 | import pytensor.tensor.random.utils 4 | from pytensor.tensor.random.basic import * 5 | from pytensor.tensor.random.op import default_rng 6 | from pytensor.tensor.random.utils import RandomStream 7 | -------------------------------------------------------------------------------- /pytensor/tensor/var.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | from pytensor.tensor.variable import * # noqa 4 | 5 | 6 | warnings.warn( 7 | "The module 'pytensor.tensor.var' has been deprecated. " 8 | "Use 'pytensor.tensor.variable' instead.", 9 | category=DeprecationWarning, 10 | stacklevel=2, 11 | ) 12 | -------------------------------------------------------------------------------- /pytensor/link/pytorch/dispatch/math.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytensor.link.pytorch.dispatch import pytorch_funcify 4 | from pytensor.tensor.math import Dot 5 | 6 | 7 | @pytorch_funcify.register(Dot) 8 | def pytorch_funcify_Dot(op, **kwargs): 9 | def dot(x, y): 10 | return torch.matmul(x, y) 11 | 12 | return dot 13 | -------------------------------------------------------------------------------- /tests/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pytensor.compile import get_default_mode 4 | from pytensor.link.numba import NumbaLinker 5 | 6 | 7 | if isinstance(get_default_mode().linker, NumbaLinker): 8 | pytest.skip( 9 | reason="Numba does not support Sparse Ops yet", 10 | allow_module_level=True, 11 | ) 12 | -------------------------------------------------------------------------------- /doc/css.inc: -------------------------------------------------------------------------------- 1 | .. _css: 2 | 3 | .. raw:: html 4 | 5 | 12 | 13 | .. role:: black 14 | .. role:: blue 15 | .. role:: red 16 | .. role:: green 17 | .. role:: pink 18 | -------------------------------------------------------------------------------- /pytensor/tensor/random/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | # TODO: This is for backward-compatibility; remove when reasonable. 2 | from pytensor.tensor.random.rewriting.basic import * 3 | 4 | 5 | # isort: off 6 | 7 | # Register Numba and JAX specializations 8 | import pytensor.tensor.random.rewriting.numba 9 | import pytensor.tensor.random.rewriting.jax 10 | 11 | # isort: on 12 | -------------------------------------------------------------------------------- /doc/tutorial/profiling_example.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import pytensor 4 | 5 | x, y, z = pytensor.tensor.vectors("xyz") 6 | f = pytensor.function([x, y, z], [(x + y + z) * 2]) 7 | xv = np.random.random((10,)).astype(pytensor.config.floatX) 8 | yv = np.random.random((10,)).astype(pytensor.config.floatX) 9 | zv = np.random.random((10,)).astype(pytensor.config.floatX) 10 | f(xv, yv, zv) 11 | -------------------------------------------------------------------------------- /doc/library/tensor/conv.rst: -------------------------------------------------------------------------------- 1 | ========================================= 2 | :mod:`tensor.conv` -- Tensor Convolutions 3 | ========================================= 4 | 5 | .. module:: tensor.conv 6 | :platform: Unix, Windows 7 | :synopsis: Tensor Convolutions 8 | .. moduleauthor:: LISA, PyMC Developers, PyTensor Developers 9 | 10 | .. automodule:: pytensor.tensor.conv 11 | :members: 12 | -------------------------------------------------------------------------------- /doc/library/graph/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _libdoc_graph: 3 | 4 | ======================================== 5 | :mod:`graph` -- PyTensor Graph Internals 6 | ======================================== 7 | 8 | .. module:: graph 9 | 10 | .. moduleauthor:: LISA 11 | 12 | .. toctree:: 13 | :maxdepth: 1 14 | 15 | graph 16 | fgraph 17 | replace 18 | features 19 | op 20 | type 21 | utils 22 | -------------------------------------------------------------------------------- /tests/link/pytorch/conftest.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytensor import config 5 | from pytensor.tensor.type import matrix 6 | 7 | 8 | @pytest.fixture 9 | def matrix_test(): 10 | rng = np.random.default_rng(213234) 11 | 12 | M = rng.normal(size=(3, 3)) 13 | test_value = M.dot(M.T).astype(config.floatX) 14 | 15 | x = matrix("x") 16 | return x, test_value 17 | -------------------------------------------------------------------------------- /doc/library/misc/pkl_utils.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _libdoc_misc: 3 | 4 | ================================================ 5 | :mod:`misc.pkl_utils` - Tools for serialization. 6 | ================================================ 7 | 8 | .. testsetup:: * 9 | 10 | from pytensor.misc.pkl_utils import * 11 | 12 | .. autoclass:: pytensor.misc.pkl_utils.StripPickler 13 | 14 | .. seealso:: 15 | 16 | :ref:`tutorial_loadsave` 17 | -------------------------------------------------------------------------------- /doc/library/tensor/basic_opt.rst: -------------------------------------------------------------------------------- 1 | ================================================ 2 | :mod:`tensor.rewriting.basic` -- Tensor Rewrites 3 | ================================================ 4 | 5 | .. module:: tensor.rewriting.basic 6 | :platform: Unix, Windows 7 | :synopsis: Tensor Rewrites 8 | .. moduleauthor:: LISA, PyMC Developers, PyTensor Developers 9 | 10 | .. automodule:: pytensor.tensor.rewriting.basic 11 | :members: 12 | -------------------------------------------------------------------------------- /doc/library/graph/type.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_graph_type: 2 | 3 | ================================================ 4 | :mod:`type` -- Interface for types of variables 5 | ================================================ 6 | 7 | --------- 8 | Reference 9 | --------- 10 | 11 | .. automodule:: pytensor.graph.type 12 | :platform: Unix, Windows 13 | :synopsis: Interface for types of symbolic variables 14 | :members: 15 | .. moduleauthor:: LISA 16 | -------------------------------------------------------------------------------- /doc/library/tensor/optimize.rst: -------------------------------------------------------------------------------- 1 | ======================================================== 2 | :mod:`tensor.optimize` -- Symbolic Optimization Routines 3 | ======================================================== 4 | 5 | .. module:: tensor.conv 6 | :platform: Unix, Windows 7 | :synopsis: Symbolic Optimization Routines 8 | .. moduleauthor:: LISA, PyMC Developers, PyTensor Developers 9 | 10 | .. automodule:: pytensor.tensor.optimize 11 | :members: 12 | -------------------------------------------------------------------------------- /doc/tutorial/adding_solution_1.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # PyTensor tutorial 3 | # Solution to Exercise in section 'Baby Steps - Algebra' 4 | 5 | 6 | import pytensor 7 | a = pytensor.tensor.vector() # declare variable 8 | b = pytensor.tensor.vector() # declare variable 9 | out = a ** 2 + b ** 2 + 2 * a * b # build symbolic expression 10 | f = pytensor.function([a, b], out) # compile function 11 | print(f([1, 2], [4, 5])) # prints [ 25. 49.] 12 | -------------------------------------------------------------------------------- /pytensor/link/c/exceptions.py: -------------------------------------------------------------------------------- 1 | from setuptools.errors import CompileError as BaseCompileError 2 | 3 | 4 | class MissingGXX(Exception): 5 | """This error is raised when we try to generate c code, but g++ is not available.""" 6 | 7 | 8 | class CompileError(BaseCompileError): # pyright: ignore 9 | """Custom `Exception` prints compilation errors with their original formatting.""" 10 | 11 | def __str__(self): 12 | return self.args[0] 13 | -------------------------------------------------------------------------------- /pytensor/tensor/exceptions.py: -------------------------------------------------------------------------------- 1 | class ShapeError(Exception): 2 | """Raised when the shape cannot be computed.""" 3 | 4 | 5 | class NotScalarConstantError(Exception): 6 | """ 7 | Raised by get_underlying_scalar_constant_value if called on something that is 8 | not a scalar constant. 9 | """ 10 | 11 | 12 | class AdvancedIndexingError(TypeError): 13 | """ 14 | Raised when Subtensor is asked to perform advanced indexing. 15 | 16 | """ 17 | -------------------------------------------------------------------------------- /pytensor/xtensor/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import pytensor.xtensor.rewriting 4 | from pytensor.xtensor import linalg, math, random 5 | from pytensor.xtensor.math import dot 6 | from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like 7 | from pytensor.xtensor.type import ( 8 | as_xtensor, 9 | xtensor, 10 | xtensor_constant, 11 | ) 12 | 13 | 14 | warnings.warn("xtensor module is experimental and full of bugs") 15 | -------------------------------------------------------------------------------- /tests/sparse/test_sharedvar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy as sp 3 | 4 | import pytensor 5 | from pytensor.sparse.sharedvar import SparseTensorSharedVariable 6 | 7 | 8 | def test_shared_basic(): 9 | x = pytensor.shared( 10 | sp.sparse.csr_matrix(np.eye(100), dtype=np.float64), name="blah", borrow=True 11 | ) 12 | 13 | assert isinstance(x, SparseTensorSharedVariable) 14 | assert x.format == "csr" 15 | assert x.dtype == "float64" 16 | -------------------------------------------------------------------------------- /doc/library/tensor/utils.rst: -------------------------------------------------------------------------------- 1 | =================================================================== 2 | :mod:`tensor.utils` -- Tensor Utils 3 | =================================================================== 4 | 5 | .. testsetup:: 6 | 7 | from pytensor.tensor.utils import * 8 | 9 | .. module:: tensor.utils 10 | :platform: Unix, Windows 11 | :synopsis: Tensor Utils 12 | .. moduleauthor:: LISA 13 | 14 | .. automodule:: pytensor.tensor.utils 15 | :members: 16 | -------------------------------------------------------------------------------- /pytensor/link/numba/dispatch/linalg/solve/utils.py: -------------------------------------------------------------------------------- 1 | from scipy import linalg 2 | 3 | from pytensor.link.numba.dispatch import basic as numba_basic 4 | 5 | 6 | @numba_basic.numba_njit(inline="always") 7 | def _solve_check_input_shapes(A, B): 8 | if A.shape[0] != B.shape[0]: 9 | raise linalg.LinAlgError("Dimensions of A and B do not conform") 10 | if A.shape[-2] != A.shape[-1]: 11 | raise linalg.LinAlgError("Last 2 dimensions of A must be square") 12 | -------------------------------------------------------------------------------- /tests/graph/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytensor 2 | from pytensor.tensor.type import vector 3 | 4 | 5 | def test_stack_trace(): 6 | with pytensor.config.change_flags(traceback__limit=1): 7 | v = vector() 8 | assert len(v.tag.trace) == 1 9 | assert len(v.tag.trace[0]) == 1 10 | 11 | with pytensor.config.change_flags(traceback__limit=2): 12 | v = vector() 13 | assert len(v.tag.trace) == 1 14 | assert len(v.tag.trace[0]) == 2 15 | -------------------------------------------------------------------------------- /doc/library/xtensor/type.md: -------------------------------------------------------------------------------- 1 | (libdoc_xtenor_type)= 2 | 3 | # `xtensor.type` -- Types and Variables 4 | 5 | ## XTensorVariable creation functions 6 | 7 | ```{eval-rst} 8 | .. automodule:: pytensor.xtensor.type 9 | :members: xtensor, xtensor_constant, as_xtensor 10 | 11 | ``` 12 | 13 | ## XTensor Type and Variable classes 14 | 15 | ```{eval-rst} 16 | .. automodule:: pytensor.xtensor.type 17 | :members: XTensorType, XTensorVariable, XTensorConstant 18 | ``` 19 | 20 | 21 | -------------------------------------------------------------------------------- /pytensor/link/pytorch/dispatch/blas.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytensor.link.pytorch.dispatch import pytorch_funcify 4 | from pytensor.tensor.blas import BatchedDot 5 | 6 | 7 | @pytorch_funcify.register(BatchedDot) 8 | def pytorch_funcify_BatchedDot(op, **kwargs): 9 | def batched_dot(a, b): 10 | if a.shape[0] != b.shape[0]: 11 | raise TypeError("Shapes must match in the 0-th dimension") 12 | return torch.bmm(a, b) 13 | 14 | return batched_dot 15 | -------------------------------------------------------------------------------- /doc/library/tensor/elemwise.rst: -------------------------------------------------------------------------------- 1 | =================================================================== 2 | :mod:`tensor.elemwise` -- Tensor Elemwise 3 | =================================================================== 4 | 5 | .. testsetup:: 6 | 7 | from pytensor.tensor.elemwise import * 8 | 9 | .. module:: tensor.elemwise 10 | :platform: Unix, Windows 11 | :synopsis: Tensor Elemwise 12 | .. moduleauthor:: LISA 13 | 14 | .. automodule:: pytensor.tensor.elemwise 15 | :members: 16 | -------------------------------------------------------------------------------- /pytensor/link/jax/dispatch/blas.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from pytensor.link.jax.dispatch import jax_funcify 4 | from pytensor.tensor.blas import BatchedDot 5 | 6 | 7 | @jax_funcify.register(BatchedDot) 8 | def jax_funcify_BatchedDot(op, **kwargs): 9 | def batched_dot(a, b): 10 | if a.shape[0] != b.shape[0]: 11 | raise TypeError("Shapes must match along the first dimension of BatchedDot") 12 | return jnp.matmul(a, b) 13 | 14 | return batched_dot 15 | -------------------------------------------------------------------------------- /doc/library/tensor/extra_ops.rst: -------------------------------------------------------------------------------- 1 | =================================================================== 2 | :mod:`tensor.extra_ops` -- Tensor Extra Ops 3 | =================================================================== 4 | 5 | .. testsetup:: * 6 | 7 | from pytensor.tensor.extra_ops import * 8 | 9 | .. module:: tensor.extra_ops 10 | :platform: Unix, Windows 11 | :synopsis: Tensor Extra Ops 12 | .. moduleauthor:: LISA 13 | 14 | .. automodule:: pytensor.tensor.extra_ops 15 | :members: 16 | -------------------------------------------------------------------------------- /doc/library/tensor/math_opt.rst: -------------------------------------------------------------------------------- 1 | =================================================================== 2 | :mod:`tensor.rewriting.math` -- Tensor Rewrites for Math Operations 3 | =================================================================== 4 | 5 | .. module:: tensor.rewriting.math 6 | :platform: Unix, Windows 7 | :synopsis: Tensor Rewrites for Math Operations 8 | .. moduleauthor:: LISA, PyMC Developers, PyTensor Developers 9 | 10 | .. automodule:: pytensor.tensor.rewriting.math 11 | :members: 12 | -------------------------------------------------------------------------------- /doc/environment.yml: -------------------------------------------------------------------------------- 1 | name: pytensor-docs 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | - python=3.11 7 | - gcc_linux-64 8 | - gxx_linux-64 9 | - numpy 10 | - scipy 11 | - six 12 | - sphinx>=5.1.0,<6 13 | - mock 14 | - pillow 15 | - pymc-sphinx-theme 16 | - sphinx-copybutton 17 | - sphinx-design 18 | - sphinx-sitemap 19 | - pygments 20 | - pydot 21 | - ipython 22 | - myst-nb 23 | - matplotlib 24 | - watermark 25 | - ablog 26 | - pip 27 | - pip: 28 | - -e ..[jax] 29 | -------------------------------------------------------------------------------- /doc/library/compile/profilemode.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. _profilemode: 4 | 5 | ================================================== 6 | :mod:`profilemode` -- profiling PyTensor functions 7 | ================================================== 8 | 9 | 10 | .. module:: pytensor.compile.profilemode 11 | :platform: Unix, Windows 12 | :synopsis: profiling PyTensor functions with ProfileMode 13 | .. moduleauthor:: LISA 14 | 15 | Guide 16 | ===== 17 | 18 | .. note:: 19 | 20 | ProfileMode is removed. Use :attr:`config.profile` instead. 21 | -------------------------------------------------------------------------------- /tests/link/jax/test_sort.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytensor.tensor import matrix 5 | from pytensor.tensor.sort import argsort, sort 6 | from tests.link.jax.test_basic import compare_jax_and_py 7 | 8 | 9 | @pytest.mark.parametrize("axis", [None, -1]) 10 | @pytest.mark.parametrize("func", (sort, argsort)) 11 | def test_sort(func, axis): 12 | x = matrix("x", shape=(2, 2), dtype="float64") 13 | out = func(x, axis=axis) 14 | arr = np.array([[1.0, 4.0], [5.0, 2.0]]) 15 | compare_jax_and_py([x], [out], [arr]) 16 | -------------------------------------------------------------------------------- /doc/library/graph/utils.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_graph_utils: 2 | 3 | ========================================================== 4 | :mod:`utils` -- Utilities functions operating on the graph 5 | ========================================================== 6 | 7 | .. testsetup:: * 8 | 9 | from pytensor.graph.utils import * 10 | 11 | --------- 12 | Reference 13 | --------- 14 | 15 | .. automodule:: pytensor.graph.utils 16 | :platform: Unix, Windows 17 | :synopsis: Utilities functions operating on the graph 18 | :members: 19 | .. moduleauthor:: LISA 20 | -------------------------------------------------------------------------------- /tests/link/pytorch/test_sort.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytensor.tensor import matrix 5 | from pytensor.tensor.sort import argsort, sort 6 | from tests.link.pytorch.test_basic import compare_pytorch_and_py 7 | 8 | 9 | @pytest.mark.parametrize("func", (sort, argsort)) 10 | @pytest.mark.parametrize("axis", [0, 1, None]) 11 | def test_sort(func, axis): 12 | x = matrix("x", shape=(2, 2), dtype="float64") 13 | out = func(x, axis=axis) 14 | arr = np.array([[1.0, 4.0], [5.0, 2.0]]) 15 | compare_pytorch_and_py([x], [out], [arr]) 16 | -------------------------------------------------------------------------------- /tests/link/pytorch/test_slinalg.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import pytensor 4 | from tests.link.pytorch.test_basic import compare_pytorch_and_py 5 | 6 | 7 | @pytest.mark.parametrize( 8 | "mode", 9 | ( 10 | "complete", 11 | "reduced", 12 | "r", 13 | pytest.param("raw", marks=pytest.mark.xfail(raises=NotImplementedError)), 14 | ), 15 | ) 16 | def test_qr(mode, matrix_test): 17 | x, test_value = matrix_test 18 | outs = pytensor.tensor.slinalg.qr(x, mode=mode) 19 | 20 | compare_pytorch_and_py([x], outs, [test_value]) 21 | -------------------------------------------------------------------------------- /pytensor/scan/scan_perform_ext.py: -------------------------------------------------------------------------------- 1 | """ 2 | 3 | To update the `Scan` Cython code you must 4 | - Update `scan_perform.pyx` 5 | - update the version value in this file and in `scan_perform.pyx` 6 | 7 | """ 8 | 9 | from pytensor.scan.scan_perform import get_version, perform # noqa: F401 10 | 11 | 12 | version = 0.326 # must match constant returned in function get_version() 13 | assert version == get_version(), ( 14 | "Invalid extension, check the installation process, " 15 | "could be problem with .pyx file or Cython ext build process." 16 | ) 17 | del get_version 18 | -------------------------------------------------------------------------------- /doc/tutorial/logistic.gp: -------------------------------------------------------------------------------- 1 | set terminal svg font "Bitstream Vera Sans,10" size 300,200 2 | set output "logistic.svg" 3 | 4 | set range [-6:6] 5 | set xzeroaxis linetype -1 6 | set yzeroaxis linetype -1 7 | set xtics axis nomirror 8 | set ytics axis nomirror 0,0.5,1 9 | set key off 10 | set grid 11 | set border 1 12 | 13 | set samples 400 14 | 15 | plot 1/(1 + exp(-x)) with line linetype rgbcolor "blue" linewidth 2 16 | 17 | set ytics axis nomirror 0,0.25 18 | set output "dlogistic.svg" 19 | plot 1/(1 + exp(-x)) * (1 - 1/(1 + exp(-x))) with line linetype rgbcolor "blue" linewidth 2 20 | -------------------------------------------------------------------------------- /doc/library/compile/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _libdoc_compile: 3 | 4 | ============================================================== 5 | :mod:`compile` -- Transforming Expression Graphs to Functions 6 | ============================================================== 7 | 8 | .. module:: compile 9 | :platform: Unix, Windows 10 | :synopsis: transforming expression graphs to functions 11 | .. moduleauthor:: LISA 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | 16 | shared 17 | function 18 | io 19 | ops 20 | mode 21 | debugmode 22 | nanguardmode 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /pytensor/d3viz/css/d3-context-menu.css: -------------------------------------------------------------------------------- 1 | .d3-context-menu { 2 | position: absolute; 3 | display: none; 4 | background-color: #f2f2f2; 5 | border-radius: 4px; 6 | 7 | font-family: Arial, sans-serif; 8 | font-size: 14px; 9 | min-width: 50px; 10 | border: 1px solid #d4d4d4; 11 | 12 | z-index:1200; 13 | } 14 | 15 | .d3-context-menu ul { 16 | list-style-type: none; 17 | margin: 4px 0px; 18 | padding: 0px; 19 | cursor: default; 20 | } 21 | 22 | .d3-context-menu ul li { 23 | padding: 4px 16px; 24 | } 25 | 26 | .d3-context-menu ul li:hover { 27 | background-color: #4677f8; 28 | color: #fefefe; 29 | } 30 | -------------------------------------------------------------------------------- /doc/.templates/rendered_citation.html: -------------------------------------------------------------------------------- 1 | 2 | {% if pagename in ablog %} 3 | {% set post = ablog[pagename] %} 4 | {% for coll in post.author %} 5 | {% if coll|length %} 6 | {{ coll }} 7 | {% if loop.index < post.author | length %},{% endif %} 8 | {% else %} 9 | {{ coll }} 10 | {% if loop.index < post.author | length %},{% endif %} 11 | {% endif %} 12 | {% endfor %}. "{{ title.split(' — ')[0] }}". In: Pytensor Examples. Ed. by Pytensor Team. 13 | {% endif %} -------------------------------------------------------------------------------- /doc/library/tensor/io.rst: -------------------------------------------------------------------------------- 1 | =================================================================== 2 | :mod:`tensor.io` -- Tensor IO Ops 3 | =================================================================== 4 | 5 | .. module:: tensor.io 6 | :platform: Unix, Windows 7 | :synopsis: Tensor IO Ops 8 | .. moduleauthor:: LISA 9 | 10 | File operation 11 | ============== 12 | 13 | - Load from disk with the function :func:`load ` and its associated op :class:`LoadFromDisk ` 14 | 15 | Details 16 | ======= 17 | 18 | .. automodule:: pytensor.tensor.io 19 | :members: 20 | -------------------------------------------------------------------------------- /tests/link/jax/signal/test_conv.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytensor.tensor import dmatrix 5 | from pytensor.tensor.signal import convolve1d 6 | from tests.link.jax.test_basic import compare_jax_and_py 7 | 8 | 9 | @pytest.mark.parametrize("mode", ["full", "valid", "same"]) 10 | def test_convolve1d(mode): 11 | x = dmatrix("x") 12 | y = dmatrix("y") 13 | out = convolve1d(x[None], y[:, None], mode=mode) 14 | 15 | rng = np.random.default_rng() 16 | test_x = rng.normal(size=(3, 5)) 17 | test_y = rng.normal(size=(7, 11)) 18 | compare_jax_and_py([x, y], out, [test_x, test_y]) 19 | -------------------------------------------------------------------------------- /doc/library/d3viz/examples/d3viz/css/d3-context-menu.css: -------------------------------------------------------------------------------- 1 | .d3-context-menu { 2 | position: absolute; 3 | display: none; 4 | background-color: #f2f2f2; 5 | border-radius: 4px; 6 | 7 | font-family: Arial, sans-serif; 8 | font-size: 14px; 9 | min-width: 50px; 10 | border: 1px solid #d4d4d4; 11 | 12 | z-index:1200; 13 | } 14 | 15 | .d3-context-menu ul { 16 | list-style-type: none; 17 | margin: 4px 0px; 18 | padding: 0px; 19 | cursor: default; 20 | } 21 | 22 | .d3-context-menu ul li { 23 | padding: 4px 16px; 24 | } 25 | 26 | .d3-context-menu ul li:hover { 27 | background-color: #4677f8; 28 | color: #fefefe; 29 | } 30 | -------------------------------------------------------------------------------- /doc/library/sparse/sandbox.rst: -------------------------------------------------------------------------------- 1 | .. ../../../../pytensor/sparse/sandbox/sp.py 2 | .. ../../../../pytensor/sparse/basic/truedot.py 3 | 4 | .. _libdoc_sparse_sandbox: 5 | 6 | =================================================================== 7 | :mod:`sparse.sandbox` -- Sparse Op Sandbox 8 | =================================================================== 9 | 10 | .. module:: sparse.sandbox 11 | :platform: Unix, Windows 12 | :synopsis: Sparse Op Sandbox 13 | .. moduleauthor:: LISA 14 | 15 | API 16 | === 17 | 18 | .. automodule:: pytensor.sparse.sandbox.sp 19 | :members: 20 | .. automodule:: pytensor.sparse.sandbox.sp2 21 | :members: 22 | -------------------------------------------------------------------------------- /doc/library/tensor/nlinalg.rst: -------------------------------------------------------------------------------- 1 | .. ../../../../pytensor/sandbox/nlinalg.py 2 | 3 | .. _libdoc_linalg: 4 | 5 | =================================================================== 6 | :mod:`tensor.nlinalg` -- Linear Algebra Ops Using Numpy 7 | =================================================================== 8 | 9 | .. module:: tensor.nlinalg 10 | :platform: Unix, Windows 11 | :synopsis: Linear Algebra Ops Using Numpy 12 | .. moduleauthor:: LISA 13 | 14 | .. note:: 15 | 16 | This module is not imported by default. You need to import it to use it. 17 | 18 | API 19 | === 20 | 21 | .. automodule:: pytensor.tensor.nlinalg 22 | :members: 23 | -------------------------------------------------------------------------------- /doc/library/tensor/slinalg.rst: -------------------------------------------------------------------------------- 1 | .. ../../../../pytensor/sandbox/slinalg.py 2 | 3 | .. _libdoc_slinalg: 4 | 5 | =================================================================== 6 | :mod:`tensor.slinalg` -- Linear Algebra Ops Using Scipy 7 | =================================================================== 8 | 9 | .. module:: tensor.slinalg 10 | :platform: Unix, Windows 11 | :synopsis: Linear Algebra Ops Using Scipy 12 | .. moduleauthor:: LISA 13 | 14 | .. note:: 15 | 16 | This module is not imported by default. You need to import it to use it. 17 | 18 | API 19 | === 20 | 21 | .. automodule:: pytensor.tensor.slinalg 22 | :members: 23 | 24 | -------------------------------------------------------------------------------- /tests/tensor/rewriting/test_ofg.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import pytensor 4 | import pytensor.tensor as pt 5 | from pytensor import config 6 | from pytensor.compile.builders import OpFromGraph 7 | 8 | 9 | @pytest.mark.skipif( 10 | config.mode == "FAST_COMPILE", 11 | reason="Rewrite is not applied in FAST_COMPILE mode", 12 | ) 13 | def test_alloc_diag_inlined(): 14 | x = pt.tensor("x", shape=(None,)) 15 | 16 | z = pt.diag(x) 17 | assert isinstance(z.owner.op, OpFromGraph) 18 | 19 | f = pytensor.function([x], z) 20 | nodes = f.maker.fgraph.apply_nodes 21 | 22 | assert not any(isinstance(node.op, OpFromGraph) for node in nodes) 23 | -------------------------------------------------------------------------------- /pytensor/link/jax/dispatch/sort.py: -------------------------------------------------------------------------------- 1 | from jax import numpy as jnp 2 | 3 | from pytensor.link.jax.dispatch import jax_funcify 4 | from pytensor.tensor.sort import ArgSortOp, SortOp 5 | 6 | 7 | @jax_funcify.register(SortOp) 8 | def jax_funcify_Sort(op, **kwargs): 9 | stable = op.kind == "stable" 10 | 11 | def sort(arr, axis): 12 | return jnp.sort(arr, axis=axis, stable=stable) 13 | 14 | return sort 15 | 16 | 17 | @jax_funcify.register(ArgSortOp) 18 | def jax_funcify_ArgSort(op, **kwargs): 19 | stable = op.kind == "stable" 20 | 21 | def argsort(arr, axis): 22 | return jnp.argsort(arr, axis=axis, stable=stable) 23 | 24 | return argsort 25 | -------------------------------------------------------------------------------- /pytensor/link/pytorch/dispatch/slinalg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytensor.link.pytorch.dispatch import pytorch_funcify 4 | from pytensor.tensor.slinalg import QR 5 | 6 | 7 | @pytorch_funcify.register(QR) 8 | def pytorch_funcify_QR(op, **kwargs): 9 | mode = op.mode 10 | if mode == "raw": 11 | raise NotImplementedError("raw mode not implemented in PyTorch") 12 | elif mode == "full": 13 | mode = "complete" 14 | elif mode == "economic": 15 | mode = "reduced" 16 | 17 | def qr(x): 18 | Q, R = torch.linalg.qr(x, mode=mode) 19 | if mode == "r": 20 | return R 21 | return Q, R 22 | 23 | return qr 24 | -------------------------------------------------------------------------------- /pytensor/graph/__init__.py: -------------------------------------------------------------------------------- 1 | """Graph objects and manipulation functions.""" 2 | 3 | # isort: off 4 | from pytensor.graph.basic import ( 5 | Apply, 6 | Variable, 7 | Constant, 8 | clone, 9 | ) 10 | from pytensor.graph.traversal import ancestors, graph_inputs 11 | from pytensor.graph.replace import clone_replace, graph_replace, vectorize_graph 12 | from pytensor.graph.op import Op 13 | from pytensor.graph.type import Type 14 | from pytensor.graph.fg import FunctionGraph 15 | from pytensor.graph.rewriting.basic import node_rewriter, graph_rewriter 16 | from pytensor.graph.rewriting.utils import rewrite_graph 17 | from pytensor.graph.rewriting.db import RewriteDatabaseQuery 18 | 19 | # isort: on 20 | -------------------------------------------------------------------------------- /pytensor/misc/check_blas_many.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python misc/check_blas.py --print_only 4 | 5 | cat /proc/cpuinfo |grep "model name" |uniq 6 | cat /proc/cpuinfo |grep processor 7 | free 8 | uname -a 9 | 10 | TIME_PREFIX=time 11 | VAR=OMP_NUM_THREADS 12 | echo "numpy gemm take=" 13 | PYTENSOR_FLAGS=blas__ldflags= $TIME_PREFIX python misc/check_blas.py --quiet 14 | for i in 1 2 4 8 15 | do 16 | export $VAR=$i 17 | x=`$TIME_PREFIX python misc/check_blas.py --quiet` 18 | echo "pytensor gemm with $VAR=$i took: ${x}s" 19 | done 20 | 21 | #Fred to test distro numpy at LISA: PYTHONPATH=/u/bastienf/repos:/usr/lib64/python2.5/site-packages PYTENSOR_FLAGS=blas__ldflags= OMP_NUM_THREADS=8 time python misc/check_blas.py 22 | -------------------------------------------------------------------------------- /pytensor/link/pytorch/dispatch/sort.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytensor.link.pytorch.dispatch.basic import pytorch_funcify 4 | from pytensor.tensor.sort import ArgSortOp, SortOp 5 | 6 | 7 | @pytorch_funcify.register(SortOp) 8 | def pytorch_funcify_Sort(op, **kwargs): 9 | stable = op.kind == "stable" 10 | 11 | def sort(arr, axis): 12 | sorted, _ = torch.sort(arr, dim=axis, stable=stable) 13 | return sorted 14 | 15 | return sort 16 | 17 | 18 | @pytorch_funcify.register(ArgSortOp) 19 | def pytorch_funcify_ArgSort(op, **kwargs): 20 | stable = op.kind == "stable" 21 | 22 | def argsort(arr, axis): 23 | return torch.argsort(arr, dim=axis, stable=stable) 24 | 25 | return argsort 26 | -------------------------------------------------------------------------------- /pytensor/link/jax/dispatch/blockwise.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from pytensor.link.jax.dispatch import jax_funcify 4 | from pytensor.tensor.blockwise import Blockwise 5 | 6 | 7 | @jax_funcify.register(Blockwise) 8 | def jax_funcify_Blockwise(op: Blockwise, node, **kwargs): 9 | signature = op.signature 10 | core_node = op._create_dummy_core_node( 11 | node.inputs, propagate_unbatched_core_inputs=True 12 | ) 13 | core_fn = jax_funcify(core_node.op, node=core_node, **kwargs) 14 | 15 | vect_fn = jnp.vectorize(core_fn, signature=signature) 16 | 17 | def blockwise_fn(*inputs): 18 | op._check_runtime_broadcast(node, inputs) 19 | return vect_fn(*inputs) 20 | 21 | return blockwise_fn 22 | -------------------------------------------------------------------------------- /pytensor/link/jax/dispatch/einsum.py: -------------------------------------------------------------------------------- 1 | import jax.numpy as jnp 2 | 3 | from pytensor.link.jax.dispatch import jax_funcify 4 | from pytensor.tensor.einsum import Einsum 5 | 6 | 7 | @jax_funcify.register(Einsum) 8 | def jax_funcify_Einsum(op, **kwargs): 9 | """Dispatch einsum to JAX. 10 | 11 | This dispatch is triggered only when we couldn't optimize einsum at the PyTensor level. 12 | This happens when some of the dimension lengths are unknown. This is never a problem in JAX, 13 | as it always compiles a function per runtime input shape. 14 | """ 15 | subscripts = op.subscripts 16 | 17 | def einsum(*operands): 18 | return jnp.einsum(subscripts, *operands, optimize="optimal") 19 | 20 | return einsum 21 | -------------------------------------------------------------------------------- /pytensor/link/mlx/dispatch/__init__.py: -------------------------------------------------------------------------------- 1 | # isort: off 2 | from pytensor.link.mlx.dispatch.basic import mlx_funcify, mlx_typify 3 | 4 | import pytensor.link.mlx.dispatch.math 5 | import pytensor.link.mlx.dispatch.basic 6 | import pytensor.link.mlx.dispatch.elemwise 7 | import pytensor.link.mlx.dispatch.shape 8 | import pytensor.link.mlx.dispatch.subtensor 9 | import pytensor.link.mlx.dispatch.core 10 | import pytensor.link.mlx.dispatch.signal 11 | import pytensor.link.mlx.dispatch.signal.conv 12 | import pytensor.link.mlx.dispatch.blockwise 13 | import pytensor.link.mlx.dispatch.extra_ops 14 | import pytensor.link.mlx.dispatch.sort 15 | import pytensor.link.mlx.dispatch.slinalg 16 | import pytensor.link.mlx.dispatch.nlinalg 17 | # isort: on 18 | -------------------------------------------------------------------------------- /pytensor/link/pytorch/dispatch/__init__.py: -------------------------------------------------------------------------------- 1 | # isort: off 2 | from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify 3 | 4 | # # Load dispatch specializations 5 | import pytensor.link.pytorch.dispatch.blas 6 | import pytensor.link.pytorch.dispatch.scalar 7 | import pytensor.link.pytorch.dispatch.elemwise 8 | import pytensor.link.pytorch.dispatch.math 9 | import pytensor.link.pytorch.dispatch.extra_ops 10 | import pytensor.link.pytorch.dispatch.nlinalg 11 | import pytensor.link.pytorch.dispatch.slinalg 12 | import pytensor.link.pytorch.dispatch.shape 13 | import pytensor.link.pytorch.dispatch.sort 14 | import pytensor.link.pytorch.dispatch.subtensor 15 | import pytensor.link.pytorch.dispatch.blockwise 16 | # isort: on 17 | -------------------------------------------------------------------------------- /pytensor/link/c/c_code/pytensor_mod_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef PYTENSOR_MOD_HELPER 2 | #define PYTENSOR_MOD_HELPER 3 | 4 | #include 5 | 6 | #ifndef _WIN32 7 | #define MOD_PUBLIC __attribute__((visibility ("default"))) 8 | #else 9 | /* MOD_PUBLIC is only used in PyMODINIT_FUNC, which is declared 10 | * and implemented in mod.cu/cpp, not in headers, so dllexport 11 | * is always correct. */ 12 | #define MOD_PUBLIC __declspec( dllexport ) 13 | #endif 14 | 15 | #ifdef __cplusplus 16 | #define PYTENSOR_EXTERN extern "C" 17 | #else 18 | #define PYTENSOR_EXTERN 19 | #endif 20 | 21 | /* We need to redefine PyMODINIT_FUNC to add MOD_PUBLIC in the middle */ 22 | #undef PyMODINIT_FUNC 23 | #define PyMODINIT_FUNC PYTENSOR_EXTERN MOD_PUBLIC PyObject * 24 | 25 | #endif 26 | -------------------------------------------------------------------------------- /tests/link/mlx/test_extra_ops.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytensor.configdefaults import config 5 | from pytensor.tensor import extra_ops as pt_extra_ops 6 | from pytensor.tensor.type import matrix 7 | from tests.link.mlx.test_basic import compare_mlx_and_py 8 | 9 | 10 | mx = pytest.importorskip("mlx.core") 11 | 12 | 13 | def test_extra_ops(): 14 | a = matrix("a") 15 | a_test = np.arange(6, dtype=config.floatX).reshape((3, 2)) 16 | 17 | out = pt_extra_ops.cumsum(a, axis=0) 18 | compare_mlx_and_py([a], [out], [a_test]) 19 | 20 | out = pt_extra_ops.cumprod(a, axis=1) 21 | compare_mlx_and_py([a], [out], [a_test]) 22 | 23 | out = pt_extra_ops.repeat(a, 3, axis=1) 24 | compare_mlx_and_py([a], [out], [a_test]) 25 | -------------------------------------------------------------------------------- /pytensor/npy_2_compat.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # function that replicates np.unique from numpy < 2.0 5 | def old_np_unique( 6 | arr, return_index=False, return_inverse=False, return_counts=False, axis=None 7 | ): 8 | """Replicate np.unique from numpy versions < 2.0""" 9 | if not return_inverse: 10 | return np.unique(arr, return_index, return_inverse, return_counts, axis) 11 | 12 | outs = list(np.unique(arr, return_index, return_inverse, return_counts, axis)) 13 | 14 | inv_idx = 2 if return_index else 1 15 | 16 | if axis is None: 17 | outs[inv_idx] = np.ravel(outs[inv_idx]) 18 | else: 19 | inv_shape = (arr.shape[axis],) 20 | outs[inv_idx] = outs[inv_idx].reshape(inv_shape) 21 | 22 | return tuple(outs) 23 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "70...100" 8 | status: 9 | project: 10 | default: 11 | # basic 12 | target: auto 13 | threshold: 1% 14 | base: auto 15 | patch: 16 | default: 17 | # basic 18 | target: 100% 19 | threshold: 1% 20 | base: auto 21 | 22 | comment: 23 | layout: "reach, diff, flags, files" 24 | behavior: default 25 | require_changes: false # if true: only post the comment if coverage changes 26 | require_base: no # [yes :: must have a base report to post] 27 | require_head: yes # [yes :: must have a head report to post] 28 | branches: null # branch names that can post comment 29 | -------------------------------------------------------------------------------- /pytensor/tensor/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | import pytensor.tensor.rewriting.basic 2 | import pytensor.tensor.rewriting.blas 3 | import pytensor.tensor.rewriting.blas_c 4 | import pytensor.tensor.rewriting.blockwise 5 | import pytensor.tensor.rewriting.einsum 6 | import pytensor.tensor.rewriting.elemwise 7 | import pytensor.tensor.rewriting.extra_ops 8 | import pytensor.tensor.rewriting.jax 9 | import pytensor.tensor.rewriting.linalg 10 | import pytensor.tensor.rewriting.math 11 | import pytensor.tensor.rewriting.numba 12 | import pytensor.tensor.rewriting.ofg 13 | import pytensor.tensor.rewriting.shape 14 | import pytensor.tensor.rewriting.special 15 | import pytensor.tensor.rewriting.subtensor 16 | import pytensor.tensor.rewriting.subtensor_lift 17 | import pytensor.tensor.rewriting.uncanonicalize 18 | -------------------------------------------------------------------------------- /tests/tensor/test_xlogx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import pytensor 4 | from pytensor.tensor import as_tensor_variable 5 | from pytensor.tensor.xlogx import xlogx, xlogy0 6 | from tests import unittest_tools as utt 7 | 8 | 9 | def test_xlogx(): 10 | x = as_tensor_variable([1, 0]) 11 | y = xlogx(x) 12 | f = pytensor.function([], y) 13 | assert np.array_equal(f(), np.asarray([0, 0.0])) 14 | 15 | rng = np.random.default_rng(24982) 16 | utt.verify_grad(xlogx, [rng.random((3, 4))]) 17 | 18 | 19 | def test_xlogy0(): 20 | x = as_tensor_variable([1, 0]) 21 | y = as_tensor_variable([1, 0]) 22 | z = xlogy0(x, y) 23 | f = pytensor.function([], z) 24 | assert np.array_equal(f(), np.asarray([0, 0.0])) 25 | 26 | rng = np.random.default_rng(24982) 27 | utt.verify_grad(xlogy0, [rng.random((3, 4)), rng.random((3, 4))]) 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | _build 3 | __pycache__ 4 | .coverage 5 | *.linkinfo 6 | *.o 7 | *.c 8 | *.orig 9 | *.pyc 10 | *.pyo 11 | *.so 12 | *.sw? 13 | *~ 14 | *.aux 15 | *.log 16 | *.nav 17 | *.out 18 | *.snm 19 | *.toc 20 | *.vrb 21 | *.nbc 22 | *.nbi 23 | .noseids 24 | *.DS_Store 25 | *.bak 26 | *.egg-info/ 27 | \#*\# 28 | build 29 | compiled/*.cpp 30 | cutils_ext.cpp 31 | dist 32 | doc/.build/ 33 | doc/indexes/oplist.txt 34 | doc/indexes/typelist.txt 35 | html 36 | pdf 37 | setuptools-*.egg 38 | pytensor/generated_version.py 39 | pytensor/generated_version.py.out 40 | distribute-*.egg 41 | distribute-*.tar.gz 42 | PyTensor.suo 43 | .ipynb_checkpoints 44 | .pydevproject 45 | .ropeproject 46 | core 47 | .idea 48 | .vs 49 | .mypy_cache/ 50 | /htmlcov/ 51 | 52 | pytensor-venv/ 53 | /notebooks/Sandbox* 54 | .vscode/ 55 | testing-report.html 56 | coverage.xml 57 | .coverage.* 58 | -------------------------------------------------------------------------------- /pytensor/sparse/utils.py: -------------------------------------------------------------------------------- 1 | from pytensor.utils import hash_from_code 2 | 3 | 4 | def hash_from_sparse(data): 5 | # We need to hash the shapes as hash_from_code only hashes 6 | # the data buffer. Otherwise, this will cause problem with shapes like: 7 | # (1, 0) and (2, 0) 8 | # We also need to add the dtype to make the distinction between 9 | # uint32 and int32 of zeros with the same shape. 10 | 11 | # Python hash is not strong, so use sha256 instead. To avoid having a too 12 | # long hash, I call it again on the contatenation of all parts. 13 | return hash_from_code( 14 | hash_from_code(data.data) 15 | + hash_from_code(data.indices) 16 | + hash_from_code(data.indptr) 17 | + hash_from_code(str(data.shape)) 18 | + hash_from_code(str(data.dtype)) 19 | + hash_from_code(data.format) 20 | ) 21 | -------------------------------------------------------------------------------- /tests/link/mlx/test_sort.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytensor.tensor.sort import argsort, sort 5 | from pytensor.tensor.type import matrix 6 | from tests.link.mlx.test_basic import compare_mlx_and_py 7 | 8 | 9 | @pytest.mark.parametrize("axis", [None, -1]) 10 | @pytest.mark.parametrize("func", (sort, argsort)) 11 | def test_sort(func, axis): 12 | x = matrix("x", shape=(2, 2), dtype="float64") 13 | out = func(x, axis=axis) 14 | arr = np.array([[1.0, 4.0], [5.0, 2.0]]) 15 | compare_mlx_and_py([x], [out], [arr]) 16 | 17 | 18 | def test_sort_invalid_kind_warning(): 19 | x = matrix("x", shape=(2, 2), dtype="float64") 20 | z = sort(x, axis=-1, kind="mergesort") 21 | with pytest.warns(UserWarning, match="MLX sort does not support the kind argument"): 22 | z.eval({x: np.array([[3.0, 1.0], [2.0, 4.0]])}, mode="MLX") 23 | -------------------------------------------------------------------------------- /doc/install.rst: -------------------------------------------------------------------------------- 1 | .. _install: 2 | 3 | Installing PyTensor 4 | =================== 5 | 6 | The latest release of PyTensor can be installed from Pypi using `pip`: 7 | 8 | .. code-block:: bash 9 | 10 | pip install pytensor 11 | 12 | 13 | Or via conda-forge: 14 | 15 | .. code-block:: bash 16 | 17 | conda install -c conda-forge pytensor 18 | 19 | 20 | The current development branch of PyTensor can be installed from GitHub using `pip`: 21 | 22 | 23 | .. code-block:: bash 24 | 25 | pip install git+https://github.com/pymc-devs/pytensor 26 | 27 | 28 | To use the Numba and JAX backend you will need to install these libraries in addition to PyTensor. Please refer to `Numba's installation instructions `__ and `JAX's installation instructions `__ respectively. 29 | -------------------------------------------------------------------------------- /doc/library/graph/features.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_graph_features: 2 | 3 | ================================================ 4 | :mod:`features` -- [doc TODO] 5 | ================================================ 6 | 7 | .. module:: pytensor.graph.features 8 | :platform: Unix, Windows 9 | :synopsis: PyTensor Internals 10 | .. moduleauthor:: LISA 11 | 12 | Guide 13 | ===== 14 | 15 | .. class:: Bookkeeper(object) 16 | 17 | .. class:: History(object) 18 | 19 | .. method:: revert(fgraph, checkpoint) 20 | Reverts the graph to whatever it was at the provided 21 | checkpoint (undoes all replacements). A checkpoint at any 22 | given time can be obtained using self.checkpoint(). 23 | 24 | .. class:: Validator(object) 25 | 26 | .. class:: ReplaceValidate(History, Validator) 27 | 28 | .. method:: replace_validate(fgraph, var, new_var, reason=None) 29 | -------------------------------------------------------------------------------- /scripts/mypy-failing.txt: -------------------------------------------------------------------------------- 1 | pytensor/compile/builders.py 2 | pytensor/compile/debugmode.py 3 | pytensor/compile/function/pfunc.py 4 | pytensor/compile/function/types.py 5 | pytensor/compile/mode.py 6 | pytensor/graph/rewriting/basic.py 7 | pytensor/ifelse.py 8 | pytensor/link/numba/dispatch/elemwise.py 9 | pytensor/link/numba/dispatch/scan.py 10 | pytensor/printing.py 11 | pytensor/raise_op.py 12 | pytensor/tensor/basic.py 13 | pytensor/tensor/blas_c.py 14 | pytensor/tensor/blas_headers.py 15 | pytensor/tensor/elemwise.py 16 | pytensor/tensor/extra_ops.py 17 | pytensor/tensor/math.py 18 | pytensor/tensor/optimize.py 19 | pytensor/tensor/random/basic.py 20 | pytensor/tensor/random/op.py 21 | pytensor/tensor/random/utils.py 22 | pytensor/tensor/rewriting/basic.py 23 | pytensor/tensor/type.py 24 | pytensor/tensor/type_other.py 25 | pytensor/tensor/variable.py 26 | pytensor/_version.py 27 | -------------------------------------------------------------------------------- /doc/library/tensor/index.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_tensor: 2 | 3 | =============================================== 4 | :mod:`tensor` -- Tensor operations in PyTensor 5 | =============================================== 6 | 7 | .. module:: tensor 8 | 9 | PyTensor's strength is in expressing symbolic calculations involving tensors. 10 | 11 | PyTensor tries to emulate the numpy interface as much as possible in the tensor module. 12 | This means that once TensorVariables are created, it should be possibly to define 13 | symbolic expressions using calls that look just like numpy calls, such as 14 | `pt.exp(x).transpose(0, 1)[:, None]` 15 | 16 | 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | 21 | basic 22 | random/index 23 | utils 24 | elemwise 25 | extra_ops 26 | io 27 | slinalg 28 | nlinalg 29 | fft 30 | conv 31 | math_opt 32 | basic_opt 33 | functional 34 | optimize 35 | -------------------------------------------------------------------------------- /tests/link/pytorch/test_blas.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytensor.configdefaults import config 5 | from pytensor.tensor import blas as pt_blas 6 | from pytensor.tensor.type import tensor3 7 | from tests.link.pytorch.test_basic import compare_pytorch_and_py 8 | 9 | 10 | def test_pytorch_BatchedDot(): 11 | # tensor3 . tensor3 12 | a = tensor3("a") 13 | a_test = np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) 14 | b = tensor3("b") 15 | b_test = np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) 16 | out = pt_blas.BatchedDot()(a, b) 17 | 18 | pytensor_pytorch_fn, _ = compare_pytorch_and_py([a, b], [out], [a_test, b_test]) 19 | 20 | # A dimension mismatch should raise a TypeError for compatibility 21 | inputs = [a_test[:-1], b_test] 22 | with pytest.raises(TypeError): 23 | pytensor_pytorch_fn(*inputs) 24 | -------------------------------------------------------------------------------- /pytensor/typed_list/rewriting.py: -------------------------------------------------------------------------------- 1 | from pytensor.compile import optdb 2 | from pytensor.graph.rewriting.basic import WalkingGraphRewriter, node_rewriter 3 | from pytensor.typed_list.basic import Append, Extend, Insert, Remove, Reverse 4 | 5 | 6 | @node_rewriter([Append, Extend, Insert, Reverse, Remove], inplace=True) 7 | def typed_list_inplace_rewrite(fgraph, node): 8 | if ( 9 | isinstance(node.op, Append | Extend | Insert | Reverse | Remove) 10 | and not node.op.inplace 11 | ): 12 | new_op = node.op.__class__(inplace=True) 13 | new_node = new_op(*node.inputs) 14 | return [new_node] 15 | return False 16 | 17 | 18 | optdb.register( 19 | "typed_list_inplace_rewrite", 20 | WalkingGraphRewriter( 21 | typed_list_inplace_rewrite, failure_callback=WalkingGraphRewriter.warn_inplace 22 | ), 23 | "fast_run", 24 | "inplace", 25 | position=50.1, 26 | ) 27 | -------------------------------------------------------------------------------- /.github/release.yml: -------------------------------------------------------------------------------- 1 | # This file has been mostly taken verbatim from https://github.com/pymc-devs/pymc/blob/main/.github/release.yml 2 | # 3 | # This file contains configuration for the automatic generation of release notes in GitHub. 4 | # It's not perfect, but it makes it a little less laborious to write informative release notes. 5 | # Also see https://docs.github.com/en/repositories/releasing-projects-on-github/automatically-generated-release-notes 6 | changelog: 7 | exclude: 8 | labels: 9 | - no releasenotes 10 | categories: 11 | - title: Major Changes 🛠 12 | labels: 13 | - major 14 | - title: New Features 🎉 15 | labels: 16 | - enhancement 17 | - feature request 18 | - title: Bugfixes 🐛 19 | labels: 20 | - bug 21 | - title: Documentation 📖 22 | labels: 23 | - docs 24 | - title: Maintenance 🔧 25 | labels: 26 | - "*" 27 | -------------------------------------------------------------------------------- /pytensor/link/numba/dispatch/__init__.py: -------------------------------------------------------------------------------- 1 | # isort: off 2 | from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify 3 | 4 | # Load dispatch specializations 5 | import pytensor.link.numba.dispatch.blockwise 6 | import pytensor.link.numba.dispatch.compile_ops 7 | import pytensor.link.numba.dispatch.elemwise 8 | import pytensor.link.numba.dispatch.extra_ops 9 | import pytensor.link.numba.dispatch.nlinalg 10 | import pytensor.link.numba.dispatch.random 11 | import pytensor.link.numba.dispatch.scan 12 | import pytensor.link.numba.dispatch.scalar 13 | import pytensor.link.numba.dispatch.shape 14 | import pytensor.link.numba.dispatch.signal 15 | import pytensor.link.numba.dispatch.slinalg 16 | import pytensor.link.numba.dispatch.sort 17 | import pytensor.link.numba.dispatch.sparse 18 | import pytensor.link.numba.dispatch.subtensor 19 | import pytensor.link.numba.dispatch.tensor_basic 20 | import pytensor.link.numba.dispatch.typed_list 21 | 22 | 23 | # isort: on 24 | -------------------------------------------------------------------------------- /pytensor/link/jax/dispatch/signal/conv.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | from pytensor.link.jax.dispatch import jax_funcify 4 | from pytensor.tensor.basic import get_underlying_scalar_constant_value 5 | from pytensor.tensor.exceptions import NotScalarConstantError 6 | from pytensor.tensor.signal.conv import Convolve1d 7 | 8 | 9 | @jax_funcify.register(Convolve1d) 10 | def jax_funcify_Convolve1d(op, node, **kwargs): 11 | _, _, full_mode = node.inputs 12 | try: 13 | full_mode = get_underlying_scalar_constant_value(full_mode) 14 | except NotScalarConstantError: 15 | raise NotImplementedError( 16 | "Cannot compile Convolve1D to jax without static mode" 17 | ) 18 | static_mode = "full" if full_mode else "valid" 19 | 20 | def conv1d(data, kernel, _runtime_full_mode): 21 | # _runtime_full_mode is not used, as we only support static mode 22 | return jax.numpy.convolve(data, kernel, mode=static_mode) 23 | 24 | return conv1d 25 | -------------------------------------------------------------------------------- /pytensor/sparse/sharedvar.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import scipy.sparse 4 | 5 | from pytensor.compile import shared_constructor 6 | from pytensor.sparse.variable import SparseTensorType, SparseVariable 7 | from pytensor.tensor.sharedvar import TensorSharedVariable 8 | 9 | 10 | class SparseTensorSharedVariable(TensorSharedVariable, SparseVariable): # type: ignore[misc] 11 | @property 12 | def format(self): 13 | return self.type.format 14 | 15 | 16 | @shared_constructor.register(scipy.sparse.spmatrix) 17 | def sparse_constructor( 18 | value, name=None, strict=False, allow_downcast=None, borrow=False, format=None 19 | ): 20 | if format is None: 21 | format = value.format 22 | 23 | type = SparseTensorType(format=format, dtype=value.dtype) 24 | 25 | if not borrow: 26 | value = copy.deepcopy(value) 27 | 28 | return SparseTensorSharedVariable( 29 | type=type, value=value, strict=strict, allow_downcast=allow_downcast, name=name 30 | ) 31 | -------------------------------------------------------------------------------- /pytensor/link/jax/dispatch/__init__.py: -------------------------------------------------------------------------------- 1 | # isort: off 2 | from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify 3 | 4 | # Load dispatch specializations 5 | import pytensor.link.jax.dispatch.blas 6 | import pytensor.link.jax.dispatch.blockwise 7 | import pytensor.link.jax.dispatch.einsum 8 | import pytensor.link.jax.dispatch.elemwise 9 | import pytensor.link.jax.dispatch.extra_ops 10 | import pytensor.link.jax.dispatch.pad 11 | import pytensor.link.jax.dispatch.math 12 | import pytensor.link.jax.dispatch.nlinalg 13 | import pytensor.link.jax.dispatch.random 14 | import pytensor.link.jax.dispatch.scalar 15 | import pytensor.link.jax.dispatch.scan 16 | import pytensor.link.jax.dispatch.shape 17 | import pytensor.link.jax.dispatch.signal 18 | import pytensor.link.jax.dispatch.slinalg 19 | import pytensor.link.jax.dispatch.sort 20 | import pytensor.link.jax.dispatch.sparse 21 | import pytensor.link.jax.dispatch.subtensor 22 | import pytensor.link.jax.dispatch.tensor_basic 23 | 24 | # isort: on 25 | -------------------------------------------------------------------------------- /pytensor/tensor/c_code/alt_blas_common.h: -------------------------------------------------------------------------------- 1 | /** C Implementation (with NumPy back-end) of BLAS functions used in PyTensor. 2 | * Used instead of BLAS when PyTensor flag ``blas__ldflags`` is empty. 3 | * This file contains some useful header code not templated. 4 | * File alt_blas_template.c currently contains template code for: 5 | * - [sd]gemm_ 6 | * - [sd]gemv_ 7 | * - [sd]dot_ 8 | **/ 9 | 10 | #define alt_fatal_error(message) { if (PyErr_Occurred()) PyErr_Print(); if(message != NULL) fprintf(stderr, message); exit(-1); } 11 | 12 | #define alt_trans_to_bool(trans) (*trans != 'N' && *trans != 'n') 13 | 14 | /**Template code for BLAS functions follows in file alt_blas_template.c 15 | * (as Python string to be used with old formatting). 16 | * PARAMETERS: 17 | * float_type: "float" or "double". 18 | * float_size: 4 for float32 (sgemm_), 8 for float64 (dgemm_). 19 | * npy_float: "NPY_FLOAT32" or "NPY_FLOAT64". 20 | * precision: "s" for single, "d" for double. 21 | * See blas_headers.py for current use.**/ 22 | -------------------------------------------------------------------------------- /doc/gallery/page_footer.md: -------------------------------------------------------------------------------- 1 | ## License notice 2 | All the notebooks in this example gallery are provided under a 3 | [3-Clause BSD License](https://github.com/pymc-devs/pytensor/blob/main/doc/LICENSE.txt) 4 | which allows modification, and redistribution for any 5 | use provided the copyright and license notices are preserved. 6 | 7 | ## Citing Pytensor Examples 8 | 9 | To cite this notebook, please use the suggested citation below. 10 | 11 | :::{important} 12 | Many notebooks are adapted from other sources: blogs, books... In such cases you should 13 | cite the original source as well. 14 | 15 | Also remember to cite the relevant libraries used by your code. 16 | ::: 17 | 18 | Here is an example citation template in bibtex: 19 | 20 | {{ citation_code }} 21 | 22 | which once rendered could look like: 23 | 24 | 25 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | 4 | import numpy 5 | import versioneer 6 | from setuptools import Extension, setup 7 | from setuptools.dist import Distribution 8 | 9 | 10 | dist = Distribution() 11 | dist.parse_config_files() 12 | 13 | 14 | NAME: str = dist.get_name() # type: ignore 15 | 16 | # Check if building for Pyodide 17 | is_pyodide = os.getenv("PYODIDE", "0") == "1" 18 | 19 | if is_pyodide: 20 | # For pyodide we build a universal wheel that must be pure-python 21 | # so we must omit the cython-version of scan. 22 | ext_modules = [] 23 | else: 24 | ext_modules = [ 25 | Extension( 26 | name="pytensor.scan.scan_perform", 27 | sources=["pytensor/scan/scan_perform.pyx"], 28 | include_dirs=[numpy.get_include()], 29 | ), 30 | ] 31 | 32 | if __name__ == "__main__": 33 | setup( 34 | name=NAME, 35 | version=versioneer.get_version(), 36 | cmdclass=versioneer.get_cmdclass(), 37 | ext_modules=ext_modules, 38 | ) 39 | -------------------------------------------------------------------------------- /tests/sparse/test_linalg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytensor.configdefaults import config 5 | from pytensor.sparse.linalg import block_diag 6 | 7 | 8 | sp = pytest.importorskip("scipy", minversion="0.7.0") 9 | 10 | 11 | @pytest.mark.parametrize("format", ["csc", "csr"], ids=["csc", "csr"]) 12 | @pytest.mark.parametrize("sparse_input", [True, False], ids=["sparse", "dense"]) 13 | def test_block_diagonal(format, sparse_input): 14 | from scipy import sparse as sp_sparse 15 | 16 | f_array = sp_sparse.csr_matrix if sparse_input else np.array 17 | A = f_array([[1, 2], [3, 4]]).astype(config.floatX) 18 | B = f_array([[5, 6], [7, 8]]).astype(config.floatX) 19 | 20 | result = block_diag(A, B, format=format) 21 | assert result.owner.op._props_dict() == {"n_inputs": 2, "format": format} 22 | 23 | sp_result = sp_sparse.block_diag([A, B], format=format) 24 | 25 | assert isinstance(result.eval(), type(sp_result)) 26 | np.testing.assert_allclose(result.eval().toarray(), sp_result.toarray()) 27 | -------------------------------------------------------------------------------- /tests/link/jax/test_einsum.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import pytensor.tensor as pt 5 | from tests.link.jax.test_basic import compare_jax_and_py 6 | 7 | 8 | jax = pytest.importorskip("jax") 9 | 10 | 11 | def test_jax_einsum(): 12 | subscripts = "ij, jk, kl -> il" 13 | x = np.random.rand(3, 5) 14 | y = np.random.rand(5, 2) 15 | z = np.random.rand(2, 4) 16 | 17 | shapes = { 18 | "x": (3, 5), 19 | "y": (5, 2), 20 | "z": (2, 4), 21 | } 22 | x_pt, y_pt, z_pt = (pt.tensor(name, shape=shape) for name, shape in shapes.items()) 23 | out = pt.einsum(subscripts, x_pt, y_pt, z_pt) 24 | compare_jax_and_py([x_pt, y_pt, z_pt], [out], [x, y, z]) 25 | 26 | 27 | def test_ellipsis_einsum(): 28 | subscripts = "...i,...i->..." 29 | x = np.random.rand(2, 5) 30 | y = np.random.rand(2, 5) 31 | 32 | x_pt = pt.tensor("x", shape=x.shape) 33 | y_pt = pt.tensor("y", shape=y.shape) 34 | out = pt.einsum(subscripts, x_pt, y_pt) 35 | compare_jax_and_py([x_pt, y_pt], [out], [x, y]) 36 | -------------------------------------------------------------------------------- /tests/misc/test_pkl_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from pathlib import Path 4 | from tempfile import mkdtemp 5 | 6 | from pytensor.misc.pkl_utils import StripPickler 7 | from pytensor.tensor.type import matrix 8 | 9 | 10 | # FIXME: this test looks weird 11 | class TestStripPickler: 12 | def setup_method(self): 13 | # Work in a temporary directory to avoid cluttering the repository 14 | self.origdir = Path.cwd() 15 | self.tmpdir = mkdtemp() 16 | os.chdir(self.tmpdir) 17 | 18 | def teardown_method(self): 19 | # Get back to the original dir, and delete the temporary one 20 | os.chdir(self.origdir) 21 | if self.tmpdir is not None: 22 | shutil.rmtree(self.tmpdir) 23 | 24 | def test_basic(self): 25 | with Path("test.pkl").open("wb"): 26 | m = matrix() 27 | dest_pkl = "my_test.pkl" 28 | with Path(dest_pkl).open("wb") as f: 29 | strip_pickler = StripPickler(f, protocol=-1) 30 | strip_pickler.dump(m) 31 | -------------------------------------------------------------------------------- /doc/library/compile/opfromgraph.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. _opfromgraph: 4 | 5 | ============= 6 | `OpFromGraph` 7 | ============= 8 | 9 | This page describes :class:`pytensor.compile.builders.OpFromGraph 10 | `, an `Op` constructor that allows one to 11 | encapsulate an PyTensor graph in a single `Op`. 12 | 13 | This can be used to encapsulate some functionality in one block. It is 14 | useful to scale PyTensor compilation for regular bigger graphs when we 15 | reuse that encapsulated functionality with different inputs many 16 | times. Due to this encapsulation, it can make PyTensor's compilation phase 17 | faster for graphs with many nodes. 18 | 19 | Using this for small graphs is not recommended as it disables 20 | rewrites between what is inside the encapsulation and outside of it. 21 | 22 | .. note: 23 | 24 | This was not used widely up to now. If you have any 25 | questions/comments do not hesitate to contact us on the mailing list. 26 | 27 | 28 | 29 | .. autoclass:: pytensor.compile.builders.OpFromGraph 30 | -------------------------------------------------------------------------------- /tests/link/pytorch/test_math.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pytensor.configdefaults import config 4 | from pytensor.tensor.type import matrix, scalar, vector 5 | from tests.link.pytorch.test_basic import compare_pytorch_and_py 6 | 7 | 8 | def test_pytorch_dot(): 9 | y = vector("y") 10 | y_test = np.r_[1.0, 2.0].astype(config.floatX) 11 | x = vector("x") 12 | x_test = np.r_[3.0, 4.0].astype(config.floatX) 13 | A = matrix("A") 14 | A_test = np.array([[6, 3], [3, 0]], dtype=config.floatX) 15 | alpha = scalar("alpha") 16 | alpha_test = np.array(3.0, dtype=config.floatX) 17 | beta = scalar("beta") 18 | beta_test = np.array(5.0, dtype=config.floatX) 19 | 20 | # 2D * 2D 21 | out = A.dot(A * alpha) + beta * A 22 | 23 | compare_pytorch_and_py([A, alpha, beta], [out], [A_test, alpha_test, beta_test]) 24 | 25 | # 1D * 2D and 1D * 1D 26 | out = y.dot(alpha * A).dot(x) + beta * y 27 | 28 | compare_pytorch_and_py( 29 | [y, x, A, alpha, beta], [out], [y_test, x_test, A_test, alpha_test, beta_test] 30 | ) 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | # This issue template was adapted from the NumPy project 2 | # under the BSD 3-Clause "New" or "Revised" License. 3 | # Copyright (c) 2005-2022, NumPy Developers. 4 | # All rights reserved. 5 | 6 | 7 | name: Documentation 8 | description: Report an issue related to the PyTensor documentation. 9 | title: "DOC: " 10 | labels: [docs] 11 | 12 | body: 13 | - type: textarea 14 | attributes: 15 | label: "Issue with current documentation:" 16 | description: > 17 | Please make sure to leave a reference to the document/code you're 18 | referring to. You can also check the development version of the 19 | documentation and see if this issue has already been addressed at 20 | https://pytensor.readthedocs.io/en/latest/. 21 | 22 | - type: textarea 23 | attributes: 24 | label: "Idea or request for content:" 25 | description: > 26 | Please describe as clearly as possible what topics you think are missing 27 | from the current documentation. 28 | -------------------------------------------------------------------------------- /.github/workflows/mypy.yml: -------------------------------------------------------------------------------- 1 | name: mypy 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | mypy: 10 | runs-on: ubuntu-latest 11 | defaults: 12 | run: 13 | shell: bash -leo pipefail {0} 14 | steps: 15 | - uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 16 | with: 17 | persist-credentials: false 18 | - uses: mamba-org/setup-micromamba@add3a49764cedee8ee24e82dfde87f5bc2914462 # v2.0.7 19 | with: 20 | micromamba-version: "1.5.10-0" # until https://github.com/mamba-org/setup-micromamba/issues/225 is resolved 21 | environment-file: environment.yml 22 | init-shell: bash 23 | cache-environment: true 24 | post-cleanup: "all" 25 | - name: Install pytensor and mypy dependencies 26 | run: | 27 | pip install -e . 28 | python --version 29 | shell: micromamba-shell {0} 30 | - name: Run mypy 31 | run: | 32 | python ./scripts/run_mypy.py --verbose 33 | shell: micromamba-shell {0} 34 | -------------------------------------------------------------------------------- /pytensor/misc/may_share_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Function to detect memory sharing for ndarray AND sparse type. 3 | numpy version support only ndarray. 4 | """ 5 | 6 | import numpy as np 7 | 8 | from pytensor.tensor.type import TensorType 9 | 10 | 11 | try: 12 | import scipy.sparse 13 | 14 | from pytensor.sparse.basic import SparseTensorType 15 | 16 | def _is_sparse(a): 17 | return scipy.sparse.issparse(a) 18 | 19 | except ImportError: 20 | 21 | def _is_sparse(a): 22 | return False 23 | 24 | 25 | def may_share_memory(a, b, raise_other_type=True): 26 | a_ndarray = isinstance(a, np.ndarray) 27 | b_ndarray = isinstance(b, np.ndarray) 28 | if a_ndarray and b_ndarray: 29 | return TensorType.may_share_memory(a, b) 30 | 31 | a_sparse = _is_sparse(a) 32 | b_sparse = _is_sparse(b) 33 | if not ((a_ndarray or a_sparse) and (b_ndarray or b_sparse)): 34 | if raise_other_type: 35 | raise TypeError("may_share_memory support only ndarray and scipy.sparse") 36 | return False 37 | 38 | return SparseTensorType.may_share_memory(a, b) 39 | -------------------------------------------------------------------------------- /pytensor/link/mlx/dispatch/extra_ops.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | 3 | from pytensor.link.mlx.dispatch.basic import mlx_funcify 4 | from pytensor.tensor.extra_ops import CumOp, Repeat 5 | 6 | 7 | @mlx_funcify.register(CumOp) 8 | def mlx_funcify_CumOp(op, **kwargs): 9 | axis = op.axis 10 | mode = op.mode 11 | 12 | def cumop(x, axis=axis, mode=mode): 13 | match mode: 14 | case "add": 15 | return mx.cumsum(x, axis=axis) 16 | case "mul": 17 | return mx.cumprod(x, axis=axis) 18 | case _: 19 | raise NotImplementedError(f"CumOp mode {mode} not implemented in MLX") 20 | 21 | return cumop 22 | 23 | 24 | @mlx_funcify.register(Repeat) 25 | def jax_funcify_Repeat(op, **kwargs): 26 | axis = op.axis 27 | 28 | def repeat(x, repeats, axis=axis): 29 | if not isinstance(repeats, int): 30 | raise NotImplementedError( 31 | "MLX repeat does not support sequence-valued repeat argument." 32 | ) 33 | return mx.repeat(x, repeats, axis=axis) 34 | 35 | return repeat 36 | -------------------------------------------------------------------------------- /doc/library/graph/fgraph.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _libdoc_graph_fgraph: 3 | 4 | ================================================ 5 | :mod:`fg` -- Graph Container [doc TODO] 6 | ================================================ 7 | 8 | .. module:: pytensor.graph.fg 9 | :platform: Unix, Windows 10 | :synopsis: PyTensor Internals 11 | .. moduleauthor:: LISA 12 | 13 | 14 | .. _fgraph: 15 | 16 | FunctionGraph 17 | ------------- 18 | 19 | .. autoclass:: pytensor.graph.fg.FunctionGraph 20 | :members: 21 | 22 | ***TODO*** 23 | 24 | .. note:: FunctionGraph(inputs, outputs) clones the inputs by 25 | default. To avoid this behavior, add the parameter 26 | clone=False. This is needed as we do not want cached constants 27 | in fgraph. 28 | 29 | .. _libdoc_graph_fgraphfeature: 30 | 31 | .. _fgraphfeature: 32 | 33 | FunctionGraph Features 34 | ---------------------- 35 | 36 | .. autoclass:: pytensor.graph.features.Feature 37 | :members: 38 | 39 | .. _libdoc_graph_fgraphfeaturelist: 40 | 41 | FunctionGraph Feature List 42 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 43 | * ReplaceValidate 44 | * DestroyHandler 45 | -------------------------------------------------------------------------------- /environment-osx-arm64.yml: -------------------------------------------------------------------------------- 1 | # To use: 2 | # 3 | # $ conda env create -f environment.yml # `mamba` works too for this command 4 | # $ conda activate pytensor-dev 5 | # 6 | name: pytensor-dev 7 | channels: 8 | - conda-forge 9 | dependencies: 10 | - python>=3.11 11 | - compilers 12 | - numpy>=2.0.0 13 | - scipy>=1,<2 14 | - filelock>=3.15 15 | - etuples 16 | - logical-unification 17 | - miniKanren 18 | - cons 19 | - pydeprecate 20 | # Apple BLAS 21 | - libblas=*=*accelerate 22 | - numba>=0.57 23 | # For testing 24 | - coveralls 25 | - diff-cover 26 | - mypy 27 | - types-setuptools 28 | - pytest 29 | - pytest-cov 30 | - pytest-xdist 31 | - pytest-benchmark 32 | - pytest-mock 33 | - pytest-sphinx 34 | # For building docs 35 | - sphinx>=5.1.0,<6 36 | - sphinx_rtd_theme 37 | - pygments 38 | - pydot 39 | - ipython 40 | - pymc-sphinx-theme 41 | - sphinx-design 42 | # code style 43 | - ruff 44 | # developer tools 45 | - pandas # required to run mypy script 46 | - pre-commit 47 | - packaging 48 | # optional 49 | - cython 50 | - graphviz 51 | - pydot 52 | -------------------------------------------------------------------------------- /tests/link/mlx/test_math.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import pytensor 5 | import pytensor.tensor as pt 6 | from pytensor.tensor.math import Argmax, Max 7 | from tests.link.mlx.test_basic import compare_mlx_and_py 8 | 9 | 10 | mx = pytest.importorskip("mlx.core") 11 | 12 | 13 | def test_dot(): 14 | x = pt.matrix("x") 15 | y = pt.matrix("y") 16 | 17 | out = x.dot(y) 18 | fn = pytensor.function([x, y], out, mode="MLX") 19 | 20 | seed = sum(map(ord, "test_mlx_dot")) 21 | rng = np.random.default_rng(seed) 22 | 23 | test_x = rng.normal(size=(3, 2)) 24 | test_y = rng.normal(size=(2, 4)) 25 | 26 | actual = fn(test_x, test_y) 27 | assert isinstance(actual, mx.array) 28 | expected = np.dot(test_x, test_y) 29 | np.testing.assert_allclose(actual, expected, rtol=1e-6) 30 | 31 | 32 | def test_mlx_max_and_argmax(): 33 | # Test that a single output of a multi-output `Op` can be used as input to 34 | # another `Op` 35 | x = pt.dvector() 36 | mx = Max([0])(x) 37 | amx = Argmax([0])(x) 38 | out = mx * amx 39 | compare_mlx_and_py([x], [out], [np.r_[1, 2]]) 40 | -------------------------------------------------------------------------------- /.github/workflows/slow-tests-issue.yml: -------------------------------------------------------------------------------- 1 | # Taken from https://github.com/pymc-labs/pymc-marketing/tree/main/.github/workflows/slow-tests-issue.yml 2 | # See the scripts in the `scripts/slowest_tests` directory for more information 3 | --- 4 | name: Slow Tests Issue Body 5 | 6 | on: 7 | workflow_dispatch: 8 | schedule: 9 | - cron: "0 */6 * * *" 10 | 11 | permissions: 12 | issues: write 13 | 14 | jobs: 15 | update-comment: 16 | runs-on: ubuntu-latest 17 | steps: 18 | - name: Install ZSH 19 | run: sudo apt-get update && sudo apt-get install -y zsh 20 | - name: Checkout code 21 | uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 22 | - name: Set up Python 23 | uses: actions/setup-python@e797f83bcb11b83ae66e0230d6156d7c80228e7c # v6.0.0 24 | with: 25 | python-version: "3.11" 26 | - name: Trigger the script 27 | continue-on-error: true 28 | working-directory: scripts/slowest_tests 29 | shell: zsh {0} 30 | run: source update-slowest-times-issue.sh 31 | env: 32 | GITHUB_TOKEN: ${{ github.token }} 33 | -------------------------------------------------------------------------------- /pytensor/link/mlx/dispatch/sort.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import mlx.core as mx 4 | 5 | from pytensor.link.mlx.dispatch.basic import mlx_funcify 6 | from pytensor.tensor.sort import ArgSortOp, SortOp 7 | 8 | 9 | @mlx_funcify.register(SortOp) 10 | def mlx_funcify_Sort(op, **kwargs): 11 | kind = op.kind 12 | if kind != "quicksort": 13 | warnings.warn( 14 | message=f"MLX sort does not support the kind argument (got kind={kind}). The argument will be " 15 | f"ignored.", 16 | category=UserWarning, 17 | ) 18 | 19 | def sort(x, axis): 20 | return mx.sort(x, axis=axis) 21 | 22 | return sort 23 | 24 | 25 | @mlx_funcify.register(ArgSortOp) 26 | def mlx_funcify_ArgSort(op, **kwargs): 27 | kind = op.kind 28 | if kind != "quicksort": 29 | warnings.warn( 30 | message=f"MLX argsort does not support the kind argument (got kind={kind}). The argument will be " 31 | f"ignored.", 32 | category=UserWarning, 33 | ) 34 | 35 | def argsort(x, axis): 36 | return mx.argsort(x, axis=axis) 37 | 38 | return argsort 39 | -------------------------------------------------------------------------------- /doc/core_development_guide.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | Core Development Guide 4 | ======================= 5 | 6 | The documentation of the core components of PyTensor is still a work in 7 | progress. For now this is a list of bits and pieces on the subject, 8 | some of them might be outdated though: 9 | 10 | 11 | * :ref:`pytensor_type` -- Tutorial for writing a new type in PyTensor. It 12 | introduces the basics concerning PyTensor datatypes. 13 | 14 | * :ref:`pytensor_ctype` -- Tutorial on how to make your type C-friendly. 15 | 16 | * :ref:`views_and_inplace` -- This is somewhere between extending PyTensor and 17 | describing how PyTensor works internally; it talks about views and inplace 18 | operations. 19 | 20 | * :ref:`graph_rewriting` -- Tutorial on how graph rewriting works in PyTensor. 21 | 22 | * :ref:`pipeline` -- Describes the steps of compiling an PyTensor Function. 23 | 24 | * :ref:`graphstructures` -- Describes the symbolic graphs generated by 25 | :mod:`pytensor.scan`. 26 | 27 | * :ref:`unittest` -- Tutorial on how to use unittest in testing PyTensor. 28 | 29 | * :ref:`libdoc_sparse` -- Description of the ``sparse`` type in PyTensor. 30 | -------------------------------------------------------------------------------- /.github/workflows/zizmor.yml: -------------------------------------------------------------------------------- 1 | # https://github.com/woodruffw/zizmor 2 | name: zizmor GHA analysis 3 | 4 | on: 5 | push: 6 | branches: ["main"] 7 | pull_request: 8 | branches: ["**"] 9 | 10 | jobs: 11 | zizmor: 12 | name: zizmor latest via PyPI 13 | runs-on: ubuntu-latest 14 | permissions: 15 | security-events: write 16 | steps: 17 | - name: Checkout repository 18 | uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 19 | with: 20 | persist-credentials: false 21 | 22 | - uses: hynek/setup-cached-uv@757bedc3f972eb7227a1aa657651f15a8527c817 # v2.3.0 23 | 24 | - name: Run zizmor 🌈 25 | run: uvx zizmor --format sarif . > results.sarif 26 | env: 27 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 28 | 29 | - name: Upload SARIF file 30 | uses: github/codeql-action/upload-sarif@v3 31 | with: 32 | # Path to SARIF file relative to the root of the repository 33 | sarif_file: results.sarif 34 | # Optional category for the results 35 | # Used to differentiate multiple results for one commit 36 | category: zizmor 37 | -------------------------------------------------------------------------------- /pytensor/link/c/cvm.py: -------------------------------------------------------------------------------- 1 | from pytensor.configdefaults import config 2 | from pytensor.link.c.exceptions import MissingGXX 3 | from pytensor.link.vm import VM 4 | 5 | 6 | try: 7 | # If cxx is explicitly set to an empty string, we do not want to import 8 | # either lazy-linker C code or lazy-linker compiled C code from the cache. 9 | if not config.cxx: 10 | raise MissingGXX( 11 | "lazylinker will not be imported if pytensor.config.cxx is not set." 12 | ) 13 | from pytensor.link.c.lazylinker_c import CLazyLinker 14 | 15 | class CVM(CLazyLinker, VM): 16 | def __init__(self, fgraph, *args, **kwargs): 17 | self.fgraph = fgraph 18 | # skip VM.__init__ 19 | CLazyLinker.__init__(self, *args, **kwargs) 20 | 21 | except ImportError: 22 | pass 23 | except (OSError, MissingGXX): 24 | # OSError happens when g++ is not installed. In that case, we 25 | # already changed the default linker to something else then CVM. 26 | # Currently this is the py linker. 27 | # Here we assert that the default linker is not cvm. 28 | if config._config_var_dict["linker"].default.startswith("cvm"): 29 | raise 30 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # To use: 2 | # 3 | # $ conda env create -f environment.yml # `mamba` works too for this command 4 | # $ conda activate pytensor-dev 5 | # 6 | name: pytensor-dev 7 | channels: 8 | - conda-forge 9 | dependencies: 10 | - python>=3.11 11 | - compilers 12 | - numpy>=2.0.0 13 | - scipy>=1,<2 14 | - filelock>=3.15 15 | - etuples 16 | - logical-unification 17 | - miniKanren 18 | - cons 19 | - pydeprecate 20 | # Intel BLAS 21 | - mkl 22 | - mkl-service 23 | - libblas=*=*mkl 24 | - numba>=0.57 25 | # For testing 26 | - coveralls 27 | - diff-cover 28 | - mypy 29 | - types-setuptools 30 | - pytest 31 | - pytest-cov 32 | - pytest-xdist 33 | - pytest-benchmark 34 | - pytest-mock 35 | - pytest-sphinx 36 | # For building docs 37 | - sphinx>=5.1.0,<6 38 | - sphinx_rtd_theme 39 | - pygments 40 | - pydot 41 | - ipython 42 | - pymc-sphinx-theme 43 | - sphinx-design 44 | - myst-nb 45 | - matplotlib 46 | - watermark 47 | 48 | # code style 49 | - ruff 50 | # developer tools 51 | - pandas # required to run mypy script 52 | - pre-commit 53 | - packaging 54 | # optional 55 | - cython 56 | - graphviz 57 | - pydot 58 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | 6 | # Using pytest_plugins causes `tests/link/c/test_cmodule.py::test_cache_versioning` to fail 7 | # pytest_plugins = ["tests.fixtures"] 8 | 9 | 10 | def pytest_sessionstart(session): 11 | os.environ["PYTENSOR_FLAGS"] = ",".join( 12 | [ 13 | os.environ.setdefault("PYTENSOR_FLAGS", ""), 14 | "warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,cmodule__warn_no_version=True", 15 | ] 16 | ) 17 | os.environ["NUMBA_BOUNDSCHECK"] = "1" 18 | 19 | 20 | def pytest_addoption(parser): 21 | parser.addoption( 22 | "--runslow", action="store_true", default=False, help="run slow tests" 23 | ) 24 | 25 | 26 | def pytest_configure(config): 27 | config.addinivalue_line("markers", "slow: mark test as slow to run") 28 | 29 | 30 | def pytest_collection_modifyitems(config, items): 31 | if config.getoption("--runslow"): 32 | # --runslow given in cli: do not skip slow tests 33 | return 34 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 35 | for item in items: 36 | if "slow" in item.keywords: 37 | item.add_marker(skip_slow) 38 | -------------------------------------------------------------------------------- /doc/internal/how_to_release.rst: -------------------------------------------------------------------------------- 1 | .. _how_to_release: 2 | 3 | ================================================== 4 | How to make a release 5 | ================================================== 6 | 7 | Update the version number 8 | ========================= 9 | 10 | ``PyTensor/doc/conf.py`` should be updated in the following ways: 11 | 12 | * Change the upper copyright year to the current year if necessary. 13 | 14 | Update the year in the ``PyTensor/LICENSE.txt`` file too, if necessary. 15 | 16 | Update the code and the documentation for the pytensor flags 17 | ``warn__ignore_bug_before`` to accept the new version. You must modify the 18 | file ``pytensor/configdefaults.py`` and ``doc/library/config.txt``. 19 | 20 | Tag the release 21 | =============== 22 | 23 | You will need to commit the previous changes, tag the resulting version, and 24 | push that into the upstream/official repository. After that, create a new release 25 | via GitHub Releases on the repository's page. The release tag must start with 26 | ``rel-`` in order to be recognized by the CI release process. 27 | 28 | This will trigger and build and upload of the PyPI and Conda packages. 29 | 30 | The documentation will be automatically regenerated as well. 31 | -------------------------------------------------------------------------------- /doc/acknowledgement.rst: -------------------------------------------------------------------------------- 1 | .. _acknowledgement: 2 | 3 | 4 | Acknowledgements 5 | ================ 6 | 7 | .. note: 8 | 9 | This page is in construction. We are missing sources. 10 | 11 | 12 | * The developers of `NumPy `_. Theano is based on its ndarray object and uses much of its implementation. 13 | * The developers of `SciPy `_. Our sparse matrix support uses their sparse matrix objects. We also reuse other parts. 14 | * The developers of `Theano `_ 15 | * All `PyTensor contributors `_. 16 | * All Theano users that have given us feedback. 17 | * Our random number generator implementation on CPU and GPU uses the MRG31k3p algorithm that is described in: 18 | 19 | P. L'Ecuyer and R. Touzin, `Fast Combined Multiple Recursive Generators with Multipliers of the form a = +/- 2^d +/- 2^e `_, Proceedings of the 2000 Winter Simulation Conference, Dec. 2000, 683--689. 20 | 21 | We were authorized by Pierre L'Ecuyer to copy/modify his Java implementation in the `SSJ `_ software and to relicense it under BSD 3-Clauses in Theano. 22 | -------------------------------------------------------------------------------- /pytensor/link/mlx/dispatch/shape.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | 3 | from pytensor.link.mlx.dispatch.basic import mlx_funcify 4 | from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape 5 | 6 | 7 | @mlx_funcify.register(Shape) 8 | def mlx_funcify_Shape(op, **kwargs): 9 | def shape(x): 10 | return mx.array(x.shape, dtype=mx.int64) 11 | 12 | return shape 13 | 14 | 15 | @mlx_funcify.register(SpecifyShape) 16 | def mlx_funcify_SpecifyShape(op, node, **kwargs): 17 | def specifyshape(x, *shape): 18 | assert x.ndim == len(shape) 19 | for actual, expected in zip(x.shape, shape, strict=True): 20 | if expected is None: 21 | continue 22 | if actual != expected: 23 | raise ValueError(f"Invalid shape: Expected {shape} but got {x.shape}") 24 | return x 25 | 26 | return specifyshape 27 | 28 | 29 | @mlx_funcify.register(Shape_i) 30 | def mlx_funcify_Shape_i(op, node, **kwargs): 31 | def shape_i(x): 32 | return x.shape[op.i] 33 | 34 | return shape_i 35 | 36 | 37 | @mlx_funcify.register(Reshape) 38 | def mlx_funcify_Reshape(op, **kwargs): 39 | def reshape(x, shp): 40 | return mx.reshape(x, shp) 41 | 42 | return reshape 43 | -------------------------------------------------------------------------------- /pytensor/tensor/random/var.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import numpy as np 4 | 5 | from pytensor.compile.sharedvalue import SharedVariable, shared_constructor 6 | from pytensor.tensor.random.type import random_generator_type 7 | 8 | 9 | class RandomGeneratorSharedVariable(SharedVariable): 10 | def __str__(self): 11 | return self.name or f"RNG({self.container!r})" 12 | 13 | 14 | @shared_constructor.register(np.random.RandomState) 15 | @shared_constructor.register(np.random.Generator) 16 | def randomgen_constructor( 17 | value, name=None, strict=False, allow_downcast=None, borrow=False 18 | ): 19 | r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`.""" 20 | if isinstance(value, np.random.RandomState): 21 | raise TypeError( 22 | "`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead." 23 | ) 24 | 25 | rng_sv_type = RandomGeneratorSharedVariable 26 | rng_type = random_generator_type 27 | 28 | if not borrow: 29 | value = copy.deepcopy(value) 30 | 31 | return rng_sv_type( 32 | type=rng_type, 33 | value=value, 34 | strict=strict, 35 | allow_downcast=allow_downcast, 36 | name=name, 37 | ) 38 | -------------------------------------------------------------------------------- /pytensor/link/mlx/dispatch/signal/conv.py: -------------------------------------------------------------------------------- 1 | import mlx.core as mx 2 | 3 | from pytensor.link.mlx.dispatch import mlx_funcify, mlx_typify 4 | from pytensor.tensor.basic import get_underlying_scalar_constant_value 5 | from pytensor.tensor.exceptions import NotScalarConstantError 6 | from pytensor.tensor.signal.conv import Convolve1d 7 | 8 | 9 | @mlx_funcify.register(Convolve1d) 10 | def mlx_funcify_Convolve1d(op, node, **kwargs): 11 | _, _, full_mode_var = node.inputs 12 | 13 | try: 14 | full_mode = bool(get_underlying_scalar_constant_value(full_mode_var)) 15 | runtime_mode_static = True 16 | except NotScalarConstantError: 17 | full_mode = True 18 | runtime_mode_static = False 19 | 20 | def conv1d(raw_data, raw_kernel, runtime_full_mode): 21 | data = mlx_typify(raw_data, dtype=None) 22 | kernel = mlx_typify(raw_kernel, dtype=None) 23 | 24 | if runtime_mode_static: 25 | runtime_mode = full_mode 26 | else: 27 | runtime_full_mode = mx.array(runtime_full_mode) 28 | runtime_mode = bool(runtime_full_mode.reshape(-1)[0]) 29 | 30 | mode = "full" if runtime_mode else "valid" 31 | return mx.convolve(data, kernel, mode=mode) 32 | 33 | return conv1d 34 | -------------------------------------------------------------------------------- /pytensor/link/pytorch/dispatch/blockwise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytensor.graph import FunctionGraph 4 | from pytensor.link.pytorch.dispatch import pytorch_funcify 5 | from pytensor.tensor.blockwise import Blockwise 6 | 7 | 8 | @pytorch_funcify.register(Blockwise) 9 | def funcify_Blockwise(op: Blockwise, node, *args, **kwargs): 10 | batched_dims = op.batch_ndim(node) 11 | core_node = op._create_dummy_core_node(node.inputs) 12 | core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs) 13 | inner_func = pytorch_funcify( 14 | core_fgraph, squeeze_output=len(node.outputs) == 1, **kwargs 15 | ) 16 | 17 | for _ in range(batched_dims): 18 | inner_func = torch.vmap(inner_func) 19 | 20 | def batcher(*inputs): 21 | op._check_runtime_broadcast(node, inputs) 22 | # broadcast on batched_dims 23 | all_batched_dims = tuple(t.shape[:batched_dims] for t in inputs) 24 | batched_shape = torch.broadcast_shapes(*all_batched_dims) 25 | broadcast_inputs = [ 26 | torch.broadcast_to(i, batched_shape + i.shape[batched_dims:]) 27 | for i in inputs 28 | ] 29 | res = inner_func(*broadcast_inputs) 30 | return res 31 | 32 | return batcher 33 | -------------------------------------------------------------------------------- /doc/extending/extending_faq.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _extend_faq: 3 | 4 | =========================================== 5 | Extending PyTensor: FAQ and Troubleshooting 6 | =========================================== 7 | 8 | I wrote a new `Op`\/`Type`, and weird stuff is happening... 9 | ----------------------------------------------------------- 10 | 11 | First, check the :ref:`op_contract` and the :ref:`type_contract` 12 | and make sure you're following the rules. 13 | Then try running your program in :ref:`using_debugmode`. `DebugMode` might catch 14 | something that you're not seeing. 15 | 16 | 17 | I wrote a new rewrite, but it's not getting used... 18 | --------------------------------------------------- 19 | 20 | Remember that you have to register rewrites with the :ref:`optdb` 21 | for them to get used by the normal modes like FAST_COMPILE, FAST_RUN, 22 | and `DebugMode`. 23 | 24 | 25 | I wrote a new rewrite, and it changed my results even though I'm pretty sure it is correct. 26 | ------------------------------------------------------------------------------------------- 27 | 28 | First, check the :ref:`op_contract` and make sure you're following the rules. 29 | Then try running your program in :ref:`using_debugmode`. `DebugMode` might 30 | catch something that you're not seeing. 31 | -------------------------------------------------------------------------------- /pytensor/link/numba/linker.py: -------------------------------------------------------------------------------- 1 | from pytensor.link.basic import JITLinker 2 | 3 | 4 | class NumbaLinker(JITLinker): 5 | required_rewrites = ( 6 | "minimum_compile", 7 | "numba", 8 | ) # TODO: Distinguish between optional "numba" and "minimum_compile_numba" 9 | incompatible_rewrites = ( 10 | "cxx", 11 | "BlasOpt", 12 | "local_careduce_fusion", 13 | "scan_save_mem_prealloc", 14 | ) 15 | 16 | """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" 17 | 18 | def fgraph_convert(self, fgraph, **kwargs): 19 | # Import numba_njit_and_cache lazily (as numba is an optional dependency) 20 | # This is what triggers the registering of the dispatches as well 21 | from pytensor.link.numba.dispatch.basic import numba_funcify_ensure_cache 22 | 23 | return numba_funcify_ensure_cache(fgraph, **kwargs) 24 | 25 | def jit_compile(self, fn_and_cache): 26 | from pytensor.link.numba.dispatch.basic import numba_njit 27 | 28 | fn, cache_key = fn_and_cache 29 | return numba_njit(fn.py_func, final_function=True, cache=cache_key is not None) 30 | 31 | def create_thunk_inputs(self, storage_map): 32 | return [storage_map[n] for n in self.fgraph.inputs] 33 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | ci: 2 | autofix_prs: false 3 | 4 | exclude: | 5 | (?x)^( 6 | versioneer\.py| 7 | pytensor/_version\.py| 8 | doc/.*| 9 | )$ 10 | repos: 11 | - repo: https://github.com/pre-commit/pre-commit-hooks 12 | rev: v6.0.0 13 | hooks: 14 | - id: debug-statements 15 | exclude: | 16 | (?x)^( 17 | pytensor/breakpoint\.py| 18 | pytensor/graph/op\.py| 19 | pytensor/compile/nanguardmode\.py| 20 | pytensor/graph/rewriting/basic\.py| 21 | pytensor/tensor/variable\.py| 22 | )$ 23 | - id: check-merge-conflict 24 | - repo: https://github.com/sphinx-contrib/sphinx-lint 25 | rev: v1.0.0 26 | hooks: 27 | - id: sphinx-lint 28 | args: ["-i", ".pixi", "."] 29 | - repo: https://github.com/astral-sh/ruff-pre-commit 30 | rev: v0.14.0 31 | hooks: 32 | - id: ruff-check 33 | types_or: [python, pyi, jupyter] 34 | args: ["--fix", "--output-format=full"] 35 | - id: ruff-format 36 | types_or: [python, pyi, jupyter] 37 | -------------------------------------------------------------------------------- /pytensor/graph/null_type.py: -------------------------------------------------------------------------------- 1 | from pytensor.graph.type import Type 2 | 3 | 4 | class NullType(Type): 5 | """ 6 | A type that allows no values. 7 | 8 | Used to represent expressions 9 | that are undefined, either because they do not exist mathematically 10 | or because the code to generate the expression has not been 11 | implemented yet. 12 | 13 | Parameters 14 | ---------- 15 | why_null : str 16 | A string explaining why this variable can't take on any values. 17 | 18 | """ 19 | 20 | def __init__(self, why_null="(no explanation given)"): 21 | self.why_null = why_null 22 | 23 | def filter(self, data, strict=False, allow_downcast=None): 24 | raise ValueError("No values may be assigned to a NullType") 25 | 26 | def filter_variable(self, other, allow_convert=True): 27 | raise ValueError("No values may be assigned to a NullType") 28 | 29 | def values_eq(self, a, b, force_same_dtype=True): 30 | raise ValueError("NullType has no values to compare") 31 | 32 | def __eq__(self, other): 33 | return type(self) is type(other) 34 | 35 | def __hash__(self): 36 | return hash(type(self)) 37 | 38 | def __str__(self): 39 | return "NullType" 40 | 41 | 42 | null_type = NullType() 43 | -------------------------------------------------------------------------------- /pytensor/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from pytensor.sparse import rewriting, sharedvar 2 | from pytensor.sparse.basic import * 3 | from pytensor.sparse.math import * 4 | from pytensor.sparse.sharedvar import sparse_constructor as shared 5 | from pytensor.sparse.type import SparseTensorType, _is_sparse 6 | 7 | 8 | def sparse_grad(var): 9 | """This function return a new variable whose gradient will be 10 | stored in a sparse format instead of dense. 11 | 12 | Currently only variable created by AdvancedSubtensor1 is supported. 13 | i.e. a_tensor_var[an_int_vector]. 14 | 15 | .. versionadded:: 0.6rc4 16 | """ 17 | from pytensor.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1 18 | 19 | if not ( 20 | var.owner and isinstance(var.owner.op, AdvancedSubtensor | AdvancedSubtensor1) 21 | ): 22 | raise TypeError( 23 | "Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1" 24 | ) 25 | 26 | x = var.owner.inputs[0] 27 | indices = var.owner.inputs[1:] 28 | 29 | if len(indices) > 1: 30 | raise TypeError( 31 | "Sparse gradient is only implemented for single advanced indexing" 32 | ) 33 | 34 | ret = AdvancedSubtensor1(sparse_grad=True)(x, indices[0]) 35 | return ret 36 | -------------------------------------------------------------------------------- /tests/link/jax/test_blas.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytensor.compile.function import function 5 | from pytensor.compile.mode import Mode 6 | from pytensor.configdefaults import config 7 | from pytensor.graph.rewriting.db import RewriteDatabaseQuery 8 | from pytensor.link.jax import JAXLinker 9 | from pytensor.tensor import blas as pt_blas 10 | from pytensor.tensor.type import tensor3 11 | from tests.link.jax.test_basic import compare_jax_and_py 12 | 13 | 14 | def test_jax_BatchedDot(): 15 | # tensor3 . tensor3 16 | a = tensor3("a") 17 | a_test_value = ( 18 | np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3)) 19 | ) 20 | b = tensor3("b") 21 | b_test_value = ( 22 | np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2)) 23 | ) 24 | out = pt_blas.BatchedDot()(a, b) 25 | compare_jax_and_py([a, b], [out], [a_test_value, b_test_value]) 26 | 27 | # A dimension mismatch should raise a TypeError for compatibility 28 | inputs = [a_test_value[:-1], b_test_value] 29 | opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"]) 30 | jax_mode = Mode(JAXLinker(), opts) 31 | pytensor_jax_fn = function([a, b], [out], mode=jax_mode) 32 | with pytest.raises(TypeError): 33 | pytensor_jax_fn(*inputs) 34 | -------------------------------------------------------------------------------- /pytensor/misc/ordered_set.py: -------------------------------------------------------------------------------- 1 | from collections.abc import Iterable, Iterator, MutableSet 2 | from typing import Any 3 | 4 | 5 | class OrderedSet(MutableSet): 6 | values: dict[Any, None] 7 | 8 | def __init__(self, iterable: Iterable | None = None) -> None: 9 | if iterable is None: 10 | self.values = {} 11 | else: 12 | self.values = dict.fromkeys(iterable) 13 | 14 | def __contains__(self, value) -> bool: 15 | return value in self.values 16 | 17 | def __iter__(self) -> Iterator: 18 | yield from self.values 19 | 20 | def __len__(self) -> int: 21 | return len(self.values) 22 | 23 | def add(self, value) -> None: 24 | self.values[value] = None 25 | 26 | def discard(self, value) -> None: 27 | if value in self.values: 28 | del self.values[value] 29 | 30 | def copy(self) -> "OrderedSet": 31 | return OrderedSet(self) 32 | 33 | def update(self, other: Iterable) -> None: 34 | for value in other: 35 | self.add(value) 36 | 37 | def union(self, other: Iterable) -> "OrderedSet": 38 | new_set = self.copy() 39 | new_set.update(other) 40 | return new_set 41 | 42 | def difference_update(self, other: Iterable) -> None: 43 | for value in other: 44 | self.discard(value) 45 | -------------------------------------------------------------------------------- /doc/.templates/nb-badges.html: -------------------------------------------------------------------------------- 1 | {% if pagename in ablog %} 2 | 3 | 4 | {% set gh_basepath = github_user + '/' + github_repo + '/blob/' + github_version + '/' %} 5 | {% set encoded_base = github_user + '%252F' + github_repo %} 6 | {% set gh_binder = github_user + '/' + github_repo + '/' + github_version %} 7 | {% set doc_path_aux = doc_path | trim('/') %} 8 | {% set file_path = doc_path_aux + '/' + pagename + ".ipynb" %} 9 | {% set encoded_path = file_path | replace("/", "%252F") %} 10 | 11 | 12 |
13 |

14 | 15 | View On GitHub 16 | 17 | 18 | Open In Binder 19 | 20 | 21 | Open In Colab 22 |

23 |
24 | {% endif %} -------------------------------------------------------------------------------- /pytensor/misc/frozendict.py: -------------------------------------------------------------------------------- 1 | # License : https://github.com/slezica/python-frozendict/blob/master/LICENSE.txt 2 | 3 | 4 | import functools 5 | import operator 6 | from collections.abc import Mapping 7 | 8 | 9 | class frozendict(Mapping): 10 | """ 11 | An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping` 12 | interface. It can be used as a drop-in replacement for dictionaries where immutability and ordering are desired. 13 | """ 14 | 15 | dict_cls = dict 16 | 17 | def __init__(self, *args, **kwargs): 18 | self._dict = self.dict_cls(*args, **kwargs) 19 | self._hash = None 20 | 21 | def __getitem__(self, key): 22 | return self._dict[key] 23 | 24 | def __contains__(self, key): 25 | return key in self._dict 26 | 27 | def copy(self, **add_or_replace): 28 | return self.__class__(self, **add_or_replace) 29 | 30 | def __iter__(self): 31 | return iter(self._dict) 32 | 33 | def __len__(self): 34 | return len(self._dict) 35 | 36 | def __repr__(self): 37 | return f"<{self.__class__.__name__} {self._dict!r}>" 38 | 39 | def __hash__(self): 40 | if self._hash is None: 41 | hashes = map(hash, self.items()) 42 | self._hash = functools.reduce(operator.xor, hashes, 0) 43 | 44 | return self._hash 45 | -------------------------------------------------------------------------------- /pytensor/d3viz/js/d3-context-menu.js: -------------------------------------------------------------------------------- 1 | d3.contextMenu = function (menu, openCallback) { 2 | 3 | // create the div element that will hold the context menu 4 | d3.selectAll('.d3-context-menu').data([1]) 5 | .enter() 6 | .append('div') 7 | .attr('class', 'd3-context-menu'); 8 | 9 | // close menu 10 | d3.select('body').on('click.d3-context-menu', function() { 11 | d3.select('.d3-context-menu').style('display', 'none'); 12 | }); 13 | 14 | // this gets executed when a contextmenu event occurs 15 | return function(data, index) { 16 | var elm = this; 17 | 18 | d3.selectAll('.d3-context-menu').html(''); 19 | var list = d3.selectAll('.d3-context-menu').append('ul'); 20 | list.selectAll('li').data(menu).enter() 21 | .append('li') 22 | .html(function(d) { 23 | return d.title; 24 | }) 25 | .on('click', function(d, i) { 26 | d.action(elm, data, index); 27 | d3.select('.d3-context-menu').style('display', 'none'); 28 | }); 29 | 30 | // the openCallback allows an action to fire before the menu is displayed 31 | // an example usage would be closing a tooltip 32 | if (openCallback) openCallback(data, index); 33 | 34 | // display context menu 35 | d3.select('.d3-context-menu') 36 | .style('left', (d3.event.pageX - 2) + 'px') 37 | .style('top', (d3.event.pageY - 2) + 'px') 38 | .style('display', 'block'); 39 | 40 | d3.event.preventDefault(); 41 | }; 42 | }; 43 | -------------------------------------------------------------------------------- /doc/library/d3viz/examples/d3viz/js/d3-context-menu.js: -------------------------------------------------------------------------------- 1 | d3.contextMenu = function (menu, openCallback) { 2 | 3 | // create the div element that will hold the context menu 4 | d3.selectAll('.d3-context-menu').data([1]) 5 | .enter() 6 | .append('div') 7 | .attr('class', 'd3-context-menu'); 8 | 9 | // close menu 10 | d3.select('body').on('click.d3-context-menu', function() { 11 | d3.select('.d3-context-menu').style('display', 'none'); 12 | }); 13 | 14 | // this gets executed when a contextmenu event occurs 15 | return function(data, index) { 16 | var elm = this; 17 | 18 | d3.selectAll('.d3-context-menu').html(''); 19 | var list = d3.selectAll('.d3-context-menu').append('ul'); 20 | list.selectAll('li').data(menu).enter() 21 | .append('li') 22 | .html(function(d) { 23 | return d.title; 24 | }) 25 | .on('click', function(d, i) { 26 | d.action(elm, data, index); 27 | d3.select('.d3-context-menu').style('display', 'none'); 28 | }); 29 | 30 | // the openCallback allows an action to fire before the menu is displayed 31 | // an example usage would be closing a tooltip 32 | if (openCallback) openCallback(data, index); 33 | 34 | // display context menu 35 | d3.select('.d3-context-menu') 36 | .style('left', (d3.event.pageX - 2) + 'px') 37 | .style('top', (d3.event.pageY - 2) + 'px') 38 | .style('display', 'block'); 39 | 40 | d3.event.preventDefault(); 41 | }; 42 | }; 43 | -------------------------------------------------------------------------------- /doc/tutorial/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _tutorial: 3 | 4 | ======== 5 | Tutorial 6 | ======== 7 | 8 | Let us start an interactive session (e.g. with ``python`` or ``ipython``) and import PyTensor. 9 | 10 | >>> from pytensor import * 11 | 12 | Several of the symbols you will need to use are in the ``tensor`` subpackage 13 | of PyTensor. Let us import that subpackage under a handy name like 14 | ``at`` (the tutorials will frequently use this convention). 15 | 16 | >>> import pytensor.tensor as pt 17 | 18 | If that succeeded you are ready for the tutorial, otherwise check your 19 | installation (see :ref:`install`). 20 | 21 | Throughout the tutorial, bear in mind that there is a :ref:`glossary` as well 22 | as *index* and *modules* links in the upper-right corner of each page to help 23 | you out. 24 | 25 | Basics 26 | ------ 27 | 28 | .. toctree:: 29 | 30 | adding 31 | examples 32 | gradients 33 | conditions 34 | loop 35 | shape_info 36 | broadcasting 37 | 38 | Advanced 39 | -------- 40 | 41 | .. toctree:: 42 | 43 | sparse 44 | prng 45 | 46 | Advanced configuration and debugging 47 | ------------------------------------ 48 | 49 | .. toctree:: 50 | 51 | modes 52 | printing_drawing 53 | debug_faq 54 | nan_tutorial 55 | profiling 56 | 57 | Further reading 58 | --------------- 59 | 60 | .. toctree:: 61 | 62 | loading_and_saving 63 | aliasing 64 | multi_cores 65 | faq_tutorial 66 | -------------------------------------------------------------------------------- /doc/generate_dtype_tensor_table.py: -------------------------------------------------------------------------------- 1 | letters = [ 2 | ('b', 'int8'), 3 | ('w', 'int16'), 4 | ('i', 'int32'), 5 | ('l', 'int64'), 6 | ('d', 'float64'), 7 | ('f', 'float32'), 8 | ('c', 'complex64'), 9 | ('z', 'complex128') ] 10 | 11 | shapes = [ 12 | ('scalar', ()), 13 | ('vector', (False,)), 14 | ('row', (True, False)), 15 | ('col', (False, True)), 16 | ('matrix', (False,False)), 17 | ('tensor3', (False,False,False)), 18 | ('tensor4', (False,False,False,False)), 19 | ('tensor5', (False,False,False,False,False)), 20 | ('tensor6', (False,) * 6), 21 | ('tensor7', (False,) * 7),] 22 | 23 | hdr = '============ =========== ==== ================ ===================================' 24 | print(hdr) 25 | print('Constructor dtype ndim shape broadcastable') 26 | print(hdr) 27 | for letter in letters: 28 | for shape in shapes: 29 | suff = ',)' if len(shape[1])==1 else ')' 30 | s = '(' + ','.join('1' if b else '?' for b in shape[1]) + suff 31 | if len(shape[1]) < 6 or len(set(shape[1])) > 1: 32 | broadcastable_str = str(shape[1]) 33 | else: 34 | broadcastable_str = f'({shape[1][0]},) * {len(shape[1])}' 35 | print('%s%-10s %-10s %-4s %-15s %-20s' %( 36 | letter[0], shape[0], letter[1], len(shape[1]), s, broadcastable_str 37 | )) 38 | print(hdr) 39 | -------------------------------------------------------------------------------- /tests/tensor/rewriting/test_einsum.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | from pytensor.graph import ancestors, rewrite_graph 4 | from pytensor.tensor import einsum, specify_shape, tensor 5 | from pytensor.tensor.einsum import Einsum 6 | 7 | 8 | specialize_rewrite = partial(rewrite_graph, include=("specialize",), clone=True) 9 | 10 | 11 | def test_einsum_optimization(): 12 | a = tensor("a", shape=(None, None)) 13 | b = tensor("b", shape=(None, None)) 14 | c = tensor("c", shape=(None, None)) 15 | 16 | dynamic_shape_einsum = einsum("ij,ij,jk->ik", a, b, c) 17 | assert not dynamic_shape_einsum.owner.op.optimized 18 | 19 | rewritten_out = specialize_rewrite(dynamic_shape_einsum) 20 | assert isinstance(rewritten_out.owner.op, Einsum) 21 | 22 | a = specify_shape(a, (2, 3)) 23 | b = specify_shape(b, (2, 3)) 24 | c = specify_shape(c, (3, 5)) 25 | 26 | static_shape_einsum = dynamic_shape_einsum.owner.clone_with_new_inputs( 27 | [a, b, c] 28 | ).default_output() 29 | assert not static_shape_einsum.owner.op.optimized 30 | 31 | rewritten_out = specialize_rewrite(static_shape_einsum) 32 | # Einsum was inlined because it was optimized 33 | assert not isinstance(rewritten_out.owner.op, Einsum) 34 | # Sanity check that it's not buried in the graph 35 | assert not any( 36 | isinstance(var.owner.op, Einsum) 37 | for var in ancestors([rewritten_out]) 38 | if var.owner 39 | ) 40 | -------------------------------------------------------------------------------- /tests/link/jax/test_blockwise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from pytensor import config 5 | from pytensor.tensor import tensor 6 | from pytensor.tensor.blockwise import Blockwise 7 | from pytensor.tensor.math import Dot, matmul 8 | from tests.link.jax.test_basic import compare_jax_and_py 9 | from tests.tensor.test_blockwise import check_blockwise_runtime_broadcasting 10 | 11 | 12 | jax = pytest.importorskip("jax") 13 | 14 | 15 | def test_runtime_broadcasting(): 16 | check_blockwise_runtime_broadcasting("JAX") 17 | 18 | 19 | # Equivalent blockwise to matmul but with dumb signature 20 | odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)") 21 | 22 | 23 | @pytest.mark.parametrize("matmul_op", (matmul, odd_matmul)) 24 | def test_matmul(matmul_op): 25 | rng = np.random.default_rng(14) 26 | a = tensor("a", shape=(2, 3, 5)) 27 | b = tensor("b", shape=(2, 5, 3)) 28 | test_values = [ 29 | rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b) 30 | ] 31 | 32 | out = matmul_op(a, b) 33 | assert isinstance(out.owner.op, Blockwise) 34 | fn, _ = compare_jax_and_py([a, b], [out], test_values) 35 | 36 | # Check we are not adding any unnecessary stuff 37 | jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values)) 38 | jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul") 39 | expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values)) 40 | assert jaxpr == expected_jaxpr 41 | -------------------------------------------------------------------------------- /doc/library/tensor/fft.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_tensor_fft: 2 | 3 | ============================================== 4 | :mod:`tensor.fft` -- Fast Fourier Transforms 5 | ============================================== 6 | 7 | Performs Fast Fourier Transforms (FFT). 8 | 9 | FFT gradients are implemented as the opposite Fourier transform of the output gradients. 10 | 11 | .. warning :: 12 | The real and imaginary parts of the Fourier domain arrays are stored as a pair of float 13 | arrays, emulating complex. Since pytensor has limited support for complex 14 | number operations, care must be taken to manually implement operations such as gradients. 15 | 16 | .. automodule:: pytensor.tensor.fft 17 | :members: rfft, irfft 18 | 19 | For example, the code below performs the real input FFT of a box function, 20 | which is a sinc function. The absolute value is plotted, since the phase 21 | oscillates due to the box function being shifted to the middle of the array. 22 | 23 | .. testcode:: 24 | 25 | import numpy as np 26 | import pytensor 27 | import pytensor.tensor as pt 28 | from pytensor.tensor import fft 29 | 30 | x = pt.matrix('x', dtype='float64') 31 | 32 | rfft = fft.rfft(x, norm='ortho') 33 | f_rfft = pytensor.function([x], rfft) 34 | 35 | N = 1024 36 | box = np.zeros((1, N), dtype='float64') 37 | box[:, N//2-10: N//2+10] = 1 38 | 39 | out = f_rfft(box) 40 | c_out = np.asarray(out[0, :, 0] + 1j*out[0, :, 1]) 41 | abs_out = abs(c_out) 42 | 43 | .. image:: plot_fft.png 44 | -------------------------------------------------------------------------------- /pytensor/link/pytorch/dispatch/extra_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from pytensor.link.pytorch.dispatch.basic import pytorch_funcify 4 | from pytensor.tensor.extra_ops import CumOp, Repeat, Unique 5 | 6 | 7 | @pytorch_funcify.register(CumOp) 8 | def pytorch_funcify_Cumop(op, **kwargs): 9 | axis = op.axis 10 | mode = op.mode 11 | 12 | def cumop(x): 13 | if mode == "add": 14 | return torch.cumsum(x, dim=axis) 15 | else: 16 | return torch.cumprod(x, dim=axis) 17 | 18 | return cumop 19 | 20 | 21 | @pytorch_funcify.register(Repeat) 22 | def pytorch_funcify_Repeat(op, **kwargs): 23 | axis = op.axis 24 | 25 | def repeat(x, repeats): 26 | return x.repeat_interleave(repeats, dim=axis) 27 | 28 | return repeat 29 | 30 | 31 | @pytorch_funcify.register(Unique) 32 | def pytorch_funcify_Unique(op, **kwargs): 33 | return_index = op.return_index 34 | 35 | if return_index: 36 | # TODO: evaluate whether is worth implementing this param 37 | # (see https://github.com/pytorch/pytorch/issues/36748) 38 | raise NotImplementedError("return_index is not implemented for pytorch") 39 | 40 | axis = op.axis 41 | return_inverse = op.return_inverse 42 | return_counts = op.return_counts 43 | 44 | def unique(x): 45 | return torch.unique( 46 | x, 47 | sorted=True, 48 | return_inverse=return_inverse, 49 | return_counts=return_counts, 50 | dim=axis, 51 | ) 52 | 53 | return unique 54 | -------------------------------------------------------------------------------- /pytensor/link/jax/dispatch/sparse.py: -------------------------------------------------------------------------------- 1 | import jax.experimental.sparse as jsp 2 | from scipy.sparse import spmatrix 3 | 4 | from pytensor.graph.basic import Constant 5 | from pytensor.link.jax.dispatch import jax_funcify, jax_typify 6 | from pytensor.sparse.math import Dot, StructuredDot 7 | from pytensor.sparse.type import SparseTensorType 8 | 9 | 10 | @jax_typify.register(spmatrix) 11 | def jax_typify_spmatrix(matrix, dtype=None, **kwargs): 12 | # Note: This changes the type of the constants from CSR/CSC to BCOO 13 | # We could add BCOO as a PyTensor type but this would only be useful for JAX graphs 14 | # and it would break the premise of one graph -> multiple backends. 15 | # The same situation happens with RandomGenerators... 16 | return jsp.BCOO.from_scipy_sparse(matrix) 17 | 18 | 19 | @jax_funcify.register(Dot) 20 | @jax_funcify.register(StructuredDot) 21 | def jax_funcify_sparse_dot(op, node, **kwargs): 22 | for input in node.inputs: 23 | if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant): 24 | raise NotImplementedError( 25 | "JAX sparse dot only implemented for constant sparse inputs" 26 | ) 27 | 28 | if isinstance(node.outputs[0].type, SparseTensorType): 29 | raise NotImplementedError("JAX sparse dot only implemented for dense outputs") 30 | 31 | @jsp.sparsify 32 | def sparse_dot(x, y): 33 | out = x @ y 34 | if isinstance(out, jsp.BCOO): 35 | out = out.todense() 36 | return out 37 | 38 | return sparse_dot 39 | -------------------------------------------------------------------------------- /tests/xtensor/test_reduction.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | pytest.importorskip("xarray") 5 | 6 | from pytensor.xtensor.type import xtensor 7 | from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function 8 | 9 | 10 | @pytest.mark.parametrize( 11 | "dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"] 12 | ) 13 | @pytest.mark.parametrize( 14 | "method", 15 | ["sum", "prod", "all", "any", "max", "min", "mean", "cumsum", "cumprod"], 16 | ) 17 | def test_reduction(method, dim): 18 | x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7)) 19 | out = getattr(x, method)(dim=dim) 20 | 21 | fn = xr_function([x], out) 22 | x_test = xr_arange_like(x) 23 | 24 | xr_assert_allclose( 25 | fn(x_test), 26 | getattr(x_test, method)(dim=dim), 27 | ) 28 | 29 | 30 | @pytest.mark.parametrize( 31 | "dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"] 32 | ) 33 | @pytest.mark.parametrize("method", ["std", "var"]) 34 | def test_std_var(method, dim): 35 | x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7)) 36 | out = [ 37 | getattr(x, method)(dim=dim), 38 | getattr(x, method)(dim=dim, ddof=2), 39 | ] 40 | 41 | fn = xr_function([x], out) 42 | x_test = xr_arange_like(x) 43 | results = fn(x_test) 44 | 45 | xr_assert_allclose( 46 | results[0], 47 | getattr(x_test, method)(dim=dim), 48 | ) 49 | 50 | xr_assert_allclose( 51 | results[1], 52 | getattr(x_test, method)(dim=dim, ddof=2), 53 | ) 54 | -------------------------------------------------------------------------------- /doc/.templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {% block footer %} 4 | {{ super() }} 5 | 16 | 17 | 18 | 35 | 36 | 39 | {% endblock %} 40 | -------------------------------------------------------------------------------- /tests/link/pytorch/test_blockwise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | import pytensor 5 | import pytensor.tensor as pt 6 | from pytensor.graph.basic import Apply 7 | from pytensor.graph.op import Op 8 | from pytensor.tensor.blockwise import Blockwise 9 | 10 | 11 | torch = pytest.importorskip("torch") 12 | basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic") 13 | 14 | 15 | class BatchedTestOp(Op): 16 | gufunc_signature = "(m,n),(n,p)->(m,p)" 17 | 18 | def __init__(self, final_shape): 19 | super().__init__() 20 | self.final_shape = final_shape 21 | self.call_shapes = [] 22 | 23 | def make_node(self, *args): 24 | return Apply(self, list(args), [pt.matrix("_", shape=self.final_shape)]) 25 | 26 | def perform(self, *_): 27 | raise RuntimeError("In perform") 28 | 29 | 30 | @basic.pytorch_funcify.register(BatchedTestOp) 31 | def evaluate_test_op(op, **_): 32 | def func(a, b): 33 | op.call_shapes.extend(map(torch.Tensor.size, [a, b])) 34 | return a @ b 35 | 36 | return func 37 | 38 | 39 | def test_blockwise_broadcast(): 40 | _x = np.random.rand(5, 1, 2, 3) 41 | _y = np.random.rand(3, 3, 2) 42 | 43 | x = pt.tensor4("x", shape=(5, 1, 2, 3)) 44 | y = pt.tensor3("y", shape=(3, 3, 2)) 45 | op = BatchedTestOp((2, 2)) 46 | z = Blockwise(op)(x, y) 47 | 48 | f = pytensor.function([x, y], z, mode="PYTORCH") 49 | res = f(_x, _y) 50 | assert tuple(res.shape) == (5, 3, 2, 2) 51 | np.testing.assert_allclose(res, _x @ _y) 52 | assert op.call_shapes == [(2, 3), (3, 2)] 53 | -------------------------------------------------------------------------------- /doc/library/typed_list.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_typed_list: 2 | 3 | =============================== 4 | :mod:`typed_list` -- Typed List 5 | =============================== 6 | 7 | .. note:: 8 | 9 | This has been added in release 0.7. 10 | 11 | .. note:: 12 | 13 | This works, but is not well integrated with the rest of PyTensor. If 14 | speed is important, it is probably better to pad to a dense 15 | tensor. 16 | 17 | This is a type that represents a list in PyTensor. All elements must have 18 | the same PyTensor type. Here is an example: 19 | 20 | >>> import pytensor.typed_list 21 | >>> tl = pytensor.typed_list.TypedListType(pytensor.tensor.fvector)() 22 | >>> v = pytensor.tensor.fvector() 23 | >>> o = pytensor.typed_list.append(tl, v) 24 | >>> f = pytensor.function([tl, v], o) 25 | >>> f([[1, 2, 3], [4, 5]], [2]) 26 | [array([ 1., 2., 3.], dtype=float32), array([ 4., 5.], dtype=float32), array([ 2.], dtype=float32)] 27 | 28 | A second example with Scan. Scan doesn't yet have direct support of 29 | TypedList, so you can only use it as non_sequences (not in sequences or 30 | as outputs): 31 | 32 | >>> import pytensor.typed_list 33 | >>> a = pytensor.typed_list.TypedListType(pytensor.tensor.fvector)() 34 | >>> l = pytensor.typed_list.length(a) 35 | >>> s, _ = pytensor.scan(fn=lambda i, tl: tl[i].sum(), 36 | ... non_sequences=[a], 37 | ... sequences=[pytensor.tensor.arange(l, dtype='int64')]) 38 | >>> f = pytensor.function([a], s) 39 | >>> f([[1, 2, 3], [4, 5]]) 40 | array([ 6., 9.], dtype=float32) 41 | 42 | .. automodule:: pytensor.typed_list.basic 43 | :members: 44 | -------------------------------------------------------------------------------- /tests/graph/test_types.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from pytensor.graph.basic import Variable 4 | from pytensor.graph.type import Type 5 | 6 | 7 | class MyType(Type): 8 | def __init__(self, thingy): 9 | self.thingy = thingy 10 | 11 | def filter(self, *args, **kwargs): 12 | raise NotImplementedError() 13 | 14 | def __eq__(self, other): 15 | return isinstance(other, MyType) and other.thingy == self.thingy 16 | 17 | def __str__(self): 18 | return f"R{self.thingy}" 19 | 20 | def __repr__(self): 21 | return f"R{self.thingy}" 22 | 23 | 24 | class MyType2(MyType): 25 | def is_super(self, other): 26 | if self.thingy <= other.thingy: 27 | return True 28 | 29 | 30 | def test_is_super(): 31 | t1 = MyType(1) 32 | t2 = MyType(2) 33 | 34 | assert t1.is_super(t2) is None 35 | 36 | t1_2 = MyType(1) 37 | assert t1.is_super(t1_2) 38 | 39 | 40 | def test_in_same_class(): 41 | t1 = MyType(1) 42 | t2 = MyType(2) 43 | 44 | assert t1.in_same_class(t2) is False 45 | 46 | t1_2 = MyType(1) 47 | assert t1.in_same_class(t1_2) 48 | 49 | 50 | def test_convert_variable(): 51 | t1 = MyType(1) 52 | v1 = Variable(MyType(1), None, None) 53 | v2 = Variable(MyType(2), None, None) 54 | v3 = Variable(MyType2(0), None, None) 55 | 56 | assert t1.convert_variable(v1) is v1 57 | assert t1.convert_variable(v2) is None 58 | 59 | with pytest.raises(NotImplementedError): 60 | t1.convert_variable(v3) 61 | 62 | 63 | def test_default_clone(): 64 | mt = MyType(1) 65 | assert isinstance(mt.clone(1), MyType) 66 | -------------------------------------------------------------------------------- /tests/link/pytorch/test_shape.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import pytensor.tensor as pt 4 | from pytensor.configdefaults import config 5 | from pytensor.tensor.shape import Shape, Shape_i, reshape 6 | from pytensor.tensor.type import iscalar, vector 7 | from tests.link.pytorch.test_basic import compare_pytorch_and_py 8 | 9 | 10 | def test_pytorch_shape_ops(): 11 | x_np = np.zeros((20, 3)) 12 | x = Shape()(pt.as_tensor_variable(x_np)) 13 | 14 | compare_pytorch_and_py([], [x], []) 15 | 16 | x = Shape_i(1)(pt.as_tensor_variable(x_np)) 17 | 18 | compare_pytorch_and_py([], [x], []) 19 | 20 | 21 | def test_pytorch_specify_shape(): 22 | in_pt = pt.matrix("in") 23 | x = pt.specify_shape(in_pt, (4, None)) 24 | compare_pytorch_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)]) 25 | 26 | # When used to assert two arrays have similar shapes 27 | in_pt = pt.matrix("in") 28 | shape_pt = pt.matrix("shape") 29 | x = pt.specify_shape(in_pt, shape_pt.shape) 30 | 31 | compare_pytorch_and_py( 32 | [in_pt, shape_pt], 33 | [x], 34 | [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)], 35 | ) 36 | 37 | 38 | def test_pytorch_Reshape_constant(): 39 | a = vector("a") 40 | x = reshape(a, (2, 2)) 41 | 42 | compare_pytorch_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)]) 43 | 44 | 45 | def test_pytorch_Reshape_dynamic(): 46 | a = vector("a") 47 | shape_pt = iscalar("b") 48 | x = reshape(a, (shape_pt, shape_pt)) 49 | 50 | compare_pytorch_and_py( 51 | [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2] 52 | ) 53 | -------------------------------------------------------------------------------- /doc/library/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _libdoc: 3 | .. _Library documentation: 4 | 5 | ================= 6 | API Documentation 7 | ================= 8 | 9 | This documentation covers PyTensor module-wise. This is suited to finding the 10 | Types and Ops that you can use to build and compile expression graphs. 11 | 12 | Modules 13 | ======= 14 | 15 | .. toctree:: 16 | :maxdepth: 1 17 | 18 | compile/index 19 | config 20 | d3viz/index 21 | graph/index 22 | gradient 23 | printing 24 | scan 25 | sparse/index 26 | tensor/index 27 | typed_list 28 | xtensor/index 29 | 30 | .. module:: pytensor 31 | :platform: Unix, Windows 32 | :synopsis: PyTensor top-level import 33 | .. moduleauthor:: LISA 34 | 35 | There are also some top-level imports that you might find more convenient: 36 | 37 | Graph 38 | ===== 39 | 40 | .. function:: shared(...) 41 | 42 | Alias for :func:`pytensor.compile.sharedvalue.shared` 43 | 44 | .. function:: function(...) 45 | 46 | Alias for :func:`pytensor.compile.function.function` 47 | 48 | .. autofunction:: pytensor.clone_replace(...) 49 | 50 | Alias for :func:`pytensor.graph.basic.clone_replace` 51 | 52 | Control flow 53 | ============ 54 | 55 | .. autofunction:: pytensor.scan(...) 56 | 57 | Alias for :func:`pytensor.scan.basic.scan` 58 | 59 | Convert to Variable 60 | ==================== 61 | 62 | .. autofunction:: pytensor.as_symbolic(...) 63 | 64 | Wrap JAX functions 65 | ================== 66 | 67 | .. autofunction:: wrap_jax(...) 68 | 69 | Alias for :func:`pytensor.link.jax.ops.wrap_jax` 70 | 71 | Debug 72 | ===== 73 | 74 | .. autofunction:: pytensor.dprint(...) 75 | 76 | Alias for :func:`pytensor.printing.debugprint` 77 | -------------------------------------------------------------------------------- /doc/tutorial/profiling_example_out.prof: -------------------------------------------------------------------------------- 1 | Function profiling 2 | ================== 3 | Message: None 4 | Time in 1 calls to Function.__call__: 5.698204e-05s 5 | Time in Function.vm.__call__: 1.192093e-05s (20.921%) 6 | Time in thunks: 6.198883e-06s (10.879%) 7 | Total compile time: 3.642474e+00s 8 | PyTensor rewrite time: 7.326508e-02s 9 | PyTensor validate time: 3.712177e-04s 10 | PyTensor Linker time (includes C, CUDA code generation/compiling): 9.584920e-01s 11 | 12 | Class 13 | --- 14 | <% time>