├── .github
├── dependabot.yml
└── workflows
│ ├── build_and_test.yml
│ ├── docs.yml
│ ├── docs
│ └── requirements.txt
│ └── lint.yml
├── .gitignore
├── .pylintrc
├── CITATION.cff
├── LICENSE
├── README.md
├── docs
├── _static
│ └── css
│ │ └── custom.css
├── conf.py
├── densitytensor.rst
├── densitytensor
│ ├── all_zeros_densitytensor.rst
│ ├── densitytensor_to_measured_densitytensor.rst
│ ├── densitytensor_to_measurement_probabilities.rst
│ ├── get_densitytensor_to_expectation_func.rst
│ ├── get_densitytensor_to_sampled_expectation_func.rst
│ ├── get_params_to_densitytensor_func.rst
│ ├── kraus.rst
│ ├── partial_trace.rst
│ └── statetensor_to_densitytensor.rst
├── examples.md
├── gates.rst
├── getting_started.rst
├── index.rst
├── logo.svg
├── logo_dark_mode.svg
├── requirements.txt
├── statetensor.rst
├── statetensor
│ ├── all_zeros_statetensor.rst
│ ├── apply_gate.rst
│ ├── get_params_to_statetensor_func.rst
│ ├── get_params_to_unitarytensor_func.rst
│ ├── get_statetensor_to_expectation_func.rst
│ └── get_statetensor_to_sampled_expectation_func.rst
├── utils.rst
└── utils
│ ├── bitstrings_to_integers.rst
│ ├── check_circuit.rst
│ ├── integers_to_bitstrings.rst
│ ├── print_circuit.rst
│ ├── repeat_circuit.rst
│ ├── sample_bitstrings.rst
│ └── sample_integers.rst
├── examples
├── barren_plateaus.ipynb
├── classification.ipynb
├── generative_modelling.ipynb
├── heisenberg_vqe.ipynb
├── hidalgo_stamp.txt
├── maxcut_vqe.ipynb
├── noise_channel.ipynb
├── qaoa.ipynb
├── reducing_jit_compilation_time.ipynb
└── variational_inference.ipynb
├── qujax
├── __init__.py
├── densitytensor.py
├── densitytensor_observable.py
├── gates.py
├── statetensor.py
├── statetensor_observable.py
├── typing.py
├── utils.py
└── version.py
├── setup.py
└── tests
├── test_densitytensor.py
├── test_expectations.py
├── test_gates.py
├── test_statetensor.py
└── test_utils.py
/.github/dependabot.yml:
--------------------------------------------------------------------------------
1 | version: 2
2 | updates:
3 | - package-ecosystem: "github-actions"
4 | directory: "/"
5 | schedule:
6 | interval: "weekly"
--------------------------------------------------------------------------------
/.github/workflows/build_and_test.yml:
--------------------------------------------------------------------------------
1 | name: Build and test
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | - develop
8 | push:
9 | branches:
10 | - develop
11 | - main
12 | release:
13 | types:
14 | - created
15 | - edited
16 |
17 | jobs:
18 | linux-checks:
19 | name: Linux - Build and test module
20 | runs-on: ubuntu-20.04
21 |
22 | steps:
23 | - uses: actions/checkout@v4
24 | with:
25 | fetch-depth: '0'
26 | - name: Set up Python 3.9
27 | uses: actions/setup-python@v5
28 | with:
29 | python-version: '3.9'
30 | - name: Build and test
31 | run: |
32 | ARTIFACTSDIR=${GITHUB_WORKSPACE}/wheelhouse
33 | rm -rf ${ARTIFACTSDIR} && mkdir ${ARTIFACTSDIR}
34 | python -m pip install --upgrade pip wheel build pytest
35 | python -m build
36 | for w in dist/*.whl ; do
37 | python -m pip install $w
38 | cp $w ${ARTIFACTSDIR}
39 | done
40 | cd tests
41 | pytest
42 | - uses: actions/upload-artifact@v4
43 | if: github.event_name == 'release'
44 | with:
45 | name: artefacts
46 | path: wheelhouse/
47 |
48 |
49 |
50 | publish_to_pypi:
51 | name: Publish to pypi
52 | if: github.event_name == 'release'
53 | needs: linux-checks
54 | runs-on: ubuntu-20.04
55 | steps:
56 | - name: Download all wheels
57 | uses: actions/download-artifact@v4
58 | with:
59 | path: wheelhouse
60 | - name: Put them all in the dist folder
61 | run: |
62 | mkdir dist
63 | for w in `find wheelhouse/ -type f -name "*.whl"` ; do cp $w dist/ ; done
64 | - name: Publish wheels
65 | uses: pypa/gh-action-pypi-publish@release/v1
66 | with:
67 | user: __token__
68 | password: ${{ secrets.PYPI_QUJAX_API_TOKEN }}
69 | verbose: true
70 |
71 |
--------------------------------------------------------------------------------
/.github/workflows/docs.yml:
--------------------------------------------------------------------------------
1 | name: Qujax Docs
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | - develop
8 | push:
9 | branches:
10 | - main
11 |
12 | jobs:
13 | build_docs:
14 | name: Build and publish docs
15 | runs-on: ubuntu-20.04
16 | steps:
17 | - uses: actions/checkout@v4
18 | with:
19 | fetch-depth: '0'
20 | - name: Set up Python 3.9
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: '3.9'
24 | - name: Upgrade pip and install wheel
25 | run: pip install --upgrade pip wheel
26 | - name: Install qujax
27 | run: |
28 | pip install .
29 | - name: Install docs dependencies
30 | run: |
31 | pip install -r .github/workflows/docs/requirements.txt
32 | - name: Build docs
33 | timeout-minutes: 20
34 | run: |
35 | cd .github/workflows/docs
36 | mkdir qujax
37 | cd qujax
38 | sphinx-build ../../../../docs . -a
39 | - name: Upload docs as artefact
40 | uses: actions/upload-pages-artifact@v3
41 | with:
42 | path: .github/workflows/docs/qujax
43 |
44 | publish_docs:
45 | name: Publish docs
46 | if: github.event_name == 'push' && contains(github.ref_name, 'main')
47 | needs: build_docs
48 | runs-on: ubuntu-22.04
49 | permissions:
50 | pages: write
51 | id-token: write
52 | environment:
53 | name: github-pages
54 | url: ${{ steps.deployment.outputs.page_url }}
55 | steps:
56 | - name: Deploy to GitHub Pages
57 | id: deployment
58 | uses: actions/deploy-pages@v4
--------------------------------------------------------------------------------
/.github/workflows/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx ~= 7.2
2 | sphinx_rtd_theme
3 | myst-parser ~= 2.0
--------------------------------------------------------------------------------
/.github/workflows/lint.yml:
--------------------------------------------------------------------------------
1 | name: Lint python projects
2 |
3 | on:
4 | pull_request:
5 | branches:
6 | - main
7 | - develop
8 | push:
9 | branches:
10 | - main
11 | - develop
12 |
13 | jobs:
14 | lint:
15 |
16 | runs-on: ubuntu-22.04
17 |
18 | steps:
19 | - uses: actions/checkout@v4
20 | - name: Set up Python 3.x
21 | uses: actions/setup-python@v5
22 | with:
23 | python-version: '3.x'
24 | - name: Update pip
25 | run: pip install --upgrade pip
26 | - name: Install black and pylint
27 | run: pip install black pylint
28 | - name: Check files are formatted with black
29 | run: |
30 | black --check .
31 | - name: Run pylint
32 | run: |
33 | pylint */
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.egg-info/
2 | .idea/
3 | dist
4 | .DS_Store
5 | .vscode
6 | .pytest_cache
7 | __pycache__
8 | *.ipynb_checkpoints/
9 | build
10 | Makefile
11 | make.bat
12 | _build
13 |
--------------------------------------------------------------------------------
/.pylintrc:
--------------------------------------------------------------------------------
1 | [MASTER]
2 | output-format=colorized
3 | disable=all
4 | enable=
5 | anomalous-backslash-in-string,
6 | assert-on-tuple,
7 | bad-indentation,
8 | bad-option-value,
9 | bad-reversed-sequence,
10 | bad-super-call,
11 | consider-merging-isinstance,
12 | continue-in-finally,
13 | dangerous-default-value,
14 | duplicate-argument-name,
15 | expression-not-assigned,
16 | function-redefined,
17 | inconsistent-mro,
18 | init-is-generator,
19 | line-too-long,
20 | lost-exception,
21 | missing-kwoa,
22 | mixed-line-endings,
23 | not-callable,
24 | no-value-for-parameter,
25 | nonexistent-operator,
26 | not-in-loop,
27 | pointless-statement,
28 | redefined-builtin,
29 | return-arg-in-generator,
30 | return-in-init,
31 | return-outside-function,
32 | simplifiable-if-statement,
33 | syntax-error,
34 | too-many-function-args,
35 | trailing-whitespace,
36 | undefined-variable,
37 | unexpected-keyword-arg,
38 | unhashable-dict-key,
39 | unnecessary-pass,
40 | unreachable,
41 | unrecognized-inline-option,
42 | unused-import,
43 | unnecessary-semicolon,
44 | unused-variable,
45 | unused-wildcard-import,
46 | wildcard-import,
47 | wrong-import-order,
48 | wrong-import-position,
49 | yield-outside-function
50 |
51 |
52 | # Ignore long lines containing URLs or pylint.
53 | ignore-long-lines=^(.*#\w*pylint: disable.*|\s*(# )??)$
54 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: "1.2.0"
2 | authors:
3 | - family-names: Duffield
4 | given-names: Samuel
5 | orcid: "https://orcid.org/0000-0002-8656-8734"
6 | - family-names: Matos
7 | given-names: Gabriel
8 | orcid: "https://orcid.org/0000-0002-3373-0128"
9 | - family-names: Johannsen
10 | given-names: Melf
11 | contact:
12 | - family-names: Duffield
13 | given-names: Samuel
14 | orcid: "https://orcid.org/0000-0002-8656-8734"
15 | doi: 10.5281/zenodo.8268973
16 | message: If you use this software, please cite our article in the
17 | Journal of Open Source Software.
18 | preferred-citation:
19 | authors:
20 | - family-names: Duffield
21 | given-names: Samuel
22 | orcid: "https://orcid.org/0000-0002-8656-8734"
23 | - family-names: Matos
24 | given-names: Gabriel
25 | orcid: "https://orcid.org/0000-0002-3373-0128"
26 | - family-names: Johannsen
27 | given-names: Melf
28 | date-published: 2023-09-12
29 | doi: 10.21105/joss.05504
30 | issn: 2475-9066
31 | issue: 89
32 | journal: Journal of Open Source Software
33 | publisher:
34 | name: Open Journals
35 | start: 5504
36 | title: "qujax: Simulating quantum circuits with JAX"
37 | type: article
38 | url: "https://joss.theoj.org/papers/10.21105/joss.05504"
39 | volume: 8
40 | title: "qujax: Simulating quantum circuits with JAX"
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # qujax
2 |
3 |
11 |
12 | [](https://pypi.org/project/qujax/)
13 | [](https://doi.org/10.21105/joss.05504)
14 |
15 | [**Documentation**](https://cqcl.github.io/qujax/) | [**Installation**](#installation) | [**Quick start**](#quick-start) | [**Examples**](https://cqcl.github.io/qujax/examples.html) | [**Contributing**](#contributing) | [**Citing qujax**](#citing-qujax)
16 |
17 | qujax is a [JAX](https://github.com/google/jax)-based Python library for the classical simulation of quantum circuits. It is designed to be *simple*, *fast* and *flexible*.
18 |
19 | It follows a functional programming design by translating circuits into pure functions. This allows qujax to [seamlessly interface with JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions), enabling direct access to its powerful automatic differentiation tools, just-in-time compiler, vectorization capabilities, GPU/TPU integration and growing ecosystem of packages.
20 |
21 | qujax can be used both for pure and for mixed quantum state simulation. It not only supports the standard gate set, but also allows user-defined custom operations, including general quantum channels, enabling the user to e.g. model device noise and errors.
22 |
23 | A summary of the core functionalities of qujax can be found in the [Quick start](#quick-start) section. More advanced use-cases, including the training of parameterised quantum circuits, can be found in the [Examples](https://cqcl.github.io/qujax/examples.html) section of the documentation.
24 |
25 |
26 | ## Installation
27 |
28 | qujax is [hosted on PyPI](https://pypi.org/project/qujax/) and can be installed via the pip package manager
29 | ```
30 | pip install qujax
31 | ```
32 |
33 | ## Quick start
34 |
35 | **Important note: qujax circuit parameters are expressed in units of $\pi$ (e.g. in the range $[0,2]$ as opposed to $[0, 2\pi]$)**.
36 |
37 | Start by defining the quantum gates making up the circuit, the qubits that they act on, and the indices of the parameters for each gate.
38 |
39 | A list of all gates can be found [here](https://github.com/CQCL/qujax/blob/main/qujax/gates.py) (custom operations can be included by [passing an array or function](https://cqcl.github.io/qujax/statetensor/get_params_to_statetensor_func.html) instead of a string).
40 |
41 | ```python
42 | from jax import numpy as jnp
43 | import qujax
44 |
45 | # List of quantum gates
46 | circuit_gates = ['H', 'Ry', 'CZ']
47 | # Indices of qubits the gates will be applied to
48 | circuit_qubit_inds = [[0], [0], [0, 1]]
49 | # Indices of parameters each parameterised gate will use
50 | circuit_params_inds = [[], [0], []]
51 |
52 | qujax.print_circuit(circuit_gates, circuit_qubit_inds, circuit_params_inds);
53 | # q0: -----H-----Ry[0]-----◯---
54 | # |
55 | # q1: ---------------------CZ--
56 | ```
57 |
58 | Translate the circuit to a pure function `param_to_st` that takes a set of parameters and an (optional) initial quantum state as its input.
59 |
60 | ```python
61 | param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
62 | circuit_qubit_inds,
63 | circuit_params_inds)
64 |
65 | param_to_st(jnp.array([0.1]))
66 | # Array([[0.58778524+0.j, 0. +0.j],
67 | # [0.80901706+0.j, 0. +0.j]], dtype=complex64)
68 | ```
69 |
70 | The optional initial state can be passed to `param_to_st` using the `statetensor_in` argument. When it is not provided, the initial state defaults to $\ket{0...0}$.
71 |
72 | Map the state to an expectation value by defining an observable using lists of Pauli matrices, the qubits they act on, and the associated coefficients.
73 |
74 | ```python
75 | st_to_expectation = qujax.get_statetensor_to_expectation_func([['Z']], [[0]], [1.])
76 | ```
77 |
78 | Combining `param_to_st` and `st_to_expectation` gives us a parameter to expectation function that can be automatically differentiated using JAX.
79 |
80 | ```python
81 | from jax import value_and_grad
82 |
83 | param_to_expectation = lambda param: st_to_expectation(param_to_st(param))
84 | expectation_and_grad = value_and_grad(param_to_expectation)
85 | expectation_and_grad(jnp.array([0.1]))
86 | # (Array(-0.3090171, dtype=float32),
87 | # Array([-2.987832], dtype=float32))
88 | ```
89 |
90 | Mixed state simulations are analogous to the above, but with calls to [`get_params_to_densitytensor_func`](https://cqcl.github.io/qujax/densitytensor/get_params_to_densitytensor_func.html) and [`get_densitytensor_to_expectation_func`](https://cqcl.github.io/qujax/densitytensor/get_densitytensor_to_expectation_func.html) instead.
91 |
92 | A more in-depth version of the above can be found in the [Getting started](https://cqcl.github.io/qujax/getting_started.html) section of the documentation. More advanced use-cases, including the training of parameterised quantum circuits, can be found in the [Examples](https://cqcl.github.io/qujax/examples.html) section of the documentation.
93 |
94 | ## Converting from TKET
95 |
96 | A [`pytket`](https://cqcl.github.io/tket/pytket/api/) circuit can be directly converted using the [`tk_to_qujax`](https://cqcl.github.io/pytket-qujax/api/api.html#pytket.extensions.qujax.qujax_convert.tk_to_qujax) and [`tk_to_qujax_symbolic`](https://cqcl.github.io/pytket-qujax/api/api.html#pytket.extensions.qujax.qujax_convert.tk_to_qujax_symbolic) functions in the [**`pytket-qujax`**](https://github.com/CQCL/pytket-qujax) extension. See [`pytket-qujax_heisenberg_vqe.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_heisenberg_vqe.ipynb) for an example.
97 |
98 | ## Contributing
99 |
100 | You can open a bug report or a feature request by creating a new [issue on GitHub](https://github.com/CQCL/qujax/issues).
101 |
102 | Pull requests are welcome! To open a new one, please go through the following steps:
103 |
104 | 1. First fork the repo and create your branch from [`develop`](https://github.com/CQCL/qujax/tree/develop).
105 | 2. Commit your code and tests.
106 | 4. Update the documentation, if required.
107 | 5. Check the code lints (run `black . --check` and `pylint */`).
108 | 6. Issue a pull request into the [`develop`](https://github.com/CQCL/qujax/tree/develop) branch.
109 |
110 | New commits on [`develop`](https://github.com/CQCL/qujax/tree/develop) will be merged into
111 | [`main`](https://github.com/CQCL/qujax/tree/main) in the next release.
112 |
113 |
114 | ## Citing qujax
115 |
116 | If you have used qujax in your code or research, we kindly ask that you cite it. You can use the following BibTeX entry for this:
117 |
118 | ```bibtex
119 | @article{qujax2023,
120 | author = {Duffield, Samuel and Matos, Gabriel and Johannsen, Melf},
121 | doi = {10.21105/joss.05504},
122 | journal = {Journal of Open Source Software},
123 | month = sep,
124 | number = {89},
125 | pages = {5504},
126 | title = {{qujax: Simulating quantum circuits with JAX}},
127 | url = {https://joss.theoj.org/papers/10.21105/joss.05504},
128 | volume = {8},
129 | year = {2023}
130 | }
131 | ```
--------------------------------------------------------------------------------
/docs/_static/css/custom.css:
--------------------------------------------------------------------------------
1 |
2 | .wy-nav-top{
3 | background-color: #203847
4 | }
5 |
6 | .wy-side-nav-search{
7 | background-color: white;
8 | }
9 |
10 | .icon.icon-home{
11 | color: #000000
12 | }
13 |
14 | .wy-menu-vertical p.caption{
15 | color: #85cfcb;
16 | }
17 |
18 | .wy-side-nav-search > a{
19 | color: #000000;
20 | }
21 |
22 | .wy-side-nav-search > div.version{
23 | color: #000000;
24 | }
25 |
26 | .sig {
27 | background: #85cfcb;
28 | }
29 |
30 | @media screen and (min-width: 1000px) {
31 | .wy-nav-content {
32 | max-width: 1000px;
33 | }
34 | }
35 |
36 | html.writer-html4 .rst-content dl:not(.docutils) > dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.citation):not(.glossary):not(.simple) > dt {
37 | display:block;
38 | background-color: #ebeef1;
39 | }
40 |
41 | #examples ul, ul.simple {
42 | list-style: none;
43 | }
44 |
45 | #examples ul li, ul.simple li {
46 | margin-bottom: 10px;
47 | }
48 |
49 | h1, h2, h3, h4, h5, h6 {
50 | color: #203847
51 | }
52 |
53 | div.toctree-wrapper .caption-text{
54 | color: #203847;
55 | }
56 |
57 | .rst-content .viewcode-back, .rst-content .viewcode-link {
58 | padding-left: 6px;
59 | }
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import importlib
4 | import inspect
5 | import pathlib
6 |
7 | sys.path.insert(0, os.path.abspath("..")) # pylint: disable=wrong-import-position
8 |
9 | from qujax.version import __version__
10 |
11 | # -- Project information -----------------------------------------------------
12 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
13 |
14 | project = "qujax"
15 | project_copyright = "2023, The qujax authors"
16 | author = "Sam Duffield, Gabriel Matos, Melf Johannsen"
17 | version = __version__
18 | release = __version__
19 |
20 | # -- General configuration ---------------------------------------------------
21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
22 |
23 | extensions = [
24 | "sphinx.ext.autodoc",
25 | "sphinx_rtd_theme",
26 | "sphinx.ext.napoleon",
27 | "sphinx.ext.mathjax",
28 | "sphinx.ext.linkcode",
29 | "myst_parser",
30 | ]
31 |
32 | templates_path = ["_templates"]
33 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"]
34 |
35 | # -- Options for HTML output -------------------------------------------------
36 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
37 |
38 | html_theme = "sphinx_rtd_theme"
39 |
40 | autodoc_typehints = "description"
41 |
42 | autodoc_type_aliases = {
43 | "random.PRNGKeyArray": "jax.random.PRNGKeyArray",
44 | "UnionCallableOptionalArray": "Union[Callable[[ndarray, Optional[ndarray]], ndarray], "
45 | "Callable[[Optional[ndarray]], ndarray]]",
46 | }
47 |
48 | latex_engine = "pdflatex"
49 |
50 | titles_only = True
51 |
52 | rst_prolog = """
53 | .. role:: python(code)
54 | :language: python
55 | """
56 |
57 | html_logo = "logo.svg"
58 |
59 | html_static_path = ["_static"]
60 | html_css_files = [
61 | "css/custom.css",
62 | ]
63 |
64 | html_theme_options = {
65 | "collapse_navigation": False,
66 | "prev_next_buttons_location": "None",
67 | }
68 |
69 |
70 | def linkcode_resolve(domain, info):
71 | """
72 | Called by sphinx's linkcode extension, which adds links directing the user to the
73 | source code of the API objects being documented. The `domain` argument specifies which
74 | programming language the object belongs to. The `info` argument is a dictionary with
75 | information specific to the programming language of the object.
76 |
77 | For Python objects, this dictionary contains a `module` key with the module the object is in
78 | and a `fullname` key with the name of the object. This function uses this information to find
79 | the source file and range of lines the object is defined in and to generate a link pointing to
80 | those lines on GitHub.
81 | """
82 | github_url = f"https://github.com/CQCL/qujax/tree/develop/qujax"
83 |
84 | if domain != "py":
85 | return
86 |
87 | module = importlib.import_module(info["module"])
88 | obj = getattr(module, info["fullname"])
89 |
90 | try:
91 | path = pathlib.Path(inspect.getsourcefile(obj))
92 | file_name = path.name
93 | lines = inspect.getsourcelines(obj)
94 | except TypeError:
95 | return
96 |
97 | start_line, end_line = lines[1], lines[1] + len(lines[0]) - 1
98 |
99 | return f"{github_url}/{file_name}#L{start_line}-L{end_line}"
100 |
--------------------------------------------------------------------------------
/docs/densitytensor.rst:
--------------------------------------------------------------------------------
1 | Mixed state simulation
2 | =======================
3 |
4 | .. toctree::
5 | :titlesonly:
6 |
7 | densitytensor/all_zeros_densitytensor
8 | densitytensor/kraus
9 | densitytensor/get_params_to_densitytensor_func
10 | densitytensor/partial_trace
11 | densitytensor/get_densitytensor_to_expectation_func
12 | densitytensor/get_densitytensor_to_sampled_expectation_func
13 | densitytensor/densitytensor_to_measurement_probabilities
14 | densitytensor/densitytensor_to_measured_densitytensor
15 | densitytensor/statetensor_to_densitytensor
16 |
17 |
--------------------------------------------------------------------------------
/docs/densitytensor/all_zeros_densitytensor.rst:
--------------------------------------------------------------------------------
1 | all_zeros_densitytensor
2 | ==============================================
3 |
4 | .. autofunction:: qujax.all_zeros_densitytensor
5 |
--------------------------------------------------------------------------------
/docs/densitytensor/densitytensor_to_measured_densitytensor.rst:
--------------------------------------------------------------------------------
1 | densitytensor_to_measured_densitytensor
2 | ==============================================
3 |
4 | .. autofunction:: qujax.densitytensor_to_measured_densitytensor
5 |
--------------------------------------------------------------------------------
/docs/densitytensor/densitytensor_to_measurement_probabilities.rst:
--------------------------------------------------------------------------------
1 | densitytensor_to_measurement_probabilities
2 | ==============================================
3 |
4 | .. autofunction:: qujax.densitytensor_to_measurement_probabilities
5 |
--------------------------------------------------------------------------------
/docs/densitytensor/get_densitytensor_to_expectation_func.rst:
--------------------------------------------------------------------------------
1 | get_densitytensor_to_expectation_func
2 | ==============================================
3 |
4 | .. autofunction:: qujax.get_densitytensor_to_expectation_func
5 |
--------------------------------------------------------------------------------
/docs/densitytensor/get_densitytensor_to_sampled_expectation_func.rst:
--------------------------------------------------------------------------------
1 | get_densitytensor_to_sampled_expectation_func
2 | ==============================================
3 |
4 | .. autofunction:: qujax.get_densitytensor_to_sampled_expectation_func
5 |
--------------------------------------------------------------------------------
/docs/densitytensor/get_params_to_densitytensor_func.rst:
--------------------------------------------------------------------------------
1 | get_params_to_densitytensor_func
2 | ==============================================
3 |
4 | .. autofunction:: qujax.get_params_to_densitytensor_func
5 |
--------------------------------------------------------------------------------
/docs/densitytensor/kraus.rst:
--------------------------------------------------------------------------------
1 | kraus
2 | ==============================================
3 |
4 | .. autofunction:: qujax.kraus
5 |
--------------------------------------------------------------------------------
/docs/densitytensor/partial_trace.rst:
--------------------------------------------------------------------------------
1 | partial_trace
2 | ==============================================
3 |
4 | .. autofunction:: qujax.partial_trace
5 |
--------------------------------------------------------------------------------
/docs/densitytensor/statetensor_to_densitytensor.rst:
--------------------------------------------------------------------------------
1 | statetensor_to_densitytensor
2 | ==============================================
3 |
4 | .. autofunction:: qujax.statetensor_to_densitytensor
5 |
--------------------------------------------------------------------------------
/docs/examples.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | Below are some use-case notebooks. These both illustrate the flexibility of qujax and the power of directly interfacing with JAX and its package ecosystem.
4 |
5 | - [heisenberg_vqe.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/heisenberg_vqe.ipynb) - an implementation of the variational quantum eigensolver to find the ground state of a quantum Hamiltonian.
6 | - [maxcut_vqe.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/maxcut_vqe.ipynb) - an implementation of the variational quantum eigensolver to solve a MaxCut problem. Trains with Adam via [`optax`](https://github.com/deepmind/optax) and uses more realistic stochastic parameter shift gradients.
7 | - [noise_channel.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/noise_channel.ipynb) - uses the densitytensor simulator to fit the parameters of a depolarising noise channel.
8 | - [qaoa.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/qaoa.ipynb) - uses a problem-inspired QAOA ansatz to find the ground state of a quantum Hamiltonian. Demonstrates how to encode more sophisticated parameters that control multiple gates.
9 | - [barren_plateaus.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/barren_plateaus.ipynb) - illustrates how to sample gradients of a cost function to identify the presence of barren plateaus. Uses batched/vectorized evaluation to speed up computation.
10 | - [reducing_jit_compilation_time.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/reducing_jit_compilation_time.ipynb) - explains how JAX compilation works and how that can lead to excessive compilation times when executing quantum circuits. Presents a solution for the case of circuits with a repeating structure.
11 | - [variational_inference.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/variational_inference.ipynb) - uses a parameterised quantum circuit as a variational distribution to fit to a target probability mass function. Uses Adam via [`optax`](https://github.com/deepmind/optax) to minimise the KL divergence between circuit and target distributions.
12 | - [classification.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/classification.ipynb) - train a quantum circuit for binary classification using data re-uploading.
13 | - [generative_modelling.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/generative_modelling.ipynb) - uses a parameterised quantum circuit as a generative model for a real life dataset. Trains via stochastic gradient Langevin dynamics on the maximum mean discrepancy between statetensor and dataset.
14 |
15 | The [pytket](https://github.com/CQCL/pytket) repository also contains `tk_to_qujax` implementations for some of the above at [pytket-qujax_classification.ipynb](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax-classification.ipynb),
16 | [pytket-qujax_heisenberg_vqe.ipynb](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_heisenberg_vqe.ipynb)
17 | and [pytket-qujax_qaoa.ipynb](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_qaoa.ipynb).
--------------------------------------------------------------------------------
/docs/gates.rst:
--------------------------------------------------------------------------------
1 | Quantum gates
2 | =======================
3 |
4 | This is a list of gates that qujax supports natively. You can also define custom operations by directly passing an array or function instead of a string, as documented in :doc:`statetensor/get_params_to_statetensor_func` and :doc:`densitytensor/get_params_to_densitytensor_func`.
5 |
6 | .. list-table::
7 | :widths: 25 25 50
8 | :header-rows: 1
9 |
10 | * - Name(s)
11 | - String
12 | - Matrix representation
13 | * - Pauli X gate
14 |
15 | NOT gate
16 |
17 | Bit flip gate
18 | - :python:`"X"`
19 | - .. math:: X = NOT = \begin{bmatrix}0 & 1\\ 1 & 0 \end{bmatrix}
20 | * - Pauli Y gate
21 | - :python:`"Y"`
22 | - .. math:: Y = \begin{bmatrix}0 & -i\\ i & 0 \end{bmatrix}
23 | * - Pauli Z gate
24 |
25 | Phase flip gate
26 | - :python:`"Z"`
27 | - .. math:: Z = \begin{bmatrix}1 & 0\\ 0 & -1 \end{bmatrix}
28 | * - Hadamard gate
29 | - :python:`"H"`
30 | - .. math:: H = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & 1\\ 1 & -1 \end{bmatrix}
31 | * - S gate
32 |
33 | P (phase) gate
34 | - :python:`"S"`
35 | - .. math:: S = P = \sqrt{Z} = \begin{bmatrix}1 & 0\\ 0 & i \end{bmatrix}
36 | * - Conjugated S gate
37 | - :python:`"Sdg"`
38 | - .. math:: S^\dagger = \begin{bmatrix}1 & 0\\ 0 & -i \end{bmatrix}
39 | * - T gate
40 | - :python:`"T"`
41 | - .. math:: T = \sqrt[4]{Z} = \begin{bmatrix}1 & 0\\ 0 & \exp(\frac{\pi i}{4}) \end{bmatrix}
42 | * - Conjugated T gate
43 | - :python:`"Tdg"`
44 | - .. math:: T^\dagger = \begin{bmatrix}1 & 0\\ 0 & -\exp(\frac{\pi i}{4}) \end{bmatrix}
45 | * - V gate
46 | - :python:`"V"`
47 | - .. math:: V = \sqrt{X} = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & -i\\ -i & 1 \end{bmatrix}
48 | * - Conjugated V gate
49 | - :python:`"Vdg"`
50 | - .. math:: V^\dagger = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & i\\ i & 1 \end{bmatrix}
51 | * - SX gate
52 | - :python:`"SX"`
53 | - .. math:: SX = \sqrt{X} = \frac{1}{2}\begin{bmatrix}1 + i & 1 - i\\ 1 - i & 1 + i \end{bmatrix}
54 | * - Conjugated SX gate
55 | - :python:`"SXdg"`
56 | - .. math:: SX^\dagger = \frac{1}{2}\begin{bmatrix}1 - i & 1 + i\\ 1 + i & 1 - i \end{bmatrix}
57 | * - CX (Controlled X) gate
58 |
59 | CNOT gate
60 | - :python:`"CX"`
61 | - .. math:: CX = CNOT = \begin{bmatrix}I & 0\\ 0 & X \end{bmatrix} = \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 0 & 1 & 0 \end{bmatrix}
62 | * - CY (Controlled Y) gate
63 | - :python:`"CY"`
64 | - .. math:: CY = \begin{bmatrix}I & 0\\ 0 & Y \end{bmatrix} = \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & -i \\ 0 & 0 & i & 0 \end{bmatrix}
65 | * - Controlled Z gate
66 | - :python:`"CZ"`
67 | - .. math:: CZ = \begin{bmatrix}I & 0\\ 0 & Z \end{bmatrix} = \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & -1 \end{bmatrix}
68 | * - Controlled Hadamard gate
69 | - :python:`"CH"`
70 | - .. math:: CH = \begin{bmatrix}I & 0\\ 0 & H \end{bmatrix} = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ 0 & 0 & 1 & -1 \end{bmatrix}
71 | * - Controlled V gate
72 | - :python:`"CV"`
73 | - .. math:: CV = \begin{bmatrix}I & 0\\ 0 & V \end{bmatrix} = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & -i \\ 0 & 0 & -i & 1 \end{bmatrix}
74 | * - Conjugated controlled V gate
75 | - :python:`"CVdg"`
76 | - .. math:: CVdg = \begin{bmatrix}I & 0\\ 0 & V^\dagger \end{bmatrix} = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & i \\ 0 & 0 & i & 1 \end{bmatrix}
77 | * - Controlled SX gate
78 | - :python:`"CSX"`
79 | - .. math:: CSX = \begin{bmatrix}I & 0\\ 0 & SX \end{bmatrix} = \frac{1}{2}\begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1+i & 1-i \\ 0 & 0 & 1-i & 1+i \end{bmatrix}
80 | * - Conjugated controlled SX gate
81 | - :python:`"CSXdg"`
82 | - .. math:: CSX^\dagger = \begin{bmatrix}I & 0\\ 0 & SX^\dagger \end{bmatrix} = \frac{1}{2}\begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1-i & 1+i \\ 0 & 0 & 1+i & 1-i \end{bmatrix}
83 | * - Toffoli gate
84 |
85 | CCX
86 |
87 | CCNOT
88 | - :python:`"CCX"`
89 | - .. math:: CCX = \begin{bmatrix}I & 0 & 0 & 0\\ 0 & I & 0 & 0 \\ 0 & 0 & I & 0 \\ 0 & 0 & 0 & X \end{bmatrix} = \begin{bmatrix}1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 \\ \end{bmatrix}
90 | * - Echoed cross-resonance gate
91 | - :python:`"ECR"`
92 | - .. math:: ECR = \begin{bmatrix}0 & V^\dagger \\ V & 0 \end{bmatrix} = \frac{1}{2}\begin{bmatrix}0 & 0 & 1 & i \\ 0 & 0 & i & 1 \\ 1 & -i & 0 & 0 \\ i & 1 & 0 & 0 \end{bmatrix}
93 | * - Swap gate
94 | - :python:`"SWAP"`
95 | - .. math:: SWAP = \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix}
96 | * - Controlled swap gate
97 | - :python:`"CSWAP"`
98 | - .. math:: CSWAP = \begin{bmatrix}I & 0 \\ 0 & SWAP \end{bmatrix} = \begin{bmatrix}1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 \end{bmatrix}
99 | * - Rotation around X axis
100 | - :python:`"Rx"`
101 | - .. math:: R_X(\theta) = \exp\left(-i \frac{\pi}{2} \theta X\right) = \begin{bmatrix} \cos( \frac{\pi}{2} \theta) & - \sin( \frac{\pi}{2} \theta) \\ - \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) \end{bmatrix}
102 | * - Rotation around X axis
103 | - :python:`"Ry"`
104 | - .. math:: R_Y(\theta) = \exp\left(-i \frac{\pi}{2} \theta Y\right) = \begin{bmatrix} \cos( \frac{\pi}{2} \theta) & i \sin( \frac{\pi}{2} \theta) \\ - i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) \end{bmatrix}
105 | * - Rotation around Z axis
106 | - :python:`"Rz"`
107 | - .. math:: R_Z(\theta) = \exp\left(-i \frac{\pi}{2} \theta Z\right) = \begin{bmatrix} \cos( \frac{\pi}{2} \theta) + \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & \cos( \frac{\pi}{2} \theta) - \sin( \frac{\pi}{2} \theta) \end{bmatrix}
108 | * - Controlled rotation around X axis
109 | - :python:`"CRx"`
110 | - .. math:: CR_X(\theta) = \begin{bmatrix}I & 0\\ 0 & RX(\theta) \end{bmatrix} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & \cos( \frac{\pi}{2} \theta) & - \sin( \frac{\pi}{2} \theta) \\ 0 & 0 & - \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) \end{bmatrix}
111 | * - Controlled rotation around Y axis
112 | - :python:`"CRy"`
113 | - .. math:: CR_Y(\theta) = \begin{bmatrix}I & 0\\ 0 & RY(\theta) \end{bmatrix} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & \cos( \frac{\pi}{2} \theta) & i \sin( \frac{\pi}{2} \theta) \\ 0 & 0 & - i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) \end{bmatrix}
114 | * - Controlled rotation around Z axis
115 | - :python:`"CRz"`
116 | - .. math:: CR_Z(\theta) = \begin{bmatrix}I & 0\\ 0 & RZ(\theta)\end{bmatrix} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & \cos( \frac{\pi}{2} \theta) + \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & 0 & 0 & \cos( \frac{\pi}{2} \theta) - \sin( \frac{\pi}{2} \theta) \end{bmatrix}
117 | * - U3
118 | - :python:`"U3"`
119 | - .. math:: U3(\alpha,\beta,\gamma) = \exp((\alpha + \beta) i \frac{\pi}{2}) R_Z(\beta) R_Y(\alpha) R_Z(\gamma)
120 | * - U1
121 | - :python:`"U1"`
122 | - .. math:: U1(\gamma) = U3(0, 0, \gamma)
123 | * - U2
124 | - :python:`"U2"`
125 | - .. math:: U2(\beta, \gamma) = U3(0.5, \beta, \gamma)
126 | * - Controlled U3
127 | - :python:`"CU3"`
128 | - .. math:: CU3(\alpha,\beta,\gamma) = \begin{bmatrix}I & 0\\ 0 & U3(\alpha,\beta,\gamma)\end{bmatrix}
129 | * - Controlled U1
130 | - :python:`"CU1"`
131 | - .. math:: CU1(\gamma) = \begin{bmatrix}I & 0\\ 0 & U1(\gamma)\end{bmatrix}
132 | * - Controlled U2
133 | - :python:`"CU2"`
134 | - .. math:: CU2(\beta, \gamma) = \begin{bmatrix}I & 0\\ 0 & U2(\beta, \gamma)\end{bmatrix}
135 | * - Imaginary swap
136 | - :python:`"ISWAP"`
137 | - .. math:: iSWAP(\theta) = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & \cos( \frac{\pi}{2} \theta) & i \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix}
138 | * - Phased imaginary swap
139 | - :python:`"PhasedISWAP"`
140 | - .. math:: PhasedISWAP(\phi, \theta) = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & \cos( \frac{\pi}{2} \theta) & \exp(2i\pi \phi) i \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & \exp(- 2i\pi \phi) i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix}
141 | * - XXPhase
142 |
143 | XX interaction
144 | - :python:`"XXPhase"`
145 | - .. math:: R_{XX}(\theta) = \exp\left(\frac{\pi}{2} \theta X\otimes X\right) = \begin{bmatrix} \cos( \frac{\pi}{2} \theta) & 0 & 0 & -i \sin( \frac{\pi}{2} \theta) \\ 0 & \cos( \frac{\pi}{2} \theta) & -i \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & -i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) & 0 \\ -i \sin( \frac{\pi}{2} \theta) & 0 & 0 & \cos( \frac{\pi}{2} \theta) \end{bmatrix}
146 | * - YYPhase
147 |
148 | YY interaction
149 | - :python:`"YYPhase"`
150 | - .. math:: R_{YY}(\theta) = \exp\left(\frac{\pi}{2} \theta Y\otimes Y\right) = \begin{bmatrix} \cos( \frac{\pi}{2} \theta) & 0 & 0 & i \sin( \frac{\pi}{2} \theta) \\ 0 & \cos( \frac{\pi}{2} \theta) & -i \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & -i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) & 0 \\ i \sin( \frac{\pi}{2} \theta) & 0 & 0 & \cos( \frac{\pi}{2} \theta) \end{bmatrix}
151 | * - ZZPhase
152 |
153 | ZZ interaction
154 | - :python:`"ZZPhase"`
155 | - .. math:: R_{ZZ}(\theta) = \exp\left(\frac{\pi}{2} \theta Z\otimes Z\right) = \begin{bmatrix} \exp( -i \frac{\pi}{2} \theta) & 0 & 0 & 0 \\ 0 & \exp( i \frac{\pi}{2} \theta) & 0 & 0 \\ 0 & 0 & \exp( i \frac{\pi}{2} \theta) & 0 \\ 0 & 0 & 0 & \exp( -i \frac{\pi}{2} \theta) \end{bmatrix}
156 | * - ZZMax
157 | - :python:`"ZZMax"`
158 | - .. math:: ZZMax = R_{ZZ}(0.5)
159 | * - PhasedX
160 | - :python:`"PhasedX"`
161 | - .. math:: PhasedX(\theta, \phi) = R_Z(\phi)R_X(\theta)R_Z(-\phi)
--------------------------------------------------------------------------------
/docs/getting_started.rst:
--------------------------------------------------------------------------------
1 | Getting started
2 | #################
3 |
4 | **Important note**: qujax circuit parameters are expressed in units of :math:`\pi` (e.g. in the range :math:`[0,2]` as opposed to :math:`[0, 2\pi]`).
5 |
6 | *********************
7 | Pure state simulation
8 | *********************
9 |
10 | We start by defining the quantum gates making up the circuit, along with the qubits that they act on and the indices of the parameters for each gate.
11 |
12 | A list of all gates can be found in :doc:`gates` (custom operations can be included by passing an array or function instead of a string, as documented in :doc:`statetensor/get_params_to_statetensor_func`).
13 |
14 | .. code-block:: python
15 |
16 | from jax import numpy as jnp
17 | import qujax
18 |
19 | # List of quantum gates
20 | circuit_gates = ['H', 'Ry', 'CZ']
21 | # Indices of qubits the gates will be applied to
22 | circuit_qubit_inds = [[0], [0], [0, 1]]
23 | # Indices of parameters each parameterised gate will use
24 | circuit_params_inds = [[], [0], []]
25 |
26 | qujax.print_circuit(circuit_gates, circuit_qubit_inds, circuit_params_inds);
27 | # q0: -----H-----Ry[0]-----◯---
28 | # |
29 | # q1: ---------------------CZ--
30 |
31 | We then translate the circuit to a pure function :python:`param_to_st` that takes a set of parameters and an (optional) initial quantum state as its input.
32 |
33 | .. code-block:: python
34 |
35 | param_to_st = qujax.get_params_to_statetensor_func(circuit_gates,
36 | circuit_qubit_inds,
37 | circuit_params_inds)
38 |
39 | param_to_st(jnp.array([0.1]))
40 | # Array([[0.58778524+0.j, 0. +0.j],
41 | # [0.80901706+0.j, 0. +0.j]], dtype=complex64)
42 |
43 | The optional initial state can be passed to :python:`param_to_st` using the :python:`statetensor_in` argument. When it is not provided, the initial state defaults to :math:`\ket{0...0}`.
44 |
45 | Note that qujax represents quantum states as *statetensors*. For example, for :math:`N=4` qubits, the corresponding vector space has :math:`2^4` dimensions, and a uantum state in this space is represented by an array with shape :python:`(2,2,2,2)`. The usual statevector representation with shape :python:`(16,)` can be obtained by calling :python:`.flatten()` or :python:`.reshape(-1)` or :python:`.reshape(2**N)` on this array.
46 |
47 | In the statetensor representation, the coefficient associated with e.g. basis state :math:`\ket{0101}` is given by `arr[0,1,0,1]`; each axis corresponds to one qubit.
48 |
49 | .. code-block:: python
50 |
51 | param_to_st(jnp.array([0.1])).flatten()
52 | # Array([0.58778524+0.j, 0.+0.j, 0.80901706+0.j, 0.+0.j], dtype=complex64)
53 |
54 |
55 | Finally, by defining an observable, we can map the statetensor to an expectation value. A general observable is specified using lists of Pauli matrices, the qubits they act on, and the associated coefficients.
56 |
57 | For example, :math:`Z_1Z_2Z_3Z_4 - 2 X_3` would be written as :python:`[['Z','Z','Z','Z'], ['X']], [[1,2,3,4], [3]], [1., -2.]`.
58 |
59 | .. code-block:: python
60 |
61 | st_to_expectation = qujax.get_statetensor_to_expectation_func([['Z']], [[0]], [1.])
62 |
63 |
64 | Combining :python:`param_to_st` and :python:`st_to_expectation` gives us a parameter to expectation function that can be automatically differentiated using JAX.
65 |
66 | .. code-block:: python
67 |
68 | from jax import value_and_grad
69 |
70 | param_to_expectation = lambda param: st_to_expectation(param_to_st(param))
71 | expectation_and_grad = value_and_grad(param_to_expectation)
72 | expectation_and_grad(jnp.array([0.1]))
73 | # (Array(-0.3090171, dtype=float32),
74 | # Array([-2.987832], dtype=float32))
75 |
76 | ***********************
77 | Mixed state simulation
78 | ***********************
79 | Mixed state simulations are analogous to the above, but with calls to :doc:`densitytensor/get_params_to_densitytensor_func` and :doc:`densitytensor/get_densitytensor_to_expectation_func` instead.
80 |
81 | .. code-block:: python
82 |
83 | param_to_dt = qujax.get_params_to_densitytensor_func(circuit_gates,
84 | circuit_qubit_inds,
85 | circuit_params_inds)
86 | dt = param_to_dt(jnp.array([0.1]))
87 | dt.shape
88 | # (2, 2, 2, 2)
89 |
90 | dt_to_expectation = qujax.get_densitytensor_to_expectation_func([['Z']], [[0]], [1.])
91 | dt_to_expectation(dt)
92 | # Array(-0.3090171, dtype=float32)
93 |
94 | Similarly to a statetensor, which represents the reshaped :math:`2^N`-dimensional statevector of a pure quantum state, a *densitytensor* represents the reshaped :math:`2^N \times 2^N` density matrix of a mixed quantum state. This densitytensor has shape :python:`(2,) * 2 * N`.
95 |
96 | For example, for :math:`N=2`, and a mixed state :math:`\frac{1}{2} (\ket{00}\bra{11} + \ket{11}\bra{00} + \ket{11}\bra{11} + \ket{00}\bra{00})`, the corresponding densitytensor :python:`dt` is such that :python:`dt[0,0,1,1] = dt[1,1,0,0] = dt[1,1,1,1] = dt[0,0,0,0] = 1/2`, and all other entries are zero.
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 |
2 | Welcome to qujax's documentation!
3 | =================================
4 |
5 | ``qujax`` is a `JAX `_-based Python library for the classical simulation of quantum circuits. It is designed to be *simple*, *fast* and *flexible*.
6 |
7 |
8 | It follows a functional programming design by translating circuits into pure functions. This allows qujax to `seamlessly and directly interface with JAX `_, enabling direct access to its powerful automatic differentiation tools, just-in-time compiler, vectorization capabilities, GPU/TPU integration and growing ecosystem of packages.
9 |
10 | If you are new to the library, we recommend that you head to the :doc:`getting_started` section of the documentation. More advanced use-cases, including the training of parameterised quantum circuits, can be found in :doc:`examples`.
11 |
12 | The source code can be found on `GitHub `_. The `pytket-qujax `_ extension can be used to translate a `tket `_ circuit directly into ``qujax``.
13 |
14 | **Important note**: qujax circuit parameters are expressed in units of :math:`\pi` (e.g. in the range :math:`[0,2]` as opposed to :math:`[0, 2\pi]`).
15 |
16 | Install
17 | =================================
18 | ``qujax`` is hosted on `PyPI `_ and can be installed with
19 |
20 | .. code-block:: bash
21 |
22 | pip install qujax
23 |
24 | Cite
25 | =================================
26 | If you have used qujax in your code or research, we kindly ask that you cite it. You can use the following BibTeX entry for this:
27 |
28 | .. code-block:: bibtex
29 |
30 | @article{qujax2023,
31 | author = {Duffield, Samuel and Matos, Gabriel and Johannsen, Melf},
32 | doi = {10.21105/joss.05504},
33 | journal = {Journal of Open Source Software},
34 | month = sep,
35 | number = {89},
36 | pages = {5504},
37 | title = {{qujax: Simulating quantum circuits with JAX}},
38 | url = {https://joss.theoj.org/papers/10.21105/joss.05504},
39 | volume = {8},
40 | year = {2023}
41 | }
42 |
43 | Contents
44 | =================================
45 |
46 | .. toctree::
47 | :caption: Documentation:
48 | :titlesonly:
49 |
50 | Getting started
51 | Examples
52 | List of gates
53 |
54 | .. toctree::
55 | :caption: API Reference:
56 | :titlesonly:
57 | :maxdepth: 1
58 |
59 | Pure state simulation
60 | Mixed state simulation
61 | Utility functions
62 |
63 | .. toctree::
64 | :caption: Links:
65 | :hidden:
66 |
67 | GitHub
68 | Paper
69 | PyPI
70 | pytket-qujax
71 |
--------------------------------------------------------------------------------
/docs/logo.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
53 |
--------------------------------------------------------------------------------
/docs/logo_dark_mode.svg:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
56 |
--------------------------------------------------------------------------------
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx >= 3
2 | sphinx-autodoc-typehints
3 | sphinx-rtd-theme >= 1.0.0
4 |
5 | jax
6 | jaxlib
--------------------------------------------------------------------------------
/docs/statetensor.rst:
--------------------------------------------------------------------------------
1 | Pure state simulation
2 | =======================
3 |
4 | .. toctree::
5 | :titlesonly:
6 |
7 | statetensor/all_zeros_statetensor
8 | statetensor/apply_gate
9 | statetensor/get_params_to_statetensor_func
10 | statetensor/get_params_to_unitarytensor_func
11 | statetensor/get_statetensor_to_expectation_func
12 | statetensor/get_statetensor_to_sampled_expectation_func
13 |
14 |
--------------------------------------------------------------------------------
/docs/statetensor/all_zeros_statetensor.rst:
--------------------------------------------------------------------------------
1 | all_zeros_statetensor
2 | ==============================================
3 |
4 | .. autofunction:: qujax.all_zeros_statetensor
5 |
--------------------------------------------------------------------------------
/docs/statetensor/apply_gate.rst:
--------------------------------------------------------------------------------
1 | apply_gate
2 | ==============================================
3 |
4 | .. autofunction:: qujax.apply_gate
5 |
--------------------------------------------------------------------------------
/docs/statetensor/get_params_to_statetensor_func.rst:
--------------------------------------------------------------------------------
1 | get_params_to_statetensor_func
2 | ==============================================
3 |
4 | .. autofunction:: qujax.get_params_to_statetensor_func
5 |
--------------------------------------------------------------------------------
/docs/statetensor/get_params_to_unitarytensor_func.rst:
--------------------------------------------------------------------------------
1 | get_params_to_unitarytensor_func
2 | ==============================================
3 |
4 | .. autofunction:: qujax.get_params_to_unitarytensor_func
5 |
--------------------------------------------------------------------------------
/docs/statetensor/get_statetensor_to_expectation_func.rst:
--------------------------------------------------------------------------------
1 | get_statetensor_to_expectation_func
2 | ==============================================
3 |
4 | .. autofunction:: qujax.get_statetensor_to_expectation_func
5 |
--------------------------------------------------------------------------------
/docs/statetensor/get_statetensor_to_sampled_expectation_func.rst:
--------------------------------------------------------------------------------
1 | get_statetensor_to_sampled_expectation_func
2 | ==============================================
3 |
4 | .. autofunction:: qujax.get_statetensor_to_sampled_expectation_func
5 |
--------------------------------------------------------------------------------
/docs/utils.rst:
--------------------------------------------------------------------------------
1 | Utility functions
2 | =======================
3 |
4 | .. toctree::
5 | :titlesonly:
6 |
7 | utils/bitstrings_to_integers
8 | utils/check_circuit
9 | utils/integers_to_bitstrings
10 | utils/print_circuit
11 | utils/repeat_circuit
12 | utils/sample_bitstrings
13 | utils/sample_integers
14 |
--------------------------------------------------------------------------------
/docs/utils/bitstrings_to_integers.rst:
--------------------------------------------------------------------------------
1 | bitstrings_to_integers
2 | ==============================================
3 |
4 | .. autofunction:: qujax.bitstrings_to_integers
5 |
--------------------------------------------------------------------------------
/docs/utils/check_circuit.rst:
--------------------------------------------------------------------------------
1 | check_circuit
2 | ==============================================
3 |
4 | .. autofunction:: qujax.check_circuit
5 |
--------------------------------------------------------------------------------
/docs/utils/integers_to_bitstrings.rst:
--------------------------------------------------------------------------------
1 | integers_to_bitstrings
2 | ==============================================
3 |
4 | .. autofunction:: qujax.integers_to_bitstrings
5 |
--------------------------------------------------------------------------------
/docs/utils/print_circuit.rst:
--------------------------------------------------------------------------------
1 | print_circuit
2 | ==============================================
3 |
4 | .. autofunction:: qujax.print_circuit
5 |
--------------------------------------------------------------------------------
/docs/utils/repeat_circuit.rst:
--------------------------------------------------------------------------------
1 | repeat_circuit
2 | ==============================================
3 |
4 | .. autofunction:: qujax.repeat_circuit
5 |
--------------------------------------------------------------------------------
/docs/utils/sample_bitstrings.rst:
--------------------------------------------------------------------------------
1 | sample_bitstrings
2 | ==============================================
3 |
4 | .. autofunction:: qujax.sample_bitstrings
5 |
--------------------------------------------------------------------------------
/docs/utils/sample_integers.rst:
--------------------------------------------------------------------------------
1 | sample_integers
2 | ==============================================
3 |
4 | .. autofunction:: qujax.sample_integers
5 |
--------------------------------------------------------------------------------
/examples/hidalgo_stamp.txt:
--------------------------------------------------------------------------------
1 | 0.060 0.064 0.064 0.065 0.066 0.068 0.069 0.069 0.069 0.069 0.069 0.069
2 | 0.069 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070
3 | 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070
4 | 0.070 0.070 0.070 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071
5 | 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.072
6 | 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072
7 | 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072
8 | 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.073 0.073 0.073 0.073 0.073
9 | 0.073 0.073 0.073 0.073 0.073 0.073 0.074 0.074 0.074 0.074 0.074 0.074
10 | 0.074 0.074 0.074 0.074 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075
11 | 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075
12 | 0.076 0.076 0.076 0.076 0.076 0.076 0.076 0.076 0.076 0.076 0.076 0.076
13 | 0.076 0.076 0.076 0.076 0.076 0.076 0.077 0.077 0.077 0.077 0.077 0.077
14 | 0.077 0.077 0.077 0.077 0.077 0.078 0.078 0.078 0.078 0.078 0.078 0.078
15 |
16 | 0.078 0.078 0.078 0.078 0.078 0.078 0.078 0.078 0.078 0.078 0.078 0.078
17 | 0.078 0.078 0.078 0.078 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079
18 | 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079
19 | 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079
20 | 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.080 0.080
21 | 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080
22 | 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080
23 | 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.081
24 | 0.081 0.081 0.081 0.081 0.081 0.081 0.081 0.081 0.081 0.081 0.081 0.081
25 | 0.081 0.081 0.082 0.082 0.082 0.082 0.082 0.082 0.082 0.082 0.082 0.082
26 | 0.082 0.082 0.082 0.082 0.082 0.082 0.082 0.082 0.083 0.083 0.083 0.083
27 | 0.083 0.083 0.083 0.084 0.084 0.084 0.085 0.085 0.086 0.086 0.087 0.088
28 | 0.088 0.089 0.089 0.089 0.089 0.089 0.089 0.089 0.089 0.089 0.089 0.090
29 | 0.090 0.090 0.090 0.090 0.090 0.090 0.090 0.090 0.091 0.091 0.091 0.092
30 |
31 | 0.092 0.092 0.092 0.092 0.093 0.093 0.093 0.093 0.093 0.093 0.094 0.094
32 | 0.094 0.095 0.095 0.096 0.096 0.096 0.097 0.097 0.097 0.097 0.097 0.097
33 | 0.097 0.098 0.098 0.098 0.098 0.098 0.099 0.099 0.099 0.099 0.099 0.100
34 | 0.100 0.100 0.100 0.100 0.100 0.100 0.100 0.100 0.100 0.100 0.100 0.100
35 | 0.100 0.100 0.101 0.101 0.101 0.101 0.101 0.101 0.101 0.101 0.101 0.102
36 | 0.102 0.102 0.102 0.102 0.102 0.102 0.102 0.103 0.103 0.103 0.103 0.103
37 | 0.103 0.103 0.104 0.104 0.105 0.105 0.105 0.105 0.105 0.106 0.106 0.106
38 | 0.106 0.107 0.107 0.107 0.108 0.108 0.108 0.108 0.108 0.108 0.108 0.109
39 | 0.109 0.109 0.109 0.109 0.109 0.109 0.110 0.110 0.110 0.110 0.110 0.110
40 | 0.110 0.110 0.110 0.110 0.110 0.111 0.111 0.111 0.111 0.112 0.112 0.112
41 | 0.112 0.112 0.114 0.114 0.114 0.115 0.115 0.115 0.117 0.119 0.119 0.119
42 | 0.119 0.120 0.120 0.120 0.121 0.122 0.122 0.123 0.123 0.125 0.125 0.128
43 | 0.129 0.129 0.129 0.130 0.131
44 |
--------------------------------------------------------------------------------
/examples/qaoa.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "a52b609d",
6 | "metadata": {},
7 | "source": [
8 | "# QAOA with `qujax`\n",
9 | "\n",
10 | "In this notebook, we will consider QAOA on an Ising Hamiltonian. In particular, we will demonstrate how to encode a circuit with parameters that control multiple gates."
11 | ]
12 | },
13 | {
14 | "cell_type": "code",
15 | "execution_count": 1,
16 | "id": "2678d061",
17 | "metadata": {},
18 | "outputs": [],
19 | "source": [
20 | "from jax import numpy as jnp, random, value_and_grad, jit\n",
21 | "import matplotlib.pyplot as plt\n",
22 | "\n",
23 | "import qujax"
24 | ]
25 | },
26 | {
27 | "cell_type": "markdown",
28 | "id": "57aaa7eb",
29 | "metadata": {},
30 | "source": [
31 | "# QAOA\n",
32 | "\n",
33 | "The Quantum Approximate Optimization Algorithm (QAOA), first introduced by [Farhi et al.](https://arxiv.org/pdf/1411.4028.pdf), is a quantum variational algorithm used to solve optimization problems. It consists of a unitary $U(\\beta, \\gamma)$ formed by alternate repetitions of $U(\\beta)=e^{-i\\beta H_B}$ and $U(\\gamma)=e^{-i\\gamma H_P}$, where $H_B$ is the mixing Hamiltonian and $H_P$ the problem Hamiltonian. The goal is to find the optimal parameters that minimize $H_P$.\n",
34 | "\n",
35 | "Given a depth $d$, the expression of the final unitary is $U(\\beta, \\gamma) = U(\\beta_d)U(\\gamma_d)\\cdots U(\\beta_1)U(\\gamma_1)$. Notice that for each repetition the parameters are different.\n",
36 | "\n",
37 | "\n",
38 | "## Problem Hamiltonian\n",
39 | "QAOA uses a problem dependent ansatz. Therefore, we first need to know the problem that we want to solve. In this case we will consider an Ising Hamiltonian with only $Z$ interactions. Given a set of pairs (or qubit indices) $E$, the problem Hamiltonian will be:\n",
40 | "$$H_P = \\sum_{(i, j) \\in E}\\alpha_{ij}Z_iZ_j,$$ \n",
41 | "where $\\alpha_{ij}$ are the coefficients.\n",
42 | "\n",
43 | "Let's build our problem Hamiltonian with random coefficients and a set of pairs for a given number of qubits:"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 2,
49 | "id": "6602f4ad",
50 | "metadata": {},
51 | "outputs": [],
52 | "source": [
53 | "n_qubits = 4\n",
54 | "hamiltonian_qubit_inds = [(0, 1), (1, 2), (0, 2), (1, 3)]\n",
55 | "hamiltonian_gates = [[\"Z\", \"Z\"]] * (len(hamiltonian_qubit_inds))\n",
56 | "\n",
57 | "\n",
58 | "# Notice that in order to use the random package from jax we first need to define a seeded key\n",
59 | "seed = 13\n",
60 | "key = random.PRNGKey(seed)\n",
61 | "coefficients = random.uniform(key, shape=(len(hamiltonian_qubit_inds),))"
62 | ]
63 | },
64 | {
65 | "cell_type": "code",
66 | "execution_count": 3,
67 | "id": "e80df4ff",
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "name": "stdout",
72 | "output_type": "stream",
73 | "text": [
74 | "Gates:\t [['Z', 'Z'], ['Z', 'Z'], ['Z', 'Z'], ['Z', 'Z']]\n",
75 | "Qubits:\t [(0, 1), (1, 2), (0, 2), (1, 3)]\n",
76 | "Coefficients:\t [0.6794174 0.2963785 0.2863201 0.31746793]\n"
77 | ]
78 | }
79 | ],
80 | "source": [
81 | "print(\"Gates:\\t\", hamiltonian_gates)\n",
82 | "print(\"Qubits:\\t\", hamiltonian_qubit_inds)\n",
83 | "print(\"Coefficients:\\t\", coefficients)"
84 | ]
85 | },
86 | {
87 | "cell_type": "markdown",
88 | "id": "b03a1f9d",
89 | "metadata": {},
90 | "source": [
91 | "## Variational Circuit\n",
92 | "\n",
93 | "Before constructing the circuit, we still need to select the mixing Hamiltonian. In our case, we will be using $X$ gates in each qubit, so $H_B = \\sum_{i=1}^{n}X_i$, where $n$ is the number of qubits. Notice that the unitary $U(\\beta)$, given this mixing Hamiltonian, is an $X$ rotation in each qubit with angle $\\beta$.\n",
94 | "\n",
95 | "As for the unitary corresponding to the problem Hamiltonian, $U(\\gamma)$, it has the following form:\n",
96 | "$$U(\\gamma)=\\prod_{(i, j) \\in E}e^{-i\\gamma\\alpha_{ij}Z_iZ_j}$$ \n",
97 | "\n",
98 | "The operation $e^{-i\\gamma\\alpha_{ij}Z_iZ_j}$ can be performed using two CNOT gates with qubit $i$ as control and qubit $j$ as target and a $Z$ rotation in qubit $j$ in between them, with angle $\\gamma\\alpha_{ij}$.\n",
99 | "\n",
100 | "Finally, the initial state used, in general, with the QAOA is an equal superposition of all the basis states. This can be achieved adding a first layer of Hadamard gates in each qubit at the beginning of the circuit."
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 4,
106 | "id": "008e895f",
107 | "metadata": {},
108 | "outputs": [],
109 | "source": [
110 | "depth = 3"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": 5,
116 | "id": "cf0fbfaa",
117 | "metadata": {},
118 | "outputs": [],
119 | "source": [
120 | "circuit_gates = []\n",
121 | "circuit_qubit_inds = []\n",
122 | "circuit_param_inds = []\n",
123 | "\n",
124 | "param_ind = 0\n",
125 | "\n",
126 | "# Initial State\n",
127 | "for i in range(n_qubits):\n",
128 | " circuit_gates.append(\"H\")\n",
129 | " circuit_qubit_inds.append([i])\n",
130 | " circuit_param_inds.append([])\n",
131 | "\n",
132 | "for d in range(depth):\n",
133 | " # Mixing Unitary\n",
134 | " for i in range(n_qubits):\n",
135 | " circuit_gates.append(\"Rx\")\n",
136 | " circuit_qubit_inds.append([i])\n",
137 | " circuit_param_inds.append([param_ind])\n",
138 | " param_ind += 1\n",
139 | "\n",
140 | " # Hamiltonian\n",
141 | " for index in range(len(hamiltonian_qubit_inds)):\n",
142 | " pair = hamiltonian_qubit_inds[index]\n",
143 | " coef = coefficients[index]\n",
144 | "\n",
145 | " circuit_gates.append(\"CX\")\n",
146 | " circuit_qubit_inds.append([pair[0], pair[1]])\n",
147 | " circuit_param_inds.append([])\n",
148 | "\n",
149 | " circuit_gates.append(lambda p: qujax.gates.Rz(p * coef))\n",
150 | " circuit_qubit_inds.append([pair[1]])\n",
151 | " circuit_param_inds.append([param_ind])\n",
152 | "\n",
153 | " circuit_gates.append(\"CX\")\n",
154 | " circuit_qubit_inds.append([pair[0], pair[1]])\n",
155 | " circuit_param_inds.append([])\n",
156 | " param_ind += 1"
157 | ]
158 | },
159 | {
160 | "cell_type": "markdown",
161 | "id": "24ce1e3c",
162 | "metadata": {},
163 | "source": [
164 | "Let's check the circuit:"
165 | ]
166 | },
167 | {
168 | "cell_type": "code",
169 | "execution_count": 6,
170 | "id": "c3fb239c",
171 | "metadata": {
172 | "scrolled": false
173 | },
174 | "outputs": [
175 | {
176 | "name": "stdout",
177 | "output_type": "stream",
178 | "text": [
179 | "q0: -----H-----Rx[0]-----◯---------------◯-------------------------------◯---------------◯---------------------Rx[2]-\n",
180 | " | | | | \n",
181 | "q1: -----H-----Rx[0]-----CX---Func[1]----CX------◯---------------◯-------|---------------|-------◯---------------◯---\n",
182 | " | | | | | | \n",
183 | "q2: -----H-----Rx[0]-----------------------------CX---Func[1]----CX------CX---Func[1]----CX------|---------------|---\n",
184 | " | | \n",
185 | "q3: -----H-----Rx[0]-----------------------------------------------------------------------------CX---Func[1]----CX--\n"
186 | ]
187 | }
188 | ],
189 | "source": [
190 | "qujax.print_circuit(\n",
191 | " circuit_gates, circuit_qubit_inds, circuit_param_inds, gate_ind_max=20\n",
192 | ");"
193 | ]
194 | },
195 | {
196 | "cell_type": "markdown",
197 | "id": "8c421d1d",
198 | "metadata": {},
199 | "source": [
200 | "Then, we invoke the `qujax.get_params_to_statetensor_func`"
201 | ]
202 | },
203 | {
204 | "cell_type": "code",
205 | "execution_count": 7,
206 | "id": "1af80d12",
207 | "metadata": {},
208 | "outputs": [],
209 | "source": [
210 | "param_to_st = qujax.get_params_to_statetensor_func(\n",
211 | " circuit_gates, circuit_qubit_inds, circuit_param_inds\n",
212 | ")"
213 | ]
214 | },
215 | {
216 | "cell_type": "markdown",
217 | "id": "689e7811",
218 | "metadata": {},
219 | "source": [
220 | "And we also construct the expectation map using the problem Hamiltonian via qujax: "
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": 8,
226 | "id": "0857d4bb",
227 | "metadata": {},
228 | "outputs": [],
229 | "source": [
230 | "st_to_expectation = qujax.get_statetensor_to_expectation_func(\n",
231 | " hamiltonian_gates, hamiltonian_qubit_inds, coefficients\n",
232 | ")\n",
233 | "\n",
234 | "param_to_expectation = lambda param: st_to_expectation(param_to_st(param))"
235 | ]
236 | },
237 | {
238 | "cell_type": "markdown",
239 | "id": "6144236c",
240 | "metadata": {},
241 | "source": [
242 | "# Training process\n",
243 | "We construct a function that, given a parameter vector, returns the value of the cost function and the gradient (we also `jit` to avoid recompilation):"
244 | ]
245 | },
246 | {
247 | "cell_type": "code",
248 | "execution_count": 9,
249 | "id": "069331d5",
250 | "metadata": {},
251 | "outputs": [],
252 | "source": [
253 | "cost_and_grad = jit(value_and_grad(param_to_expectation))"
254 | ]
255 | },
256 | {
257 | "cell_type": "markdown",
258 | "id": "8f8e0741",
259 | "metadata": {},
260 | "source": [
261 | "For the training process we'll use vanilla gradient descent with a constant stepsize:"
262 | ]
263 | },
264 | {
265 | "cell_type": "code",
266 | "execution_count": 10,
267 | "id": "b28018e2",
268 | "metadata": {},
269 | "outputs": [],
270 | "source": [
271 | "seed = 123\n",
272 | "key = random.PRNGKey(seed)\n",
273 | "init_param = random.uniform(key, shape=(param_ind,))\n",
274 | "\n",
275 | "n_steps = 150\n",
276 | "stepsize = 0.01"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 11,
282 | "id": "c0f53be5",
283 | "metadata": {},
284 | "outputs": [
285 | {
286 | "name": "stdout",
287 | "output_type": "stream",
288 | "text": [
289 | "Iteration: 1 \tCost: -0.09861969\r",
290 | "Iteration: 2 \tCost: -0.30488127\r",
291 | "Iteration: 3 \tCost: -0.43221825\r",
292 | "Iteration: 4 \tCost: -0.49104044\r",
293 | "Iteration: 5 \tCost: -0.51378363\r",
294 | "Iteration: 6 \tCost: -0.5220106\r",
295 | "Iteration: 7 \tCost: -0.52507806\r",
296 | "Iteration: 8 \tCost: -0.5263982\r",
297 | "Iteration: 9 \tCost: -0.52713084\r",
298 | "Iteration: 10 \tCost: -0.5276594\r",
299 | "Iteration: 11 \tCost: -0.5281093\r",
300 | "Iteration: 12 \tCost: -0.5285234\r",
301 | "Iteration: 13 \tCost: -0.5289158\r",
302 | "Iteration: 14 \tCost: -0.5292916\r",
303 | "Iteration: 15 \tCost: -0.5296537\r",
304 | "Iteration: 16 \tCost: -0.53000385\r",
305 | "Iteration: 17 \tCost: -0.5303429\r",
306 | "Iteration: 18 \tCost: -0.5306716\r",
307 | "Iteration: 19 \tCost: -0.5309909\r",
308 | "Iteration: 20 \tCost: -0.5313021\r",
309 | "Iteration: 21 \tCost: -0.53160506\r",
310 | "Iteration: 22 \tCost: -0.53190106\r",
311 | "Iteration: 23 \tCost: -0.53219026\r",
312 | "Iteration: 24 \tCost: -0.5324728\r",
313 | "Iteration: 25 \tCost: -0.53274995\r",
314 | "Iteration: 26 \tCost: -0.533021\r",
315 | "Iteration: 27 \tCost: -0.5332869\r",
316 | "Iteration: 28 \tCost: -0.53354824\r",
317 | "Iteration: 29 \tCost: -0.5338048\r",
318 | "Iteration: 30 \tCost: -0.53405714\r",
319 | "Iteration: 31 \tCost: -0.5343054\r",
320 | "Iteration: 32 \tCost: -0.53454924\r",
321 | "Iteration: 33 \tCost: -0.5347898\r",
322 | "Iteration: 34 \tCost: -0.53502667\r",
323 | "Iteration: 35 \tCost: -0.5352601\r",
324 | "Iteration: 36 \tCost: -0.53549004\r",
325 | "Iteration: 37 \tCost: -0.53571683\r",
326 | "Iteration: 38 \tCost: -0.5359408\r",
327 | "Iteration: 39 \tCost: -0.5361613\r",
328 | "Iteration: 40 \tCost: -0.5363793\r",
329 | "Iteration: 41 \tCost: -0.53659475\r",
330 | "Iteration: 42 \tCost: -0.5368071\r",
331 | "Iteration: 43 \tCost: -0.53701735\r",
332 | "Iteration: 44 \tCost: -0.5372245\r",
333 | "Iteration: 45 \tCost: -0.5374299\r",
334 | "Iteration: 46 \tCost: -0.5376322\r",
335 | "Iteration: 47 \tCost: -0.53783244\r",
336 | "Iteration: 48 \tCost: -0.5380306\r",
337 | "Iteration: 49 \tCost: -0.53822607\r",
338 | "Iteration: 50 \tCost: -0.53841996\r",
339 | "Iteration: 51 \tCost: -0.53861094\r",
340 | "Iteration: 52 \tCost: -0.53880095\r",
341 | "Iteration: 53 \tCost: -0.53898793\r",
342 | "Iteration: 54 \tCost: -0.5391733\r",
343 | "Iteration: 55 \tCost: -0.5393568\r",
344 | "Iteration: 56 \tCost: -0.5395382\r",
345 | "Iteration: 57 \tCost: -0.53971833\r",
346 | "Iteration: 58 \tCost: -0.53989595\r",
347 | "Iteration: 59 \tCost: -0.54007185\r",
348 | "Iteration: 60 \tCost: -0.5402465\r",
349 | "Iteration: 61 \tCost: -0.54041874\r",
350 | "Iteration: 62 \tCost: -0.5405901\r",
351 | "Iteration: 63 \tCost: -0.5407588\r",
352 | "Iteration: 64 \tCost: -0.5409268\r",
353 | "Iteration: 65 \tCost: -0.54109275\r",
354 | "Iteration: 66 \tCost: -0.5412577\r",
355 | "Iteration: 67 \tCost: -0.54141986\r",
356 | "Iteration: 68 \tCost: -0.5415817\r",
357 | "Iteration: 69 \tCost: -0.5417414\r",
358 | "Iteration: 70 \tCost: -0.5419003\r",
359 | "Iteration: 71 \tCost: -0.5420571\r",
360 | "Iteration: 72 \tCost: -0.5422127\r",
361 | "Iteration: 73 \tCost: -0.5423671\r",
362 | "Iteration: 74 \tCost: -0.5425201\r",
363 | "Iteration: 75 \tCost: -0.5426715\r",
364 | "Iteration: 76 \tCost: -0.54282165\r",
365 | "Iteration: 77 \tCost: -0.54297113\r",
366 | "Iteration: 78 \tCost: -0.5431187\r",
367 | "Iteration: 79 \tCost: -0.543265\r",
368 | "Iteration: 80 \tCost: -0.5434102\r",
369 | "Iteration: 81 \tCost: -0.54355454\r",
370 | "Iteration: 82 \tCost: -0.54369736\r",
371 | "Iteration: 83 \tCost: -0.5438389\r",
372 | "Iteration: 84 \tCost: -0.5439794\r",
373 | "Iteration: 85 \tCost: -0.54411876\r",
374 | "Iteration: 86 \tCost: -0.54425716\r",
375 | "Iteration: 87 \tCost: -0.5443945\r",
376 | "Iteration: 88 \tCost: -0.5445305\r",
377 | "Iteration: 89 \tCost: -0.5446654\r",
378 | "Iteration: 90 \tCost: -0.5447996\r",
379 | "Iteration: 91 \tCost: -0.5449326\r",
380 | "Iteration: 92 \tCost: -0.54506487\r",
381 | "Iteration: 93 \tCost: -0.54519576\r",
382 | "Iteration: 94 \tCost: -0.5453257\r",
383 | "Iteration: 95 \tCost: -0.5454546\r",
384 | "Iteration: 96 \tCost: -0.545583\r",
385 | "Iteration: 97 \tCost: -0.5457101\r",
386 | "Iteration: 98 \tCost: -0.54583657\r",
387 | "Iteration: 99 \tCost: -0.5459618\r",
388 | "Iteration: 100 \tCost: -0.54608655\r",
389 | "Iteration: 101 \tCost: -0.54621017\r",
390 | "Iteration: 102 \tCost: -0.54633296\r",
391 | "Iteration: 103 \tCost: -0.54645467\r",
392 | "Iteration: 104 \tCost: -0.54657626\r",
393 | "Iteration: 105 \tCost: -0.54669654\r",
394 | "Iteration: 106 \tCost: -0.54681575\r",
395 | "Iteration: 107 \tCost: -0.5469348\r",
396 | "Iteration: 108 \tCost: -0.5470527\r",
397 | "Iteration: 109 \tCost: -0.54716986\r",
398 | "Iteration: 110 \tCost: -0.54728657\r",
399 | "Iteration: 111 \tCost: -0.5474022\r",
400 | "Iteration: 112 \tCost: -0.5475172\r",
401 | "Iteration: 113 \tCost: -0.5476318\r",
402 | "Iteration: 114 \tCost: -0.5477449\r",
403 | "Iteration: 115 \tCost: -0.5478576\r",
404 | "Iteration: 116 \tCost: -0.54797024\r",
405 | "Iteration: 117 \tCost: -0.5480816\r",
406 | "Iteration: 118 \tCost: -0.548192\r",
407 | "Iteration: 119 \tCost: -0.54830235\r",
408 | "Iteration: 120 \tCost: -0.5484118\r",
409 | "Iteration: 121 \tCost: -0.54852045\r",
410 | "Iteration: 122 \tCost: -0.548629\r",
411 | "Iteration: 123 \tCost: -0.5487364\r",
412 | "Iteration: 124 \tCost: -0.5488433\r",
413 | "Iteration: 125 \tCost: -0.5489498\r",
414 | "Iteration: 126 \tCost: -0.54905534\r",
415 | "Iteration: 127 \tCost: -0.5491607\r",
416 | "Iteration: 128 \tCost: -0.5492652\r",
417 | "Iteration: 129 \tCost: -0.54936945\r",
418 | "Iteration: 130 \tCost: -0.54947305\r",
419 | "Iteration: 131 \tCost: -0.54957575\r",
420 | "Iteration: 132 \tCost: -0.5496777\r",
421 | "Iteration: 133 \tCost: -0.5497798\r",
422 | "Iteration: 134 \tCost: -0.5498811\r",
423 | "Iteration: 135 \tCost: -0.5499818\r",
424 | "Iteration: 136 \tCost: -0.5500822\r",
425 | "Iteration: 137 \tCost: -0.5501819\r",
426 | "Iteration: 138 \tCost: -0.550281\r",
427 | "Iteration: 139 \tCost: -0.55037963\r",
428 | "Iteration: 140 \tCost: -0.55047804\r",
429 | "Iteration: 141 \tCost: -0.5505761\r",
430 | "Iteration: 142 \tCost: -0.55067325\r",
431 | "Iteration: 143 \tCost: -0.5507698\r",
432 | "Iteration: 144 \tCost: -0.55086654\r",
433 | "Iteration: 145 \tCost: -0.5509624\r",
434 | "Iteration: 146 \tCost: -0.5510575\r",
435 | "Iteration: 147 \tCost: -0.55115235\r",
436 | "Iteration: 148 \tCost: -0.55124706\r",
437 | "Iteration: 149 \tCost: -0.5513411\r"
438 | ]
439 | }
440 | ],
441 | "source": [
442 | "param = init_param\n",
443 | "\n",
444 | "cost_vals = jnp.zeros(n_steps)\n",
445 | "cost_vals = cost_vals.at[0].set(param_to_expectation(init_param))\n",
446 | "\n",
447 | "for step in range(1, n_steps):\n",
448 | " cost_val, cost_grad = cost_and_grad(param)\n",
449 | " cost_vals = cost_vals.at[step].set(cost_val)\n",
450 | " param = param - stepsize * cost_grad\n",
451 | " print(\"Iteration:\", step, \"\\tCost:\", cost_val, end=\"\\r\")"
452 | ]
453 | },
454 | {
455 | "cell_type": "markdown",
456 | "id": "b9305afc",
457 | "metadata": {},
458 | "source": [
459 | "Let's visualise the gradient descent"
460 | ]
461 | },
462 | {
463 | "cell_type": "code",
464 | "execution_count": 12,
465 | "id": "114ed37d",
466 | "metadata": {},
467 | "outputs": [
468 | {
469 | "data": {
470 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEGCAYAAAB7DNKzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAAbhklEQVR4nO3de5Scd33f8fdnbnuzLGllIwtLtri4hEuMgYVAuJTEIoGEYjcnxU5oIw5xTZq2SRrSVOCentPTpnEOaUIaEnJUp0UQDumJg7GOoYAtbCAXAzL1BTAgGwcQ1s2yLVnSauf27R/Pb3ZH65nZkbSzz1jzeZ2zZ57Lb5757iPtfvb5/Z6LIgIzM7NuCnkXYGZmw81BYWZmPTkozMysJweFmZn15KAwM7OeSnkXsNwuuOCC2Lx5c95lmJk9o9xzzz2PRcSFndadc0GxefNmdu/enXcZZmbPKJK+122du57MzKwnB4WZmfXkoDAzs54cFGZm1pODwszMenJQmJlZT7kEhaRpSbdL2pNe13Zp9xlJT0q6baVrNDOzTF7XUWwDdkXEjZK2pfn/0KHd+4FJ4N2DLuhEtc6f3fUwAMVCgWtftYn1548P+mPNzIZeXkFxFfDGNL0DuIsOQRERuyS9cfHyQZitNvjjOx+i9XiO8XKBd//j563ER5uZDbW8xijWR8S+NL0fWH82G5N0vaTdknYfOnTojLax7rwxHvndn+Wh33kLAHP15tmUZGZ2zhjYEYWkO4CLOqy6oX0mIkLSWT1mLyK2A9sBZmZmzmpbxYKQoNZwUJiZwQCDIiK2dFsn6YCkDRGxT9IG4OCg6jhdkigXC1QdFGZmQH5dTzuBrWl6K3BrTnV0VCkWqDf8LHEzM8gvKG4E3iRpD7AlzSNpRtJNrUaSvgT8FXClpL2SfnoliisX5a4nM7Mkl7OeIuIwcGWH5buB69rmX7+SdbWUigUHhZlZ4iuzO6gUC1Tr7noyMwMHRUfuejIzW+Cg6KBcLFBvOijMzMBB0VHZXU9mZvMcFB2468nMbIGDooOyz3oyM5vnoOjAQWFmtsBB0UG5VKDmK7PNzAAHRUcVj1GYmc1zUHRQKrjrycysxUHRgbuezMwWOCg6KBdF1Q8uMjMDHBQdVXzWk5nZPAdFB9ktPNz1ZGYGDoqOysUCNXc9mZkBDoqOykX5UahmZomDogNfmW1mtsBB0UG5WKAZ0PA4hZmZg6KTckkAPqowM8NB0VGlmO0WB4WZmYOio1KhdUThriczMwdFB+WSjyjMzFocFB2UU9eTb+NhZuag6Kg1RuGrs83MHBQdlT2YbWY2z0HRQamYDWa768nMzEHRkU+PNTNb4KDoYKHryWMUZmYOig7Kqeup7iMKMzMHRSet6yh8B1kzMwdFR+WCu57MzFocFB34poBmZgscFB34OgozswUOig4qPuvJzGxeLkEhaVrS7ZL2pNe1HdpcIenvJX1D0v2Srlmp+nxEYWa2IK8jim3Aroi4DNiV5hc7AfxSRLwYeDPwAUlrVqK41umxDgozs/yC4ipgR5reAVy9uEFEfCci9qTpR4GDwIUrUVzJd481M5uXV1Csj4h9aXo/sL5XY0mvAirAw4MuDDxGYWbWrjSoDUu6A7iow6ob2mciIiR1/Y0saQPwUWBrRHT8E1/S9cD1AJdccskZ19ziriczswUDC4qI2NJtnaQDkjZExL4UBAe7tDsf+BRwQ0Tc3eOztgPbAWZmZs76MKBYEJJv4WFmBvl1Pe0EtqbprcCtixtIqgC3AB+JiJtXsDYkUS4WqLrrycwst6C4EXiTpD3AljSPpBlJN6U2bwfeALxT0r3p64qVKrBckLuezMwYYNdTLxFxGLiyw/LdwHVp+i+Av1jh0uaVSwUHhZkZvjK7q3LRQWFmBg6KrirFgk+PNTPDQdFVuegxCjMzcFB0VXLXk5kZ4KDoqlwsUK2768nMzEHRRcVdT2ZmgIOiq3KxQL3poDAzc1B0US4WqLnryczMQdFNqSiq7noyM3NQdFPxWU9mZoCDoitfmW1mlnFQdFEuFaj7ymwzMwdFN2WPUZiZAQ6KrjxGYWaWcVB0USrKNwU0M8NB0VV2HYWPKMzMHBRdVIoFar4y28zMQdFN2c+jMDMDHBRdlYsFGs2g0XRYmNloc1B0USoKwGc+mdnIc1B0USlmu8ZBYWajzkHRRXn+iMJdT2Y22hwUXZRL2a6p+4jCzEacg6KLcup68m08zGzUOSi6cNeTmVnGQdFF2YPZZmaAg6Kr+a4n38bDzEacg6KL1umxdV9wZ2YjzkHRhbuezMwyDoou5q/MdteTmY04B0UXPj3WzCzjoOhi4RYeHqMws9HmoOiiXMq6nnxltpmNOgdFF+56MjPLOCi6cNeTmVkml6CQNC3pdkl70uvaDm0ulfQ1SfdK+oakX1nJGv08CjOzTF5HFNuAXRFxGbArzS+2D3hNRFwB/BiwTdKzV6pAX0dhZpbJKyiuAnak6R3A1YsbREQ1IubS7BgrXGvZXU9mZkB+QbE+Ival6f3A+k6NJG2SdD/wA+D3IuLRLu2ul7Rb0u5Dhw4tS4F+wp2ZWaY0qA1LugO4qMOqG9pnIiIkdfyzPSJ+AFyeupw+KenmiDjQod12YDvAzMzMshwClH1ltpkZMMCgiIgt3dZJOiBpQ0Tsk7QBOLjEth6V9HXg9cDNy1xqR8WCkHx6rJlZX11Pkj7az7LTsBPYmqa3Ard22P5GSRNpei3wOuDbZ/GZp0US46UiJ2uNlfpIM7Oh1O8YxYvbZyQVgVecxefeCLxJ0h5gS5pH0oykm1KbFwJflnQf8AXg9yPigbP4zNM2USky66AwsxHXs+tJ0nuB9wETko62FgNV0pjAmYiIw8CVHZbvBq5L07cDl5/pZyyH8VKBkzV3PZnZaOt5RBERvxsRq4D3R8T56WtVRKyLiPeuUI25GS+768nMrN+up9skTQFI+ueS/kDSpQOsayg4KMzM+g+KDwEnJL0UeA/wMPCRgVU1JMbL7noyM+s3KOoREWRXVH8wIv4EWDW4soaDB7PNzPoPiqfSwPa/AD4lqQCUB1fWcPDpsWZm/QfFNcAc8K6I2A9sBN4/sKqGxHjZRxRmZn0FRQqHjwGrJb0VOBkRIzBGUWTOYxRmNuL6vTL77cBXgH8GvJ3sQrifH2RhwyAbzPYRhZmNtn7v9XQD8MqIOAgg6ULgDlbovkt5mXDXk5lZ32MUhVZIJIdP473PWK3rKLITvszMRlO/RxSfkfRZ4ONp/hrg04MpaXiMlws0I7uD7FipmHc5Zma5WOpeT88ne8jQv5f0c2R3cAX4e7LB7XPaeDkLh5M1B4WZja6luo8+ABwFiIhPRMRvRsRvArekdee0VlDMeZzCzEbYUkGxvtOtvdOyzQOpaIhMpKDwgLaZjbKlgmJNj3UTy1jHUGrvejIzG1VLBcVuSf9y8UJJ1wH3DKak4TFeznaPjyjMbJQtddbTbwC3SHoHC8EwA1SAfzrAuobCxPwRhYPCzEZXz6CIiAPAj0v6CeAlafGnIuLzA69sCIw5KMzM+ruOIiLuBO4ccC1Dx0cUZmYjcHX12WiNUXgw28xGmYOih4mKT481M3NQ9DBecteTmZmDoodxX3BnZuag6GWs5DEKMzMHRQ+FghgrFXyvJzMbaQ6KJUxU/PAiMxttDooljJeKHsw2s5HmoFjCeLnArMcozGyEOSiW0HocqpnZqHJQLMFBYWajzkGxhAkHhZmNOAfFEsbLBV9HYWYjzUGxhPGyT481s9HmoFiCu57MbNQ5KJYwVi6668nMRlouQSFpWtLtkvak17U92p4vaa+kD65kjS0+ojCzUZfXEcU2YFdEXAbsSvPd/BfgiytSVQfZYLaDwsxGV15BcRWwI03vAK7u1EjSK4D1wOdWpqynGy8XqTeDWsPdT2Y2mvIKivURsS9N7ycLg1NIKgD/HfitpTYm6XpJuyXtPnTo0LIW6udmm9moKw1qw5LuAC7qsOqG9pmICEnRod2vAp+OiL2Sen5WRGwHtgPMzMx02tYZa39u9qrx5dyymdkzw8CCIiK2dFsn6YCkDRGxT9IG4GCHZq8BXi/pV4HzgIqkYxHRazxj2Y37iMLMRtzAgmIJO4GtwI3p9dbFDSLiHa1pSe8EZlY6JMBBYWaW1xjFjcCbJO0BtqR5JM1Iuimnmjryc7PNbNTlckQREYeBKzss3w1c12H5h4EPD7ywDhYGs33Wk5mNJl+ZvYTWYLaPKMxsVDkoluAxCjMbdQ6KJTgozGzUOSiWsHAdhYPCzEaTg2IJHsw2s1HnoFiCT481s1HnoFiCxyjMbNQ5KJZQLIhKseAjCjMbWQ6KPpw/UeLobD3vMszMcuGg6MPayQpPHK/mXYaZWS4cFH1YO1Xh8RMOCjMbTQ6KPqybqvC4jyjMbEQ5KPqwdspdT2Y2uhwUfZierPDEiSrN5rI+PM/M7BnBQdGH6akKzYAjs7W8SzEzW3EOij5MT1UAPKBtZiPJQdGHtSkoPE5hZqPIQdGHdSkoDjsozGwEOSj64CMKMxtlDoo+TE96jMLMRpeDog8TlSIT5SKPH3NQmNnocVD0adq38TCzEeWg6NPaqbLHKMxsJDko+jQ9Neb7PZnZSHJQ9Gl6suyuJzMbSQ6KPk1PjfHEcd/Cw8xGj4OiT9NTZY7N1Zmr+5GoZjZaHBR9WrjozkcVZjZaHBR9at3GwwPaZjZqHBR9WjvpoDCz0eSg6JNvNW5mo8pB0adp3xjQzEaUg6JPqyfKSL7VuJmNHgdFn0rFAmsnKxx66mTepZiZrahcgkLStKTbJe1Jr2u7tGtIujd97VzpOhd7wfpVfOPRo3mXYWa2ovI6otgG7IqIy4Bdab6T2Yi4In29beXK6+zyTat5cN9RX3RnZiMlr6C4CtiRpncAV+dUx2l56cY11BrBt/c/lXcpZmYrJq+gWB8R+9L0fmB9l3bjknZLulvS1d02Jun61G73oUOHlrvWeT968WoA7tt7ZGCfYWY2bEqD2rCkO4CLOqy6oX0mIkJSdNnMpRHxQ0nPBT4v6YGIeHhxo4jYDmwHmJmZ6bats7Zx7QTTUxUe2PskcOmgPsbMbKgMLCgiYku3dZIOSNoQEfskbQAOdtnGD9PrdyXdBbwMeFpQrBRJXL5xNff7iMLMRkheXU87ga1peitw6+IGktZKGkvTFwCvBb65YhV2cfnFq/nOgac4Ua3nXYqZ2YrIKyhuBN4kaQ+wJc0jaUbSTanNC4Hdku4D7gRujIj8g2LjGpqBT5M1s5ExsK6nXiLiMHBlh+W7gevS9N8BP7rCpS3p8o3ZgPb9e4/wys3TOVdjZjZ4vjL7ND3r/HEuXjPB576xn4iBjZubmQ0NB8UZuP4Nz+XLjzzOHQ92HIM3MzunOCjOwC/+2CU878Ip/tunH6Rab+ZdjpnZQDkozkC5WOA/vvVFPPLYcT5wx3doNt0FZWbnLgfFGfqJFzyLt7302fzpXQ9z7f+8mwf2HvGYhZmdk3I56+lc8UfXXsFrn7+O/3rbg/yTD/4Nz1o1xisuXcsl6ya5dHqKS9dN8uw1E6w7r8KqsRKS8i7ZzOy0OSjOgiSueeUlbHnhenZ96yBf+M4hHnz0KHc8eIBa49Sji0qxwLrzKqw7r8L01BgXTFXS/Bjr0vTqiTKrJ1qvZSolH/CZWf4cFMtg3XljvH1mE2+f2QRAoxnsP3qS7x8+wb4jszx+vMpjx6ocPjbH4ePZ68MHj3H4+Bwna90HwycrRdZMlFk9WWH1RIk1KUTWTJY5P72uniizarzMqvESq8ZKrBovc954ialK0UcwZrYsHBQDUCyIi9dMcPGaiSXbnqjWOXysymPH5jgyW1v4OlHjyTT95IkaR2drfPexYxyZrfHEidqSZ1sVBFNjJc4fL3PeWIlV4yXOG09BMlbi/PFS2/LOQTM5VmKyXKRQcOCYjTIHRc4mKyUmp0tsmp48rfedrDXmQ+TYXI2jJ+scO1nnqZN1js3VeCpNt88/frzK9w6fSMtrzPV5au9EucjUWJHJSompsYUQmaq0lqXXSjFbP7Z4eYnJsSJTbcuKDh+zZwwHxTPUeLnIeLnI+vPHz3gb1XqTY3NZwBw9WePYXBYsx+fqHK/WOTHXyF6rDY7PZa/H5uqcqNY5Oltj/5FZjs81OFGtc7zaOK1rSsbLBaYqJcbLRSYrRSYq2fczkb4mK0XGKwvzE5VTX9vfd8p8ajNWKrjrzWyZOChGWKVUYLpUYXqqsizbqzWabeFS53hrelHgLIRLndlqk9landlqg9lagydna+w7MstsrcFstcnJWtb2dC9VkWC8lAKnQ9C0prPALTBWOvV1vJyFTbf12WuRsXKB8VKRclEOJjtnOShs2ZSLBVZPFlg9WV7W7UYE1UaTk9Umsyk4ZmsNTqYwOXW+wYlag5MpeFqB0yuM5moN5upNqo0zv8q+IDqGzFipwFg64pkPnlJhPmA6BdJYucBYqUClVKBSzOYrxWx+fnkptS1l6zyOZIPkoLChJyn9UiyymuUNoXaNZjBXbzBXa3Ky3uBkrclcej2ZwuSU17bpbm1P1rMgOjJbO6X9wvsap3201Em5qLYwKc6HSaVYWBQ0xfmwGXtam2KHMCq0tS923WZrebkoigUfXZ1rHBRmSbGg7OSC5emJ60tEUG/G08Kmmo5wqvVsWbXemk6vjSyAFto059vMv6dt3Vy9yVMn6zxWr1JN6+ZqzVPaNJbpVjRSdnQ5VixQLmXhUS4uBEu52LZsPmAW2o7Ntymk96ntfVm7xct6bjttt/U5lVKBksPstDgozHIkaf4X26ozPy9hWTSacUowzbUHUI9gykIpqDWa1FoB12hSq6dljYVtZPMxP39srp7eF/PbrnVoNwitI6BTAqe0sKxUyIKmlP59ykVRKqTgKWTLTlnXCrKCKKcwam2nlIKqvX25mK2bn55v075e6fPyPVpzUJgZkB1RTaQzyYZJ66jr1MAJailU5uoLwVJrNJ8eOB1CqJpCaCGUmlRTu9Z2q20hd7zaoFZvUm8ufE6t0aQ+Px3z6watFTjtQVQuiXKhwIsvXs0f/8LLlv0zHRRmNtTaj7pWslvwTLSHWq0R1BuLgiUdtdWb2brqorDJ2rRNNxZvq0mtGSm0Ir0/a1dtNNm0dumLfM+Eg8LMbJm0h9q55Nz6bszMbNk5KMzMrCcHhZmZ9eSgMDOznhwUZmbWk4PCzMx6clCYmVlPDgozM+tJEYO/5HwlSToEfO8sNnEB8NgylTMow17jsNcHrnG5uMblMQw1XhoRF3Zacc4FxdmStDsiZvKuo5dhr3HY6wPXuFxc4/IY9hrd9WRmZj05KMzMrCcHxdNtz7uAPgx7jcNeH7jG5eIal8dQ1+gxCjMz68lHFGZm1pODwszMenJQJJLeLOnbkh6StC3vegAkbZJ0p6RvSvqGpF9Py6cl3S5pT3pdOwS1FiX9P0m3pfnnSPpy2p//R1KuzyaTtEbSzZK+JelBSa8Zpv0o6d+lf+OvS/q4pPFh2IeS/pekg5K+3ras435T5n+keu+X9PKc6nt/+ne+X9Itkta0rXtvqu/bkn560PV1q7Ft3XskhaQL0vyK78N+OCjIfskBfwK8BXgR8AuSXpRvVQDUgfdExIuAVwP/OtW1DdgVEZcBu9J83n4deLBt/veAP4yI5wNPAL+cS1UL/gj4TET8CPBSslqHYj9Kuhj4NWAmIl4CFIFrGY59+GHgzYuWddtvbwEuS1/XAx/Kqb7bgZdExOXAd4D3AqSfnWuBF6f3/Gn62c+jRiRtAn4K+H7b4jz24ZIcFJlXAQ9FxHcjogr8JXBVzjUREfsi4mtp+imyX24Xk9W2IzXbAVydS4GJpI3AzwI3pXkBPwncnJrkWqOk1cAbgD8HiIhqRDzJcO3HEjAhqQRMAvsYgn0YEV8EHl+0uNt+uwr4SGTuBtZI2rDS9UXE5yKinmbvBja21feXETEXEY8AD5H97A9Ul30I8IfAbwPtZxSt+D7sh4MiczHwg7b5vWnZ0JC0GXgZ8GVgfUTsS6v2A+vzqiv5ANl/+GaaXwc82fbDmvf+fA5wCPjfqXvsJklTDMl+jIgfAr9P9pflPuAIcA/DtQ/bddtvw/hz9C7g/6bpoalP0lXADyPivkWrhqbGdg6KZwBJ5wF/DfxGRBxtXxfZ+c25neMs6a3AwYi4J68a+lACXg58KCJeBhxnUTdTnvsx9fFfRRZozwam6NBVMYzy/v/Xi6QbyLpvP5Z3Le0kTQLvA/5T3rX0y0GR+SGwqW1+Y1qWO0llspD4WER8Ii0+0DocTa8H86oPeC3wNkn/QNZl95Nk4wFrUjcK5L8/9wJ7I+LLaf5msuAYlv24BXgkIg5FRA34BNl+HaZ92K7bfhuanyNJ7wTeCrwjFi4WG5b6nkf2R8F96edmI/A1SRcxPDWewkGR+SpwWTrLpEI24LUz55paff1/DjwYEX/QtmonsDVNbwVuXenaWiLivRGxMSI2k+23z0fEO4A7gZ9PzfKucT/wA0kvSIuuBL7J8OzH7wOvljSZ/s1b9Q3NPlyk237bCfxSOnPn1cCRti6qFSPpzWRdoW+LiBNtq3YC10oak/QcsgHjr6x0fRHxQEQ8KyI2p5+bvcDL0//TodiHTxMR/sr+4PgZsjMkHgZuyLueVNPryA7r7wfuTV8/QzYGsAvYA9wBTOdda6r3jcBtafq5ZD+EDwF/BYzlXNsVwO60Lz8JrB2m/Qj8Z+BbwNeBjwJjw7APgY+TjZvUyH6h/XK3/QaI7OzBh4EHyM7iyqO+h8j6+Vs/M3/W1v6GVN+3gbfktQ8Xrf8H4IK89mE/X76Fh5mZ9eSuJzMz68lBYWZmPTkozMysJweFmZn15KAwM7OeHBRmPUg6ll43S/rFZd72+xbN/91ybt9suTgozPqzGTitoGi7qrqbU4IiIn78NGsyWxEOCrP+3Ai8XtK96dkRxfTcg6+m5wa8G0DSGyV9SdJOsqurkfRJSfcoe97E9WnZjWR3i71X0sfSstbRi9K2vy7pAUnXtG37Li08V+Nj6Upus4Fa6i8eM8tsA34rIt4KkH7hH4mIV0oaA/5W0udS25eTPQ/hkTT/roh4XNIE8FVJfx0R2yT9m4i4osNn/RzZleQvBS5I7/liWvcysucpPAr8Ldk9of5mub9Zs3Y+ojA7Mz9Fdk+ee8lu/b6O7N5BAF9pCwmAX5N0H9mzETa1tevmdcDHI6IREQeALwCvbNv23ohokt2eYvMyfC9mPfmIwuzMCPi3EfHZUxZKbyS7jXn7/BbgNRFxQtJdwPhZfO5c23QD/wzbCvARhVl/ngJWtc1/FvhX6TbwSPpH6WFIi60Gnkgh8SNkj7RtqbXev8iXgGvSOMiFZE/nW/G7nJq1+K8Rs/7cDzRSF9KHyZ65sZnsOQIie4Le1R3e9xngVyQ9SHbH0rvb1m0H7pf0tchuzd5yC/Aa4D6yuwf/dkTsT0FjtuJ891gzM+vJXU9mZtaTg8LMzHpyUJiZWU8OCjMz68lBYWZmPTkozMysJweFmZn19P8BsfyWFS7UICQAAAAASUVORK5CYII=\n",
471 | "text/plain": [
472 | ""
473 | ]
474 | },
475 | "metadata": {
476 | "needs_background": "light"
477 | },
478 | "output_type": "display_data"
479 | }
480 | ],
481 | "source": [
482 | "plt.plot(cost_vals)\n",
483 | "plt.xlabel(\"Iteration\")\n",
484 | "plt.ylabel(\"Cost\");"
485 | ]
486 | }
487 | ],
488 | "metadata": {
489 | "kernelspec": {
490 | "display_name": "Python 3 (ipykernel)",
491 | "language": "python",
492 | "name": "python3"
493 | },
494 | "language_info": {
495 | "codemirror_mode": {
496 | "name": "ipython",
497 | "version": 3
498 | },
499 | "file_extension": ".py",
500 | "mimetype": "text/x-python",
501 | "name": "python",
502 | "nbconvert_exporter": "python",
503 | "pygments_lexer": "ipython3",
504 | "version": "3.9.10"
505 | }
506 | },
507 | "nbformat": 4,
508 | "nbformat_minor": 5
509 | }
510 |
--------------------------------------------------------------------------------
/qujax/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Simulating quantum circuits with JAX
3 | """
4 |
5 | from qujax.version import __version__
6 |
7 | from qujax import gates
8 |
9 | from qujax.statetensor import all_zeros_statetensor
10 | from qujax.statetensor import apply_gate
11 | from qujax.statetensor import get_params_to_statetensor_func
12 | from qujax.statetensor import get_params_to_unitarytensor_func
13 |
14 | from qujax.statetensor_observable import statetensor_to_single_expectation
15 | from qujax.statetensor_observable import get_statetensor_to_expectation_func
16 | from qujax.statetensor_observable import get_statetensor_to_sampled_expectation_func
17 |
18 | from qujax.densitytensor import all_zeros_densitytensor
19 | from qujax.densitytensor import _kraus_single
20 | from qujax.densitytensor import kraus
21 | from qujax.densitytensor import get_params_to_densitytensor_func
22 | from qujax.densitytensor import partial_trace
23 |
24 | from qujax.densitytensor_observable import densitytensor_to_single_expectation
25 | from qujax.densitytensor_observable import get_densitytensor_to_expectation_func
26 | from qujax.densitytensor_observable import get_densitytensor_to_sampled_expectation_func
27 | from qujax.densitytensor_observable import densitytensor_to_measurement_probabilities
28 | from qujax.densitytensor_observable import densitytensor_to_measured_densitytensor
29 |
30 | from qujax.utils import check_unitary
31 | from qujax.utils import check_hermitian
32 | from qujax.utils import check_circuit
33 | from qujax.utils import print_circuit
34 | from qujax.utils import integers_to_bitstrings
35 | from qujax.utils import bitstrings_to_integers
36 | from qujax.utils import repeat_circuit
37 | from qujax.utils import sample_integers
38 | from qujax.utils import sample_bitstrings
39 | from qujax.utils import statetensor_to_densitytensor
40 |
41 | import qujax.typing
42 |
43 | # pylint: disable=undefined-variable
44 | del version
45 | del statetensor
46 | del statetensor_observable
47 | del densitytensor
48 | del densitytensor_observable
49 | del utils
50 |
--------------------------------------------------------------------------------
/qujax/densitytensor.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Iterable, Sequence, Tuple, Union, Optional
4 |
5 | import jax
6 | from jax import numpy as jnp
7 | from jax.typing import ArrayLike
8 | from jax.lax import scan
9 | from jax._src.dtypes import canonicalize_dtype
10 | from jax._src.typing import DTypeLike
11 | from qujax.statetensor import (
12 | _arrayify_inds,
13 | _gate_func_to_unitary,
14 | _to_gate_func,
15 | apply_gate,
16 | )
17 | from qujax.utils import check_circuit
18 | from qujax.typing import (
19 | MixedCircuitFunction,
20 | KrausOp,
21 | GateFunction,
22 | GateParameterIndices,
23 | )
24 |
25 |
26 | def _kraus_single(
27 | densitytensor: jax.Array, array: jax.Array, qubit_inds: Sequence[int]
28 | ) -> jax.Array:
29 | r"""
30 | Performs single Kraus operation
31 |
32 | .. math::
33 |
34 | \rho_\text{out} = B \rho_\text{in} B^{\dagger}
35 |
36 | Args:
37 | densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits
38 | array: Array containing the Kraus operator (in tensor form).
39 | qubit_inds: Sequence of qubit indices on which to apply the Kraus operation.
40 |
41 | Returns:
42 | Updated density matrix.
43 | """
44 | n_qubits = densitytensor.ndim // 2
45 | densitytensor = apply_gate(densitytensor, array, qubit_inds)
46 | densitytensor = apply_gate(
47 | densitytensor, array.conj(), [n_qubits + i for i in qubit_inds]
48 | )
49 | return densitytensor
50 |
51 |
52 | def kraus(
53 | densitytensor: jax.Array, arrays: Iterable[jax.Array], qubit_inds: Sequence[int]
54 | ) -> jax.Array:
55 | r"""
56 | Performs Kraus operation.
57 |
58 | .. math::
59 | \rho_\text{out} = \sum_i B_i \rho_\text{in} B_i^{\dagger}
60 |
61 | Args:
62 | densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits
63 | arrays: Sequence of arrays containing the Kraus operators (in tensor form).
64 | qubit_inds: Sequence of qubit indices on which to apply the Kraus operation.
65 |
66 | Returns:
67 | Updated density matrix.
68 | """
69 | arrays = jnp.array(arrays)
70 | if arrays.ndim % 2 == 0:
71 | arrays = arrays[jnp.newaxis]
72 | # ensure first dimensions indexes different kraus operators
73 | arrays = arrays.reshape((arrays.shape[0],) + (2,) * 2 * len(qubit_inds))
74 |
75 | new_densitytensor, _ = scan(
76 | lambda dt, arr: (dt + _kraus_single(densitytensor, arr, qubit_inds), None),
77 | init=jnp.zeros_like(densitytensor) * 0.0j,
78 | xs=arrays,
79 | )
80 | # new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))(
81 | # densitytensor, arrays, qubit_inds
82 | # ).sum(0)
83 | return new_densitytensor
84 |
85 |
86 | def _to_kraus_operator_seq_funcs(
87 | kraus_op: KrausOp,
88 | param_inds: Optional[Union[GateParameterIndices, Sequence[GateParameterIndices]]],
89 | ) -> Tuple[Sequence[GateFunction], Sequence[jax.Array]]:
90 | """
91 | Ensures Kraus operators are a sequence of functions that map (possibly empty) parameters to
92 | tensors and that each element of param_inds_seq is a sequence of arrays that correspond to the
93 | parameter indices of each Kraus operator.
94 |
95 | Args:
96 | kraus_op: Either a normal Gate or a sequence of Gates representing Kraus operators.
97 | param_inds: If kraus_op is a normal Gate then a sequence of parameter indices,
98 | if kraus_op is a sequence of Kraus operators then a sequence of sequences of
99 | parameter indices
100 |
101 | Returns:
102 | Tuple containing sequence of functions mapping to Kraus operators
103 | and sequence of arrays with parameter indices
104 |
105 | """
106 | if isinstance(kraus_op, (list, tuple)):
107 | kraus_op_funcs = [_to_gate_func(ko) for ko in kraus_op]
108 | if param_inds is None:
109 | param_inds = [None for _ in kraus_op]
110 | elif isinstance(kraus_op, (str, jax.Array)) or callable(kraus_op):
111 | kraus_op_funcs = [_to_gate_func(kraus_op)]
112 | param_inds = [param_inds]
113 | else:
114 | raise ValueError(f"Invalid Kraus operator specification: {kraus_op}")
115 | return kraus_op_funcs, _arrayify_inds(param_inds)
116 |
117 |
118 | def partial_trace(
119 | densitytensor: jax.Array, indices_to_trace: Sequence[int]
120 | ) -> jax.Array:
121 | """
122 | Traces out (discards) specified qubits, resulting in a densitytensor
123 | representing the mixed quantum state on the remaining qubits.
124 |
125 | Args:
126 | densitytensor: Input densitytensor.
127 | indices_to_trace: Indices of qubits to trace out/discard.
128 |
129 | Returns:
130 | Resulting densitytensor on remaining qubits.
131 |
132 | """
133 | n_qubits = densitytensor.ndim // 2
134 | einsum_indices = list(range(densitytensor.ndim))
135 | for i in indices_to_trace:
136 | einsum_indices[i + n_qubits] = einsum_indices[i]
137 | densitytensor = jnp.einsum(densitytensor, einsum_indices)
138 | return densitytensor
139 |
140 |
141 | def all_zeros_densitytensor(n_qubits: int, dtype: DTypeLike = complex) -> jax.Array:
142 | """
143 | Returns a densitytensor representation of the all-zeros state |00...0> on `n_qubits` qubits
144 |
145 | Args:
146 | n_qubits: Number of qubits that the state is defined on.
147 | dtype: Data type of the densitytensor returned.
148 |
149 | Returns:
150 | Densitytensor representing the state having all qubits set to zero.
151 | """
152 | densitytensor = jnp.zeros((2,) * 2 * n_qubits, canonicalize_dtype(dtype))
153 | densitytensor = densitytensor.at[(0,) * 2 * n_qubits].set(1.0)
154 | return densitytensor
155 |
156 |
157 | def get_params_to_densitytensor_func(
158 | kraus_ops_seq: Sequence[KrausOp],
159 | qubit_inds_seq: Sequence[Sequence[int]],
160 | param_inds_seq: Sequence[
161 | Union[GateParameterIndices, Sequence[GateParameterIndices]]
162 | ],
163 | n_qubits: Optional[int] = None,
164 | ) -> MixedCircuitFunction:
165 | """
166 | Creates a function that maps circuit parameters to a density tensor (a density matrix in
167 | tensor form).
168 | densitytensor = densitymatrix.reshape((2,) * 2 * n_qubits)
169 | densitymatrix = densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits)
170 |
171 | Args:
172 | kraus_ops_seq: Sequence of gates.
173 | Each element is either a string matching a unitary array or function in qujax.gates,
174 | a custom unitary array or a custom function taking parameters and returning a unitary
175 | array. Unitary arrays will be reshaped into tensor form (2, 2,...)
176 | qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are
177 | acting on.
178 | i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the
179 | zeroth qubit, the second gate is a two qubit gate acting on the zeroth and
180 | first qubit etc.
181 | param_inds_seq: Sequence of sequences representing parameter indices that gates are using,
182 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter
183 | (the float at position zero in the parameter vector/array), the second gate is not
184 | parameterised and the third gate uses the parameters at position five and two.
185 | n_qubits: Number of qubits, if fixed.
186 |
187 | Returns:
188 | Function which maps parameters (and optional densitytensor_in) to a densitytensor.
189 | If no parameters are found then the function only takes optional densitytensor_in.
190 |
191 | """
192 |
193 | check_circuit(kraus_ops_seq, qubit_inds_seq, param_inds_seq, n_qubits, False)
194 |
195 | if n_qubits is None:
196 | n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1
197 |
198 | kraus_ops_seq_callable_and_param_inds = [
199 | _to_kraus_operator_seq_funcs(ko, param_inds)
200 | for ko, param_inds in zip(kraus_ops_seq, param_inds_seq)
201 | ]
202 | kraus_ops_seq_callable = [
203 | ko_pi[0] for ko_pi in kraus_ops_seq_callable_and_param_inds
204 | ]
205 | param_inds_array_seq = [ko_pi[1] for ko_pi in kraus_ops_seq_callable_and_param_inds]
206 |
207 | def params_to_densitytensor_func(
208 | params: ArrayLike, densitytensor_in: Optional[jax.Array] = None
209 | ) -> jax.Array:
210 | """
211 | Applies parameterised circuit (series of gates) to a densitytensor_in
212 | (default is |0>^N <0|^N).
213 |
214 | Args:
215 | params: Parameters of the circuit.
216 | densitytensor_in: Optional. Input densitytensor.
217 | Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in
218 | the [0]*(2*N) index).
219 |
220 | Returns:
221 | Updated densitytensor.
222 |
223 | """
224 | if densitytensor_in is None:
225 | densitytensor = all_zeros_densitytensor(n_qubits)
226 | else:
227 | densitytensor = densitytensor_in
228 | params = jnp.atleast_1d(params)
229 | # Guarantee `params` has the right type for type-checking purposes
230 | if not isinstance(params, jax.Array):
231 | raise ValueError("This should not happen. Please open an issue on GitHub.")
232 | for gate_func_single_seq, qubit_inds, param_inds_single_seq in zip(
233 | kraus_ops_seq_callable, qubit_inds_seq, param_inds_array_seq
234 | ):
235 | kraus_operators = [
236 | _gate_func_to_unitary(gf, qubit_inds, pi, params)
237 | for gf, pi in zip(gate_func_single_seq, param_inds_single_seq)
238 | ]
239 | densitytensor = kraus(densitytensor, kraus_operators, qubit_inds)
240 | return densitytensor
241 |
242 | non_parameterised = all(
243 | [all([pi.size == 0 for pi in pi_seq]) for pi_seq in param_inds_array_seq]
244 | )
245 | if non_parameterised:
246 |
247 | def no_params_to_densitytensor_func(
248 | densitytensor_in: Optional[jax.Array] = None,
249 | ) -> jax.Array:
250 | """
251 | Applies circuit (series of gates with no parameters) to a densitytensor_in
252 | (default is |0>^N <0|^N).
253 |
254 | Args:
255 | densitytensor_in: Optional. Input densitytensor.
256 | Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in
257 | the [0]*(2*N) index).
258 |
259 | Returns:
260 | Updated densitytensor.
261 |
262 | """
263 | return params_to_densitytensor_func(jnp.array([]), densitytensor_in)
264 |
265 | return no_params_to_densitytensor_func
266 |
267 | return params_to_densitytensor_func
268 |
--------------------------------------------------------------------------------
/qujax/densitytensor_observable.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Callable, Sequence, Union
4 |
5 | import jax
6 | from jax.typing import ArrayLike
7 | from jax import numpy as jnp
8 | from jax import random
9 |
10 | from qujax.densitytensor import _kraus_single, partial_trace
11 | from qujax.statetensor_observable import _get_tensor_to_expectation_func, sample_probs
12 | from qujax.utils import bitstrings_to_integers, check_hermitian
13 |
14 |
15 | def densitytensor_to_single_expectation(
16 | densitytensor: jax.Array, hermitian: jax.Array, qubit_inds: Sequence[int]
17 | ) -> jax.Array:
18 | """
19 | Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form).
20 |
21 | Args:
22 | densitytensor: Input densitytensor.
23 | hermitian: Hermitian matrix representing observable
24 | must be in tensor form with shape (2,2,...).
25 | qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to.
26 | Must have 2 * len(qubit_inds) == hermitian.ndim
27 | Returns:
28 | Expected value (float).
29 | """
30 | n_qubits = densitytensor.ndim // 2
31 | dt_indices = 2 * list(range(n_qubits))
32 | hermitian_indices = [i + densitytensor.ndim // 2 for i in range(hermitian.ndim)]
33 | for n, q in enumerate(qubit_inds):
34 | dt_indices[q] = hermitian_indices[n + len(qubit_inds)]
35 | dt_indices[q + n_qubits] = hermitian_indices[n]
36 | return jnp.einsum(densitytensor, dt_indices, hermitian, hermitian_indices).real
37 |
38 |
39 | def get_densitytensor_to_expectation_func(
40 | hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]],
41 | qubits_seq_seq: Sequence[Sequence[int]],
42 | coefficients: Union[Sequence[float], jax.Array],
43 | ) -> Callable[[jax.Array], float]:
44 | """
45 | Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and
46 | a list of coefficients and returns a function that converts a densitytensor into an
47 | expected value.
48 |
49 | Args:
50 | hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors.
51 | Each Hermitian matrix is either represented by a tensor (jax.Array)
52 | or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices.
53 | E.g. [['Z', 'Z'], ['X']]
54 | qubits_seq_seq: Sequence of sequences of integer qubit indices.
55 | E.g. [[0,1], [2]]
56 | coefficients: Sequence of float coefficients to scale the expected values.
57 |
58 | Returns:
59 | Function that takes densitytensor and returns expected value (float).
60 | """
61 |
62 | return _get_tensor_to_expectation_func(
63 | hermitian_seq_seq,
64 | qubits_seq_seq,
65 | coefficients,
66 | densitytensor_to_single_expectation,
67 | )
68 |
69 |
70 | def get_densitytensor_to_sampled_expectation_func(
71 | hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]],
72 | qubits_seq_seq: Sequence[Sequence[int]],
73 | coefficients: Union[Sequence[float], jax.Array],
74 | ) -> Callable[[jax.Array, random.PRNGKeyArray, int], float]:
75 | """
76 | Converts strings (or arrays) representing Hermitian matrices, qubit indices and
77 | coefficients into a function that converts a densitytensor into a sampled expected value.
78 |
79 | On a quantum device, measurements are always taken in the computational basis, as such
80 | sampled expectation values should be taken with respect to an observable that commutes
81 | with the Pauli Z - a warning will be raised if it does not.
82 |
83 | qujax applies an importance sampling heuristic for sampled expectation values that only
84 | reflects the physical notion of measurement in the case that the observable commutes with Z.
85 | In the case that it does not, the expectation value will still be asymptotically unbiased
86 | but not representative of an experiment on a real quantum device.
87 |
88 | Args:
89 | hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors.
90 | Each Hermitian is either a tensor (jax.Array) or a string in ('X', 'Y', 'Z').
91 | E.g. [['Z', 'Z'], ['X']]
92 | qubits_seq_seq: Sequence of sequences of integer qubit indices.
93 | E.g. [[0,1], [2]]
94 | coefficients: Sequence of float coefficients to scale the expected values.
95 |
96 | Returns:
97 | Function that takes densitytensor, random key and integer number of shots
98 | and returns sampled expected value (float).
99 | """
100 | densitytensor_to_expectation_func = get_densitytensor_to_expectation_func(
101 | hermitian_seq_seq, qubits_seq_seq, coefficients
102 | )
103 |
104 | for hermitian_seq in hermitian_seq_seq:
105 | for h in hermitian_seq:
106 | check_hermitian(h, check_z_commutes=True)
107 |
108 | def densitytensor_to_sampled_expectation_func(
109 | densitytensor: jax.Array, random_key: random.PRNGKeyArray, n_samps: int
110 | ) -> float:
111 | """
112 | Maps densitytensor to sampled expected value.
113 |
114 | Args:
115 | densitytensor: Input densitytensor.
116 | random_key: JAX random key
117 | n_samps: Number of samples contributing to sampled expectation.
118 |
119 | Returns:
120 | Sampled expected value (float).
121 |
122 | """
123 | n_qubits = densitytensor.ndim // 2
124 | dm = densitytensor.reshape((2**n_qubits, 2**n_qubits))
125 | measure_probs = jnp.diag(dm).real
126 | sampled_probs = sample_probs(measure_probs, random_key, n_samps)
127 | iweights = jnp.sqrt(sampled_probs / measure_probs)
128 | return densitytensor_to_expectation_func(
129 | densitytensor * jnp.outer(iweights, iweights).reshape(densitytensor.shape)
130 | )
131 |
132 | return densitytensor_to_sampled_expectation_func
133 |
134 |
135 | def densitytensor_to_measurement_probabilities(
136 | densitytensor: jax.Array, qubit_inds: Sequence[int]
137 | ) -> jax.Array:
138 | """
139 | Extract array of measurement probabilities given a densitytensor and some qubit indices to
140 | measure (in the computational basis).
141 | I.e. the ith element of the array corresponds to the probability of observing the bitstring
142 | represented by the integer i on the measured qubits.
143 |
144 | Args:
145 | densitytensor: Input densitytensor.
146 | qubit_inds: Sequence of qubit indices to measure.
147 |
148 | Returns:
149 | Normalised array of measurement probabilities.
150 | """
151 | n_qubits = densitytensor.ndim // 2
152 | n_qubits_measured = len(qubit_inds)
153 | qubit_inds_trace_out = [i for i in range(n_qubits) if i not in qubit_inds]
154 | return jnp.diag(
155 | partial_trace(densitytensor, qubit_inds_trace_out).reshape(
156 | 2 * n_qubits_measured, 2 * n_qubits_measured
157 | )
158 | ).real
159 |
160 |
161 | def densitytensor_to_measured_densitytensor(
162 | densitytensor: jax.Array,
163 | qubit_inds: Sequence[int],
164 | measurement: ArrayLike,
165 | ) -> jax.Array:
166 | """
167 | Returns the post-measurement densitytensor assuming that qubit_inds are measured
168 | (in the computational basis) and the given measurement (integer or bitstring) is observed.
169 |
170 | Args:
171 | densitytensor: Input densitytensor.
172 | qubit_inds: Sequence of qubit indices to measure.
173 | measurement: Observed integer or bitstring.
174 |
175 | Returns:
176 | Post-measurement densitytensor (same shape as input densitytensor).
177 | """
178 | measurement = jnp.array(measurement)
179 | measured_int = (
180 | bitstrings_to_integers(measurement) if measurement.ndim == 1 else measurement
181 | )
182 |
183 | n_qubits = densitytensor.ndim // 2
184 | n_qubits_measured = len(qubit_inds)
185 | qubit_inds_projector = jnp.diag(
186 | jnp.zeros(2**n_qubits_measured).at[measured_int].set(1)
187 | ).reshape((2,) * 2 * n_qubits_measured)
188 | unnorm_densitytensor = _kraus_single(
189 | densitytensor, qubit_inds_projector, qubit_inds
190 | )
191 | norm_const = jnp.trace(unnorm_densitytensor.reshape(2**n_qubits, 2**n_qubits)).real
192 | return unnorm_densitytensor / norm_const
193 |
--------------------------------------------------------------------------------
/qujax/gates.py:
--------------------------------------------------------------------------------
1 | import jax
2 | from jax import numpy as jnp
3 |
4 | I = jnp.eye(2)
5 |
6 | _0 = jnp.zeros((2, 2))
7 |
8 | X = jnp.array([[0.0, 1.0], [1.0, 0.0]])
9 |
10 | Y = jnp.array([[0.0, -1.0j], [1.0j, 0.0]])
11 |
12 | Z = jnp.array([[1.0, 0.0], [0.0, -1.0]])
13 |
14 | H = jnp.array([[1.0, 1.0], [1.0, -1]]) / jnp.sqrt(2)
15 |
16 | S = jnp.array([[1.0, 0.0], [0.0, 1.0j]])
17 |
18 | Sdg = jnp.array([[1.0, 0.0], [0.0, -1.0j]])
19 |
20 | T = jnp.array([[1.0, 0.0], [0.0, jnp.exp(jnp.pi * 1.0j / 4)]])
21 |
22 | Tdg = jnp.array([[1.0, 0.0], [0.0, jnp.exp(-jnp.pi * 1.0j / 4)]])
23 |
24 | V = jnp.array([[1.0, -1.0j], [-1.0j, 1.0]]) / jnp.sqrt(2)
25 |
26 | Vdg = jnp.array([[1.0, 1.0j], [1.0j, 1.0]]) / jnp.sqrt(2)
27 |
28 | SX = jnp.array([[1.0 + 1.0j, 1.0 - 1.0j], [1.0 - 1.0j, 1.0 + 1.0j]]) / 2
29 |
30 | SXdg = jnp.array([[1.0 - 1.0j, 1.0 + 1.0j], [1.0 + 1.0j, 1.0 - 1.0j]]) / 2
31 |
32 | CX = jnp.block([[I, _0], [_0, X]]).reshape((2,) * 4)
33 |
34 | CY = jnp.block([[I, _0], [_0, Y]]).reshape((2,) * 4)
35 |
36 | CZ = jnp.block([[I, _0], [_0, Z]]).reshape((2,) * 4)
37 |
38 | CH = jnp.block([[I, _0], [_0, H]]).reshape((2,) * 4)
39 |
40 | CV = jnp.block([[I, _0], [_0, V]]).reshape((2,) * 4)
41 |
42 | CVdg = jnp.block([[I, _0], [_0, Vdg]]).reshape((2,) * 4)
43 |
44 | CSX = jnp.block([[I, _0], [_0, SX]]).reshape((2,) * 4)
45 |
46 | CSXdg = jnp.block([[I, _0], [_0, SXdg]]).reshape((2,) * 4)
47 |
48 | CCX = jnp.block(
49 | [[I, _0, _0, _0], [_0, I, _0, _0], [_0, _0, I, _0], [_0, _0, _0, X]] # Toffoli gate
50 | ).reshape((2,) * 6)
51 |
52 | ECR = jnp.block([[_0, Vdg], [V, _0]]).reshape((2,) * 4)
53 |
54 | SWAP = jnp.array(
55 | [
56 | [1.0, 0.0, 0.0, 0.0],
57 | [0.0, 0.0, 1.0, 0.0],
58 | [0.0, 1.0, 0.0, 0.0],
59 | [0.0, 0.0, 0.0, 1],
60 | ]
61 | )
62 |
63 | CSWAP = jnp.block([[jnp.eye(4), jnp.zeros((4, 4))], [jnp.zeros((4, 4)), SWAP]]).reshape(
64 | (2,) * 6
65 | )
66 |
67 |
68 | def Rx(param: float) -> jax.Array:
69 | param_pi_2 = param * jnp.pi / 2
70 | return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * X * 1.0j
71 |
72 |
73 | def Ry(param: float) -> jax.Array:
74 | param_pi_2 = param * jnp.pi / 2
75 | return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Y * 1.0j
76 |
77 |
78 | def Rz(param: float) -> jax.Array:
79 | param_pi_2 = param * jnp.pi / 2
80 | return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Z * 1.0j
81 |
82 |
83 | def CRx(param: float) -> jax.Array:
84 | return jnp.block([[I, _0], [_0, Rx(param)]]).reshape((2,) * 4)
85 |
86 |
87 | def CRy(param: float) -> jax.Array:
88 | return jnp.block([[I, _0], [_0, Ry(param)]]).reshape((2,) * 4)
89 |
90 |
91 | def CRz(param: float) -> jax.Array:
92 | return jnp.block([[I, _0], [_0, Rz(param)]]).reshape((2,) * 4)
93 |
94 |
95 | def U1(param: float) -> jax.Array:
96 | return U3(0, 0, param)
97 |
98 |
99 | def U2(param0: float, param1: float) -> jax.Array:
100 | return U3(0.5, param0, param1)
101 |
102 |
103 | def U3(param0: float, param1: float, param2: float) -> jax.Array:
104 | return (
105 | jnp.exp((param1 + param2) * jnp.pi * 1.0j / 2)
106 | * Rz(param1)
107 | @ Ry(param0)
108 | @ Rz(param2)
109 | )
110 |
111 |
112 | def CU1(param: float) -> jax.Array:
113 | return jnp.block([[I, _0], [_0, U1(param)]]).reshape((2,) * 4)
114 |
115 |
116 | def CU2(param0: float, param1: float) -> jax.Array:
117 | return jnp.block([[I, _0], [_0, U2(param0, param1)]]).reshape((2,) * 4)
118 |
119 |
120 | def CU3(param0: float, param1: float, param2: float) -> jax.Array:
121 | return jnp.block([[I, _0], [_0, U3(param0, param1, param2)]]).reshape((2,) * 4)
122 |
123 |
124 | def ISWAP(param: float) -> jax.Array:
125 | param_pi_2 = param * jnp.pi / 2
126 | c = jnp.cos(param_pi_2)
127 | i_s = 1.0j * jnp.sin(param_pi_2)
128 | return jnp.array(
129 | [
130 | [1.0, 0.0, 0.0, 0.0],
131 | [0.0, c, i_s, 0.0],
132 | [0.0, i_s, c, 0.0],
133 | [0.0, 0.0, 0.0, 1.0],
134 | ]
135 | ).reshape((2,) * 4)
136 |
137 |
138 | def PhasedISWAP(param0: float, param1: float) -> jax.Array:
139 | param1_pi_2 = param1 * jnp.pi / 2
140 | c = jnp.cos(param1_pi_2)
141 | i_s = 1.0j * jnp.sin(param1_pi_2)
142 | return jnp.array(
143 | [
144 | [1.0, 0.0, 0.0, 0.0],
145 | [0.0, c, i_s * jnp.exp(2.0j * jnp.pi * param0), 0.0],
146 | [0.0, i_s * jnp.exp(-2.0j * jnp.pi * param0), c, 0.0],
147 | [0.0, 0.0, 0.0, 1.0],
148 | ]
149 | ).reshape((2,) * 4)
150 |
151 |
152 | def XXPhase(param: float) -> jax.Array:
153 | param_pi_2 = param * jnp.pi / 2
154 | c = jnp.cos(param_pi_2)
155 | i_s = 1.0j * jnp.sin(param_pi_2)
156 | return jnp.array(
157 | [
158 | [c, 0.0, 0.0, -i_s],
159 | [0.0, c, -i_s, 0.0],
160 | [0.0, -i_s, c, 0.0],
161 | [-i_s, 0.0, 0.0, c],
162 | ]
163 | ).reshape((2,) * 4)
164 |
165 |
166 | def YYPhase(param: float) -> jax.Array:
167 | param_pi_2 = param * jnp.pi / 2
168 | c = jnp.cos(param_pi_2)
169 | i_s = 1.0j * jnp.sin(param_pi_2)
170 | return jnp.array(
171 | [
172 | [c, 0.0, 0.0, i_s],
173 | [0.0, c, -i_s, 0.0],
174 | [0.0, -i_s, c, 0.0],
175 | [i_s, 0.0, 0.0, c],
176 | ]
177 | ).reshape((2,) * 4)
178 |
179 |
180 | def ZZPhase(param: float) -> jax.Array:
181 | param_pi_2 = param * jnp.pi / 2
182 | e_m = jnp.exp(-1.0j * param_pi_2)
183 | e_p = jnp.exp(1.0j * param_pi_2)
184 | return jnp.diag(jnp.array([e_m, e_p, e_p, e_m])).reshape((2,) * 4)
185 |
186 |
187 | ZZMax = ZZPhase(0.5)
188 |
189 |
190 | def PhasedX(param0: float, param1: float) -> jax.Array:
191 | return Rz(param1) @ Rx(param0) @ Rz(-param1)
192 |
--------------------------------------------------------------------------------
/qujax/statetensor.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from functools import partial
4 | from typing import Callable, Sequence, Optional
5 |
6 | import jax
7 | from jax import numpy as jnp
8 | from jax.typing import ArrayLike
9 | from jax._src.dtypes import canonicalize_dtype
10 | from jax._src.typing import DTypeLike
11 |
12 | from qujax import gates
13 | from qujax.utils import _arrayify_inds, check_circuit
14 |
15 | from qujax.typing import Gate, PureCircuitFunction, GateFunction, GateParameterIndices
16 |
17 |
18 | def apply_gate(
19 | statetensor: jax.Array, gate_unitary: jax.Array, qubit_inds: Sequence[int]
20 | ) -> jax.Array:
21 | """
22 | Applies gate to statetensor and returns updated statetensor.
23 | Gate is represented by a unitary matrix in tensor form.
24 |
25 | Args:
26 | statetensor: Input statetensor.
27 | gate_unitary: Unitary array representing gate
28 | must be in tensor form with shape (2,2,...).
29 | qubit_inds: Sequence of indices for gate to be applied to.
30 | Must have 2 * len(qubit_inds) = gate_unitary.ndim
31 |
32 | Returns:
33 | Updated statetensor.
34 | """
35 | statetensor = jnp.tensordot(
36 | gate_unitary, statetensor, axes=(list(range(-len(qubit_inds), 0)), qubit_inds)
37 | )
38 | statetensor = jnp.moveaxis(statetensor, list(range(len(qubit_inds))), qubit_inds)
39 | return statetensor
40 |
41 |
42 | def _to_gate_func(
43 | gate: Gate,
44 | ) -> GateFunction:
45 | """
46 | Ensures a gate_seq element is a function that map (possibly empty) parameters
47 | to a unitary tensor.
48 |
49 | Args:
50 | gate: Either a string matching an array or function in qujax.gates,
51 | a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) )
52 | or a function taking parameters and returning gate unitary in tensor form.
53 |
54 | Returns:
55 | Gate parameter to unitary functions
56 | """
57 |
58 | def _array_to_callable(arr: jax.Array) -> Callable[[], jax.Array]:
59 | return lambda: arr
60 |
61 | if isinstance(gate, str):
62 | gate = gates.__dict__[gate]
63 |
64 | if callable(gate):
65 | gate_func = gate
66 | elif hasattr(gate, "__array__"):
67 | gate_func = _array_to_callable(jnp.array(gate))
68 | else:
69 | raise TypeError(
70 | f"Unsupported gate type - gate must be either a string in qujax.gates, an array or "
71 | f"callable: {gate}"
72 | )
73 | return gate_func
74 |
75 |
76 | def _gate_func_to_unitary(
77 | gate_func: GateFunction,
78 | qubit_inds: Sequence[int],
79 | param_inds: jax.Array,
80 | params: jax.Array,
81 | ) -> jax.Array:
82 | """
83 | Extract gate unitary.
84 |
85 | Args:
86 | gate_func: Function that maps a (possibly empty) parameter array to a unitary tensor (array)
87 | qubit_inds: Indices of qubits to apply gate to
88 | (only needed to ensure gate is in tensor form)
89 | param_inds: Indices of full parameter to extract gate specific parameters
90 | params: Full parameter vector
91 |
92 | Returns:
93 | Array containing gate unitary in tensor form.
94 | """
95 | gate_params = jnp.take(params, param_inds)
96 | gate_unitary = gate_func(*gate_params)
97 | gate_unitary = gate_unitary.reshape(
98 | (2,) * (2 * len(qubit_inds))
99 | ) # Ensure gate is in tensor form
100 | return gate_unitary
101 |
102 |
103 | def all_zeros_statetensor(n_qubits: int, dtype: DTypeLike = complex) -> jax.Array:
104 | """
105 | Returns a statetensor representation of the all-zeros state |00...0> on `n_qubits` qubits
106 |
107 | Args:
108 | n_qubits: Number of qubits that the state is defined on.
109 | dtype: Data type of the statetensor returned.
110 |
111 | Returns:
112 | Statetensor representing the state having all qubits set to zero.
113 | """
114 | statetensor = jnp.zeros((2,) * n_qubits, dtype=canonicalize_dtype(dtype))
115 | statetensor = statetensor.at[(0,) * n_qubits].set(1.0)
116 | return statetensor
117 |
118 |
119 | def get_params_to_statetensor_func(
120 | gate_seq: Sequence[Gate],
121 | qubit_inds_seq: Sequence[Sequence[int]],
122 | param_inds_seq: Sequence[GateParameterIndices],
123 | n_qubits: Optional[int] = None,
124 | ) -> PureCircuitFunction:
125 | """
126 | Creates a function that maps circuit parameters to a statetensor.
127 |
128 | Args:
129 | gate_seq: Sequence of gates.
130 | Each element is either a string matching a unitary array or function in qujax.gates,
131 | a custom unitary array or a custom function taking parameters and returning a
132 | unitary array. Unitary arrays will be reshaped into tensor form (2, 2,...)
133 | qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are
134 | acting on.
135 | i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the
136 | zeroth qubit, the second gate is a two qubit gate acting on the zeroth and first qubit
137 | etc.
138 | param_inds_seq: Sequence of sequences representing parameter indices that gates are using,
139 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter
140 | (the float at position zero in the parameter vector/array), the second gate is not
141 | parameterised and the third gate uses the parameters at position five and two.
142 | n_qubits: Number of qubits, if fixed.
143 |
144 | Returns:
145 | Function which maps parameters (and optional statetensor_in) to a statetensor.
146 | If no parameters are found then the function only takes optional statetensor_in.
147 |
148 | """
149 |
150 | check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits)
151 |
152 | if n_qubits is None:
153 | n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1
154 |
155 | gate_seq_callable = [_to_gate_func(g) for g in gate_seq]
156 | param_inds_array_seq = _arrayify_inds(param_inds_seq)
157 |
158 | def params_to_statetensor_func(
159 | params: ArrayLike, statetensor_in: Optional[jax.Array] = None
160 | ) -> jax.Array:
161 | """
162 | Applies parameterised circuit (series of gates) to a statetensor_in (default is |0>^N).
163 |
164 | Args:
165 | params: Parameters of the circuit.
166 | statetensor_in: Optional. Input statetensor.
167 | Defaults to |0>^N (tensor of size 2^n with all zeroes except one in [0]*N index).
168 |
169 | Returns:
170 | Updated statetensor.
171 |
172 | """
173 | if statetensor_in is None:
174 | statetensor = all_zeros_statetensor(n_qubits)
175 | else:
176 | statetensor = statetensor_in
177 |
178 | params = jnp.atleast_1d(params)
179 | # Guarantee `params` has the right type for type-checking purposes
180 | if not isinstance(params, jax.Array):
181 | raise ValueError("This should not happen. Please open an issue on GitHub.")
182 |
183 | for gate_func, qubit_inds, param_inds in zip(
184 | gate_seq_callable, qubit_inds_seq, param_inds_array_seq
185 | ):
186 | gate_unitary = _gate_func_to_unitary(
187 | gate_func, qubit_inds, param_inds, params
188 | )
189 | statetensor = apply_gate(statetensor, gate_unitary, qubit_inds)
190 | return statetensor
191 |
192 | non_parameterised = all([pi.size == 0 for pi in param_inds_array_seq])
193 | if non_parameterised:
194 |
195 | def no_params_to_statetensor_func(
196 | statetensor_in: Optional[jax.Array] = None,
197 | ) -> jax.Array:
198 | """
199 | Applies circuit (series of gates with no parameters) to a statetensor_in
200 | (default is |0>^N).
201 |
202 | Args:
203 | statetensor_in: Optional. Input statetensor.
204 | Defaults to |0>^N (tensor of size 2^n with all zeroes except one in
205 | the [0]*N index).
206 |
207 | Returns:
208 | Updated statetensor.
209 |
210 | """
211 | return params_to_statetensor_func(jnp.array([]), statetensor_in)
212 |
213 | return no_params_to_statetensor_func
214 |
215 | return params_to_statetensor_func
216 |
217 |
218 | def get_params_to_unitarytensor_func(
219 | gate_seq: Sequence[Gate],
220 | qubit_inds_seq: Sequence[Sequence[int]],
221 | param_inds_seq: Sequence[GateParameterIndices],
222 | n_qubits: Optional[int] = None,
223 | ) -> PureCircuitFunction:
224 | """
225 | Creates a function that maps circuit parameters to a unitarytensor.
226 | The unitarytensor is an array with shape (2,) * 2 * n_qubits
227 | representing the full unitary matrix of the circuit.
228 |
229 | Args:
230 | gate_seq: Sequence of gates.
231 | Each element is either a string matching a unitary array or function in qujax.gates,
232 | a custom unitary array or a custom function taking parameters and returning a unitary
233 | array. Unitary arrays will be reshaped into tensor form (2, 2,...)
234 | qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are
235 | acting on.
236 | i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the
237 | zeroth qubit, the second gate is a two qubit gate acting on the zeroth and first qubit
238 | etc.
239 | param_inds_seq: Sequence of sequences representing parameter indices that gates are using,
240 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter
241 | (the float at position zero in the parameter vector/array), the second gate is not
242 | parameterised and the third gate uses the parameters at position five and two.
243 | n_qubits: Number of qubits, if fixed.
244 |
245 | Returns:
246 | Function which maps any parameters to a unitarytensor.
247 |
248 | """
249 |
250 | if n_qubits is None:
251 | n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1
252 |
253 | param_to_st = get_params_to_statetensor_func(
254 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits
255 | )
256 | identity_unitarytensor = jnp.eye(2**n_qubits).reshape((2,) * 2 * n_qubits)
257 | return partial(param_to_st, statetensor_in=identity_unitarytensor)
258 |
--------------------------------------------------------------------------------
/qujax/statetensor_observable.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | from typing import Callable, Sequence, Union
4 |
5 | import jax
6 | from jax import numpy as jnp
7 | from jax import random
8 | from jax.lax import fori_loop
9 |
10 | from qujax.statetensor import apply_gate
11 | from qujax.utils import check_hermitian, paulis
12 |
13 |
14 | def statetensor_to_single_expectation(
15 | statetensor: jax.Array, hermitian: jax.Array, qubit_inds: Sequence[int]
16 | ) -> jax.Array:
17 | """
18 | Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form).
19 |
20 | Args:
21 | statetensor: Input statetensor.
22 | hermitian: Hermitian array
23 | must be in tensor form with shape (2,2,...).
24 | qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to.
25 | Must have 2 * len(qubit_inds) == hermitian.ndim
26 |
27 | Returns:
28 | Expected value (float).
29 | """
30 | statetensor_new = apply_gate(statetensor, hermitian, qubit_inds)
31 | axes = tuple(range(statetensor.ndim))
32 | return jnp.tensordot(
33 | statetensor.conjugate(), statetensor_new, axes=(axes, axes)
34 | ).real
35 |
36 |
37 | def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jax.Array]]) -> jax.Array:
38 | """
39 | Convert a sequence of observables represented by Pauli strings or Hermitian matrices
40 | in tensor form into single array (in tensor form).
41 |
42 | Args:
43 | hermitian_seq: Sequence of Hermitian strings or arrays.
44 |
45 | Returns:
46 | Hermitian matrix in tensor form (array).
47 | """
48 | for h in hermitian_seq:
49 | check_hermitian(h)
50 |
51 | single_arrs = [paulis[h] if isinstance(h, str) else h for h in hermitian_seq]
52 | single_arrs = [
53 | h_arr.reshape((2,) * int(jnp.rint(jnp.log2(h_arr.size))))
54 | for h_arr in single_arrs
55 | ]
56 |
57 | full_mat = single_arrs[0]
58 | for single_matrix in single_arrs[1:]:
59 | full_mat = jnp.kron(full_mat, single_matrix)
60 | full_mat = full_mat.reshape((2,) * int(jnp.rint(jnp.log2(full_mat.size))))
61 | return full_mat
62 |
63 |
64 | def _get_tensor_to_expectation_func(
65 | hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]],
66 | qubits_seq_seq: Sequence[Sequence[int]],
67 | coefficients: Union[Sequence[float], jax.Array],
68 | contraction_function: Callable,
69 | ) -> Callable[[jax.Array], float]:
70 | """
71 | Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and
72 | a list of coefficients and returns a function that converts a tensor into an expected value.
73 | The contraction function performs the tensor contraction according to the type of tensor
74 | provided (i.e. whether it is a statetensor or a densitytensor).
75 |
76 | Args:
77 | hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors.
78 | Each Hermitian matrix is either represented by a tensor (jax.Array) or by a
79 | list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices.
80 | E.g. [['Z', 'Z'], ['X']]
81 | qubits_seq_seq: Sequence of sequences of integer qubit indices.
82 | E.g. [[0,1], [2]]
83 | coefficients: Sequence of float coefficients to scale the expected values.
84 | contraction_function: Function that performs the tensor contraction.
85 |
86 | Returns:
87 | Function that takes tensor and returns expected value (float).
88 | """
89 |
90 | hermitian_tensors = [get_hermitian_tensor(h_seq) for h_seq in hermitian_seq_seq]
91 |
92 | def tensor_to_expectation_func(tensor: jax.Array) -> float:
93 | """
94 | Maps tensor to expected value.
95 |
96 | Args:
97 | tensor: Input tensor.
98 |
99 | Returns:
100 | Expected value (float).
101 | """
102 | out = 0
103 | for hermitian, qubit_inds, coeff in zip(
104 | hermitian_tensors, qubits_seq_seq, coefficients
105 | ):
106 | out += coeff * contraction_function(tensor, hermitian, qubit_inds)
107 | return out
108 |
109 | return tensor_to_expectation_func
110 |
111 |
112 | def get_statetensor_to_expectation_func(
113 | hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]],
114 | qubits_seq_seq: Sequence[Sequence[int]],
115 | coefficients: Union[Sequence[float], jax.Array],
116 | ) -> Callable[[jax.Array], float]:
117 | """
118 | Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and
119 | a list of coefficients and returns a function that converts a statetensor into an expected
120 | value.
121 |
122 | Args:
123 | hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors.
124 | Each Hermitian matrix is either represented by a tensor (jax.Array)
125 | or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices.
126 | E.g. [['Z', 'Z'], ['X']]
127 | qubits_seq_seq: Sequence of sequences of integer qubit indices.
128 | E.g. [[0,1], [2]]
129 | coefficients: Sequence of float coefficients to scale the expected values.
130 |
131 | Returns:
132 | Function that takes statetensor and returns expected value (float).
133 | """
134 |
135 | return _get_tensor_to_expectation_func(
136 | hermitian_seq_seq,
137 | qubits_seq_seq,
138 | coefficients,
139 | statetensor_to_single_expectation,
140 | )
141 |
142 |
143 | def get_statetensor_to_sampled_expectation_func(
144 | hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]],
145 | qubits_seq_seq: Sequence[Sequence[int]],
146 | coefficients: Union[Sequence[float], jax.Array],
147 | ) -> Callable[[jax.Array, random.PRNGKeyArray, int], float]:
148 | """
149 | Converts strings (or arrays) representing Hermitian matrices, qubit indices and
150 | coefficients into a function that converts a statetensor into a sampled expected value.
151 |
152 | On a quantum device, measurements are always taken in the computational basis, as such
153 | sampled expectation values should be taken with respect to an observable that commutes
154 | with the Pauli Z - a warning will be raised if it does not.
155 |
156 | qujax applies an importance sampling heuristic for sampled expectation values that only
157 | reflects the physical notion of measurement in the case that the observable commutes with Z.
158 | In the case that it does not, the expectation value will still be asymptotically unbiased
159 | but not representative of an experiment on a real quantum device.
160 |
161 | Args:
162 | hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors.
163 | Each Hermitian is either a tensor (jax.Array) or a string in ('X', 'Y', 'Z').
164 | E.g. [['Z', 'Z'], ['X']]
165 | qubits_seq_seq: Sequence of sequences of integer qubit indices.
166 | E.g. [[0,1], [2]]
167 | coefficients: Sequence of float coefficients to scale the expected values.
168 |
169 | Returns:
170 | Function that takes statetensor, random key and integer number of shots
171 | and returns sampled expected value (float).
172 | """
173 | statetensor_to_expectation_func = get_statetensor_to_expectation_func(
174 | hermitian_seq_seq, qubits_seq_seq, coefficients
175 | )
176 |
177 | for hermitian_seq in hermitian_seq_seq:
178 | for h in hermitian_seq:
179 | check_hermitian(h, check_z_commutes=True)
180 |
181 | def statetensor_to_sampled_expectation_func(
182 | statetensor: jax.Array, random_key: random.PRNGKeyArray, n_samps: int
183 | ) -> float:
184 | """
185 | Maps statetensor to sampled expected value.
186 |
187 | Args:
188 | statetensor: Input statetensor.
189 | random_key: JAX random key
190 | n_samps: Number of samples contributing to sampled expectation.
191 |
192 | Returns:
193 | Sampled expected value (float).
194 | """
195 | measure_probs = jnp.abs(statetensor) ** 2
196 | sampled_probs = sample_probs(measure_probs, random_key, n_samps)
197 | iweights = jnp.sqrt(sampled_probs / measure_probs)
198 | return statetensor_to_expectation_func(statetensor * iweights)
199 |
200 | return statetensor_to_sampled_expectation_func
201 |
202 |
203 | def sample_probs(
204 | measure_probs: jax.Array, random_key: random.PRNGKeyArray, n_samps: int
205 | ):
206 | """
207 | Generate an empirical distribution from a probability distribution.
208 |
209 | Args:
210 | measure_probs: Probability distribution.
211 | random_key: JAX random key
212 | n_samps: Number of samples contributing to empirical distribution.
213 |
214 | Returns:
215 | Empirical distribution (jax.Array).
216 | """
217 | measure_probs_flat = measure_probs.flatten()
218 | sampled_integers = random.choice(
219 | random_key,
220 | a=jnp.arange(measure_probs.size),
221 | shape=(n_samps,),
222 | p=measure_probs_flat,
223 | )
224 | sampled_probs = fori_loop(
225 | 0,
226 | n_samps,
227 | lambda i, sv: sv.at[sampled_integers[i]].add(1 / n_samps),
228 | jnp.zeros_like(measure_probs_flat),
229 | )
230 | return sampled_probs.reshape(measure_probs.shape)
231 |
--------------------------------------------------------------------------------
/qujax/typing.py:
--------------------------------------------------------------------------------
1 | from typing import Union, Optional, Protocol, Callable, Iterable, Sequence
2 |
3 | # Backwards compatibility with Python <3.10
4 | from typing_extensions import TypeVarTuple, Unpack
5 |
6 | import jax
7 | from jax.typing import ArrayLike
8 |
9 |
10 | class PureParameterizedCircuit(Protocol):
11 | def __call__(
12 | self, params: ArrayLike, statetensor_in: Optional[jax.Array] = None
13 | ) -> jax.Array: ...
14 |
15 |
16 | class PureUnparameterizedCircuit(Protocol):
17 | def __call__(self, statetensor_in: Optional[jax.Array] = None) -> jax.Array: ...
18 |
19 |
20 | class MixedParameterizedCircuit(Protocol):
21 | def __call__(
22 | self, params: ArrayLike, densitytensor_in: Optional[jax.Array] = None
23 | ) -> jax.Array: ...
24 |
25 |
26 | class MixedUnparameterizedCircuit(Protocol):
27 | def __call__(self, densitytensor_in: Optional[jax.Array] = None) -> jax.Array: ...
28 |
29 |
30 | GateArgs = TypeVarTuple("GateArgs")
31 | # Function that takes arbitrary nr. of parameters and returns an array representing the gate
32 | # Currently Python does not allow us to restrict the type of the arguments using a TypeVarTuple
33 | GateFunction = Callable[[Unpack[GateArgs]], jax.Array]
34 | GateParameterIndices = Optional[Sequence[int]]
35 |
36 | PureCircuitFunction = Union[PureUnparameterizedCircuit, PureParameterizedCircuit]
37 | MixedCircuitFunction = Union[MixedUnparameterizedCircuit, MixedParameterizedCircuit]
38 |
39 | Gate = Union[str, jax.Array, GateFunction]
40 |
41 | KrausOp = Union[Gate, Iterable[Gate]]
42 |
--------------------------------------------------------------------------------
/qujax/utils.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import collections.abc
4 | from inspect import signature
5 | from typing import Callable, List, Optional, Sequence, Tuple, Union
6 | from warnings import warn
7 |
8 | import jax
9 | from jax.typing import ArrayLike
10 | from jax import numpy as jnp
11 | from jax import random
12 |
13 | from qujax import gates
14 |
15 | from qujax.typing import (
16 | Gate,
17 | KrausOp,
18 | GateParameterIndices,
19 | PureParameterizedCircuit,
20 | MixedParameterizedCircuit,
21 | )
22 |
23 | paulis = {"X": gates.X, "Y": gates.Y, "Z": gates.Z}
24 |
25 |
26 | def check_unitary(gate: Gate) -> None:
27 | """
28 | Checks whether a qujax Gate is unitary.
29 | Throws a TypeError if this is found not to be the case.
30 |
31 | Args:
32 | gate: array containing potentially unitary string, array
33 | or function (which will be evaluated with all arguments set to 0.1).
34 |
35 | """
36 | if isinstance(gate, str):
37 | if gate in gates.__dict__:
38 | gate = gates.__dict__[gate]
39 | else:
40 | raise KeyError(
41 | f"Gate string '{gate}' not found in qujax.gates "
42 | f"- consider changing input to an array or callable"
43 | )
44 |
45 | if callable(gate):
46 | num_args = len(signature(gate).parameters)
47 | gate_arr = gate(*jnp.ones(num_args) * 0.1)
48 | elif isinstance(gate, jax.Array):
49 | gate_arr = gate
50 | else:
51 | raise TypeError(
52 | f"Unsupported gate type - gate must be either a string in qujax.gates, an array or "
53 | f"callable: {gate}"
54 | )
55 |
56 | gate_square_dim = int(jnp.sqrt(gate_arr.size))
57 | gate_arr = gate_arr.reshape(gate_square_dim, gate_square_dim)
58 |
59 | if jnp.any(
60 | jnp.abs(gate_arr @ jnp.conjugate(gate_arr).T - jnp.eye(gate_square_dim)) > 1e-3
61 | ):
62 | raise TypeError(f"Gate not unitary: {gate}")
63 |
64 |
65 | def check_hermitian(hermitian: Union[str, jax.Array], check_z_commutes: bool = False):
66 | """
67 | Checks whether a matrix or tensor is Hermitian.
68 |
69 | Args:
70 | hermitian: array containing potentially Hermitian matrix or tensor
71 | check_z_commutes: boolean on whether to check if the matrix commutes with Z
72 |
73 | """
74 | if isinstance(hermitian, str):
75 | if hermitian not in paulis:
76 | raise TypeError(
77 | f"qujax only accepts {tuple(paulis.keys())} as Hermitian strings,"
78 | "received: {hermitian}"
79 | )
80 | n_qubits = 1
81 | hermitian_mat = paulis[hermitian]
82 |
83 | else:
84 | n_qubits = hermitian.ndim // 2
85 | hermitian_mat = hermitian.reshape(2 * n_qubits, 2 * n_qubits)
86 | if not jnp.allclose(hermitian_mat, hermitian_mat.T.conj()):
87 | raise TypeError(f"Array not Hermitian: {hermitian}")
88 |
89 | if check_z_commutes:
90 | big_z = jnp.diag(jnp.where(jnp.arange(2**n_qubits) % 2 == 0, 1, -1))
91 | z_commutes = jnp.allclose(hermitian_mat @ big_z, big_z @ hermitian_mat)
92 | if not z_commutes:
93 | warn(
94 | "Hermitian matrix does not commute with Z. \n"
95 | "For sampled expectation values, this may lead to unexpected results, "
96 | "measurements on a quantum device are always taken in the computational basis. "
97 | "Additional gates can be applied in the circuit to change the basis such "
98 | "that an observable that commutes with Z can be measured."
99 | )
100 |
101 |
102 | def _arrayify_inds(
103 | param_inds_seq: Optional[Sequence[GateParameterIndices]],
104 | ) -> Sequence[jax.Array]:
105 | """
106 | Ensure each element of param_inds_seq is an array (and therefore valid for jnp.take)
107 |
108 | Args:
109 | param_inds_seq: Sequence of sequences representing parameter indices that gates are using,
110 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter
111 | (the float at position zero in the parameter vector/array), the second gate is not
112 | parameterised and the third gates used the parameters at position five and two.
113 |
114 | Returns:
115 | Sequence of arrays representing parameter indices.
116 | """
117 | if param_inds_seq is None:
118 | param_inds_seq = [None]
119 | array_param_inds = [jnp.array(p) for p in param_inds_seq]
120 | array_param_inds = [
121 | jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int)
122 | for p in array_param_inds
123 | ]
124 | return array_param_inds
125 |
126 |
127 | def check_circuit(
128 | gate_seq: Sequence[KrausOp],
129 | qubit_inds_seq: Sequence[Sequence[int]],
130 | param_inds_seq: Sequence[
131 | Union[GateParameterIndices, Sequence[GateParameterIndices]]
132 | ],
133 | n_qubits: Optional[int] = None,
134 | check_unitaries: bool = True,
135 | ):
136 | """
137 | Basic checks that circuit arguments conform.
138 |
139 | Args:
140 | gate_seq: Sequence of gates.
141 | Each element is either a string matching an array or function in qujax.gates,
142 | a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) )
143 | or a function taking parameters and returning gate unitary in tensor form.
144 | Or alternatively a sequence of the above representing Kraus operators.
145 | qubit_inds_seq: Sequences of qubits (ints) that gates are acting on.
146 | param_inds_seq: Sequence of parameter indices that gates are using,
147 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter,
148 | the second gate is not parameterised and the third gates used the fifth and second
149 | parameters.
150 | n_qubits: Number of qubits, if fixed.
151 | check_unitaries: boolean on whether to check if each gate represents a unitary matrix
152 |
153 | """
154 | if not isinstance(gate_seq, collections.abc.Sequence):
155 | raise TypeError("gate_seq must be Sequence e.g. ['H', 'Rx', 'CX']")
156 |
157 | if (not isinstance(qubit_inds_seq, collections.abc.Sequence)) or (
158 | any(
159 | [
160 | not (isinstance(q, collections.abc.Sequence) or hasattr(q, "__array__"))
161 | for q in qubit_inds_seq
162 | ]
163 | )
164 | ):
165 | raise TypeError(
166 | "qubit_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]"
167 | )
168 |
169 | if (not isinstance(param_inds_seq, collections.abc.Sequence)) or (
170 | any(
171 | [
172 | not (
173 | isinstance(p, collections.abc.Sequence)
174 | or hasattr(p, "__array__")
175 | or p is None
176 | )
177 | for p in param_inds_seq
178 | ]
179 | )
180 | ):
181 | raise TypeError(
182 | "param_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]"
183 | )
184 |
185 | if len(gate_seq) != len(qubit_inds_seq) or len(param_inds_seq) != len(
186 | param_inds_seq
187 | ):
188 | raise TypeError(
189 | f"gate_seq ({len(gate_seq)}), qubit_inds_seq ({len(qubit_inds_seq)})"
190 | f"and param_inds_seq ({len(param_inds_seq)}) must have matching lengths"
191 | )
192 |
193 | if n_qubits is not None and n_qubits < max([max(qi) for qi in qubit_inds_seq]) + 1:
194 | raise TypeError(
195 | "n_qubits must be larger than largest qubit index in qubit_inds_seq"
196 | )
197 |
198 | if check_unitaries:
199 | for g in gate_seq:
200 | check_unitary(g)
201 |
202 |
203 | def _get_gate_str(
204 | gate_obj: KrausOp,
205 | param_inds: Union[GateParameterIndices, Sequence[GateParameterIndices]],
206 | ) -> str:
207 | """
208 | Maps single gate object to a four character string representation
209 |
210 | Args:
211 | gate_obj: Either a string matching a function in qujax.gates,
212 | a unitary array (which will be reshaped into a tensor of shape e.g. (2,2,2,...) )
213 | or a function taking parameters (can be empty) and returning gate unitary
214 | in tensor form.
215 | Or alternatively, a sequence of Krause operators represented by strings, arrays or
216 | functions.
217 | param_inds: Parameter indices that gates are using, i.e. gate uses 1st and 5th parameter.
218 |
219 | Returns:
220 | Four character string representation of the gate
221 |
222 | """
223 | if isinstance(gate_obj, (tuple, list)) or (
224 | hasattr(gate_obj, "__array__") and gate_obj.ndim % 2 == 1
225 | ):
226 | # Kraus operators
227 | gate_obj = "Kr"
228 | param_inds = jnp.unique(jnp.concatenate(_arrayify_inds(param_inds), axis=0))
229 |
230 | if isinstance(gate_obj, str):
231 | gate_str = gate_obj
232 | elif hasattr(gate_obj, "__array__"):
233 | gate_str = "Arr"
234 | elif callable(gate_obj):
235 | gate_str = "Func"
236 | else:
237 | if hasattr(gate_obj, "__name__"):
238 | gate_str = gate_obj.__name__
239 | elif hasattr(gate_obj, "__class__") and hasattr(gate_obj.__class__, "__name__"):
240 | gate_str = gate_obj.__class__.__name__
241 | else:
242 | gate_str = "Other"
243 |
244 | if hasattr(param_inds, "tolist"):
245 | param_inds = param_inds.tolist()
246 |
247 | if isinstance(param_inds, tuple):
248 | param_inds = list(param_inds)
249 |
250 | if param_inds == [] or param_inds == [None] or param_inds is None:
251 | if len(gate_str) > 7:
252 | gate_str = gate_str[:6] + "."
253 | else:
254 | param_str = str(param_inds).replace(" ", "")
255 |
256 | if len(param_str) > 5:
257 | param_str = "[.]"
258 |
259 | if (len(gate_str) + len(param_str)) > 7:
260 | gate_str = gate_str[:1] + "."
261 |
262 | gate_str += param_str
263 |
264 | gate_str = gate_str.center(7, "-")
265 |
266 | return gate_str
267 |
268 |
269 | def _pad_rows(rows: List[str]) -> Tuple[List[str], List[bool]]:
270 | """
271 | Pad string representation of circuit to be rectangular.
272 | Fills qubit rows with '-' and between-qubit rows with ' '.
273 |
274 | Args:
275 | rows: String representation of circuit
276 |
277 | Returns:
278 | Rectangular string representation of circuit with right padding.
279 |
280 | """
281 |
282 | max_len = max([len(r) for r in rows])
283 |
284 | def extend_row(row: str, qubit_row: bool) -> str:
285 | lr = len(row)
286 | if lr < max_len:
287 | if qubit_row:
288 | row += "-" * (max_len - lr)
289 | else:
290 | row += " " * (max_len - lr)
291 | return row
292 |
293 | out_rows = [extend_row(r, i % 2 == 0) for i, r in enumerate(rows)]
294 | return out_rows, [True] * len(rows)
295 |
296 |
297 | def print_circuit(
298 | gate_seq: Sequence[KrausOp],
299 | qubit_inds_seq: Sequence[Sequence[int]],
300 | param_inds_seq: Sequence[
301 | Union[GateParameterIndices, Sequence[GateParameterIndices]]
302 | ],
303 | n_qubits: Optional[int] = None,
304 | qubit_min: int = 0,
305 | qubit_max: Optional[int] = None,
306 | gate_ind_min: int = 0,
307 | gate_ind_max: Optional[int] = None,
308 | sep_length: int = 1,
309 | ) -> List[str]:
310 | """
311 | Returns and prints basic string representation of circuit.
312 |
313 | Args:
314 | gate_seq: Sequence of gates.
315 | Each element is either a string matching an array or function in qujax.gates,
316 | a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) )
317 | or a function taking parameters and returning gate unitary in tensor form.
318 | Or alternatively a sequence of the above representing Kraus operators.
319 | qubit_inds_seq: Sequences of qubits (ints) that gates are acting on.
320 | param_inds_seq: Sequence of parameter indices that gates are using,
321 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter,
322 | the second gate is not parameterised and the third gates used the fifth and
323 | second parameters.
324 | n_qubits: Number of qubits, if fixed.
325 | qubit_min: Index of first qubit to display.
326 | qubit_max: Index of final qubit to display.
327 | gate_ind_min: Index of gate to start circuit printing.
328 | gate_ind_max: Index of gate to stop circuit printing.
329 | sep_length: Number of dashes to separate gates.
330 |
331 | Returns:
332 | String representation of circuit
333 |
334 | """
335 | check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits, False)
336 |
337 | if gate_ind_max is None:
338 | gate_ind_max = len(gate_seq) - 1
339 | else:
340 | gate_ind_max = min(len(gate_seq) - 1, gate_ind_max)
341 |
342 | if gate_ind_min > gate_ind_max:
343 | raise TypeError("gate_ind_max must be larger or equal to gate_ind_min")
344 |
345 | if n_qubits is None:
346 | n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1
347 |
348 | if qubit_max is None:
349 | qubit_max = n_qubits - 1
350 | else:
351 | qubit_max = min(n_qubits - 1, qubit_max)
352 |
353 | if qubit_min > qubit_max:
354 | raise TypeError("qubit_max must be larger or equal to qubit_min")
355 |
356 | gate_str_seq = [_get_gate_str(g, p) for g, p in zip(gate_seq, param_inds_seq)]
357 |
358 | n_qubits_disp = qubit_max - qubit_min + 1
359 |
360 | rows = [f"q{qubit_min}: ".ljust(3) + "-" * sep_length]
361 | if n_qubits_disp > 1:
362 | for i in range(qubit_min + 1, qubit_max + 1):
363 | rows += [" ", f"q{i}: ".ljust(3) + "-" * sep_length]
364 | rows, rows_free = _pad_rows(rows)
365 |
366 | for gate_ind in range(gate_ind_min, gate_ind_max + 1):
367 | g = gate_str_seq[gate_ind]
368 | qi = qubit_inds_seq[gate_ind]
369 |
370 | qi_min = min(qi)
371 | qi_max = max(qi)
372 | ri_min = 2 * qi_min # index of top row used by gate
373 | ri_max = 2 * qi_max # index of bottom row used by gate
374 |
375 | if not all([rows_free[i] for i in range(ri_min, ri_max)]):
376 | rows, rows_free = _pad_rows(rows)
377 |
378 | for row_ind in range(ri_min, ri_max + 1):
379 | if row_ind == 2 * qi[-1]:
380 | rows[row_ind] += "-" * sep_length + g
381 | elif row_ind % 2 == 1:
382 | rows[row_ind] += " " * sep_length + " " + "|" + " "
383 | elif row_ind / 2 in qi:
384 | rows[row_ind] += "-" * sep_length + "---" + "◯" + "---"
385 | else:
386 | rows[row_ind] += "-" * sep_length + "---" + "|" + "---"
387 |
388 | rows_free[row_ind] = False
389 |
390 | rows, _ = _pad_rows(rows)
391 |
392 | for p in rows:
393 | print(p)
394 |
395 | return rows
396 |
397 |
398 | def integers_to_bitstrings(
399 | integers: Union[int, jax.Array], nbits: Optional[int] = None
400 | ) -> jax.Array:
401 | """
402 | Convert integer or array of integers into their binary expansion(s).
403 |
404 | Args:
405 | integers: Integer or array of integers to be converted.
406 | nbits: Length of output binary expansion.
407 | Defaults to smallest possible.
408 |
409 | Returns:
410 | Array of binary expansion(s).
411 | """
412 | integers = jnp.atleast_1d(integers)
413 | # Guarantee `bitstrings` has the right type for type-checking purposes
414 | if not isinstance(integers, jax.Array):
415 | raise ValueError("This should not happen. Please open an issue on GitHub.")
416 |
417 | if nbits is None:
418 | nbits = int(jnp.ceil(jnp.log2(jnp.maximum(integers.max(), 1)) + 1e-5).item())
419 |
420 | return jnp.squeeze(
421 | ((integers[:, None] & (1 << jnp.arange(nbits - 1, -1, -1))) > 0).astype(int)
422 | )
423 |
424 |
425 | def bitstrings_to_integers(bitstrings: ArrayLike) -> jax.Array:
426 | """
427 | Convert binary expansion(s) into integers.
428 |
429 | Args:
430 | bitstrings: Bitstring array or array of bitstring arrays.
431 |
432 | Returns:
433 | Array of integers.
434 | """
435 | bitstrings = jnp.atleast_2d(bitstrings)
436 |
437 | # Guarantee `bitstrings` has the right type for type-checking purposes
438 | if not isinstance(bitstrings, jax.Array):
439 | raise ValueError("This should not happen. Please open an issue on GitHub.")
440 |
441 | convarr = 2 ** jnp.arange(bitstrings.shape[-1] - 1, -1, -1)
442 | return jnp.squeeze(bitstrings.dot(convarr)).astype(int)
443 |
444 |
445 | def sample_integers(
446 | random_key: random.PRNGKeyArray,
447 | statetensor: jax.Array,
448 | n_samps: int = 1,
449 | ) -> jax.Array:
450 | """
451 | Generate random integer samples according to statetensor.
452 |
453 | Args:
454 | random_key: JAX random key to seed samples.
455 | statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes).
456 | n_samps: Number of samples to generate. Defaults to 1.
457 |
458 | Returns:
459 | Array with sampled integers, shape=(n_samps,).
460 |
461 | """
462 | sv_probs = jnp.square(jnp.abs(statetensor.flatten()))
463 | sampled_inds = random.choice(
464 | random_key, a=jnp.arange(statetensor.size), shape=(n_samps,), p=sv_probs
465 | )
466 | return sampled_inds
467 |
468 |
469 | def sample_bitstrings(
470 | random_key: random.PRNGKeyArray,
471 | statetensor: jax.Array,
472 | n_samps: int = 1,
473 | ) -> jax.Array:
474 | """
475 | Generate random bitstring samples according to statetensor.
476 |
477 | Args:
478 | random_key: JAX random key to seed samples.
479 | statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes).
480 | n_samps: Number of samples to generate. Defaults to 1.
481 |
482 | Returns:
483 | Array with sampled bitstrings, shape=(n_samps, statetensor.ndim).
484 | """
485 | return integers_to_bitstrings(
486 | sample_integers(random_key, statetensor, n_samps), statetensor.ndim
487 | )
488 |
489 |
490 | def statetensor_to_densitytensor(statetensor: jax.Array) -> jax.Array:
491 | """
492 | Computes a densitytensor representation of a pure quantum state
493 | from its statetensor representaton
494 |
495 | Args:
496 | statetensor: Input statetensor.
497 |
498 | Returns:
499 | A densitytensor representing the quantum state.
500 | """
501 | n_qubits = statetensor.ndim
502 | st = statetensor
503 | dt = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape(
504 | 2 for _ in range(2 * n_qubits)
505 | )
506 | return dt
507 |
508 |
509 | def repeat_circuit(
510 | circuit: Union[PureParameterizedCircuit, MixedParameterizedCircuit],
511 | nr_of_parameters: int,
512 | ) -> Callable[[jax.Array, jax.Array], jax.Array]:
513 | """
514 | Repeats circuit encoded by `circuit` an arbitrary number of times.
515 | Avoids compilation overhead with increasing circuit depth.
516 |
517 | Args:
518 | circuit: The function encoding the circuit.
519 | nr_of_parameters: The number of parameters that `circuit` takes.
520 |
521 | Returns:
522 | A function taking an arbitrary number of parameters and returning as
523 | many applications of `circuit` as the number of parameters allows.
524 | An exception is thrown if this function is supplied with a parameter array
525 | of size not divisible by `nr_of_parameters`.
526 | """
527 |
528 | def repeated_circuit(params: jax.Array, statetensor_in: jax.Array) -> jax.Array:
529 | def f(state, p):
530 | return circuit(p, state), None
531 |
532 | reshaped_parameters = params.reshape(-1, nr_of_parameters)
533 | result, _ = jax.lax.scan(f, statetensor_in, reshaped_parameters)
534 | return result
535 |
536 | return repeated_circuit
537 |
--------------------------------------------------------------------------------
/qujax/version.py:
--------------------------------------------------------------------------------
1 | __version__ = "1.1.0"
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | exec(open("qujax/version.py").read())
4 |
5 | setup(
6 | name="qujax",
7 | author="Sam Duffield",
8 | author_email="sam.duffield@quantinuum.com",
9 | url="https://github.com/CQCL/qujax",
10 | description="Simulating quantum circuits with JAX",
11 | long_description=open("README.md").read(),
12 | long_description_content_type="text/markdown",
13 | license="Apache 2",
14 | packages=find_packages(),
15 | python_requires=">=3.8",
16 | install_requires=["jax>=0.4.1", "jaxlib", "typing_extensions"],
17 | classifiers=[
18 | "Programming Language :: Python",
19 | "Intended Audience :: Developers",
20 | "Intended Audience :: Science/Research",
21 | "Topic :: Scientific/Engineering",
22 | ],
23 | include_package_data=True,
24 | platforms="any",
25 | version=__version__,
26 | )
27 |
--------------------------------------------------------------------------------
/tests/test_densitytensor.py:
--------------------------------------------------------------------------------
1 | from itertools import combinations
2 |
3 | from jax import jit
4 | from jax import numpy as jnp
5 |
6 | import qujax
7 |
8 |
9 | def test_kraus_single():
10 | n_qubits = 3
11 | dim = 2**n_qubits
12 | density_matrix = jnp.arange(dim**2).reshape(dim, dim)
13 | density_tensor = density_matrix.reshape((2,) * 2 * n_qubits)
14 | kraus_operator = qujax.gates.Rx(0.2)
15 |
16 | qubit_inds = (1,)
17 |
18 | unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator)
19 | unitary_matrix = jnp.kron(
20 | unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1))
21 | )
22 | check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T
23 |
24 | # qujax._kraus_single
25 | qujax_kraus_dt = qujax._kraus_single(density_tensor, kraus_operator, qubit_inds)
26 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim)
27 |
28 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm)
29 |
30 | qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))(
31 | density_tensor, kraus_operator, qubit_inds
32 | )
33 | qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim)
34 | assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm)
35 |
36 | # qujax.kraus (but for a single array)
37 | qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator, qubit_inds)
38 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim)
39 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm)
40 |
41 | qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(
42 | density_tensor, kraus_operator, qubit_inds
43 | )
44 | qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim)
45 | assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm)
46 |
47 |
48 | def test_kraus_single_2qubit():
49 | n_qubits = 4
50 | dim = 2**n_qubits
51 | density_matrix = jnp.arange(dim**2).reshape(dim, dim)
52 | density_tensor = density_matrix.reshape((2,) * 2 * n_qubits)
53 | kraus_operator_tensor = qujax.gates.ZZPhase(0.1)
54 | kraus_operator = qujax.gates.ZZPhase(0.1).reshape(4, 4)
55 |
56 | qubit_inds = (1, 2)
57 |
58 | unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator)
59 | unitary_matrix = jnp.kron(
60 | unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1))
61 | )
62 | check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T
63 |
64 | # qujax._kraus_single
65 | qujax_kraus_dt = qujax._kraus_single(
66 | density_tensor, kraus_operator_tensor, qubit_inds
67 | )
68 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim)
69 |
70 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm)
71 |
72 | qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))(
73 | density_tensor, kraus_operator_tensor, qubit_inds
74 | )
75 | qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim)
76 | assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm)
77 |
78 | # qujax.kraus (but for a single array)
79 | qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator_tensor, qubit_inds)
80 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim)
81 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm)
82 |
83 | qujax_kraus_dt = qujax.kraus(
84 | density_tensor, kraus_operator, qubit_inds
85 | ) # check reshape kraus_operator correctly
86 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim)
87 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm)
88 |
89 | qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(
90 | density_tensor, kraus_operator_tensor, qubit_inds
91 | )
92 | qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim)
93 | assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm)
94 |
95 |
96 | def test_kraus_multiple():
97 | n_qubits = 3
98 | dim = 2**n_qubits
99 | density_matrix = jnp.arange(dim**2).reshape(dim, dim)
100 | density_tensor = density_matrix.reshape((2,) * 2 * n_qubits)
101 |
102 | kraus_operators = [
103 | 0.25 * qujax.gates.H,
104 | 0.25 * qujax.gates.Rx(0.3),
105 | 0.5 * qujax.gates.Ry(0.1),
106 | ]
107 |
108 | qubit_inds = (1,)
109 |
110 | unitary_matrices = [
111 | jnp.kron(jnp.eye(2 * qubit_inds[0]), ko) for ko in kraus_operators
112 | ]
113 | unitary_matrices = [
114 | jnp.kron(um, jnp.eye(2 * (n_qubits - qubit_inds[0] - 1)))
115 | for um in unitary_matrices
116 | ]
117 |
118 | check_kraus_dm = jnp.zeros_like(density_matrix)
119 | for um in unitary_matrices:
120 | check_kraus_dm += um @ density_matrix @ um.conj().T
121 |
122 | qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operators, qubit_inds)
123 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim)
124 |
125 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm)
126 |
127 | qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))(
128 | density_tensor, kraus_operators, qubit_inds
129 | )
130 | qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim)
131 | assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm)
132 |
133 |
134 | def test_params_to_densitytensor_func():
135 | n_qubits = 2
136 |
137 | gate_seq = ["Rx" for _ in range(n_qubits)]
138 | qubit_inds_seq = [(i,) for i in range(n_qubits)]
139 | param_inds_seq = [(i,) for i in range(n_qubits)]
140 |
141 | gate_seq += ["CZ" for _ in range(n_qubits - 1)]
142 | qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)]
143 | param_inds_seq += [() for _ in range(n_qubits - 1)]
144 |
145 | params_to_dt = qujax.get_params_to_densitytensor_func(
146 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits
147 | )
148 | params_to_st = qujax.get_params_to_statetensor_func(
149 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits
150 | )
151 |
152 | params = jnp.arange(n_qubits) / 10.0
153 |
154 | st = params_to_st(params)
155 | dt_test = qujax.statetensor_to_densitytensor(st)
156 |
157 | dt = params_to_dt(params)
158 |
159 | assert jnp.allclose(dt, dt_test)
160 |
161 | jit_dt = jit(params_to_dt)(params)
162 | assert jnp.allclose(jit_dt, dt_test)
163 |
164 |
165 | def test_params_to_densitytensor_func_with_bit_flip():
166 | n_qubits = 2
167 |
168 | gate_seq = ["Rx" for _ in range(n_qubits)]
169 | qubit_inds_seq = [(i,) for i in range(n_qubits)]
170 | param_inds_seq = [(i,) for i in range(n_qubits)]
171 |
172 | gate_seq += ["CZ" for _ in range(n_qubits - 1)]
173 | qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)]
174 | param_inds_seq += [() for _ in range(n_qubits - 1)]
175 |
176 | params_to_pre_bf_st = qujax.get_params_to_statetensor_func(
177 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits
178 | )
179 |
180 | kraus_ops = [[0.3 * jnp.eye(2), 0.7 * qujax.gates.X]]
181 | kraus_qubit_inds = [(0,)]
182 | kraus_param_inds = [None]
183 |
184 | gate_seq += kraus_ops
185 | qubit_inds_seq += kraus_qubit_inds
186 | param_inds_seq += kraus_param_inds
187 |
188 | _ = qujax.print_circuit(gate_seq, qubit_inds_seq, param_inds_seq)
189 |
190 | params_to_dt = qujax.get_params_to_densitytensor_func(
191 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits
192 | )
193 |
194 | params = jnp.arange(n_qubits) / 10.0
195 |
196 | pre_bf_st = params_to_pre_bf_st(params)
197 | pre_bf_dt = (pre_bf_st.reshape(-1, 1) @ pre_bf_st.reshape(1, -1).conj()).reshape(
198 | 2 for _ in range(2 * n_qubits)
199 | )
200 | dt_test = qujax.kraus(pre_bf_dt, kraus_ops[0], kraus_qubit_inds[0])
201 |
202 | dt = params_to_dt(params)
203 |
204 | assert jnp.allclose(dt, dt_test)
205 |
206 | jit_dt = jit(params_to_dt)(params)
207 | assert jnp.allclose(jit_dt, dt_test)
208 |
209 |
210 | def test_partial_trace_1():
211 | state1 = 1 / jnp.sqrt(2) * jnp.array([1.0, 1.0])
212 | state2 = jnp.kron(state1, state1)
213 | state3 = jnp.kron(state1, state2)
214 |
215 | dt1 = jnp.outer(state1, state1.conj()).reshape((2,) * 2)
216 | dt2 = jnp.outer(state2, state2.conj()).reshape((2,) * 4)
217 | dt3 = jnp.outer(state3, state3.conj()).reshape((2,) * 6)
218 |
219 | for i in range(3):
220 | assert jnp.allclose(qujax.partial_trace(dt3, [i]), dt2)
221 |
222 | for i in combinations(range(3), 2):
223 | assert jnp.allclose(qujax.partial_trace(dt3, i), dt1)
224 |
225 |
226 | def test_partial_trace_2():
227 | n_qubits = 3
228 |
229 | gate_seq = ["Rx" for _ in range(n_qubits)]
230 | qubit_inds_seq = [(i,) for i in range(n_qubits)]
231 | param_inds_seq = [(i,) for i in range(n_qubits)]
232 |
233 | gate_seq += ["CZ" for _ in range(n_qubits - 1)]
234 | qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)]
235 | param_inds_seq += [() for _ in range(n_qubits - 1)]
236 |
237 | params_to_dt = qujax.get_params_to_densitytensor_func(
238 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits
239 | )
240 |
241 | params = jnp.arange(1, n_qubits + 1) / 10.0
242 |
243 | dt = params_to_dt(params)
244 | dt_discard_test = jnp.trace(dt, axis1=0, axis2=n_qubits)
245 | dt_discard = qujax.partial_trace(dt, [0])
246 |
247 | assert jnp.allclose(dt_discard, dt_discard_test)
248 |
249 |
250 | def test_measure():
251 | n_qubits = 3
252 |
253 | gate_seq = ["Rx" for _ in range(n_qubits)]
254 | qubit_inds_seq = [(i,) for i in range(n_qubits)]
255 | param_inds_seq = [(i,) for i in range(n_qubits)]
256 |
257 | gate_seq += ["CZ" for _ in range(n_qubits - 1)]
258 | qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)]
259 | param_inds_seq += [() for _ in range(n_qubits - 1)]
260 |
261 | params_to_dt = qujax.get_params_to_densitytensor_func(
262 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits
263 | )
264 |
265 | params = jnp.arange(1, n_qubits + 1) / 10.0
266 |
267 | dt = params_to_dt(params)
268 |
269 | qubit_inds = [0]
270 |
271 | all_probs = jnp.diag(dt.reshape(2**n_qubits, 2**n_qubits)).real
272 | all_probs_marginalise = all_probs.reshape((2,) * n_qubits).sum(
273 | axis=[i for i in range(n_qubits) if i not in qubit_inds]
274 | )
275 |
276 | probs = qujax.densitytensor_to_measurement_probabilities(dt, qubit_inds)
277 |
278 | assert jnp.isclose(probs.sum(), 1.0)
279 | assert jnp.isclose(all_probs.sum(), 1.0)
280 | assert jnp.allclose(probs, all_probs_marginalise)
281 |
282 | dm = dt.reshape(2**n_qubits, 2**n_qubits)
283 | projector = jnp.array([[1, 0], [0, 0]])
284 | for _ in range(n_qubits - 1):
285 | projector = jnp.kron(projector, jnp.eye(2))
286 | measured_dm = projector @ dm @ projector.T.conj()
287 | measured_dm /= jnp.trace(projector.T.conj() @ projector @ dm)
288 | measured_dt_true = measured_dm.reshape((2,) * 2 * n_qubits)
289 |
290 | measured_dt = qujax.densitytensor_to_measured_densitytensor(dt, qubit_inds, 0)
291 | measured_dt_bits = qujax.densitytensor_to_measured_densitytensor(
292 | dt, qubit_inds, jnp.zeros(n_qubits)
293 | )
294 | assert jnp.allclose(measured_dt_true, measured_dt)
295 | assert jnp.allclose(measured_dt_true, measured_dt_bits)
296 |
--------------------------------------------------------------------------------
/tests/test_expectations.py:
--------------------------------------------------------------------------------
1 | from jax import config, grad, jit
2 | from jax import numpy as jnp
3 | from jax import random
4 |
5 | import qujax
6 |
7 |
8 | def test_pauli_hermitian():
9 | for p_str in ("X", "Y", "Z"):
10 | qujax.check_hermitian(p_str)
11 | qujax.check_hermitian(qujax.gates.__dict__[p_str])
12 |
13 |
14 | def test_single_expectation():
15 | Z = qujax.gates.Z
16 |
17 | st1 = jnp.zeros((2, 2, 2))
18 | st2 = jnp.zeros((2, 2, 2))
19 | st1 = st1.at[(0, 0, 0)].set(1.0)
20 | st2 = st2.at[(1, 0, 0)].set(1.0)
21 | dt1 = qujax.statetensor_to_densitytensor(st1)
22 | dt2 = qujax.statetensor_to_densitytensor(st2)
23 | ZZ = jnp.kron(Z, Z).reshape(2, 2, 2, 2)
24 |
25 | est1 = qujax.statetensor_to_single_expectation(st1, ZZ, [0, 1])
26 | est2 = qujax.statetensor_to_single_expectation(st2, ZZ, [0, 1])
27 | edt1 = qujax.densitytensor_to_single_expectation(dt1, ZZ, [0, 1])
28 | edt2 = qujax.densitytensor_to_single_expectation(dt2, ZZ, [0, 1])
29 |
30 | assert est1.item() == edt1.item() == 1
31 | assert est2.item() == edt2.item() == -1
32 |
33 |
34 | def test_bitstring_expectation():
35 | n_qubits = 4
36 |
37 | gates = (
38 | ["H"] * n_qubits
39 | + ["Ry"] * n_qubits
40 | + ["Rz"] * n_qubits
41 | + ["CX"] * (n_qubits - 1)
42 | + ["Ry"] * n_qubits
43 | + ["Rz"] * n_qubits
44 | )
45 | qubits = (
46 | [[i] for i in range(n_qubits)] * 3
47 | + [[i, i + 1] for i in range(n_qubits - 1)]
48 | + [[i] for i in range(n_qubits)] * 2
49 | )
50 | param_inds = (
51 | [[]] * n_qubits
52 | + [[i] for i in range(n_qubits * 2)]
53 | + [[]] * (n_qubits - 1)
54 | + [[i] for i in range(n_qubits * 2, n_qubits * 4)]
55 | )
56 |
57 | param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds)
58 |
59 | n_params = n_qubits * 4
60 | params = random.uniform(random.PRNGKey(0), shape=(n_params,))
61 |
62 | costs = random.normal(random.PRNGKey(1), shape=(2**n_qubits,))
63 |
64 | def st_to_expectation(statetensor):
65 | probs = jnp.square(jnp.abs(statetensor.flatten()))
66 | return jnp.sum(costs * probs)
67 |
68 | param_to_expectation = lambda p: st_to_expectation(param_to_st(p))
69 |
70 | def brute_force_param_to_exp(p):
71 | sv = param_to_st(p).flatten()
72 | return jnp.dot(sv, jnp.diag(costs) @ sv.conj()).real
73 |
74 | true_expectation = brute_force_param_to_exp(params)
75 |
76 | expectation = param_to_expectation(params)
77 | expectation_jit = jit(param_to_expectation)(params)
78 |
79 | assert expectation.shape == ()
80 | assert expectation.dtype.name[:5] == "float"
81 | assert jnp.isclose(true_expectation, expectation)
82 | assert jnp.isclose(true_expectation, expectation_jit)
83 |
84 | true_expectation_grad = grad(brute_force_param_to_exp)(params)
85 | expectation_grad = grad(param_to_expectation)(params)
86 | expectation_grad_jit = jit(grad(param_to_expectation))(params)
87 |
88 | assert expectation_grad.shape == (n_params,)
89 | assert expectation_grad.dtype.name[:5] == "float"
90 | assert jnp.allclose(true_expectation_grad, expectation_grad, atol=1e-5)
91 | assert jnp.allclose(true_expectation_grad, expectation_grad_jit, atol=1e-5)
92 |
93 |
94 | def _test_hermitian_observable(
95 | hermitian_str_seq_seq, qubit_inds_seq, coefs, st_in=None
96 | ):
97 | n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1
98 |
99 | if st_in is None:
100 | state = (
101 | random.uniform(random.PRNGKey(2), shape=(2**n_qubits,)) * 2
102 | + 1.0j * random.uniform(random.PRNGKey(1), shape=(2**n_qubits,)) * 2
103 | )
104 | state /= jnp.linalg.norm(state)
105 | st_in = state.reshape((2,) * n_qubits)
106 |
107 | dt_in = qujax.statetensor_to_densitytensor(st_in)
108 |
109 | st_to_exp = qujax.get_statetensor_to_expectation_func(
110 | hermitian_str_seq_seq, qubit_inds_seq, coefs
111 | )
112 | dt_to_exp = qujax.get_densitytensor_to_expectation_func(
113 | hermitian_str_seq_seq, qubit_inds_seq, coefs
114 | )
115 |
116 | def big_hermitian_matrix(hermitian_str_seq, qubit_inds):
117 | qubit_arrs = [getattr(qujax.gates, s) for s in hermitian_str_seq]
118 | hermitian_arrs = []
119 | j = 0
120 | for i in range(n_qubits):
121 | if i in qubit_inds:
122 | hermitian_arrs.append(qubit_arrs[j])
123 | j += 1
124 | else:
125 | hermitian_arrs.append(jnp.eye(2))
126 |
127 | big_h = hermitian_arrs[0]
128 | for k in range(1, n_qubits):
129 | big_h = jnp.kron(big_h, hermitian_arrs[k])
130 | return big_h
131 |
132 | sum_big_hs = jnp.zeros((2**n_qubits, 2**n_qubits), dtype=complex)
133 | for i in range(len(hermitian_str_seq_seq)):
134 | sum_big_hs += coefs[i] * big_hermitian_matrix(
135 | hermitian_str_seq_seq[i], qubit_inds_seq[i]
136 | )
137 |
138 | assert jnp.allclose(sum_big_hs, sum_big_hs.conj().T)
139 |
140 | sv = st_in.flatten()
141 | true_exp = jnp.dot(sv.conj(), sum_big_hs @ sv).real
142 |
143 | qujax_exp = st_to_exp(st_in)
144 | qujax_dt_exp = dt_to_exp(dt_in)
145 | qujax_exp_jit = jit(st_to_exp)(st_in)
146 | qujax_dt_exp_jit = jit(dt_to_exp)(dt_in)
147 |
148 | assert jnp.array(qujax_exp).shape == ()
149 | assert jnp.array(qujax_exp).dtype.name[:5] == "float"
150 | assert jnp.isclose(true_exp, qujax_exp)
151 | assert jnp.isclose(true_exp, qujax_dt_exp)
152 | assert jnp.isclose(true_exp, qujax_exp_jit)
153 | assert jnp.isclose(true_exp, qujax_dt_exp_jit)
154 |
155 | st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func(
156 | hermitian_str_seq_seq, qubit_inds_seq, coefs
157 | )
158 | dt_to_samp_exp = qujax.get_densitytensor_to_sampled_expectation_func(
159 | hermitian_str_seq_seq, qubit_inds_seq, coefs
160 | )
161 | qujax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 1000000)
162 | qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)(
163 | st_in, random.PRNGKey(1), 1000000
164 | )
165 | qujax_samp_exp_dt = dt_to_samp_exp(dt_in, random.PRNGKey(1), 1000000)
166 | qujax_samp_exp_dt_jit = jit(dt_to_samp_exp, static_argnums=2)(
167 | dt_in, random.PRNGKey(1), 1000000
168 | )
169 | assert jnp.array(qujax_samp_exp).shape == ()
170 | assert jnp.array(qujax_samp_exp).dtype.name[:5] == "float"
171 | assert jnp.isclose(true_exp, qujax_samp_exp, rtol=1e-2)
172 | assert jnp.isclose(true_exp, qujax_samp_exp_jit, rtol=1e-2)
173 | assert jnp.isclose(true_exp, qujax_samp_exp_dt, rtol=1e-2)
174 | assert jnp.isclose(true_exp, qujax_samp_exp_dt_jit, rtol=1e-2)
175 |
176 |
177 | def test_X():
178 | hermitian_str_seq_seq = ["X"]
179 | qubit_inds_seq = [[0]]
180 | coefs = [1]
181 |
182 | gates = ["H", "Rz"]
183 | qubit = [[0], [0]]
184 | param_ind = [[], [0]]
185 | st_in = qujax.get_params_to_statetensor_func(gates, qubit, param_ind)(0.3)
186 |
187 | _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs, st_in)
188 |
189 |
190 | def test_Y():
191 | n_qubits = 1
192 |
193 | hermitian_str_seq_seq = ["Y"] * n_qubits
194 | qubit_inds_seq = [[i] for i in range(n_qubits)]
195 | coefs = jnp.ones(len(hermitian_str_seq_seq))
196 |
197 | _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs)
198 |
199 |
200 | def test_Z():
201 | n_qubits = 1
202 |
203 | hermitian_str_seq_seq = ["Z"] * n_qubits
204 | qubit_inds_seq = [[i] for i in range(n_qubits)]
205 | coefs = random.normal(random.PRNGKey(0), shape=(len(hermitian_str_seq_seq),))
206 |
207 | _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs)
208 |
209 |
210 | def test_XYZ():
211 | n_qubits = 1
212 |
213 | hermitian_str_seq_seq = ["X", "Y", "Z"] * n_qubits
214 | qubit_inds_seq = [[i] for _ in range(3) for i in range(n_qubits)]
215 | coefs = random.normal(random.PRNGKey(0), shape=(len(hermitian_str_seq_seq),))
216 |
217 | _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs)
218 |
219 |
220 | def test_ZZ_Y():
221 | config.update("jax_enable_x64", True) # Run this test with 64 bit precision
222 |
223 | n_qubits = 4
224 |
225 | hermitian_str_seq_seq = [["Z", "Z"]] * (n_qubits - 1) + [["Y"]] * n_qubits
226 | qubit_inds_seq = [[i, i + 1] for i in range(n_qubits - 1)] + [
227 | [i] for i in range(n_qubits)
228 | ]
229 | coefs = random.normal(random.PRNGKey(1), shape=(len(hermitian_str_seq_seq),))
230 |
231 | _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs)
232 |
233 |
234 | def test_sampling():
235 | target_pmf = jnp.array([0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0])
236 | target_pmf /= target_pmf.sum()
237 |
238 | target_st = jnp.sqrt(target_pmf).reshape(
239 | (2,) * int(jnp.rint(jnp.log2(target_pmf.size)))
240 | )
241 |
242 | n_samps = 7
243 |
244 | sample_ints = qujax.sample_integers(random.PRNGKey(0), target_st, n_samps)
245 | assert sample_ints.shape == (n_samps,)
246 | assert all(target_pmf[sample_ints] > 0)
247 |
248 | sample_bitstrings = qujax.sample_bitstrings(random.PRNGKey(0), target_st, n_samps)
249 | assert sample_bitstrings.shape == (
250 | n_samps,
251 | int(jnp.rint(jnp.log2(target_pmf.size))),
252 | )
253 | assert all(qujax.bitstrings_to_integers(sample_bitstrings) == sample_ints)
254 |
--------------------------------------------------------------------------------
/tests/test_gates.py:
--------------------------------------------------------------------------------
1 | from qujax import check_unitary, gates
2 |
3 |
4 | def test_gates():
5 | for g_str, g in gates.__dict__.items():
6 | # Exclude elements in jax.gates namespace which are not gates
7 | if g_str[0] != "_" and g_str not in ("jax", "jnp"):
8 | check_unitary(g_str)
9 | check_unitary(g)
10 |
--------------------------------------------------------------------------------
/tests/test_statetensor.py:
--------------------------------------------------------------------------------
1 | from jax import jit
2 | from jax import numpy as jnp
3 |
4 | import qujax
5 |
6 |
7 | def test_H():
8 | gates = ["H"]
9 | qubits = [[0]]
10 | param_inds = [[]]
11 |
12 | param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds)
13 | st = param_to_st()
14 | st_jit = jit(param_to_st)()
15 |
16 | true_sv = jnp.array([0.70710678 + 0.0j, 0.70710678 + 0.0j])
17 |
18 | assert st.size == true_sv.size
19 | assert jnp.allclose(st.flatten(), true_sv)
20 | assert jnp.allclose(st_jit.flatten(), true_sv)
21 |
22 | param_to_unitary = qujax.get_params_to_unitarytensor_func(gates, qubits, param_inds)
23 | unitary = param_to_unitary().reshape(2, 2)
24 | unitary_jit = jit(param_to_unitary)().reshape(2, 2)
25 | zero_sv = jnp.zeros(2).at[0].set(1)
26 | assert jnp.allclose(unitary @ zero_sv, true_sv)
27 | assert jnp.allclose(unitary_jit @ zero_sv, true_sv)
28 |
29 |
30 | def test_H_redundant_qubits():
31 | gates = ["H"]
32 | qubits = [[0]]
33 | param_inds = [[]]
34 | n_qubits = 3
35 |
36 | param_to_st = qujax.get_params_to_statetensor_func(
37 | gates, qubits, param_inds, n_qubits
38 | )
39 | st = param_to_st(statetensor_in=None)
40 |
41 | true_sv = jnp.array([0.70710678, 0.0, 0.0, 0.0, 0.70710678, 0.0, 0.0, 0.0])
42 |
43 | assert st.size == true_sv.size
44 | assert jnp.allclose(st.flatten(), true_sv)
45 |
46 | param_to_unitary = qujax.get_params_to_unitarytensor_func(
47 | gates, qubits, param_inds, n_qubits
48 | )
49 | unitary = param_to_unitary().reshape(2**n_qubits, 2**n_qubits)
50 | unitary_jit = jit(param_to_unitary)().reshape(2**n_qubits, 2**n_qubits)
51 | zero_sv = jnp.zeros(2**n_qubits).at[0].set(1)
52 | assert jnp.allclose(unitary @ zero_sv, true_sv)
53 | assert jnp.allclose(unitary_jit @ zero_sv, true_sv)
54 |
55 |
56 | def test_CX_Rz_CY():
57 | gates = ["H", "H", "H", "CX", "Rz", "CY"]
58 | qubits = [[0], [1], [2], [0, 1], [1], [1, 2]]
59 | param_inds = [[], [], [], None, [0], []]
60 |
61 | param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds)
62 | param = jnp.array(0.1)
63 | st = param_to_st(param)
64 |
65 | true_sv = jnp.array(
66 | [
67 | 0.34920055 - 0.05530793j,
68 | 0.34920055 - 0.05530793j,
69 | 0.05530793 - 0.34920055j,
70 | -0.05530793 + 0.34920055j,
71 | 0.34920055 - 0.05530793j,
72 | 0.34920055 - 0.05530793j,
73 | 0.05530793 - 0.34920055j,
74 | -0.05530793 + 0.34920055j,
75 | ],
76 | dtype="complex64",
77 | )
78 |
79 | assert st.size == true_sv.size
80 | assert jnp.allclose(st.flatten(), true_sv)
81 |
82 | n_qubits = 3
83 | param_to_unitary = qujax.get_params_to_unitarytensor_func(
84 | gates, qubits, param_inds, n_qubits
85 | )
86 | unitary = param_to_unitary(param).reshape(2**n_qubits, 2**n_qubits)
87 | unitary_jit = jit(param_to_unitary)(param).reshape(2**n_qubits, 2**n_qubits)
88 | zero_sv = jnp.zeros(2**n_qubits).at[0].set(1)
89 | assert jnp.allclose(unitary @ zero_sv, true_sv)
90 | assert jnp.allclose(unitary_jit @ zero_sv, true_sv)
91 |
92 |
93 | def test_stacked_circuits():
94 | gates = ["H"]
95 | qubits = [[0]]
96 | param_inds = [[]]
97 |
98 | param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds)
99 |
100 | st1 = param_to_st()
101 | st2 = param_to_st(st1)
102 |
103 | st2_2 = param_to_st(statetensor_in=st1)
104 |
105 | all_zeros_sv = jnp.array(jnp.arange(st2.size) == 0, dtype=int)
106 |
107 | assert jnp.allclose(st2.flatten(), all_zeros_sv, atol=1e-7)
108 | assert jnp.allclose(st2_2.flatten(), all_zeros_sv, atol=1e-7)
109 |
--------------------------------------------------------------------------------
/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | import jax
2 | import jax.numpy as jnp
3 |
4 | import qujax
5 |
6 |
7 | def test_repeat_circuit():
8 | n_qubits = 4
9 | depth = 4
10 | seed = 0
11 | statetensor_in = qujax.all_zeros_statetensor(n_qubits)
12 | densitytensor_in = qujax.all_zeros_densitytensor(n_qubits)
13 |
14 | def circuit(n_qubits: int, depth: int):
15 | parameter_index = 0
16 |
17 | gates = []
18 | qubit_inds = []
19 | param_inds = []
20 |
21 | for _ in range(depth):
22 | # Rx layer
23 | for i in range(n_qubits):
24 | gates.append("Rx")
25 | qubit_inds.append([i])
26 | param_inds.append([parameter_index])
27 | parameter_index += 1
28 |
29 | # CRz layer
30 | for i in range(n_qubits - 1):
31 | gates.append("CRz")
32 | qubit_inds.append([i, i + 1])
33 | param_inds.append([parameter_index])
34 | parameter_index += 1
35 |
36 | return gates, qubit_inds, param_inds, parameter_index
37 |
38 | rng = jax.random.PRNGKey(seed)
39 |
40 | g1, qi1, pi1, np1 = circuit(n_qubits, depth)
41 |
42 | params = jax.random.uniform(rng, (np1,))
43 |
44 | param_to_st = qujax.get_params_to_statetensor_func(g1, qi1, pi1, n_qubits)
45 |
46 | g2, qi2, pi2, np2 = circuit(n_qubits, 1)
47 |
48 | param_to_st_single_repetition = qujax.get_params_to_statetensor_func(
49 | g2, qi2, pi2, n_qubits
50 | )
51 | param_to_st_repeated = qujax.repeat_circuit(param_to_st_single_repetition, np2)
52 |
53 | assert jnp.allclose(
54 | param_to_st(params, statetensor_in),
55 | param_to_st_repeated(params, statetensor_in),
56 | )
57 |
58 | param_to_dt = qujax.get_params_to_densitytensor_func(g1, qi1, pi1, n_qubits)
59 |
60 | param_to_dt_single_repetition = qujax.get_params_to_densitytensor_func(
61 | g2, qi2, pi2, n_qubits
62 | )
63 | param_to_dt_repeated = qujax.repeat_circuit(param_to_dt_single_repetition, np2)
64 |
65 | assert jnp.allclose(
66 | param_to_dt(params, densitytensor_in),
67 | param_to_dt_repeated(params, densitytensor_in),
68 | )
69 |
--------------------------------------------------------------------------------