├── .github ├── dependabot.yml └── workflows │ ├── build_and_test.yml │ ├── docs.yml │ ├── docs │ └── requirements.txt │ └── lint.yml ├── .gitignore ├── .pylintrc ├── CITATION.cff ├── LICENSE ├── README.md ├── docs ├── _static │ └── css │ │ └── custom.css ├── conf.py ├── densitytensor.rst ├── densitytensor │ ├── all_zeros_densitytensor.rst │ ├── densitytensor_to_measured_densitytensor.rst │ ├── densitytensor_to_measurement_probabilities.rst │ ├── get_densitytensor_to_expectation_func.rst │ ├── get_densitytensor_to_sampled_expectation_func.rst │ ├── get_params_to_densitytensor_func.rst │ ├── kraus.rst │ ├── partial_trace.rst │ └── statetensor_to_densitytensor.rst ├── examples.md ├── gates.rst ├── getting_started.rst ├── index.rst ├── logo.svg ├── logo_dark_mode.svg ├── requirements.txt ├── statetensor.rst ├── statetensor │ ├── all_zeros_statetensor.rst │ ├── apply_gate.rst │ ├── get_params_to_statetensor_func.rst │ ├── get_params_to_unitarytensor_func.rst │ ├── get_statetensor_to_expectation_func.rst │ └── get_statetensor_to_sampled_expectation_func.rst ├── utils.rst └── utils │ ├── bitstrings_to_integers.rst │ ├── check_circuit.rst │ ├── integers_to_bitstrings.rst │ ├── print_circuit.rst │ ├── repeat_circuit.rst │ ├── sample_bitstrings.rst │ └── sample_integers.rst ├── examples ├── barren_plateaus.ipynb ├── classification.ipynb ├── generative_modelling.ipynb ├── heisenberg_vqe.ipynb ├── hidalgo_stamp.txt ├── maxcut_vqe.ipynb ├── noise_channel.ipynb ├── qaoa.ipynb ├── reducing_jit_compilation_time.ipynb └── variational_inference.ipynb ├── qujax ├── __init__.py ├── densitytensor.py ├── densitytensor_observable.py ├── gates.py ├── statetensor.py ├── statetensor_observable.py ├── typing.py ├── utils.py └── version.py ├── setup.py └── tests ├── test_densitytensor.py ├── test_expectations.py ├── test_gates.py ├── test_statetensor.py └── test_utils.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "github-actions" 4 | directory: "/" 5 | schedule: 6 | interval: "weekly" -------------------------------------------------------------------------------- /.github/workflows/build_and_test.yml: -------------------------------------------------------------------------------- 1 | name: Build and test 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | - develop 8 | push: 9 | branches: 10 | - develop 11 | - main 12 | release: 13 | types: 14 | - created 15 | - edited 16 | 17 | jobs: 18 | linux-checks: 19 | name: Linux - Build and test module 20 | runs-on: ubuntu-20.04 21 | 22 | steps: 23 | - uses: actions/checkout@v4 24 | with: 25 | fetch-depth: '0' 26 | - name: Set up Python 3.9 27 | uses: actions/setup-python@v5 28 | with: 29 | python-version: '3.9' 30 | - name: Build and test 31 | run: | 32 | ARTIFACTSDIR=${GITHUB_WORKSPACE}/wheelhouse 33 | rm -rf ${ARTIFACTSDIR} && mkdir ${ARTIFACTSDIR} 34 | python -m pip install --upgrade pip wheel build pytest 35 | python -m build 36 | for w in dist/*.whl ; do 37 | python -m pip install $w 38 | cp $w ${ARTIFACTSDIR} 39 | done 40 | cd tests 41 | pytest 42 | - uses: actions/upload-artifact@v4 43 | if: github.event_name == 'release' 44 | with: 45 | name: artefacts 46 | path: wheelhouse/ 47 | 48 | 49 | 50 | publish_to_pypi: 51 | name: Publish to pypi 52 | if: github.event_name == 'release' 53 | needs: linux-checks 54 | runs-on: ubuntu-20.04 55 | steps: 56 | - name: Download all wheels 57 | uses: actions/download-artifact@v4 58 | with: 59 | path: wheelhouse 60 | - name: Put them all in the dist folder 61 | run: | 62 | mkdir dist 63 | for w in `find wheelhouse/ -type f -name "*.whl"` ; do cp $w dist/ ; done 64 | - name: Publish wheels 65 | uses: pypa/gh-action-pypi-publish@release/v1 66 | with: 67 | user: __token__ 68 | password: ${{ secrets.PYPI_QUJAX_API_TOKEN }} 69 | verbose: true 70 | 71 | -------------------------------------------------------------------------------- /.github/workflows/docs.yml: -------------------------------------------------------------------------------- 1 | name: Qujax Docs 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | - develop 8 | push: 9 | branches: 10 | - main 11 | 12 | jobs: 13 | build_docs: 14 | name: Build and publish docs 15 | runs-on: ubuntu-20.04 16 | steps: 17 | - uses: actions/checkout@v4 18 | with: 19 | fetch-depth: '0' 20 | - name: Set up Python 3.9 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.9' 24 | - name: Upgrade pip and install wheel 25 | run: pip install --upgrade pip wheel 26 | - name: Install qujax 27 | run: | 28 | pip install . 29 | - name: Install docs dependencies 30 | run: | 31 | pip install -r .github/workflows/docs/requirements.txt 32 | - name: Build docs 33 | timeout-minutes: 20 34 | run: | 35 | cd .github/workflows/docs 36 | mkdir qujax 37 | cd qujax 38 | sphinx-build ../../../../docs . -a 39 | - name: Upload docs as artefact 40 | uses: actions/upload-pages-artifact@v3 41 | with: 42 | path: .github/workflows/docs/qujax 43 | 44 | publish_docs: 45 | name: Publish docs 46 | if: github.event_name == 'push' && contains(github.ref_name, 'main') 47 | needs: build_docs 48 | runs-on: ubuntu-22.04 49 | permissions: 50 | pages: write 51 | id-token: write 52 | environment: 53 | name: github-pages 54 | url: ${{ steps.deployment.outputs.page_url }} 55 | steps: 56 | - name: Deploy to GitHub Pages 57 | id: deployment 58 | uses: actions/deploy-pages@v4 -------------------------------------------------------------------------------- /.github/workflows/docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx ~= 7.2 2 | sphinx_rtd_theme 3 | myst-parser ~= 2.0 -------------------------------------------------------------------------------- /.github/workflows/lint.yml: -------------------------------------------------------------------------------- 1 | name: Lint python projects 2 | 3 | on: 4 | pull_request: 5 | branches: 6 | - main 7 | - develop 8 | push: 9 | branches: 10 | - main 11 | - develop 12 | 13 | jobs: 14 | lint: 15 | 16 | runs-on: ubuntu-22.04 17 | 18 | steps: 19 | - uses: actions/checkout@v4 20 | - name: Set up Python 3.x 21 | uses: actions/setup-python@v5 22 | with: 23 | python-version: '3.x' 24 | - name: Update pip 25 | run: pip install --upgrade pip 26 | - name: Install black and pylint 27 | run: pip install black pylint 28 | - name: Check files are formatted with black 29 | run: | 30 | black --check . 31 | - name: Run pylint 32 | run: | 33 | pylint */ -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info/ 2 | .idea/ 3 | dist 4 | .DS_Store 5 | .vscode 6 | .pytest_cache 7 | __pycache__ 8 | *.ipynb_checkpoints/ 9 | build 10 | Makefile 11 | make.bat 12 | _build 13 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | output-format=colorized 3 | disable=all 4 | enable= 5 | anomalous-backslash-in-string, 6 | assert-on-tuple, 7 | bad-indentation, 8 | bad-option-value, 9 | bad-reversed-sequence, 10 | bad-super-call, 11 | consider-merging-isinstance, 12 | continue-in-finally, 13 | dangerous-default-value, 14 | duplicate-argument-name, 15 | expression-not-assigned, 16 | function-redefined, 17 | inconsistent-mro, 18 | init-is-generator, 19 | line-too-long, 20 | lost-exception, 21 | missing-kwoa, 22 | mixed-line-endings, 23 | not-callable, 24 | no-value-for-parameter, 25 | nonexistent-operator, 26 | not-in-loop, 27 | pointless-statement, 28 | redefined-builtin, 29 | return-arg-in-generator, 30 | return-in-init, 31 | return-outside-function, 32 | simplifiable-if-statement, 33 | syntax-error, 34 | too-many-function-args, 35 | trailing-whitespace, 36 | undefined-variable, 37 | unexpected-keyword-arg, 38 | unhashable-dict-key, 39 | unnecessary-pass, 40 | unreachable, 41 | unrecognized-inline-option, 42 | unused-import, 43 | unnecessary-semicolon, 44 | unused-variable, 45 | unused-wildcard-import, 46 | wildcard-import, 47 | wrong-import-order, 48 | wrong-import-position, 49 | yield-outside-function 50 | 51 | 52 | # Ignore long lines containing URLs or pylint. 53 | ignore-long-lines=^(.*#\w*pylint: disable.*|\s*(# )??)$ 54 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: "1.2.0" 2 | authors: 3 | - family-names: Duffield 4 | given-names: Samuel 5 | orcid: "https://orcid.org/0000-0002-8656-8734" 6 | - family-names: Matos 7 | given-names: Gabriel 8 | orcid: "https://orcid.org/0000-0002-3373-0128" 9 | - family-names: Johannsen 10 | given-names: Melf 11 | contact: 12 | - family-names: Duffield 13 | given-names: Samuel 14 | orcid: "https://orcid.org/0000-0002-8656-8734" 15 | doi: 10.5281/zenodo.8268973 16 | message: If you use this software, please cite our article in the 17 | Journal of Open Source Software. 18 | preferred-citation: 19 | authors: 20 | - family-names: Duffield 21 | given-names: Samuel 22 | orcid: "https://orcid.org/0000-0002-8656-8734" 23 | - family-names: Matos 24 | given-names: Gabriel 25 | orcid: "https://orcid.org/0000-0002-3373-0128" 26 | - family-names: Johannsen 27 | given-names: Melf 28 | date-published: 2023-09-12 29 | doi: 10.21105/joss.05504 30 | issn: 2475-9066 31 | issue: 89 32 | journal: Journal of Open Source Software 33 | publisher: 34 | name: Open Journals 35 | start: 5504 36 | title: "qujax: Simulating quantum circuits with JAX" 37 | type: article 38 | url: "https://joss.theoj.org/papers/10.21105/joss.05504" 39 | volume: 8 40 | title: "qujax: Simulating quantum circuits with JAX" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # qujax 2 | 3 |
4 | 5 | 6 | 7 | 8 | 9 | 10 |
11 | 12 | [![PyPI - Version](https://img.shields.io/pypi/v/qujax)](https://pypi.org/project/qujax/) 13 | [![DOI](https://joss.theoj.org/papers/10.21105/joss.05504/status.svg)](https://doi.org/10.21105/joss.05504) 14 | 15 | [**Documentation**](https://cqcl.github.io/qujax/) | [**Installation**](#installation) | [**Quick start**](#quick-start) | [**Examples**](https://cqcl.github.io/qujax/examples.html) | [**Contributing**](#contributing) | [**Citing qujax**](#citing-qujax) 16 | 17 | qujax is a [JAX](https://github.com/google/jax)-based Python library for the classical simulation of quantum circuits. It is designed to be *simple*, *fast* and *flexible*. 18 | 19 | It follows a functional programming design by translating circuits into pure functions. This allows qujax to [seamlessly interface with JAX](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions), enabling direct access to its powerful automatic differentiation tools, just-in-time compiler, vectorization capabilities, GPU/TPU integration and growing ecosystem of packages. 20 | 21 | qujax can be used both for pure and for mixed quantum state simulation. It not only supports the standard gate set, but also allows user-defined custom operations, including general quantum channels, enabling the user to e.g. model device noise and errors. 22 | 23 | A summary of the core functionalities of qujax can be found in the [Quick start](#quick-start) section. More advanced use-cases, including the training of parameterised quantum circuits, can be found in the [Examples](https://cqcl.github.io/qujax/examples.html) section of the documentation. 24 | 25 | 26 | ## Installation 27 | 28 | qujax is [hosted on PyPI](https://pypi.org/project/qujax/) and can be installed via the pip package manager 29 | ``` 30 | pip install qujax 31 | ``` 32 | 33 | ## Quick start 34 | 35 | **Important note: qujax circuit parameters are expressed in units of $\pi$ (e.g. in the range $[0,2]$ as opposed to $[0, 2\pi]$)**. 36 | 37 | Start by defining the quantum gates making up the circuit, the qubits that they act on, and the indices of the parameters for each gate. 38 | 39 | A list of all gates can be found [here](https://github.com/CQCL/qujax/blob/main/qujax/gates.py) (custom operations can be included by [passing an array or function](https://cqcl.github.io/qujax/statetensor/get_params_to_statetensor_func.html) instead of a string). 40 | 41 | ```python 42 | from jax import numpy as jnp 43 | import qujax 44 | 45 | # List of quantum gates 46 | circuit_gates = ['H', 'Ry', 'CZ'] 47 | # Indices of qubits the gates will be applied to 48 | circuit_qubit_inds = [[0], [0], [0, 1]] 49 | # Indices of parameters each parameterised gate will use 50 | circuit_params_inds = [[], [0], []] 51 | 52 | qujax.print_circuit(circuit_gates, circuit_qubit_inds, circuit_params_inds); 53 | # q0: -----H-----Ry[0]-----◯--- 54 | # | 55 | # q1: ---------------------CZ-- 56 | ``` 57 | 58 | Translate the circuit to a pure function `param_to_st` that takes a set of parameters and an (optional) initial quantum state as its input. 59 | 60 | ```python 61 | param_to_st = qujax.get_params_to_statetensor_func(circuit_gates, 62 | circuit_qubit_inds, 63 | circuit_params_inds) 64 | 65 | param_to_st(jnp.array([0.1])) 66 | # Array([[0.58778524+0.j, 0. +0.j], 67 | # [0.80901706+0.j, 0. +0.j]], dtype=complex64) 68 | ``` 69 | 70 | The optional initial state can be passed to `param_to_st` using the `statetensor_in` argument. When it is not provided, the initial state defaults to $\ket{0...0}$. 71 | 72 | Map the state to an expectation value by defining an observable using lists of Pauli matrices, the qubits they act on, and the associated coefficients. 73 | 74 | ```python 75 | st_to_expectation = qujax.get_statetensor_to_expectation_func([['Z']], [[0]], [1.]) 76 | ``` 77 | 78 | Combining `param_to_st` and `st_to_expectation` gives us a parameter to expectation function that can be automatically differentiated using JAX. 79 | 80 | ```python 81 | from jax import value_and_grad 82 | 83 | param_to_expectation = lambda param: st_to_expectation(param_to_st(param)) 84 | expectation_and_grad = value_and_grad(param_to_expectation) 85 | expectation_and_grad(jnp.array([0.1])) 86 | # (Array(-0.3090171, dtype=float32), 87 | # Array([-2.987832], dtype=float32)) 88 | ``` 89 | 90 | Mixed state simulations are analogous to the above, but with calls to [`get_params_to_densitytensor_func`](https://cqcl.github.io/qujax/densitytensor/get_params_to_densitytensor_func.html) and [`get_densitytensor_to_expectation_func`](https://cqcl.github.io/qujax/densitytensor/get_densitytensor_to_expectation_func.html) instead. 91 | 92 | A more in-depth version of the above can be found in the [Getting started](https://cqcl.github.io/qujax/getting_started.html) section of the documentation. More advanced use-cases, including the training of parameterised quantum circuits, can be found in the [Examples](https://cqcl.github.io/qujax/examples.html) section of the documentation. 93 | 94 | ## Converting from TKET 95 | 96 | A [`pytket`](https://cqcl.github.io/tket/pytket/api/) circuit can be directly converted using the [`tk_to_qujax`](https://cqcl.github.io/pytket-qujax/api/api.html#pytket.extensions.qujax.qujax_convert.tk_to_qujax) and [`tk_to_qujax_symbolic`](https://cqcl.github.io/pytket-qujax/api/api.html#pytket.extensions.qujax.qujax_convert.tk_to_qujax_symbolic) functions in the [**`pytket-qujax`**](https://github.com/CQCL/pytket-qujax) extension. See [`pytket-qujax_heisenberg_vqe.ipynb`](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_heisenberg_vqe.ipynb) for an example. 97 | 98 | ## Contributing 99 | 100 | You can open a bug report or a feature request by creating a new [issue on GitHub](https://github.com/CQCL/qujax/issues). 101 | 102 | Pull requests are welcome! To open a new one, please go through the following steps: 103 | 104 | 1. First fork the repo and create your branch from [`develop`](https://github.com/CQCL/qujax/tree/develop). 105 | 2. Commit your code and tests. 106 | 4. Update the documentation, if required. 107 | 5. Check the code lints (run `black . --check` and `pylint */`). 108 | 6. Issue a pull request into the [`develop`](https://github.com/CQCL/qujax/tree/develop) branch. 109 | 110 | New commits on [`develop`](https://github.com/CQCL/qujax/tree/develop) will be merged into 111 | [`main`](https://github.com/CQCL/qujax/tree/main) in the next release. 112 | 113 | 114 | ## Citing qujax 115 | 116 | If you have used qujax in your code or research, we kindly ask that you cite it. You can use the following BibTeX entry for this: 117 | 118 | ```bibtex 119 | @article{qujax2023, 120 | author = {Duffield, Samuel and Matos, Gabriel and Johannsen, Melf}, 121 | doi = {10.21105/joss.05504}, 122 | journal = {Journal of Open Source Software}, 123 | month = sep, 124 | number = {89}, 125 | pages = {5504}, 126 | title = {{qujax: Simulating quantum circuits with JAX}}, 127 | url = {https://joss.theoj.org/papers/10.21105/joss.05504}, 128 | volume = {8}, 129 | year = {2023} 130 | } 131 | ``` -------------------------------------------------------------------------------- /docs/_static/css/custom.css: -------------------------------------------------------------------------------- 1 | 2 | .wy-nav-top{ 3 | background-color: #203847 4 | } 5 | 6 | .wy-side-nav-search{ 7 | background-color: white; 8 | } 9 | 10 | .icon.icon-home{ 11 | color: #000000 12 | } 13 | 14 | .wy-menu-vertical p.caption{ 15 | color: #85cfcb; 16 | } 17 | 18 | .wy-side-nav-search > a{ 19 | color: #000000; 20 | } 21 | 22 | .wy-side-nav-search > div.version{ 23 | color: #000000; 24 | } 25 | 26 | .sig { 27 | background: #85cfcb; 28 | } 29 | 30 | @media screen and (min-width: 1000px) { 31 | .wy-nav-content { 32 | max-width: 1000px; 33 | } 34 | } 35 | 36 | html.writer-html4 .rst-content dl:not(.docutils) > dt, html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.citation):not(.glossary):not(.simple) > dt { 37 | display:block; 38 | background-color: #ebeef1; 39 | } 40 | 41 | #examples ul, ul.simple { 42 | list-style: none; 43 | } 44 | 45 | #examples ul li, ul.simple li { 46 | margin-bottom: 10px; 47 | } 48 | 49 | h1, h2, h3, h4, h5, h6 { 50 | color: #203847 51 | } 52 | 53 | div.toctree-wrapper .caption-text{ 54 | color: #203847; 55 | } 56 | 57 | .rst-content .viewcode-back, .rst-content .viewcode-link { 58 | padding-left: 6px; 59 | } -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import importlib 4 | import inspect 5 | import pathlib 6 | 7 | sys.path.insert(0, os.path.abspath("..")) # pylint: disable=wrong-import-position 8 | 9 | from qujax.version import __version__ 10 | 11 | # -- Project information ----------------------------------------------------- 12 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information 13 | 14 | project = "qujax" 15 | project_copyright = "2023, The qujax authors" 16 | author = "Sam Duffield, Gabriel Matos, Melf Johannsen" 17 | version = __version__ 18 | release = __version__ 19 | 20 | # -- General configuration --------------------------------------------------- 21 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration 22 | 23 | extensions = [ 24 | "sphinx.ext.autodoc", 25 | "sphinx_rtd_theme", 26 | "sphinx.ext.napoleon", 27 | "sphinx.ext.mathjax", 28 | "sphinx.ext.linkcode", 29 | "myst_parser", 30 | ] 31 | 32 | templates_path = ["_templates"] 33 | exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] 34 | 35 | # -- Options for HTML output ------------------------------------------------- 36 | # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output 37 | 38 | html_theme = "sphinx_rtd_theme" 39 | 40 | autodoc_typehints = "description" 41 | 42 | autodoc_type_aliases = { 43 | "random.PRNGKeyArray": "jax.random.PRNGKeyArray", 44 | "UnionCallableOptionalArray": "Union[Callable[[ndarray, Optional[ndarray]], ndarray], " 45 | "Callable[[Optional[ndarray]], ndarray]]", 46 | } 47 | 48 | latex_engine = "pdflatex" 49 | 50 | titles_only = True 51 | 52 | rst_prolog = """ 53 | .. role:: python(code) 54 | :language: python 55 | """ 56 | 57 | html_logo = "logo.svg" 58 | 59 | html_static_path = ["_static"] 60 | html_css_files = [ 61 | "css/custom.css", 62 | ] 63 | 64 | html_theme_options = { 65 | "collapse_navigation": False, 66 | "prev_next_buttons_location": "None", 67 | } 68 | 69 | 70 | def linkcode_resolve(domain, info): 71 | """ 72 | Called by sphinx's linkcode extension, which adds links directing the user to the 73 | source code of the API objects being documented. The `domain` argument specifies which 74 | programming language the object belongs to. The `info` argument is a dictionary with 75 | information specific to the programming language of the object. 76 | 77 | For Python objects, this dictionary contains a `module` key with the module the object is in 78 | and a `fullname` key with the name of the object. This function uses this information to find 79 | the source file and range of lines the object is defined in and to generate a link pointing to 80 | those lines on GitHub. 81 | """ 82 | github_url = f"https://github.com/CQCL/qujax/tree/develop/qujax" 83 | 84 | if domain != "py": 85 | return 86 | 87 | module = importlib.import_module(info["module"]) 88 | obj = getattr(module, info["fullname"]) 89 | 90 | try: 91 | path = pathlib.Path(inspect.getsourcefile(obj)) 92 | file_name = path.name 93 | lines = inspect.getsourcelines(obj) 94 | except TypeError: 95 | return 96 | 97 | start_line, end_line = lines[1], lines[1] + len(lines[0]) - 1 98 | 99 | return f"{github_url}/{file_name}#L{start_line}-L{end_line}" 100 | -------------------------------------------------------------------------------- /docs/densitytensor.rst: -------------------------------------------------------------------------------- 1 | Mixed state simulation 2 | ======================= 3 | 4 | .. toctree:: 5 | :titlesonly: 6 | 7 | densitytensor/all_zeros_densitytensor 8 | densitytensor/kraus 9 | densitytensor/get_params_to_densitytensor_func 10 | densitytensor/partial_trace 11 | densitytensor/get_densitytensor_to_expectation_func 12 | densitytensor/get_densitytensor_to_sampled_expectation_func 13 | densitytensor/densitytensor_to_measurement_probabilities 14 | densitytensor/densitytensor_to_measured_densitytensor 15 | densitytensor/statetensor_to_densitytensor 16 | 17 | -------------------------------------------------------------------------------- /docs/densitytensor/all_zeros_densitytensor.rst: -------------------------------------------------------------------------------- 1 | all_zeros_densitytensor 2 | ============================================== 3 | 4 | .. autofunction:: qujax.all_zeros_densitytensor 5 | -------------------------------------------------------------------------------- /docs/densitytensor/densitytensor_to_measured_densitytensor.rst: -------------------------------------------------------------------------------- 1 | densitytensor_to_measured_densitytensor 2 | ============================================== 3 | 4 | .. autofunction:: qujax.densitytensor_to_measured_densitytensor 5 | -------------------------------------------------------------------------------- /docs/densitytensor/densitytensor_to_measurement_probabilities.rst: -------------------------------------------------------------------------------- 1 | densitytensor_to_measurement_probabilities 2 | ============================================== 3 | 4 | .. autofunction:: qujax.densitytensor_to_measurement_probabilities 5 | -------------------------------------------------------------------------------- /docs/densitytensor/get_densitytensor_to_expectation_func.rst: -------------------------------------------------------------------------------- 1 | get_densitytensor_to_expectation_func 2 | ============================================== 3 | 4 | .. autofunction:: qujax.get_densitytensor_to_expectation_func 5 | -------------------------------------------------------------------------------- /docs/densitytensor/get_densitytensor_to_sampled_expectation_func.rst: -------------------------------------------------------------------------------- 1 | get_densitytensor_to_sampled_expectation_func 2 | ============================================== 3 | 4 | .. autofunction:: qujax.get_densitytensor_to_sampled_expectation_func 5 | -------------------------------------------------------------------------------- /docs/densitytensor/get_params_to_densitytensor_func.rst: -------------------------------------------------------------------------------- 1 | get_params_to_densitytensor_func 2 | ============================================== 3 | 4 | .. autofunction:: qujax.get_params_to_densitytensor_func 5 | -------------------------------------------------------------------------------- /docs/densitytensor/kraus.rst: -------------------------------------------------------------------------------- 1 | kraus 2 | ============================================== 3 | 4 | .. autofunction:: qujax.kraus 5 | -------------------------------------------------------------------------------- /docs/densitytensor/partial_trace.rst: -------------------------------------------------------------------------------- 1 | partial_trace 2 | ============================================== 3 | 4 | .. autofunction:: qujax.partial_trace 5 | -------------------------------------------------------------------------------- /docs/densitytensor/statetensor_to_densitytensor.rst: -------------------------------------------------------------------------------- 1 | statetensor_to_densitytensor 2 | ============================================== 3 | 4 | .. autofunction:: qujax.statetensor_to_densitytensor 5 | -------------------------------------------------------------------------------- /docs/examples.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Below are some use-case notebooks. These both illustrate the flexibility of qujax and the power of directly interfacing with JAX and its package ecosystem. 4 | 5 | - [heisenberg_vqe.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/heisenberg_vqe.ipynb) - an implementation of the variational quantum eigensolver to find the ground state of a quantum Hamiltonian. 6 | - [maxcut_vqe.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/maxcut_vqe.ipynb) - an implementation of the variational quantum eigensolver to solve a MaxCut problem. Trains with Adam via [`optax`](https://github.com/deepmind/optax) and uses more realistic stochastic parameter shift gradients. 7 | - [noise_channel.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/noise_channel.ipynb) - uses the densitytensor simulator to fit the parameters of a depolarising noise channel. 8 | - [qaoa.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/qaoa.ipynb) - uses a problem-inspired QAOA ansatz to find the ground state of a quantum Hamiltonian. Demonstrates how to encode more sophisticated parameters that control multiple gates. 9 | - [barren_plateaus.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/barren_plateaus.ipynb) - illustrates how to sample gradients of a cost function to identify the presence of barren plateaus. Uses batched/vectorized evaluation to speed up computation. 10 | - [reducing_jit_compilation_time.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/reducing_jit_compilation_time.ipynb) - explains how JAX compilation works and how that can lead to excessive compilation times when executing quantum circuits. Presents a solution for the case of circuits with a repeating structure. 11 | - [variational_inference.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/variational_inference.ipynb) - uses a parameterised quantum circuit as a variational distribution to fit to a target probability mass function. Uses Adam via [`optax`](https://github.com/deepmind/optax) to minimise the KL divergence between circuit and target distributions. 12 | - [classification.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/classification.ipynb) - train a quantum circuit for binary classification using data re-uploading. 13 | - [generative_modelling.ipynb](https://github.com/CQCL/qujax/blob/develop/examples/generative_modelling.ipynb) - uses a parameterised quantum circuit as a generative model for a real life dataset. Trains via stochastic gradient Langevin dynamics on the maximum mean discrepancy between statetensor and dataset. 14 | 15 | The [pytket](https://github.com/CQCL/pytket) repository also contains `tk_to_qujax` implementations for some of the above at [pytket-qujax_classification.ipynb](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax-classification.ipynb), 16 | [pytket-qujax_heisenberg_vqe.ipynb](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_heisenberg_vqe.ipynb) 17 | and [pytket-qujax_qaoa.ipynb](https://github.com/CQCL/pytket/blob/main/examples/pytket-qujax_qaoa.ipynb). -------------------------------------------------------------------------------- /docs/gates.rst: -------------------------------------------------------------------------------- 1 | Quantum gates 2 | ======================= 3 | 4 | This is a list of gates that qujax supports natively. You can also define custom operations by directly passing an array or function instead of a string, as documented in :doc:`statetensor/get_params_to_statetensor_func` and :doc:`densitytensor/get_params_to_densitytensor_func`. 5 | 6 | .. list-table:: 7 | :widths: 25 25 50 8 | :header-rows: 1 9 | 10 | * - Name(s) 11 | - String 12 | - Matrix representation 13 | * - Pauli X gate 14 | 15 | NOT gate 16 | 17 | Bit flip gate 18 | - :python:`"X"` 19 | - .. math:: X = NOT = \begin{bmatrix}0 & 1\\ 1 & 0 \end{bmatrix} 20 | * - Pauli Y gate 21 | - :python:`"Y"` 22 | - .. math:: Y = \begin{bmatrix}0 & -i\\ i & 0 \end{bmatrix} 23 | * - Pauli Z gate 24 | 25 | Phase flip gate 26 | - :python:`"Z"` 27 | - .. math:: Z = \begin{bmatrix}1 & 0\\ 0 & -1 \end{bmatrix} 28 | * - Hadamard gate 29 | - :python:`"H"` 30 | - .. math:: H = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & 1\\ 1 & -1 \end{bmatrix} 31 | * - S gate 32 | 33 | P (phase) gate 34 | - :python:`"S"` 35 | - .. math:: S = P = \sqrt{Z} = \begin{bmatrix}1 & 0\\ 0 & i \end{bmatrix} 36 | * - Conjugated S gate 37 | - :python:`"Sdg"` 38 | - .. math:: S^\dagger = \begin{bmatrix}1 & 0\\ 0 & -i \end{bmatrix} 39 | * - T gate 40 | - :python:`"T"` 41 | - .. math:: T = \sqrt[4]{Z} = \begin{bmatrix}1 & 0\\ 0 & \exp(\frac{\pi i}{4}) \end{bmatrix} 42 | * - Conjugated T gate 43 | - :python:`"Tdg"` 44 | - .. math:: T^\dagger = \begin{bmatrix}1 & 0\\ 0 & -\exp(\frac{\pi i}{4}) \end{bmatrix} 45 | * - V gate 46 | - :python:`"V"` 47 | - .. math:: V = \sqrt{X} = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & -i\\ -i & 1 \end{bmatrix} 48 | * - Conjugated V gate 49 | - :python:`"Vdg"` 50 | - .. math:: V^\dagger = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & i\\ i & 1 \end{bmatrix} 51 | * - SX gate 52 | - :python:`"SX"` 53 | - .. math:: SX = \sqrt{X} = \frac{1}{2}\begin{bmatrix}1 + i & 1 - i\\ 1 - i & 1 + i \end{bmatrix} 54 | * - Conjugated SX gate 55 | - :python:`"SXdg"` 56 | - .. math:: SX^\dagger = \frac{1}{2}\begin{bmatrix}1 - i & 1 + i\\ 1 + i & 1 - i \end{bmatrix} 57 | * - CX (Controlled X) gate 58 | 59 | CNOT gate 60 | - :python:`"CX"` 61 | - .. math:: CX = CNOT = \begin{bmatrix}I & 0\\ 0 & X \end{bmatrix} = \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1 \\ 0 & 0 & 1 & 0 \end{bmatrix} 62 | * - CY (Controlled Y) gate 63 | - :python:`"CY"` 64 | - .. math:: CY = \begin{bmatrix}I & 0\\ 0 & Y \end{bmatrix} = \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & -i \\ 0 & 0 & i & 0 \end{bmatrix} 65 | * - Controlled Z gate 66 | - :python:`"CZ"` 67 | - .. math:: CZ = \begin{bmatrix}I & 0\\ 0 & Z \end{bmatrix} = \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & -1 \end{bmatrix} 68 | * - Controlled Hadamard gate 69 | - :python:`"CH"` 70 | - .. math:: CH = \begin{bmatrix}I & 0\\ 0 & H \end{bmatrix} = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & 1 \\ 0 & 0 & 1 & -1 \end{bmatrix} 71 | * - Controlled V gate 72 | - :python:`"CV"` 73 | - .. math:: CV = \begin{bmatrix}I & 0\\ 0 & V \end{bmatrix} = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & -i \\ 0 & 0 & -i & 1 \end{bmatrix} 74 | * - Conjugated controlled V gate 75 | - :python:`"CVdg"` 76 | - .. math:: CVdg = \begin{bmatrix}I & 0\\ 0 & V^\dagger \end{bmatrix} = \frac{1}{\sqrt{2}}\begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1 & i \\ 0 & 0 & i & 1 \end{bmatrix} 77 | * - Controlled SX gate 78 | - :python:`"CSX"` 79 | - .. math:: CSX = \begin{bmatrix}I & 0\\ 0 & SX \end{bmatrix} = \frac{1}{2}\begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1+i & 1-i \\ 0 & 0 & 1-i & 1+i \end{bmatrix} 80 | * - Conjugated controlled SX gate 81 | - :python:`"CSXdg"` 82 | - .. math:: CSX^\dagger = \begin{bmatrix}I & 0\\ 0 & SX^\dagger \end{bmatrix} = \frac{1}{2}\begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 1-i & 1+i \\ 0 & 0 & 1+i & 1-i \end{bmatrix} 83 | * - Toffoli gate 84 | 85 | CCX 86 | 87 | CCNOT 88 | - :python:`"CCX"` 89 | - .. math:: CCX = \begin{bmatrix}I & 0 & 0 & 0\\ 0 & I & 0 & 0 \\ 0 & 0 & I & 0 \\ 0 & 0 & 0 & X \end{bmatrix} = \begin{bmatrix}1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 \\ \end{bmatrix} 90 | * - Echoed cross-resonance gate 91 | - :python:`"ECR"` 92 | - .. math:: ECR = \begin{bmatrix}0 & V^\dagger \\ V & 0 \end{bmatrix} = \frac{1}{2}\begin{bmatrix}0 & 0 & 1 & i \\ 0 & 0 & i & 1 \\ 1 & -i & 0 & 0 \\ i & 1 & 0 & 0 \end{bmatrix} 93 | * - Swap gate 94 | - :python:`"SWAP"` 95 | - .. math:: SWAP = \begin{bmatrix}1 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} 96 | * - Controlled swap gate 97 | - :python:`"CSWAP"` 98 | - .. math:: CSWAP = \begin{bmatrix}I & 0 \\ 0 & SWAP \end{bmatrix} = \begin{bmatrix}1 & 0 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 1 & 0 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 1 & 0 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 1 & 0 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 1 & 0 \\ 0 & 0 & 0 & 0 & 0 & 1 & 0 & 0 \\ 0 & 0 & 0 & 0 & 0 & 0 & 0 & 1 \end{bmatrix} 99 | * - Rotation around X axis 100 | - :python:`"Rx"` 101 | - .. math:: R_X(\theta) = \exp\left(-i \frac{\pi}{2} \theta X\right) = \begin{bmatrix} \cos( \frac{\pi}{2} \theta) & - \sin( \frac{\pi}{2} \theta) \\ - \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) \end{bmatrix} 102 | * - Rotation around X axis 103 | - :python:`"Ry"` 104 | - .. math:: R_Y(\theta) = \exp\left(-i \frac{\pi}{2} \theta Y\right) = \begin{bmatrix} \cos( \frac{\pi}{2} \theta) & i \sin( \frac{\pi}{2} \theta) \\ - i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) \end{bmatrix} 105 | * - Rotation around Z axis 106 | - :python:`"Rz"` 107 | - .. math:: R_Z(\theta) = \exp\left(-i \frac{\pi}{2} \theta Z\right) = \begin{bmatrix} \cos( \frac{\pi}{2} \theta) + \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & \cos( \frac{\pi}{2} \theta) - \sin( \frac{\pi}{2} \theta) \end{bmatrix} 108 | * - Controlled rotation around X axis 109 | - :python:`"CRx"` 110 | - .. math:: CR_X(\theta) = \begin{bmatrix}I & 0\\ 0 & RX(\theta) \end{bmatrix} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & \cos( \frac{\pi}{2} \theta) & - \sin( \frac{\pi}{2} \theta) \\ 0 & 0 & - \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) \end{bmatrix} 111 | * - Controlled rotation around Y axis 112 | - :python:`"CRy"` 113 | - .. math:: CR_Y(\theta) = \begin{bmatrix}I & 0\\ 0 & RY(\theta) \end{bmatrix} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & \cos( \frac{\pi}{2} \theta) & i \sin( \frac{\pi}{2} \theta) \\ 0 & 0 & - i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) \end{bmatrix} 114 | * - Controlled rotation around Z axis 115 | - :python:`"CRz"` 116 | - .. math:: CR_Z(\theta) = \begin{bmatrix}I & 0\\ 0 & RZ(\theta)\end{bmatrix} = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & 1 & 0 & 0 \\ 0 & 0 & \cos( \frac{\pi}{2} \theta) + \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & 0 & 0 & \cos( \frac{\pi}{2} \theta) - \sin( \frac{\pi}{2} \theta) \end{bmatrix} 117 | * - U3 118 | - :python:`"U3"` 119 | - .. math:: U3(\alpha,\beta,\gamma) = \exp((\alpha + \beta) i \frac{\pi}{2}) R_Z(\beta) R_Y(\alpha) R_Z(\gamma) 120 | * - U1 121 | - :python:`"U1"` 122 | - .. math:: U1(\gamma) = U3(0, 0, \gamma) 123 | * - U2 124 | - :python:`"U2"` 125 | - .. math:: U2(\beta, \gamma) = U3(0.5, \beta, \gamma) 126 | * - Controlled U3 127 | - :python:`"CU3"` 128 | - .. math:: CU3(\alpha,\beta,\gamma) = \begin{bmatrix}I & 0\\ 0 & U3(\alpha,\beta,\gamma)\end{bmatrix} 129 | * - Controlled U1 130 | - :python:`"CU1"` 131 | - .. math:: CU1(\gamma) = \begin{bmatrix}I & 0\\ 0 & U1(\gamma)\end{bmatrix} 132 | * - Controlled U2 133 | - :python:`"CU2"` 134 | - .. math:: CU2(\beta, \gamma) = \begin{bmatrix}I & 0\\ 0 & U2(\beta, \gamma)\end{bmatrix} 135 | * - Imaginary swap 136 | - :python:`"ISWAP"` 137 | - .. math:: iSWAP(\theta) = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & \cos( \frac{\pi}{2} \theta) & i \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} 138 | * - Phased imaginary swap 139 | - :python:`"PhasedISWAP"` 140 | - .. math:: PhasedISWAP(\phi, \theta) = \begin{bmatrix} 1 & 0 & 0 & 0 \\ 0 & \cos( \frac{\pi}{2} \theta) & \exp(2i\pi \phi) i \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & \exp(- 2i\pi \phi) i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) & 0 \\ 0 & 0 & 0 & 1 \end{bmatrix} 141 | * - XXPhase 142 | 143 | XX interaction 144 | - :python:`"XXPhase"` 145 | - .. math:: R_{XX}(\theta) = \exp\left(\frac{\pi}{2} \theta X\otimes X\right) = \begin{bmatrix} \cos( \frac{\pi}{2} \theta) & 0 & 0 & -i \sin( \frac{\pi}{2} \theta) \\ 0 & \cos( \frac{\pi}{2} \theta) & -i \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & -i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) & 0 \\ -i \sin( \frac{\pi}{2} \theta) & 0 & 0 & \cos( \frac{\pi}{2} \theta) \end{bmatrix} 146 | * - YYPhase 147 | 148 | YY interaction 149 | - :python:`"YYPhase"` 150 | - .. math:: R_{YY}(\theta) = \exp\left(\frac{\pi}{2} \theta Y\otimes Y\right) = \begin{bmatrix} \cos( \frac{\pi}{2} \theta) & 0 & 0 & i \sin( \frac{\pi}{2} \theta) \\ 0 & \cos( \frac{\pi}{2} \theta) & -i \sin( \frac{\pi}{2} \theta) & 0 \\ 0 & -i \sin( \frac{\pi}{2} \theta) & \cos( \frac{\pi}{2} \theta) & 0 \\ i \sin( \frac{\pi}{2} \theta) & 0 & 0 & \cos( \frac{\pi}{2} \theta) \end{bmatrix} 151 | * - ZZPhase 152 | 153 | ZZ interaction 154 | - :python:`"ZZPhase"` 155 | - .. math:: R_{ZZ}(\theta) = \exp\left(\frac{\pi}{2} \theta Z\otimes Z\right) = \begin{bmatrix} \exp( -i \frac{\pi}{2} \theta) & 0 & 0 & 0 \\ 0 & \exp( i \frac{\pi}{2} \theta) & 0 & 0 \\ 0 & 0 & \exp( i \frac{\pi}{2} \theta) & 0 \\ 0 & 0 & 0 & \exp( -i \frac{\pi}{2} \theta) \end{bmatrix} 156 | * - ZZMax 157 | - :python:`"ZZMax"` 158 | - .. math:: ZZMax = R_{ZZ}(0.5) 159 | * - PhasedX 160 | - :python:`"PhasedX"` 161 | - .. math:: PhasedX(\theta, \phi) = R_Z(\phi)R_X(\theta)R_Z(-\phi) -------------------------------------------------------------------------------- /docs/getting_started.rst: -------------------------------------------------------------------------------- 1 | Getting started 2 | ################# 3 | 4 | **Important note**: qujax circuit parameters are expressed in units of :math:`\pi` (e.g. in the range :math:`[0,2]` as opposed to :math:`[0, 2\pi]`). 5 | 6 | ********************* 7 | Pure state simulation 8 | ********************* 9 | 10 | We start by defining the quantum gates making up the circuit, along with the qubits that they act on and the indices of the parameters for each gate. 11 | 12 | A list of all gates can be found in :doc:`gates` (custom operations can be included by passing an array or function instead of a string, as documented in :doc:`statetensor/get_params_to_statetensor_func`). 13 | 14 | .. code-block:: python 15 | 16 | from jax import numpy as jnp 17 | import qujax 18 | 19 | # List of quantum gates 20 | circuit_gates = ['H', 'Ry', 'CZ'] 21 | # Indices of qubits the gates will be applied to 22 | circuit_qubit_inds = [[0], [0], [0, 1]] 23 | # Indices of parameters each parameterised gate will use 24 | circuit_params_inds = [[], [0], []] 25 | 26 | qujax.print_circuit(circuit_gates, circuit_qubit_inds, circuit_params_inds); 27 | # q0: -----H-----Ry[0]-----◯--- 28 | # | 29 | # q1: ---------------------CZ-- 30 | 31 | We then translate the circuit to a pure function :python:`param_to_st` that takes a set of parameters and an (optional) initial quantum state as its input. 32 | 33 | .. code-block:: python 34 | 35 | param_to_st = qujax.get_params_to_statetensor_func(circuit_gates, 36 | circuit_qubit_inds, 37 | circuit_params_inds) 38 | 39 | param_to_st(jnp.array([0.1])) 40 | # Array([[0.58778524+0.j, 0. +0.j], 41 | # [0.80901706+0.j, 0. +0.j]], dtype=complex64) 42 | 43 | The optional initial state can be passed to :python:`param_to_st` using the :python:`statetensor_in` argument. When it is not provided, the initial state defaults to :math:`\ket{0...0}`. 44 | 45 | Note that qujax represents quantum states as *statetensors*. For example, for :math:`N=4` qubits, the corresponding vector space has :math:`2^4` dimensions, and a uantum state in this space is represented by an array with shape :python:`(2,2,2,2)`. The usual statevector representation with shape :python:`(16,)` can be obtained by calling :python:`.flatten()` or :python:`.reshape(-1)` or :python:`.reshape(2**N)` on this array. 46 | 47 | In the statetensor representation, the coefficient associated with e.g. basis state :math:`\ket{0101}` is given by `arr[0,1,0,1]`; each axis corresponds to one qubit. 48 | 49 | .. code-block:: python 50 | 51 | param_to_st(jnp.array([0.1])).flatten() 52 | # Array([0.58778524+0.j, 0.+0.j, 0.80901706+0.j, 0.+0.j], dtype=complex64) 53 | 54 | 55 | Finally, by defining an observable, we can map the statetensor to an expectation value. A general observable is specified using lists of Pauli matrices, the qubits they act on, and the associated coefficients. 56 | 57 | For example, :math:`Z_1Z_2Z_3Z_4 - 2 X_3` would be written as :python:`[['Z','Z','Z','Z'], ['X']], [[1,2,3,4], [3]], [1., -2.]`. 58 | 59 | .. code-block:: python 60 | 61 | st_to_expectation = qujax.get_statetensor_to_expectation_func([['Z']], [[0]], [1.]) 62 | 63 | 64 | Combining :python:`param_to_st` and :python:`st_to_expectation` gives us a parameter to expectation function that can be automatically differentiated using JAX. 65 | 66 | .. code-block:: python 67 | 68 | from jax import value_and_grad 69 | 70 | param_to_expectation = lambda param: st_to_expectation(param_to_st(param)) 71 | expectation_and_grad = value_and_grad(param_to_expectation) 72 | expectation_and_grad(jnp.array([0.1])) 73 | # (Array(-0.3090171, dtype=float32), 74 | # Array([-2.987832], dtype=float32)) 75 | 76 | *********************** 77 | Mixed state simulation 78 | *********************** 79 | Mixed state simulations are analogous to the above, but with calls to :doc:`densitytensor/get_params_to_densitytensor_func` and :doc:`densitytensor/get_densitytensor_to_expectation_func` instead. 80 | 81 | .. code-block:: python 82 | 83 | param_to_dt = qujax.get_params_to_densitytensor_func(circuit_gates, 84 | circuit_qubit_inds, 85 | circuit_params_inds) 86 | dt = param_to_dt(jnp.array([0.1])) 87 | dt.shape 88 | # (2, 2, 2, 2) 89 | 90 | dt_to_expectation = qujax.get_densitytensor_to_expectation_func([['Z']], [[0]], [1.]) 91 | dt_to_expectation(dt) 92 | # Array(-0.3090171, dtype=float32) 93 | 94 | Similarly to a statetensor, which represents the reshaped :math:`2^N`-dimensional statevector of a pure quantum state, a *densitytensor* represents the reshaped :math:`2^N \times 2^N` density matrix of a mixed quantum state. This densitytensor has shape :python:`(2,) * 2 * N`. 95 | 96 | For example, for :math:`N=2`, and a mixed state :math:`\frac{1}{2} (\ket{00}\bra{11} + \ket{11}\bra{00} + \ket{11}\bra{11} + \ket{00}\bra{00})`, the corresponding densitytensor :python:`dt` is such that :python:`dt[0,0,1,1] = dt[1,1,0,0] = dt[1,1,1,1] = dt[0,0,0,0] = 1/2`, and all other entries are zero. -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | 2 | Welcome to qujax's documentation! 3 | ================================= 4 | 5 | ``qujax`` is a `JAX `_-based Python library for the classical simulation of quantum circuits. It is designed to be *simple*, *fast* and *flexible*. 6 | 7 | 8 | It follows a functional programming design by translating circuits into pure functions. This allows qujax to `seamlessly and directly interface with JAX `_, enabling direct access to its powerful automatic differentiation tools, just-in-time compiler, vectorization capabilities, GPU/TPU integration and growing ecosystem of packages. 9 | 10 | If you are new to the library, we recommend that you head to the :doc:`getting_started` section of the documentation. More advanced use-cases, including the training of parameterised quantum circuits, can be found in :doc:`examples`. 11 | 12 | The source code can be found on `GitHub `_. The `pytket-qujax `_ extension can be used to translate a `tket `_ circuit directly into ``qujax``. 13 | 14 | **Important note**: qujax circuit parameters are expressed in units of :math:`\pi` (e.g. in the range :math:`[0,2]` as opposed to :math:`[0, 2\pi]`). 15 | 16 | Install 17 | ================================= 18 | ``qujax`` is hosted on `PyPI `_ and can be installed with 19 | 20 | .. code-block:: bash 21 | 22 | pip install qujax 23 | 24 | Cite 25 | ================================= 26 | If you have used qujax in your code or research, we kindly ask that you cite it. You can use the following BibTeX entry for this: 27 | 28 | .. code-block:: bibtex 29 | 30 | @article{qujax2023, 31 | author = {Duffield, Samuel and Matos, Gabriel and Johannsen, Melf}, 32 | doi = {10.21105/joss.05504}, 33 | journal = {Journal of Open Source Software}, 34 | month = sep, 35 | number = {89}, 36 | pages = {5504}, 37 | title = {{qujax: Simulating quantum circuits with JAX}}, 38 | url = {https://joss.theoj.org/papers/10.21105/joss.05504}, 39 | volume = {8}, 40 | year = {2023} 41 | } 42 | 43 | Contents 44 | ================================= 45 | 46 | .. toctree:: 47 | :caption: Documentation: 48 | :titlesonly: 49 | 50 | Getting started 51 | Examples 52 | List of gates 53 | 54 | .. toctree:: 55 | :caption: API Reference: 56 | :titlesonly: 57 | :maxdepth: 1 58 | 59 | Pure state simulation 60 | Mixed state simulation 61 | Utility functions 62 | 63 | .. toctree:: 64 | :caption: Links: 65 | :hidden: 66 | 67 | GitHub 68 | Paper 69 | PyPI 70 | pytket-qujax 71 | -------------------------------------------------------------------------------- /docs/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 53 | -------------------------------------------------------------------------------- /docs/logo_dark_mode.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 56 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx >= 3 2 | sphinx-autodoc-typehints 3 | sphinx-rtd-theme >= 1.0.0 4 | 5 | jax 6 | jaxlib -------------------------------------------------------------------------------- /docs/statetensor.rst: -------------------------------------------------------------------------------- 1 | Pure state simulation 2 | ======================= 3 | 4 | .. toctree:: 5 | :titlesonly: 6 | 7 | statetensor/all_zeros_statetensor 8 | statetensor/apply_gate 9 | statetensor/get_params_to_statetensor_func 10 | statetensor/get_params_to_unitarytensor_func 11 | statetensor/get_statetensor_to_expectation_func 12 | statetensor/get_statetensor_to_sampled_expectation_func 13 | 14 | -------------------------------------------------------------------------------- /docs/statetensor/all_zeros_statetensor.rst: -------------------------------------------------------------------------------- 1 | all_zeros_statetensor 2 | ============================================== 3 | 4 | .. autofunction:: qujax.all_zeros_statetensor 5 | -------------------------------------------------------------------------------- /docs/statetensor/apply_gate.rst: -------------------------------------------------------------------------------- 1 | apply_gate 2 | ============================================== 3 | 4 | .. autofunction:: qujax.apply_gate 5 | -------------------------------------------------------------------------------- /docs/statetensor/get_params_to_statetensor_func.rst: -------------------------------------------------------------------------------- 1 | get_params_to_statetensor_func 2 | ============================================== 3 | 4 | .. autofunction:: qujax.get_params_to_statetensor_func 5 | -------------------------------------------------------------------------------- /docs/statetensor/get_params_to_unitarytensor_func.rst: -------------------------------------------------------------------------------- 1 | get_params_to_unitarytensor_func 2 | ============================================== 3 | 4 | .. autofunction:: qujax.get_params_to_unitarytensor_func 5 | -------------------------------------------------------------------------------- /docs/statetensor/get_statetensor_to_expectation_func.rst: -------------------------------------------------------------------------------- 1 | get_statetensor_to_expectation_func 2 | ============================================== 3 | 4 | .. autofunction:: qujax.get_statetensor_to_expectation_func 5 | -------------------------------------------------------------------------------- /docs/statetensor/get_statetensor_to_sampled_expectation_func.rst: -------------------------------------------------------------------------------- 1 | get_statetensor_to_sampled_expectation_func 2 | ============================================== 3 | 4 | .. autofunction:: qujax.get_statetensor_to_sampled_expectation_func 5 | -------------------------------------------------------------------------------- /docs/utils.rst: -------------------------------------------------------------------------------- 1 | Utility functions 2 | ======================= 3 | 4 | .. toctree:: 5 | :titlesonly: 6 | 7 | utils/bitstrings_to_integers 8 | utils/check_circuit 9 | utils/integers_to_bitstrings 10 | utils/print_circuit 11 | utils/repeat_circuit 12 | utils/sample_bitstrings 13 | utils/sample_integers 14 | -------------------------------------------------------------------------------- /docs/utils/bitstrings_to_integers.rst: -------------------------------------------------------------------------------- 1 | bitstrings_to_integers 2 | ============================================== 3 | 4 | .. autofunction:: qujax.bitstrings_to_integers 5 | -------------------------------------------------------------------------------- /docs/utils/check_circuit.rst: -------------------------------------------------------------------------------- 1 | check_circuit 2 | ============================================== 3 | 4 | .. autofunction:: qujax.check_circuit 5 | -------------------------------------------------------------------------------- /docs/utils/integers_to_bitstrings.rst: -------------------------------------------------------------------------------- 1 | integers_to_bitstrings 2 | ============================================== 3 | 4 | .. autofunction:: qujax.integers_to_bitstrings 5 | -------------------------------------------------------------------------------- /docs/utils/print_circuit.rst: -------------------------------------------------------------------------------- 1 | print_circuit 2 | ============================================== 3 | 4 | .. autofunction:: qujax.print_circuit 5 | -------------------------------------------------------------------------------- /docs/utils/repeat_circuit.rst: -------------------------------------------------------------------------------- 1 | repeat_circuit 2 | ============================================== 3 | 4 | .. autofunction:: qujax.repeat_circuit 5 | -------------------------------------------------------------------------------- /docs/utils/sample_bitstrings.rst: -------------------------------------------------------------------------------- 1 | sample_bitstrings 2 | ============================================== 3 | 4 | .. autofunction:: qujax.sample_bitstrings 5 | -------------------------------------------------------------------------------- /docs/utils/sample_integers.rst: -------------------------------------------------------------------------------- 1 | sample_integers 2 | ============================================== 3 | 4 | .. autofunction:: qujax.sample_integers 5 | -------------------------------------------------------------------------------- /examples/hidalgo_stamp.txt: -------------------------------------------------------------------------------- 1 | 0.060 0.064 0.064 0.065 0.066 0.068 0.069 0.069 0.069 0.069 0.069 0.069 2 | 0.069 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 3 | 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 0.070 4 | 0.070 0.070 0.070 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 5 | 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.071 0.072 6 | 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 7 | 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.072 8 | 0.072 0.072 0.072 0.072 0.072 0.072 0.072 0.073 0.073 0.073 0.073 0.073 9 | 0.073 0.073 0.073 0.073 0.073 0.073 0.074 0.074 0.074 0.074 0.074 0.074 10 | 0.074 0.074 0.074 0.074 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075 11 | 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075 0.075 12 | 0.076 0.076 0.076 0.076 0.076 0.076 0.076 0.076 0.076 0.076 0.076 0.076 13 | 0.076 0.076 0.076 0.076 0.076 0.076 0.077 0.077 0.077 0.077 0.077 0.077 14 | 0.077 0.077 0.077 0.077 0.077 0.078 0.078 0.078 0.078 0.078 0.078 0.078 15 | 16 | 0.078 0.078 0.078 0.078 0.078 0.078 0.078 0.078 0.078 0.078 0.078 0.078 17 | 0.078 0.078 0.078 0.078 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 18 | 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 19 | 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 20 | 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.079 0.080 0.080 21 | 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 22 | 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 23 | 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.080 0.081 24 | 0.081 0.081 0.081 0.081 0.081 0.081 0.081 0.081 0.081 0.081 0.081 0.081 25 | 0.081 0.081 0.082 0.082 0.082 0.082 0.082 0.082 0.082 0.082 0.082 0.082 26 | 0.082 0.082 0.082 0.082 0.082 0.082 0.082 0.082 0.083 0.083 0.083 0.083 27 | 0.083 0.083 0.083 0.084 0.084 0.084 0.085 0.085 0.086 0.086 0.087 0.088 28 | 0.088 0.089 0.089 0.089 0.089 0.089 0.089 0.089 0.089 0.089 0.089 0.090 29 | 0.090 0.090 0.090 0.090 0.090 0.090 0.090 0.090 0.091 0.091 0.091 0.092 30 | 31 | 0.092 0.092 0.092 0.092 0.093 0.093 0.093 0.093 0.093 0.093 0.094 0.094 32 | 0.094 0.095 0.095 0.096 0.096 0.096 0.097 0.097 0.097 0.097 0.097 0.097 33 | 0.097 0.098 0.098 0.098 0.098 0.098 0.099 0.099 0.099 0.099 0.099 0.100 34 | 0.100 0.100 0.100 0.100 0.100 0.100 0.100 0.100 0.100 0.100 0.100 0.100 35 | 0.100 0.100 0.101 0.101 0.101 0.101 0.101 0.101 0.101 0.101 0.101 0.102 36 | 0.102 0.102 0.102 0.102 0.102 0.102 0.102 0.103 0.103 0.103 0.103 0.103 37 | 0.103 0.103 0.104 0.104 0.105 0.105 0.105 0.105 0.105 0.106 0.106 0.106 38 | 0.106 0.107 0.107 0.107 0.108 0.108 0.108 0.108 0.108 0.108 0.108 0.109 39 | 0.109 0.109 0.109 0.109 0.109 0.109 0.110 0.110 0.110 0.110 0.110 0.110 40 | 0.110 0.110 0.110 0.110 0.110 0.111 0.111 0.111 0.111 0.112 0.112 0.112 41 | 0.112 0.112 0.114 0.114 0.114 0.115 0.115 0.115 0.117 0.119 0.119 0.119 42 | 0.119 0.120 0.120 0.120 0.121 0.122 0.122 0.123 0.123 0.125 0.125 0.128 43 | 0.129 0.129 0.129 0.130 0.131 44 | -------------------------------------------------------------------------------- /examples/qaoa.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "a52b609d", 6 | "metadata": {}, 7 | "source": [ 8 | "# QAOA with `qujax`\n", 9 | "\n", 10 | "In this notebook, we will consider QAOA on an Ising Hamiltonian. In particular, we will demonstrate how to encode a circuit with parameters that control multiple gates." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "id": "2678d061", 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "from jax import numpy as jnp, random, value_and_grad, jit\n", 21 | "import matplotlib.pyplot as plt\n", 22 | "\n", 23 | "import qujax" 24 | ] 25 | }, 26 | { 27 | "cell_type": "markdown", 28 | "id": "57aaa7eb", 29 | "metadata": {}, 30 | "source": [ 31 | "# QAOA\n", 32 | "\n", 33 | "The Quantum Approximate Optimization Algorithm (QAOA), first introduced by [Farhi et al.](https://arxiv.org/pdf/1411.4028.pdf), is a quantum variational algorithm used to solve optimization problems. It consists of a unitary $U(\\beta, \\gamma)$ formed by alternate repetitions of $U(\\beta)=e^{-i\\beta H_B}$ and $U(\\gamma)=e^{-i\\gamma H_P}$, where $H_B$ is the mixing Hamiltonian and $H_P$ the problem Hamiltonian. The goal is to find the optimal parameters that minimize $H_P$.\n", 34 | "\n", 35 | "Given a depth $d$, the expression of the final unitary is $U(\\beta, \\gamma) = U(\\beta_d)U(\\gamma_d)\\cdots U(\\beta_1)U(\\gamma_1)$. Notice that for each repetition the parameters are different.\n", 36 | "\n", 37 | "\n", 38 | "## Problem Hamiltonian\n", 39 | "QAOA uses a problem dependent ansatz. Therefore, we first need to know the problem that we want to solve. In this case we will consider an Ising Hamiltonian with only $Z$ interactions. Given a set of pairs (or qubit indices) $E$, the problem Hamiltonian will be:\n", 40 | "$$H_P = \\sum_{(i, j) \\in E}\\alpha_{ij}Z_iZ_j,$$ \n", 41 | "where $\\alpha_{ij}$ are the coefficients.\n", 42 | "\n", 43 | "Let's build our problem Hamiltonian with random coefficients and a set of pairs for a given number of qubits:" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 2, 49 | "id": "6602f4ad", 50 | "metadata": {}, 51 | "outputs": [], 52 | "source": [ 53 | "n_qubits = 4\n", 54 | "hamiltonian_qubit_inds = [(0, 1), (1, 2), (0, 2), (1, 3)]\n", 55 | "hamiltonian_gates = [[\"Z\", \"Z\"]] * (len(hamiltonian_qubit_inds))\n", 56 | "\n", 57 | "\n", 58 | "# Notice that in order to use the random package from jax we first need to define a seeded key\n", 59 | "seed = 13\n", 60 | "key = random.PRNGKey(seed)\n", 61 | "coefficients = random.uniform(key, shape=(len(hamiltonian_qubit_inds),))" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 3, 67 | "id": "e80df4ff", 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "Gates:\t [['Z', 'Z'], ['Z', 'Z'], ['Z', 'Z'], ['Z', 'Z']]\n", 75 | "Qubits:\t [(0, 1), (1, 2), (0, 2), (1, 3)]\n", 76 | "Coefficients:\t [0.6794174 0.2963785 0.2863201 0.31746793]\n" 77 | ] 78 | } 79 | ], 80 | "source": [ 81 | "print(\"Gates:\\t\", hamiltonian_gates)\n", 82 | "print(\"Qubits:\\t\", hamiltonian_qubit_inds)\n", 83 | "print(\"Coefficients:\\t\", coefficients)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "b03a1f9d", 89 | "metadata": {}, 90 | "source": [ 91 | "## Variational Circuit\n", 92 | "\n", 93 | "Before constructing the circuit, we still need to select the mixing Hamiltonian. In our case, we will be using $X$ gates in each qubit, so $H_B = \\sum_{i=1}^{n}X_i$, where $n$ is the number of qubits. Notice that the unitary $U(\\beta)$, given this mixing Hamiltonian, is an $X$ rotation in each qubit with angle $\\beta$.\n", 94 | "\n", 95 | "As for the unitary corresponding to the problem Hamiltonian, $U(\\gamma)$, it has the following form:\n", 96 | "$$U(\\gamma)=\\prod_{(i, j) \\in E}e^{-i\\gamma\\alpha_{ij}Z_iZ_j}$$ \n", 97 | "\n", 98 | "The operation $e^{-i\\gamma\\alpha_{ij}Z_iZ_j}$ can be performed using two CNOT gates with qubit $i$ as control and qubit $j$ as target and a $Z$ rotation in qubit $j$ in between them, with angle $\\gamma\\alpha_{ij}$.\n", 99 | "\n", 100 | "Finally, the initial state used, in general, with the QAOA is an equal superposition of all the basis states. This can be achieved adding a first layer of Hadamard gates in each qubit at the beginning of the circuit." 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": 4, 106 | "id": "008e895f", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "depth = 3" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 5, 116 | "id": "cf0fbfaa", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "circuit_gates = []\n", 121 | "circuit_qubit_inds = []\n", 122 | "circuit_param_inds = []\n", 123 | "\n", 124 | "param_ind = 0\n", 125 | "\n", 126 | "# Initial State\n", 127 | "for i in range(n_qubits):\n", 128 | " circuit_gates.append(\"H\")\n", 129 | " circuit_qubit_inds.append([i])\n", 130 | " circuit_param_inds.append([])\n", 131 | "\n", 132 | "for d in range(depth):\n", 133 | " # Mixing Unitary\n", 134 | " for i in range(n_qubits):\n", 135 | " circuit_gates.append(\"Rx\")\n", 136 | " circuit_qubit_inds.append([i])\n", 137 | " circuit_param_inds.append([param_ind])\n", 138 | " param_ind += 1\n", 139 | "\n", 140 | " # Hamiltonian\n", 141 | " for index in range(len(hamiltonian_qubit_inds)):\n", 142 | " pair = hamiltonian_qubit_inds[index]\n", 143 | " coef = coefficients[index]\n", 144 | "\n", 145 | " circuit_gates.append(\"CX\")\n", 146 | " circuit_qubit_inds.append([pair[0], pair[1]])\n", 147 | " circuit_param_inds.append([])\n", 148 | "\n", 149 | " circuit_gates.append(lambda p: qujax.gates.Rz(p * coef))\n", 150 | " circuit_qubit_inds.append([pair[1]])\n", 151 | " circuit_param_inds.append([param_ind])\n", 152 | "\n", 153 | " circuit_gates.append(\"CX\")\n", 154 | " circuit_qubit_inds.append([pair[0], pair[1]])\n", 155 | " circuit_param_inds.append([])\n", 156 | " param_ind += 1" 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "id": "24ce1e3c", 162 | "metadata": {}, 163 | "source": [ 164 | "Let's check the circuit:" 165 | ] 166 | }, 167 | { 168 | "cell_type": "code", 169 | "execution_count": 6, 170 | "id": "c3fb239c", 171 | "metadata": { 172 | "scrolled": false 173 | }, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "q0: -----H-----Rx[0]-----◯---------------◯-------------------------------◯---------------◯---------------------Rx[2]-\n", 180 | " | | | | \n", 181 | "q1: -----H-----Rx[0]-----CX---Func[1]----CX------◯---------------◯-------|---------------|-------◯---------------◯---\n", 182 | " | | | | | | \n", 183 | "q2: -----H-----Rx[0]-----------------------------CX---Func[1]----CX------CX---Func[1]----CX------|---------------|---\n", 184 | " | | \n", 185 | "q3: -----H-----Rx[0]-----------------------------------------------------------------------------CX---Func[1]----CX--\n" 186 | ] 187 | } 188 | ], 189 | "source": [ 190 | "qujax.print_circuit(\n", 191 | " circuit_gates, circuit_qubit_inds, circuit_param_inds, gate_ind_max=20\n", 192 | ");" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "id": "8c421d1d", 198 | "metadata": {}, 199 | "source": [ 200 | "Then, we invoke the `qujax.get_params_to_statetensor_func`" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": 7, 206 | "id": "1af80d12", 207 | "metadata": {}, 208 | "outputs": [], 209 | "source": [ 210 | "param_to_st = qujax.get_params_to_statetensor_func(\n", 211 | " circuit_gates, circuit_qubit_inds, circuit_param_inds\n", 212 | ")" 213 | ] 214 | }, 215 | { 216 | "cell_type": "markdown", 217 | "id": "689e7811", 218 | "metadata": {}, 219 | "source": [ 220 | "And we also construct the expectation map using the problem Hamiltonian via qujax: " 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 8, 226 | "id": "0857d4bb", 227 | "metadata": {}, 228 | "outputs": [], 229 | "source": [ 230 | "st_to_expectation = qujax.get_statetensor_to_expectation_func(\n", 231 | " hamiltonian_gates, hamiltonian_qubit_inds, coefficients\n", 232 | ")\n", 233 | "\n", 234 | "param_to_expectation = lambda param: st_to_expectation(param_to_st(param))" 235 | ] 236 | }, 237 | { 238 | "cell_type": "markdown", 239 | "id": "6144236c", 240 | "metadata": {}, 241 | "source": [ 242 | "# Training process\n", 243 | "We construct a function that, given a parameter vector, returns the value of the cost function and the gradient (we also `jit` to avoid recompilation):" 244 | ] 245 | }, 246 | { 247 | "cell_type": "code", 248 | "execution_count": 9, 249 | "id": "069331d5", 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "cost_and_grad = jit(value_and_grad(param_to_expectation))" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "id": "8f8e0741", 259 | "metadata": {}, 260 | "source": [ 261 | "For the training process we'll use vanilla gradient descent with a constant stepsize:" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": 10, 267 | "id": "b28018e2", 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "seed = 123\n", 272 | "key = random.PRNGKey(seed)\n", 273 | "init_param = random.uniform(key, shape=(param_ind,))\n", 274 | "\n", 275 | "n_steps = 150\n", 276 | "stepsize = 0.01" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 11, 282 | "id": "c0f53be5", 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "Iteration: 1 \tCost: -0.09861969\r", 290 | "Iteration: 2 \tCost: -0.30488127\r", 291 | "Iteration: 3 \tCost: -0.43221825\r", 292 | "Iteration: 4 \tCost: -0.49104044\r", 293 | "Iteration: 5 \tCost: -0.51378363\r", 294 | "Iteration: 6 \tCost: -0.5220106\r", 295 | "Iteration: 7 \tCost: -0.52507806\r", 296 | "Iteration: 8 \tCost: -0.5263982\r", 297 | "Iteration: 9 \tCost: -0.52713084\r", 298 | "Iteration: 10 \tCost: -0.5276594\r", 299 | "Iteration: 11 \tCost: -0.5281093\r", 300 | "Iteration: 12 \tCost: -0.5285234\r", 301 | "Iteration: 13 \tCost: -0.5289158\r", 302 | "Iteration: 14 \tCost: -0.5292916\r", 303 | "Iteration: 15 \tCost: -0.5296537\r", 304 | "Iteration: 16 \tCost: -0.53000385\r", 305 | "Iteration: 17 \tCost: -0.5303429\r", 306 | "Iteration: 18 \tCost: -0.5306716\r", 307 | "Iteration: 19 \tCost: -0.5309909\r", 308 | "Iteration: 20 \tCost: -0.5313021\r", 309 | "Iteration: 21 \tCost: -0.53160506\r", 310 | "Iteration: 22 \tCost: -0.53190106\r", 311 | "Iteration: 23 \tCost: -0.53219026\r", 312 | "Iteration: 24 \tCost: -0.5324728\r", 313 | "Iteration: 25 \tCost: -0.53274995\r", 314 | "Iteration: 26 \tCost: -0.533021\r", 315 | "Iteration: 27 \tCost: -0.5332869\r", 316 | "Iteration: 28 \tCost: -0.53354824\r", 317 | "Iteration: 29 \tCost: -0.5338048\r", 318 | "Iteration: 30 \tCost: -0.53405714\r", 319 | "Iteration: 31 \tCost: -0.5343054\r", 320 | "Iteration: 32 \tCost: -0.53454924\r", 321 | "Iteration: 33 \tCost: -0.5347898\r", 322 | "Iteration: 34 \tCost: -0.53502667\r", 323 | "Iteration: 35 \tCost: -0.5352601\r", 324 | "Iteration: 36 \tCost: -0.53549004\r", 325 | "Iteration: 37 \tCost: -0.53571683\r", 326 | "Iteration: 38 \tCost: -0.5359408\r", 327 | "Iteration: 39 \tCost: -0.5361613\r", 328 | "Iteration: 40 \tCost: -0.5363793\r", 329 | "Iteration: 41 \tCost: -0.53659475\r", 330 | "Iteration: 42 \tCost: -0.5368071\r", 331 | "Iteration: 43 \tCost: -0.53701735\r", 332 | "Iteration: 44 \tCost: -0.5372245\r", 333 | "Iteration: 45 \tCost: -0.5374299\r", 334 | "Iteration: 46 \tCost: -0.5376322\r", 335 | "Iteration: 47 \tCost: -0.53783244\r", 336 | "Iteration: 48 \tCost: -0.5380306\r", 337 | "Iteration: 49 \tCost: -0.53822607\r", 338 | "Iteration: 50 \tCost: -0.53841996\r", 339 | "Iteration: 51 \tCost: -0.53861094\r", 340 | "Iteration: 52 \tCost: -0.53880095\r", 341 | "Iteration: 53 \tCost: -0.53898793\r", 342 | "Iteration: 54 \tCost: -0.5391733\r", 343 | "Iteration: 55 \tCost: -0.5393568\r", 344 | "Iteration: 56 \tCost: -0.5395382\r", 345 | "Iteration: 57 \tCost: -0.53971833\r", 346 | "Iteration: 58 \tCost: -0.53989595\r", 347 | "Iteration: 59 \tCost: -0.54007185\r", 348 | "Iteration: 60 \tCost: -0.5402465\r", 349 | "Iteration: 61 \tCost: -0.54041874\r", 350 | "Iteration: 62 \tCost: -0.5405901\r", 351 | "Iteration: 63 \tCost: -0.5407588\r", 352 | "Iteration: 64 \tCost: -0.5409268\r", 353 | "Iteration: 65 \tCost: -0.54109275\r", 354 | "Iteration: 66 \tCost: -0.5412577\r", 355 | "Iteration: 67 \tCost: -0.54141986\r", 356 | "Iteration: 68 \tCost: -0.5415817\r", 357 | "Iteration: 69 \tCost: -0.5417414\r", 358 | "Iteration: 70 \tCost: -0.5419003\r", 359 | "Iteration: 71 \tCost: -0.5420571\r", 360 | "Iteration: 72 \tCost: -0.5422127\r", 361 | "Iteration: 73 \tCost: -0.5423671\r", 362 | "Iteration: 74 \tCost: -0.5425201\r", 363 | "Iteration: 75 \tCost: -0.5426715\r", 364 | "Iteration: 76 \tCost: -0.54282165\r", 365 | "Iteration: 77 \tCost: -0.54297113\r", 366 | "Iteration: 78 \tCost: -0.5431187\r", 367 | "Iteration: 79 \tCost: -0.543265\r", 368 | "Iteration: 80 \tCost: -0.5434102\r", 369 | "Iteration: 81 \tCost: -0.54355454\r", 370 | "Iteration: 82 \tCost: -0.54369736\r", 371 | "Iteration: 83 \tCost: -0.5438389\r", 372 | "Iteration: 84 \tCost: -0.5439794\r", 373 | "Iteration: 85 \tCost: -0.54411876\r", 374 | "Iteration: 86 \tCost: -0.54425716\r", 375 | "Iteration: 87 \tCost: -0.5443945\r", 376 | "Iteration: 88 \tCost: -0.5445305\r", 377 | "Iteration: 89 \tCost: -0.5446654\r", 378 | "Iteration: 90 \tCost: -0.5447996\r", 379 | "Iteration: 91 \tCost: -0.5449326\r", 380 | "Iteration: 92 \tCost: -0.54506487\r", 381 | "Iteration: 93 \tCost: -0.54519576\r", 382 | "Iteration: 94 \tCost: -0.5453257\r", 383 | "Iteration: 95 \tCost: -0.5454546\r", 384 | "Iteration: 96 \tCost: -0.545583\r", 385 | "Iteration: 97 \tCost: -0.5457101\r", 386 | "Iteration: 98 \tCost: -0.54583657\r", 387 | "Iteration: 99 \tCost: -0.5459618\r", 388 | "Iteration: 100 \tCost: -0.54608655\r", 389 | "Iteration: 101 \tCost: -0.54621017\r", 390 | "Iteration: 102 \tCost: -0.54633296\r", 391 | "Iteration: 103 \tCost: -0.54645467\r", 392 | "Iteration: 104 \tCost: -0.54657626\r", 393 | "Iteration: 105 \tCost: -0.54669654\r", 394 | "Iteration: 106 \tCost: -0.54681575\r", 395 | "Iteration: 107 \tCost: -0.5469348\r", 396 | "Iteration: 108 \tCost: -0.5470527\r", 397 | "Iteration: 109 \tCost: -0.54716986\r", 398 | "Iteration: 110 \tCost: -0.54728657\r", 399 | "Iteration: 111 \tCost: -0.5474022\r", 400 | "Iteration: 112 \tCost: -0.5475172\r", 401 | "Iteration: 113 \tCost: -0.5476318\r", 402 | "Iteration: 114 \tCost: -0.5477449\r", 403 | "Iteration: 115 \tCost: -0.5478576\r", 404 | "Iteration: 116 \tCost: -0.54797024\r", 405 | "Iteration: 117 \tCost: -0.5480816\r", 406 | "Iteration: 118 \tCost: -0.548192\r", 407 | "Iteration: 119 \tCost: -0.54830235\r", 408 | "Iteration: 120 \tCost: -0.5484118\r", 409 | "Iteration: 121 \tCost: -0.54852045\r", 410 | "Iteration: 122 \tCost: -0.548629\r", 411 | "Iteration: 123 \tCost: -0.5487364\r", 412 | "Iteration: 124 \tCost: -0.5488433\r", 413 | "Iteration: 125 \tCost: -0.5489498\r", 414 | "Iteration: 126 \tCost: -0.54905534\r", 415 | "Iteration: 127 \tCost: -0.5491607\r", 416 | "Iteration: 128 \tCost: -0.5492652\r", 417 | "Iteration: 129 \tCost: -0.54936945\r", 418 | "Iteration: 130 \tCost: -0.54947305\r", 419 | "Iteration: 131 \tCost: -0.54957575\r", 420 | "Iteration: 132 \tCost: -0.5496777\r", 421 | "Iteration: 133 \tCost: -0.5497798\r", 422 | "Iteration: 134 \tCost: -0.5498811\r", 423 | "Iteration: 135 \tCost: -0.5499818\r", 424 | "Iteration: 136 \tCost: -0.5500822\r", 425 | "Iteration: 137 \tCost: -0.5501819\r", 426 | "Iteration: 138 \tCost: -0.550281\r", 427 | "Iteration: 139 \tCost: -0.55037963\r", 428 | "Iteration: 140 \tCost: -0.55047804\r", 429 | "Iteration: 141 \tCost: -0.5505761\r", 430 | "Iteration: 142 \tCost: -0.55067325\r", 431 | "Iteration: 143 \tCost: -0.5507698\r", 432 | "Iteration: 144 \tCost: -0.55086654\r", 433 | "Iteration: 145 \tCost: -0.5509624\r", 434 | "Iteration: 146 \tCost: -0.5510575\r", 435 | "Iteration: 147 \tCost: -0.55115235\r", 436 | "Iteration: 148 \tCost: -0.55124706\r", 437 | "Iteration: 149 \tCost: -0.5513411\r" 438 | ] 439 | } 440 | ], 441 | "source": [ 442 | "param = init_param\n", 443 | "\n", 444 | "cost_vals = jnp.zeros(n_steps)\n", 445 | "cost_vals = cost_vals.at[0].set(param_to_expectation(init_param))\n", 446 | "\n", 447 | "for step in range(1, n_steps):\n", 448 | " cost_val, cost_grad = cost_and_grad(param)\n", 449 | " cost_vals = cost_vals.at[step].set(cost_val)\n", 450 | " param = param - stepsize * cost_grad\n", 451 | " print(\"Iteration:\", step, \"\\tCost:\", cost_val, end=\"\\r\")" 452 | ] 453 | }, 454 | { 455 | "cell_type": "markdown", 456 | "id": "b9305afc", 457 | "metadata": {}, 458 | "source": [ 459 | "Let's visualise the gradient descent" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 12, 465 | "id": "114ed37d", 466 | "metadata": {}, 467 | "outputs": [ 468 | { 469 | "data": { 470 | "image/png": "\n", 471 | "text/plain": [ 472 | "
" 473 | ] 474 | }, 475 | "metadata": { 476 | "needs_background": "light" 477 | }, 478 | "output_type": "display_data" 479 | } 480 | ], 481 | "source": [ 482 | "plt.plot(cost_vals)\n", 483 | "plt.xlabel(\"Iteration\")\n", 484 | "plt.ylabel(\"Cost\");" 485 | ] 486 | } 487 | ], 488 | "metadata": { 489 | "kernelspec": { 490 | "display_name": "Python 3 (ipykernel)", 491 | "language": "python", 492 | "name": "python3" 493 | }, 494 | "language_info": { 495 | "codemirror_mode": { 496 | "name": "ipython", 497 | "version": 3 498 | }, 499 | "file_extension": ".py", 500 | "mimetype": "text/x-python", 501 | "name": "python", 502 | "nbconvert_exporter": "python", 503 | "pygments_lexer": "ipython3", 504 | "version": "3.9.10" 505 | } 506 | }, 507 | "nbformat": 4, 508 | "nbformat_minor": 5 509 | } 510 | -------------------------------------------------------------------------------- /qujax/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simulating quantum circuits with JAX 3 | """ 4 | 5 | from qujax.version import __version__ 6 | 7 | from qujax import gates 8 | 9 | from qujax.statetensor import all_zeros_statetensor 10 | from qujax.statetensor import apply_gate 11 | from qujax.statetensor import get_params_to_statetensor_func 12 | from qujax.statetensor import get_params_to_unitarytensor_func 13 | 14 | from qujax.statetensor_observable import statetensor_to_single_expectation 15 | from qujax.statetensor_observable import get_statetensor_to_expectation_func 16 | from qujax.statetensor_observable import get_statetensor_to_sampled_expectation_func 17 | 18 | from qujax.densitytensor import all_zeros_densitytensor 19 | from qujax.densitytensor import _kraus_single 20 | from qujax.densitytensor import kraus 21 | from qujax.densitytensor import get_params_to_densitytensor_func 22 | from qujax.densitytensor import partial_trace 23 | 24 | from qujax.densitytensor_observable import densitytensor_to_single_expectation 25 | from qujax.densitytensor_observable import get_densitytensor_to_expectation_func 26 | from qujax.densitytensor_observable import get_densitytensor_to_sampled_expectation_func 27 | from qujax.densitytensor_observable import densitytensor_to_measurement_probabilities 28 | from qujax.densitytensor_observable import densitytensor_to_measured_densitytensor 29 | 30 | from qujax.utils import check_unitary 31 | from qujax.utils import check_hermitian 32 | from qujax.utils import check_circuit 33 | from qujax.utils import print_circuit 34 | from qujax.utils import integers_to_bitstrings 35 | from qujax.utils import bitstrings_to_integers 36 | from qujax.utils import repeat_circuit 37 | from qujax.utils import sample_integers 38 | from qujax.utils import sample_bitstrings 39 | from qujax.utils import statetensor_to_densitytensor 40 | 41 | import qujax.typing 42 | 43 | # pylint: disable=undefined-variable 44 | del version 45 | del statetensor 46 | del statetensor_observable 47 | del densitytensor 48 | del densitytensor_observable 49 | del utils 50 | -------------------------------------------------------------------------------- /qujax/densitytensor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Iterable, Sequence, Tuple, Union, Optional 4 | 5 | import jax 6 | from jax import numpy as jnp 7 | from jax.typing import ArrayLike 8 | from jax.lax import scan 9 | from jax._src.dtypes import canonicalize_dtype 10 | from jax._src.typing import DTypeLike 11 | from qujax.statetensor import ( 12 | _arrayify_inds, 13 | _gate_func_to_unitary, 14 | _to_gate_func, 15 | apply_gate, 16 | ) 17 | from qujax.utils import check_circuit 18 | from qujax.typing import ( 19 | MixedCircuitFunction, 20 | KrausOp, 21 | GateFunction, 22 | GateParameterIndices, 23 | ) 24 | 25 | 26 | def _kraus_single( 27 | densitytensor: jax.Array, array: jax.Array, qubit_inds: Sequence[int] 28 | ) -> jax.Array: 29 | r""" 30 | Performs single Kraus operation 31 | 32 | .. math:: 33 | 34 | \rho_\text{out} = B \rho_\text{in} B^{\dagger} 35 | 36 | Args: 37 | densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits 38 | array: Array containing the Kraus operator (in tensor form). 39 | qubit_inds: Sequence of qubit indices on which to apply the Kraus operation. 40 | 41 | Returns: 42 | Updated density matrix. 43 | """ 44 | n_qubits = densitytensor.ndim // 2 45 | densitytensor = apply_gate(densitytensor, array, qubit_inds) 46 | densitytensor = apply_gate( 47 | densitytensor, array.conj(), [n_qubits + i for i in qubit_inds] 48 | ) 49 | return densitytensor 50 | 51 | 52 | def kraus( 53 | densitytensor: jax.Array, arrays: Iterable[jax.Array], qubit_inds: Sequence[int] 54 | ) -> jax.Array: 55 | r""" 56 | Performs Kraus operation. 57 | 58 | .. math:: 59 | \rho_\text{out} = \sum_i B_i \rho_\text{in} B_i^{\dagger} 60 | 61 | Args: 62 | densitytensor: Input density matrix of shape=(2, 2, ...) and ndim=2*n_qubits 63 | arrays: Sequence of arrays containing the Kraus operators (in tensor form). 64 | qubit_inds: Sequence of qubit indices on which to apply the Kraus operation. 65 | 66 | Returns: 67 | Updated density matrix. 68 | """ 69 | arrays = jnp.array(arrays) 70 | if arrays.ndim % 2 == 0: 71 | arrays = arrays[jnp.newaxis] 72 | # ensure first dimensions indexes different kraus operators 73 | arrays = arrays.reshape((arrays.shape[0],) + (2,) * 2 * len(qubit_inds)) 74 | 75 | new_densitytensor, _ = scan( 76 | lambda dt, arr: (dt + _kraus_single(densitytensor, arr, qubit_inds), None), 77 | init=jnp.zeros_like(densitytensor) * 0.0j, 78 | xs=arrays, 79 | ) 80 | # new_densitytensor = vmap(_kraus_single, in_axes=(None, 0, None))( 81 | # densitytensor, arrays, qubit_inds 82 | # ).sum(0) 83 | return new_densitytensor 84 | 85 | 86 | def _to_kraus_operator_seq_funcs( 87 | kraus_op: KrausOp, 88 | param_inds: Optional[Union[GateParameterIndices, Sequence[GateParameterIndices]]], 89 | ) -> Tuple[Sequence[GateFunction], Sequence[jax.Array]]: 90 | """ 91 | Ensures Kraus operators are a sequence of functions that map (possibly empty) parameters to 92 | tensors and that each element of param_inds_seq is a sequence of arrays that correspond to the 93 | parameter indices of each Kraus operator. 94 | 95 | Args: 96 | kraus_op: Either a normal Gate or a sequence of Gates representing Kraus operators. 97 | param_inds: If kraus_op is a normal Gate then a sequence of parameter indices, 98 | if kraus_op is a sequence of Kraus operators then a sequence of sequences of 99 | parameter indices 100 | 101 | Returns: 102 | Tuple containing sequence of functions mapping to Kraus operators 103 | and sequence of arrays with parameter indices 104 | 105 | """ 106 | if isinstance(kraus_op, (list, tuple)): 107 | kraus_op_funcs = [_to_gate_func(ko) for ko in kraus_op] 108 | if param_inds is None: 109 | param_inds = [None for _ in kraus_op] 110 | elif isinstance(kraus_op, (str, jax.Array)) or callable(kraus_op): 111 | kraus_op_funcs = [_to_gate_func(kraus_op)] 112 | param_inds = [param_inds] 113 | else: 114 | raise ValueError(f"Invalid Kraus operator specification: {kraus_op}") 115 | return kraus_op_funcs, _arrayify_inds(param_inds) 116 | 117 | 118 | def partial_trace( 119 | densitytensor: jax.Array, indices_to_trace: Sequence[int] 120 | ) -> jax.Array: 121 | """ 122 | Traces out (discards) specified qubits, resulting in a densitytensor 123 | representing the mixed quantum state on the remaining qubits. 124 | 125 | Args: 126 | densitytensor: Input densitytensor. 127 | indices_to_trace: Indices of qubits to trace out/discard. 128 | 129 | Returns: 130 | Resulting densitytensor on remaining qubits. 131 | 132 | """ 133 | n_qubits = densitytensor.ndim // 2 134 | einsum_indices = list(range(densitytensor.ndim)) 135 | for i in indices_to_trace: 136 | einsum_indices[i + n_qubits] = einsum_indices[i] 137 | densitytensor = jnp.einsum(densitytensor, einsum_indices) 138 | return densitytensor 139 | 140 | 141 | def all_zeros_densitytensor(n_qubits: int, dtype: DTypeLike = complex) -> jax.Array: 142 | """ 143 | Returns a densitytensor representation of the all-zeros state |00...0> on `n_qubits` qubits 144 | 145 | Args: 146 | n_qubits: Number of qubits that the state is defined on. 147 | dtype: Data type of the densitytensor returned. 148 | 149 | Returns: 150 | Densitytensor representing the state having all qubits set to zero. 151 | """ 152 | densitytensor = jnp.zeros((2,) * 2 * n_qubits, canonicalize_dtype(dtype)) 153 | densitytensor = densitytensor.at[(0,) * 2 * n_qubits].set(1.0) 154 | return densitytensor 155 | 156 | 157 | def get_params_to_densitytensor_func( 158 | kraus_ops_seq: Sequence[KrausOp], 159 | qubit_inds_seq: Sequence[Sequence[int]], 160 | param_inds_seq: Sequence[ 161 | Union[GateParameterIndices, Sequence[GateParameterIndices]] 162 | ], 163 | n_qubits: Optional[int] = None, 164 | ) -> MixedCircuitFunction: 165 | """ 166 | Creates a function that maps circuit parameters to a density tensor (a density matrix in 167 | tensor form). 168 | densitytensor = densitymatrix.reshape((2,) * 2 * n_qubits) 169 | densitymatrix = densitytensor.reshape(2 ** n_qubits, 2 ** n_qubits) 170 | 171 | Args: 172 | kraus_ops_seq: Sequence of gates. 173 | Each element is either a string matching a unitary array or function in qujax.gates, 174 | a custom unitary array or a custom function taking parameters and returning a unitary 175 | array. Unitary arrays will be reshaped into tensor form (2, 2,...) 176 | qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are 177 | acting on. 178 | i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the 179 | zeroth qubit, the second gate is a two qubit gate acting on the zeroth and 180 | first qubit etc. 181 | param_inds_seq: Sequence of sequences representing parameter indices that gates are using, 182 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter 183 | (the float at position zero in the parameter vector/array), the second gate is not 184 | parameterised and the third gate uses the parameters at position five and two. 185 | n_qubits: Number of qubits, if fixed. 186 | 187 | Returns: 188 | Function which maps parameters (and optional densitytensor_in) to a densitytensor. 189 | If no parameters are found then the function only takes optional densitytensor_in. 190 | 191 | """ 192 | 193 | check_circuit(kraus_ops_seq, qubit_inds_seq, param_inds_seq, n_qubits, False) 194 | 195 | if n_qubits is None: 196 | n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 197 | 198 | kraus_ops_seq_callable_and_param_inds = [ 199 | _to_kraus_operator_seq_funcs(ko, param_inds) 200 | for ko, param_inds in zip(kraus_ops_seq, param_inds_seq) 201 | ] 202 | kraus_ops_seq_callable = [ 203 | ko_pi[0] for ko_pi in kraus_ops_seq_callable_and_param_inds 204 | ] 205 | param_inds_array_seq = [ko_pi[1] for ko_pi in kraus_ops_seq_callable_and_param_inds] 206 | 207 | def params_to_densitytensor_func( 208 | params: ArrayLike, densitytensor_in: Optional[jax.Array] = None 209 | ) -> jax.Array: 210 | """ 211 | Applies parameterised circuit (series of gates) to a densitytensor_in 212 | (default is |0>^N <0|^N). 213 | 214 | Args: 215 | params: Parameters of the circuit. 216 | densitytensor_in: Optional. Input densitytensor. 217 | Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in 218 | the [0]*(2*N) index). 219 | 220 | Returns: 221 | Updated densitytensor. 222 | 223 | """ 224 | if densitytensor_in is None: 225 | densitytensor = all_zeros_densitytensor(n_qubits) 226 | else: 227 | densitytensor = densitytensor_in 228 | params = jnp.atleast_1d(params) 229 | # Guarantee `params` has the right type for type-checking purposes 230 | if not isinstance(params, jax.Array): 231 | raise ValueError("This should not happen. Please open an issue on GitHub.") 232 | for gate_func_single_seq, qubit_inds, param_inds_single_seq in zip( 233 | kraus_ops_seq_callable, qubit_inds_seq, param_inds_array_seq 234 | ): 235 | kraus_operators = [ 236 | _gate_func_to_unitary(gf, qubit_inds, pi, params) 237 | for gf, pi in zip(gate_func_single_seq, param_inds_single_seq) 238 | ] 239 | densitytensor = kraus(densitytensor, kraus_operators, qubit_inds) 240 | return densitytensor 241 | 242 | non_parameterised = all( 243 | [all([pi.size == 0 for pi in pi_seq]) for pi_seq in param_inds_array_seq] 244 | ) 245 | if non_parameterised: 246 | 247 | def no_params_to_densitytensor_func( 248 | densitytensor_in: Optional[jax.Array] = None, 249 | ) -> jax.Array: 250 | """ 251 | Applies circuit (series of gates with no parameters) to a densitytensor_in 252 | (default is |0>^N <0|^N). 253 | 254 | Args: 255 | densitytensor_in: Optional. Input densitytensor. 256 | Defaults to |0>^N <0|^N (tensor of size 2^(2*N) with all zeroes except one in 257 | the [0]*(2*N) index). 258 | 259 | Returns: 260 | Updated densitytensor. 261 | 262 | """ 263 | return params_to_densitytensor_func(jnp.array([]), densitytensor_in) 264 | 265 | return no_params_to_densitytensor_func 266 | 267 | return params_to_densitytensor_func 268 | -------------------------------------------------------------------------------- /qujax/densitytensor_observable.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Callable, Sequence, Union 4 | 5 | import jax 6 | from jax.typing import ArrayLike 7 | from jax import numpy as jnp 8 | from jax import random 9 | 10 | from qujax.densitytensor import _kraus_single, partial_trace 11 | from qujax.statetensor_observable import _get_tensor_to_expectation_func, sample_probs 12 | from qujax.utils import bitstrings_to_integers, check_hermitian 13 | 14 | 15 | def densitytensor_to_single_expectation( 16 | densitytensor: jax.Array, hermitian: jax.Array, qubit_inds: Sequence[int] 17 | ) -> jax.Array: 18 | """ 19 | Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). 20 | 21 | Args: 22 | densitytensor: Input densitytensor. 23 | hermitian: Hermitian matrix representing observable 24 | must be in tensor form with shape (2,2,...). 25 | qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to. 26 | Must have 2 * len(qubit_inds) == hermitian.ndim 27 | Returns: 28 | Expected value (float). 29 | """ 30 | n_qubits = densitytensor.ndim // 2 31 | dt_indices = 2 * list(range(n_qubits)) 32 | hermitian_indices = [i + densitytensor.ndim // 2 for i in range(hermitian.ndim)] 33 | for n, q in enumerate(qubit_inds): 34 | dt_indices[q] = hermitian_indices[n + len(qubit_inds)] 35 | dt_indices[q + n_qubits] = hermitian_indices[n] 36 | return jnp.einsum(densitytensor, dt_indices, hermitian, hermitian_indices).real 37 | 38 | 39 | def get_densitytensor_to_expectation_func( 40 | hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]], 41 | qubits_seq_seq: Sequence[Sequence[int]], 42 | coefficients: Union[Sequence[float], jax.Array], 43 | ) -> Callable[[jax.Array], float]: 44 | """ 45 | Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and 46 | a list of coefficients and returns a function that converts a densitytensor into an 47 | expected value. 48 | 49 | Args: 50 | hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. 51 | Each Hermitian matrix is either represented by a tensor (jax.Array) 52 | or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. 53 | E.g. [['Z', 'Z'], ['X']] 54 | qubits_seq_seq: Sequence of sequences of integer qubit indices. 55 | E.g. [[0,1], [2]] 56 | coefficients: Sequence of float coefficients to scale the expected values. 57 | 58 | Returns: 59 | Function that takes densitytensor and returns expected value (float). 60 | """ 61 | 62 | return _get_tensor_to_expectation_func( 63 | hermitian_seq_seq, 64 | qubits_seq_seq, 65 | coefficients, 66 | densitytensor_to_single_expectation, 67 | ) 68 | 69 | 70 | def get_densitytensor_to_sampled_expectation_func( 71 | hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]], 72 | qubits_seq_seq: Sequence[Sequence[int]], 73 | coefficients: Union[Sequence[float], jax.Array], 74 | ) -> Callable[[jax.Array, random.PRNGKeyArray, int], float]: 75 | """ 76 | Converts strings (or arrays) representing Hermitian matrices, qubit indices and 77 | coefficients into a function that converts a densitytensor into a sampled expected value. 78 | 79 | On a quantum device, measurements are always taken in the computational basis, as such 80 | sampled expectation values should be taken with respect to an observable that commutes 81 | with the Pauli Z - a warning will be raised if it does not. 82 | 83 | qujax applies an importance sampling heuristic for sampled expectation values that only 84 | reflects the physical notion of measurement in the case that the observable commutes with Z. 85 | In the case that it does not, the expectation value will still be asymptotically unbiased 86 | but not representative of an experiment on a real quantum device. 87 | 88 | Args: 89 | hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. 90 | Each Hermitian is either a tensor (jax.Array) or a string in ('X', 'Y', 'Z'). 91 | E.g. [['Z', 'Z'], ['X']] 92 | qubits_seq_seq: Sequence of sequences of integer qubit indices. 93 | E.g. [[0,1], [2]] 94 | coefficients: Sequence of float coefficients to scale the expected values. 95 | 96 | Returns: 97 | Function that takes densitytensor, random key and integer number of shots 98 | and returns sampled expected value (float). 99 | """ 100 | densitytensor_to_expectation_func = get_densitytensor_to_expectation_func( 101 | hermitian_seq_seq, qubits_seq_seq, coefficients 102 | ) 103 | 104 | for hermitian_seq in hermitian_seq_seq: 105 | for h in hermitian_seq: 106 | check_hermitian(h, check_z_commutes=True) 107 | 108 | def densitytensor_to_sampled_expectation_func( 109 | densitytensor: jax.Array, random_key: random.PRNGKeyArray, n_samps: int 110 | ) -> float: 111 | """ 112 | Maps densitytensor to sampled expected value. 113 | 114 | Args: 115 | densitytensor: Input densitytensor. 116 | random_key: JAX random key 117 | n_samps: Number of samples contributing to sampled expectation. 118 | 119 | Returns: 120 | Sampled expected value (float). 121 | 122 | """ 123 | n_qubits = densitytensor.ndim // 2 124 | dm = densitytensor.reshape((2**n_qubits, 2**n_qubits)) 125 | measure_probs = jnp.diag(dm).real 126 | sampled_probs = sample_probs(measure_probs, random_key, n_samps) 127 | iweights = jnp.sqrt(sampled_probs / measure_probs) 128 | return densitytensor_to_expectation_func( 129 | densitytensor * jnp.outer(iweights, iweights).reshape(densitytensor.shape) 130 | ) 131 | 132 | return densitytensor_to_sampled_expectation_func 133 | 134 | 135 | def densitytensor_to_measurement_probabilities( 136 | densitytensor: jax.Array, qubit_inds: Sequence[int] 137 | ) -> jax.Array: 138 | """ 139 | Extract array of measurement probabilities given a densitytensor and some qubit indices to 140 | measure (in the computational basis). 141 | I.e. the ith element of the array corresponds to the probability of observing the bitstring 142 | represented by the integer i on the measured qubits. 143 | 144 | Args: 145 | densitytensor: Input densitytensor. 146 | qubit_inds: Sequence of qubit indices to measure. 147 | 148 | Returns: 149 | Normalised array of measurement probabilities. 150 | """ 151 | n_qubits = densitytensor.ndim // 2 152 | n_qubits_measured = len(qubit_inds) 153 | qubit_inds_trace_out = [i for i in range(n_qubits) if i not in qubit_inds] 154 | return jnp.diag( 155 | partial_trace(densitytensor, qubit_inds_trace_out).reshape( 156 | 2 * n_qubits_measured, 2 * n_qubits_measured 157 | ) 158 | ).real 159 | 160 | 161 | def densitytensor_to_measured_densitytensor( 162 | densitytensor: jax.Array, 163 | qubit_inds: Sequence[int], 164 | measurement: ArrayLike, 165 | ) -> jax.Array: 166 | """ 167 | Returns the post-measurement densitytensor assuming that qubit_inds are measured 168 | (in the computational basis) and the given measurement (integer or bitstring) is observed. 169 | 170 | Args: 171 | densitytensor: Input densitytensor. 172 | qubit_inds: Sequence of qubit indices to measure. 173 | measurement: Observed integer or bitstring. 174 | 175 | Returns: 176 | Post-measurement densitytensor (same shape as input densitytensor). 177 | """ 178 | measurement = jnp.array(measurement) 179 | measured_int = ( 180 | bitstrings_to_integers(measurement) if measurement.ndim == 1 else measurement 181 | ) 182 | 183 | n_qubits = densitytensor.ndim // 2 184 | n_qubits_measured = len(qubit_inds) 185 | qubit_inds_projector = jnp.diag( 186 | jnp.zeros(2**n_qubits_measured).at[measured_int].set(1) 187 | ).reshape((2,) * 2 * n_qubits_measured) 188 | unnorm_densitytensor = _kraus_single( 189 | densitytensor, qubit_inds_projector, qubit_inds 190 | ) 191 | norm_const = jnp.trace(unnorm_densitytensor.reshape(2**n_qubits, 2**n_qubits)).real 192 | return unnorm_densitytensor / norm_const 193 | -------------------------------------------------------------------------------- /qujax/gates.py: -------------------------------------------------------------------------------- 1 | import jax 2 | from jax import numpy as jnp 3 | 4 | I = jnp.eye(2) 5 | 6 | _0 = jnp.zeros((2, 2)) 7 | 8 | X = jnp.array([[0.0, 1.0], [1.0, 0.0]]) 9 | 10 | Y = jnp.array([[0.0, -1.0j], [1.0j, 0.0]]) 11 | 12 | Z = jnp.array([[1.0, 0.0], [0.0, -1.0]]) 13 | 14 | H = jnp.array([[1.0, 1.0], [1.0, -1]]) / jnp.sqrt(2) 15 | 16 | S = jnp.array([[1.0, 0.0], [0.0, 1.0j]]) 17 | 18 | Sdg = jnp.array([[1.0, 0.0], [0.0, -1.0j]]) 19 | 20 | T = jnp.array([[1.0, 0.0], [0.0, jnp.exp(jnp.pi * 1.0j / 4)]]) 21 | 22 | Tdg = jnp.array([[1.0, 0.0], [0.0, jnp.exp(-jnp.pi * 1.0j / 4)]]) 23 | 24 | V = jnp.array([[1.0, -1.0j], [-1.0j, 1.0]]) / jnp.sqrt(2) 25 | 26 | Vdg = jnp.array([[1.0, 1.0j], [1.0j, 1.0]]) / jnp.sqrt(2) 27 | 28 | SX = jnp.array([[1.0 + 1.0j, 1.0 - 1.0j], [1.0 - 1.0j, 1.0 + 1.0j]]) / 2 29 | 30 | SXdg = jnp.array([[1.0 - 1.0j, 1.0 + 1.0j], [1.0 + 1.0j, 1.0 - 1.0j]]) / 2 31 | 32 | CX = jnp.block([[I, _0], [_0, X]]).reshape((2,) * 4) 33 | 34 | CY = jnp.block([[I, _0], [_0, Y]]).reshape((2,) * 4) 35 | 36 | CZ = jnp.block([[I, _0], [_0, Z]]).reshape((2,) * 4) 37 | 38 | CH = jnp.block([[I, _0], [_0, H]]).reshape((2,) * 4) 39 | 40 | CV = jnp.block([[I, _0], [_0, V]]).reshape((2,) * 4) 41 | 42 | CVdg = jnp.block([[I, _0], [_0, Vdg]]).reshape((2,) * 4) 43 | 44 | CSX = jnp.block([[I, _0], [_0, SX]]).reshape((2,) * 4) 45 | 46 | CSXdg = jnp.block([[I, _0], [_0, SXdg]]).reshape((2,) * 4) 47 | 48 | CCX = jnp.block( 49 | [[I, _0, _0, _0], [_0, I, _0, _0], [_0, _0, I, _0], [_0, _0, _0, X]] # Toffoli gate 50 | ).reshape((2,) * 6) 51 | 52 | ECR = jnp.block([[_0, Vdg], [V, _0]]).reshape((2,) * 4) 53 | 54 | SWAP = jnp.array( 55 | [ 56 | [1.0, 0.0, 0.0, 0.0], 57 | [0.0, 0.0, 1.0, 0.0], 58 | [0.0, 1.0, 0.0, 0.0], 59 | [0.0, 0.0, 0.0, 1], 60 | ] 61 | ) 62 | 63 | CSWAP = jnp.block([[jnp.eye(4), jnp.zeros((4, 4))], [jnp.zeros((4, 4)), SWAP]]).reshape( 64 | (2,) * 6 65 | ) 66 | 67 | 68 | def Rx(param: float) -> jax.Array: 69 | param_pi_2 = param * jnp.pi / 2 70 | return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * X * 1.0j 71 | 72 | 73 | def Ry(param: float) -> jax.Array: 74 | param_pi_2 = param * jnp.pi / 2 75 | return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Y * 1.0j 76 | 77 | 78 | def Rz(param: float) -> jax.Array: 79 | param_pi_2 = param * jnp.pi / 2 80 | return jnp.cos(param_pi_2) * I - jnp.sin(param_pi_2) * Z * 1.0j 81 | 82 | 83 | def CRx(param: float) -> jax.Array: 84 | return jnp.block([[I, _0], [_0, Rx(param)]]).reshape((2,) * 4) 85 | 86 | 87 | def CRy(param: float) -> jax.Array: 88 | return jnp.block([[I, _0], [_0, Ry(param)]]).reshape((2,) * 4) 89 | 90 | 91 | def CRz(param: float) -> jax.Array: 92 | return jnp.block([[I, _0], [_0, Rz(param)]]).reshape((2,) * 4) 93 | 94 | 95 | def U1(param: float) -> jax.Array: 96 | return U3(0, 0, param) 97 | 98 | 99 | def U2(param0: float, param1: float) -> jax.Array: 100 | return U3(0.5, param0, param1) 101 | 102 | 103 | def U3(param0: float, param1: float, param2: float) -> jax.Array: 104 | return ( 105 | jnp.exp((param1 + param2) * jnp.pi * 1.0j / 2) 106 | * Rz(param1) 107 | @ Ry(param0) 108 | @ Rz(param2) 109 | ) 110 | 111 | 112 | def CU1(param: float) -> jax.Array: 113 | return jnp.block([[I, _0], [_0, U1(param)]]).reshape((2,) * 4) 114 | 115 | 116 | def CU2(param0: float, param1: float) -> jax.Array: 117 | return jnp.block([[I, _0], [_0, U2(param0, param1)]]).reshape((2,) * 4) 118 | 119 | 120 | def CU3(param0: float, param1: float, param2: float) -> jax.Array: 121 | return jnp.block([[I, _0], [_0, U3(param0, param1, param2)]]).reshape((2,) * 4) 122 | 123 | 124 | def ISWAP(param: float) -> jax.Array: 125 | param_pi_2 = param * jnp.pi / 2 126 | c = jnp.cos(param_pi_2) 127 | i_s = 1.0j * jnp.sin(param_pi_2) 128 | return jnp.array( 129 | [ 130 | [1.0, 0.0, 0.0, 0.0], 131 | [0.0, c, i_s, 0.0], 132 | [0.0, i_s, c, 0.0], 133 | [0.0, 0.0, 0.0, 1.0], 134 | ] 135 | ).reshape((2,) * 4) 136 | 137 | 138 | def PhasedISWAP(param0: float, param1: float) -> jax.Array: 139 | param1_pi_2 = param1 * jnp.pi / 2 140 | c = jnp.cos(param1_pi_2) 141 | i_s = 1.0j * jnp.sin(param1_pi_2) 142 | return jnp.array( 143 | [ 144 | [1.0, 0.0, 0.0, 0.0], 145 | [0.0, c, i_s * jnp.exp(2.0j * jnp.pi * param0), 0.0], 146 | [0.0, i_s * jnp.exp(-2.0j * jnp.pi * param0), c, 0.0], 147 | [0.0, 0.0, 0.0, 1.0], 148 | ] 149 | ).reshape((2,) * 4) 150 | 151 | 152 | def XXPhase(param: float) -> jax.Array: 153 | param_pi_2 = param * jnp.pi / 2 154 | c = jnp.cos(param_pi_2) 155 | i_s = 1.0j * jnp.sin(param_pi_2) 156 | return jnp.array( 157 | [ 158 | [c, 0.0, 0.0, -i_s], 159 | [0.0, c, -i_s, 0.0], 160 | [0.0, -i_s, c, 0.0], 161 | [-i_s, 0.0, 0.0, c], 162 | ] 163 | ).reshape((2,) * 4) 164 | 165 | 166 | def YYPhase(param: float) -> jax.Array: 167 | param_pi_2 = param * jnp.pi / 2 168 | c = jnp.cos(param_pi_2) 169 | i_s = 1.0j * jnp.sin(param_pi_2) 170 | return jnp.array( 171 | [ 172 | [c, 0.0, 0.0, i_s], 173 | [0.0, c, -i_s, 0.0], 174 | [0.0, -i_s, c, 0.0], 175 | [i_s, 0.0, 0.0, c], 176 | ] 177 | ).reshape((2,) * 4) 178 | 179 | 180 | def ZZPhase(param: float) -> jax.Array: 181 | param_pi_2 = param * jnp.pi / 2 182 | e_m = jnp.exp(-1.0j * param_pi_2) 183 | e_p = jnp.exp(1.0j * param_pi_2) 184 | return jnp.diag(jnp.array([e_m, e_p, e_p, e_m])).reshape((2,) * 4) 185 | 186 | 187 | ZZMax = ZZPhase(0.5) 188 | 189 | 190 | def PhasedX(param0: float, param1: float) -> jax.Array: 191 | return Rz(param1) @ Rx(param0) @ Rz(-param1) 192 | -------------------------------------------------------------------------------- /qujax/statetensor.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from functools import partial 4 | from typing import Callable, Sequence, Optional 5 | 6 | import jax 7 | from jax import numpy as jnp 8 | from jax.typing import ArrayLike 9 | from jax._src.dtypes import canonicalize_dtype 10 | from jax._src.typing import DTypeLike 11 | 12 | from qujax import gates 13 | from qujax.utils import _arrayify_inds, check_circuit 14 | 15 | from qujax.typing import Gate, PureCircuitFunction, GateFunction, GateParameterIndices 16 | 17 | 18 | def apply_gate( 19 | statetensor: jax.Array, gate_unitary: jax.Array, qubit_inds: Sequence[int] 20 | ) -> jax.Array: 21 | """ 22 | Applies gate to statetensor and returns updated statetensor. 23 | Gate is represented by a unitary matrix in tensor form. 24 | 25 | Args: 26 | statetensor: Input statetensor. 27 | gate_unitary: Unitary array representing gate 28 | must be in tensor form with shape (2,2,...). 29 | qubit_inds: Sequence of indices for gate to be applied to. 30 | Must have 2 * len(qubit_inds) = gate_unitary.ndim 31 | 32 | Returns: 33 | Updated statetensor. 34 | """ 35 | statetensor = jnp.tensordot( 36 | gate_unitary, statetensor, axes=(list(range(-len(qubit_inds), 0)), qubit_inds) 37 | ) 38 | statetensor = jnp.moveaxis(statetensor, list(range(len(qubit_inds))), qubit_inds) 39 | return statetensor 40 | 41 | 42 | def _to_gate_func( 43 | gate: Gate, 44 | ) -> GateFunction: 45 | """ 46 | Ensures a gate_seq element is a function that map (possibly empty) parameters 47 | to a unitary tensor. 48 | 49 | Args: 50 | gate: Either a string matching an array or function in qujax.gates, 51 | a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) 52 | or a function taking parameters and returning gate unitary in tensor form. 53 | 54 | Returns: 55 | Gate parameter to unitary functions 56 | """ 57 | 58 | def _array_to_callable(arr: jax.Array) -> Callable[[], jax.Array]: 59 | return lambda: arr 60 | 61 | if isinstance(gate, str): 62 | gate = gates.__dict__[gate] 63 | 64 | if callable(gate): 65 | gate_func = gate 66 | elif hasattr(gate, "__array__"): 67 | gate_func = _array_to_callable(jnp.array(gate)) 68 | else: 69 | raise TypeError( 70 | f"Unsupported gate type - gate must be either a string in qujax.gates, an array or " 71 | f"callable: {gate}" 72 | ) 73 | return gate_func 74 | 75 | 76 | def _gate_func_to_unitary( 77 | gate_func: GateFunction, 78 | qubit_inds: Sequence[int], 79 | param_inds: jax.Array, 80 | params: jax.Array, 81 | ) -> jax.Array: 82 | """ 83 | Extract gate unitary. 84 | 85 | Args: 86 | gate_func: Function that maps a (possibly empty) parameter array to a unitary tensor (array) 87 | qubit_inds: Indices of qubits to apply gate to 88 | (only needed to ensure gate is in tensor form) 89 | param_inds: Indices of full parameter to extract gate specific parameters 90 | params: Full parameter vector 91 | 92 | Returns: 93 | Array containing gate unitary in tensor form. 94 | """ 95 | gate_params = jnp.take(params, param_inds) 96 | gate_unitary = gate_func(*gate_params) 97 | gate_unitary = gate_unitary.reshape( 98 | (2,) * (2 * len(qubit_inds)) 99 | ) # Ensure gate is in tensor form 100 | return gate_unitary 101 | 102 | 103 | def all_zeros_statetensor(n_qubits: int, dtype: DTypeLike = complex) -> jax.Array: 104 | """ 105 | Returns a statetensor representation of the all-zeros state |00...0> on `n_qubits` qubits 106 | 107 | Args: 108 | n_qubits: Number of qubits that the state is defined on. 109 | dtype: Data type of the statetensor returned. 110 | 111 | Returns: 112 | Statetensor representing the state having all qubits set to zero. 113 | """ 114 | statetensor = jnp.zeros((2,) * n_qubits, dtype=canonicalize_dtype(dtype)) 115 | statetensor = statetensor.at[(0,) * n_qubits].set(1.0) 116 | return statetensor 117 | 118 | 119 | def get_params_to_statetensor_func( 120 | gate_seq: Sequence[Gate], 121 | qubit_inds_seq: Sequence[Sequence[int]], 122 | param_inds_seq: Sequence[GateParameterIndices], 123 | n_qubits: Optional[int] = None, 124 | ) -> PureCircuitFunction: 125 | """ 126 | Creates a function that maps circuit parameters to a statetensor. 127 | 128 | Args: 129 | gate_seq: Sequence of gates. 130 | Each element is either a string matching a unitary array or function in qujax.gates, 131 | a custom unitary array or a custom function taking parameters and returning a 132 | unitary array. Unitary arrays will be reshaped into tensor form (2, 2,...) 133 | qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are 134 | acting on. 135 | i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the 136 | zeroth qubit, the second gate is a two qubit gate acting on the zeroth and first qubit 137 | etc. 138 | param_inds_seq: Sequence of sequences representing parameter indices that gates are using, 139 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter 140 | (the float at position zero in the parameter vector/array), the second gate is not 141 | parameterised and the third gate uses the parameters at position five and two. 142 | n_qubits: Number of qubits, if fixed. 143 | 144 | Returns: 145 | Function which maps parameters (and optional statetensor_in) to a statetensor. 146 | If no parameters are found then the function only takes optional statetensor_in. 147 | 148 | """ 149 | 150 | check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits) 151 | 152 | if n_qubits is None: 153 | n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 154 | 155 | gate_seq_callable = [_to_gate_func(g) for g in gate_seq] 156 | param_inds_array_seq = _arrayify_inds(param_inds_seq) 157 | 158 | def params_to_statetensor_func( 159 | params: ArrayLike, statetensor_in: Optional[jax.Array] = None 160 | ) -> jax.Array: 161 | """ 162 | Applies parameterised circuit (series of gates) to a statetensor_in (default is |0>^N). 163 | 164 | Args: 165 | params: Parameters of the circuit. 166 | statetensor_in: Optional. Input statetensor. 167 | Defaults to |0>^N (tensor of size 2^n with all zeroes except one in [0]*N index). 168 | 169 | Returns: 170 | Updated statetensor. 171 | 172 | """ 173 | if statetensor_in is None: 174 | statetensor = all_zeros_statetensor(n_qubits) 175 | else: 176 | statetensor = statetensor_in 177 | 178 | params = jnp.atleast_1d(params) 179 | # Guarantee `params` has the right type for type-checking purposes 180 | if not isinstance(params, jax.Array): 181 | raise ValueError("This should not happen. Please open an issue on GitHub.") 182 | 183 | for gate_func, qubit_inds, param_inds in zip( 184 | gate_seq_callable, qubit_inds_seq, param_inds_array_seq 185 | ): 186 | gate_unitary = _gate_func_to_unitary( 187 | gate_func, qubit_inds, param_inds, params 188 | ) 189 | statetensor = apply_gate(statetensor, gate_unitary, qubit_inds) 190 | return statetensor 191 | 192 | non_parameterised = all([pi.size == 0 for pi in param_inds_array_seq]) 193 | if non_parameterised: 194 | 195 | def no_params_to_statetensor_func( 196 | statetensor_in: Optional[jax.Array] = None, 197 | ) -> jax.Array: 198 | """ 199 | Applies circuit (series of gates with no parameters) to a statetensor_in 200 | (default is |0>^N). 201 | 202 | Args: 203 | statetensor_in: Optional. Input statetensor. 204 | Defaults to |0>^N (tensor of size 2^n with all zeroes except one in 205 | the [0]*N index). 206 | 207 | Returns: 208 | Updated statetensor. 209 | 210 | """ 211 | return params_to_statetensor_func(jnp.array([]), statetensor_in) 212 | 213 | return no_params_to_statetensor_func 214 | 215 | return params_to_statetensor_func 216 | 217 | 218 | def get_params_to_unitarytensor_func( 219 | gate_seq: Sequence[Gate], 220 | qubit_inds_seq: Sequence[Sequence[int]], 221 | param_inds_seq: Sequence[GateParameterIndices], 222 | n_qubits: Optional[int] = None, 223 | ) -> PureCircuitFunction: 224 | """ 225 | Creates a function that maps circuit parameters to a unitarytensor. 226 | The unitarytensor is an array with shape (2,) * 2 * n_qubits 227 | representing the full unitary matrix of the circuit. 228 | 229 | Args: 230 | gate_seq: Sequence of gates. 231 | Each element is either a string matching a unitary array or function in qujax.gates, 232 | a custom unitary array or a custom function taking parameters and returning a unitary 233 | array. Unitary arrays will be reshaped into tensor form (2, 2,...) 234 | qubit_inds_seq: Sequences of sequences representing qubit indices (ints) that gates are 235 | acting on. 236 | i.e. [[0], [0,1], [1]] tells qujax the first gate is a single qubit gate acting on the 237 | zeroth qubit, the second gate is a two qubit gate acting on the zeroth and first qubit 238 | etc. 239 | param_inds_seq: Sequence of sequences representing parameter indices that gates are using, 240 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter 241 | (the float at position zero in the parameter vector/array), the second gate is not 242 | parameterised and the third gate uses the parameters at position five and two. 243 | n_qubits: Number of qubits, if fixed. 244 | 245 | Returns: 246 | Function which maps any parameters to a unitarytensor. 247 | 248 | """ 249 | 250 | if n_qubits is None: 251 | n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 252 | 253 | param_to_st = get_params_to_statetensor_func( 254 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits 255 | ) 256 | identity_unitarytensor = jnp.eye(2**n_qubits).reshape((2,) * 2 * n_qubits) 257 | return partial(param_to_st, statetensor_in=identity_unitarytensor) 258 | -------------------------------------------------------------------------------- /qujax/statetensor_observable.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import Callable, Sequence, Union 4 | 5 | import jax 6 | from jax import numpy as jnp 7 | from jax import random 8 | from jax.lax import fori_loop 9 | 10 | from qujax.statetensor import apply_gate 11 | from qujax.utils import check_hermitian, paulis 12 | 13 | 14 | def statetensor_to_single_expectation( 15 | statetensor: jax.Array, hermitian: jax.Array, qubit_inds: Sequence[int] 16 | ) -> jax.Array: 17 | """ 18 | Evaluates expectation value of an observable represented by a Hermitian matrix (in tensor form). 19 | 20 | Args: 21 | statetensor: Input statetensor. 22 | hermitian: Hermitian array 23 | must be in tensor form with shape (2,2,...). 24 | qubit_inds: Sequence of qubit indices for Hermitian matrix to be applied to. 25 | Must have 2 * len(qubit_inds) == hermitian.ndim 26 | 27 | Returns: 28 | Expected value (float). 29 | """ 30 | statetensor_new = apply_gate(statetensor, hermitian, qubit_inds) 31 | axes = tuple(range(statetensor.ndim)) 32 | return jnp.tensordot( 33 | statetensor.conjugate(), statetensor_new, axes=(axes, axes) 34 | ).real 35 | 36 | 37 | def get_hermitian_tensor(hermitian_seq: Sequence[Union[str, jax.Array]]) -> jax.Array: 38 | """ 39 | Convert a sequence of observables represented by Pauli strings or Hermitian matrices 40 | in tensor form into single array (in tensor form). 41 | 42 | Args: 43 | hermitian_seq: Sequence of Hermitian strings or arrays. 44 | 45 | Returns: 46 | Hermitian matrix in tensor form (array). 47 | """ 48 | for h in hermitian_seq: 49 | check_hermitian(h) 50 | 51 | single_arrs = [paulis[h] if isinstance(h, str) else h for h in hermitian_seq] 52 | single_arrs = [ 53 | h_arr.reshape((2,) * int(jnp.rint(jnp.log2(h_arr.size)))) 54 | for h_arr in single_arrs 55 | ] 56 | 57 | full_mat = single_arrs[0] 58 | for single_matrix in single_arrs[1:]: 59 | full_mat = jnp.kron(full_mat, single_matrix) 60 | full_mat = full_mat.reshape((2,) * int(jnp.rint(jnp.log2(full_mat.size)))) 61 | return full_mat 62 | 63 | 64 | def _get_tensor_to_expectation_func( 65 | hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]], 66 | qubits_seq_seq: Sequence[Sequence[int]], 67 | coefficients: Union[Sequence[float], jax.Array], 68 | contraction_function: Callable, 69 | ) -> Callable[[jax.Array], float]: 70 | """ 71 | Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and 72 | a list of coefficients and returns a function that converts a tensor into an expected value. 73 | The contraction function performs the tensor contraction according to the type of tensor 74 | provided (i.e. whether it is a statetensor or a densitytensor). 75 | 76 | Args: 77 | hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. 78 | Each Hermitian matrix is either represented by a tensor (jax.Array) or by a 79 | list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. 80 | E.g. [['Z', 'Z'], ['X']] 81 | qubits_seq_seq: Sequence of sequences of integer qubit indices. 82 | E.g. [[0,1], [2]] 83 | coefficients: Sequence of float coefficients to scale the expected values. 84 | contraction_function: Function that performs the tensor contraction. 85 | 86 | Returns: 87 | Function that takes tensor and returns expected value (float). 88 | """ 89 | 90 | hermitian_tensors = [get_hermitian_tensor(h_seq) for h_seq in hermitian_seq_seq] 91 | 92 | def tensor_to_expectation_func(tensor: jax.Array) -> float: 93 | """ 94 | Maps tensor to expected value. 95 | 96 | Args: 97 | tensor: Input tensor. 98 | 99 | Returns: 100 | Expected value (float). 101 | """ 102 | out = 0 103 | for hermitian, qubit_inds, coeff in zip( 104 | hermitian_tensors, qubits_seq_seq, coefficients 105 | ): 106 | out += coeff * contraction_function(tensor, hermitian, qubit_inds) 107 | return out 108 | 109 | return tensor_to_expectation_func 110 | 111 | 112 | def get_statetensor_to_expectation_func( 113 | hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]], 114 | qubits_seq_seq: Sequence[Sequence[int]], 115 | coefficients: Union[Sequence[float], jax.Array], 116 | ) -> Callable[[jax.Array], float]: 117 | """ 118 | Takes strings (or arrays) representing Hermitian matrices, along with qubit indices and 119 | a list of coefficients and returns a function that converts a statetensor into an expected 120 | value. 121 | 122 | Args: 123 | hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. 124 | Each Hermitian matrix is either represented by a tensor (jax.Array) 125 | or by a list of 'X', 'Y' or 'Z' characters corresponding to the standard Pauli matrices. 126 | E.g. [['Z', 'Z'], ['X']] 127 | qubits_seq_seq: Sequence of sequences of integer qubit indices. 128 | E.g. [[0,1], [2]] 129 | coefficients: Sequence of float coefficients to scale the expected values. 130 | 131 | Returns: 132 | Function that takes statetensor and returns expected value (float). 133 | """ 134 | 135 | return _get_tensor_to_expectation_func( 136 | hermitian_seq_seq, 137 | qubits_seq_seq, 138 | coefficients, 139 | statetensor_to_single_expectation, 140 | ) 141 | 142 | 143 | def get_statetensor_to_sampled_expectation_func( 144 | hermitian_seq_seq: Sequence[Sequence[Union[str, jax.Array]]], 145 | qubits_seq_seq: Sequence[Sequence[int]], 146 | coefficients: Union[Sequence[float], jax.Array], 147 | ) -> Callable[[jax.Array, random.PRNGKeyArray, int], float]: 148 | """ 149 | Converts strings (or arrays) representing Hermitian matrices, qubit indices and 150 | coefficients into a function that converts a statetensor into a sampled expected value. 151 | 152 | On a quantum device, measurements are always taken in the computational basis, as such 153 | sampled expectation values should be taken with respect to an observable that commutes 154 | with the Pauli Z - a warning will be raised if it does not. 155 | 156 | qujax applies an importance sampling heuristic for sampled expectation values that only 157 | reflects the physical notion of measurement in the case that the observable commutes with Z. 158 | In the case that it does not, the expectation value will still be asymptotically unbiased 159 | but not representative of an experiment on a real quantum device. 160 | 161 | Args: 162 | hermitian_seq_seq: Sequence of sequences of Hermitian matrices/tensors. 163 | Each Hermitian is either a tensor (jax.Array) or a string in ('X', 'Y', 'Z'). 164 | E.g. [['Z', 'Z'], ['X']] 165 | qubits_seq_seq: Sequence of sequences of integer qubit indices. 166 | E.g. [[0,1], [2]] 167 | coefficients: Sequence of float coefficients to scale the expected values. 168 | 169 | Returns: 170 | Function that takes statetensor, random key and integer number of shots 171 | and returns sampled expected value (float). 172 | """ 173 | statetensor_to_expectation_func = get_statetensor_to_expectation_func( 174 | hermitian_seq_seq, qubits_seq_seq, coefficients 175 | ) 176 | 177 | for hermitian_seq in hermitian_seq_seq: 178 | for h in hermitian_seq: 179 | check_hermitian(h, check_z_commutes=True) 180 | 181 | def statetensor_to_sampled_expectation_func( 182 | statetensor: jax.Array, random_key: random.PRNGKeyArray, n_samps: int 183 | ) -> float: 184 | """ 185 | Maps statetensor to sampled expected value. 186 | 187 | Args: 188 | statetensor: Input statetensor. 189 | random_key: JAX random key 190 | n_samps: Number of samples contributing to sampled expectation. 191 | 192 | Returns: 193 | Sampled expected value (float). 194 | """ 195 | measure_probs = jnp.abs(statetensor) ** 2 196 | sampled_probs = sample_probs(measure_probs, random_key, n_samps) 197 | iweights = jnp.sqrt(sampled_probs / measure_probs) 198 | return statetensor_to_expectation_func(statetensor * iweights) 199 | 200 | return statetensor_to_sampled_expectation_func 201 | 202 | 203 | def sample_probs( 204 | measure_probs: jax.Array, random_key: random.PRNGKeyArray, n_samps: int 205 | ): 206 | """ 207 | Generate an empirical distribution from a probability distribution. 208 | 209 | Args: 210 | measure_probs: Probability distribution. 211 | random_key: JAX random key 212 | n_samps: Number of samples contributing to empirical distribution. 213 | 214 | Returns: 215 | Empirical distribution (jax.Array). 216 | """ 217 | measure_probs_flat = measure_probs.flatten() 218 | sampled_integers = random.choice( 219 | random_key, 220 | a=jnp.arange(measure_probs.size), 221 | shape=(n_samps,), 222 | p=measure_probs_flat, 223 | ) 224 | sampled_probs = fori_loop( 225 | 0, 226 | n_samps, 227 | lambda i, sv: sv.at[sampled_integers[i]].add(1 / n_samps), 228 | jnp.zeros_like(measure_probs_flat), 229 | ) 230 | return sampled_probs.reshape(measure_probs.shape) 231 | -------------------------------------------------------------------------------- /qujax/typing.py: -------------------------------------------------------------------------------- 1 | from typing import Union, Optional, Protocol, Callable, Iterable, Sequence 2 | 3 | # Backwards compatibility with Python <3.10 4 | from typing_extensions import TypeVarTuple, Unpack 5 | 6 | import jax 7 | from jax.typing import ArrayLike 8 | 9 | 10 | class PureParameterizedCircuit(Protocol): 11 | def __call__( 12 | self, params: ArrayLike, statetensor_in: Optional[jax.Array] = None 13 | ) -> jax.Array: ... 14 | 15 | 16 | class PureUnparameterizedCircuit(Protocol): 17 | def __call__(self, statetensor_in: Optional[jax.Array] = None) -> jax.Array: ... 18 | 19 | 20 | class MixedParameterizedCircuit(Protocol): 21 | def __call__( 22 | self, params: ArrayLike, densitytensor_in: Optional[jax.Array] = None 23 | ) -> jax.Array: ... 24 | 25 | 26 | class MixedUnparameterizedCircuit(Protocol): 27 | def __call__(self, densitytensor_in: Optional[jax.Array] = None) -> jax.Array: ... 28 | 29 | 30 | GateArgs = TypeVarTuple("GateArgs") 31 | # Function that takes arbitrary nr. of parameters and returns an array representing the gate 32 | # Currently Python does not allow us to restrict the type of the arguments using a TypeVarTuple 33 | GateFunction = Callable[[Unpack[GateArgs]], jax.Array] 34 | GateParameterIndices = Optional[Sequence[int]] 35 | 36 | PureCircuitFunction = Union[PureUnparameterizedCircuit, PureParameterizedCircuit] 37 | MixedCircuitFunction = Union[MixedUnparameterizedCircuit, MixedParameterizedCircuit] 38 | 39 | Gate = Union[str, jax.Array, GateFunction] 40 | 41 | KrausOp = Union[Gate, Iterable[Gate]] 42 | -------------------------------------------------------------------------------- /qujax/utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import collections.abc 4 | from inspect import signature 5 | from typing import Callable, List, Optional, Sequence, Tuple, Union 6 | from warnings import warn 7 | 8 | import jax 9 | from jax.typing import ArrayLike 10 | from jax import numpy as jnp 11 | from jax import random 12 | 13 | from qujax import gates 14 | 15 | from qujax.typing import ( 16 | Gate, 17 | KrausOp, 18 | GateParameterIndices, 19 | PureParameterizedCircuit, 20 | MixedParameterizedCircuit, 21 | ) 22 | 23 | paulis = {"X": gates.X, "Y": gates.Y, "Z": gates.Z} 24 | 25 | 26 | def check_unitary(gate: Gate) -> None: 27 | """ 28 | Checks whether a qujax Gate is unitary. 29 | Throws a TypeError if this is found not to be the case. 30 | 31 | Args: 32 | gate: array containing potentially unitary string, array 33 | or function (which will be evaluated with all arguments set to 0.1). 34 | 35 | """ 36 | if isinstance(gate, str): 37 | if gate in gates.__dict__: 38 | gate = gates.__dict__[gate] 39 | else: 40 | raise KeyError( 41 | f"Gate string '{gate}' not found in qujax.gates " 42 | f"- consider changing input to an array or callable" 43 | ) 44 | 45 | if callable(gate): 46 | num_args = len(signature(gate).parameters) 47 | gate_arr = gate(*jnp.ones(num_args) * 0.1) 48 | elif isinstance(gate, jax.Array): 49 | gate_arr = gate 50 | else: 51 | raise TypeError( 52 | f"Unsupported gate type - gate must be either a string in qujax.gates, an array or " 53 | f"callable: {gate}" 54 | ) 55 | 56 | gate_square_dim = int(jnp.sqrt(gate_arr.size)) 57 | gate_arr = gate_arr.reshape(gate_square_dim, gate_square_dim) 58 | 59 | if jnp.any( 60 | jnp.abs(gate_arr @ jnp.conjugate(gate_arr).T - jnp.eye(gate_square_dim)) > 1e-3 61 | ): 62 | raise TypeError(f"Gate not unitary: {gate}") 63 | 64 | 65 | def check_hermitian(hermitian: Union[str, jax.Array], check_z_commutes: bool = False): 66 | """ 67 | Checks whether a matrix or tensor is Hermitian. 68 | 69 | Args: 70 | hermitian: array containing potentially Hermitian matrix or tensor 71 | check_z_commutes: boolean on whether to check if the matrix commutes with Z 72 | 73 | """ 74 | if isinstance(hermitian, str): 75 | if hermitian not in paulis: 76 | raise TypeError( 77 | f"qujax only accepts {tuple(paulis.keys())} as Hermitian strings," 78 | "received: {hermitian}" 79 | ) 80 | n_qubits = 1 81 | hermitian_mat = paulis[hermitian] 82 | 83 | else: 84 | n_qubits = hermitian.ndim // 2 85 | hermitian_mat = hermitian.reshape(2 * n_qubits, 2 * n_qubits) 86 | if not jnp.allclose(hermitian_mat, hermitian_mat.T.conj()): 87 | raise TypeError(f"Array not Hermitian: {hermitian}") 88 | 89 | if check_z_commutes: 90 | big_z = jnp.diag(jnp.where(jnp.arange(2**n_qubits) % 2 == 0, 1, -1)) 91 | z_commutes = jnp.allclose(hermitian_mat @ big_z, big_z @ hermitian_mat) 92 | if not z_commutes: 93 | warn( 94 | "Hermitian matrix does not commute with Z. \n" 95 | "For sampled expectation values, this may lead to unexpected results, " 96 | "measurements on a quantum device are always taken in the computational basis. " 97 | "Additional gates can be applied in the circuit to change the basis such " 98 | "that an observable that commutes with Z can be measured." 99 | ) 100 | 101 | 102 | def _arrayify_inds( 103 | param_inds_seq: Optional[Sequence[GateParameterIndices]], 104 | ) -> Sequence[jax.Array]: 105 | """ 106 | Ensure each element of param_inds_seq is an array (and therefore valid for jnp.take) 107 | 108 | Args: 109 | param_inds_seq: Sequence of sequences representing parameter indices that gates are using, 110 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the zeroth parameter 111 | (the float at position zero in the parameter vector/array), the second gate is not 112 | parameterised and the third gates used the parameters at position five and two. 113 | 114 | Returns: 115 | Sequence of arrays representing parameter indices. 116 | """ 117 | if param_inds_seq is None: 118 | param_inds_seq = [None] 119 | array_param_inds = [jnp.array(p) for p in param_inds_seq] 120 | array_param_inds = [ 121 | jnp.array([]) if jnp.any(jnp.isnan(p)) else p.astype(int) 122 | for p in array_param_inds 123 | ] 124 | return array_param_inds 125 | 126 | 127 | def check_circuit( 128 | gate_seq: Sequence[KrausOp], 129 | qubit_inds_seq: Sequence[Sequence[int]], 130 | param_inds_seq: Sequence[ 131 | Union[GateParameterIndices, Sequence[GateParameterIndices]] 132 | ], 133 | n_qubits: Optional[int] = None, 134 | check_unitaries: bool = True, 135 | ): 136 | """ 137 | Basic checks that circuit arguments conform. 138 | 139 | Args: 140 | gate_seq: Sequence of gates. 141 | Each element is either a string matching an array or function in qujax.gates, 142 | a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) 143 | or a function taking parameters and returning gate unitary in tensor form. 144 | Or alternatively a sequence of the above representing Kraus operators. 145 | qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. 146 | param_inds_seq: Sequence of parameter indices that gates are using, 147 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, 148 | the second gate is not parameterised and the third gates used the fifth and second 149 | parameters. 150 | n_qubits: Number of qubits, if fixed. 151 | check_unitaries: boolean on whether to check if each gate represents a unitary matrix 152 | 153 | """ 154 | if not isinstance(gate_seq, collections.abc.Sequence): 155 | raise TypeError("gate_seq must be Sequence e.g. ['H', 'Rx', 'CX']") 156 | 157 | if (not isinstance(qubit_inds_seq, collections.abc.Sequence)) or ( 158 | any( 159 | [ 160 | not (isinstance(q, collections.abc.Sequence) or hasattr(q, "__array__")) 161 | for q in qubit_inds_seq 162 | ] 163 | ) 164 | ): 165 | raise TypeError( 166 | "qubit_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]" 167 | ) 168 | 169 | if (not isinstance(param_inds_seq, collections.abc.Sequence)) or ( 170 | any( 171 | [ 172 | not ( 173 | isinstance(p, collections.abc.Sequence) 174 | or hasattr(p, "__array__") 175 | or p is None 176 | ) 177 | for p in param_inds_seq 178 | ] 179 | ) 180 | ): 181 | raise TypeError( 182 | "param_inds_seq must be Sequence of Sequences e.g. [[0,1], [0], []]" 183 | ) 184 | 185 | if len(gate_seq) != len(qubit_inds_seq) or len(param_inds_seq) != len( 186 | param_inds_seq 187 | ): 188 | raise TypeError( 189 | f"gate_seq ({len(gate_seq)}), qubit_inds_seq ({len(qubit_inds_seq)})" 190 | f"and param_inds_seq ({len(param_inds_seq)}) must have matching lengths" 191 | ) 192 | 193 | if n_qubits is not None and n_qubits < max([max(qi) for qi in qubit_inds_seq]) + 1: 194 | raise TypeError( 195 | "n_qubits must be larger than largest qubit index in qubit_inds_seq" 196 | ) 197 | 198 | if check_unitaries: 199 | for g in gate_seq: 200 | check_unitary(g) 201 | 202 | 203 | def _get_gate_str( 204 | gate_obj: KrausOp, 205 | param_inds: Union[GateParameterIndices, Sequence[GateParameterIndices]], 206 | ) -> str: 207 | """ 208 | Maps single gate object to a four character string representation 209 | 210 | Args: 211 | gate_obj: Either a string matching a function in qujax.gates, 212 | a unitary array (which will be reshaped into a tensor of shape e.g. (2,2,2,...) ) 213 | or a function taking parameters (can be empty) and returning gate unitary 214 | in tensor form. 215 | Or alternatively, a sequence of Krause operators represented by strings, arrays or 216 | functions. 217 | param_inds: Parameter indices that gates are using, i.e. gate uses 1st and 5th parameter. 218 | 219 | Returns: 220 | Four character string representation of the gate 221 | 222 | """ 223 | if isinstance(gate_obj, (tuple, list)) or ( 224 | hasattr(gate_obj, "__array__") and gate_obj.ndim % 2 == 1 225 | ): 226 | # Kraus operators 227 | gate_obj = "Kr" 228 | param_inds = jnp.unique(jnp.concatenate(_arrayify_inds(param_inds), axis=0)) 229 | 230 | if isinstance(gate_obj, str): 231 | gate_str = gate_obj 232 | elif hasattr(gate_obj, "__array__"): 233 | gate_str = "Arr" 234 | elif callable(gate_obj): 235 | gate_str = "Func" 236 | else: 237 | if hasattr(gate_obj, "__name__"): 238 | gate_str = gate_obj.__name__ 239 | elif hasattr(gate_obj, "__class__") and hasattr(gate_obj.__class__, "__name__"): 240 | gate_str = gate_obj.__class__.__name__ 241 | else: 242 | gate_str = "Other" 243 | 244 | if hasattr(param_inds, "tolist"): 245 | param_inds = param_inds.tolist() 246 | 247 | if isinstance(param_inds, tuple): 248 | param_inds = list(param_inds) 249 | 250 | if param_inds == [] or param_inds == [None] or param_inds is None: 251 | if len(gate_str) > 7: 252 | gate_str = gate_str[:6] + "." 253 | else: 254 | param_str = str(param_inds).replace(" ", "") 255 | 256 | if len(param_str) > 5: 257 | param_str = "[.]" 258 | 259 | if (len(gate_str) + len(param_str)) > 7: 260 | gate_str = gate_str[:1] + "." 261 | 262 | gate_str += param_str 263 | 264 | gate_str = gate_str.center(7, "-") 265 | 266 | return gate_str 267 | 268 | 269 | def _pad_rows(rows: List[str]) -> Tuple[List[str], List[bool]]: 270 | """ 271 | Pad string representation of circuit to be rectangular. 272 | Fills qubit rows with '-' and between-qubit rows with ' '. 273 | 274 | Args: 275 | rows: String representation of circuit 276 | 277 | Returns: 278 | Rectangular string representation of circuit with right padding. 279 | 280 | """ 281 | 282 | max_len = max([len(r) for r in rows]) 283 | 284 | def extend_row(row: str, qubit_row: bool) -> str: 285 | lr = len(row) 286 | if lr < max_len: 287 | if qubit_row: 288 | row += "-" * (max_len - lr) 289 | else: 290 | row += " " * (max_len - lr) 291 | return row 292 | 293 | out_rows = [extend_row(r, i % 2 == 0) for i, r in enumerate(rows)] 294 | return out_rows, [True] * len(rows) 295 | 296 | 297 | def print_circuit( 298 | gate_seq: Sequence[KrausOp], 299 | qubit_inds_seq: Sequence[Sequence[int]], 300 | param_inds_seq: Sequence[ 301 | Union[GateParameterIndices, Sequence[GateParameterIndices]] 302 | ], 303 | n_qubits: Optional[int] = None, 304 | qubit_min: int = 0, 305 | qubit_max: Optional[int] = None, 306 | gate_ind_min: int = 0, 307 | gate_ind_max: Optional[int] = None, 308 | sep_length: int = 1, 309 | ) -> List[str]: 310 | """ 311 | Returns and prints basic string representation of circuit. 312 | 313 | Args: 314 | gate_seq: Sequence of gates. 315 | Each element is either a string matching an array or function in qujax.gates, 316 | a unitary array (which will be reshaped into a tensor of shape (2,2,2,...) ) 317 | or a function taking parameters and returning gate unitary in tensor form. 318 | Or alternatively a sequence of the above representing Kraus operators. 319 | qubit_inds_seq: Sequences of qubits (ints) that gates are acting on. 320 | param_inds_seq: Sequence of parameter indices that gates are using, 321 | i.e. [[0], [], [5, 2]] tells qujax that the first gate uses the first parameter, 322 | the second gate is not parameterised and the third gates used the fifth and 323 | second parameters. 324 | n_qubits: Number of qubits, if fixed. 325 | qubit_min: Index of first qubit to display. 326 | qubit_max: Index of final qubit to display. 327 | gate_ind_min: Index of gate to start circuit printing. 328 | gate_ind_max: Index of gate to stop circuit printing. 329 | sep_length: Number of dashes to separate gates. 330 | 331 | Returns: 332 | String representation of circuit 333 | 334 | """ 335 | check_circuit(gate_seq, qubit_inds_seq, param_inds_seq, n_qubits, False) 336 | 337 | if gate_ind_max is None: 338 | gate_ind_max = len(gate_seq) - 1 339 | else: 340 | gate_ind_max = min(len(gate_seq) - 1, gate_ind_max) 341 | 342 | if gate_ind_min > gate_ind_max: 343 | raise TypeError("gate_ind_max must be larger or equal to gate_ind_min") 344 | 345 | if n_qubits is None: 346 | n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 347 | 348 | if qubit_max is None: 349 | qubit_max = n_qubits - 1 350 | else: 351 | qubit_max = min(n_qubits - 1, qubit_max) 352 | 353 | if qubit_min > qubit_max: 354 | raise TypeError("qubit_max must be larger or equal to qubit_min") 355 | 356 | gate_str_seq = [_get_gate_str(g, p) for g, p in zip(gate_seq, param_inds_seq)] 357 | 358 | n_qubits_disp = qubit_max - qubit_min + 1 359 | 360 | rows = [f"q{qubit_min}: ".ljust(3) + "-" * sep_length] 361 | if n_qubits_disp > 1: 362 | for i in range(qubit_min + 1, qubit_max + 1): 363 | rows += [" ", f"q{i}: ".ljust(3) + "-" * sep_length] 364 | rows, rows_free = _pad_rows(rows) 365 | 366 | for gate_ind in range(gate_ind_min, gate_ind_max + 1): 367 | g = gate_str_seq[gate_ind] 368 | qi = qubit_inds_seq[gate_ind] 369 | 370 | qi_min = min(qi) 371 | qi_max = max(qi) 372 | ri_min = 2 * qi_min # index of top row used by gate 373 | ri_max = 2 * qi_max # index of bottom row used by gate 374 | 375 | if not all([rows_free[i] for i in range(ri_min, ri_max)]): 376 | rows, rows_free = _pad_rows(rows) 377 | 378 | for row_ind in range(ri_min, ri_max + 1): 379 | if row_ind == 2 * qi[-1]: 380 | rows[row_ind] += "-" * sep_length + g 381 | elif row_ind % 2 == 1: 382 | rows[row_ind] += " " * sep_length + " " + "|" + " " 383 | elif row_ind / 2 in qi: 384 | rows[row_ind] += "-" * sep_length + "---" + "◯" + "---" 385 | else: 386 | rows[row_ind] += "-" * sep_length + "---" + "|" + "---" 387 | 388 | rows_free[row_ind] = False 389 | 390 | rows, _ = _pad_rows(rows) 391 | 392 | for p in rows: 393 | print(p) 394 | 395 | return rows 396 | 397 | 398 | def integers_to_bitstrings( 399 | integers: Union[int, jax.Array], nbits: Optional[int] = None 400 | ) -> jax.Array: 401 | """ 402 | Convert integer or array of integers into their binary expansion(s). 403 | 404 | Args: 405 | integers: Integer or array of integers to be converted. 406 | nbits: Length of output binary expansion. 407 | Defaults to smallest possible. 408 | 409 | Returns: 410 | Array of binary expansion(s). 411 | """ 412 | integers = jnp.atleast_1d(integers) 413 | # Guarantee `bitstrings` has the right type for type-checking purposes 414 | if not isinstance(integers, jax.Array): 415 | raise ValueError("This should not happen. Please open an issue on GitHub.") 416 | 417 | if nbits is None: 418 | nbits = int(jnp.ceil(jnp.log2(jnp.maximum(integers.max(), 1)) + 1e-5).item()) 419 | 420 | return jnp.squeeze( 421 | ((integers[:, None] & (1 << jnp.arange(nbits - 1, -1, -1))) > 0).astype(int) 422 | ) 423 | 424 | 425 | def bitstrings_to_integers(bitstrings: ArrayLike) -> jax.Array: 426 | """ 427 | Convert binary expansion(s) into integers. 428 | 429 | Args: 430 | bitstrings: Bitstring array or array of bitstring arrays. 431 | 432 | Returns: 433 | Array of integers. 434 | """ 435 | bitstrings = jnp.atleast_2d(bitstrings) 436 | 437 | # Guarantee `bitstrings` has the right type for type-checking purposes 438 | if not isinstance(bitstrings, jax.Array): 439 | raise ValueError("This should not happen. Please open an issue on GitHub.") 440 | 441 | convarr = 2 ** jnp.arange(bitstrings.shape[-1] - 1, -1, -1) 442 | return jnp.squeeze(bitstrings.dot(convarr)).astype(int) 443 | 444 | 445 | def sample_integers( 446 | random_key: random.PRNGKeyArray, 447 | statetensor: jax.Array, 448 | n_samps: int = 1, 449 | ) -> jax.Array: 450 | """ 451 | Generate random integer samples according to statetensor. 452 | 453 | Args: 454 | random_key: JAX random key to seed samples. 455 | statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes). 456 | n_samps: Number of samples to generate. Defaults to 1. 457 | 458 | Returns: 459 | Array with sampled integers, shape=(n_samps,). 460 | 461 | """ 462 | sv_probs = jnp.square(jnp.abs(statetensor.flatten())) 463 | sampled_inds = random.choice( 464 | random_key, a=jnp.arange(statetensor.size), shape=(n_samps,), p=sv_probs 465 | ) 466 | return sampled_inds 467 | 468 | 469 | def sample_bitstrings( 470 | random_key: random.PRNGKeyArray, 471 | statetensor: jax.Array, 472 | n_samps: int = 1, 473 | ) -> jax.Array: 474 | """ 475 | Generate random bitstring samples according to statetensor. 476 | 477 | Args: 478 | random_key: JAX random key to seed samples. 479 | statetensor: Statetensor encoding sampling probabilities (in the form of amplitudes). 480 | n_samps: Number of samples to generate. Defaults to 1. 481 | 482 | Returns: 483 | Array with sampled bitstrings, shape=(n_samps, statetensor.ndim). 484 | """ 485 | return integers_to_bitstrings( 486 | sample_integers(random_key, statetensor, n_samps), statetensor.ndim 487 | ) 488 | 489 | 490 | def statetensor_to_densitytensor(statetensor: jax.Array) -> jax.Array: 491 | """ 492 | Computes a densitytensor representation of a pure quantum state 493 | from its statetensor representaton 494 | 495 | Args: 496 | statetensor: Input statetensor. 497 | 498 | Returns: 499 | A densitytensor representing the quantum state. 500 | """ 501 | n_qubits = statetensor.ndim 502 | st = statetensor 503 | dt = (st.reshape(-1, 1) @ st.reshape(1, -1).conj()).reshape( 504 | 2 for _ in range(2 * n_qubits) 505 | ) 506 | return dt 507 | 508 | 509 | def repeat_circuit( 510 | circuit: Union[PureParameterizedCircuit, MixedParameterizedCircuit], 511 | nr_of_parameters: int, 512 | ) -> Callable[[jax.Array, jax.Array], jax.Array]: 513 | """ 514 | Repeats circuit encoded by `circuit` an arbitrary number of times. 515 | Avoids compilation overhead with increasing circuit depth. 516 | 517 | Args: 518 | circuit: The function encoding the circuit. 519 | nr_of_parameters: The number of parameters that `circuit` takes. 520 | 521 | Returns: 522 | A function taking an arbitrary number of parameters and returning as 523 | many applications of `circuit` as the number of parameters allows. 524 | An exception is thrown if this function is supplied with a parameter array 525 | of size not divisible by `nr_of_parameters`. 526 | """ 527 | 528 | def repeated_circuit(params: jax.Array, statetensor_in: jax.Array) -> jax.Array: 529 | def f(state, p): 530 | return circuit(p, state), None 531 | 532 | reshaped_parameters = params.reshape(-1, nr_of_parameters) 533 | result, _ = jax.lax.scan(f, statetensor_in, reshaped_parameters) 534 | return result 535 | 536 | return repeated_circuit 537 | -------------------------------------------------------------------------------- /qujax/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "1.1.0" 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | exec(open("qujax/version.py").read()) 4 | 5 | setup( 6 | name="qujax", 7 | author="Sam Duffield", 8 | author_email="sam.duffield@quantinuum.com", 9 | url="https://github.com/CQCL/qujax", 10 | description="Simulating quantum circuits with JAX", 11 | long_description=open("README.md").read(), 12 | long_description_content_type="text/markdown", 13 | license="Apache 2", 14 | packages=find_packages(), 15 | python_requires=">=3.8", 16 | install_requires=["jax>=0.4.1", "jaxlib", "typing_extensions"], 17 | classifiers=[ 18 | "Programming Language :: Python", 19 | "Intended Audience :: Developers", 20 | "Intended Audience :: Science/Research", 21 | "Topic :: Scientific/Engineering", 22 | ], 23 | include_package_data=True, 24 | platforms="any", 25 | version=__version__, 26 | ) 27 | -------------------------------------------------------------------------------- /tests/test_densitytensor.py: -------------------------------------------------------------------------------- 1 | from itertools import combinations 2 | 3 | from jax import jit 4 | from jax import numpy as jnp 5 | 6 | import qujax 7 | 8 | 9 | def test_kraus_single(): 10 | n_qubits = 3 11 | dim = 2**n_qubits 12 | density_matrix = jnp.arange(dim**2).reshape(dim, dim) 13 | density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) 14 | kraus_operator = qujax.gates.Rx(0.2) 15 | 16 | qubit_inds = (1,) 17 | 18 | unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) 19 | unitary_matrix = jnp.kron( 20 | unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1)) 21 | ) 22 | check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T 23 | 24 | # qujax._kraus_single 25 | qujax_kraus_dt = qujax._kraus_single(density_tensor, kraus_operator, qubit_inds) 26 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) 27 | 28 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) 29 | 30 | qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))( 31 | density_tensor, kraus_operator, qubit_inds 32 | ) 33 | qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) 34 | assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) 35 | 36 | # qujax.kraus (but for a single array) 37 | qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator, qubit_inds) 38 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) 39 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) 40 | 41 | qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))( 42 | density_tensor, kraus_operator, qubit_inds 43 | ) 44 | qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) 45 | assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) 46 | 47 | 48 | def test_kraus_single_2qubit(): 49 | n_qubits = 4 50 | dim = 2**n_qubits 51 | density_matrix = jnp.arange(dim**2).reshape(dim, dim) 52 | density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) 53 | kraus_operator_tensor = qujax.gates.ZZPhase(0.1) 54 | kraus_operator = qujax.gates.ZZPhase(0.1).reshape(4, 4) 55 | 56 | qubit_inds = (1, 2) 57 | 58 | unitary_matrix = jnp.kron(jnp.eye(2 * qubit_inds[0]), kraus_operator) 59 | unitary_matrix = jnp.kron( 60 | unitary_matrix, jnp.eye(2 * (n_qubits - qubit_inds[-1] - 1)) 61 | ) 62 | check_kraus_dm = unitary_matrix @ density_matrix @ unitary_matrix.conj().T 63 | 64 | # qujax._kraus_single 65 | qujax_kraus_dt = qujax._kraus_single( 66 | density_tensor, kraus_operator_tensor, qubit_inds 67 | ) 68 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) 69 | 70 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) 71 | 72 | qujax_kraus_dt_jit = jit(qujax._kraus_single, static_argnums=(2,))( 73 | density_tensor, kraus_operator_tensor, qubit_inds 74 | ) 75 | qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) 76 | assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) 77 | 78 | # qujax.kraus (but for a single array) 79 | qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operator_tensor, qubit_inds) 80 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) 81 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) 82 | 83 | qujax_kraus_dt = qujax.kraus( 84 | density_tensor, kraus_operator, qubit_inds 85 | ) # check reshape kraus_operator correctly 86 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) 87 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) 88 | 89 | qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))( 90 | density_tensor, kraus_operator_tensor, qubit_inds 91 | ) 92 | qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) 93 | assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) 94 | 95 | 96 | def test_kraus_multiple(): 97 | n_qubits = 3 98 | dim = 2**n_qubits 99 | density_matrix = jnp.arange(dim**2).reshape(dim, dim) 100 | density_tensor = density_matrix.reshape((2,) * 2 * n_qubits) 101 | 102 | kraus_operators = [ 103 | 0.25 * qujax.gates.H, 104 | 0.25 * qujax.gates.Rx(0.3), 105 | 0.5 * qujax.gates.Ry(0.1), 106 | ] 107 | 108 | qubit_inds = (1,) 109 | 110 | unitary_matrices = [ 111 | jnp.kron(jnp.eye(2 * qubit_inds[0]), ko) for ko in kraus_operators 112 | ] 113 | unitary_matrices = [ 114 | jnp.kron(um, jnp.eye(2 * (n_qubits - qubit_inds[0] - 1))) 115 | for um in unitary_matrices 116 | ] 117 | 118 | check_kraus_dm = jnp.zeros_like(density_matrix) 119 | for um in unitary_matrices: 120 | check_kraus_dm += um @ density_matrix @ um.conj().T 121 | 122 | qujax_kraus_dt = qujax.kraus(density_tensor, kraus_operators, qubit_inds) 123 | qujax_kraus_dm = qujax_kraus_dt.reshape(dim, dim) 124 | 125 | assert jnp.allclose(qujax_kraus_dm, check_kraus_dm) 126 | 127 | qujax_kraus_dt_jit = jit(qujax.kraus, static_argnums=(2,))( 128 | density_tensor, kraus_operators, qubit_inds 129 | ) 130 | qujax_kraus_dm_jit = qujax_kraus_dt_jit.reshape(dim, dim) 131 | assert jnp.allclose(qujax_kraus_dm_jit, check_kraus_dm) 132 | 133 | 134 | def test_params_to_densitytensor_func(): 135 | n_qubits = 2 136 | 137 | gate_seq = ["Rx" for _ in range(n_qubits)] 138 | qubit_inds_seq = [(i,) for i in range(n_qubits)] 139 | param_inds_seq = [(i,) for i in range(n_qubits)] 140 | 141 | gate_seq += ["CZ" for _ in range(n_qubits - 1)] 142 | qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] 143 | param_inds_seq += [() for _ in range(n_qubits - 1)] 144 | 145 | params_to_dt = qujax.get_params_to_densitytensor_func( 146 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits 147 | ) 148 | params_to_st = qujax.get_params_to_statetensor_func( 149 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits 150 | ) 151 | 152 | params = jnp.arange(n_qubits) / 10.0 153 | 154 | st = params_to_st(params) 155 | dt_test = qujax.statetensor_to_densitytensor(st) 156 | 157 | dt = params_to_dt(params) 158 | 159 | assert jnp.allclose(dt, dt_test) 160 | 161 | jit_dt = jit(params_to_dt)(params) 162 | assert jnp.allclose(jit_dt, dt_test) 163 | 164 | 165 | def test_params_to_densitytensor_func_with_bit_flip(): 166 | n_qubits = 2 167 | 168 | gate_seq = ["Rx" for _ in range(n_qubits)] 169 | qubit_inds_seq = [(i,) for i in range(n_qubits)] 170 | param_inds_seq = [(i,) for i in range(n_qubits)] 171 | 172 | gate_seq += ["CZ" for _ in range(n_qubits - 1)] 173 | qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] 174 | param_inds_seq += [() for _ in range(n_qubits - 1)] 175 | 176 | params_to_pre_bf_st = qujax.get_params_to_statetensor_func( 177 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits 178 | ) 179 | 180 | kraus_ops = [[0.3 * jnp.eye(2), 0.7 * qujax.gates.X]] 181 | kraus_qubit_inds = [(0,)] 182 | kraus_param_inds = [None] 183 | 184 | gate_seq += kraus_ops 185 | qubit_inds_seq += kraus_qubit_inds 186 | param_inds_seq += kraus_param_inds 187 | 188 | _ = qujax.print_circuit(gate_seq, qubit_inds_seq, param_inds_seq) 189 | 190 | params_to_dt = qujax.get_params_to_densitytensor_func( 191 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits 192 | ) 193 | 194 | params = jnp.arange(n_qubits) / 10.0 195 | 196 | pre_bf_st = params_to_pre_bf_st(params) 197 | pre_bf_dt = (pre_bf_st.reshape(-1, 1) @ pre_bf_st.reshape(1, -1).conj()).reshape( 198 | 2 for _ in range(2 * n_qubits) 199 | ) 200 | dt_test = qujax.kraus(pre_bf_dt, kraus_ops[0], kraus_qubit_inds[0]) 201 | 202 | dt = params_to_dt(params) 203 | 204 | assert jnp.allclose(dt, dt_test) 205 | 206 | jit_dt = jit(params_to_dt)(params) 207 | assert jnp.allclose(jit_dt, dt_test) 208 | 209 | 210 | def test_partial_trace_1(): 211 | state1 = 1 / jnp.sqrt(2) * jnp.array([1.0, 1.0]) 212 | state2 = jnp.kron(state1, state1) 213 | state3 = jnp.kron(state1, state2) 214 | 215 | dt1 = jnp.outer(state1, state1.conj()).reshape((2,) * 2) 216 | dt2 = jnp.outer(state2, state2.conj()).reshape((2,) * 4) 217 | dt3 = jnp.outer(state3, state3.conj()).reshape((2,) * 6) 218 | 219 | for i in range(3): 220 | assert jnp.allclose(qujax.partial_trace(dt3, [i]), dt2) 221 | 222 | for i in combinations(range(3), 2): 223 | assert jnp.allclose(qujax.partial_trace(dt3, i), dt1) 224 | 225 | 226 | def test_partial_trace_2(): 227 | n_qubits = 3 228 | 229 | gate_seq = ["Rx" for _ in range(n_qubits)] 230 | qubit_inds_seq = [(i,) for i in range(n_qubits)] 231 | param_inds_seq = [(i,) for i in range(n_qubits)] 232 | 233 | gate_seq += ["CZ" for _ in range(n_qubits - 1)] 234 | qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] 235 | param_inds_seq += [() for _ in range(n_qubits - 1)] 236 | 237 | params_to_dt = qujax.get_params_to_densitytensor_func( 238 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits 239 | ) 240 | 241 | params = jnp.arange(1, n_qubits + 1) / 10.0 242 | 243 | dt = params_to_dt(params) 244 | dt_discard_test = jnp.trace(dt, axis1=0, axis2=n_qubits) 245 | dt_discard = qujax.partial_trace(dt, [0]) 246 | 247 | assert jnp.allclose(dt_discard, dt_discard_test) 248 | 249 | 250 | def test_measure(): 251 | n_qubits = 3 252 | 253 | gate_seq = ["Rx" for _ in range(n_qubits)] 254 | qubit_inds_seq = [(i,) for i in range(n_qubits)] 255 | param_inds_seq = [(i,) for i in range(n_qubits)] 256 | 257 | gate_seq += ["CZ" for _ in range(n_qubits - 1)] 258 | qubit_inds_seq += [(i, i + 1) for i in range(n_qubits - 1)] 259 | param_inds_seq += [() for _ in range(n_qubits - 1)] 260 | 261 | params_to_dt = qujax.get_params_to_densitytensor_func( 262 | gate_seq, qubit_inds_seq, param_inds_seq, n_qubits 263 | ) 264 | 265 | params = jnp.arange(1, n_qubits + 1) / 10.0 266 | 267 | dt = params_to_dt(params) 268 | 269 | qubit_inds = [0] 270 | 271 | all_probs = jnp.diag(dt.reshape(2**n_qubits, 2**n_qubits)).real 272 | all_probs_marginalise = all_probs.reshape((2,) * n_qubits).sum( 273 | axis=[i for i in range(n_qubits) if i not in qubit_inds] 274 | ) 275 | 276 | probs = qujax.densitytensor_to_measurement_probabilities(dt, qubit_inds) 277 | 278 | assert jnp.isclose(probs.sum(), 1.0) 279 | assert jnp.isclose(all_probs.sum(), 1.0) 280 | assert jnp.allclose(probs, all_probs_marginalise) 281 | 282 | dm = dt.reshape(2**n_qubits, 2**n_qubits) 283 | projector = jnp.array([[1, 0], [0, 0]]) 284 | for _ in range(n_qubits - 1): 285 | projector = jnp.kron(projector, jnp.eye(2)) 286 | measured_dm = projector @ dm @ projector.T.conj() 287 | measured_dm /= jnp.trace(projector.T.conj() @ projector @ dm) 288 | measured_dt_true = measured_dm.reshape((2,) * 2 * n_qubits) 289 | 290 | measured_dt = qujax.densitytensor_to_measured_densitytensor(dt, qubit_inds, 0) 291 | measured_dt_bits = qujax.densitytensor_to_measured_densitytensor( 292 | dt, qubit_inds, jnp.zeros(n_qubits) 293 | ) 294 | assert jnp.allclose(measured_dt_true, measured_dt) 295 | assert jnp.allclose(measured_dt_true, measured_dt_bits) 296 | -------------------------------------------------------------------------------- /tests/test_expectations.py: -------------------------------------------------------------------------------- 1 | from jax import config, grad, jit 2 | from jax import numpy as jnp 3 | from jax import random 4 | 5 | import qujax 6 | 7 | 8 | def test_pauli_hermitian(): 9 | for p_str in ("X", "Y", "Z"): 10 | qujax.check_hermitian(p_str) 11 | qujax.check_hermitian(qujax.gates.__dict__[p_str]) 12 | 13 | 14 | def test_single_expectation(): 15 | Z = qujax.gates.Z 16 | 17 | st1 = jnp.zeros((2, 2, 2)) 18 | st2 = jnp.zeros((2, 2, 2)) 19 | st1 = st1.at[(0, 0, 0)].set(1.0) 20 | st2 = st2.at[(1, 0, 0)].set(1.0) 21 | dt1 = qujax.statetensor_to_densitytensor(st1) 22 | dt2 = qujax.statetensor_to_densitytensor(st2) 23 | ZZ = jnp.kron(Z, Z).reshape(2, 2, 2, 2) 24 | 25 | est1 = qujax.statetensor_to_single_expectation(st1, ZZ, [0, 1]) 26 | est2 = qujax.statetensor_to_single_expectation(st2, ZZ, [0, 1]) 27 | edt1 = qujax.densitytensor_to_single_expectation(dt1, ZZ, [0, 1]) 28 | edt2 = qujax.densitytensor_to_single_expectation(dt2, ZZ, [0, 1]) 29 | 30 | assert est1.item() == edt1.item() == 1 31 | assert est2.item() == edt2.item() == -1 32 | 33 | 34 | def test_bitstring_expectation(): 35 | n_qubits = 4 36 | 37 | gates = ( 38 | ["H"] * n_qubits 39 | + ["Ry"] * n_qubits 40 | + ["Rz"] * n_qubits 41 | + ["CX"] * (n_qubits - 1) 42 | + ["Ry"] * n_qubits 43 | + ["Rz"] * n_qubits 44 | ) 45 | qubits = ( 46 | [[i] for i in range(n_qubits)] * 3 47 | + [[i, i + 1] for i in range(n_qubits - 1)] 48 | + [[i] for i in range(n_qubits)] * 2 49 | ) 50 | param_inds = ( 51 | [[]] * n_qubits 52 | + [[i] for i in range(n_qubits * 2)] 53 | + [[]] * (n_qubits - 1) 54 | + [[i] for i in range(n_qubits * 2, n_qubits * 4)] 55 | ) 56 | 57 | param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds) 58 | 59 | n_params = n_qubits * 4 60 | params = random.uniform(random.PRNGKey(0), shape=(n_params,)) 61 | 62 | costs = random.normal(random.PRNGKey(1), shape=(2**n_qubits,)) 63 | 64 | def st_to_expectation(statetensor): 65 | probs = jnp.square(jnp.abs(statetensor.flatten())) 66 | return jnp.sum(costs * probs) 67 | 68 | param_to_expectation = lambda p: st_to_expectation(param_to_st(p)) 69 | 70 | def brute_force_param_to_exp(p): 71 | sv = param_to_st(p).flatten() 72 | return jnp.dot(sv, jnp.diag(costs) @ sv.conj()).real 73 | 74 | true_expectation = brute_force_param_to_exp(params) 75 | 76 | expectation = param_to_expectation(params) 77 | expectation_jit = jit(param_to_expectation)(params) 78 | 79 | assert expectation.shape == () 80 | assert expectation.dtype.name[:5] == "float" 81 | assert jnp.isclose(true_expectation, expectation) 82 | assert jnp.isclose(true_expectation, expectation_jit) 83 | 84 | true_expectation_grad = grad(brute_force_param_to_exp)(params) 85 | expectation_grad = grad(param_to_expectation)(params) 86 | expectation_grad_jit = jit(grad(param_to_expectation))(params) 87 | 88 | assert expectation_grad.shape == (n_params,) 89 | assert expectation_grad.dtype.name[:5] == "float" 90 | assert jnp.allclose(true_expectation_grad, expectation_grad, atol=1e-5) 91 | assert jnp.allclose(true_expectation_grad, expectation_grad_jit, atol=1e-5) 92 | 93 | 94 | def _test_hermitian_observable( 95 | hermitian_str_seq_seq, qubit_inds_seq, coefs, st_in=None 96 | ): 97 | n_qubits = max([max(qi) for qi in qubit_inds_seq]) + 1 98 | 99 | if st_in is None: 100 | state = ( 101 | random.uniform(random.PRNGKey(2), shape=(2**n_qubits,)) * 2 102 | + 1.0j * random.uniform(random.PRNGKey(1), shape=(2**n_qubits,)) * 2 103 | ) 104 | state /= jnp.linalg.norm(state) 105 | st_in = state.reshape((2,) * n_qubits) 106 | 107 | dt_in = qujax.statetensor_to_densitytensor(st_in) 108 | 109 | st_to_exp = qujax.get_statetensor_to_expectation_func( 110 | hermitian_str_seq_seq, qubit_inds_seq, coefs 111 | ) 112 | dt_to_exp = qujax.get_densitytensor_to_expectation_func( 113 | hermitian_str_seq_seq, qubit_inds_seq, coefs 114 | ) 115 | 116 | def big_hermitian_matrix(hermitian_str_seq, qubit_inds): 117 | qubit_arrs = [getattr(qujax.gates, s) for s in hermitian_str_seq] 118 | hermitian_arrs = [] 119 | j = 0 120 | for i in range(n_qubits): 121 | if i in qubit_inds: 122 | hermitian_arrs.append(qubit_arrs[j]) 123 | j += 1 124 | else: 125 | hermitian_arrs.append(jnp.eye(2)) 126 | 127 | big_h = hermitian_arrs[0] 128 | for k in range(1, n_qubits): 129 | big_h = jnp.kron(big_h, hermitian_arrs[k]) 130 | return big_h 131 | 132 | sum_big_hs = jnp.zeros((2**n_qubits, 2**n_qubits), dtype=complex) 133 | for i in range(len(hermitian_str_seq_seq)): 134 | sum_big_hs += coefs[i] * big_hermitian_matrix( 135 | hermitian_str_seq_seq[i], qubit_inds_seq[i] 136 | ) 137 | 138 | assert jnp.allclose(sum_big_hs, sum_big_hs.conj().T) 139 | 140 | sv = st_in.flatten() 141 | true_exp = jnp.dot(sv.conj(), sum_big_hs @ sv).real 142 | 143 | qujax_exp = st_to_exp(st_in) 144 | qujax_dt_exp = dt_to_exp(dt_in) 145 | qujax_exp_jit = jit(st_to_exp)(st_in) 146 | qujax_dt_exp_jit = jit(dt_to_exp)(dt_in) 147 | 148 | assert jnp.array(qujax_exp).shape == () 149 | assert jnp.array(qujax_exp).dtype.name[:5] == "float" 150 | assert jnp.isclose(true_exp, qujax_exp) 151 | assert jnp.isclose(true_exp, qujax_dt_exp) 152 | assert jnp.isclose(true_exp, qujax_exp_jit) 153 | assert jnp.isclose(true_exp, qujax_dt_exp_jit) 154 | 155 | st_to_samp_exp = qujax.get_statetensor_to_sampled_expectation_func( 156 | hermitian_str_seq_seq, qubit_inds_seq, coefs 157 | ) 158 | dt_to_samp_exp = qujax.get_densitytensor_to_sampled_expectation_func( 159 | hermitian_str_seq_seq, qubit_inds_seq, coefs 160 | ) 161 | qujax_samp_exp = st_to_samp_exp(st_in, random.PRNGKey(1), 1000000) 162 | qujax_samp_exp_jit = jit(st_to_samp_exp, static_argnums=2)( 163 | st_in, random.PRNGKey(1), 1000000 164 | ) 165 | qujax_samp_exp_dt = dt_to_samp_exp(dt_in, random.PRNGKey(1), 1000000) 166 | qujax_samp_exp_dt_jit = jit(dt_to_samp_exp, static_argnums=2)( 167 | dt_in, random.PRNGKey(1), 1000000 168 | ) 169 | assert jnp.array(qujax_samp_exp).shape == () 170 | assert jnp.array(qujax_samp_exp).dtype.name[:5] == "float" 171 | assert jnp.isclose(true_exp, qujax_samp_exp, rtol=1e-2) 172 | assert jnp.isclose(true_exp, qujax_samp_exp_jit, rtol=1e-2) 173 | assert jnp.isclose(true_exp, qujax_samp_exp_dt, rtol=1e-2) 174 | assert jnp.isclose(true_exp, qujax_samp_exp_dt_jit, rtol=1e-2) 175 | 176 | 177 | def test_X(): 178 | hermitian_str_seq_seq = ["X"] 179 | qubit_inds_seq = [[0]] 180 | coefs = [1] 181 | 182 | gates = ["H", "Rz"] 183 | qubit = [[0], [0]] 184 | param_ind = [[], [0]] 185 | st_in = qujax.get_params_to_statetensor_func(gates, qubit, param_ind)(0.3) 186 | 187 | _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs, st_in) 188 | 189 | 190 | def test_Y(): 191 | n_qubits = 1 192 | 193 | hermitian_str_seq_seq = ["Y"] * n_qubits 194 | qubit_inds_seq = [[i] for i in range(n_qubits)] 195 | coefs = jnp.ones(len(hermitian_str_seq_seq)) 196 | 197 | _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs) 198 | 199 | 200 | def test_Z(): 201 | n_qubits = 1 202 | 203 | hermitian_str_seq_seq = ["Z"] * n_qubits 204 | qubit_inds_seq = [[i] for i in range(n_qubits)] 205 | coefs = random.normal(random.PRNGKey(0), shape=(len(hermitian_str_seq_seq),)) 206 | 207 | _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs) 208 | 209 | 210 | def test_XYZ(): 211 | n_qubits = 1 212 | 213 | hermitian_str_seq_seq = ["X", "Y", "Z"] * n_qubits 214 | qubit_inds_seq = [[i] for _ in range(3) for i in range(n_qubits)] 215 | coefs = random.normal(random.PRNGKey(0), shape=(len(hermitian_str_seq_seq),)) 216 | 217 | _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs) 218 | 219 | 220 | def test_ZZ_Y(): 221 | config.update("jax_enable_x64", True) # Run this test with 64 bit precision 222 | 223 | n_qubits = 4 224 | 225 | hermitian_str_seq_seq = [["Z", "Z"]] * (n_qubits - 1) + [["Y"]] * n_qubits 226 | qubit_inds_seq = [[i, i + 1] for i in range(n_qubits - 1)] + [ 227 | [i] for i in range(n_qubits) 228 | ] 229 | coefs = random.normal(random.PRNGKey(1), shape=(len(hermitian_str_seq_seq),)) 230 | 231 | _test_hermitian_observable(hermitian_str_seq_seq, qubit_inds_seq, coefs) 232 | 233 | 234 | def test_sampling(): 235 | target_pmf = jnp.array([0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 0]) 236 | target_pmf /= target_pmf.sum() 237 | 238 | target_st = jnp.sqrt(target_pmf).reshape( 239 | (2,) * int(jnp.rint(jnp.log2(target_pmf.size))) 240 | ) 241 | 242 | n_samps = 7 243 | 244 | sample_ints = qujax.sample_integers(random.PRNGKey(0), target_st, n_samps) 245 | assert sample_ints.shape == (n_samps,) 246 | assert all(target_pmf[sample_ints] > 0) 247 | 248 | sample_bitstrings = qujax.sample_bitstrings(random.PRNGKey(0), target_st, n_samps) 249 | assert sample_bitstrings.shape == ( 250 | n_samps, 251 | int(jnp.rint(jnp.log2(target_pmf.size))), 252 | ) 253 | assert all(qujax.bitstrings_to_integers(sample_bitstrings) == sample_ints) 254 | -------------------------------------------------------------------------------- /tests/test_gates.py: -------------------------------------------------------------------------------- 1 | from qujax import check_unitary, gates 2 | 3 | 4 | def test_gates(): 5 | for g_str, g in gates.__dict__.items(): 6 | # Exclude elements in jax.gates namespace which are not gates 7 | if g_str[0] != "_" and g_str not in ("jax", "jnp"): 8 | check_unitary(g_str) 9 | check_unitary(g) 10 | -------------------------------------------------------------------------------- /tests/test_statetensor.py: -------------------------------------------------------------------------------- 1 | from jax import jit 2 | from jax import numpy as jnp 3 | 4 | import qujax 5 | 6 | 7 | def test_H(): 8 | gates = ["H"] 9 | qubits = [[0]] 10 | param_inds = [[]] 11 | 12 | param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds) 13 | st = param_to_st() 14 | st_jit = jit(param_to_st)() 15 | 16 | true_sv = jnp.array([0.70710678 + 0.0j, 0.70710678 + 0.0j]) 17 | 18 | assert st.size == true_sv.size 19 | assert jnp.allclose(st.flatten(), true_sv) 20 | assert jnp.allclose(st_jit.flatten(), true_sv) 21 | 22 | param_to_unitary = qujax.get_params_to_unitarytensor_func(gates, qubits, param_inds) 23 | unitary = param_to_unitary().reshape(2, 2) 24 | unitary_jit = jit(param_to_unitary)().reshape(2, 2) 25 | zero_sv = jnp.zeros(2).at[0].set(1) 26 | assert jnp.allclose(unitary @ zero_sv, true_sv) 27 | assert jnp.allclose(unitary_jit @ zero_sv, true_sv) 28 | 29 | 30 | def test_H_redundant_qubits(): 31 | gates = ["H"] 32 | qubits = [[0]] 33 | param_inds = [[]] 34 | n_qubits = 3 35 | 36 | param_to_st = qujax.get_params_to_statetensor_func( 37 | gates, qubits, param_inds, n_qubits 38 | ) 39 | st = param_to_st(statetensor_in=None) 40 | 41 | true_sv = jnp.array([0.70710678, 0.0, 0.0, 0.0, 0.70710678, 0.0, 0.0, 0.0]) 42 | 43 | assert st.size == true_sv.size 44 | assert jnp.allclose(st.flatten(), true_sv) 45 | 46 | param_to_unitary = qujax.get_params_to_unitarytensor_func( 47 | gates, qubits, param_inds, n_qubits 48 | ) 49 | unitary = param_to_unitary().reshape(2**n_qubits, 2**n_qubits) 50 | unitary_jit = jit(param_to_unitary)().reshape(2**n_qubits, 2**n_qubits) 51 | zero_sv = jnp.zeros(2**n_qubits).at[0].set(1) 52 | assert jnp.allclose(unitary @ zero_sv, true_sv) 53 | assert jnp.allclose(unitary_jit @ zero_sv, true_sv) 54 | 55 | 56 | def test_CX_Rz_CY(): 57 | gates = ["H", "H", "H", "CX", "Rz", "CY"] 58 | qubits = [[0], [1], [2], [0, 1], [1], [1, 2]] 59 | param_inds = [[], [], [], None, [0], []] 60 | 61 | param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds) 62 | param = jnp.array(0.1) 63 | st = param_to_st(param) 64 | 65 | true_sv = jnp.array( 66 | [ 67 | 0.34920055 - 0.05530793j, 68 | 0.34920055 - 0.05530793j, 69 | 0.05530793 - 0.34920055j, 70 | -0.05530793 + 0.34920055j, 71 | 0.34920055 - 0.05530793j, 72 | 0.34920055 - 0.05530793j, 73 | 0.05530793 - 0.34920055j, 74 | -0.05530793 + 0.34920055j, 75 | ], 76 | dtype="complex64", 77 | ) 78 | 79 | assert st.size == true_sv.size 80 | assert jnp.allclose(st.flatten(), true_sv) 81 | 82 | n_qubits = 3 83 | param_to_unitary = qujax.get_params_to_unitarytensor_func( 84 | gates, qubits, param_inds, n_qubits 85 | ) 86 | unitary = param_to_unitary(param).reshape(2**n_qubits, 2**n_qubits) 87 | unitary_jit = jit(param_to_unitary)(param).reshape(2**n_qubits, 2**n_qubits) 88 | zero_sv = jnp.zeros(2**n_qubits).at[0].set(1) 89 | assert jnp.allclose(unitary @ zero_sv, true_sv) 90 | assert jnp.allclose(unitary_jit @ zero_sv, true_sv) 91 | 92 | 93 | def test_stacked_circuits(): 94 | gates = ["H"] 95 | qubits = [[0]] 96 | param_inds = [[]] 97 | 98 | param_to_st = qujax.get_params_to_statetensor_func(gates, qubits, param_inds) 99 | 100 | st1 = param_to_st() 101 | st2 = param_to_st(st1) 102 | 103 | st2_2 = param_to_st(statetensor_in=st1) 104 | 105 | all_zeros_sv = jnp.array(jnp.arange(st2.size) == 0, dtype=int) 106 | 107 | assert jnp.allclose(st2.flatten(), all_zeros_sv, atol=1e-7) 108 | assert jnp.allclose(st2_2.flatten(), all_zeros_sv, atol=1e-7) 109 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import jax 2 | import jax.numpy as jnp 3 | 4 | import qujax 5 | 6 | 7 | def test_repeat_circuit(): 8 | n_qubits = 4 9 | depth = 4 10 | seed = 0 11 | statetensor_in = qujax.all_zeros_statetensor(n_qubits) 12 | densitytensor_in = qujax.all_zeros_densitytensor(n_qubits) 13 | 14 | def circuit(n_qubits: int, depth: int): 15 | parameter_index = 0 16 | 17 | gates = [] 18 | qubit_inds = [] 19 | param_inds = [] 20 | 21 | for _ in range(depth): 22 | # Rx layer 23 | for i in range(n_qubits): 24 | gates.append("Rx") 25 | qubit_inds.append([i]) 26 | param_inds.append([parameter_index]) 27 | parameter_index += 1 28 | 29 | # CRz layer 30 | for i in range(n_qubits - 1): 31 | gates.append("CRz") 32 | qubit_inds.append([i, i + 1]) 33 | param_inds.append([parameter_index]) 34 | parameter_index += 1 35 | 36 | return gates, qubit_inds, param_inds, parameter_index 37 | 38 | rng = jax.random.PRNGKey(seed) 39 | 40 | g1, qi1, pi1, np1 = circuit(n_qubits, depth) 41 | 42 | params = jax.random.uniform(rng, (np1,)) 43 | 44 | param_to_st = qujax.get_params_to_statetensor_func(g1, qi1, pi1, n_qubits) 45 | 46 | g2, qi2, pi2, np2 = circuit(n_qubits, 1) 47 | 48 | param_to_st_single_repetition = qujax.get_params_to_statetensor_func( 49 | g2, qi2, pi2, n_qubits 50 | ) 51 | param_to_st_repeated = qujax.repeat_circuit(param_to_st_single_repetition, np2) 52 | 53 | assert jnp.allclose( 54 | param_to_st(params, statetensor_in), 55 | param_to_st_repeated(params, statetensor_in), 56 | ) 57 | 58 | param_to_dt = qujax.get_params_to_densitytensor_func(g1, qi1, pi1, n_qubits) 59 | 60 | param_to_dt_single_repetition = qujax.get_params_to_densitytensor_func( 61 | g2, qi2, pi2, n_qubits 62 | ) 63 | param_to_dt_repeated = qujax.repeat_circuit(param_to_dt_single_repetition, np2) 64 | 65 | assert jnp.allclose( 66 | param_to_dt(params, densitytensor_in), 67 | param_to_dt_repeated(params, densitytensor_in), 68 | ) 69 | --------------------------------------------------------------------------------