├── aesara ├── py.typed ├── bin │ └── __init__.py ├── link │ ├── __init__.py │ ├── c │ │ ├── __init__.py │ │ ├── exceptions.py │ │ ├── c_code │ │ │ └── aesara_mod_helper.h │ │ └── cvm.py │ ├── jax │ │ ├── __init__.py │ │ ├── jax_linker.py │ │ ├── jax_dispatch.py │ │ └── dispatch │ │ │ ├── __init__.py │ │ │ └── slinalg.py │ └── numba │ │ ├── __init__.py │ │ ├── dispatch │ │ └── __init__.py │ │ └── linker.py ├── misc │ ├── __init__.py │ ├── check_blas_many.sh │ ├── latence_gpu_transfert.py │ ├── may_share_memory.py │ ├── frozendict.py │ ├── elemwise_time_test.py │ ├── elemwise_openmp_speedup.py │ └── safe_asarray.py ├── sandbox │ ├── __init__.py │ ├── linalg │ │ └── __init__.py │ ├── solve.py │ └── minimal.py ├── sparse │ ├── sandbox │ │ └── __init__.py │ ├── opt.py │ ├── utils.py │ ├── sharedvar.py │ └── __init__.py ├── tensor │ ├── signal │ │ └── __init__.py │ ├── linalg.py │ ├── random │ │ ├── rewriting │ │ │ ├── __init__.py │ │ │ └── jax.py │ │ ├── __init__.py │ │ ├── opt.py │ │ └── var.py │ ├── math_opt.py │ ├── nnet │ │ ├── opt.py │ │ └── __init__.py │ ├── subtensor_opt.py │ ├── opt_uncanonicalize.py │ ├── exceptions.py │ ├── rewriting │ │ └── __init__.py │ ├── basic_opt.py │ ├── c_code │ │ ├── alt_blas_common.h │ │ └── dimshuffle.c │ └── xlogx.py ├── graph │ ├── rewriting │ │ └── __init__.py │ ├── toolbox.py │ ├── unify.py │ ├── kanren.py │ ├── __init__.py │ ├── optdb.py │ ├── opt.py │ ├── opt_utils.py │ └── null_type.py ├── d3viz │ ├── __init__.py │ ├── css │ │ ├── d3-context-menu.css │ │ └── d3viz.css │ └── js │ │ └── d3-context-menu.js ├── typed_list │ ├── __init__.py │ └── rewriting.py ├── scan │ ├── opt.py │ └── __init__.py ├── scalar │ ├── basic_scipy.py │ ├── __init__.py │ └── sharedvar.py ├── assert_op.py ├── version.py └── compile │ ├── __init__.py │ └── compilelock.py ├── tests ├── __init__.py ├── d3viz │ ├── __init__.py │ ├── models.py │ ├── test_d3viz.py │ └── test_formatting.py ├── graph │ ├── __init__.py │ ├── rewriting │ │ └── __init__.py │ ├── test_utils.py │ ├── test_types.py │ └── test_sched.py ├── link │ ├── __init__.py │ ├── c │ │ ├── __init__.py │ │ └── c_code │ │ │ ├── test_cenum.h │ │ │ └── test_quadratic_function.c │ └── numba │ │ ├── __init__.py │ │ └── test_performance.py ├── misc │ ├── __init__.py │ └── test_pkl_utils.py ├── scan │ ├── __init__.py │ └── test_checkpoints.py ├── compile │ ├── __init__.py │ ├── function │ │ └── __init__.py │ ├── test_ops.py │ └── test_misc.py ├── sandbox │ ├── __init__.py │ ├── linalg │ │ └── __init__.py │ └── test_minimal.py ├── scalar │ ├── __init__.py │ ├── test_basic_sympy.py │ └── test_type.py ├── sparse │ ├── __init__.py │ ├── sandbox │ │ └── __init__.py │ ├── test_sharedvar.py │ └── test_utils.py ├── tensor │ ├── __init__.py │ ├── nnet │ │ ├── __init__.py │ │ └── test_rewriting.py │ ├── random │ │ └── __init__.py │ ├── signal │ │ └── __init__.py │ ├── rewriting │ │ └── __init__.py │ ├── test_xlogx.py │ ├── _test_mpi_roundtrip.py │ ├── test_type_other.py │ ├── test_io.py │ ├── test_fourier.py │ ├── test_merge.py │ ├── test_utils.py │ └── test_misc.py ├── typed_list │ └── __init__.py └── test_updates.py ├── .github ├── FUNDING.yml ├── SECURITY.md ├── dependabot.yml ├── workflows │ ├── update-pre-commit.yml │ └── dev-builds.yml ├── PULL_REQUEST_TEMPLATE.md ├── ISSUE_TEMPLATE.md └── CONTRIBUTING.md ├── .gitattributes ├── doc ├── .build │ └── PLACEHOLDER ├── .static │ ├── PLACEHOLDER │ └── fix_rtd.css ├── extend │ ├── rewrite.rst │ ├── op │ │ ├── apply.png │ │ └── index.rst │ └── backend │ │ └── index.rst ├── .templates │ ├── PLACEHOLDER │ └── layout.html ├── images │ ├── lstm.png │ ├── talk2010.gif │ ├── talk2010.png │ ├── Elman_srnn.png │ ├── blocksparse.png │ ├── aesara_logo_200.png │ ├── aesara_logo_2400.png │ ├── lstm_memorycell.png │ └── aesara_overview_diagram.png ├── troubleshoot │ ├── breakpoints.rst │ ├── d3viz.png │ ├── optimisations.rst │ ├── interacting_with_graph.rst │ ├── d3viz │ │ ├── examples │ │ │ ├── mlp.png │ │ │ ├── mlp2.pdf │ │ │ ├── mlp2.png │ │ │ └── d3viz │ │ │ │ ├── css │ │ │ │ ├── d3-context-menu.css │ │ │ │ └── d3viz.css │ │ │ │ └── js │ │ │ │ └── d3-context-menu.js │ │ └── index_files │ │ │ ├── index_10_0.png │ │ │ ├── index_11_0.png │ │ │ ├── index_24_0.png │ │ │ └── index_25_0.png │ ├── logreg_pydotprint_train.png │ ├── logreg_pydotprint_predict.png │ ├── logreg_pydotprint_prediction.png │ ├── index.rst │ └── profiling_example_out.prof ├── reference │ ├── tensor │ │ ├── shared │ │ │ ├── shared.rst │ │ │ └── index.rst │ │ ├── bcast.png │ │ ├── utils.rst │ │ ├── operations.window_functions.rst │ │ ├── operations.statistics.rst │ │ ├── operations.padding.rst │ │ ├── operations.sorting.rst │ │ ├── sparse │ │ │ └── sandbox.rst │ │ ├── operations.discrete_fourier.rst │ │ ├── operations.binary_operations.rst │ │ ├── operations.logic.rst │ │ ├── operations.rst │ │ ├── index.rst │ │ ├── operations.indexing.rst │ │ └── operations.tensor_creation.rst │ ├── gradient │ │ ├── dlogistic.png │ │ ├── index.rst │ │ └── gradient_api.rst │ ├── index.rst │ └── loops │ │ └── index.rst ├── fundamentals │ ├── graph │ │ ├── apply.png │ │ ├── symbolic_graph_opt.png │ │ ├── symbolic_graph_unopt.png │ │ ├── index.rst │ │ └── graph │ │ │ ├── op.rst │ │ │ ├── graph.rst │ │ │ ├── type.rst │ │ │ ├── index.rst │ │ │ ├── utils.rst │ │ │ ├── features.rst │ │ │ └── fgraph.rst │ ├── compilation │ │ └── index.rst │ └── rewrites │ │ └── index.rst ├── sandbox │ ├── function.rst │ ├── tensoroptools.rst │ ├── functional.rst │ ├── index.rst │ ├── compilation.rst │ ├── index2.rst │ ├── performance.rst │ ├── software.rst │ └── interactive_debugger.rst ├── serializing │ ├── index.rst │ └── pkl_utils.rst ├── compile │ ├── ops.rst │ ├── index.rst │ ├── profilemode.rst │ ├── opfromgraph.rst │ ├── rewrites │ │ └── index.rst │ ├── modes_solution_1.py │ ├── nanguardmode.rst │ └── mode.rst ├── css.inc ├── environment.yml ├── help.rst ├── acknowledgement.rst ├── core_development_guide.rst ├── generate_dtype_tensor_table.py └── install.rst ├── .git_archival.txt ├── requirements-rtd.txt ├── bin ├── __init__.py ├── aesara-cache └── aesara_cache.py ├── requirements.txt ├── .readthedocs.yaml ├── codecov.yml ├── .flake8 ├── DESCRIPTION.txt ├── .gitignore ├── environment.yml ├── environment-arm.yml ├── conftest.py ├── .pre-commit-config.yaml ├── CITATION.cff ├── LICENSE.txt └── Makefile /aesara/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aesara/bin/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aesara/link/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aesara/misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/d3viz/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/graph/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/misc/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/scan/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aesara/link/c/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aesara/sandbox/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/compile/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/c/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/link/numba/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/sandbox/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/scalar/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/typed_list/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aesara/sparse/sandbox/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aesara/tensor/signal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/graph/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/sandbox/linalg/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/sparse/sandbox/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/nnet/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/random/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/signal/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /aesara/graph/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/compile/function/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/tensor/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | .git_archival.txt export-subst 2 | -------------------------------------------------------------------------------- /aesara/d3viz/__init__.py: -------------------------------------------------------------------------------- 1 | from aesara.d3viz.d3viz import d3viz, d3write 2 | -------------------------------------------------------------------------------- /aesara/link/jax/__init__.py: -------------------------------------------------------------------------------- 1 | from aesara.link.jax.linker import JAXLinker 2 | -------------------------------------------------------------------------------- /aesara/link/numba/__init__.py: -------------------------------------------------------------------------------- 1 | from aesara.link.numba.linker import NumbaLinker 2 | -------------------------------------------------------------------------------- /doc/.build/PLACEHOLDER: -------------------------------------------------------------------------------- 1 | sphinx doesn't like it when this repertory isn't available 2 | -------------------------------------------------------------------------------- /doc/.static/PLACEHOLDER: -------------------------------------------------------------------------------- 1 | sphinx doesn't like it when this repertory isn't available 2 | -------------------------------------------------------------------------------- /doc/extend/rewrite.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | Add a new rewrite 3 | ================= 4 | -------------------------------------------------------------------------------- /doc/.templates/PLACEHOLDER: -------------------------------------------------------------------------------- 1 | sphinx doesn't like it when this repertory isn't available 2 | -------------------------------------------------------------------------------- /doc/images/lstm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/images/lstm.png -------------------------------------------------------------------------------- /doc/troubleshoot/breakpoints.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Add breakpoints 3 | =============== 4 | -------------------------------------------------------------------------------- /aesara/sandbox/linalg/__init__.py: -------------------------------------------------------------------------------- 1 | from aesara.sandbox.linalg.ops import spectral_radius_bound 2 | -------------------------------------------------------------------------------- /aesara/tensor/linalg.py: -------------------------------------------------------------------------------- 1 | from aesara.tensor.nlinalg import * 2 | from aesara.tensor.slinalg import * 3 | -------------------------------------------------------------------------------- /doc/extend/op/apply.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/extend/op/apply.png -------------------------------------------------------------------------------- /doc/images/talk2010.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/images/talk2010.gif -------------------------------------------------------------------------------- /doc/images/talk2010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/images/talk2010.png -------------------------------------------------------------------------------- /doc/reference/tensor/shared/shared.rst: -------------------------------------------------------------------------------- 1 | ================ 2 | Shared variables 3 | ================ 4 | -------------------------------------------------------------------------------- /doc/images/Elman_srnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/images/Elman_srnn.png -------------------------------------------------------------------------------- /doc/images/blocksparse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/images/blocksparse.png -------------------------------------------------------------------------------- /doc/troubleshoot/d3viz.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/d3viz.png -------------------------------------------------------------------------------- /doc/fundamentals/graph/apply.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/fundamentals/graph/apply.png -------------------------------------------------------------------------------- /doc/images/aesara_logo_200.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/images/aesara_logo_200.png -------------------------------------------------------------------------------- /doc/images/aesara_logo_2400.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/images/aesara_logo_2400.png -------------------------------------------------------------------------------- /doc/images/lstm_memorycell.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/images/lstm_memorycell.png -------------------------------------------------------------------------------- /doc/reference/tensor/bcast.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/reference/tensor/bcast.png -------------------------------------------------------------------------------- /doc/troubleshoot/optimisations.rst: -------------------------------------------------------------------------------- 1 | ====================== 2 | Turn optimizations off 3 | ====================== 4 | -------------------------------------------------------------------------------- /doc/reference/gradient/dlogistic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/reference/gradient/dlogistic.png -------------------------------------------------------------------------------- /doc/troubleshoot/interacting_with_graph.rst: -------------------------------------------------------------------------------- 1 | ======================= 2 | Interact with the graph 3 | ======================= 4 | -------------------------------------------------------------------------------- /doc/images/aesara_overview_diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/images/aesara_overview_diagram.png -------------------------------------------------------------------------------- /doc/troubleshoot/d3viz/examples/mlp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/d3viz/examples/mlp.png -------------------------------------------------------------------------------- /doc/troubleshoot/d3viz/examples/mlp2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/d3viz/examples/mlp2.pdf -------------------------------------------------------------------------------- /doc/troubleshoot/d3viz/examples/mlp2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/d3viz/examples/mlp2.png -------------------------------------------------------------------------------- /doc/troubleshoot/logreg_pydotprint_train.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/logreg_pydotprint_train.png -------------------------------------------------------------------------------- /doc/fundamentals/graph/symbolic_graph_opt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/fundamentals/graph/symbolic_graph_opt.png -------------------------------------------------------------------------------- /doc/fundamentals/graph/symbolic_graph_unopt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/fundamentals/graph/symbolic_graph_unopt.png -------------------------------------------------------------------------------- /doc/reference/tensor/utils.rst: -------------------------------------------------------------------------------- 1 | ============ 2 | Tensor utils 3 | ============ 4 | 5 | .. automodule:: aesara.tensor.utils 6 | :members: 7 | -------------------------------------------------------------------------------- /doc/troubleshoot/logreg_pydotprint_predict.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/logreg_pydotprint_predict.png -------------------------------------------------------------------------------- /.github/SECURITY.md: -------------------------------------------------------------------------------- 1 | To report a security vulnerability to Aesara, please go to 2 | https://tidelift.com/security and see the instructions there. 3 | 4 | -------------------------------------------------------------------------------- /doc/sandbox/function.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _function: 3 | 4 | ================== 5 | function interface 6 | ================== 7 | 8 | WRITEME 9 | 10 | -------------------------------------------------------------------------------- /doc/troubleshoot/d3viz/index_files/index_10_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/d3viz/index_files/index_10_0.png -------------------------------------------------------------------------------- /doc/troubleshoot/d3viz/index_files/index_11_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/d3viz/index_files/index_11_0.png -------------------------------------------------------------------------------- /doc/troubleshoot/d3viz/index_files/index_24_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/d3viz/index_files/index_24_0.png -------------------------------------------------------------------------------- /doc/troubleshoot/d3viz/index_files/index_25_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/d3viz/index_files/index_25_0.png -------------------------------------------------------------------------------- /doc/troubleshoot/logreg_pydotprint_prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/aesara-devs/aesara/HEAD/doc/troubleshoot/logreg_pydotprint_prediction.png -------------------------------------------------------------------------------- /aesara/typed_list/__init__.py: -------------------------------------------------------------------------------- 1 | from aesara.typed_list import rewriting 2 | from aesara.typed_list.basic import * 3 | from aesara.typed_list.type import TypedListType 4 | -------------------------------------------------------------------------------- /.git_archival.txt: -------------------------------------------------------------------------------- 1 | node: 4bfce9589c2df43753f963eabf216c16f48fe655 2 | node-date: 2024-09-04T01:19:45-05:00 3 | describe-name: rel-2.9.4-1-g4bfce9589c 4 | ref-names: HEAD -> main 5 | -------------------------------------------------------------------------------- /doc/fundamentals/compilation/index.rst: -------------------------------------------------------------------------------- 1 | .. _fundamental_compilation: 2 | 3 | Compilation 4 | =========== 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | pipeline 10 | -------------------------------------------------------------------------------- /doc/sandbox/tensoroptools.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _tensoroptools: 3 | 4 | ================ 5 | Tensor Op Tools 6 | ================ 7 | 8 | WRITEME - describe how to use Elemwise here 9 | 10 | -------------------------------------------------------------------------------- /doc/serializing/index.rst: -------------------------------------------------------------------------------- 1 | .. _reference_serializing: 2 | 3 | Serialize 4 | ========= 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | loading_and_saving 10 | pkl_utils 11 | -------------------------------------------------------------------------------- /doc/fundamentals/rewrites/index.rst: -------------------------------------------------------------------------------- 1 | .. _fundamental_rewrites: 2 | 3 | Rewrites 4 | ======== 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | graph_rewriting 10 | optimizations 11 | -------------------------------------------------------------------------------- /doc/reference/tensor/shared/index.rst: -------------------------------------------------------------------------------- 1 | .. _reference_shared: 2 | 3 | Shared variables 4 | ================ 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | aliasing 10 | shared 11 | -------------------------------------------------------------------------------- /tests/link/c/c_code/test_cenum.h: -------------------------------------------------------------------------------- 1 | #ifndef AESARA_TEST_CENUM 2 | #define AESARA_TEST_CENUM 3 | 4 | #define MILLION 1000000 5 | #define BILLION 1000000000 6 | #define TWO_BILLIONS 2000000000 7 | 8 | #endif 9 | -------------------------------------------------------------------------------- /doc/sandbox/functional.rst: -------------------------------------------------------------------------------- 1 | 2 | ========== 3 | Functional 4 | ========== 5 | 6 | Want to know about Aesara's `function design 7 | `? 8 | -------------------------------------------------------------------------------- /requirements-rtd.txt: -------------------------------------------------------------------------------- 1 | -e ./ 2 | sphinx>=1.3.0 3 | sphinx-book-theme 4 | sphinx-design 5 | jinja2<3.1.0 6 | pygments 7 | pytest 8 | numpy 9 | gnumpy 10 | pydot 11 | pydot-ng 12 | Cython 13 | scipy>=0.13 14 | -------------------------------------------------------------------------------- /doc/extend/backend/index.rst: -------------------------------------------------------------------------------- 1 | .. _extend_backend: 2 | 3 | Extend a backend 4 | ================ 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | creating_a_c_op 10 | creating_a_numba_jax_op 11 | ctype 12 | -------------------------------------------------------------------------------- /doc/compile/ops.rst: -------------------------------------------------------------------------------- 1 | ================================================== 2 | :mod:`ops` -- Some Common Ops and extra Ops stuff 3 | ================================================== 4 | 5 | .. automodule:: aesara.compile.ops 6 | :members: 7 | -------------------------------------------------------------------------------- /doc/reference/index.rst: -------------------------------------------------------------------------------- 1 | Building Aesara graphs 2 | ====================== 3 | 4 | .. toctree:: 5 | :caption: Reference 6 | 7 | tensor/index 8 | random/index 9 | loops/index 10 | gradient/index 11 | conditionals 12 | -------------------------------------------------------------------------------- /doc/reference/tensor/operations.window_functions.rst: -------------------------------------------------------------------------------- 1 | .. _reference_tensor_window_functions: 2 | .. currentmodule:: aesara.tensor 3 | 4 | Window functions 5 | ================ 6 | 7 | .. autosummary:: 8 | :toctree: _autosummary 9 | 10 | bartlett 11 | -------------------------------------------------------------------------------- /doc/fundamentals/graph/index.rst: -------------------------------------------------------------------------------- 1 | .. _fundamental_graph: 2 | 3 | Aesara Graphs 4 | ============= 5 | 6 | .. toctree:: 7 | :maxdepth: 1 8 | 9 | graphstructures 10 | op 11 | type 12 | using_params 13 | other_ops 14 | graph/index 15 | -------------------------------------------------------------------------------- /doc/sandbox/index.rst: -------------------------------------------------------------------------------- 1 | ========================================================= 2 | Sandbox, this documentation may or may not be out-of-date 3 | ========================================================= 4 | 5 | .. toctree:: 6 | :glob: 7 | 8 | * 9 | 10 | -------------------------------------------------------------------------------- /doc/sandbox/compilation.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _compilation: 3 | 4 | ======================= 5 | Compilation and Linking 6 | ======================= 7 | 8 | .. index:: 9 | single: Linker 10 | 11 | .. _linker: 12 | 13 | Linker 14 | ====== 15 | 16 | WRITEME 17 | 18 | 19 | -------------------------------------------------------------------------------- /bin/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.warn( 4 | message= "Importing 'bin.aesara_cache' is deprecated. Import from " 5 | "'aesara.bin.aesara_cache' instead.", 6 | category=DeprecationWarning, 7 | stacklevel=2, # Raise the warning on the import line 8 | ) 9 | -------------------------------------------------------------------------------- /doc/reference/tensor/operations.statistics.rst: -------------------------------------------------------------------------------- 1 | .. _reference_tensor_statistics: 2 | .. currentmodule:: aesara.tensor 3 | 4 | Statistics 5 | ========== 6 | 7 | .. autosummary:: 8 | :toctree: _autosummary 9 | 10 | ptp 11 | mean 12 | std 13 | cov 14 | bincount 15 | -------------------------------------------------------------------------------- /aesara/sandbox/solve.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | from aesara.tensor.slinalg import solve # noqa 5 | 6 | message = ( 7 | "The module aesara.sandbox.solve will soon be deprecated.\n" 8 | "Please use tensor.slinalg.solve instead." 9 | ) 10 | 11 | warnings.warn(message) 12 | -------------------------------------------------------------------------------- /aesara/scan/opt.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.scan.opt` is deprecated; use `aesara.scan.rewriting` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.scan.rewriting import * # noqa: F401 E402 F403 11 | -------------------------------------------------------------------------------- /aesara/graph/toolbox.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.graph.toolbox` is deprecated " 6 | "and has been renamed to `aesara.graph.features`", 7 | DeprecationWarning, 8 | stacklevel=2, 9 | ) 10 | 11 | from aesara.graph.toolbox import * 12 | -------------------------------------------------------------------------------- /aesara/sparse/opt.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.sparse.opt` is deprecated; use `aesara.sparse.rewriting` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.sparse.rewriting import * # noqa: F401 E402 F403 11 | -------------------------------------------------------------------------------- /aesara/scalar/basic_scipy.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.scalar.basic_scipy` is deprecated " 6 | "and has been renamed to `aesara.scalar.math`", 7 | DeprecationWarning, 8 | stacklevel=2, 9 | ) 10 | 11 | from aesara.scalar.math import * 12 | -------------------------------------------------------------------------------- /doc/reference/tensor/operations.padding.rst: -------------------------------------------------------------------------------- 1 | .. _reference_tensor_padding: 2 | .. currentmodule:: aesara.tensor 3 | 4 | 5 | Padding tensors 6 | =============== 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | shape_padleft 12 | shape_padright 13 | shape_padaxis 14 | -------------------------------------------------------------------------------- /aesara/graph/unify.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.graph.unify` is deprecated; use `aesara.graph.rewriting.unify` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.graph.rewriting.unify import * # noqa: F401 E402 F403 11 | -------------------------------------------------------------------------------- /aesara/link/jax/jax_linker.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.link.jax.jax_linker` is deprecated " 6 | "and has been renamed to `aesara.link.jax.linker`", 7 | DeprecationWarning, 8 | stacklevel=2, 9 | ) 10 | 11 | from aesara.link.jax.linker import * 12 | -------------------------------------------------------------------------------- /aesara/tensor/random/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | # TODO: This is for backward-compatibility; remove when reasonable. 2 | from aesara.tensor.random.rewriting.basic import * 3 | 4 | 5 | # isort: off 6 | 7 | # Register JAX specializations 8 | import aesara.tensor.random.rewriting.jax 9 | 10 | # isort: on 11 | -------------------------------------------------------------------------------- /aesara/graph/kanren.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.graph.kanren` is deprecated; use `aesara.graph.rewriting.kanren` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.graph.rewriting.kanren import * # noqa: F401 E402 F403 11 | -------------------------------------------------------------------------------- /aesara/tensor/math_opt.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.tensor.math_opt` is deprecated; use `aesara.tensor.rewriting.math` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.tensor.rewriting.math import * # noqa: F401 E402 F403 11 | -------------------------------------------------------------------------------- /aesara/tensor/nnet/opt.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.tensor.nnet.opt` is deprecated; use `aesara.tensor.nnet.rewriting` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.tensor.nnet.rewriting import * # noqa: F401 E402 F403 11 | -------------------------------------------------------------------------------- /aesara/tensor/random/__init__.py: -------------------------------------------------------------------------------- 1 | # Initialize `RandomVariable` rewrites 2 | import aesara.tensor.random.rewriting 3 | import aesara.tensor.random.utils 4 | from aesara.tensor.random.basic import * 5 | from aesara.tensor.random.op import RandomState, default_rng 6 | from aesara.tensor.random.utils import RandomStream 7 | -------------------------------------------------------------------------------- /aesara/assert_op.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.assert_op` is deprecated " 6 | "and its `Op`s have been moved to `aesara.raise_op`", 7 | DeprecationWarning, 8 | stacklevel=2, 9 | ) 10 | 11 | from aesara.raise_op import Assert, assert_op # noqa: F401 E402 12 | -------------------------------------------------------------------------------- /aesara/link/jax/jax_dispatch.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.link.jax.jax_dispatch` is deprecated " 6 | "and has been renamed to `aesara.link.jax.dispatch`", 7 | DeprecationWarning, 8 | stacklevel=2, 9 | ) 10 | 11 | from aesara.link.jax.dispatch import * 12 | -------------------------------------------------------------------------------- /doc/compile/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _reference_compile: 3 | 4 | ======= 5 | Compile 6 | ======= 7 | 8 | .. toctree:: 9 | :maxdepth: 1 10 | 11 | shared 12 | function 13 | io 14 | ops 15 | mode 16 | rewrites/index 17 | modes 18 | debugmode 19 | nanguardmode 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /doc/sandbox/index2.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _advanced: 3 | 4 | ==================================== 5 | Advanced Topics (under construction) 6 | ==================================== 7 | 8 | .. toctree:: 9 | :maxdepth: 2 10 | 11 | compilation 12 | ccodegen 13 | function 14 | debugging_with_stepmode 15 | 16 | -------------------------------------------------------------------------------- /aesara/tensor/random/opt.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.tensor.random.opt` is deprecated; use `aesara.tensor.random.rewriting` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.tensor.random.rewriting import * # noqa: F401 E402 F403 11 | -------------------------------------------------------------------------------- /doc/.static/fix_rtd.css: -------------------------------------------------------------------------------- 1 | /* work around https://github.com/snide/sphinx_rtd_theme/issues/149 */ 2 | .rst-content table.field-list .field-body { 3 | padding-top: 8px; 4 | } 5 | .rst-versions-up { 6 | cursor: pointer; 7 | display: inline; 8 | } 9 | .wy-side-nav-search>div.version { 10 | color: white; 11 | } 12 | -------------------------------------------------------------------------------- /aesara/tensor/subtensor_opt.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.tensor.subtensor_opt` is deprecated; use `aesara.tensor.rewriting.subtensor` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.tensor.rewriting.subtensor import * # noqa: F401 E402 F403 11 | -------------------------------------------------------------------------------- /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | 4 | # 5 | - package-ecosystem: "github-actions" 6 | directory: "/" 7 | schedule: 8 | interval: "daily" 9 | -------------------------------------------------------------------------------- /aesara/tensor/opt_uncanonicalize.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.tensor.opt_uncanonicalize` is deprecated; use `aesara.tensor.rewriting.uncanonicalize` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.tensor.rewriting.uncanonicalize import * # noqa: F401 E402 F403 11 | -------------------------------------------------------------------------------- /bin/aesara-cache: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import warnings 3 | 4 | from aesara.bin.aesara_cache import main 5 | 6 | 7 | warnings.warn( 8 | message= "Using this bin/aesara-cache script is deprecated. Use the plain " 9 | "aesara-cache command which is installed along with aesara.", 10 | category=DeprecationWarning, 11 | ) 12 | main() 13 | -------------------------------------------------------------------------------- /doc/css.inc: -------------------------------------------------------------------------------- 1 | .. _css: 2 | 3 | .. raw:: html 4 | 5 | 12 | 13 | .. role:: black 14 | .. role:: blue 15 | .. role:: red 16 | .. role:: green 17 | .. role:: pink 18 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | -e ./ 2 | filelock 3 | flake8>=3.8.4 4 | pep8 5 | pyflakes 6 | black>=20.8b1 7 | pytest-cov>=2.6.1 8 | coverage>=5.1 9 | pytest 10 | cython 11 | sympy 12 | jax>=0.4.1,<=0.4.16 13 | jaxlib>=0.4.1 14 | numba>=0.57.0,<0.58.0 15 | numba-scipy>=0.3.0 16 | diff-cover 17 | pre-commit 18 | isort 19 | packaging 20 | typing_extensions 21 | pytest-benchmark 22 | -------------------------------------------------------------------------------- /doc/troubleshoot/index.rst: -------------------------------------------------------------------------------- 1 | .. _troubleshoot: 2 | 3 | ============ 4 | Troubleshoot 5 | ============ 6 | 7 | .. toctree:: 8 | :caption: Troubleshoot 9 | 10 | printing 11 | breakpoints 12 | optimisations 13 | interacting_with_graph 14 | printing_drawing 15 | profiling 16 | debug_faq 17 | nan_tutorial 18 | troubleshooting 19 | d3viz/index 20 | -------------------------------------------------------------------------------- /doc/reference/tensor/operations.sorting.rst: -------------------------------------------------------------------------------- 1 | .. _reference_tensor_sorting: 2 | .. currentmodule:: aesara.tensor 3 | 4 | 5 | Sorting, searching and counting 6 | =============================== 7 | 8 | .. autosummary:: 9 | :toctree: _autosummary 10 | 11 | sort 12 | argsort 13 | argmax 14 | argmin 15 | nonzero 16 | flatnonzero 17 | searchsorted 18 | where 19 | -------------------------------------------------------------------------------- /bin/aesara_cache.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import warnings 4 | 5 | from aesara.bin.aesara_cache import * 6 | from aesara.bin.aesara_cache import _logger 7 | 8 | if __name__ == "__main__": 9 | warnings.warn( 10 | message= "Running 'aesara_cache.py' is deprecated. Use the aesara-cache " 11 | "script instead.", 12 | category=DeprecationWarning, 13 | ) 14 | main() 15 | -------------------------------------------------------------------------------- /doc/environment.yml: -------------------------------------------------------------------------------- 1 | name: aesara-docs 2 | channels: 3 | - conda-forge 4 | - nodefaults 5 | dependencies: 6 | - python=3.9 7 | - gcc_linux-64 8 | - gxx_linux-64 9 | - numpy 10 | - scipy 11 | - six 12 | - sphinx>=3,<7 13 | - sphinx-book-theme 14 | - sphinx-design 15 | - sphinx_rtd_theme 16 | - jinja2<3.1.0 17 | - mock 18 | - pillow 19 | - pip 20 | - pip: 21 | - -e ..[doc] 22 | -------------------------------------------------------------------------------- /doc/fundamentals/graph/graph/op.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _libdoc_graph_op: 3 | 4 | ============================================================== 5 | :mod:`graph` -- Objects and functions for computational graphs 6 | ============================================================== 7 | 8 | .. automodule:: aesara.graph.op 9 | :platform: Unix, Windows 10 | :synopsis: Interface for types of symbolic variables 11 | :members: 12 | .. moduleauthor:: LISA 13 | -------------------------------------------------------------------------------- /doc/fundamentals/graph/graph/graph.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_graph_graph: 2 | 3 | ============================================== 4 | :mod:`graph` -- Interface for the Aesara graph 5 | ============================================== 6 | 7 | --------- 8 | Reference 9 | --------- 10 | 11 | .. automodule:: aesara.graph.basic 12 | :platform: Unix, Windows 13 | :synopsis: Interface for types of symbolic variables 14 | :members: 15 | .. moduleauthor:: LISA 16 | -------------------------------------------------------------------------------- /aesara/tensor/exceptions.py: -------------------------------------------------------------------------------- 1 | class ShapeError(Exception): 2 | """Raised when the shape cannot be computed.""" 3 | 4 | 5 | class NotScalarConstantError(Exception): 6 | """ 7 | Raised by get_scalar_constant_value if called on something that is 8 | not a scalar constant. 9 | """ 10 | 11 | 12 | class AdvancedIndexingError(TypeError): 13 | """ 14 | Raised when Subtensor is asked to perform advanced indexing. 15 | 16 | """ 17 | -------------------------------------------------------------------------------- /doc/fundamentals/graph/graph/type.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_graph_type: 2 | 3 | ================================================ 4 | :mod:`type` -- Interface for types of variables 5 | ================================================ 6 | 7 | --------- 8 | Reference 9 | --------- 10 | 11 | .. automodule:: aesara.graph.type 12 | :platform: Unix, Windows 13 | :synopsis: Interface for types of symbolic variables 14 | :members: 15 | .. moduleauthor:: LISA 16 | -------------------------------------------------------------------------------- /aesara/tensor/rewriting/__init__.py: -------------------------------------------------------------------------------- 1 | import aesara.tensor.rewriting.basic 2 | import aesara.tensor.rewriting.elemwise 3 | import aesara.tensor.rewriting.extra_ops 4 | 5 | # Register JAX specializations 6 | import aesara.tensor.rewriting.jax 7 | import aesara.tensor.rewriting.math 8 | import aesara.tensor.rewriting.shape 9 | import aesara.tensor.rewriting.special 10 | import aesara.tensor.rewriting.subtensor 11 | import aesara.tensor.rewriting.uncanonicalize 12 | -------------------------------------------------------------------------------- /tests/sparse/test_sharedvar.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy as sp 3 | 4 | import aesara 5 | from aesara.sparse.sharedvar import SparseTensorSharedVariable 6 | 7 | 8 | def test_shared_basic(): 9 | x = aesara.shared( 10 | sp.sparse.csr_matrix(np.eye(100), dtype=np.float64), name="blah", borrow=True 11 | ) 12 | 13 | assert isinstance(x, SparseTensorSharedVariable) 14 | assert x.format == "csr" 15 | assert x.dtype == "float64" 16 | -------------------------------------------------------------------------------- /tests/graph/test_utils.py: -------------------------------------------------------------------------------- 1 | import aesara 2 | from aesara.tensor.type import vector 3 | 4 | 5 | def test_stack_trace(): 6 | with aesara.config.change_flags(traceback__limit=1): 7 | v = vector() 8 | assert len(v.tag.trace) == 1 9 | assert len(v.tag.trace[0]) == 1 10 | 11 | with aesara.config.change_flags(traceback__limit=2): 12 | v = vector() 13 | assert len(v.tag.trace) == 1 14 | assert len(v.tag.trace[0]) == 2 15 | -------------------------------------------------------------------------------- /doc/fundamentals/graph/graph/index.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _libdoc_graph: 3 | 4 | ================================================ 5 | :mod:`graph` -- Aesara Internals [doc TODO] 6 | ================================================ 7 | 8 | .. module:: graph 9 | :platform: Unix, Windows 10 | :synopsis: Theano Internals 11 | .. moduleauthor:: LISA 12 | 13 | .. toctree:: 14 | :maxdepth: 1 15 | 16 | graph 17 | fgraph 18 | features 19 | op 20 | type 21 | utils 22 | -------------------------------------------------------------------------------- /aesara/link/c/exceptions.py: -------------------------------------------------------------------------------- 1 | from setuptools._distutils.errors import CompileError as BaseCompileError 2 | 3 | 4 | class MissingGXX(Exception): 5 | """ 6 | This error is raised when we try to generate c code, 7 | but g++ is not available. 8 | 9 | """ 10 | 11 | 12 | class CompileError(BaseCompileError): 13 | """This custom `Exception` prints compilation errors with their original 14 | formatting. 15 | """ 16 | 17 | def __str__(self): 18 | return self.args[0] 19 | -------------------------------------------------------------------------------- /doc/compile/profilemode.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. _profilemode: 4 | 5 | ================================================ 6 | :mod:`profilemode` -- profiling Aesara functions 7 | ================================================ 8 | 9 | 10 | .. module:: aesara.compile.profilemode 11 | :platform: Unix, Windows 12 | :synopsis: profiling Aesara functions with ProfileMode 13 | .. moduleauthor:: LISA 14 | 15 | Guide 16 | ===== 17 | 18 | .. note:: 19 | 20 | ProfileMode is removed. Use :attr:`config.profile` instead. 21 | -------------------------------------------------------------------------------- /doc/serializing/pkl_utils.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _libdoc_misc: 3 | 4 | ================================================ 5 | :mod:`misc.pkl_utils` - Tools for serialization. 6 | ================================================ 7 | 8 | .. testsetup:: * 9 | 10 | from aesara.misc.pkl_utils import * 11 | 12 | .. autofunction:: aesara.misc.pkl_utils.dump 13 | 14 | .. autofunction:: aesara.misc.pkl_utils.load 15 | 16 | .. autoclass:: aesara.misc.pkl_utils.StripPickler 17 | 18 | .. seealso:: 19 | 20 | :ref:`tutorial_loadsave` 21 | -------------------------------------------------------------------------------- /aesara/tensor/basic_opt.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.tensor.basic_opt` is deprecated; use `aesara.tensor.rewriting.basic` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.tensor.rewriting.basic import * # noqa: F401 E402 F403 11 | from aesara.tensor.rewriting.elemwise import * # noqa: F401 E402 F403 12 | from aesara.tensor.rewriting.extra_ops import * # noqa: F401 E402 F403 13 | from aesara.tensor.rewriting.shape import * # noqa: F401 E402 F403 14 | -------------------------------------------------------------------------------- /doc/fundamentals/graph/graph/utils.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_graph_utils: 2 | 3 | ========================================================== 4 | :mod:`utils` -- Utilities functions operating on the graph 5 | ========================================================== 6 | 7 | .. testsetup:: * 8 | 9 | from aesara.graph.utils import * 10 | 11 | --------- 12 | Reference 13 | --------- 14 | 15 | .. automodule:: aesara.graph.utils 16 | :platform: Unix, Windows 17 | :synopsis: Utilities functions operating on the graph 18 | :members: 19 | .. moduleauthor:: LISA 20 | -------------------------------------------------------------------------------- /aesara/d3viz/css/d3-context-menu.css: -------------------------------------------------------------------------------- 1 | .d3-context-menu { 2 | position: absolute; 3 | display: none; 4 | background-color: #f2f2f2; 5 | border-radius: 4px; 6 | 7 | font-family: Arial, sans-serif; 8 | font-size: 14px; 9 | min-width: 50px; 10 | border: 1px solid #d4d4d4; 11 | 12 | z-index:1200; 13 | } 14 | 15 | .d3-context-menu ul { 16 | list-style-type: none; 17 | margin: 4px 0px; 18 | padding: 0px; 19 | cursor: default; 20 | } 21 | 22 | .d3-context-menu ul li { 23 | padding: 4px 16px; 24 | } 25 | 26 | .d3-context-menu ul li:hover { 27 | background-color: #4677f8; 28 | color: #fefefe; 29 | } 30 | -------------------------------------------------------------------------------- /aesara/graph/__init__.py: -------------------------------------------------------------------------------- 1 | """Graph objects and manipulation functions.""" 2 | 3 | # isort: off 4 | from aesara.graph.basic import ( 5 | Apply, 6 | Variable, 7 | Constant, 8 | graph_inputs, 9 | clone, 10 | clone_replace, 11 | ancestors, 12 | ) 13 | from aesara.graph.op import Op 14 | from aesara.graph.type import Type 15 | from aesara.graph.fg import FunctionGraph 16 | from aesara.graph.rewriting.basic import node_rewriter, graph_rewriter 17 | from aesara.graph.rewriting.utils import rewrite_graph 18 | from aesara.graph.rewriting.db import RewriteDatabaseQuery 19 | 20 | # isort: on 21 | -------------------------------------------------------------------------------- /aesara/link/numba/dispatch/__init__.py: -------------------------------------------------------------------------------- 1 | # isort: off 2 | from aesara.link.numba.dispatch.basic import ( 3 | numba_funcify, 4 | numba_const_convert, 5 | numba_njit, 6 | ) 7 | 8 | # Load dispatch specializations 9 | import aesara.link.numba.dispatch.scalar 10 | import aesara.link.numba.dispatch.tensor_basic 11 | import aesara.link.numba.dispatch.extra_ops 12 | import aesara.link.numba.dispatch.nlinalg 13 | import aesara.link.numba.dispatch.random 14 | import aesara.link.numba.dispatch.elemwise 15 | import aesara.link.numba.dispatch.scan 16 | import aesara.link.numba.dispatch.sparse 17 | 18 | # isort: on 19 | -------------------------------------------------------------------------------- /aesara/link/jax/dispatch/__init__.py: -------------------------------------------------------------------------------- 1 | # isort: off 2 | from aesara.link.jax.dispatch.basic import jax_funcify, jax_typify 3 | 4 | # Load dispatch specializations 5 | import aesara.link.jax.dispatch.scalar 6 | import aesara.link.jax.dispatch.tensor_basic 7 | import aesara.link.jax.dispatch.subtensor 8 | import aesara.link.jax.dispatch.shape 9 | import aesara.link.jax.dispatch.extra_ops 10 | import aesara.link.jax.dispatch.nlinalg 11 | import aesara.link.jax.dispatch.slinalg 12 | import aesara.link.jax.dispatch.random 13 | import aesara.link.jax.dispatch.elemwise 14 | import aesara.link.jax.dispatch.scan 15 | 16 | # isort: on 17 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | version: 2 2 | sphinx: 3 | configuration: doc/conf.py 4 | conda: 5 | environment: doc/environment.yml 6 | build: 7 | os: "ubuntu-20.04" 8 | tools: 9 | python: "mambaforge-4.10" 10 | jobs: 11 | post_checkout: 12 | # This is necessary for setuptools_scm to properly read the tags. The default 13 | # depth of 50 often leads to 'assert version is not None' AssertionError during 14 | # the pip install build process. Alternatively, we could use 15 | # 'git fetch --unshallow', but that is rather intensive. 16 | - git fetch --depth 1000 17 | post_install: 18 | - pip list 19 | -------------------------------------------------------------------------------- /doc/troubleshoot/d3viz/examples/d3viz/css/d3-context-menu.css: -------------------------------------------------------------------------------- 1 | .d3-context-menu { 2 | position: absolute; 3 | display: none; 4 | background-color: #f2f2f2; 5 | border-radius: 4px; 6 | 7 | font-family: Arial, sans-serif; 8 | font-size: 14px; 9 | min-width: 50px; 10 | border: 1px solid #d4d4d4; 11 | 12 | z-index:1200; 13 | } 14 | 15 | .d3-context-menu ul { 16 | list-style-type: none; 17 | margin: 4px 0px; 18 | padding: 0px; 19 | cursor: default; 20 | } 21 | 22 | .d3-context-menu ul li { 23 | padding: 4px 16px; 24 | } 25 | 26 | .d3-context-menu ul li:hover { 27 | background-color: #4677f8; 28 | color: #fefefe; 29 | } 30 | -------------------------------------------------------------------------------- /doc/reference/tensor/sparse/sandbox.rst: -------------------------------------------------------------------------------- 1 | .. ../../../../aesara/sparse/sandbox/sp.py 2 | .. ../../../../aesara/sparse/basic/truedot.py 3 | 4 | .. _libdoc_sparse_sandbox: 5 | 6 | =================================================================== 7 | :mod:`sparse.sandbox` -- Sparse Op Sandbox 8 | =================================================================== 9 | 10 | .. module:: sparse.sandbox 11 | :platform: Unix, Windows 12 | :synopsis: Sparse Op Sandbox 13 | .. moduleauthor:: LISA 14 | 15 | API 16 | === 17 | 18 | .. automodule:: aesara.sparse.sandbox.sp 19 | :members: 20 | .. automodule:: aesara.sparse.sandbox.sp2 21 | :members: 22 | -------------------------------------------------------------------------------- /doc/reference/gradient/index.rst: -------------------------------------------------------------------------------- 1 | .. _reference_grad: 2 | 3 | Gradients 4 | ========= 5 | 6 | .. module:: grad 7 | 8 | The module :mod:`aesara.grad` allows to compute the gradient of an Aesara graph. 9 | 10 | .. autofunction:: aesara.gradient.grad 11 | 12 | This section of the documentation is organized as follows: 13 | 14 | * :ref:`Derivatives in Aesara ` gives a hands-on introduction to how to build gradient graphs in Aesara. 15 | * :ref:`Gradients API ` is an API reference for the :mod:`aesara.gradient` module. 16 | 17 | 18 | .. toctree:: 19 | :maxdepth: 1 20 | :hidden: 21 | 22 | gradient_tutorial 23 | gradient_api 24 | -------------------------------------------------------------------------------- /aesara/misc/check_blas_many.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python misc/check_blas.py --print_only 4 | 5 | cat /proc/cpuinfo |grep "model name" |uniq 6 | cat /proc/cpuinfo |grep processor 7 | free 8 | uname -a 9 | 10 | TIME_PREFIX=time 11 | VAR=OMP_NUM_THREADS 12 | echo "numpy gemm take=" 13 | AESARA_FLAGS=blas__ldflags= $TIME_PREFIX python misc/check_blas.py --quiet 14 | for i in 1 2 4 8 15 | do 16 | export $VAR=$i 17 | x=`$TIME_PREFIX python misc/check_blas.py --quiet` 18 | echo "aesara gemm with $VAR=$i took: ${x}s" 19 | done 20 | 21 | #Fred to test distro numpy at LISA: PYTHONPATH=/u/bastienf/repos:/usr/lib64/python2.5/site-packages AESARA_FLAGS=blas__ldflags= OMP_NUM_THREADS=8 time python misc/check_blas.py 22 | -------------------------------------------------------------------------------- /doc/reference/tensor/operations.discrete_fourier.rst: -------------------------------------------------------------------------------- 1 | .. _reference_tensor_discrete_fourier: 2 | 3 | Discrete Fourier Transform 4 | ========================== 5 | 6 | ``aesara.tensor.fourier`` 7 | ~~~~~~~~~~~~~~~~~~~~~~~~~ 8 | 9 | .. note:: 10 | 11 | You need to import the ``aesara.tensor.fourier`` module first. 12 | 13 | 14 | .. currentmodule:: aesara.tensor.fourier 15 | .. autosummary:: 16 | :toctree: _autosummary 17 | 18 | fft 19 | 20 | 21 | ``aesara.tensor.fft`` 22 | ~~~~~~~~~~~~~~~~~~~~~~ 23 | 24 | .. note:: 25 | 26 | You need to import the ``aesara.tensor.fft`` module first. 27 | 28 | 29 | .. currentmodule:: aesara.tensor.fft 30 | .. autosummary:: 31 | :toctree: _autosummary 32 | 33 | rfft 34 | irfft 35 | -------------------------------------------------------------------------------- /aesara/misc/latence_gpu_transfert.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import numpy as np 4 | 5 | import aesara 6 | 7 | 8 | y = aesara.tensor.type.fvector() 9 | x = aesara.shared(np.zeros(1, dtype="float32")) 10 | f1 = aesara.function([y], updates={x: y}) 11 | f2 = aesara.function([], x.transfer("cpu")) 12 | print(f1.maker.fgraph.toposort()) 13 | print(f2.maker.fgraph.toposort()) 14 | for i in (1, 10, 100, 1000, 10000, 100000, 1000000, 10000000): 15 | o = np.zeros(i, dtype="float32") 16 | t0 = time.perf_counter() 17 | f1(o) 18 | t1 = time.perf_counter() 19 | tf1 = t1 - t0 20 | t0 = time.perf_counter() 21 | f2() 22 | t1 = time.perf_counter() 23 | 24 | print("%8i %6.1f ns %7.1f ns" % (i, tf1 * 1e6, (t1 - t0) * 1e6)) 25 | -------------------------------------------------------------------------------- /codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "70...100" 8 | status: 9 | project: 10 | default: 11 | # basic 12 | target: auto 13 | threshold: 1% 14 | base: auto 15 | patch: 16 | default: 17 | # basic 18 | target: 100% 19 | threshold: 1% 20 | base: auto 21 | 22 | comment: 23 | layout: "reach, diff, flags, files" 24 | behavior: default 25 | require_changes: false # if true: only post the comment if coverage changes 26 | require_base: no # [yes :: must have a base report to post] 27 | require_head: yes # [yes :: must have a head report to post] 28 | branches: null # branch names that can post comment 29 | -------------------------------------------------------------------------------- /doc/reference/loops/index.rst: -------------------------------------------------------------------------------- 1 | .. _reference_scan: 2 | 3 | Loops 4 | ===== 5 | 6 | The module :mod:`aesara.scan` provides the basic functionality needed to do loops 7 | in Aesara. 8 | 9 | .. automodule:: aesara.scan 10 | 11 | `aesara.scan` 12 | ------------- 13 | 14 | .. autofunction:: aesara.scan 15 | :noindex: 16 | 17 | Other ways to create loops 18 | -------------------------- 19 | 20 | :func:`aesara.scan` comes with bells and whistles that are not always all necessary, which is why Aesara provides several other functions to create a :class:`Scan` operator: 21 | 22 | .. autofunction:: aesara.map 23 | .. autofunction:: aesara.reduce 24 | .. autofunction:: aesara.foldl 25 | .. autofunction:: aesara.foldr 26 | 27 | .. toctree:: 28 | :maxdepth: 1 29 | 30 | loops_api 31 | loops_tutorial 32 | scan_extend 33 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | select = C,E,F,W 3 | ignore = E203,E231,E501,E741,W503,W504,C901 4 | max-line-length = 88 5 | per-file-ignores = 6 | **/__init__.py:F401,E402,F403 7 | aesara/tensor/linalg.py:F401,F403 8 | aesara/scalar/basic_scipy.py:E402,F403,F401 9 | aesara/graph/toolbox.py:E402,F403,F401 10 | aesara/link/jax/jax_dispatch.py:E402,F403,F401 11 | aesara/link/jax/jax_linker.py:E402,F403,F401 12 | aesara/sparse/sandbox/sp2.py:F401 13 | tests/tensor/test_math_scipy.py:E402 14 | tests/sparse/test_basic.py:E402 15 | tests/sparse/test_opt.py:E402 16 | tests/sparse/test_sp2.py:E402 17 | tests/sparse/test_utils.py:E402,F401 18 | tests/sparse/sandbox/test_sp.py:E402,F401 19 | tests/scalar/test_basic_sympy.py:E402 20 | aesara/graph/rewriting/unify.py:F811 21 | exclude = 22 | doc/ 23 | -------------------------------------------------------------------------------- /doc/reference/tensor/operations.binary_operations.rst: -------------------------------------------------------------------------------- 1 | .. _reference_tensor_binary_operations: 2 | .. currentmodule:: aesara.tensor 3 | 4 | Binary operations 5 | ================== 6 | 7 | .. note:: 8 | 9 | The bitwise operators take an integer as an input. 10 | 11 | .. autosummary:: 12 | :toctree: _autosummary 13 | 14 | bitwise_and 15 | bitwise_or 16 | bitwise_xor 17 | bitwise_not 18 | invert 19 | 20 | .. doctest:: 21 | :options: +SKIP 22 | 23 | >>> a, b = at.itensor3(), at.itensor3() # example inputs 24 | >>> a & b # at.and_(a,b) bitwise and (alias at.bitwise_and) 25 | >>> a ^ 1 # at.xor(a,1) bitwise xor (alias at.bitwise_xor) 26 | >>> a | b # at.or_(a,b) bitwise or (alias at.bitwise_or) 27 | >>> ~a # at.invert(a) bitwise invert (alias at.bitwise_not) 28 | -------------------------------------------------------------------------------- /DESCRIPTION.txt: -------------------------------------------------------------------------------- 1 | Aesara is a Python library that allows you to define, optimize, and efficiently evaluate mathematical expressions involving multi-dimensional arrays. It is built on top of NumPy_. Aesara features: 2 | 3 | * **tight integration with NumPy:** a similar interface to NumPy's. numpy.ndarrays are also used internally in Aesara-compiled functions. 4 | * **efficient symbolic differentiation:** Aesara can compute derivatives for functions of one or many inputs. 5 | * **speed and stability optimizations:** avoid nasty bugs when computing expressions such as log(1 + exp(x)) for large values of x. 6 | * **dynamic C code generation:** evaluate expressions faster. 7 | * **extensive unit-testing and self-verification:** includes tools for detecting and diagnosing bugs and/or potential problems. 8 | 9 | .. _NumPy: http://numpy.scipy.org/ 10 | -------------------------------------------------------------------------------- /aesara/scalar/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | from aesara.scalar.basic import * 4 | from aesara.scalar.math import * 5 | 6 | 7 | # isort: off 8 | from aesara.scalar.basic import DEPRECATED_NAMES as BASIC_DEPRECATIONS 9 | 10 | # isort: on 11 | 12 | DEPRECATED_NAMES: List[Tuple[str, str, object]] = BASIC_DEPRECATIONS 13 | 14 | 15 | def __getattr__(name): 16 | """Intercept module-level attribute access of deprecated symbols. 17 | 18 | Adapted from https://stackoverflow.com/a/55139609/3006474. 19 | 20 | """ 21 | from warnings import warn 22 | 23 | for old_name, msg, old_object in DEPRECATED_NAMES: 24 | if name == old_name: 25 | warn(msg, DeprecationWarning, stacklevel=2) 26 | return old_object 27 | 28 | raise AttributeError(f"module {__name__} has no attribute {name}") 29 | -------------------------------------------------------------------------------- /tests/tensor/test_xlogx.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | import aesara 4 | from aesara.tensor import as_tensor_variable 5 | from aesara.tensor.xlogx import xlogx, xlogy0 6 | from tests import unittest_tools as utt 7 | 8 | 9 | def test_xlogx(): 10 | x = as_tensor_variable([1, 0]) 11 | y = xlogx(x) 12 | f = aesara.function([], y) 13 | assert np.array_equal(f(), np.asarray([0, 0.0])) 14 | 15 | rng = np.random.default_rng(24982) 16 | utt.verify_grad(xlogx, [rng.random((3, 4))]) 17 | 18 | 19 | def test_xlogy0(): 20 | x = as_tensor_variable([1, 0]) 21 | y = as_tensor_variable([1, 0]) 22 | z = xlogy0(x, y) 23 | f = aesara.function([], z) 24 | assert np.array_equal(f(), np.asarray([0, 0.0])) 25 | 26 | rng = np.random.default_rng(24982) 27 | utt.verify_grad(xlogy0, [rng.random((3, 4)), rng.random((3, 4))]) 28 | -------------------------------------------------------------------------------- /aesara/sparse/utils.py: -------------------------------------------------------------------------------- 1 | from aesara.utils import hash_from_code 2 | 3 | 4 | def hash_from_sparse(data): 5 | # We need to hash the shapes as hash_from_code only hashes 6 | # the data buffer. Otherwise, this will cause problem with shapes like: 7 | # (1, 0) and (2, 0) 8 | # We also need to add the dtype to make the distinction between 9 | # uint32 and int32 of zeros with the same shape. 10 | 11 | # Python hash is not strong, so use sha256 instead. To avoid having a too 12 | # long hash, I call it again on the contatenation of all parts. 13 | return hash_from_code( 14 | hash_from_code(data.data) 15 | + hash_from_code(data.indices) 16 | + hash_from_code(data.indptr) 17 | + hash_from_code(str(data.shape)) 18 | + hash_from_code(str(data.dtype)) 19 | + hash_from_code(data.format) 20 | ) 21 | -------------------------------------------------------------------------------- /doc/sandbox/performance.rst: -------------------------------------------------------------------------------- 1 | 2 | =========== 3 | Performance 4 | =========== 5 | 6 | Aesara uses several tricks to obtain good performance: 7 | * common sub-expression elimination 8 | * [custom generated] C code for many operations 9 | * pre-allocation of temporary storage 10 | * loop fusion (which gcc normally can't do) 11 | 12 | On my neural net experiments for my course projects, I was getting around 10x 13 | speed improvements over basic numpy by using aesara. 14 | [More specific speed tests would be nice.] 15 | 16 | 17 | With a little work, Aesara could also implement more sophisticated 18 | rewrites: 19 | 20 | * automatic ordering of matrix multiplications 21 | * profile-based memory layout decisions (e.g. row-major vs. col-major) 22 | * gcc intrinsics to use MMX, SSE2 parallelism for faster element-wise arithmetic 23 | * conditional expressions 24 | -------------------------------------------------------------------------------- /aesara/link/c/c_code/aesara_mod_helper.h: -------------------------------------------------------------------------------- 1 | #ifndef AESARA_MOD_HELPER 2 | #define AESARA_MOD_HELPER 3 | 4 | #include 5 | 6 | #ifndef _WIN32 7 | #define MOD_PUBLIC __attribute__((visibility ("default"))) 8 | #else 9 | /* MOD_PUBLIC is only used in PyMODINIT_FUNC, which is declared 10 | * and implemented in mod.cu/cpp, not in headers, so dllexport 11 | * is always correct. */ 12 | #define MOD_PUBLIC __declspec( dllexport ) 13 | #endif 14 | 15 | #ifdef __cplusplus 16 | #define AESARA_EXTERN extern "C" 17 | #else 18 | #define AESARA_EXTERN 19 | #endif 20 | 21 | #if PY_MAJOR_VERSION < 3 22 | #define AESARA_RTYPE void 23 | #else 24 | #define AESARA_RTYPE PyObject * 25 | #endif 26 | 27 | /* We need to redefine PyMODINIT_FUNC to add MOD_PUBLIC in the middle */ 28 | #undef PyMODINIT_FUNC 29 | #define PyMODINIT_FUNC AESARA_EXTERN MOD_PUBLIC AESARA_RTYPE 30 | 31 | #endif 32 | -------------------------------------------------------------------------------- /aesara/typed_list/rewriting.py: -------------------------------------------------------------------------------- 1 | from aesara.compile import optdb 2 | from aesara.graph.rewriting.basic import WalkingGraphRewriter, node_rewriter 3 | from aesara.typed_list.basic import Append, Extend, Insert, Remove, Reverse 4 | 5 | 6 | @node_rewriter([Append, Extend, Insert, Reverse, Remove], inplace=True) 7 | def typed_list_inplace_rewrite(fgraph, node): 8 | if ( 9 | isinstance(node.op, (Append, Extend, Insert, Reverse, Remove)) 10 | and not node.op.inplace 11 | ): 12 | new_op = node.op.__class__(inplace=True) 13 | new_node = new_op(*node.inputs) 14 | return [new_node] 15 | return False 16 | 17 | 18 | optdb.register( 19 | "typed_list_inplace_rewrite", 20 | WalkingGraphRewriter( 21 | typed_list_inplace_rewrite, failure_callback=WalkingGraphRewriter.warn_inplace 22 | ), 23 | "fast_run", 24 | "inplace", 25 | position=60, 26 | ) 27 | -------------------------------------------------------------------------------- /aesara/graph/optdb.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.graph.optdb` is deprecated; use `aesara.graph.rewriting.db` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.graph.rewriting.db import * # noqa: F401 E402 F403 11 | from aesara.graph.rewriting.db import DEPRECATED_NAMES # noqa: F401 E402 F403 12 | 13 | 14 | def __getattr__(name): 15 | """Intercept module-level attribute access of deprecated symbols. 16 | 17 | Adapted from https://stackoverflow.com/a/55139609/3006474. 18 | 19 | """ 20 | global DEPRECATED_NAMES 21 | 22 | from warnings import warn 23 | 24 | for old_name, msg, old_object in DEPRECATED_NAMES: 25 | if name == old_name: 26 | warn(msg, DeprecationWarning, stacklevel=2) 27 | return old_object 28 | 29 | raise AttributeError(f"module {__name__} has no attribute {name}") 30 | -------------------------------------------------------------------------------- /aesara/graph/opt.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.graph.opt` is deprecated; use `aesara.graph.rewriting.basic` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.graph.rewriting.basic import * # noqa: F401 E402 F403 11 | from aesara.graph.rewriting.basic import DEPRECATED_NAMES # noqa: F401 E402 F403 12 | 13 | 14 | def __getattr__(name): 15 | """Intercept module-level attribute access of deprecated symbols. 16 | 17 | Adapted from https://stackoverflow.com/a/55139609/3006474. 18 | 19 | """ 20 | global DEPRECATED_NAMES 21 | 22 | from warnings import warn 23 | 24 | for old_name, msg, old_object in DEPRECATED_NAMES: 25 | if name == old_name: 26 | warn(msg, DeprecationWarning, stacklevel=2) 27 | return old_object 28 | 29 | raise AttributeError(f"module {__name__} has no attribute {name}") 30 | -------------------------------------------------------------------------------- /doc/reference/tensor/operations.logic.rst: -------------------------------------------------------------------------------- 1 | .. _reference_tensor_logic: 2 | .. currentmodule:: aesara.tensor 3 | 4 | 5 | Logic functions 6 | =============== 7 | 8 | 9 | Truth value testing 10 | ------------------- 11 | 12 | .. autosummary:: 13 | :toctree: _autosummary 14 | 15 | allclose 16 | any 17 | 18 | Array contents 19 | -------------- 20 | 21 | .. autosummary:: 22 | :toctree: _autosummary 23 | 24 | isinf 25 | isnan 26 | 27 | Comparisons 28 | ----------- 29 | 30 | .. note:: 31 | 32 | Aesara does not have a boolean dtype. Instead the result of comparison operators are represented in ``int8``. 33 | 34 | .. danger:: 35 | 36 | The Python operator ``==`` does not work as a comparison operator in the usual sense in Aesara. Use :func:`eq` instead. 37 | 38 | .. autosummary:: 39 | :toctree: _autosummary 40 | 41 | lt 42 | gt 43 | ge 44 | eq 45 | neq 46 | allclose 47 | isclose 48 | -------------------------------------------------------------------------------- /.github/workflows/update-pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: Pre-commit auto-update 2 | 3 | on: 4 | # Every day at midnight 5 | schedule: 6 | # Automatically run on 07:27 UTC every Monday 7 | - cron: '27 7 * * 1' 8 | # On demand 9 | workflow_dispatch: 10 | 11 | jobs: 12 | auto-update: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v3 16 | 17 | - uses: actions/setup-python@v4 18 | with: 19 | cache: "pip" 20 | cache-dependency-path: "pyproject.toml" 21 | 22 | - uses: browniebroke/pre-commit-autoupdate-action@main 23 | 24 | - uses: peter-evans/create-pull-request@v5 25 | with: 26 | token: ${{ secrets.GITHUB_TOKEN }} 27 | branch: update/pre-commit-hooks 28 | title: Update pre-commit hooks 29 | commit-message: "Update pre-commit hook versions" 30 | body: Update pre-commit hooks to their latest versions. 31 | -------------------------------------------------------------------------------- /aesara/graph/opt_utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.graph.opt_utils` is deprecated; use `aesara.graph.rewriting.utils` instead.", 6 | DeprecationWarning, 7 | stacklevel=2, 8 | ) 9 | 10 | from aesara.graph.rewriting.utils import * # noqa: F401 E402 F403 11 | from aesara.graph.rewriting.utils import DEPRECATED_NAMES # noqa: F401 E402 F403 12 | 13 | 14 | def __getattr__(name): 15 | """Intercept module-level attribute access of deprecated symbols. 16 | 17 | Adapted from https://stackoverflow.com/a/55139609/3006474. 18 | 19 | """ 20 | global DEPRECATED_NAMES 21 | 22 | from warnings import warn 23 | 24 | for old_name, msg, old_object in DEPRECATED_NAMES: 25 | if name == old_name: 26 | warn(msg, DeprecationWarning, stacklevel=2) 27 | return old_object 28 | 29 | raise AttributeError(f"module {__name__} has no attribute {name}") 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pkl 2 | _build 3 | __pycache__ 4 | .coverage 5 | *.linkinfo 6 | *.o 7 | *.orig 8 | *.pyc 9 | *.pyo 10 | *.so 11 | *.sw? 12 | *~ 13 | *.aux 14 | *.log 15 | *.nav 16 | *.out 17 | *.snm 18 | *.toc 19 | *.vrb 20 | *.nbc 21 | *.nbi 22 | .noseids 23 | *.DS_Store 24 | *.bak 25 | *.egg-info/ 26 | \#*\# 27 | build 28 | compiled/*.cpp 29 | core.* 30 | cutils_ext.cpp 31 | dist 32 | doc/.build/* 33 | !doc/.build/PLACEHOLDER 34 | doc/indexes/oplist.txt 35 | doc/indexes/typelist.txt 36 | /html 37 | pdf 38 | setuptools-*.egg 39 | aesara/generated_version.py 40 | aesara/generated_version.py.out 41 | distribute-*.egg 42 | distribute-*.tar.gz 43 | Aesara.suo 44 | .ipynb_checkpoints 45 | .pydevproject 46 | .ropeproject 47 | core 48 | .idea 49 | .vs 50 | .mypy_cache/ 51 | .pytest_cache/ 52 | /htmlcov/ 53 | 54 | aesara-venv/ 55 | /notebooks/Sandbox* 56 | .vscode/ 57 | testing-report.html 58 | coverage.xml 59 | .coverage.* 60 | aesara/_version.py 61 | _autosummary/ 62 | -------------------------------------------------------------------------------- /doc/fundamentals/graph/graph/features.rst: -------------------------------------------------------------------------------- 1 | .. _libdoc_graph_features: 2 | 3 | ================================================ 4 | :mod:`features` -- [doc TODO] 5 | ================================================ 6 | 7 | .. module:: aesara.graph.features 8 | :platform: Unix, Windows 9 | :synopsis: Aesara Internals 10 | .. moduleauthor:: LISA 11 | 12 | Guide 13 | ===== 14 | 15 | .. class:: Bookkeeper(object) 16 | 17 | .. class:: History(object) 18 | 19 | .. method:: revert(fgraph, checkpoint) 20 | Reverts the graph to whatever it was at the provided 21 | checkpoint (undoes all replacements). A checkpoint at any 22 | given time can be obtained using self.checkpoint(). 23 | 24 | .. class:: Validator(object) 25 | 26 | .. class:: ReplaceValidate(History, Validator) 27 | 28 | .. method:: replace_validate(fgraph, var, new_var, reason=None) 29 | 30 | .. class:: NodeFinder(Bookkeeper) 31 | 32 | .. class:: PrintListener(object) 33 | -------------------------------------------------------------------------------- /aesara/sparse/sharedvar.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import scipy.sparse 4 | 5 | from aesara.compile import shared_constructor 6 | from aesara.sparse.basic import SparseTensorType, _sparse_py_operators 7 | from aesara.tensor.sharedvar import TensorSharedVariable 8 | 9 | 10 | class SparseTensorSharedVariable(TensorSharedVariable, _sparse_py_operators): 11 | @property 12 | def format(self): 13 | return self.type.format 14 | 15 | 16 | @shared_constructor.register(scipy.sparse.spmatrix) 17 | def sparse_constructor( 18 | value, name=None, strict=False, allow_downcast=None, borrow=False, format=None 19 | ): 20 | if format is None: 21 | format = value.format 22 | 23 | type = SparseTensorType(format=format, dtype=value.dtype) 24 | 25 | if not borrow: 26 | value = copy.deepcopy(value) 27 | 28 | return SparseTensorSharedVariable( 29 | type=type, value=value, strict=strict, allow_downcast=allow_downcast, name=name 30 | ) 31 | -------------------------------------------------------------------------------- /doc/help.rst: -------------------------------------------------------------------------------- 1 | Get help 2 | ======== 3 | 4 | If you need help with any aspect of Aesara, let us know! We would be happy to hear from you. 5 | 6 | Bugs and feature requests 7 | ------------------------- 8 | 9 | .. tip:: 10 | Take a look at the :ref:`troubleshoot` section if you are encountering difficulties with Aesara. 11 | 12 | Report bugs on GitHub on the Aesara `Issues page `__. You can post feature requests on the `Discussions page `__. 13 | 14 | Help using Aesara 15 | ----------------- 16 | 17 | Don’t know how to do X with Aesara? Start a `discussion `__ on GitHub. 18 | 19 | Community 20 | --------- 21 | 22 | For real-time feedback or more general chat about Aesara, join the `Discord server `__ or the `Gitter room `__ to connect with the Aesara community. 23 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # To use: 2 | # 3 | # $ conda env create -f environment.yml # `mamba` works too for this command 4 | # $ conda activate aesara-dev 5 | # 6 | name: aesara-dev 7 | channels: 8 | - conda-forge 9 | - numba 10 | dependencies: 11 | - python 12 | - compilers 13 | - numpy>=1.17.0,<2.0.0 14 | - scipy>=0.14,<=1.12.0 15 | - filelock 16 | - etuples 17 | - logical-unification 18 | - miniKanren 19 | - cons 20 | # Intel BLAS 21 | - mkl 22 | - mkl-service 23 | - libblas=*=*mkl 24 | # numba backend 25 | - numba>=0.57.0,<0.58.0 26 | - numba-scipy 27 | # For testing 28 | - coveralls 29 | - diff-cover 30 | - pytest 31 | - pytest-cov 32 | - pytest-xdist 33 | - pytest-benchmark 34 | # For building docs 35 | - sphinx>=1.3 36 | - sphinx_rtd_theme 37 | - pygments 38 | - pydot 39 | - ipython 40 | # developer tools 41 | - pre-commit 42 | - packaging 43 | - typing_extensions 44 | # optional 45 | - sympy 46 | - cython 47 | -------------------------------------------------------------------------------- /tests/sandbox/test_minimal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | 4 | from aesara import function 5 | from aesara.sandbox.minimal import minimal 6 | from aesara.tensor.type import matrix, vector 7 | from tests import unittest_tools as utt 8 | 9 | 10 | @pytest.mark.skip(reason="Unfinished test") 11 | class TestMinimal: 12 | """ 13 | TODO: test dtype conversion 14 | TODO: test that invalid types are rejected by make_node 15 | TODO: test that each valid type for A and b works correctly 16 | """ 17 | 18 | def setup_method(self): 19 | self.rng = np.random.default_rng(utt.fetch_seed(666)) 20 | 21 | def test_minimal(self): 22 | A = matrix() 23 | b = vector() 24 | 25 | print("building function") 26 | f = function([A, b], minimal(A, A, b, b, A)) 27 | print("built") 28 | 29 | Aval = self.rng.standard_normal((5, 5)) 30 | bval = np.arange(5, dtype=float) 31 | f(Aval, bval) 32 | print("done") 33 | -------------------------------------------------------------------------------- /aesara/tensor/c_code/alt_blas_common.h: -------------------------------------------------------------------------------- 1 | /** C Implementation (with NumPy back-end) of BLAS functions used in Aesara. 2 | * Used instead of BLAS when Aesara flag ``blas__ldflags`` is empty. 3 | * This file contains some useful header code not templated. 4 | * File alt_blas_template.c currently contains template code for: 5 | * - [sd]gemm_ 6 | * - [sd]gemv_ 7 | * - [sd]dot_ 8 | **/ 9 | 10 | #define alt_fatal_error(message) { if (PyErr_Occurred()) PyErr_Print(); if(message != NULL) fprintf(stderr, message); exit(-1); } 11 | 12 | #define alt_trans_to_bool(trans) (*trans != 'N' && *trans != 'n') 13 | 14 | /**Template code for BLAS functions follows in file alt_blas_template.c 15 | * (as Python string to be used with old formatting). 16 | * PARAMETERS: 17 | * float_type: "float" or "double". 18 | * float_size: 4 for float32 (sgemm_), 8 for float64 (dgemm_). 19 | * npy_float: "NPY_FLOAT32" or "NPY_FLOAT64". 20 | * precision: "s" for single, "d" for double. 21 | * See blas_headers.py for current use.**/ 22 | -------------------------------------------------------------------------------- /doc/compile/opfromgraph.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | .. _opfromgraph: 4 | 5 | ============ 6 | `OpFromGraph` 7 | ============ 8 | 9 | This page describes :class:`aesara.compile.builders.OpFromGraph 10 | `, an `Op` constructor that allows one to 11 | encapsulate an Aesara graph in a single `Op`. 12 | 13 | This can be used to encapsulate some functionality in one block. It is 14 | useful to scale Aesara compilation for regular bigger graphs when we 15 | reuse that encapsulated functionality with different inputs many 16 | times. Due to this encapsulation, it can make Aesara's compilation phase 17 | faster for graphs with many nodes. 18 | 19 | Using this for small graphs is not recommended as it disables 20 | rewrites between what is inside the encapsulation and outside of it. 21 | 22 | .. note: 23 | 24 | This was not used widely up to now. If you have any 25 | questions/comments do not hesitate to contact us on the mailing list. 26 | 27 | 28 | 29 | .. autoclass:: aesara.compile.builders.OpFromGraph 30 | -------------------------------------------------------------------------------- /environment-arm.yml: -------------------------------------------------------------------------------- 1 | # To use: 2 | # 3 | # $ conda env create -f environment.yml # `mamba` works too for this command 4 | # $ conda activate aesara-dev 5 | # 6 | name: aesara-dev 7 | channels: 8 | - conda-forge 9 | - numba 10 | dependencies: 11 | - python 12 | - compilers 13 | - numpy>=1.17.0,<2.0.0 14 | - scipy>=0.14,<=1.12.0 15 | - filelock 16 | - etuples 17 | - logical-unification 18 | - miniKanren 19 | - cons 20 | # Non-Intel BLAS 21 | - nomkl 22 | - openblas 23 | - libblas=*=*openblas 24 | # numba backend 25 | - numba>=0.57.0,<0.58.0 26 | - llvmlite>=0.38.1 27 | - numba-scipy 28 | # For testing 29 | - coveralls 30 | - diff-cover 31 | - pytest 32 | - pytest-cov 33 | - pytest-xdist 34 | - pytest-benchmark 35 | # For building docs 36 | - sphinx>=1.3 37 | - sphinx_rtd_theme 38 | - pygments 39 | - pydot 40 | - ipython 41 | # developer tools 42 | - pre-commit 43 | - packaging 44 | - typing_extensions 45 | # optional 46 | - sympy 47 | - cython 48 | - jax<=0.4.16 49 | - jaxlib 50 | -------------------------------------------------------------------------------- /conftest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pytest 4 | 5 | 6 | def pytest_sessionstart(session): 7 | os.environ["AESARA_FLAGS"] = ",".join( 8 | [ 9 | os.environ.setdefault("AESARA_FLAGS", ""), 10 | "warn__ignore_bug_before=all,on_opt_error=raise,on_shape_error=raise,cmodule__warn_no_version=True", 11 | ] 12 | ) 13 | os.environ["NUMBA_BOUNDSCHECK"] = "1" 14 | 15 | 16 | def pytest_addoption(parser): 17 | parser.addoption( 18 | "--runslow", action="store_true", default=False, help="run slow tests" 19 | ) 20 | 21 | 22 | def pytest_configure(config): 23 | config.addinivalue_line("markers", "slow: mark test as slow to run") 24 | 25 | 26 | def pytest_collection_modifyitems(config, items): 27 | if config.getoption("--runslow"): 28 | # --runslow given in cli: do not skip slow tests 29 | return 30 | skip_slow = pytest.mark.skip(reason="need --runslow option to run") 31 | for item in items: 32 | if "slow" in item.keywords: 33 | item.add_marker(skip_slow) 34 | -------------------------------------------------------------------------------- /aesara/version.py: -------------------------------------------------------------------------------- 1 | try: 2 | from aesara._version import __version__ as version 3 | except ImportError: 4 | raise RuntimeError( 5 | "Unable to find the version number that is generated when either building or " 6 | "installing from source. Please make sure that this Aesara has been properly " 7 | "installed, e.g. with\n\n pip install -e .\n" 8 | ) 9 | 10 | deprecated_names = [ 11 | "FALLBACK_VERSION", 12 | "full_version", 13 | "git_revision", 14 | "short_version", 15 | "release", 16 | ] 17 | 18 | 19 | def __getattr__(name): 20 | # (Called when the module attribute is not found.) 21 | if name in deprecated_names: 22 | raise RuntimeError( 23 | f"{name} was deprecated when migrating away from versioneer. If you " 24 | f"need it, please search for or open an issue on GitHub entitled " 25 | f"'Restore deprecated versioneer variable {name}'.", 26 | ) 27 | raise AttributeError(f"module {__name__!r} has no attribute {name!r}") 28 | 29 | 30 | __all__ = ["version"] 31 | -------------------------------------------------------------------------------- /doc/sandbox/software.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | Others software 3 | =============== 4 | 5 | Other software to look at and maybe recommend to users: 6 | 7 | * [http://www.pytables.org/moin PyTables] - This is looking really 8 | promising for dataset storage and experiment logging... This might 9 | actually be useful for large data sets. 10 | * [http://matplotlib.sourceforge.net/ MatPlotLib] - visualization tools 11 | (plot curves interactively, like matlab's figure window) 12 | * [http://www.pythonware.com/products/pil/ PIL] - Python Image Library: 13 | write your matrices out in png! (Kinda a weird recommendation, I think) 14 | * [http://www.logilab.org/857 pylint] - Syntax checker for python to 15 | help beautify your code. (We'd be hypocrites to recommend this :) 16 | * [http://www.winpdb.org/ Winpdb] - A Platform Independent Python 17 | Debugger. (Except it doesn't really help you debug Aesara graphs) 18 | * [http://wiki.python.org/moin/IntegratedDevelopmentEnvironments Python 19 | Integrated Development Environments] - for all your coding needs 20 | -------------------------------------------------------------------------------- /aesara/misc/may_share_memory.py: -------------------------------------------------------------------------------- 1 | """ 2 | Function to detect memory sharing for ndarray AND sparse type. 3 | numpy version support only ndarray. 4 | """ 5 | 6 | 7 | import numpy as np 8 | 9 | from aesara.tensor.type import TensorType 10 | 11 | 12 | try: 13 | import scipy.sparse 14 | 15 | from aesara.sparse.basic import SparseTensorType 16 | 17 | def _is_sparse(a): 18 | return scipy.sparse.issparse(a) 19 | 20 | except ImportError: 21 | 22 | def _is_sparse(a): 23 | return False 24 | 25 | 26 | def may_share_memory(a, b, raise_other_type=True): 27 | a_ndarray = isinstance(a, np.ndarray) 28 | b_ndarray = isinstance(b, np.ndarray) 29 | if a_ndarray and b_ndarray: 30 | return TensorType.may_share_memory(a, b) 31 | 32 | a_sparse = _is_sparse(a) 33 | b_sparse = _is_sparse(b) 34 | if not (a_ndarray or a_sparse) or not (b_ndarray or b_sparse): 35 | if raise_other_type: 36 | raise TypeError("may_share_memory support only ndarray" " and scipy.sparse") 37 | return False 38 | 39 | return SparseTensorType.may_share_memory(a, b) 40 | -------------------------------------------------------------------------------- /doc/reference/tensor/operations.rst: -------------------------------------------------------------------------------- 1 | .. _reference_tensor_operations: 2 | 3 | Tensor operations 4 | ================= 5 | 6 | .. currentmodule:: aesara.tensor 7 | 8 | The module :mod:`aesara.tensor` allows to create tensors and express symbolic calculations using the NumPy and SciPy API. Docstings are grouped by functionality, and assume that ``aesara.tensor`` is imported as 9 | 10 | >>> import aesara.tensor as at 11 | 12 | Aesara's API tries to mirror NumPy's, so in most cases it is safe to assume that the basic NumPy array functions and methods will be available. If you find an inconsistency, or if a function is missing, please open an `Issue `__. 13 | 14 | .. toctree:: 15 | :maxdepth: 1 16 | 17 | operations.tensor_creation 18 | operations.tensor_manipulation 19 | operations.indexing 20 | operations.binary_operations 21 | operations.discrete_fourier 22 | operations.linalg 23 | operations.logic 24 | operations.mathematical_functions 25 | operations.padding 26 | operations.sorting 27 | operations.statistics 28 | operations.window_functions 29 | -------------------------------------------------------------------------------- /doc/fundamentals/graph/graph/fgraph.rst: -------------------------------------------------------------------------------- 1 | 2 | .. _libdoc_graph_fgraph: 3 | 4 | ================================================ 5 | :mod:`fg` -- Graph Container [doc TODO] 6 | ================================================ 7 | 8 | .. module:: aesara.graph.fg 9 | :platform: Unix, Windows 10 | :synopsis: Aesara Internals 11 | .. moduleauthor:: LISA 12 | 13 | 14 | .. _fgraph: 15 | 16 | FunctionGraph 17 | ------------- 18 | 19 | .. autoclass:: aesara.graph.fg.FunctionGraph 20 | :members: 21 | 22 | ***TODO*** 23 | 24 | .. note:: FunctionGraph(inputs, outputs) clones the inputs by 25 | default. To avoid this behavior, add the parameter 26 | clone=False. This is needed as we do not want cached constants 27 | in fgraph. 28 | 29 | .. _libdoc_graph_fgraphfeature: 30 | 31 | .. _fgraphfeature: 32 | 33 | FunctionGraph Features 34 | ---------------------- 35 | 36 | .. autoclass:: aesara.graph.features.Feature 37 | :members: 38 | 39 | .. _libdoc_graph_fgraphfeaturelist: 40 | 41 | FunctionGraph Feature List 42 | ^^^^^^^^^^^^^^^^^^^^^^^^^^ 43 | * ReplaceValidate 44 | * DestroyHandler 45 | -------------------------------------------------------------------------------- /doc/compile/rewrites/index.rst: -------------------------------------------------------------------------------- 1 | .. _reference_rewrites: 2 | 3 | Rewrites 4 | ======== 5 | 6 | This section of the documentation references all the rewrites that can be applied during the compilation of an Aesara graph 7 | 8 | Tensor rewrites 9 | --------------- 10 | 11 | These rewrites are implemented in the module :mod:`tensor.rewriting.basic`. 12 | 13 | .. automodule:: aesara.tensor.rewriting.basic 14 | :members: 15 | 16 | Indexing 17 | -------- 18 | 19 | .. automodule:: aesara.tensor.rewriting.subtensor 20 | :members: 21 | 22 | Shape 23 | ----- 24 | 25 | .. automodule:: aesara.tensor.rewriting.shape 26 | :members: 27 | 28 | Mathematical operations 29 | ----------------------- 30 | 31 | .. automodule:: aesara.tensor.rewriting.math 32 | :members: 33 | 34 | .. automodule:: aesara.tensor.rewriting.elemwise 35 | :members: 36 | 37 | .. automodule:: aesara.tensor.rewriting.extra_ops 38 | :members: 39 | 40 | .. automodule:: aesara.tensor.rewriting.special 41 | :members: 42 | 43 | Random variables 44 | ---------------- 45 | 46 | .. automodule:: aesara.tensor.random.rewriting.basic 47 | :members: 48 | -------------------------------------------------------------------------------- /aesara/link/c/cvm.py: -------------------------------------------------------------------------------- 1 | from aesara.configdefaults import config 2 | from aesara.link.c.exceptions import MissingGXX 3 | from aesara.link.vm import VM 4 | 5 | 6 | try: 7 | # If cxx is explicitly set to an empty string, we do not want to import 8 | # either lazy-linker C code or lazy-linker compiled C code from the cache. 9 | if not config.cxx: 10 | raise MissingGXX( 11 | "lazylinker will not be imported if aesara.config.cxx is not set." 12 | ) 13 | from aesara.link.c.lazylinker_c import CLazyLinker 14 | 15 | class CVM(CLazyLinker, VM): 16 | def __init__(self, fgraph, *args, **kwargs): 17 | self.fgraph = fgraph 18 | # skip VM.__init__ 19 | CLazyLinker.__init__(self, *args, **kwargs) 20 | 21 | except ImportError: 22 | pass 23 | except (OSError, MissingGXX): 24 | # OSError happens when g++ is not installed. In that case, we 25 | # already changed the default linker to something else then CVM. 26 | # Currently this is the py linker. 27 | # Here we assert that the default linker is not cvm. 28 | if config._config_var_dict["linker"].default.startswith("cvm"): 29 | raise 30 | -------------------------------------------------------------------------------- /tests/scalar/test_basic_sympy.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | import aesara 4 | from aesara.graph.fg import FunctionGraph 5 | from aesara.link.c.basic import CLinker 6 | from aesara.scalar.basic import floats 7 | from aesara.scalar.basic_sympy import SymPyCCode 8 | from tests.link.test_link import make_function 9 | 10 | 11 | sympy = pytest.importorskip("sympy") 12 | 13 | 14 | xs = sympy.Symbol("x") 15 | ys = sympy.Symbol("y") 16 | 17 | xt, yt = floats("xy") 18 | 19 | 20 | @pytest.mark.skipif(not aesara.config.cxx, reason="Need cxx for this test") 21 | def test_SymPyCCode(): 22 | op = SymPyCCode([xs, ys], xs + ys) 23 | e = op(xt, yt) 24 | g = FunctionGraph([xt, yt], [e]) 25 | fn = make_function(CLinker().accept(g)) 26 | assert fn(1.0, 2.0) == 3.0 27 | 28 | 29 | def test_grad(): 30 | op = SymPyCCode([xs], xs**2) 31 | zt = op(xt) 32 | ztprime = aesara.grad(zt, xt) 33 | assert ztprime.owner.op.expr == 2 * xs 34 | 35 | 36 | def test_multivar_grad(): 37 | op = SymPyCCode([xs, ys], xs**2 + ys**3) 38 | zt = op(xt, yt) 39 | dzdx, dzdy = aesara.grad(zt, [xt, yt]) 40 | assert dzdx.owner.op.expr == 2 * xs 41 | assert dzdy.owner.op.expr == 3 * ys**2 42 | -------------------------------------------------------------------------------- /doc/reference/tensor/index.rst: -------------------------------------------------------------------------------- 1 | .. _reference_tensor: 2 | 3 | Tensors 4 | ======= 5 | 6 | 7 | Aesara supports symbolic tensor expressions. When you type, 8 | 9 | >>> import aesara.tensor as at 10 | >>> x = at.fmatrix() 11 | 12 | the ``x`` is a :class:`TensorVariable` instance. 13 | 14 | The ``at.fmatrix`` object itself is an instance of :class:`TensorType`. 15 | Aesara knows what type of variable ``x`` is because ``x.type`` 16 | points back to ``at.fmatrix``. 17 | 18 | This section of the documentation is organized as follows: 19 | 20 | * :ref:`Tensor objects ` page explains the various ways in which a tensor variable can be created, the attributes and methods of :class:`TensorVariable` and :class:`TensorType`. 21 | * :ref:`Tensor creation ` describes all the ways one can create a :class:`TensorVariable`. 22 | * :ref:`Tensor operations ` lists the available operations on :class:`TensorVariable`. 23 | 24 | .. toctree:: 25 | :maxdepth: 1 26 | :hidden: 27 | 28 | tensor 29 | Creation 30 | Operations 31 | shapes 32 | sparse/index 33 | shared/index 34 | Utils 35 | -------------------------------------------------------------------------------- /aesara/link/numba/linker.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Any 2 | 3 | import numpy as np 4 | 5 | import aesara 6 | from aesara.link.basic import JITLinker 7 | 8 | 9 | if TYPE_CHECKING: 10 | from aesara.graph.basic import Variable 11 | 12 | 13 | class NumbaLinker(JITLinker): 14 | """A `Linker` that JIT-compiles NumPy-based operations using Numba.""" 15 | 16 | def output_filter(self, var: "Variable", out: Any) -> Any: 17 | if not isinstance(var, np.ndarray) and isinstance( 18 | var.type, aesara.tensor.TensorType 19 | ): 20 | return var.type.filter(out, allow_downcast=True) 21 | 22 | return out 23 | 24 | def fgraph_convert(self, fgraph, **kwargs): 25 | from aesara.link.numba.dispatch import numba_funcify 26 | 27 | return numba_funcify(fgraph, **kwargs) 28 | 29 | def jit_compile(self, fn): 30 | from aesara.link.numba.dispatch import numba_njit 31 | 32 | jitted_fn = numba_njit(fn) 33 | return jitted_fn 34 | 35 | def create_thunk_inputs(self, storage_map): 36 | thunk_inputs = [] 37 | for n in self.fgraph.inputs: 38 | thunk_inputs.append(storage_map[n]) 39 | 40 | return thunk_inputs 41 | -------------------------------------------------------------------------------- /doc/acknowledgement.rst: -------------------------------------------------------------------------------- 1 | .. _acknowledgement: 2 | 3 | 4 | Acknowledgements 5 | ================ 6 | 7 | .. note: 8 | 9 | This page is in construction. We are missing sources. 10 | 11 | 12 | * The developers of `NumPy `_. Theano is based on its ndarray object and uses much of its implementation. 13 | * The developers of `SciPy `_. Our sparse matrix support uses their sparse matrix objects. We also reuse other parts. 14 | * The developers of `Theano `_ 15 | * All `Aesara contributors `_. 16 | * All Theano users that have given us feedback. 17 | * Our random number generator implementation on CPU and GPU uses the MRG31k3p algorithm that is described in: 18 | 19 | P. L'Ecuyer and R. Touzin, `Fast Combined Multiple Recursive Generators with Multipliers of the form a = +/- 2^d +/- 2^e `_, Proceedings of the 2000 Winter Simulation Conference, Dec. 2000, 683--689. 20 | 21 | We were authorized by Pierre L'Ecuyer to copy/modify his Java implementation in the `SSJ `_ software and to relicense it under BSD 3-Clauses in Theano. 22 | -------------------------------------------------------------------------------- /aesara/sparse/__init__.py: -------------------------------------------------------------------------------- 1 | from aesara.sparse import rewriting, sharedvar 2 | from aesara.sparse.basic import * 3 | from aesara.sparse.sharedvar import sparse_constructor as shared 4 | from aesara.sparse.type import SparseTensorType, _is_sparse 5 | 6 | 7 | def sparse_grad(var): 8 | """This function return a new variable whose gradient will be 9 | stored in a sparse format instead of dense. 10 | 11 | Currently only variable created by AdvancedSubtensor1 is supported. 12 | i.e. a_tensor_var[an_int_vector]. 13 | 14 | .. versionadded:: 0.6rc4 15 | """ 16 | from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1 17 | 18 | if var.owner is None or not isinstance( 19 | var.owner.op, (AdvancedSubtensor, AdvancedSubtensor1) 20 | ): 21 | raise TypeError( 22 | "Sparse gradient is only implemented for AdvancedSubtensor and AdvancedSubtensor1" 23 | ) 24 | 25 | x = var.owner.inputs[0] 26 | indices = var.owner.inputs[1:] 27 | 28 | if len(indices) > 1: 29 | raise TypeError( 30 | "Sparse gradient is only implemented for single advanced indexing" 31 | ) 32 | 33 | ret = AdvancedSubtensor1(sparse_grad=True)(x, indices[0]) 34 | return ret 35 | -------------------------------------------------------------------------------- /aesara/link/jax/dispatch/slinalg.py: -------------------------------------------------------------------------------- 1 | import jax 2 | 3 | from aesara.link.jax.dispatch.basic import jax_funcify 4 | from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular 5 | 6 | 7 | @jax_funcify.register(Cholesky) 8 | def jax_funcify_Cholesky(op, **kwargs): 9 | lower = op.lower 10 | 11 | def cholesky(a, lower=lower): 12 | return jax.scipy.linalg.cholesky(a, lower=lower).astype(a.dtype) 13 | 14 | return cholesky 15 | 16 | 17 | @jax_funcify.register(Solve) 18 | def jax_funcify_Solve(op, **kwargs): 19 | if op.assume_a != "gen" and op.lower: 20 | lower = True 21 | else: 22 | lower = False 23 | 24 | def solve(a, b, lower=lower): 25 | return jax.scipy.linalg.solve(a, b, lower=lower) 26 | 27 | return solve 28 | 29 | 30 | @jax_funcify.register(SolveTriangular) 31 | def jax_funcify_SolveTriangular(op, **kwargs): 32 | lower = op.lower 33 | trans = op.trans 34 | unit_diagonal = op.unit_diagonal 35 | check_finite = op.check_finite 36 | 37 | def solve_triangular(A, b): 38 | return jax.scipy.linalg.solve_triangular( 39 | A, 40 | b, 41 | lower=lower, 42 | trans=trans, 43 | unit_diagonal=unit_diagonal, 44 | check_finite=check_finite, 45 | ) 46 | 47 | return solve_triangular 48 | -------------------------------------------------------------------------------- /aesara/graph/null_type.py: -------------------------------------------------------------------------------- 1 | from aesara.graph.type import Type 2 | 3 | 4 | class NullType(Type): 5 | """ 6 | A type that allows no values. 7 | 8 | Used to represent expressions 9 | that are undefined, either because they do not exist mathematically 10 | or because the code to generate the expression has not been 11 | implemented yet. 12 | 13 | Parameters 14 | ---------- 15 | why_null : str 16 | A string explaining why this variable can't take on any values. 17 | 18 | """ 19 | 20 | def __init__(self, why_null="(no explanation given)"): 21 | self.why_null = why_null 22 | 23 | def filter(self, data, strict=False, allow_downcast=None): 24 | raise ValueError("No values may be assigned to a NullType") 25 | 26 | def filter_variable(self, other, allow_convert=True): 27 | raise ValueError("No values may be assigned to a NullType") 28 | 29 | def may_share_memory(a, b): 30 | return False 31 | 32 | def values_eq(self, a, b, force_same_dtype=True): 33 | raise ValueError("NullType has no values to compare") 34 | 35 | def __eq__(self, other): 36 | return type(self) == type(other) 37 | 38 | def __hash__(self): 39 | return hash(type(self)) 40 | 41 | def __str__(self): 42 | return "NullType" 43 | 44 | 45 | null_type = NullType() 46 | -------------------------------------------------------------------------------- /doc/core_development_guide.rst: -------------------------------------------------------------------------------- 1 | :orphan: 2 | 3 | Core Development Guide 4 | ======================= 5 | 6 | The documentation of the core components of Aesara is still a work in 7 | progress. For now this is a list of bits and pieces on the subject, 8 | some of them might be outdated though: 9 | 10 | 11 | * :ref:`aesara_type` -- Tutorial for writing a new type in Aesara. It 12 | introduces the basics concerning Aesara datatypes. 13 | 14 | * :ref:`aesara_ctype` -- Tutorial on how to make your type C-friendly. 15 | 16 | * :ref:`views_and_inplace` -- This is somewhere between extending Aesara and 17 | describing how Aesara works internally; it talks about views and inplace 18 | operations. 19 | 20 | * :ref:`graph_rewriting` -- Tutorial on how graph rewriting works in Aesara. 21 | 22 | * :ref:`pipeline` -- Describes the steps of compiling an Aesara Function. 23 | 24 | * :ref:`graphstructures` -- Describes the symbolic graphs generated by 25 | :mod:`aesara.scan`. 26 | 27 | * :ref:`sandbox_debugging_step_mode` -- How to step through the execution of 28 | an Aesara function and print the inputs and outputs of each op. 29 | 30 | * :ref:`sandbox_elemwise` -- Description of element wise operations. 31 | 32 | * :ref:`sandbox_randnb` -- Description of how Aesara deals with random 33 | numbers. 34 | 35 | * :ref:`sparse` -- Description of the ``sparse`` type in Aesara. 36 | -------------------------------------------------------------------------------- /aesara/d3viz/js/d3-context-menu.js: -------------------------------------------------------------------------------- 1 | d3.contextMenu = function (menu, openCallback) { 2 | 3 | // create the div element that will hold the context menu 4 | d3.selectAll('.d3-context-menu').data([1]) 5 | .enter() 6 | .append('div') 7 | .attr('class', 'd3-context-menu'); 8 | 9 | // close menu 10 | d3.select('body').on('click.d3-context-menu', function() { 11 | d3.select('.d3-context-menu').style('display', 'none'); 12 | }); 13 | 14 | // this gets executed when a contextmenu event occurs 15 | return function(data, index) { 16 | var elm = this; 17 | 18 | d3.selectAll('.d3-context-menu').html(''); 19 | var list = d3.selectAll('.d3-context-menu').append('ul'); 20 | list.selectAll('li').data(menu).enter() 21 | .append('li') 22 | .html(function(d) { 23 | return d.title; 24 | }) 25 | .on('click', function(d, i) { 26 | d.action(elm, data, index); 27 | d3.select('.d3-context-menu').style('display', 'none'); 28 | }); 29 | 30 | // the openCallback allows an action to fire before the menu is displayed 31 | // an example usage would be closing a tooltip 32 | if (openCallback) openCallback(data, index); 33 | 34 | // display context menu 35 | d3.select('.d3-context-menu') 36 | .style('left', (d3.event.pageX - 2) + 'px') 37 | .style('top', (d3.event.pageY - 2) + 'px') 38 | .style('display', 'block'); 39 | 40 | d3.event.preventDefault(); 41 | }; 42 | }; 43 | -------------------------------------------------------------------------------- /doc/reference/tensor/operations.indexing.rst: -------------------------------------------------------------------------------- 1 | .. _reference_tensor_indexing: 2 | .. currentmodule:: aesara.tensor 3 | 4 | 5 | Indexing 6 | -------- 7 | 8 | Like NumPy, Aesara distinguishes between *basic* and *advanced* indexing. 9 | Aesara fully supports basic indexing 10 | (see `NumPy's indexing `_) 11 | and `integer advanced indexing 12 | `_. 13 | 14 | Index-assignment is *not* supported. If you want to do something like ``a[5] 15 | = b`` or ``a[5]+=b``, see :func:`aesara.tensor.subtensor.set_subtensor` and 16 | :func:`aesara.tensor.subtensor.inc_subtensor` below. 17 | 18 | 19 | Generating index tensors 20 | ~~~~~~~~~~~~~~~~~~~~~~~~ 21 | 22 | .. autosummary:: 23 | :toctree: _autosummary 24 | 25 | where 26 | ogrid 27 | ravel_multi_index 28 | unravel_index 29 | tril_indices_from 30 | tril_indices 31 | triu_indices 32 | triu_indices_from 33 | 34 | 35 | Indexing-like operations 36 | ~~~~~~~~~~~~~~~~~~~~~~~~ 37 | 38 | .. autosummary:: 39 | :toctree: _autosummary 40 | 41 | take 42 | take_along_axis 43 | choose 44 | compress 45 | diag 46 | diagonal 47 | 48 | Inserting data into tensors 49 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~ 50 | 51 | 52 | .. autosummary:: 53 | :toctree: _autosummary 54 | 55 | set_subtensor 56 | inc_subtensor 57 | -------------------------------------------------------------------------------- /doc/troubleshoot/d3viz/examples/d3viz/js/d3-context-menu.js: -------------------------------------------------------------------------------- 1 | d3.contextMenu = function (menu, openCallback) { 2 | 3 | // create the div element that will hold the context menu 4 | d3.selectAll('.d3-context-menu').data([1]) 5 | .enter() 6 | .append('div') 7 | .attr('class', 'd3-context-menu'); 8 | 9 | // close menu 10 | d3.select('body').on('click.d3-context-menu', function() { 11 | d3.select('.d3-context-menu').style('display', 'none'); 12 | }); 13 | 14 | // this gets executed when a contextmenu event occurs 15 | return function(data, index) { 16 | var elm = this; 17 | 18 | d3.selectAll('.d3-context-menu').html(''); 19 | var list = d3.selectAll('.d3-context-menu').append('ul'); 20 | list.selectAll('li').data(menu).enter() 21 | .append('li') 22 | .html(function(d) { 23 | return d.title; 24 | }) 25 | .on('click', function(d, i) { 26 | d.action(elm, data, index); 27 | d3.select('.d3-context-menu').style('display', 'none'); 28 | }); 29 | 30 | // the openCallback allows an action to fire before the menu is displayed 31 | // an example usage would be closing a tooltip 32 | if (openCallback) openCallback(data, index); 33 | 34 | // display context menu 35 | d3.select('.d3-context-menu') 36 | .style('left', (d3.event.pageX - 2) + 'px') 37 | .style('top', (d3.event.pageY - 2) + 'px') 38 | .style('display', 'block'); 39 | 40 | d3.event.preventDefault(); 41 | }; 42 | }; 43 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | 2 | **Thank you for opening a PR!** 3 | 4 | Here are a few important guidelines and requirements to check before your PR can be merged: 5 | + [ ] There is an informative high-level description of the changes. 6 | + [ ] The description and/or commit message(s) references the relevant GitHub issue(s). 7 | + [ ] [`pre-commit`](https://pre-commit.com/#installation) is installed and [set up](https://pre-commit.com/#3-install-the-git-hook-scripts). 8 | + [ ] The commit messages follow [these guidelines](https://tbaggery.com/2008/04/19/a-note-about-git-commit-messages.html). 9 | + [ ] The commits correspond to [_relevant logical changes_](https://wiki.openstack.org/wiki/GitCommitMessages#Structural_split_of_changes), and there are **no commits that fix changes introduced by other commits in the same branch/BR**. 10 | + [ ] There are tests covering the changes introduced in the PR. 11 | 12 | Don't worry, your PR doesn't need to be in perfect order to submit it. As development progresses and/or reviewers request changes, you can always [rewrite the history](https://git-scm.com/book/en/v2/Git-Tools-Rewriting-History#_rewriting_history) of your feature/PR branches. 13 | 14 | If your PR is an ongoing effort and you would like to involve us in the process, simply make it a [draft PR](https://docs.github.com/en/free-pro-team@latest/github/collaborating-with-issues-and-pull-requests/about-pull-requests#draft-pull-requests). 15 | -------------------------------------------------------------------------------- /doc/generate_dtype_tensor_table.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | letters = [ 4 | ('b', 'int8'), 5 | ('w', 'int16'), 6 | ('i', 'int32'), 7 | ('l', 'int64'), 8 | ('d', 'float64'), 9 | ('f', 'float32'), 10 | ('c', 'complex64'), 11 | ('z', 'complex128') ] 12 | 13 | shapes = [ 14 | ('scalar', ()), 15 | ('vector', (False,)), 16 | ('row', (True, False)), 17 | ('col', (False, True)), 18 | ('matrix', (False,False)), 19 | ('tensor3', (False,False,False)), 20 | ('tensor4', (False,False,False,False)), 21 | ('tensor5', (False,False,False,False,False)), 22 | ('tensor6', (False,) * 6), 23 | ('tensor7', (False,) * 7),] 24 | 25 | hdr = '============ =========== ==== ================ ===================================' 26 | print(hdr) 27 | print('Constructor dtype ndim shape broadcastable') 28 | print(hdr) 29 | for letter in letters: 30 | for shape in shapes: 31 | suff = ',)' if len(shape[1])==1 else ')' 32 | s = '(' + ','.join('1' if b else '?' for b in shape[1]) + suff 33 | if len(shape[1]) < 6 or len(set(shape[1])) > 1: 34 | broadcastable_str = str(shape[1]) 35 | else: 36 | broadcastable_str = '(%s,) * %d' % (str(shape[1][0]), len(shape[1])) 37 | print('%s%-10s %-10s %-4s %-15s %-20s' %( 38 | letter[0], shape[0], letter[1], len(shape[1]), s, broadcastable_str 39 | )) 40 | print(hdr) 41 | -------------------------------------------------------------------------------- /doc/.templates/layout.html: -------------------------------------------------------------------------------- 1 | {% extends "!layout.html" %} 2 | 3 | {% block footer %} 4 | {{ super() }} 5 | 16 | 17 | 18 | 35 | 36 | 39 | {% endblock %} 40 | -------------------------------------------------------------------------------- /doc/extend/op/index.rst: -------------------------------------------------------------------------------- 1 | .. _extend_op: 2 | 3 | Add a new Op 4 | ============ 5 | 6 | 7 | Don't define new :class:`Op`\s unless you have to 8 | ------------------------------------------------- 9 | 10 | It is usually not useful to define :class:`Op`\s that can be easily 11 | implemented using other already existing :class:`Op`\s. For example, instead of 12 | writing a "sum_square_difference" :class:`Op`, you should probably just write a 13 | simple function: 14 | 15 | .. code:: 16 | 17 | from aesara import tensor as at 18 | 19 | def sum_square_difference(a, b): 20 | return at.sum((a - b)**2) 21 | 22 | Even without taking Aesara's rewrites into account, it is likely 23 | to work just as well as a custom implementation. It also supports all 24 | data types, tensors of all dimensions as well as broadcasting, whereas 25 | a custom implementation would probably only bother to support 26 | contiguous vectors/matrices of doubles... 27 | 28 | 29 | Use Aesara's high order :class:`Op`\s when applicable 30 | ----------------------------------------------------- 31 | 32 | Aesara provides some generic :class:`Op` classes which allow you to generate a 33 | lot of :class:`Op`\s at a lesser effort. For instance, :class:`Elemwise` can be used to 34 | make :term:`elemwise` operations easily, whereas :class:`DimShuffle` can be 35 | used to make transpose-like transformations. These higher order :class:`Op`\s 36 | are mostly tensor-related, as this is Aesara's specialty. 37 | 38 | .. toctree:: 39 | :maxdepth: 1 40 | 41 | creating_an_op 42 | inplace 43 | how_to_make_ops 44 | -------------------------------------------------------------------------------- /aesara/misc/frozendict.py: -------------------------------------------------------------------------------- 1 | # License : https://github.com/slezica/python-frozendict/blob/master/LICENSE.txt 2 | 3 | 4 | import collections 5 | import functools 6 | import operator 7 | from collections.abc import Mapping 8 | 9 | 10 | class frozendict(Mapping): 11 | """ 12 | An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping` 13 | interface. It can be used as a drop-in replacement for dictionaries where immutability and ordering are desired. 14 | """ 15 | 16 | dict_cls = dict 17 | 18 | def __init__(self, *args, **kwargs): 19 | self._dict = self.dict_cls(*args, **kwargs) 20 | self._hash = None 21 | 22 | def __getitem__(self, key): 23 | return self._dict[key] 24 | 25 | def __contains__(self, key): 26 | return key in self._dict 27 | 28 | def copy(self, **add_or_replace): 29 | return self.__class__(self, **add_or_replace) 30 | 31 | def __iter__(self): 32 | return iter(self._dict) 33 | 34 | def __len__(self): 35 | return len(self._dict) 36 | 37 | def __repr__(self): 38 | return f"<{self.__class__.__name__} {self._dict!r}>" 39 | 40 | def __hash__(self): 41 | if self._hash is None: 42 | hashes = map(hash, self.items()) 43 | self._hash = functools.reduce(operator.xor, hashes, 0) 44 | 45 | return self._hash 46 | 47 | 48 | class FrozenOrderedDict(frozendict): 49 | """ 50 | A FrozenDict subclass that maintains key order 51 | """ 52 | 53 | dict_cls = collections.OrderedDict 54 | -------------------------------------------------------------------------------- /tests/graph/test_types.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from aesara.graph.basic import Variable 4 | from aesara.graph.type import Type 5 | 6 | 7 | class MyType(Type): 8 | def __init__(self, thingy): 9 | self.thingy = thingy 10 | 11 | def filter(self, *args, **kwargs): 12 | raise NotImplementedError() 13 | 14 | def __eq__(self, other): 15 | return isinstance(other, MyType) and other.thingy == self.thingy 16 | 17 | def __str__(self): 18 | return f"R{self.thingy}" 19 | 20 | def __repr__(self): 21 | return f"R{self.thingy}" 22 | 23 | 24 | class MyType2(MyType): 25 | def is_super(self, other): 26 | if self.thingy <= other.thingy: 27 | return True 28 | 29 | 30 | def test_is_super(): 31 | t1 = MyType(1) 32 | t2 = MyType(2) 33 | 34 | assert t1.is_super(t2) is None 35 | 36 | t1_2 = MyType(1) 37 | assert t1.is_super(t1_2) 38 | 39 | 40 | def test_in_same_class(): 41 | t1 = MyType(1) 42 | t2 = MyType(2) 43 | 44 | assert t1.in_same_class(t2) is False 45 | 46 | t1_2 = MyType(1) 47 | assert t1.in_same_class(t1_2) 48 | 49 | 50 | def test_convert_variable(): 51 | t1 = MyType(1) 52 | v1 = Variable(MyType(1), None, None) 53 | v2 = Variable(MyType(2), None, None) 54 | v3 = Variable(MyType2(0), None, None) 55 | 56 | assert t1.convert_variable(v1) is v1 57 | assert t1.convert_variable(v2) is None 58 | 59 | with pytest.raises(NotImplementedError): 60 | t1.convert_variable(v3) 61 | 62 | 63 | def test_default_clone(): 64 | mt = MyType(1) 65 | assert isinstance(mt.clone(1), MyType) 66 | -------------------------------------------------------------------------------- /.github/workflows/dev-builds.yml: -------------------------------------------------------------------------------- 1 | name: Development builds 2 | on: 3 | push: 4 | branches: 5 | - main 6 | 7 | jobs: 8 | build_and_publish: 9 | name: Build source distribution 10 | if: github.repository == 'aesara-devs/aesara' 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | with: 15 | fetch-depth: 0 16 | - uses: actions/setup-python@v4 17 | with: 18 | python-version: "3.9" 19 | cache: "pip" 20 | cache-dependency-path: "pyproject.toml" 21 | - name: Install dependencies and customize pyproject.toml 22 | run: | 23 | # Download dasel to modify pyproject.toml 24 | curl -sSLf https://github.com/TomWright/dasel/releases/download/v2.0.2/dasel_linux_amd64 \ 25 | -L -o /tmp/dasel && chmod +x /tmp/dasel 26 | 27 | # Modify pyproject.toml to set the nightly version in the form of 28 | # x.y.z.postN, where N is the number of commits since the last release 29 | /tmp/dasel put -f pyproject.toml project.name -v aesara-nightly 30 | /tmp/dasel put -f pyproject.toml tool.hatch.version.raw-options.version_scheme -v post-release 31 | /tmp/dasel put -f pyproject.toml tool.hatch.version.raw-options.local_scheme -v no-local-version 32 | 33 | # Install build prerequisites 34 | python -m pip install -U pip build 35 | - name: Build the sdist 36 | run: python -m build --sdist . 37 | - uses: pypa/gh-action-pypi-publish@release/v1 38 | with: 39 | password: ${{ secrets.nightly_pypi_secret }} 40 | -------------------------------------------------------------------------------- /aesara/sandbox/minimal.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from aesara.graph.basic import Apply 4 | from aesara.graph.op import Op 5 | from aesara.tensor.type import lscalar 6 | 7 | 8 | class Minimal(Op): 9 | # TODO : need description for class 10 | 11 | # if the Op has any attributes, consider using them in the eq function. 12 | # If two Apply nodes have the same inputs and the ops compare equal... 13 | # then they will be MERGED so they had better have computed the same thing! 14 | 15 | __props__ = () 16 | 17 | def __init__(self): 18 | # If you put things here, think about whether they change the outputs 19 | # computed by # self.perform() 20 | # - If they do, then you should take them into consideration in 21 | # __eq__ and __hash__ 22 | # - If they do not, then you should not use them in 23 | # __eq__ and __hash__ 24 | 25 | super().__init__() 26 | 27 | def make_node(self, *args): 28 | # HERE `args` must be AESARA VARIABLES 29 | return Apply(op=self, inputs=args, outputs=[lscalar()]) 30 | 31 | def perform(self, node, inputs, out_): 32 | (output,) = out_ 33 | # HERE `inputs` are PYTHON OBJECTS 34 | 35 | # do what you want here, 36 | # but do not modify any of the arguments [inplace]. 37 | print("perform got %i arguments" % len(inputs)) 38 | 39 | print("Max of input[0] is ", np.max(inputs[0])) 40 | 41 | # return some computed value. 42 | # do not return something that is aliased to one of the inputs. 43 | output[0] = np.asarray(0, dtype="int64") 44 | 45 | 46 | minimal = Minimal() 47 | -------------------------------------------------------------------------------- /aesara/tensor/nnet/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | warnings.warn( 5 | "The module `aesara.tensor.nnet` is deprecated and will " 6 | "be removed from Aesara in version 2.9.0", 7 | DeprecationWarning, 8 | stacklevel=2, 9 | ) 10 | 11 | import aesara.tensor.nnet.rewriting 12 | from aesara.tensor.nnet.abstract_conv import ( 13 | abstract_conv2d, 14 | conv2d, 15 | conv2d_grad_wrt_inputs, 16 | conv2d_transpose, 17 | conv3d, 18 | separable_conv2d, 19 | ) 20 | from aesara.tensor.nnet.basic import ( 21 | binary_crossentropy, 22 | categorical_crossentropy, 23 | confusion_matrix, 24 | crossentropy_categorical_1hot, 25 | crossentropy_categorical_1hot_grad, 26 | crossentropy_softmax_1hot, 27 | crossentropy_softmax_1hot_with_bias, 28 | crossentropy_softmax_1hot_with_bias_dx, 29 | crossentropy_softmax_argmax_1hot_with_bias, 30 | crossentropy_softmax_max_and_argmax_1hot, 31 | crossentropy_softmax_max_and_argmax_1hot_with_bias, 32 | crossentropy_to_crossentropy_with_softmax, 33 | crossentropy_to_crossentropy_with_softmax_with_bias, 34 | elu, 35 | graph_merge_softmax_with_crossentropy_softmax, 36 | h_softmax, 37 | logsoftmax, 38 | prepend_0_to_each_row, 39 | prepend_1_to_each_row, 40 | prepend_scalar_to_each_row, 41 | relu, 42 | selu, 43 | sigmoid_binary_crossentropy, 44 | softmax, 45 | softmax_grad_legacy, 46 | softmax_legacy, 47 | softmax_simplifier, 48 | softmax_with_bias, 49 | softsign, 50 | ) 51 | from aesara.tensor.nnet.batchnorm import batch_normalization 52 | from aesara.tensor.nnet.sigm import hard_sigmoid, ultra_fast_sigmoid 53 | -------------------------------------------------------------------------------- /aesara/compile/__init__.py: -------------------------------------------------------------------------------- 1 | from aesara.compile.function.pfunc import pfunc, rebuild_collect_shared 2 | from aesara.compile.function.types import ( 3 | AliasedMemoryError, 4 | Function, 5 | FunctionMaker, 6 | Supervisor, 7 | UnusedInputError, 8 | alias_root, 9 | convert_function_input, 10 | fgraph_updated_vars, 11 | get_info_on_inputs, 12 | infer_reuse_pattern, 13 | insert_deepcopy, 14 | orig_function, 15 | std_fgraph, 16 | view_tree_set, 17 | ) 18 | from aesara.compile.io import In, Out, SymbolicInput, SymbolicOutput 19 | from aesara.compile.mode import ( 20 | FAST_COMPILE, 21 | FAST_RUN, 22 | JAX, 23 | NUMBA, 24 | OPT_FAST_COMPILE, 25 | OPT_FAST_RUN, 26 | OPT_FAST_RUN_STABLE, 27 | OPT_MERGE, 28 | OPT_NONE, 29 | OPT_O2, 30 | OPT_O3, 31 | OPT_STABILIZE, 32 | OPT_UNSAFE, 33 | AddDestroyHandler, 34 | AddFeatureOptimizer, 35 | Mode, 36 | PrintCurrentFunctionGraph, 37 | get_default_mode, 38 | get_mode, 39 | instantiated_default_mode, 40 | local_useless, 41 | optdb, 42 | predefined_linkers, 43 | predefined_modes, 44 | predefined_optimizers, 45 | register_linker, 46 | register_mode, 47 | register_optimizer, 48 | ) 49 | from aesara.compile.monitormode import MonitorMode 50 | from aesara.compile.ops import ( 51 | DeepCopyOp, 52 | FromFunctionOp, 53 | ViewOp, 54 | as_op, 55 | deep_copy_op, 56 | register_deep_copy_op_c_code, 57 | register_view_op_c_code, 58 | view_op, 59 | ) 60 | from aesara.compile.profiling import ProfileStats 61 | from aesara.compile.sharedvalue import SharedVariable, shared, shared_constructor 62 | -------------------------------------------------------------------------------- /doc/troubleshoot/profiling_example_out.prof: -------------------------------------------------------------------------------- 1 | Function profiling 2 | ================== 3 | Message: None 4 | Time in 1 calls to Function.__call__: 5.698204e-05s 5 | Time in Function.vm.__call__: 1.192093e-05s (20.921%) 6 | Time in thunks: 6.198883e-06s (10.879%) 7 | Total compile time: 3.642474e+00s 8 | Aesara rewrite time: 7.326508e-02s 9 | Aesara validate time: 3.712177e-04s 10 | Aesara Linker time (includes C, CUDA code generation/compiling): 9.584920e-01s 11 | 12 | Class 13 | --- 14 | <% time>