13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 | {% endif %}
--------------------------------------------------------------------------------
/pytensor/misc/frozendict.py:
--------------------------------------------------------------------------------
1 | # License : https://github.com/slezica/python-frozendict/blob/master/LICENSE.txt
2 |
3 |
4 | import functools
5 | import operator
6 | from collections.abc import Mapping
7 |
8 |
9 | class frozendict(Mapping):
10 | """
11 | An immutable wrapper around dictionaries that implements the complete :py:class:`collections.abc.Mapping`
12 | interface. It can be used as a drop-in replacement for dictionaries where immutability and ordering are desired.
13 | """
14 |
15 | dict_cls = dict
16 |
17 | def __init__(self, *args, **kwargs):
18 | self._dict = self.dict_cls(*args, **kwargs)
19 | self._hash = None
20 |
21 | def __getitem__(self, key):
22 | return self._dict[key]
23 |
24 | def __contains__(self, key):
25 | return key in self._dict
26 |
27 | def copy(self, **add_or_replace):
28 | return self.__class__(self, **add_or_replace)
29 |
30 | def __iter__(self):
31 | return iter(self._dict)
32 |
33 | def __len__(self):
34 | return len(self._dict)
35 |
36 | def __repr__(self):
37 | return f"<{self.__class__.__name__} {self._dict!r}>"
38 |
39 | def __hash__(self):
40 | if self._hash is None:
41 | hashes = map(hash, self.items())
42 | self._hash = functools.reduce(operator.xor, hashes, 0)
43 |
44 | return self._hash
45 |
--------------------------------------------------------------------------------
/pytensor/d3viz/js/d3-context-menu.js:
--------------------------------------------------------------------------------
1 | d3.contextMenu = function (menu, openCallback) {
2 |
3 | // create the div element that will hold the context menu
4 | d3.selectAll('.d3-context-menu').data([1])
5 | .enter()
6 | .append('div')
7 | .attr('class', 'd3-context-menu');
8 |
9 | // close menu
10 | d3.select('body').on('click.d3-context-menu', function() {
11 | d3.select('.d3-context-menu').style('display', 'none');
12 | });
13 |
14 | // this gets executed when a contextmenu event occurs
15 | return function(data, index) {
16 | var elm = this;
17 |
18 | d3.selectAll('.d3-context-menu').html('');
19 | var list = d3.selectAll('.d3-context-menu').append('ul');
20 | list.selectAll('li').data(menu).enter()
21 | .append('li')
22 | .html(function(d) {
23 | return d.title;
24 | })
25 | .on('click', function(d, i) {
26 | d.action(elm, data, index);
27 | d3.select('.d3-context-menu').style('display', 'none');
28 | });
29 |
30 | // the openCallback allows an action to fire before the menu is displayed
31 | // an example usage would be closing a tooltip
32 | if (openCallback) openCallback(data, index);
33 |
34 | // display context menu
35 | d3.select('.d3-context-menu')
36 | .style('left', (d3.event.pageX - 2) + 'px')
37 | .style('top', (d3.event.pageY - 2) + 'px')
38 | .style('display', 'block');
39 |
40 | d3.event.preventDefault();
41 | };
42 | };
43 |
--------------------------------------------------------------------------------
/doc/library/d3viz/examples/d3viz/js/d3-context-menu.js:
--------------------------------------------------------------------------------
1 | d3.contextMenu = function (menu, openCallback) {
2 |
3 | // create the div element that will hold the context menu
4 | d3.selectAll('.d3-context-menu').data([1])
5 | .enter()
6 | .append('div')
7 | .attr('class', 'd3-context-menu');
8 |
9 | // close menu
10 | d3.select('body').on('click.d3-context-menu', function() {
11 | d3.select('.d3-context-menu').style('display', 'none');
12 | });
13 |
14 | // this gets executed when a contextmenu event occurs
15 | return function(data, index) {
16 | var elm = this;
17 |
18 | d3.selectAll('.d3-context-menu').html('');
19 | var list = d3.selectAll('.d3-context-menu').append('ul');
20 | list.selectAll('li').data(menu).enter()
21 | .append('li')
22 | .html(function(d) {
23 | return d.title;
24 | })
25 | .on('click', function(d, i) {
26 | d.action(elm, data, index);
27 | d3.select('.d3-context-menu').style('display', 'none');
28 | });
29 |
30 | // the openCallback allows an action to fire before the menu is displayed
31 | // an example usage would be closing a tooltip
32 | if (openCallback) openCallback(data, index);
33 |
34 | // display context menu
35 | d3.select('.d3-context-menu')
36 | .style('left', (d3.event.pageX - 2) + 'px')
37 | .style('top', (d3.event.pageY - 2) + 'px')
38 | .style('display', 'block');
39 |
40 | d3.event.preventDefault();
41 | };
42 | };
43 |
--------------------------------------------------------------------------------
/doc/tutorial/index.rst:
--------------------------------------------------------------------------------
1 |
2 | .. _tutorial:
3 |
4 | ========
5 | Tutorial
6 | ========
7 |
8 | Let us start an interactive session (e.g. with ``python`` or ``ipython``) and import PyTensor.
9 |
10 | >>> from pytensor import *
11 |
12 | Several of the symbols you will need to use are in the ``tensor`` subpackage
13 | of PyTensor. Let us import that subpackage under a handy name like
14 | ``at`` (the tutorials will frequently use this convention).
15 |
16 | >>> import pytensor.tensor as pt
17 |
18 | If that succeeded you are ready for the tutorial, otherwise check your
19 | installation (see :ref:`install`).
20 |
21 | Throughout the tutorial, bear in mind that there is a :ref:`glossary` as well
22 | as *index* and *modules* links in the upper-right corner of each page to help
23 | you out.
24 |
25 | Basics
26 | ------
27 |
28 | .. toctree::
29 |
30 | adding
31 | examples
32 | gradients
33 | conditions
34 | loop
35 | shape_info
36 | broadcasting
37 |
38 | Advanced
39 | --------
40 |
41 | .. toctree::
42 |
43 | sparse
44 | prng
45 |
46 | Advanced configuration and debugging
47 | ------------------------------------
48 |
49 | .. toctree::
50 |
51 | modes
52 | printing_drawing
53 | debug_faq
54 | nan_tutorial
55 | profiling
56 |
57 | Further reading
58 | ---------------
59 |
60 | .. toctree::
61 |
62 | loading_and_saving
63 | aliasing
64 | multi_cores
65 | faq_tutorial
66 |
--------------------------------------------------------------------------------
/doc/generate_dtype_tensor_table.py:
--------------------------------------------------------------------------------
1 | letters = [
2 | ('b', 'int8'),
3 | ('w', 'int16'),
4 | ('i', 'int32'),
5 | ('l', 'int64'),
6 | ('d', 'float64'),
7 | ('f', 'float32'),
8 | ('c', 'complex64'),
9 | ('z', 'complex128') ]
10 |
11 | shapes = [
12 | ('scalar', ()),
13 | ('vector', (False,)),
14 | ('row', (True, False)),
15 | ('col', (False, True)),
16 | ('matrix', (False,False)),
17 | ('tensor3', (False,False,False)),
18 | ('tensor4', (False,False,False,False)),
19 | ('tensor5', (False,False,False,False,False)),
20 | ('tensor6', (False,) * 6),
21 | ('tensor7', (False,) * 7),]
22 |
23 | hdr = '============ =========== ==== ================ ==================================='
24 | print(hdr)
25 | print('Constructor dtype ndim shape broadcastable')
26 | print(hdr)
27 | for letter in letters:
28 | for shape in shapes:
29 | suff = ',)' if len(shape[1])==1 else ')'
30 | s = '(' + ','.join('1' if b else '?' for b in shape[1]) + suff
31 | if len(shape[1]) < 6 or len(set(shape[1])) > 1:
32 | broadcastable_str = str(shape[1])
33 | else:
34 | broadcastable_str = f'({shape[1][0]},) * {len(shape[1])}'
35 | print('%s%-10s %-10s %-4s %-15s %-20s' %(
36 | letter[0], shape[0], letter[1], len(shape[1]), s, broadcastable_str
37 | ))
38 | print(hdr)
39 |
--------------------------------------------------------------------------------
/tests/tensor/rewriting/test_einsum.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 |
3 | from pytensor.graph import ancestors, rewrite_graph
4 | from pytensor.tensor import einsum, specify_shape, tensor
5 | from pytensor.tensor.einsum import Einsum
6 |
7 |
8 | specialize_rewrite = partial(rewrite_graph, include=("specialize",), clone=True)
9 |
10 |
11 | def test_einsum_optimization():
12 | a = tensor("a", shape=(None, None))
13 | b = tensor("b", shape=(None, None))
14 | c = tensor("c", shape=(None, None))
15 |
16 | dynamic_shape_einsum = einsum("ij,ij,jk->ik", a, b, c)
17 | assert not dynamic_shape_einsum.owner.op.optimized
18 |
19 | rewritten_out = specialize_rewrite(dynamic_shape_einsum)
20 | assert isinstance(rewritten_out.owner.op, Einsum)
21 |
22 | a = specify_shape(a, (2, 3))
23 | b = specify_shape(b, (2, 3))
24 | c = specify_shape(c, (3, 5))
25 |
26 | static_shape_einsum = dynamic_shape_einsum.owner.clone_with_new_inputs(
27 | [a, b, c]
28 | ).default_output()
29 | assert not static_shape_einsum.owner.op.optimized
30 |
31 | rewritten_out = specialize_rewrite(static_shape_einsum)
32 | # Einsum was inlined because it was optimized
33 | assert not isinstance(rewritten_out.owner.op, Einsum)
34 | # Sanity check that it's not buried in the graph
35 | assert not any(
36 | isinstance(var.owner.op, Einsum)
37 | for var in ancestors([rewritten_out])
38 | if var.owner
39 | )
40 |
--------------------------------------------------------------------------------
/tests/link/jax/test_blockwise.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | from pytensor import config
5 | from pytensor.tensor import tensor
6 | from pytensor.tensor.blockwise import Blockwise
7 | from pytensor.tensor.math import Dot, matmul
8 | from tests.link.jax.test_basic import compare_jax_and_py
9 | from tests.tensor.test_blockwise import check_blockwise_runtime_broadcasting
10 |
11 |
12 | jax = pytest.importorskip("jax")
13 |
14 |
15 | def test_runtime_broadcasting():
16 | check_blockwise_runtime_broadcasting("JAX")
17 |
18 |
19 | # Equivalent blockwise to matmul but with dumb signature
20 | odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")
21 |
22 |
23 | @pytest.mark.parametrize("matmul_op", (matmul, odd_matmul))
24 | def test_matmul(matmul_op):
25 | rng = np.random.default_rng(14)
26 | a = tensor("a", shape=(2, 3, 5))
27 | b = tensor("b", shape=(2, 5, 3))
28 | test_values = [
29 | rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b)
30 | ]
31 |
32 | out = matmul_op(a, b)
33 | assert isinstance(out.owner.op, Blockwise)
34 | fn, _ = compare_jax_and_py([a, b], [out], test_values)
35 |
36 | # Check we are not adding any unnecessary stuff
37 | jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))
38 | jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul")
39 | expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values))
40 | assert jaxpr == expected_jaxpr
41 |
--------------------------------------------------------------------------------
/doc/library/tensor/fft.rst:
--------------------------------------------------------------------------------
1 | .. _libdoc_tensor_fft:
2 |
3 | ==============================================
4 | :mod:`tensor.fft` -- Fast Fourier Transforms
5 | ==============================================
6 |
7 | Performs Fast Fourier Transforms (FFT).
8 |
9 | FFT gradients are implemented as the opposite Fourier transform of the output gradients.
10 |
11 | .. warning ::
12 | The real and imaginary parts of the Fourier domain arrays are stored as a pair of float
13 | arrays, emulating complex. Since pytensor has limited support for complex
14 | number operations, care must be taken to manually implement operations such as gradients.
15 |
16 | .. automodule:: pytensor.tensor.fft
17 | :members: rfft, irfft
18 |
19 | For example, the code below performs the real input FFT of a box function,
20 | which is a sinc function. The absolute value is plotted, since the phase
21 | oscillates due to the box function being shifted to the middle of the array.
22 |
23 | .. testcode::
24 |
25 | import numpy as np
26 | import pytensor
27 | import pytensor.tensor as pt
28 | from pytensor.tensor import fft
29 |
30 | x = pt.matrix('x', dtype='float64')
31 |
32 | rfft = fft.rfft(x, norm='ortho')
33 | f_rfft = pytensor.function([x], rfft)
34 |
35 | N = 1024
36 | box = np.zeros((1, N), dtype='float64')
37 | box[:, N//2-10: N//2+10] = 1
38 |
39 | out = f_rfft(box)
40 | c_out = np.asarray(out[0, :, 0] + 1j*out[0, :, 1])
41 | abs_out = abs(c_out)
42 |
43 | .. image:: plot_fft.png
44 |
--------------------------------------------------------------------------------
/pytensor/link/pytorch/dispatch/extra_ops.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
4 | from pytensor.tensor.extra_ops import CumOp, Repeat, Unique
5 |
6 |
7 | @pytorch_funcify.register(CumOp)
8 | def pytorch_funcify_Cumop(op, **kwargs):
9 | axis = op.axis
10 | mode = op.mode
11 |
12 | def cumop(x):
13 | if mode == "add":
14 | return torch.cumsum(x, dim=axis)
15 | else:
16 | return torch.cumprod(x, dim=axis)
17 |
18 | return cumop
19 |
20 |
21 | @pytorch_funcify.register(Repeat)
22 | def pytorch_funcify_Repeat(op, **kwargs):
23 | axis = op.axis
24 |
25 | def repeat(x, repeats):
26 | return x.repeat_interleave(repeats, dim=axis)
27 |
28 | return repeat
29 |
30 |
31 | @pytorch_funcify.register(Unique)
32 | def pytorch_funcify_Unique(op, **kwargs):
33 | return_index = op.return_index
34 |
35 | if return_index:
36 | # TODO: evaluate whether is worth implementing this param
37 | # (see https://github.com/pytorch/pytorch/issues/36748)
38 | raise NotImplementedError("return_index is not implemented for pytorch")
39 |
40 | axis = op.axis
41 | return_inverse = op.return_inverse
42 | return_counts = op.return_counts
43 |
44 | def unique(x):
45 | return torch.unique(
46 | x,
47 | sorted=True,
48 | return_inverse=return_inverse,
49 | return_counts=return_counts,
50 | dim=axis,
51 | )
52 |
53 | return unique
54 |
--------------------------------------------------------------------------------
/pytensor/link/jax/dispatch/sparse.py:
--------------------------------------------------------------------------------
1 | import jax.experimental.sparse as jsp
2 | from scipy.sparse import spmatrix
3 |
4 | from pytensor.graph.basic import Constant
5 | from pytensor.link.jax.dispatch import jax_funcify, jax_typify
6 | from pytensor.sparse.math import Dot, StructuredDot
7 | from pytensor.sparse.type import SparseTensorType
8 |
9 |
10 | @jax_typify.register(spmatrix)
11 | def jax_typify_spmatrix(matrix, dtype=None, **kwargs):
12 | # Note: This changes the type of the constants from CSR/CSC to BCOO
13 | # We could add BCOO as a PyTensor type but this would only be useful for JAX graphs
14 | # and it would break the premise of one graph -> multiple backends.
15 | # The same situation happens with RandomGenerators...
16 | return jsp.BCOO.from_scipy_sparse(matrix)
17 |
18 |
19 | @jax_funcify.register(Dot)
20 | @jax_funcify.register(StructuredDot)
21 | def jax_funcify_sparse_dot(op, node, **kwargs):
22 | for input in node.inputs:
23 | if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant):
24 | raise NotImplementedError(
25 | "JAX sparse dot only implemented for constant sparse inputs"
26 | )
27 |
28 | if isinstance(node.outputs[0].type, SparseTensorType):
29 | raise NotImplementedError("JAX sparse dot only implemented for dense outputs")
30 |
31 | @jsp.sparsify
32 | def sparse_dot(x, y):
33 | out = x @ y
34 | if isinstance(out, jsp.BCOO):
35 | out = out.todense()
36 | return out
37 |
38 | return sparse_dot
39 |
--------------------------------------------------------------------------------
/tests/xtensor/test_reduction.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 |
4 | pytest.importorskip("xarray")
5 |
6 | from pytensor.xtensor.type import xtensor
7 | from tests.xtensor.util import xr_arange_like, xr_assert_allclose, xr_function
8 |
9 |
10 | @pytest.mark.parametrize(
11 | "dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
12 | )
13 | @pytest.mark.parametrize(
14 | "method",
15 | ["sum", "prod", "all", "any", "max", "min", "mean", "cumsum", "cumprod"],
16 | )
17 | def test_reduction(method, dim):
18 | x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
19 | out = getattr(x, method)(dim=dim)
20 |
21 | fn = xr_function([x], out)
22 | x_test = xr_arange_like(x)
23 |
24 | xr_assert_allclose(
25 | fn(x_test),
26 | getattr(x_test, method)(dim=dim),
27 | )
28 |
29 |
30 | @pytest.mark.parametrize(
31 | "dim", [..., None, "a", ("c", "a")], ids=["Ellipsis", "None", "a", "(a, c)"]
32 | )
33 | @pytest.mark.parametrize("method", ["std", "var"])
34 | def test_std_var(method, dim):
35 | x = xtensor("x", dims=("a", "b", "c"), shape=(3, 5, 7))
36 | out = [
37 | getattr(x, method)(dim=dim),
38 | getattr(x, method)(dim=dim, ddof=2),
39 | ]
40 |
41 | fn = xr_function([x], out)
42 | x_test = xr_arange_like(x)
43 | results = fn(x_test)
44 |
45 | xr_assert_allclose(
46 | results[0],
47 | getattr(x_test, method)(dim=dim),
48 | )
49 |
50 | xr_assert_allclose(
51 | results[1],
52 | getattr(x_test, method)(dim=dim, ddof=2),
53 | )
54 |
--------------------------------------------------------------------------------
/doc/.templates/layout.html:
--------------------------------------------------------------------------------
1 | {% extends "!layout.html" %}
2 |
3 | {% block footer %}
4 | {{ super() }}
5 |
16 |
17 |
18 |
35 |
36 |
39 | {% endblock %}
40 |
--------------------------------------------------------------------------------
/tests/link/pytorch/test_blockwise.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 |
4 | import pytensor
5 | import pytensor.tensor as pt
6 | from pytensor.graph.basic import Apply
7 | from pytensor.graph.op import Op
8 | from pytensor.tensor.blockwise import Blockwise
9 |
10 |
11 | torch = pytest.importorskip("torch")
12 | basic = pytest.importorskip("pytensor.link.pytorch.dispatch.basic")
13 |
14 |
15 | class BatchedTestOp(Op):
16 | gufunc_signature = "(m,n),(n,p)->(m,p)"
17 |
18 | def __init__(self, final_shape):
19 | super().__init__()
20 | self.final_shape = final_shape
21 | self.call_shapes = []
22 |
23 | def make_node(self, *args):
24 | return Apply(self, list(args), [pt.matrix("_", shape=self.final_shape)])
25 |
26 | def perform(self, *_):
27 | raise RuntimeError("In perform")
28 |
29 |
30 | @basic.pytorch_funcify.register(BatchedTestOp)
31 | def evaluate_test_op(op, **_):
32 | def func(a, b):
33 | op.call_shapes.extend(map(torch.Tensor.size, [a, b]))
34 | return a @ b
35 |
36 | return func
37 |
38 |
39 | def test_blockwise_broadcast():
40 | _x = np.random.rand(5, 1, 2, 3)
41 | _y = np.random.rand(3, 3, 2)
42 |
43 | x = pt.tensor4("x", shape=(5, 1, 2, 3))
44 | y = pt.tensor3("y", shape=(3, 3, 2))
45 | op = BatchedTestOp((2, 2))
46 | z = Blockwise(op)(x, y)
47 |
48 | f = pytensor.function([x, y], z, mode="PYTORCH")
49 | res = f(_x, _y)
50 | assert tuple(res.shape) == (5, 3, 2, 2)
51 | np.testing.assert_allclose(res, _x @ _y)
52 | assert op.call_shapes == [(2, 3), (3, 2)]
53 |
--------------------------------------------------------------------------------
/doc/library/typed_list.rst:
--------------------------------------------------------------------------------
1 | .. _libdoc_typed_list:
2 |
3 | ===============================
4 | :mod:`typed_list` -- Typed List
5 | ===============================
6 |
7 | .. note::
8 |
9 | This has been added in release 0.7.
10 |
11 | .. note::
12 |
13 | This works, but is not well integrated with the rest of PyTensor. If
14 | speed is important, it is probably better to pad to a dense
15 | tensor.
16 |
17 | This is a type that represents a list in PyTensor. All elements must have
18 | the same PyTensor type. Here is an example:
19 |
20 | >>> import pytensor.typed_list
21 | >>> tl = pytensor.typed_list.TypedListType(pytensor.tensor.fvector)()
22 | >>> v = pytensor.tensor.fvector()
23 | >>> o = pytensor.typed_list.append(tl, v)
24 | >>> f = pytensor.function([tl, v], o)
25 | >>> f([[1, 2, 3], [4, 5]], [2])
26 | [array([ 1., 2., 3.], dtype=float32), array([ 4., 5.], dtype=float32), array([ 2.], dtype=float32)]
27 |
28 | A second example with Scan. Scan doesn't yet have direct support of
29 | TypedList, so you can only use it as non_sequences (not in sequences or
30 | as outputs):
31 |
32 | >>> import pytensor.typed_list
33 | >>> a = pytensor.typed_list.TypedListType(pytensor.tensor.fvector)()
34 | >>> l = pytensor.typed_list.length(a)
35 | >>> s, _ = pytensor.scan(fn=lambda i, tl: tl[i].sum(),
36 | ... non_sequences=[a],
37 | ... sequences=[pytensor.tensor.arange(l, dtype='int64')])
38 | >>> f = pytensor.function([a], s)
39 | >>> f([[1, 2, 3], [4, 5]])
40 | array([ 6., 9.], dtype=float32)
41 |
42 | .. automodule:: pytensor.typed_list.basic
43 | :members:
44 |
--------------------------------------------------------------------------------
/tests/graph/test_types.py:
--------------------------------------------------------------------------------
1 | import pytest
2 |
3 | from pytensor.graph.basic import Variable
4 | from pytensor.graph.type import Type
5 |
6 |
7 | class MyType(Type):
8 | def __init__(self, thingy):
9 | self.thingy = thingy
10 |
11 | def filter(self, *args, **kwargs):
12 | raise NotImplementedError()
13 |
14 | def __eq__(self, other):
15 | return isinstance(other, MyType) and other.thingy == self.thingy
16 |
17 | def __str__(self):
18 | return f"R{self.thingy}"
19 |
20 | def __repr__(self):
21 | return f"R{self.thingy}"
22 |
23 |
24 | class MyType2(MyType):
25 | def is_super(self, other):
26 | if self.thingy <= other.thingy:
27 | return True
28 |
29 |
30 | def test_is_super():
31 | t1 = MyType(1)
32 | t2 = MyType(2)
33 |
34 | assert t1.is_super(t2) is None
35 |
36 | t1_2 = MyType(1)
37 | assert t1.is_super(t1_2)
38 |
39 |
40 | def test_in_same_class():
41 | t1 = MyType(1)
42 | t2 = MyType(2)
43 |
44 | assert t1.in_same_class(t2) is False
45 |
46 | t1_2 = MyType(1)
47 | assert t1.in_same_class(t1_2)
48 |
49 |
50 | def test_convert_variable():
51 | t1 = MyType(1)
52 | v1 = Variable(MyType(1), None, None)
53 | v2 = Variable(MyType(2), None, None)
54 | v3 = Variable(MyType2(0), None, None)
55 |
56 | assert t1.convert_variable(v1) is v1
57 | assert t1.convert_variable(v2) is None
58 |
59 | with pytest.raises(NotImplementedError):
60 | t1.convert_variable(v3)
61 |
62 |
63 | def test_default_clone():
64 | mt = MyType(1)
65 | assert isinstance(mt.clone(1), MyType)
66 |
--------------------------------------------------------------------------------
/tests/link/pytorch/test_shape.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import pytensor.tensor as pt
4 | from pytensor.configdefaults import config
5 | from pytensor.tensor.shape import Shape, Shape_i, reshape
6 | from pytensor.tensor.type import iscalar, vector
7 | from tests.link.pytorch.test_basic import compare_pytorch_and_py
8 |
9 |
10 | def test_pytorch_shape_ops():
11 | x_np = np.zeros((20, 3))
12 | x = Shape()(pt.as_tensor_variable(x_np))
13 |
14 | compare_pytorch_and_py([], [x], [])
15 |
16 | x = Shape_i(1)(pt.as_tensor_variable(x_np))
17 |
18 | compare_pytorch_and_py([], [x], [])
19 |
20 |
21 | def test_pytorch_specify_shape():
22 | in_pt = pt.matrix("in")
23 | x = pt.specify_shape(in_pt, (4, None))
24 | compare_pytorch_and_py([in_pt], [x], [np.ones((4, 5)).astype(config.floatX)])
25 |
26 | # When used to assert two arrays have similar shapes
27 | in_pt = pt.matrix("in")
28 | shape_pt = pt.matrix("shape")
29 | x = pt.specify_shape(in_pt, shape_pt.shape)
30 |
31 | compare_pytorch_and_py(
32 | [in_pt, shape_pt],
33 | [x],
34 | [np.ones((4, 5)).astype(config.floatX), np.ones((4, 5)).astype(config.floatX)],
35 | )
36 |
37 |
38 | def test_pytorch_Reshape_constant():
39 | a = vector("a")
40 | x = reshape(a, (2, 2))
41 |
42 | compare_pytorch_and_py([a], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
43 |
44 |
45 | def test_pytorch_Reshape_dynamic():
46 | a = vector("a")
47 | shape_pt = iscalar("b")
48 | x = reshape(a, (shape_pt, shape_pt))
49 |
50 | compare_pytorch_and_py(
51 | [a, shape_pt], [x], [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2]
52 | )
53 |
--------------------------------------------------------------------------------
/doc/library/index.rst:
--------------------------------------------------------------------------------
1 |
2 | .. _libdoc:
3 | .. _Library documentation:
4 |
5 | =================
6 | API Documentation
7 | =================
8 |
9 | This documentation covers PyTensor module-wise. This is suited to finding the
10 | Types and Ops that you can use to build and compile expression graphs.
11 |
12 | Modules
13 | =======
14 |
15 | .. toctree::
16 | :maxdepth: 1
17 |
18 | compile/index
19 | config
20 | d3viz/index
21 | graph/index
22 | gradient
23 | printing
24 | scan
25 | sparse/index
26 | tensor/index
27 | typed_list
28 | xtensor/index
29 |
30 | .. module:: pytensor
31 | :platform: Unix, Windows
32 | :synopsis: PyTensor top-level import
33 | .. moduleauthor:: LISA
34 |
35 | There are also some top-level imports that you might find more convenient:
36 |
37 | Graph
38 | =====
39 |
40 | .. function:: shared(...)
41 |
42 | Alias for :func:`pytensor.compile.sharedvalue.shared`
43 |
44 | .. function:: function(...)
45 |
46 | Alias for :func:`pytensor.compile.function.function`
47 |
48 | .. autofunction:: pytensor.clone_replace(...)
49 |
50 | Alias for :func:`pytensor.graph.basic.clone_replace`
51 |
52 | Control flow
53 | ============
54 |
55 | .. autofunction:: pytensor.scan(...)
56 |
57 | Alias for :func:`pytensor.scan.basic.scan`
58 |
59 | Convert to Variable
60 | ====================
61 |
62 | .. autofunction:: pytensor.as_symbolic(...)
63 |
64 | Wrap JAX functions
65 | ==================
66 |
67 | .. autofunction:: wrap_jax(...)
68 |
69 | Alias for :func:`pytensor.link.jax.ops.wrap_jax`
70 |
71 | Debug
72 | =====
73 |
74 | .. autofunction:: pytensor.dprint(...)
75 |
76 | Alias for :func:`pytensor.printing.debugprint`
77 |
--------------------------------------------------------------------------------
/doc/tutorial/profiling_example_out.prof:
--------------------------------------------------------------------------------
1 | Function profiling
2 | ==================
3 | Message: None
4 | Time in 1 calls to Function.__call__: 5.698204e-05s
5 | Time in Function.vm.__call__: 1.192093e-05s (20.921%)
6 | Time in thunks: 6.198883e-06s (10.879%)
7 | Total compile time: 3.642474e+00s
8 | PyTensor rewrite time: 7.326508e-02s
9 | PyTensor validate time: 3.712177e-04s
10 | PyTensor Linker time (includes C, CUDA code generation/compiling): 9.584920e-01s
11 |
12 | Class
13 | ---
14 | <% time>