├── koala ├── VERSION ├── statevector │ ├── __init__.py │ ├── constructors.py │ └── statevector.py ├── __init__.py ├── version.py ├── peps │ ├── __init__.py │ ├── constructors.py │ ├── sites.py │ ├── peps.py │ ├── contraction.py │ └── update.py ├── quantum_state.py ├── tensors.py ├── observable.py └── gates.py ├── test ├── __init__.py ├── test_statevector.py └── test_peps.py ├── setup.py ├── LICENSE ├── .gitignore └── README.md /koala/VERSION: -------------------------------------------------------------------------------- 1 | 0.3.0 -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /koala/statevector/__init__.py: -------------------------------------------------------------------------------- 1 | from .statevector import StateVector, braket 2 | from .constructors import computational_zeros, computational_ones, computational_basis, random 3 | -------------------------------------------------------------------------------- /koala/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | koala 3 | """ 4 | 5 | from .version import VERSION as __version__ 6 | 7 | from .quantum_state import QuantumState, Gate 8 | from .observable import Observable 9 | -------------------------------------------------------------------------------- /koala/version.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module imports the version number. 3 | """ 4 | 5 | import os 6 | 7 | with open(os.path.join(os.path.dirname(__file__), 'VERSION')) as version_file: 8 | VERSION = version_file.read().strip() 9 | -------------------------------------------------------------------------------- /koala/peps/__init__.py: -------------------------------------------------------------------------------- 1 | from .peps import PEPS, braket, save, load, make_expectation_cache 2 | from .constructors import computational_zeros, computational_ones, computational_basis, random, identity 3 | from .contraction import Snake, ABMPS, BMPS, SingleLayer, Square, TRG, contract_options 4 | from .update import DirectUpdate, QRUpdate, LocalGramQRUpdate, LocalGramQRSVDUpdate, DefaultUpdate 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | import os 3 | 4 | 5 | VERSION_PATH = os.path.join(os.path.dirname(__file__), 'koala', 'VERSION') 6 | with open(VERSION_PATH) as version_file: 7 | VERSION = version_file.read().strip() 8 | 9 | 10 | setup( 11 | name='koala', 12 | version=VERSION, 13 | packages=find_packages(exclude=[]), 14 | package_data={ 15 | 'koala': ['VERSION'], 16 | }, 17 | install_requires=[ 18 | 'numpy>=1.17', 19 | ], 20 | ) 21 | -------------------------------------------------------------------------------- /koala/statevector/constructors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorbackends 3 | 4 | from .statevector import StateVector 5 | 6 | 7 | def computational_zeros(nsite, *, backend='numpy'): 8 | backend = tensorbackends.get(backend) 9 | tensor = backend.zeros((2,)*nsite, dtype=complex) 10 | tensor[(0,)*nsite] = 1 11 | return StateVector(tensor, backend) 12 | 13 | 14 | def computational_ones(nsite, *, backend='numpy'): 15 | backend = tensorbackends.get(backend) 16 | tensor = backend.zeros((2,)*nsite, dtype=complex) 17 | tensor[(1,)*nsite] = 1 18 | return StateVector(tensor, backend) 19 | 20 | 21 | def computational_basis(nsite, bits, *, backend='numpy'): 22 | backend = tensorbackends.get(backend) 23 | bits = np.asarray(bits).reshape(nsite) 24 | tensor = backend.zeros((2,)*nsite, dtype=complex) 25 | tensor[tuple(bits)] = 1 26 | return StateVector(tensor, backend) 27 | 28 | def random(nsite, *, backend='numpy'): 29 | backend = tensorbackends.get(backend) 30 | shape = (2,)*nsite 31 | tensor = backend.random.uniform(-1,1,shape) + 1j * backend.random.uniform(-1,1,shape) 32 | tensor /= backend.norm(tensor) 33 | return StateVector(tensor, backend) 34 | -------------------------------------------------------------------------------- /koala/quantum_state.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines the interface of a quantum state. 3 | """ 4 | 5 | from collections import namedtuple 6 | 7 | 8 | Gate = namedtuple('Gate', ['name', 'parameters', 'qubits']) 9 | 10 | 11 | class QuantumState: 12 | @property 13 | def nsite(self): 14 | raise NotImplementedError() 15 | 16 | def copy(self): 17 | raise NotImplementedError() 18 | 19 | def norm(self): 20 | raise NotImplementedError() 21 | 22 | def conjugate(self): 23 | raise NotImplementedError() 24 | 25 | def apply_gate(self, gate): 26 | raise NotImplementedError() 27 | 28 | def apply_circuit(self, gates): 29 | raise NotImplementedError() 30 | 31 | def apply_operator(self, operator, sites): 32 | raise NotImplementedError() 33 | 34 | def __imul__(self, a): 35 | raise NotImplementedError() 36 | 37 | def __itruediv__(self, a): 38 | raise NotImplementedError() 39 | 40 | def amplitude(self, indices): 41 | raise NotImplementedError() 42 | 43 | def probability(self, indices): 44 | raise NotImplementedError() 45 | 46 | def expectation(self, observable): 47 | raise NotImplementedError() 48 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The Clear BSD License 2 | 3 | Copyright (c) 2019-2020, Koala Developers. 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted (subject to the limitations in the disclaimer 8 | below) provided that the following conditions are met: 9 | 10 | * Redistributions of source code must retain the above copyright notice, 11 | this list of conditions and the following disclaimer. 12 | 13 | * Redistributions in binary form must reproduce the above copyright 14 | notice, this list of conditions and the following disclaimer in the 15 | documentation and/or other materials provided with the distribution. 16 | 17 | * Neither the name of the copyright holder nor the names of its 18 | contributors may be used to endorse or promote products derived from this 19 | software without specific prior written permission. 20 | 21 | NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY 22 | THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND 23 | CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 24 | LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A 25 | PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR 26 | CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 27 | EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 28 | PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR 29 | BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER 30 | IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 31 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 32 | POSSIBILITY OF SUCH DAMAGE. 33 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # IDE & editor configurations 107 | .vscode/ -------------------------------------------------------------------------------- /koala/peps/constructors.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorbackends 3 | 4 | from .peps import PEPS 5 | 6 | def computational_zeros(nrow, ncol, backend='numpy'): 7 | backend = tensorbackends.get(backend) 8 | grid = np.empty((nrow, ncol), dtype=object) 9 | for i, j in np.ndindex(nrow, ncol): 10 | grid[i, j] = backend.astensor(np.array([1,0],dtype=complex).reshape(1,1,1,1,2,1)) 11 | return PEPS(grid, backend) 12 | 13 | 14 | def computational_ones(nrow, ncol, backend='numpy'): 15 | backend = tensorbackends.get(backend) 16 | grid = np.empty((nrow, ncol), dtype=object) 17 | for i, j in np.ndindex(nrow, ncol): 18 | grid[i, j] = backend.astensor(np.array([0,1],dtype=complex).reshape(1,1,1,1,2,1)) 19 | return PEPS(grid, backend) 20 | 21 | 22 | def computational_basis(nrow, ncol, bits, backend='numpy'): 23 | backend = tensorbackends.get(backend) 24 | bits = np.asarray(bits).reshape(nrow, ncol) 25 | grid = np.empty_like(bits, dtype=object) 26 | for i, j in np.ndindex(*bits.shape): 27 | grid[i, j] = backend.astensor( 28 | np.array([0,1] if bits[i,j] else [1,0],dtype=complex).reshape(1,1,1,1,2,1) 29 | ) 30 | return PEPS(grid, backend) 31 | 32 | 33 | def random(nrow, ncol, rank, physical_dim=2, dual_dim=1, backend='numpy'): 34 | backend = tensorbackends.get(backend) 35 | grid = np.empty((nrow, ncol), dtype=object) 36 | for i, j in np.ndindex(nrow, ncol): 37 | shape = ( 38 | rank if i > 0 else 1, 39 | rank if j < ncol - 1 else 1, 40 | rank if i < nrow - 1 else 1, 41 | rank if j > 0 else 1, 42 | physical_dim, dual_dim, 43 | ) 44 | grid[i, j] = backend.random.uniform(-1,1,shape) + 1j * backend.random.uniform(-1,1,shape) 45 | return PEPS(grid, backend) 46 | 47 | 48 | def identity(nrow, ncol, backend='numpy'): 49 | backend = tensorbackends.get(backend) 50 | grid = np.empty((nrow, ncol), dtype=object) 51 | for i, j in np.ndindex(nrow, ncol): 52 | grid[i, j] = backend.astensor(np.eye(2,dtype=complex).reshape(1,1,1,1,2,2)) 53 | return PEPS(grid, backend) 54 | -------------------------------------------------------------------------------- /koala/peps/sites.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements operations on PEPS sites. 3 | 4 | Here are the convetions for names of directions and order of axes. 5 | 6 | Direction names: | Axes order: 7 | z | 4 0 8 | |/ | |/ 9 | -- O -- y | 3 -- O -- 1 10 | /| | /| 11 | x | 2 5 12 | 13 | We use leg 0, 1, 2, 3 as bond dimensions, leg 4 as the physical dimension, 14 | and leg 5 as the dimension of the dual space w.r.t leg 4. Also conventionally, 15 | +x direction corresponds to the direction that row number grows, while 16 | +y direction corresponds to the direction that column number grows. 17 | """ 18 | 19 | import numpy as np 20 | 21 | 22 | def contract_x(a, b): 23 | return a.backend.einsum('abidpq,iBcDPQ->a(bB)c(dD)(pP)(qQ)', a, b) 24 | 25 | def contract_y(a, b): 26 | return a.backend.einsum('aicdpq,AbCiPQ->(aA)b(cC)d(pP)(qQ)', a, b) 27 | 28 | def contract_z(a, b): 29 | return a.backend.einsum('abcdiq,ABCDpi->(aA)(bB)(cC)(dD)pq', a, b) 30 | 31 | 32 | def reduce_x(a, b, option): 33 | u, _, vh = a.backend.einsumsvd('abidpq,iBcDPQ->abIdpq,IBcDPQ', a, b, option=option, absorb_s='even') 34 | return u, vh 35 | 36 | def reduce_y(a, b, option): 37 | u, _, vh = a.backend.einsumsvd('aicdpq,AbCiPQ->aIcdpq,AbCIPQ', a, b, option=option, absorb_s='even') 38 | return u, vh 39 | 40 | def reduce_z(a, b, option): 41 | u, _, vh = a.backend.einsumsvd('abcdpi,ABCDiq->abcdpI,ABCDIq', a, b, option=option, absorb_s='even') 42 | return u, vh 43 | 44 | 45 | def rotate_x(a, n=1): 46 | p = np.roll([4, 1, 5, 3], n) 47 | return a.transpose(0, p[1], 2, p[3], p[0], p[2]) 48 | 49 | def rotate_y(a, n=1): 50 | p = np.roll([4, 0, 5, 2], n) 51 | return a.transpose(p[2], 1, p[4], 3, p[0], p[3]) 52 | 53 | def rotate_z(a, n=1): 54 | p = np.roll([0, 1, 2, 3], n) 55 | return a.transpose(p[0], p[1], p[2], p[3], 4, 5) 56 | 57 | 58 | def flip_x(a): 59 | return a.transpose(2, 1, 0, 3, 4, 5) 60 | 61 | def flip_y(a): 62 | return a.transpose(0, 3, 2, 1, 4, 5) 63 | 64 | def flip_z(a): 65 | return a.transpose(0, 1, 2, 3, 5, 4) 66 | 67 | 68 | def trace_z(a): 69 | return a.backend.einsum('abcdii->abcd()()', a) 70 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Koala 2 | 3 | Koala is a quantum circuit/state simulator based on projected entangled–pair states (PEPS) tensor networks. 4 | 5 | ### Installation 6 | 7 | NumPy and [TensorBackends](https://github.com/cyclops-community/tensorbackends) are required. 8 | It is recommended to install Koala and TensorBackends in the editable (`pip install -e`) mode, as they are currently under development. 9 | 10 | ```console 11 | git clone https://github.com/cyclops-community/tensorbackends.git 12 | pip install -e ./tensorbackends 13 | git clone https://github.com/YiqingZhouKelly/koala.git 14 | pip install -e ./koala 15 | ``` 16 | 17 | Parallelization is provided by [Cyclops Tensor Framework](https://github.com/cyclops-community/ctf), and is optional. 18 | 19 | ### Testing 20 | ```console 21 | python -m unittest 22 | ``` 23 | 24 | ### Get Started 25 | ```python 26 | from koala import peps, Observable, Gate 27 | from tensorbackends.interface import ImplicitRandomizedSVD 28 | 29 | # initialize a 2 by 3 state with peps approach and numpy backend 30 | qstate = peps.computational_zeros(2, 3, backend='numpy') 31 | 32 | # we also provide the state vector approach and a parallel backend 33 | # statevector.computational_zeros(2, 3, backend='ctf') 34 | 35 | # apply one gate or a list of gates 36 | qstate.apply_gate(Gate('X', [], [0])) # (name, parameters, qubits) 37 | qstate.apply_circuit([ 38 | Gate('R', [0.42], [2]), 39 | Gate('CX', [], [1,4]) 40 | ], update_option=peps.LocalGramQRUpdate(rank=4)) 41 | # choose from a list of update algorithms and optionally specify the cap bond dimension for approximate state evolution 42 | 43 | # or apply arbitrary single-site or two-site operators 44 | # qstate.apply_operator(np.array(...), [0]) 45 | 46 | # compute the amplitude, probability, and expectation value 47 | qstate.amplitude([1,0,0,1,0,0]) 48 | qstate.probability([1,0,0,1,0,0]) 49 | observable = 1.5 * Observable.sum([ 50 | Observable.Z(0), 51 | Observable.XY(3, 4) * 2 52 | ]) 53 | qstate.expectation(observable, 54 | contract_option=peps.BMPS(ImplicitRandomizedSVD(rank=8)), 55 | use_cache=True 56 | ) 57 | # choose from a list of contraction algorithms and optionally specify the cap bond dimension for approximate contraction 58 | # use built-in caching option to trade memory for time in expectation value calculations 59 | ``` 60 | -------------------------------------------------------------------------------- /koala/tensors.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines some common tensors in NumPy. 3 | """ 4 | 5 | from math import sqrt 6 | from itertools import repeat 7 | 8 | import numpy as np 9 | 10 | 11 | def control(nctrl, tensor): 12 | nqubit = nctrl + tensor.ndim // 2 13 | result = np.eye(2**nqubit, dtype=complex) 14 | result[2**nctrl:, 2**nctrl:] = tensor 15 | return result.reshape(*repeat(2, nqubit*2)) 16 | 17 | def H(): 18 | return (np.array([1,1,1,-1],dtype=complex)/np.sqrt(2)).reshape(2,2) 19 | 20 | def X(): 21 | return np.array([0,1,1,0],dtype=complex).reshape(2,2) 22 | 23 | def Y(): 24 | return np.array([0,-1j,1j,0],dtype=complex).reshape(2,2) 25 | 26 | def Z(): 27 | return np.array([1,0,0,-1],dtype=complex).reshape(2,2) 28 | 29 | def S(): 30 | return np.array([1,0,0,1j],dtype=complex).reshape(2,2) 31 | 32 | def Sdag(): 33 | return np.array([1,0,0,-1j],dtype=complex).reshape(2,2) 34 | 35 | def T(): 36 | return U1(np.pi/4) 37 | 38 | def Tdag(): 39 | return U1(-np.pi/4) 40 | 41 | def W(): 42 | return (X()+Y())/sqrt(2) 43 | 44 | def sqrtX(): 45 | return np.array([1+1j,1-1j,1-1j,1+1j],dtype=complex).reshape(2,2)/2 46 | 47 | def sqrtY(): 48 | return np.array([1+1j,-1-1j,1+1j,1+1j],dtype=complex).reshape(2,2)/2 49 | 50 | def sqrtZ(): 51 | return S() 52 | 53 | def sqrtW(): 54 | return np.array([1+1j,-sqrt(2)*1j,sqrt(2),1+1j],dtype=complex).reshape(2,2)/2 55 | 56 | def R(theta): 57 | return U1(theta) 58 | 59 | def U1(lmbda): 60 | return U3(0, 0, lmbda) 61 | 62 | def U2(phi, lmbda): 63 | return U3(np.pi/2, phi, lmbda) 64 | 65 | def U3(theta, phi, lmbda): 66 | c, s = np.cos(theta), np.sin(theta) 67 | e_phi, e_lmbda = np.exp(1j*phi), np.exp(1j*lmbda) 68 | return np.array([c,-e_lmbda*s,e_phi*s,e_lmbda*e_phi*c],dtype=complex).reshape(2,2) 69 | 70 | def SWAP(): 71 | return np.array([1,0,0,0,0,0,1,0,0,1,0,0,0,0,0,1],dtype=complex).reshape(2,2,2,2) 72 | 73 | def ISWAP(): 74 | return np.array([1,0,0,0,0,0,1j,0,0,1j,0,0,0,0,0,1],dtype=complex).reshape(2,2,2,2) 75 | 76 | def XX(): 77 | return np.einsum('ij,kl->ikjl', X(), X()) 78 | 79 | def XY(): 80 | return np.einsum('ij,kl->ikjl', X(), Y()) 81 | 82 | def XZ(): 83 | return np.einsum('ij,kl->ikjl', X(), Z()) 84 | 85 | def YY(): 86 | return np.einsum('ij,kl->ikjl', Y(), Y()) 87 | 88 | def YZ(): 89 | return np.einsum('ij,kl->ikjl', Y(), Z()) 90 | 91 | def ZZ(): 92 | return np.einsum('ij,kl->ikjl', Z(), Z()) 93 | -------------------------------------------------------------------------------- /test/test_statevector.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from tensorbackends.utils import test_with_backend 5 | 6 | from koala import Observable, statevector, Gate 7 | 8 | 9 | @test_with_backend() 10 | class TestStateVector(unittest.TestCase): 11 | def test_norm(self, backend): 12 | qstate = statevector.computational_zeros(6, backend=backend) 13 | qstate.apply_circuit([ 14 | Gate('X', [], [0]), 15 | Gate('H', [], [1]), 16 | Gate('CX', [], [0,3]), 17 | Gate('CX', [], [1,4]), 18 | Gate('S', [], [1]), 19 | ]) 20 | self.assertTrue(backend.isclose(qstate.norm(), 1)) 21 | qstate *= 2 22 | self.assertTrue(backend.isclose(qstate.norm(), 2)) 23 | qstate /= 2j 24 | self.assertTrue(backend.isclose(qstate.norm(), 1)) 25 | 26 | def test_amplitude(self, backend): 27 | qstate = statevector.computational_zeros(6, backend=backend) 28 | qstate.apply_circuit([ 29 | Gate('X', [], [0]), 30 | Gate('H', [], [1]), 31 | Gate('CX', [], [0,3]), 32 | Gate('CX', [], [1,4]), 33 | Gate('S', [], [1]), 34 | ]) 35 | self.assertTrue(backend.isclose(qstate.amplitude([1,0,0,1,0,0]), 1/np.sqrt(2))) 36 | self.assertTrue(backend.isclose(qstate.amplitude([1,1,0,1,1,0]), 1j/np.sqrt(2))) 37 | 38 | def test_probablity(self, backend): 39 | qstate = statevector.computational_zeros(6, backend=backend) 40 | qstate.apply_circuit([ 41 | Gate('X', [], [0]), 42 | Gate('H', [], [1]), 43 | Gate('CX', [], [0,3]), 44 | Gate('CX', [], [1,4]), 45 | Gate('S', [], [1]), 46 | ]) 47 | self.assertTrue(backend.isclose(qstate.probability([1,0,0,1,0,0]), 1/2)) 48 | self.assertTrue(backend.isclose(qstate.probability([1,1,0,1,1,0]), 1/2)) 49 | 50 | def test_expectation(self, backend): 51 | qstate = statevector.computational_zeros(6, backend=backend) 52 | qstate.apply_circuit([ 53 | Gate('X', [], [0]), 54 | Gate('CX', [], [0,3]), 55 | Gate('H', [], [2]), 56 | ]) 57 | observable = 1.5 * Observable.sum([ 58 | Observable.Z(0) * 2, 59 | Observable.Z(1), 60 | Observable.Z(2) * 2, 61 | Observable.Z(3), 62 | ]) 63 | self.assertTrue(backend.isclose(qstate.expectation(observable), -3)) 64 | 65 | def test_add(self, backend): 66 | psi = statevector.computational_zeros(6, backend=backend) 67 | phi = statevector.computational_ones(6, backend=backend) 68 | self.assertTrue(backend.isclose((psi + phi).norm(), np.sqrt(2))) 69 | 70 | def test_inner(self, backend): 71 | psi = statevector.computational_zeros(6, backend=backend) 72 | psi.apply_circuit([ 73 | Gate('H', [], [0]), 74 | Gate('CX', [], [0,3]), 75 | Gate('H', [], [3]), 76 | ]) 77 | phi = statevector.computational_zeros(6, backend=backend) 78 | self.assertTrue(backend.isclose(psi.inner(phi), 0.5)) 79 | -------------------------------------------------------------------------------- /koala/observable.py: -------------------------------------------------------------------------------- 1 | """ 2 | """ 3 | 4 | from numbers import Real 5 | 6 | import numpy as np 7 | 8 | from . import tensors 9 | 10 | 11 | class Observable: 12 | def __init__(self, operators): 13 | self.operators = operators 14 | 15 | @staticmethod 16 | def zero(): 17 | return Observable([]) 18 | 19 | @staticmethod 20 | def X(qubit): 21 | return Observable([(tensors.X(), (qubit,))]) 22 | 23 | @staticmethod 24 | def Y(qubit): 25 | return Observable([(tensors.Y(), (qubit,))]) 26 | 27 | @staticmethod 28 | def Z(qubit): 29 | return Observable([(tensors.Z(), (qubit,))]) 30 | 31 | @staticmethod 32 | def XX(first, second): 33 | return Observable([(tensors.XX(), (first, second))]) 34 | 35 | @staticmethod 36 | def XY(first, second): 37 | return Observable([(tensors.XY(), (first, second))]) 38 | 39 | @staticmethod 40 | def XZ(first, second): 41 | return Observable([(tensors.XZ(), (first, second))]) 42 | 43 | @staticmethod 44 | def YY(first, second): 45 | return Observable([(tensors.YY(), (first, second))]) 46 | 47 | @staticmethod 48 | def YZ(first, second): 49 | return Observable([(tensors.YZ(), (first, second))]) 50 | 51 | @staticmethod 52 | def ZZ(first, second): 53 | return Observable([(tensors.ZZ(), (first, second))]) 54 | 55 | @staticmethod 56 | def operator(tensor, qubits): 57 | if tensor.ndim != len(qubits) * 2: 58 | raise ValueError(f'tensor shape and number of target qubits do not match') 59 | return Observable([(tensor, qubits)]) 60 | 61 | @staticmethod 62 | def sum(observables): 63 | result = Observable.zero() 64 | for observable in observables: 65 | result += observable 66 | return result 67 | 68 | def __iter__(self): 69 | yield from self.operators 70 | 71 | def scale(self, a): 72 | return Observable([(tensor*a, qubits) for tensor, qubits in self.operators]) 73 | 74 | def __pos__(self): 75 | return Observable([*self.operators]) 76 | 77 | def __neg__(self): 78 | return Observable([(-tensor, qubits) for tensor, qubits in self.operators]) 79 | 80 | def __add__(self, other): 81 | return Observable([*self, *other]) 82 | 83 | def __iadd__(self, other): 84 | self.operators.extend(other.operators) 85 | return self 86 | 87 | def __mul__(self, other): 88 | if isinstance(other, Real): 89 | return self.scale(other) 90 | return NotImplemented 91 | 92 | def __rmul__(self, other): 93 | if isinstance(other, Real): 94 | return self.scale(other) 95 | return NotImplemented 96 | 97 | def __str__(self): 98 | operators_str = ';'.join( 99 | f'{operator},{qubits}' 100 | for operator, qubits in self.operators 101 | ) 102 | return f"Observable({operators_str})" 103 | 104 | def copy(self): 105 | operators = [] 106 | for tensor, qubits in self.operators: 107 | operators.append((tensor.copy(), qubits)) 108 | return Observable(operators) 109 | -------------------------------------------------------------------------------- /koala/gates.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines the gates supported by the simulators and their tensor 3 | representations. 4 | """ 5 | 6 | from functools import lru_cache 7 | 8 | from . import tensors 9 | 10 | 11 | def tensorize(backend, gate_name, *gate_parameters): 12 | if gate_name not in _GATES: 13 | raise ValueError(f"{gate_name} gate is not supported") 14 | return _GATES[gate_name](backend, *gate_parameters) 15 | 16 | 17 | _GATES = {} 18 | 19 | def _register(func): 20 | _GATES[func.__name__] = func 21 | return func 22 | 23 | @_register 24 | @lru_cache(maxsize=None) 25 | def H(backend): 26 | return backend.astensor(tensors.H()) 27 | 28 | @_register 29 | @lru_cache(maxsize=None) 30 | def X(backend): 31 | return backend.astensor(tensors.X()) 32 | 33 | @_register 34 | @lru_cache(maxsize=None) 35 | def Y(backend): 36 | return backend.astensor(tensors.Y()) 37 | 38 | @_register 39 | @lru_cache(maxsize=None) 40 | def Z(backend): 41 | return backend.astensor(tensors.Z()) 42 | 43 | @_register 44 | @lru_cache(maxsize=None) 45 | def S(backend): 46 | return backend.astensor(tensors.S()) 47 | 48 | @_register 49 | @lru_cache(maxsize=None) 50 | def Sdag(backend): 51 | return backend.astensor(tensors.Sdag()) 52 | 53 | @_register 54 | @lru_cache(maxsize=None) 55 | def T(backend): 56 | return backend.astensor(tensors.T()) 57 | 58 | @_register 59 | @lru_cache(maxsize=None) 60 | def Tdag(backend): 61 | return backend.astensor(tensors.Tdag()) 62 | 63 | @_register 64 | @lru_cache(maxsize=None) 65 | def W(backend): 66 | return backend.astensor(tensors.W()) 67 | 68 | @_register 69 | @lru_cache(maxsize=None) 70 | def sqrtX(backend): 71 | return backend.astensor(tensors.sqrtX()) 72 | 73 | @_register 74 | @lru_cache(maxsize=None) 75 | def sqrtY(backend): 76 | return backend.astensor(tensors.sqrtY()) 77 | 78 | @_register 79 | @lru_cache(maxsize=None) 80 | def sqrtZ(backend): 81 | return backend.astensor(tensors.sqrtZ()) 82 | 83 | @_register 84 | @lru_cache(maxsize=None) 85 | def sqrtW(backend): 86 | return backend.astensor(tensors.sqrtW()) 87 | 88 | @_register 89 | @lru_cache(maxsize=64) 90 | def R(backend, theta): 91 | return backend.astensor(tensors.R(theta)) 92 | 93 | @_register 94 | @lru_cache(maxsize=64) 95 | def U1(backend, lmbda): 96 | return backend.astensor(tensors.U1(lmbda)) 97 | 98 | @_register 99 | @lru_cache(maxsize=64) 100 | def U2(backend, phi, lmbda): 101 | return backend.astensor(tensors.U2(phi, lmbda)) 102 | 103 | @_register 104 | @lru_cache(maxsize=64) 105 | def U3(backend, theta, phi, lmbda): 106 | return backend.astensor(tensors.U3(theta, phi, lmbda)) 107 | 108 | @_register 109 | @lru_cache(maxsize=None) 110 | def CH(backend): 111 | return backend.astensor(tensors.control(1, tensors.H())) 112 | 113 | @_register 114 | @lru_cache(maxsize=None) 115 | def CX(backend): 116 | return backend.astensor(tensors.control(1, tensors.X())) 117 | 118 | @_register 119 | @lru_cache(maxsize=None) 120 | def CY(backend): 121 | return backend.astensor(tensors.control(1, tensors.Y())) 122 | 123 | @_register 124 | @lru_cache(maxsize=None) 125 | def CZ(backend): 126 | return backend.astensor(tensors.control(1, tensors.Z())) 127 | 128 | @_register 129 | @lru_cache(maxsize=64) 130 | def CR(backend, theta): 131 | return backend.astensor(tensors.control(1, tensors.R(theta))) 132 | 133 | @_register 134 | @lru_cache(maxsize=64) 135 | def CU1(backend, lmbda): 136 | return backend.astensor(tensors.control(1, tensors.U1(lmbda))) 137 | 138 | @_register 139 | @lru_cache(maxsize=64) 140 | def CU2(backend, phi, lmbda): 141 | return backend.astensor(tensors.control(1, tensors.U2(phi, lmbda))) 142 | 143 | @_register 144 | @lru_cache(maxsize=64) 145 | def CU3(backend, theta, phi, lmbda): 146 | return backend.astensor(tensors.control(1, tensors.U3(theta, phi, lmbda))) 147 | 148 | @_register 149 | @lru_cache(maxsize=None) 150 | def SWAP(backend): 151 | return backend.astensor(tensors.SWAP()) 152 | 153 | @_register 154 | @lru_cache(maxsize=None) 155 | def ISWAP(backend): 156 | return backend.astensor(tensors.ISWAP()) 157 | -------------------------------------------------------------------------------- /koala/statevector/statevector.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements state vector quantum register. 3 | """ 4 | 5 | from numbers import Number 6 | from string import ascii_letters as chars 7 | import numpy as np 8 | 9 | import tensorbackends 10 | 11 | from ..quantum_state import QuantumState 12 | from ..gates import tensorize 13 | 14 | 15 | class StateVector(QuantumState): 16 | def __init__(self, tensor, backend): 17 | self.backend = tensorbackends.get(backend) 18 | self.tensor = tensor 19 | 20 | @property 21 | def nsite(self): 22 | return self.tensor.ndim 23 | 24 | def copy(self): 25 | return StateVector(self.tensor.copy(), self.backend) 26 | 27 | def conjugate(self): 28 | return StateVector(selkf.tensor.conj(), self.backend) 29 | 30 | def apply_gate(self, gate): 31 | tensor = tensorize(self.backend, gate.name, *gate.parameters) 32 | self.apply_operator(tensor, gate.qubits) 33 | 34 | def apply_circuit(self, gates): 35 | for gate in gates: 36 | self.apply_gate(gate) 37 | 38 | def apply_operator(self, operator, sites): 39 | self.tensor = apply_operator(self.backend, self.tensor, operator, sites) 40 | 41 | def norm(self): 42 | return self.backend.norm(self.tensor) 43 | 44 | def amplitude(self, indices): 45 | if len(indices) != self.nsite: 46 | raise ValueError('indices number and sites number do not match') 47 | return self.tensor[tuple(indices)] 48 | 49 | def probability(self, indices): 50 | return abs(self.amplitude(indices))**2 51 | 52 | def expectation(self, observable): 53 | return braket(self, observable, self).real 54 | 55 | def probabilities(self): 56 | prob_vector = np.real(self.tensor)**2 + np.imag(self.tensor)**2 57 | return [(index, a) for index, a in np.ndenumerate(self.tensor) if not np.isclose(a.conj()*a,0)] 58 | 59 | def inner(self, other): 60 | terms = ''.join(chars[i] for i in range(self.nsite)) 61 | subscripts = terms + ',' + terms + '->' 62 | return self.backend.einsum(subscripts, self.tensor.conj(), other.tensor) 63 | 64 | 65 | def apply_operator(backend, state_tensor, operator, axes): 66 | operator = backend.astensor(operator) 67 | ndim = state_tensor.ndim 68 | input_state_indices = range(ndim) 69 | operator_indices = [*range(ndim, ndim+len(axes)), *axes] 70 | output_state_indices = [*range(ndim)] 71 | for i, axis in enumerate(axes): 72 | output_state_indices[axis] = i + ndim 73 | input_terms = ''.join(chars[i] for i in input_state_indices) 74 | operator_terms = ''.join(chars[i] for i in operator_indices) 75 | output_terms = ''.join(chars[i] for i in output_state_indices) 76 | einstr = f'{input_terms},{operator_terms}->{output_terms}' 77 | return backend.einsum(einstr, state_tensor, operator) 78 | 79 | 80 | def braket(p, observable, q): 81 | if p.backend != q.backend: 82 | raise ValueError('two states must use the same backend') 83 | if p.nsite != q.nsite: 84 | raise ValueError('number of sites must be equal in both states') 85 | all_terms = ''.join(chars[i] for i in range(q.nsite)) 86 | einstr = f'{all_terms},{all_terms}->' 87 | p_tensor_conj = p.tensor.conj() 88 | e = 0 89 | for operator, sites in observable: 90 | r = apply_operator(q.backend, q.tensor, operator, sites) 91 | e += p.backend.einsum(einstr, p_tensor_conj, r) 92 | return e 93 | 94 | 95 | def inherit_unary_operators(*operator_names): 96 | def add_unary_operator(operator_name): 97 | def method(self): 98 | return StateVector(getattr(self.tensor, operator_name)(), self.backend) 99 | method.__module__ = StateVector.__module__ 100 | method.__qualname__ = '{}.{}'.format(StateVector.__qualname__, operator_name) 101 | method.__name__ = operator_name 102 | setattr(StateVector, operator_name, method) 103 | for op_name in operator_names: 104 | add_unary_operator(op_name) 105 | 106 | 107 | def inherit_binary_operators(*operator_names): 108 | def add_binary_operator(operator_name): 109 | def method(self, other): 110 | if isinstance(other, StateVector) and self.backend == other.backend: 111 | return StateVector(getattr(self.tensor, operator_name)(other.tensor), self.backend) 112 | elif isinstance(other, Number): 113 | return StateVector(getattr(self.tensor, operator_name)(other), self.backend) 114 | else: 115 | return NotImplemented 116 | method.__module__ = StateVector.__module__ 117 | method.__qualname__ = '{}.{}'.format(StateVector.__qualname__, operator_name) 118 | method.__name__ = operator_name 119 | setattr(StateVector, operator_name, method) 120 | for op_name in operator_names: 121 | add_binary_operator(op_name) 122 | 123 | 124 | inherit_unary_operators( 125 | '__pos__', 126 | '__neg__', 127 | ) 128 | 129 | inherit_binary_operators( 130 | '__add__', 131 | '__sub__', 132 | '__mul__', 133 | '__truediv__', 134 | '__pow__', 135 | 136 | '__radd__', 137 | '__rsub__', 138 | '__rmul__', 139 | '__rtruediv__', 140 | '__rpow__', 141 | 142 | '__iadd__', 143 | '__isub__', 144 | '__imul__', 145 | '__itruediv__', 146 | '__ipow__', 147 | ) 148 | -------------------------------------------------------------------------------- /test/test_peps.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | from tensorbackends.interface import ImplicitRandomizedSVD, ReducedSVD, RandomizedSVD 5 | from tensorbackends.utils import test_with_backend 6 | 7 | from koala import Observable, observable, peps, statevector, Gate 8 | from koala.peps import contract_options, Snake, ABMPS, BMPS, Square, TRG, contraction 9 | 10 | 11 | @test_with_backend() 12 | class TestPEPS(unittest.TestCase): 13 | def test_norm(self, backend): 14 | qstate = peps.computational_zeros(2, 3, backend=backend) 15 | qstate.apply_circuit([ 16 | Gate('X', [], [0]), 17 | Gate('H', [], [1]), 18 | Gate('CX', [], [0,3]), 19 | Gate('CX', [], [1,4]), 20 | Gate('S', [], [1]), 21 | ]) 22 | self.assertTrue(backend.isclose(qstate.norm(), 1)) 23 | qstate *= 2 24 | self.assertTrue(backend.isclose(qstate.norm(), 2)) 25 | qstate /= 2j 26 | self.assertTrue(backend.isclose(qstate.norm(), 1)) 27 | 28 | def test_trace(self, backend): 29 | observable = Observable.ZZ(0,1) + Observable.ZZ(0,3) 30 | qstate = peps.identity(3, 3, backend=backend) 31 | self.assertTrue(backend.isclose(qstate.trace(), 2**qstate.nsite)) 32 | self.assertTrue(backend.isclose(qstate.trace(observable), 0, atol=1e-8)) 33 | 34 | def test_trace_with_cache(self, backend): 35 | observable = Observable.ZZ(0,1) + Observable.ZZ(0,3) 36 | contract_option = BMPS(ReducedSVD(1)) 37 | qstate = peps.identity(3, 3, backend=backend) 38 | cache = qstate.make_trace_cache(contract_option) 39 | self.assertTrue(backend.isclose( 40 | qstate.trace(contract_option=contract_option, cache=cache), 41 | 2**qstate.nsite 42 | )) 43 | self.assertTrue(backend.isclose( 44 | qstate.trace(observable, contract_option=contract_option, cache=cache), 45 | 0, atol=1e-8 46 | )) 47 | 48 | def test_amplitude(self, backend): 49 | qstate = peps.computational_zeros(2, 3, backend=backend) 50 | qstate.apply_circuit([ 51 | Gate('X', [], [0]), 52 | Gate('H', [], [1]), 53 | Gate('CX', [], [0,3]), 54 | Gate('CX', [], [1,4]), 55 | Gate('S', [], [1]), 56 | ]) 57 | self.assertTrue(backend.isclose(qstate.amplitude([1,0,0,1,0,0]), 1/np.sqrt(2))) 58 | self.assertTrue(backend.isclose(qstate.amplitude([1,1,0,1,1,0]), 1j/np.sqrt(2))) 59 | 60 | def test_amplitude_approx(self, backend): 61 | qstate = peps.computational_zeros(2, 3, backend=backend) 62 | qstate.apply_circuit([ 63 | Gate('X', [], [0]), 64 | Gate('H', [], [1]), 65 | Gate('CX', [], [0,3]), 66 | Gate('CX', [], [1,4]), 67 | Gate('S', [], [1]), 68 | ], update_option=peps.DirectUpdate(ImplicitRandomizedSVD(rank=2))) 69 | contract_option = peps.BMPS(ReducedSVD(rank=2)) 70 | self.assertTrue(backend.isclose(qstate.amplitude([1,0,0,1,0,0], contract_option), 1/np.sqrt(2))) 71 | self.assertTrue(backend.isclose(qstate.amplitude([1,1,0,1,1,0], contract_option), 1j/np.sqrt(2))) 72 | 73 | def test_amplitude_qr_update(self, backend): 74 | qstate = peps.computational_zeros(2, 3, backend=backend) 75 | qstate.apply_circuit([ 76 | Gate('X', [], [0]), 77 | Gate('H', [], [1]), 78 | Gate('CX', [], [0,3]), 79 | Gate('CX', [], [1,4]), 80 | Gate('S', [], [1]), 81 | ], update_option=peps.QRUpdate(rank=2)) 82 | contract_option = peps.BMPS(ReducedSVD(rank=2)) 83 | self.assertTrue(backend.isclose(qstate.amplitude([1,0,0,1,0,0], contract_option), 1/np.sqrt(2))) 84 | self.assertTrue(backend.isclose(qstate.amplitude([1,1,0,1,1,0], contract_option), 1j/np.sqrt(2))) 85 | 86 | def test_amplitude_local_gram_qr_update(self, backend): 87 | qstate = peps.computational_zeros(2, 3, backend=backend) 88 | qstate.apply_circuit([ 89 | Gate('X', [], [0]), 90 | Gate('H', [], [1]), 91 | Gate('CX', [], [0,3]), 92 | Gate('CX', [], [1,4]), 93 | Gate('S', [], [1]), 94 | ], update_option=peps.LocalGramQRUpdate(rank=2)) 95 | contract_option = peps.BMPS(ReducedSVD(rank=2)) 96 | self.assertTrue(backend.isclose(qstate.amplitude([1,0,0,1,0,0], contract_option), 1/np.sqrt(2))) 97 | self.assertTrue(backend.isclose(qstate.amplitude([1,1,0,1,1,0], contract_option), 1j/np.sqrt(2))) 98 | 99 | def test_amplitude_local_gram_qr_svd_update(self, backend): 100 | qstate = peps.computational_zeros(2, 3, backend=backend) 101 | qstate.apply_circuit([ 102 | Gate('X', [], [0]), 103 | Gate('H', [], [1]), 104 | Gate('CX', [], [0,3]), 105 | Gate('CX', [], [1,4]), 106 | Gate('S', [], [1]), 107 | ], update_option=peps.LocalGramQRSVDUpdate(rank=2)) 108 | contract_option = peps.BMPS(ReducedSVD(rank=2)) 109 | self.assertTrue(backend.isclose(qstate.amplitude([1,0,0,1,0,0], contract_option), 1/np.sqrt(2))) 110 | self.assertTrue(backend.isclose(qstate.amplitude([1,1,0,1,1,0], contract_option), 1j/np.sqrt(2))) 111 | 112 | def test_amplitude_nonlocal(self, backend): 113 | update_options = [ 114 | None, 115 | peps.DirectUpdate(ImplicitRandomizedSVD(rank=2)), 116 | peps.QRUpdate(rank=2), 117 | peps.LocalGramQRUpdate(rank=2), 118 | peps.LocalGramQRSVDUpdate(rank=2), 119 | ] 120 | for option in update_options: 121 | with self.subTest(update_option=option): 122 | qstate = peps.computational_zeros(2, 3, backend=backend) 123 | qstate.apply_circuit([ 124 | Gate('X', [], [0]), 125 | Gate('H', [], [1]), 126 | Gate('CX', [], [0,5]), 127 | Gate('CX', [], [1,3]), 128 | Gate('S', [], [1]), 129 | ], update_option=option) 130 | self.assertTrue(backend.isclose(qstate.amplitude([1,0,0,0,0,1]), 1/np.sqrt(2))) 131 | self.assertTrue(backend.isclose(qstate.amplitude([1,1,0,1,0,1]), 1j/np.sqrt(2))) 132 | 133 | def test_amplitude_flip(self, backend): 134 | update_options = [ 135 | None, 136 | peps.DirectUpdate(ImplicitRandomizedSVD(rank=2)), 137 | peps.QRUpdate(rank=2), 138 | peps.LocalGramQRUpdate(rank=2), 139 | peps.LocalGramQRSVDUpdate(rank=2), 140 | ] 141 | for option in update_options: 142 | with self.subTest(update_option=option): 143 | qstate = peps.computational_zeros(2, 3, backend=backend).flip() 144 | qstate.apply_circuit([ 145 | Gate('X', [], [0]), 146 | Gate('H', [], [1]), 147 | Gate('CX', [], [0,5]), 148 | Gate('CX', [], [1,3]), 149 | Gate('S', [], [1]), 150 | ], update_option=option, flip=True) 151 | qstate = qstate.flip() 152 | self.assertTrue(backend.isclose(qstate.amplitude([1,0,0,0,0,1]), 1/np.sqrt(2))) 153 | self.assertTrue(backend.isclose(qstate.amplitude([1,1,0,1,0,1]), 1j/np.sqrt(2))) 154 | 155 | def test_probablity(self, backend): 156 | qstate = peps.computational_zeros(2, 3, backend=backend) 157 | qstate.apply_circuit([ 158 | Gate('X', [], [0]), 159 | Gate('H', [], [1]), 160 | Gate('CX', [], [0,3]), 161 | Gate('CX', [], [1,4]), 162 | Gate('S', [], [1]), 163 | ]) 164 | self.assertTrue(backend.isclose(qstate.probability([1,0,0,1,0,0]), 1/2)) 165 | self.assertTrue(backend.isclose(qstate.probability([1,1,0,1,1,0]), 1/2)) 166 | 167 | def test_expectation(self, backend): 168 | qstate = peps.computational_zeros(2, 3, backend=backend) 169 | qstate.apply_circuit([ 170 | Gate('X', [], [0]), 171 | Gate('CX', [], [0,3]), 172 | Gate('H', [], [2]), 173 | ]) 174 | observable = 1.5 * Observable.sum([ 175 | Observable.Z(0) * 2, 176 | Observable.Z(1), 177 | Observable.Z(2) * 2, 178 | Observable.Z(3), 179 | ]) 180 | self.assertTrue(backend.isclose(qstate.expectation(observable), -3)) 181 | 182 | def test_expectation_single_layer(self, backend): 183 | qstate = peps.computational_zeros(2, 3, backend=backend) 184 | qstate.apply_circuit([ 185 | Gate('X', [], [0]), 186 | Gate('CX', [], [0,3]), 187 | Gate('H', [], [2]), 188 | ]) 189 | observable = 1.5 * Observable.sum([ 190 | Observable.Z(0) * 2, 191 | Observable.Z(1), 192 | Observable.Z(2) * 2, 193 | Observable.Z(3), 194 | ]) 195 | contract_option = peps.SingleLayer(ImplicitRandomizedSVD(rank=2)) 196 | self.assertTrue(backend.isclose(qstate.expectation(observable, contract_option=contract_option), -3)) 197 | 198 | def test_expectation_use_cache(self, backend): 199 | qstate = peps.computational_zeros(2, 3, backend=backend) 200 | qstate.apply_circuit([ 201 | Gate('X', [], [0]), 202 | Gate('CX', [], [0,3]), 203 | Gate('H', [], [2]), 204 | ]) 205 | observable = 1.5 * Observable.sum([ 206 | Observable.Z(0) * 2, 207 | Observable.Z(1), 208 | Observable.Z(2) * 2, 209 | Observable.Z(3), 210 | ]) 211 | self.assertTrue(backend.isclose(qstate.expectation(observable, use_cache=True), -3)) 212 | cache = peps.make_expectation_cache(qstate, qstate) 213 | self.assertTrue(backend.isclose(qstate.expectation(observable, use_cache=cache), -3)) 214 | self.assertTrue(backend.isclose(qstate.norm(cache=cache), 1)) 215 | 216 | def test_expectation_use_cache_approx(self, backend): 217 | qstate = peps.computational_zeros(2, 3, backend=backend) 218 | qstate.apply_circuit([ 219 | Gate('X', [], [0]), 220 | Gate('CX', [], [0,3]), 221 | Gate('H', [], [2]), 222 | ], update_option=peps.DirectUpdate(ImplicitRandomizedSVD(rank=2))) 223 | observable = 1.5 * Observable.sum([ 224 | Observable.Z(0) * 2, 225 | Observable.Z(1), 226 | Observable.Z(2) * 2, 227 | Observable.Z(3), 228 | ]) 229 | contract_option = peps.BMPS(ReducedSVD(rank=2)) 230 | self.assertTrue(backend.isclose(qstate.expectation(observable, use_cache=True, contract_option=contract_option), -3)) 231 | 232 | def test_add(self, backend): 233 | psi = peps.computational_zeros(2, 3, backend=backend) 234 | phi = peps.computational_ones(2, 3, backend=backend) 235 | self.assertTrue(backend.isclose((psi + phi).norm(), np.sqrt(2))) 236 | 237 | def test_inner(self, backend): 238 | psi = peps.computational_zeros(2, 3, backend=backend) 239 | psi.apply_circuit([ 240 | Gate('H', [], [0]), 241 | Gate('CX', [], [0,3]), 242 | Gate('H', [], [3]), 243 | ]) 244 | phi = peps.computational_zeros(2, 3, backend=backend) 245 | self.assertTrue(backend.isclose(psi.inner(phi), 0.5)) 246 | 247 | def test_inner_approx(self, backend): 248 | psi = peps.computational_zeros(2, 3, backend=backend) 249 | psi.apply_circuit([ 250 | Gate('H', [], [0]), 251 | Gate('CX', [], [0,3]), 252 | Gate('H', [], [3]), 253 | ], update_option=peps.DirectUpdate(ImplicitRandomizedSVD(rank=2))) 254 | phi = peps.computational_zeros(2, 3, backend=backend) 255 | contract_option = peps.BMPS(ReducedSVD(rank=2)) 256 | self.assertTrue(backend.isclose(psi.inner(phi, contract_option), 0.5)) 257 | 258 | def test_statevector(self, backend): 259 | psi = peps.computational_zeros(2, 3, backend=backend) 260 | psi.apply_circuit([ 261 | Gate('H', [], [0]), 262 | Gate('CX', [], [0,3]), 263 | Gate('H', [], [3]), 264 | ]) 265 | psi = psi.statevector() 266 | phi = statevector.computational_zeros(6, backend=backend) 267 | self.assertTrue(backend.isclose(psi.inner(phi), 0.5)) 268 | 269 | def test_contract_scalar(self, backend): 270 | qstate = peps.random(3, 4, 2, backend=backend) 271 | norm = qstate.norm(contract_option=Snake()) 272 | for contract_option in contract_options: 273 | if contract_option is not Snake: 274 | for svd_option in (None, ReducedSVD(16), RandomizedSVD(16), ImplicitRandomizedSVD(16), ImplicitRandomizedSVD(16, orth_method='local_gram')): 275 | with self.subTest(contract_option=contract_option.__name__, svd_option=svd_option): 276 | self.assertTrue(backend.isclose(norm, qstate.norm(contract_option=contract_option(svd_option)))) 277 | 278 | def test_contract_vector(self, backend): 279 | qstate = peps.random(3, 3, 2, backend=backend) 280 | statevector = qstate.statevector(contract_option=Snake()) 281 | for contract_option in [BMPS(None), BMPS(ReducedSVD(16)), BMPS(RandomizedSVD(16)), BMPS(ImplicitRandomizedSVD(16))]: 282 | with self.subTest(contract_option=contract_option): 283 | contract_result = qstate.statevector(contract_option=contract_option) 284 | self.assertTrue(backend.allclose(statevector.tensor, contract_result.tensor)) 285 | 286 | def test_truncate(self, backend): 287 | for phys, dual in [(1,1), (2,1), (2,2)]: 288 | with self.subTest(phyiscal_dim=phys, dual_dim=dual): 289 | qstate = peps.random(2, 3, 4, phys, dual, backend=backend) 290 | self.assertEqual(qstate.get_average_bond_dim(), 4) 291 | qstate.truncate(peps.DefaultUpdate(rank=2)) 292 | self.assertEqual(qstate.get_average_bond_dim(), 2) 293 | -------------------------------------------------------------------------------- /koala/peps/peps.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module defines PEPS and operations on it. 3 | """ 4 | 5 | import random, json, os 6 | from pathlib import Path 7 | from math import sqrt 8 | from numbers import Number 9 | from itertools import chain 10 | 11 | import numpy as np 12 | import tensorbackends 13 | 14 | from ..quantum_state import QuantumState 15 | from ..gates import tensorize 16 | from . import contraction, update, sites 17 | 18 | 19 | class PEPS(QuantumState): 20 | def __init__(self, grid, backend): 21 | self.backend = tensorbackends.get(backend) 22 | self.grid = grid 23 | 24 | @property 25 | def nrow(self): 26 | return self.grid.shape[0] 27 | 28 | @property 29 | def ncol(self): 30 | return self.grid.shape[1] 31 | 32 | @property 33 | def shape(self): 34 | return self.grid.shape 35 | 36 | @property 37 | def nsite(self): 38 | return self.nrow * self.ncol 39 | 40 | @property 41 | def dims(self): 42 | dims = np.empty_like(self.grid, dtype=tuple) 43 | for idx, tsr in np.ndenumerate(self.grid): 44 | dims[idx] = tsr.shape 45 | return dims 46 | 47 | def get_average_bond_dim(self): 48 | s = 0 49 | for (i,j), tsr in np.ndenumerate(self.grid): 50 | if i > 0: s += tsr.shape[0] 51 | if j < self.ncol - 1: s += tsr.shape[1] 52 | if i < self.nrow - 1: s += tsr.shape[2] 53 | if j > 0: s += tsr.shape[3] 54 | return s / (2 * self.nrow * self.ncol - self.nrow - self.ncol) / 2 55 | 56 | def get_max_bond_dim(self): 57 | return max(chain.from_iterable(site.shape[0:4] for _, site in np.ndenumerate(self.grid))) 58 | 59 | def truncate(self, update_option=None): 60 | update.truncate(self, update_option) 61 | 62 | def __getitem__(self, position): 63 | item = self.grid[position] 64 | if isinstance(item, np.ndarray): 65 | if item.ndim == 1: 66 | if isinstance(position, int) or isinstance(position[0], int): 67 | item = item.reshape(1, -1) 68 | else: 69 | item = item.reshape(-1, 1) 70 | return PEPS(item, self.backend) 71 | return item 72 | 73 | def __iter__(self): 74 | if self.nrow == 1: 75 | return self.grid.reshape(-1).__iter__() 76 | return PEPS(self.grid.__iter__(), self.backend) 77 | 78 | def __next__(self): 79 | if self.nrow == 1: 80 | return self.grid.reshape(-1).__next__() 81 | return PEPS(self.grid.__next__(), self.backend) 82 | 83 | def copy(self): 84 | grid = np.empty_like(self.grid) 85 | for idx, tensor in np.ndenumerate(self.grid): 86 | grid[idx] = tensor.copy() 87 | return PEPS(grid, self.backend) 88 | 89 | def conjugate(self): 90 | grid = np.empty_like(self.grid) 91 | for idx, tensor in np.ndenumerate(self.grid): 92 | grid[idx] = tensor.conj() 93 | return PEPS(grid, self.backend) 94 | 95 | def apply_gate(self, gate, update_option=None, flip=False): 96 | tensor = tensorize(self.backend, gate.name, *gate.parameters) 97 | self.apply_operator(tensor, gate.qubits, update_option, flip) 98 | 99 | def apply_circuit(self, gates, update_option=None, flip=False): 100 | for gate in gates: 101 | self.apply_gate(gate, update_option, flip) 102 | 103 | def apply_operator(self, operator, sites, update_option=None, flip=False): 104 | positions = [divmod(site, self.ncol) for site in sites] 105 | if len(positions) == 1: 106 | update.apply_single_site_operator(self, operator, positions[0], flip) 107 | elif len(positions) == 2 and is_two_local(*positions): 108 | update.apply_local_pair_operator(self, operator, positions, update_option, flip) 109 | elif len(positions) == 2: 110 | update.apply_nonlocal_pair_operator(self, operator, positions, update_option, flip) 111 | else: 112 | raise ValueError('nonlocal operator is not supported') 113 | 114 | def site_normalize(self, *sites): 115 | """Normalize site-wise.""" 116 | if not sites: 117 | sites = range(self.nsite) 118 | for site in sites: 119 | pos = divmod(site, self.ncol) 120 | self.grid[pos] /= self.backend.norm(self.grid[pos]) 121 | 122 | def __add__(self, other): 123 | if isinstance(other, PEPS) and self.backend == other.backend: 124 | return self.add(other) 125 | else: 126 | return NotImplemented 127 | 128 | def __sub__(self, other): 129 | if isinstance(other, PEPS) and self.backend == other.backend: 130 | return self.add(other, coeff=-1.0) 131 | else: 132 | return NotImplemented 133 | 134 | def __imul__(self, a): 135 | if isinstance(a, Number): 136 | multiplier = a ** (1/(self.nrow * self.ncol)) 137 | for idx in np.ndindex(*self.shape): 138 | self.grid[idx] *= multiplier 139 | return self 140 | else: 141 | return NotImplemented 142 | 143 | def __itruediv__(self, a): 144 | if isinstance(a, Number): 145 | divider = a ** (1/(self.nrow * self.ncol)) 146 | for idx in np.ndindex(*self.shape): 147 | self.grid[idx] /= divider 148 | return self 149 | else: 150 | return NotImplemented 151 | 152 | def norm(self, contract_option=None, cache=None): 153 | return sqrt(self.inner(self, contract_option=contract_option, cache=cache).real) 154 | 155 | def trace_sitewise(self): 156 | grid = np.empty_like(self.grid) 157 | for idx, tensor in np.ndenumerate(self.grid): 158 | grid[idx] = sites.trace_z(tensor) 159 | return PEPS(grid, self.backend) 160 | 161 | def trace(self, observable=None, contract_option=None, cache=None): 162 | if cache is None: 163 | if observable is None: 164 | return self.trace_sitewise().contract(option=contract_option) 165 | else: 166 | result = 0.0 167 | for op, pos in observable: 168 | qstate = self.copy() 169 | qstate.apply_operator(op, pos) 170 | result += qstate.trace(contract_option=contract_option) 171 | return result 172 | else: 173 | if not isinstance(contract_option, contraction.BMPS): 174 | raise ValueError(f'cache only works with BMPS contraction: {contract_option}') 175 | return self._trace_with_cache(observable, contract_option, cache) 176 | 177 | def _trace_with_cache(self, observable, bmps_option, cache): 178 | if observable is None: 179 | return contraction.contract_with_env(self[0:1].trace_sitewise(), cache, 0, 0, bmps_option) 180 | else: 181 | e = 0 182 | for tensor, sites in observable: 183 | other = self.copy() 184 | other.apply_operator(self.backend.astensor(tensor), sites) 185 | rows = [site // self.ncol for site in sites] 186 | up, down = min(rows), max(rows) 187 | e += contraction.contract_with_env( 188 | other[up:down+1].trace_sitewise(), 189 | cache, up, down, bmps_option 190 | ) 191 | return e 192 | 193 | def make_trace_cache(self, contract_option=None): 194 | return contraction.create_env_cache(self.trace_sitewise(), contract_option) 195 | 196 | def add(self, other, *, coeff=1.0): 197 | """ 198 | Add two PEPS of the same grid shape and return the sum as a third PEPS also with the same grid shape. 199 | """ 200 | if self.shape != other.shape: 201 | raise ValueError(f'PEPS shapes do not match: {self.shape} != {other.shape}') 202 | grid = np.empty(self.shape, dtype=object) 203 | for i, j in np.ndindex(*self.grid.shape): 204 | internal_bonds = [] 205 | external_bonds = [4, 5] 206 | (external_bonds if i == 0 else internal_bonds).append(0) 207 | (external_bonds if j == self.shape[1] - 1 else internal_bonds).append(1) 208 | (external_bonds if i == self.shape[0] - 1 else internal_bonds).append(2) 209 | (external_bonds if j == 0 else internal_bonds).append(3) 210 | grid[i, j] = tn_add(self.backend, self[i, j], other[i, j], internal_bonds, external_bonds, 1, coeff) 211 | return PEPS(grid, self.backend) 212 | 213 | def amplitude(self, indices, contract_option=None): 214 | if len(indices) != self.nsite: 215 | raise ValueError('indices number and sites number do not match') 216 | indices = np.array(indices).reshape(*self.shape) 217 | grid = np.empty_like(self.grid, dtype=object) 218 | zero = self.backend.astensor(np.array([1,0], dtype=complex).reshape(2, 1)) 219 | one = self.backend.astensor(np.array([0,1], dtype=complex).reshape(2, 1)) 220 | for idx, tensor in np.ndenumerate(self.grid): 221 | grid[idx] = self.backend.einsum('ijklxp,xq->ijklpq', tensor, one if indices[idx] else zero) 222 | return PEPS(grid, self.backend).contract(contract_option) 223 | 224 | def probability(self, indices, contract_option=None): 225 | return np.abs(self.amplitude(indices, contract_option))**2 226 | 227 | def expectation(self, observable, use_cache=False, contract_option=None): 228 | return braket(self, observable, self, use_cache=use_cache, contract_option=contract_option).real 229 | 230 | def contract(self, option=None): 231 | return contraction.contract(self, option) 232 | 233 | def inner(self, other, contract_option=None, cache=None): 234 | if cache is None: 235 | return contraction.contract_sandwich(self.dagger(), other, contract_option) 236 | else: 237 | if contract_option is None: 238 | contract_option = contraction.BMPS(svd_option=None) 239 | if not isinstance(contract_option, contraction.BMPS): 240 | raise ValueError('inner with cache must use BMPS contraction') 241 | return contraction.contract_with_env(None, cache, 1, 0, contract_option) 242 | 243 | def statevector(self, contract_option=None): 244 | from .. import statevector 245 | return statevector.StateVector(self.contract(contract_option), self.backend) 246 | 247 | def apply(self, other): 248 | """ 249 | Apply a PEPS/PEPO to another PEPS/PEPO. Only the first pair of physical indices is contracted; the other physical indices are left in the order of A, B. 250 | 251 | Parameters 252 | ---------- 253 | other: PEPS 254 | The second PEPS/PEPO. 255 | 256 | Returns 257 | ------- 258 | output: PEPS 259 | The PEPS generated by the application. 260 | """ 261 | grid = np.empty_like(self.grid) 262 | for (idx, a), b in zip(np.ndenumerate(self.grid), other.grid.flat): 263 | grid[idx] = sites.contract_z(b, a) 264 | return PEPS(grid, self.backend) 265 | 266 | def concatenate(self, other, axis=0): 267 | """ 268 | Concatenate two PEPS along the given axis. 269 | 270 | Parameters 271 | ---------- 272 | other: PEPS 273 | The second PEPS 274 | 275 | axis: int, optional 276 | The axis along which the PEPS will be concatenated. 277 | 278 | Returns 279 | ------- 280 | output: PEPS 281 | The concatenated PEPS. 282 | """ 283 | return PEPS(np.concatenate((self.grid, other.grid), axis), self.backend) 284 | 285 | def dagger(self): 286 | """ 287 | Compute the Hermitian conjugate of the PEPS. Equivalent to take `conjugate` then `flip`. 288 | 289 | Returns 290 | ------- 291 | output: PEPS 292 | """ 293 | return self.conjugate().flip() 294 | 295 | def flip(self, *indices): 296 | """ 297 | Flip the direction of physical indices for specified sites. 298 | Parameters 299 | ---------- 300 | indices: iterable, optional 301 | Indices of sites (tensors) to flip. Specify as `(i, j)` or `((i1, j1), (i2, j2), ...)`, where `i` and `j` should be int. 302 | Will flip all sites if left as `None`. 303 | Returns 304 | ------- 305 | output: PEPS 306 | """ 307 | if indices and isinstance(indices[0], int): 308 | indices = (indices, ) 309 | tn = np.empty_like(self.grid) 310 | for idx, tsr in np.ndenumerate(self.grid): 311 | if not indices or idx in indices: 312 | tn[idx] = sites.flip_z(tsr) 313 | else: 314 | tn[idx] = tsr.copy() 315 | return PEPS(tn, self.backend) 316 | 317 | def rotate(self, num_rotate90=1): 318 | """ 319 | Rotate the PEPS counter-clockwise by 90 degrees * the specified times. Will cause the tensors to transpose accordingly. 320 | 321 | Parameters 322 | ---------- 323 | num_rotate90: int, optional 324 | Number of 90 degree rotations. 325 | 326 | Returns 327 | ------- 328 | output: PEPS 329 | """ 330 | num_rotate90 = num_rotate90 % 4 331 | if num_rotate90 == 0: 332 | return self 333 | else: 334 | tn = np.rot90(self.grid, k=num_rotate90).copy() 335 | for idx, tsr in np.ndenumerate(tn): 336 | tn[idx] = sites.rotate_z(tsr, -num_rotate90).copy() 337 | return PEPS(tn, self.backend) 338 | 339 | 340 | def make_expectation_cache(p, q, contract_option=None): 341 | if p.backend != q.backend: 342 | raise ValueError('two states must use the same backend') 343 | if p.nsite != q.nsite: 344 | raise ValueError('number of sites must be equal in both states') 345 | if contract_option is None: 346 | contract_option = contraction.BMPS(svd_option=None) 347 | if not isinstance(contract_option, contraction.BMPS): 348 | raise ValueError('expectation cache must use BMPS contraction') 349 | return contraction.create_env_cache(p.dagger().apply(q), contract_option) 350 | 351 | 352 | def braket(p, observable, q, use_cache=False, contract_option=None): 353 | if p.backend != q.backend: 354 | raise ValueError('two states must use the same backend') 355 | if p.nsite != q.nsite: 356 | raise ValueError('number of sites must be equal in both states') 357 | if use_cache: 358 | if contract_option is None: 359 | contract_option = contraction.BMPS(svd_option=None) 360 | if not isinstance(contract_option, contraction.BMPS): 361 | raise ValueError('braket with cache must use BMPS contraction') 362 | env = use_cache if isinstance(use_cache, tuple) else None 363 | return _braket_with_cache(p, observable, q, contract_option, env) 364 | e = 0 365 | p_dagger = p.dagger() 366 | for tensor, sites in observable: 367 | other = q.copy() 368 | other.apply_operator(q.backend.astensor(tensor), sites) 369 | e += contraction.contract_sandwich(p_dagger, other, contract_option) 370 | return e 371 | 372 | 373 | def _braket_with_cache(p, observable, q, bmps_option, cache=None): 374 | p_dagger = p.dagger() 375 | if cache is None: 376 | env = contraction.create_env_cache(p_dagger.apply(q), bmps_option) 377 | else: 378 | env = cache 379 | e = 0 380 | for tensor, sites in observable: 381 | other = q.copy() 382 | other.apply_operator(q.backend.astensor(tensor), sites) 383 | rows = [site // q.ncol for site in sites] 384 | up, down = min(rows), max(rows) 385 | e += contraction.contract_with_env( 386 | p_dagger[up:down+1].apply(other[up:down+1]), 387 | env, up, down, bmps_option 388 | ) 389 | return e 390 | 391 | 392 | def tn_add(backend, a, b, internal_bonds, external_bonds, coeff_a, coeff_b): 393 | """ 394 | Helper function for addition of two tensor network states with the same structure. 395 | Add two site from two tensor network states respecting specified inner and external bond structure. 396 | """ 397 | ndim = a.ndim 398 | shape_a = np.array(np.shape(a)) 399 | shape_b = np.array(np.shape(b)) 400 | shape_c = np.copy(shape_a) 401 | shape_c[internal_bonds] += shape_b[internal_bonds] 402 | lim = np.copy(shape_a).astype(object) 403 | lim[external_bonds] = None 404 | a_ind = tuple([slice(lim[i]) for i in range(ndim)]) 405 | b_ind = tuple([slice(lim[i], None) for i in range(ndim)]) 406 | c = backend.zeros(shape_c, dtype=a.dtype) 407 | c[a_ind] += a * coeff_a 408 | c[b_ind] += b * coeff_b 409 | return c 410 | 411 | 412 | def is_two_local(p, q): 413 | dx, dy = abs(q[0] - p[0]), abs(q[1] - p[1]) 414 | return dx == 1 and dy == 0 or dx == 0 and dy == 1 415 | 416 | 417 | def save(qstate, dirname): 418 | Path(dirname).mkdir(exist_ok=True) 419 | with open(os.path.join(dirname, 'koala_peps.json'), 'w+') as file: 420 | json.dump({ 421 | 'backend': qstate.backend.name, 422 | 'nrow': qstate.nrow, 423 | 'ncol': qstate.ncol, 424 | }, file) 425 | for i, j in np.ndindex(*qstate.shape): 426 | qstate.backend.save(qstate[i, j], os.path.join(dirname, f'{i}_{j}')) 427 | 428 | 429 | def load(dirname): 430 | with open(os.path.join(dirname, 'koala_peps.json')) as file: 431 | meta = json.load(file) 432 | backend = tensorbackends.get(meta['backend']) 433 | grid = np.empty((meta['nrow'], meta['ncol']), dtype=object) 434 | for i, j in np.ndindex(*grid.shape): 435 | grid[i, j] = backend.load(os.path.join(dirname, f'{i}_{j}')) 436 | return PEPS(grid, backend) 437 | -------------------------------------------------------------------------------- /koala/peps/contraction.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module implements contraction algorithms. 3 | """ 4 | 5 | from collections import namedtuple 6 | import numpy as np 7 | from tensorbackends.interface import ReducedSVD, ImplicitRandomizedSVD 8 | 9 | from . import sites 10 | 11 | 12 | class ContractOption: 13 | def __str__(self): 14 | return '{}({})'.format( 15 | type(self).__name__, 16 | ','.join('{}={}'.format(k, v) for k, v in vars(self).items()) 17 | ) 18 | 19 | def __repr__(self): 20 | return str(self) 21 | 22 | @property 23 | def name(self): 24 | return type(self).__name__ 25 | 26 | 27 | class ABMPS(ContractOption): 28 | def __init__(self, svd_option=None, compress_alg='contract first'): 29 | self.svd_option = svd_option 30 | self.compress_alg = compress_alg 31 | 32 | class BMPS(ContractOption): 33 | def __init__(self, svd_option=None, compress_alg='contract first', canonicalize=False): 34 | self.svd_option = svd_option 35 | self.compress_alg = compress_alg 36 | self.canonicalize = canonicalize 37 | 38 | class SingleLayer(ContractOption): 39 | def __init__(self, svd_option=None, compress_alg='contract first'): 40 | self.svd_option = svd_option 41 | self.compress_alg = compress_alg 42 | 43 | class Snake(ContractOption): 44 | pass 45 | 46 | class Square(ContractOption): 47 | def __init__(self, svd_option=None): 48 | self.svd_option = svd_option 49 | 50 | class TRG(ContractOption): 51 | def __init__(self, svd_option_1st=None, svd_option_rem=None): 52 | self.svd_option_1st = svd_option_1st 53 | self.svd_option_rem = svd_option_rem 54 | 55 | contract_options = (ABMPS, BMPS, Snake, Square, TRG) 56 | 57 | 58 | def contract(state, option): 59 | """ 60 | Contract the PEPS to a single tensor or a scalar(a "0-tensor"). 61 | 62 | Parameters 63 | ---------- 64 | approach: str, optional 65 | The approach to contract. 66 | 67 | option: koala.peps.ContractOption, optional 68 | Parameters for performing the contraction. 69 | 70 | Returns 71 | ------- 72 | output: state.backend.tensor or scalar 73 | The contraction result. 74 | """ 75 | if option is None: 76 | option = BMPS(None) 77 | if isinstance(option, Snake): 78 | return contract_snake(state) 79 | elif isinstance(option, ABMPS): 80 | return contract_ABMPS(state, svd_option=option.svd_option, compress_alg=option.compress_alg) 81 | elif isinstance(option, BMPS): 82 | return contract_BMPS(state, svd_option=option.svd_option, compress_alg=option.compress_alg, canonicalize=option.canonicalize) 83 | elif isinstance(option, SingleLayer): 84 | return contract_single_layer(*state, svd_option=option.svd_option, compress_alg=option.compress_alg) 85 | elif isinstance(option, Square): 86 | return contract_squares(state, svd_option=option.svd_option) 87 | elif isinstance(option, TRG): 88 | return contract_TRG(state, svd_option_1st=option.svd_option_1st, svd_option_rem=option.svd_option_rem) 89 | else: 90 | raise ValueError(f'unknown contraction option: {option}') 91 | 92 | 93 | def contract_sandwich(state1, state2, option): 94 | if state1.backend != state2.backend: 95 | raise ValueError('cannot contract two states with different backends') 96 | if state1.shape != state2.shape: 97 | raise ValueError('cannot contract two states with different shapes') 98 | if option is None: 99 | option = SingleLayer(None) 100 | if isinstance(option, SingleLayer): 101 | return contract_single_layer(state1, state2, svd_option=option.svd_option, compress_alg=option.compress_alg) 102 | else: 103 | return contract(state1.apply(state2), option=option) 104 | 105 | 106 | def contract_ABMPS(state, compress_alg='contract first', svd_option=None): 107 | """ 108 | Contract the PEPS by performing alternating vertical and horizontal bondary contractions. 109 | 110 | Parameters 111 | ---------- 112 | mps_mult_mpo: method or None, optional 113 | The method used to apply an MPS to another MPS/MPO. 114 | 115 | svd_option: tensorbackends.interface.Option, optional 116 | Parameters for SVD truncations. Will perform SVD if given. 117 | 118 | Returns 119 | ------- 120 | output: state.backend.tensor or scalar 121 | The contraction result. 122 | """ 123 | horizontal = False 124 | while state.ncol > 2 and state.nrow > 2: 125 | edge = state[:,:2] if horizontal else state[:2] 126 | body = state[:,2:] if horizontal else state[2:] 127 | state = contract_to_MPS(edge, horizontal=horizontal, svd_option=svd_option).concatenate(body, int(horizontal)) 128 | horizontal = not horizontal 129 | return contract_BMPS(state) 130 | 131 | 132 | def contract_BMPS(state, svd_option=None, compress_alg='contract first', canonicalize=False): 133 | """ 134 | Contract the PEPS by contracting each MPS layer. 135 | 136 | Parameters 137 | ---------- 138 | mps_mult_mpo: method or None, optional 139 | The method used to apply an MPS to another MPS/MPO. 140 | 141 | svd_option: tensorbackends.interface.Option, optional 142 | Parameters for SVD truncations. Will perform SVD if given. 143 | 144 | Returns 145 | ------- 146 | output: state.backend.tensor or scalar 147 | The contraction result. 148 | """ 149 | # contract boundary MPS down then contract the last MPS to a single tensor 150 | return _vector_reshaper_BMPS(contract_MPS(contract_to_MPS( 151 | state, horizontal=False, reverse=False, svd_option=svd_option, compress_alg=compress_alg, canonicalize=canonicalize 152 | )), state.shape) 153 | 154 | 155 | def contract_env(state, row_range, col_range, svd_option=None): 156 | """ 157 | Contract the surrounding environment to four MPS around the core sites. 158 | 159 | Parameters 160 | ---------- 161 | row_range: tuple or int 162 | A two-int tuple specifying the row range of the core sites, i.e. [row_range[0] : row_range[1]]. 163 | If only an int is given, it is equivalent to (row_range, row_range+1). 164 | 165 | col_range: tuple or int 166 | A two-int tuple specifying the column range of the core sites, i.e. [:, col_range[0] : col_range[1]]. 167 | If only an int is given, it is equivalent to (col_range, col_range+1). 168 | 169 | svd_option: tensorbackends.interface.Option, optional 170 | Parameters for SVD truncations. Will perform SVD if given. 171 | 172 | Returns 173 | ------- 174 | output: PEPS 175 | The new PEPS consisting of core sites and contracted environment. 176 | """ 177 | if isinstance(row_range, int): 178 | row_range = (row_range, row_range+1) 179 | if isinstance(col_range, int): 180 | col_range = (col_range, col_range+1) 181 | mid_peps = state[row_range[0]:row_range[1]].copy() 182 | if row_range[0] > 0: 183 | mid_peps = state[:row_range[0]].contract_to_MPS(svd_option=svd_option).concatenate(mid_peps) 184 | if row_range[1] < state.nrow: 185 | mid_peps = mid_peps.concatenate(state[row_range[1]:].contract_to_MPS(svd_option=svd_option)) 186 | env_peps = mid_peps[:,col_range[0]:col_range[1]] 187 | if col_range[0] > 0: 188 | env_peps = mid_peps[:,:col_range[0]].contract_to_MPS(horizontal=True, svd_option=svd_option).concatenate(env_peps, axis=1) 189 | if col_range[1] < mid_peps.shape[1]: 190 | env_peps = env_peps.concatenate(mid_peps[:,col_range[1]:].contract_to_MPS(horizontal=True, svd_option=svd_option), axis=1) 191 | return env_peps 192 | 193 | 194 | def contract_MPS(mps): 195 | result = mps[0,0] 196 | for tsr in mps[0,1:]: 197 | result = sites.contract_y(result, tsr) 198 | return result 199 | 200 | 201 | def contract_single_layer(state1, state2, svd_option=None, compress_alg='contract first'): 202 | """ 203 | Contract the PEPS by contracting each MPS layer. 204 | 205 | Parameters 206 | ---------- 207 | svd_option: tensorbackends.interface.Option, optional 208 | Parameters for SVD truncations. Will perform SVD if given. 209 | 210 | Returns 211 | ------- 212 | output: state.backend.tensor or scalar 213 | The contraction result. 214 | """ 215 | # contract boundary MPS down 216 | from .peps import PEPS 217 | mps = np.empty_like(state1.grid[0]) 218 | for i, (tsr1, tsr2) in enumerate(zip(state1.grid[0], state2.grid[0])): 219 | mps[i] = state1.backend.einsum('abcdpi,ABCDiq->(aA)(bB)cC(dD)pq', tsr1, tsr2) 220 | for i, (mpo1, mpo2) in enumerate(zip(state1.grid[1:], state2.grid[1:])): 221 | for j, (s, o1, o2) in enumerate(zip(mps, mpo1, mpo2)): 222 | if compress_alg is None or compress_alg == 'contract first': 223 | if svd_option: 224 | if j == 0: 225 | mps[0] = s.backend.einsum('xyijzpP,ibcdqk,jBCDkQ->xybBcC(zdD)(pq)(PQ)', s, o1, o2) 226 | else: 227 | mps[j-1], _, mps[j] = s.backend.einsumsvd( 228 | 'aijkxXdpP,AylmiqQ,lbcjrn,mBCknR->azxXdpP,AybBcCz(qr)(QR)', 229 | mps[j-1], s, o1, o2, option=svd_option, absorb_s='even' 230 | ) 231 | if j == len(mps)-1: 232 | mps[-1] = s.backend.einsum('AybBcCzqQ->A(ybB)cCzqQ', mps[-1]) 233 | else: 234 | mps[j] = s.backend.einsum('xyijzpP,ibcdqk,jBCDkQ->x(ybB)cC(zdD)(pq)(PQ)', s, o1, o2) 235 | elif compress_alg == 'svd first': 236 | if svd_option: 237 | if j == 0: 238 | mps[0], _, s_left = s.backend.einsumsvd( 239 | 'xyijzpP,ibcdqk,jBCDkQ->xncC(zdD)(pq)(PQ),ybBn', 240 | s, o1, o2, option=svd_option, absorb_s='even' 241 | ) 242 | elif j == len(mps)-1: 243 | mps[-1] = s.backend.einsum('ijkd,AylmiqQ,lbcjrn,mBCknR->A(ybB)cCd(qr)(QR)', s_left, s, o1, o2) 244 | else: 245 | mps[j], _, s_left = s.backend.einsumsvd( 246 | 'ijkd,AylmiqQ,lbcjrn,mBCknR->AzcCd(qr)(QR),ybBz', 247 | s_left, s, o1, o2, option=svd_option, absorb_s='even' 248 | ) 249 | else: 250 | mps[j] = s.backend.einsum('xyijzpP,ibcdqk,jBCDkQ->x(ybB)cC(zdD)(pq)(PQ)', s, o1, o2) 251 | elif not callable(compress_alg): 252 | raise ValueError('Invalid compress algorithm') 253 | 254 | # contract the last MPS to a single tensor 255 | for i, tsr in enumerate(mps): 256 | mps[i] = state1.backend.einsum('abcCdpP->ab(cC)dpP', tsr) 257 | return _vector_reshaper_BMPS(contract_MPS(PEPS(mps.reshape(1, -1), state1.backend)), state1.shape) 258 | 259 | def contract_snake(state): 260 | """ 261 | Contract the PEPS by contracting sites in the row-major order. 262 | 263 | Returns 264 | ------- 265 | output: state.backend.tensor or scalar 266 | The contraction result. 267 | 268 | References 269 | ---------- 270 | https://arxiv.org/pdf/1905.08394.pdf 271 | """ 272 | head = state.grid[0,0] 273 | for i, mps in enumerate(state.grid): 274 | for tsr in mps[int(i==0):]: 275 | head = state.backend.einsum('gabcdef->a(gb)cdef', 276 | head.reshape(*((tsr.shape[0], head.shape[0] // tsr.shape[0]) + head.shape[1:]))) 277 | tsr = state.backend.einsum('agbcdef->abc(gd)ef', tsr.reshape(*((1,) + tsr.shape))) 278 | head = sites.contract_y(head, tsr) 279 | head = head.transpose(2, 1, 0, 3, 4, 5) 280 | return head.item() if head.size == 1 else head.reshape(*[int(round(head.size ** (1 / state.nsite)))] * state.nsite) 281 | 282 | 283 | def contract_squares(state, svd_option=None): 284 | """ 285 | Contract the PEPS by contracting two neighboring tensors to one recursively. 286 | The neighboring relationship alternates from horizontal and vertical. 287 | 288 | Parameters 289 | ---------- 290 | svd_option: tensorbackends.interface.Option, optional 291 | Parameters for SVD truncations. Will perform SVD if given. 292 | 293 | Returns 294 | ------- 295 | output: state.backend.tensor or scalar 296 | The contraction result. 297 | """ 298 | from .peps import PEPS 299 | tn = state.grid 300 | new_tn = np.empty((int((state.nrow + 1) / 2), state.ncol), dtype=object) 301 | for ((i, j), a), b in zip(np.ndenumerate(tn[:-1:2,:]), tn[1::2,:].flat): 302 | new_tn[i,j] = sites.contract_x(a, b) 303 | if svd_option is not None and j > 0 and new_tn.shape != (1, 2): 304 | new_tn[i,j-1], new_tn[i,j] = sites.reduce_y(new_tn[i,j-1], new_tn[i,j], svd_option) 305 | # append the left edge if nrow/ncol is odd 306 | if state.nrow % 2 == 1: 307 | for i, a in enumerate(tn[-1]): 308 | new_tn[-1,i] = a.copy() 309 | # base case 310 | if new_tn.shape == (1, 1): 311 | return new_tn[0,0].item() if new_tn[0,0].size == 1 else new_tn[0,0] 312 | # alternate the neighboring relationship and contract recursively 313 | return contract_squares(PEPS(new_tn, state.backend).rotate(), svd_option) 314 | 315 | 316 | def contract_squares_variant(state, svd_option=None): 317 | """ 318 | Contract the PEPS by contracting two neighboring tensors to one recursively. 319 | The neighboring relationship alternates from horizontal and vertical. 320 | 321 | Parameters 322 | ---------- 323 | svd_option: tensorbackends.interface.Option, optional 324 | Parameters for SVD truncations. Will perform SVD if given. 325 | 326 | Returns 327 | ------- 328 | output: state.backend.tensor or scalar 329 | The contraction result. 330 | """ 331 | from .peps import PEPS 332 | from .update import QRUpdate 333 | 334 | state.truncate(QRUpdate(svd_option)) 335 | tn = state.grid 336 | new_tn = np.empty(((state.nrow + 1) // 2, (state.ncol + 1) // 2), dtype=object) 337 | nrow, ncol = new_tn.shape 338 | if state.nrow % 2 == 1: 339 | nrow -= 1 340 | if state.ncol % 2 == 1: 341 | ncol -= 1 342 | new_tn[-1,-1] = tn[-1,-1].copy() 343 | for j in range(ncol): 344 | new_tn[-1,j] = sites.contract_y(tn[-1,2*j], tn[-1,2*j+1]) 345 | if state.ncol % 2 == 1: 346 | for i in range(nrow): 347 | new_tn[i,-1] = sites.contract_x(tn[2*i,-1], tn[2*i+1,-1]) 348 | 349 | for i, j in np.ndindex(nrow, ncol): 350 | new_tn[i,j] = state.backend.einsum('aijdpP,AbkiqQ,jlcDrR,kBClsS->(aA)(bB)(cC)(dD)(pqrs)(PQRS)', 351 | tn[2*i,2*j], tn[2*i,2*j+1], tn[2*i+1,2*j], tn[2*i+1,2*j+1]) 352 | 353 | # base case 354 | if new_tn.shape == (1, 1): 355 | return new_tn[0,0].item() if new_tn[0,0].size == 1 else new_tn[0,0] 356 | # alternate the neighboring relationship and contract recursively 357 | return contract_squares_variant(PEPS(new_tn, state.backend), svd_option) 358 | 359 | 360 | def contract_to_MPS(state, horizontal=False, reverse=False, svd_option=None, compress_alg='contract first', canonicalize=False): 361 | """ 362 | Contract the PEPS to an MPS. 363 | 364 | Parameters 365 | ---------- 366 | horizontal: bool, optional 367 | Control whether to contract from top to bottom or from left to right. Will affect the output MPS direction. 368 | 369 | mps_mult_mpo: method or None, optional 370 | The method used to apply an MPS to another MPS/MPO. 371 | 372 | 373 | Returns 374 | ------- 375 | output: PEPS 376 | The resulting MPS (as a `PEPS` object of shape `(1, N)` or `(M, 1)`). 377 | """ 378 | from .peps import PEPS 379 | if compress_alg is None or compress_alg == 'contract first': 380 | compress_alg = _compress_contract_first 381 | elif compress_alg == 'svd first': 382 | compress_alg = _compress_svd_first 383 | elif not callable(compress_alg): 384 | raise ValueError('Invalid compress algorithm') 385 | num_rotate90 = bool(horizontal) * 1 + bool(reverse) * 2 386 | state = state.rotate(-num_rotate90) 387 | mps = state.grid[0] 388 | for i, mpo in enumerate(state.grid[1:]): 389 | mps = compress_alg(mps, mpo, svd_option=svd_option, canonicalize=('left' if i % 2 else 'right') if canonicalize else False) 390 | return PEPS(mps.reshape(1, -1), state.backend).rotate(num_rotate90) 391 | 392 | 393 | 394 | def contract_TRG(state, svd_option_1st=None, svd_option_rem=None): 395 | """ 396 | Contract the PEPS using Tensor Renormalization Group. 397 | 398 | Parameters 399 | ---------- 400 | svd_option_1st: tensorbackends.interface.Option, optional 401 | Parameters for the first SVD in TRG. Will default to tensorbackends.interface.ReducedSVD() if not given. 402 | 403 | 404 | svd_option_rem: tensorbackends.interface.Option, optional 405 | Parameters for the remaining SVD truncations. Will perform SVD if given. 406 | 407 | Returns 408 | ------- 409 | output: state.backend.tensor or scalar 410 | The contraction result. 411 | 412 | References 413 | ---------- 414 | https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.99.120601 415 | https://journals.aps.org/prb/abstract/10.1103/PhysRevB.78.205116 416 | """ 417 | # base case 418 | if state.shape <= (2, 2): 419 | return contract_BMPS(state, svd_option_rem) 420 | # SVD each tensor into two 421 | tn = np.empty(state.shape + (2,), dtype=object) 422 | for (i, j), tsr in np.ndenumerate(state.grid): 423 | str_uv = 'abi,icdpq' if (i+j) % 2 == 0 else 'aidpq,bci' 424 | tn[i,j,0], _, tn[i,j,1] = state.backend.einsumsvd( 425 | 'abcdpq->' + str_uv, tsr, 426 | option=svd_option_1st or ReducedSVD(), 427 | absorb_s='even' 428 | ) 429 | tn[i,j,(i+j)%2] = tn[i,j,(i+j)%2].reshape(*(tn[i,j,(i+j)%2].shape + (1, 1))) 430 | return _contract_TRG(state, tn, svd_option_rem) 431 | 432 | 433 | def _contract_TRG(state, tn, svd_option=None): 434 | from .peps import PEPS 435 | # base case 436 | if tn.shape == (2, 2, 2): 437 | p = np.empty((2, 2), dtype=object) 438 | for i, j in np.ndindex((2, 2)): 439 | p[i,j] = state.backend.einsum( 440 | 'abipq,icdPQ->abcd(pP)(qQ)' if (i+j) % 2 == 0 else 'aidpq,bciPQ->abcd(pP)(qQ)', tn[i,j][0], tn[i,j][1]) 441 | return contract_BMPS(PEPS(p, state.backend)) 442 | 443 | # contract specific horizontal and vertical bonds and SVD truncate the generated squared bonds 444 | for i, j in np.ndindex(tn.shape[:2]): 445 | if j > 0 and j % 2 == 0: 446 | k = 1 - i % 2 447 | l = j - ((i // 2 * 2 + j) % 4 == 0) 448 | tn[i,l][k] = state.backend.einsum( 449 | 'ibapq,ABiPQ->A(bB)a(pP)(qQ)' if k else 'biapq,BAiPQ->(bB)Aa(pP)(qQ)', tn[i,j-1][k], tn[i,j][k]) 450 | if i % 2 == 1 and svd_option is not None: 451 | tn[i-1,l][1], _, tn[i,l][0] = state.backend.einsumsvd( 452 | 'aidpq,iBCPQ->aIdpq,IBCPQ', 453 | tn[i-1,l][1], tn[i,l][0], option=svd_option, absorb_s='even' 454 | ) 455 | if i > 0 and i % 2 == 0: 456 | k = 1 - j % 2 457 | l = int((i + j // 2 * 2) % 4 == 0) 458 | tn[i-l,j][l] = state.backend.einsum( 459 | 'biapq,iBAPQ->(bB)Aa(pP)(qQ)' if k else 'aibpq,iABPQ->aA(bB)(pP)(qQ)', tn[i-1,j][1], tn[i,j][0]) 460 | if j % 2 == 1 and svd_option is not None: 461 | tn[i-l,j-1][l], _, tn[i-l,j][l] = state.backend.einsumsvd( 462 | 'icdpq,ABiPQ->Icdpq,ABIPQ', 463 | tn[i-l,j-1][l], tn[i-l,j][l], 464 | option=svd_option, absorb_s='even' 465 | ) 466 | 467 | # contract specific diagonal bonds and generate a smaller tensor network 468 | new_tn = np.empty((tn.shape[0] // 2 + 1, tn.shape[1] // 2 + 1, 2), dtype=object) 469 | for i, j in np.ndindex(tn.shape[:2]): 470 | m, n = (i + 1) // 2, (j + 1) // 2 471 | if (i + j) % 4 == 2 and i % 2 == 0: 472 | if tn[i,j][0] is None: 473 | new_tn[m,n][1] = tn[i,j][1] 474 | elif tn[i,j][1] is None: 475 | new_tn[m,n][1] = tn[i,j][0] 476 | else: 477 | new_tn[m,n][1] = state.backend.einsum( 478 | 'abipq,iCAPQ->bC(aA)(pP)(qQ)' if i == 0 else 'aibpq,iCBPQ->aC(bB)(pP)(qQ)', tn[i,j][0], tn[i,j][1]) 479 | elif (i + j) % 4 == 0 and i % 2 == 1: 480 | new_tn[m,n][0] = state.backend.einsum('abipq,iBCPQ->a(bB)C(pP)(qQ)', tn[i,j][0], tn[i,j][1]) 481 | elif (i + j) % 4 == 3 and i % 2 == 0: 482 | new_tn[m,n][1] = state.backend.einsum('aibpq,ACiPQ->(aA)Cb(pP)(qQ)', tn[i,j][0], tn[i,j][1]) 483 | elif (i + j) % 4 == 3 and i % 2 == 1: 484 | new_tn[m,n][0] = state.backend.einsum('aibpq,CBiPQ->aC(bB)(pP)(qQ)', tn[i,j][0], tn[i,j][1]) 485 | else: 486 | if new_tn[m,n][0] is None: 487 | new_tn[m,n][0] = tn[i,j][0] 488 | if new_tn[m,n][1] is None: 489 | new_tn[m,n][1] = tn[i,j][1] 490 | 491 | # SVD truncate the squared bonds generated by the diagonal contractions 492 | if svd_option is not None: 493 | for i, j in np.ndindex(new_tn.shape[:2]): 494 | if (i + j) % 2 == 0 and new_tn[i,j][0] is not None and new_tn[i,j][1] is not None: 495 | new_tn[i,j][0], _, new_tn[i,j][1] = state.backend.einsumsvd( 496 | 'abipq,iCDPQ->abIpq,ICDPQ', new_tn[i,j][0], new_tn[i,j][1], 497 | option=svd_option, absorb_s='even' 498 | ) 499 | elif (i + j) % 2 == 1: 500 | new_tn[i,j][0], _, new_tn[i,j][1] = state.backend.einsumsvd( 501 | 'aidpq,BCiPQ->aIdpq,BCIPQ', new_tn[i,j][0], new_tn[i,j][1], 502 | option=svd_option, absorb_s='even' 503 | ) 504 | 505 | return _contract_TRG(state, new_tn, svd_option) 506 | 507 | 508 | def _compress_contract_first(mps, mpo, svd_option=None, canonicalize=False): 509 | new_mps = np.empty_like(mps) 510 | if canonicalize == 'right': 511 | mps = mps[::-1] 512 | mpo = mpo[::-1] 513 | 514 | for i, (s, o) in enumerate(zip(mps, mpo)): 515 | if canonicalize == 'right': 516 | s = s.backend.einsum('abcdpq->adcbpq', s) 517 | o = o.backend.einsum('abcdpq->adcbpq', o) 518 | if svd_option: 519 | if i == 0: 520 | new_mps[0] = s.backend.einsum('abidpq,iBcDPQ->abBc(dD)(pP)(qQ)', s, o) 521 | else: 522 | new_mps[i-1], _, new_mps[i] = s.backend.einsumsvd( 523 | 'aijcdpP,AbkiqQ,kBCjrR->aIcdpP,AbBCI(qr)(QR)', 524 | new_mps[i-1], s, o, option=svd_option, 525 | absorb_s='v' if canonicalize else 'even' 526 | ) 527 | if i == len(mps)-1: 528 | new_mps[-1] = s.backend.einsum('abBcdpq->a(bB)cdpq', new_mps[-1]) 529 | else: 530 | new_mps[i] = sites.contract_x(s, o) 531 | 532 | if canonicalize == 'right': 533 | new_mps = new_mps[::-1] 534 | for i, s in enumerate(new_mps): 535 | new_mps[i] = s.backend.einsum('abcdpq->adcbpq', s) 536 | return new_mps 537 | 538 | 539 | def _compress_svd_first(mps, mpo, svd_option=None, canonicalize=False): 540 | new_mps = np.empty_like(mps) 541 | if canonicalize == 'right': 542 | mps = mps[::-1] 543 | mpo = mpo[::-1] 544 | 545 | for i, (s, o) in enumerate(zip(mps, mpo)): 546 | if canonicalize == 'right': 547 | s = s.backend.einsum('abcdpq->adcbpq', s) 548 | o = o.backend.einsum('abcdpq->adcbpq', o) 549 | if svd_option: 550 | if i == 0: 551 | new_mps[0], _, left = s.backend.einsumsvd( 552 | 'abidpq,iBcDPQ->axc(dD)(pP)(qQ),bBx', 553 | s, o, option=svd_option, absorb_s='even', 554 | ) 555 | elif i == len(mps)-1: 556 | new_mps[-1] = s.backend.einsum('ijd,abkipP,kBcjqQ->a(bB)cd(pP)(qQ)', left, s, o) 557 | else: 558 | new_mps[i], _, left = s.backend.einsumsvd( 559 | 'ijd,abkipP,kBcjqQ->axcd(pP)(qQ),bBx', 560 | left, s, o, option=svd_option, absorb_s='even', 561 | ) 562 | else: 563 | new_mps[i] = sites.contract_x(s, o) 564 | 565 | if canonicalize == 'right': 566 | new_mps = new_mps[::-1] 567 | for i, s in enumerate(new_mps): 568 | new_mps[i] = s.backend.einsum('abcdpq->adcbpq', s) 569 | return new_mps 570 | 571 | 572 | def _vector_reshaper_BMPS(vector, peps_shape): 573 | # return vector.item()if vector.size == 1 else vector # FIXME: support both PEPS/PEPO 574 | return vector.item() if vector.size == 1 else vector.reshape( 575 | *[int(round(vector.size ** (1 / np.prod(peps_shape))))] * np.prod(peps_shape) 576 | ).transpose(*[i + j * peps_shape[0] for i, j in np.ndindex(*peps_shape)]) 577 | 578 | 579 | def create_env_cache(peps_obj, bmps_option): 580 | upper_mps_list = _contract_to_MPS_cache(peps_obj[:peps_obj.nrow], svd_option=bmps_option.svd_option) 581 | lower_mps_list = _contract_to_MPS_cache(peps_obj[1:], reverse=True, svd_option=bmps_option.svd_option) 582 | upper = {i: mps for i, mps in enumerate(upper_mps_list, 1)} 583 | lower = {i: mps for i, mps in enumerate(lower_mps_list)} 584 | return upper, lower 585 | 586 | 587 | def contract_with_env(state, env, up_idx, down_idx, bmps_option): 588 | up, down = env[0].get(up_idx), env[1].get(down_idx) 589 | if state is None: 590 | peps_obj = up if down is None else up.concatenate(down) 591 | elif up is None and down is None: 592 | peps_obj = state 593 | elif up is None: 594 | peps_obj = state.concatenate(down) 595 | elif down is None: 596 | peps_obj = up.concatenate(state) 597 | else: 598 | peps_obj = up.concatenate(state).concatenate(down) 599 | return peps_obj.contract(bmps_option) 600 | 601 | 602 | def _contract_to_MPS_cache(state, horizontal=False, reverse=False, svd_option=None): 603 | from .peps import PEPS 604 | num_rotate90 = bool(horizontal) * 1 + bool(reverse) * 2 605 | state = state.rotate(-num_rotate90) 606 | mps_list = [state.grid[0]] 607 | for mpo in state.grid[1:]: 608 | mps_list.append(_compress_contract_first(mps_list[-1], mpo, svd_option)) 609 | mps_list = [PEPS(mps.reshape(1, -1), state.backend).rotate(num_rotate90) for mps in mps_list] 610 | return [*reversed(mps_list)] if reverse else mps_list 611 | -------------------------------------------------------------------------------- /koala/peps/update.py: -------------------------------------------------------------------------------- 1 | import tensorbackends 2 | from tensorbackends.interface import ReducedSVD 3 | import numpy as np 4 | import scipy.linalg as la 5 | from .. import gates 6 | 7 | 8 | class UpdateOption: 9 | def __str__(self): 10 | return '{}({})'.format( 11 | type(self).__name__, 12 | ','.join('{}={}'.format(k, v) for k, v in vars(self).items()) 13 | ) 14 | 15 | def __repr__(self): 16 | return str(self) 17 | 18 | @property 19 | def name(self): 20 | return type(self).__name__ 21 | 22 | 23 | class DirectUpdate(UpdateOption): 24 | def __init__(self, svd_option=None): 25 | self.svd_option = svd_option 26 | 27 | 28 | class QRUpdate(UpdateOption): 29 | def __init__(self, rank=None): 30 | self.rank = rank 31 | 32 | 33 | class LocalGramQRUpdate(UpdateOption): 34 | def __init__(self, rank=None): 35 | self.rank = rank 36 | 37 | 38 | class LocalGramQRSVDUpdate(UpdateOption): 39 | def __init__(self, rank=None): 40 | self.rank = rank 41 | 42 | 43 | class DefaultUpdate(UpdateOption): 44 | def __init__(self, rank=None): 45 | self.rank = rank 46 | 47 | 48 | def apply_single_site_operator(state, operator, position, flip=False): 49 | operator = state.backend.astensor(operator) 50 | state.grid[position] = state.backend.einsum( 51 | 'ijklpx,yx->ijklpy' if flip else 'ijklxp,yx->ijklyp', 52 | state.grid[position], operator 53 | ) 54 | 55 | 56 | def apply_local_pair_operator(state, operator, positions, update_option, flip=False): 57 | if update_option is None: 58 | update_option = DefaultUpdate() 59 | if isinstance(update_option, DefaultUpdate): 60 | apply_local_pair_operator_qr(state, operator, positions, update_option.rank, flip) 61 | elif isinstance(update_option, DirectUpdate): 62 | apply_local_pair_operator_direct(state, operator, positions, update_option.svd_option, flip) 63 | elif isinstance(update_option, QRUpdate): 64 | apply_local_pair_operator_qr(state, operator, positions, update_option.rank, flip) 65 | elif isinstance(update_option, LocalGramQRUpdate): 66 | apply_local_pair_operator_local_gram_qr(state, operator, positions, update_option.rank, flip) 67 | elif isinstance(update_option, LocalGramQRSVDUpdate): 68 | apply_local_pair_operator_local_gram_qr_svd(state, operator, positions, update_option.rank, flip) 69 | else: 70 | raise ValueError(f'unknown update option: {update_option}') 71 | 72 | 73 | def apply_nonlocal_pair_operator(state, operator, positions, update_option, flip=False): 74 | assert len(positions) == 2 75 | x_pos, y_pos = positions 76 | 77 | path = [] 78 | moved_y_pos = [None, None] 79 | if y_pos[0] < x_pos[0]: 80 | path.extend(((i,y_pos[1]),(i+1,y_pos[1])) for i in range(y_pos[0], x_pos[0]-1)) 81 | moved_y_pos[0] = x_pos[0] - 1 82 | elif y_pos[0] > x_pos[0]: 83 | path.extend(((i,y_pos[1]),(i-1,y_pos[1])) for i in range(y_pos[0], x_pos[0]+1, -1)) 84 | moved_y_pos[0] = x_pos[0] + 1 85 | else: 86 | moved_y_pos[0] = x_pos[0] 87 | if y_pos[1] < x_pos[1]: 88 | path.extend(((moved_y_pos[0],j),(moved_y_pos[0],j+1)) for j in range(y_pos[1], x_pos[1]-1)) 89 | moved_y_pos[1] = x_pos[1] - 1 90 | elif y_pos[1] > x_pos[1]: 91 | path.extend(((moved_y_pos[0],j),(moved_y_pos[0],j-1)) for j in range(y_pos[1], x_pos[1]+1, -1)) 92 | moved_y_pos[1] = x_pos[1] + 1 93 | else: 94 | moved_y_pos[1] = x_pos[1] 95 | if moved_y_pos[0] != x_pos[0] and moved_y_pos[1] < x_pos[1]: 96 | new_moved_y_pos = [moved_y_pos[0], moved_y_pos[1]+1] 97 | path.append((tuple(moved_y_pos),tuple(new_moved_y_pos))) 98 | moved_y_pos = new_moved_y_pos 99 | elif moved_y_pos[0] != x_pos[0] and moved_y_pos[1] > x_pos[1]: 100 | new_moved_y_pos = [moved_y_pos[0], moved_y_pos[1]-1] 101 | path.append((tuple(moved_y_pos),tuple(new_moved_y_pos))) 102 | moved_y_pos = new_moved_y_pos 103 | moved_y_pos = tuple(moved_y_pos) 104 | 105 | for u, v in path: 106 | swap_local_pair(state, u, v, update_option) 107 | apply_local_pair_operator(state, operator, [x_pos, moved_y_pos], update_option, flip) 108 | for u, v in reversed(path): 109 | swap_local_pair(state, u, v, update_option) 110 | 111 | 112 | def swap_local_pair(state, x_pos, y_pos, update_option): 113 | if update_option is None: 114 | update_option = DefaultUpdate() 115 | if isinstance(update_option, DefaultUpdate): 116 | swap_local_pair_qr(state, x_pos, y_pos, update_option.rank) 117 | elif isinstance(update_option, DirectUpdate): 118 | swap_local_pair_direct(state, x_pos, y_pos, update_option.svd_option) 119 | elif isinstance(update_option, QRUpdate): 120 | swap_local_pair_qr(state, x_pos, y_pos, update_option.rank) 121 | elif isinstance(update_option, LocalGramQRUpdate): 122 | swap_local_pair_local_gram_qr(state, x_pos, y_pos, update_option.rank) 123 | elif isinstance(update_option, LocalGramQRSVDUpdate): 124 | swap_local_pair_local_gram_qr_svd(state, x_pos, y_pos, update_option.rank) 125 | else: 126 | raise ValueError(f'unknown update option: {update_option}') 127 | 128 | 129 | def truncate(state, update_option): 130 | identities = {} 131 | def apply_identity(x_pos, y_pos): 132 | dims = state.grid[x_pos].shape[4], state.grid[y_pos].shape[4] 133 | if dims not in identities: 134 | identities[dims] = state.backend.astensor(np.einsum('ux,vy->uvxy', np.eye(dims[0]), np.eye(dims[1]))) 135 | apply_local_pair_operator(state, identities[dims], (x_pos, y_pos), update_option) 136 | 137 | for i, j in np.ndindex(*state.shape): 138 | if i < state.shape[0] - 1: 139 | apply_identity((i, j), (i+1, j)) 140 | if j < state.shape[1] - 1: 141 | apply_identity((i, j), (i, j+1)) 142 | 143 | 144 | def apply_local_pair_operator_direct(state, operator, positions, svd_option, flip=False): 145 | assert len(positions) == 2 146 | if svd_option is None: 147 | svd_option = ReducedSVD() 148 | x_pos, y_pos = positions 149 | x, y = state.grid[x_pos], state.grid[y_pos] 150 | operator = state.backend.astensor(operator) 151 | 152 | if flip: 153 | if x_pos[0] < y_pos[0]: # [x y]^T 154 | prod_subscripts = 'abcdpx,cfghqy,uvxy->abndpu,nfghqv' 155 | scale_u_subscripts = 'absdpu,s->absdpu' 156 | scale_v_subscripts = 'sbcdpv,s->sbcdpv' 157 | elif x_pos[0] > y_pos[0]: # [y x]^T 158 | prod_subscripts = 'abcdpx,efahqy,uvxy->nbcdpu,efnhqv' 159 | scale_u_subscripts = 'sbcdpu,s->sbcdpu' 160 | scale_v_subscripts = 'absdpv,s->absdpv' 161 | elif x_pos[1] < y_pos[1]: # [x y] 162 | prod_subscripts = 'abcdpx,efgbqy,uvxy->ancdpu,efgnqv' 163 | scale_u_subscripts = 'ascdpu,s->ascdpu' 164 | scale_v_subscripts = 'abcspv,s->abcspv' 165 | elif x_pos[1] > y_pos[1]: # [y x] 166 | prod_subscripts = 'abcdpx,edghqy,uvxy->abcnpu,enghqv' 167 | scale_u_subscripts = 'abcspu,s->abcspu' 168 | scale_v_subscripts = 'ascdpv,s->ascdpv' 169 | else: 170 | assert False 171 | else: 172 | if x_pos[0] < y_pos[0]: # [x y]^T 173 | prod_subscripts = 'abcdxp,cfghyq,uvxy->abndup,nfghvq' 174 | scale_u_subscripts = 'absdup,s->absdup' 175 | scale_v_subscripts = 'sbcdvp,s->sbcdvp' 176 | elif x_pos[0] > y_pos[0]: # [y x]^T 177 | prod_subscripts = 'abcdxp,efahyq,uvxy->nbcdup,efnhvq' 178 | scale_u_subscripts = 'sbcdup,s->sbcdup' 179 | scale_v_subscripts = 'absdvp,s->absdvp' 180 | elif x_pos[1] < y_pos[1]: # [x y] 181 | prod_subscripts = 'abcdxp,efgbyq,uvxy->ancdup,efgnvq' 182 | scale_u_subscripts = 'ascdup,s->ascdup' 183 | scale_v_subscripts = 'abcsvp,s->abcsvp' 184 | elif x_pos[1] > y_pos[1]: # [y x] 185 | prod_subscripts = 'abcdxp,edghyq,uvxy->abcnup,enghvq' 186 | scale_u_subscripts = 'abcsup,s->abcsup' 187 | scale_v_subscripts = 'ascdvp,s->ascdvp' 188 | else: 189 | assert False 190 | 191 | u, s, v = state.backend.einsumsvd(prod_subscripts, x, y, operator, option=svd_option) 192 | s = s ** 0.5 193 | u = state.backend.einsum(scale_u_subscripts, u, s) 194 | v = state.backend.einsum(scale_v_subscripts, v, s) 195 | state.grid[x_pos] = u 196 | state.grid[y_pos] = v 197 | 198 | 199 | def apply_local_pair_operator_qr(state, operator, positions, rank, flip=False): 200 | assert len(positions) == 2 201 | svd_option = ReducedSVD(rank) 202 | x_pos, y_pos = positions 203 | x, y = state.grid[x_pos], state.grid[y_pos] 204 | operator = state.backend.astensor(operator) 205 | 206 | if flip: 207 | if x_pos[0] < y_pos[0]: # [x y]^T 208 | split_x_subscripts = 'abcdpx->abdpi,icx' 209 | split_y_subscripts = 'cfghqy->fghqj,jcy' 210 | recover_x_subscripts = 'abdpi,isu,s->absdpu' 211 | recover_y_subscripts = 'fghqj,jsv,s->sfghqv' 212 | elif x_pos[0] > y_pos[0]: # [y x]^T 213 | split_x_subscripts = 'abcdpx->bcdpi,iax' 214 | split_y_subscripts = 'efahqy->efhqj,jay' 215 | recover_x_subscripts = 'bcdpi,isu,s->sbcdpu' 216 | recover_y_subscripts = 'efhqj,jsv,s->efshqv' 217 | elif x_pos[1] < y_pos[1]: # [x y] 218 | split_x_subscripts = 'abcdpx->acdpi,ibx' 219 | split_y_subscripts = 'efgbqy->efgqj,jby' 220 | recover_x_subscripts = 'acdpi,isu,s->ascdpu' 221 | recover_y_subscripts = 'efgqj,jsv,s->efgsqv' 222 | elif x_pos[1] > y_pos[1]: # [y x] 223 | split_x_subscripts = 'abcdpx->abcpi,idx' 224 | split_y_subscripts = 'edghqy->eghqj,jdy' 225 | recover_x_subscripts = 'abcpi,isu,s->abcspu' 226 | recover_y_subscripts = 'eghqj,jsv,s->esghqv' 227 | else: 228 | assert False 229 | else: 230 | if x_pos[0] < y_pos[0]: # [x y]^T 231 | split_x_subscripts = 'abcdxp->abdpi,icx' 232 | split_y_subscripts = 'cfghyq->fghqj,jcy' 233 | recover_x_subscripts = 'abdpi,isu,s->absdup' 234 | recover_y_subscripts = 'fghqj,jsv,s->sfghvq' 235 | elif x_pos[0] > y_pos[0]: # [y x]^T 236 | split_x_subscripts = 'abcdxp->bcdpi,iax' 237 | split_y_subscripts = 'efahyq->efhqj,jay' 238 | recover_x_subscripts = 'bcdpi,isu,s->sbcdup' 239 | recover_y_subscripts = 'efhqj,jsv,s->efshvq' 240 | elif x_pos[1] < y_pos[1]: # [x y] 241 | split_x_subscripts = 'abcdxp->acdpi,ibx' 242 | split_y_subscripts = 'efgbyq->efgqj,jby' 243 | recover_x_subscripts = 'acdpi,isu,s->ascdup' 244 | recover_y_subscripts = 'efgqj,jsv,s->efgsvq' 245 | elif x_pos[1] > y_pos[1]: # [y x] 246 | split_x_subscripts = 'abcdxp->abcpi,idx' 247 | split_y_subscripts = 'edghyq->eghqj,jdy' 248 | recover_x_subscripts = 'abcpi,isu,s->abcsup' 249 | recover_y_subscripts = 'eghqj,jsv,s->esghvq' 250 | else: 251 | assert False 252 | 253 | xq, xr = state.backend.einqr(split_x_subscripts, x) 254 | yq, yr = state.backend.einqr(split_y_subscripts, y) 255 | 256 | u, s, v = state.backend.einsumsvd('ikx,jky,uvxy->isu,jsv', xr, yr, operator, option=svd_option) 257 | s = s ** 0.5 258 | state.grid[x_pos] = state.backend.einsum(recover_x_subscripts, xq, u, s) 259 | state.grid[y_pos] = state.backend.einsum(recover_y_subscripts, yq, v, s) 260 | 261 | 262 | def apply_local_pair_operator_local_gram_qr(state, operator, positions, rank, flip=False): 263 | assert len(positions) == 2 264 | x_pos, y_pos = positions 265 | x, y = state.grid[x_pos], state.grid[y_pos] 266 | operator = state.backend.astensor(operator) 267 | 268 | if flip: 269 | if x_pos[0] < y_pos[0]: # [x y]^T 270 | gram_x_subscripts = 'abcdpx,abCdpX->xcXC' 271 | gram_y_subscripts = 'cfghqy,CfghqY->ycYC' 272 | xq_subscripts = 'abcdpx,xci->abdpi' 273 | yq_subscripts = 'cfghqy,ycj->fghqj' 274 | recover_x_subscripts = 'abdpi,isu,s->absdpu' 275 | recover_y_subscripts = 'fghqj,jsv,s->sfghqv' 276 | elif x_pos[0] > y_pos[0]: # [y x]^T 277 | gram_x_subscripts = 'abcdpx,AbcdpX->xaXA' 278 | gram_y_subscripts = 'efahpy,efAhpY->yaYA' 279 | xq_subscripts = 'abcdpx,xai->bcdpi' 280 | yq_subscripts = 'efahqy,yaj->efhqj' 281 | recover_x_subscripts = 'bcdpi,isu,s->sbcdpu' 282 | recover_y_subscripts = 'efhqj,jsv,s->efshqv' 283 | elif x_pos[1] < y_pos[1]: # [x y] 284 | gram_x_subscripts = 'abcdpx,aBcdpX->xbXB' 285 | gram_y_subscripts = 'efgbqy,efgBqY->ybYB' 286 | xq_subscripts = 'abcdpx,xbi->acdpi' 287 | yq_subscripts = 'efgbqy,ybj->efgqj' 288 | recover_x_subscripts = 'acdpi,isu,s->ascdpu' 289 | recover_y_subscripts = 'efgqj,jsv,s->efgsqv' 290 | elif x_pos[1] > y_pos[1]: # [y x] 291 | gram_x_subscripts = 'abcdpx,abcDpX->xdXD' 292 | gram_y_subscripts = 'edghqy,eDghqY->ydYD' 293 | xq_subscripts = 'abcdpx,xdi->abcpi' 294 | yq_subscripts = 'edghqy,ydj->eghqj' 295 | recover_x_subscripts = 'abcpi,isu,s->abcspu' 296 | recover_y_subscripts = 'eghqj,jsv,s->esghqv' 297 | else: 298 | assert False 299 | else: 300 | if x_pos[0] < y_pos[0]: # [x y]^T 301 | gram_x_subscripts = 'abcdxp,abCdXp->xcXC' 302 | gram_y_subscripts = 'cfghyq,CfghYq->ycYC' 303 | xq_subscripts = 'abcdxp,xci->abdpi' 304 | yq_subscripts = 'cfghyq,ycj->fghqj' 305 | recover_x_subscripts = 'abdpi,isu,s->absdup' 306 | recover_y_subscripts = 'fghqj,jsv,s->sfghvq' 307 | elif x_pos[0] > y_pos[0]: # [y x]^T 308 | gram_x_subscripts = 'abcdxp,AbcdXp->xaXA' 309 | gram_y_subscripts = 'efahyq,efAhYq->yaYA' 310 | xq_subscripts = 'abcdxp,xai->bcdpi' 311 | yq_subscripts = 'efahyq,yaj->efhqj' 312 | recover_x_subscripts = 'bcdpi,isu,s->sbcdup' 313 | recover_y_subscripts = 'efhqj,jsv,s->efshvq' 314 | elif x_pos[1] < y_pos[1]: # [x y] 315 | gram_x_subscripts = 'abcdxp,aBcdXp->xbXB' 316 | gram_y_subscripts = 'efgbyq,efgBYq->ybYB' 317 | xq_subscripts = 'abcdxp,xbi->acdpi' 318 | yq_subscripts = 'efgbyq,ybj->efgqj' 319 | recover_x_subscripts = 'acdpi,isu,s->ascdup' 320 | recover_y_subscripts = 'efgqj,jsv,s->efgsvq' 321 | elif x_pos[1] > y_pos[1]: # [y x] 322 | gram_x_subscripts = 'abcdxp,abcDXp->xdXD' 323 | gram_y_subscripts = 'edghyq,eDghYq->ydYD' 324 | xq_subscripts = 'abcdxp,xdi->abcpi' 325 | yq_subscripts = 'edghyq,ydj->eghqj' 326 | recover_x_subscripts = 'abcpi,isu,s->abcsup' 327 | recover_y_subscripts = 'eghqj,jsv,s->esghvq' 328 | else: 329 | assert False 330 | 331 | def gram_qr_local(backend, a, gram_a_subscripts, q_subscripts): 332 | gram_a = backend.einsum(gram_a_subscripts, a.conj(), a) 333 | d, xi = gram_a.shape[:2] 334 | 335 | # local 336 | gram_a = gram_a.numpy().reshape(d*xi, d*xi) 337 | w, v = la.eigh(gram_a, overwrite_a=True) 338 | s = np.clip(w, 0, None) ** 0.5 339 | s_pinv = np.divide(1, s, out=np.zeros_like(s), where=s!=0) 340 | r = np.einsum('j,ij->ji', s, v.conj()).reshape(d*xi, d, xi) 341 | r_inv = np.einsum('j,ij->ij', s_pinv, v).reshape(d, xi, d*xi) 342 | 343 | r = backend.astensor(r) 344 | r_inv = backend.astensor(r_inv) 345 | q = backend.einsum(q_subscripts, a, r_inv) 346 | return q, r 347 | 348 | xq, xr = gram_qr_local(state.backend, x, gram_x_subscripts, xq_subscripts) 349 | yq, yr = gram_qr_local(state.backend, y, gram_y_subscripts, yq_subscripts) 350 | 351 | u, s, v = state.backend.einsumsvd('ixk,jyk,uvxy->isu,jsv', xr, yr, operator, option=ReducedSVD(rank)) 352 | s = s ** 0.5 353 | state.grid[x_pos] = state.backend.einsum(recover_x_subscripts, xq, u, s) 354 | state.grid[y_pos] = state.backend.einsum(recover_y_subscripts, yq, v, s) 355 | 356 | 357 | def apply_local_pair_operator_local_gram_qr_svd(state, operator, positions, rank, flip=False): 358 | assert len(positions) == 2 359 | x_pos, y_pos = positions 360 | x, y = state.grid[x_pos], state.grid[y_pos] 361 | 362 | if flip: 363 | if x_pos[0] < y_pos[0]: # [x y]^T 364 | gram_x_subscripts = 'abcdpx,abCdpX->xcXC' 365 | gram_y_subscripts = 'cfghqy,CfghqY->ycYC' 366 | recover_x_subscripts = 'abcdpx,cxsu->absdpu' 367 | recover_y_subscripts = 'cfghqy,cysv->sfghqv' 368 | elif x_pos[0] > y_pos[0]: # [y x]^T 369 | gram_x_subscripts = 'abcdpx,AbcdpX->xaXA' 370 | gram_y_subscripts = 'efahpy,efAhpY->yaYA' 371 | recover_x_subscripts = 'abcdpx,axsu->sbcdpu' 372 | recover_y_subscripts = 'efahqy,aysv->efshqv' 373 | elif x_pos[1] < y_pos[1]: # [x y] 374 | gram_x_subscripts = 'abcdpx,aBcdpX->xbXB' 375 | gram_y_subscripts = 'efgbqy,efgBqY->ybYB' 376 | recover_x_subscripts = 'abcdpx,bxsu->ascdpu' 377 | recover_y_subscripts = 'efgbqy,bysv->efgsqv' 378 | elif x_pos[1] > y_pos[1]: # [y x] 379 | gram_x_subscripts = 'abcdpx,abcDpX->xdXD' 380 | gram_y_subscripts = 'edghqy,eDghqY->ydYD' 381 | recover_x_subscripts = 'abcdpx,dxsu->abcspu' 382 | recover_y_subscripts = 'edghqy,dysv->esghqv' 383 | else: 384 | assert False 385 | else: 386 | if x_pos[0] < y_pos[0]: # [x y]^T 387 | gram_x_subscripts = 'abcdxp,abCdXp->xcXC' 388 | gram_y_subscripts = 'cfghyq,CfghYq->ycYC' 389 | recover_x_subscripts = 'abcdxp,cxsu->absdup' 390 | recover_y_subscripts = 'cfghyq,cysv->sfghvq' 391 | elif x_pos[0] > y_pos[0]: # [y x]^T 392 | gram_x_subscripts = 'abcdxp,AbcdXp->xaXA' 393 | gram_y_subscripts = 'efahyq,efAhYq->yaYA' 394 | recover_x_subscripts = 'abcdxp,axsu->sbcdup' 395 | recover_y_subscripts = 'efahyq,aysv->efshvq' 396 | elif x_pos[1] < y_pos[1]: # [x y] 397 | gram_x_subscripts = 'abcdxp,aBcdXp->xbXB' 398 | gram_y_subscripts = 'efgbyq,efgBYq->ybYB' 399 | recover_x_subscripts = 'abcdxp,bxsu->ascdup' 400 | recover_y_subscripts = 'efgbyq,bysv->efgsvq' 401 | elif x_pos[1] > y_pos[1]: # [y x] 402 | gram_x_subscripts = 'abcdxp,abcDXp->xdXD' 403 | gram_y_subscripts = 'edghyq,eDghYq->ydYD' 404 | recover_x_subscripts = 'abcdxp,dxsu->abcsup' 405 | recover_y_subscripts = 'edghyq,dysv->esghvq' 406 | else: 407 | assert False 408 | 409 | numpy_backend = tensorbackends.get('numpy') 410 | 411 | def gram_qr_local(backend, a, gram_a_subscripts): 412 | gram_a = backend.einsum(gram_a_subscripts, a.conj(), a) 413 | d, xi = gram_a.shape[:2] 414 | 415 | # local 416 | gram_a = gram_a.numpy().reshape(d*xi, d*xi) 417 | w, v = la.eigh(gram_a, overwrite_a=True) 418 | s = np.clip(w, 0, None) ** 0.5 419 | s_pinv = np.divide(1, s, out=np.zeros_like(s), where=s!=0) 420 | r = np.einsum('j,ij->ji', s, v.conj()).reshape(d*xi, d, xi) 421 | r_inv = np.einsum('j,ij->ij', s_pinv, v).reshape(d, xi, d*xi) 422 | return numpy_backend.tensor(r), numpy_backend.tensor(r_inv) 423 | 424 | xr, xr_inv = gram_qr_local(state.backend, x, gram_x_subscripts) 425 | yr, yr_inv = gram_qr_local(state.backend, y, gram_y_subscripts) 426 | 427 | operator = numpy_backend.tensor(operator if isinstance(operator, np.ndarray) else operator.numpy()) 428 | u, s, v = numpy_backend.einsumsvd('ixk,jyk,uvxy->isu,jsv', xr, yr, operator, option=ReducedSVD(rank)) 429 | s **= 0.5 430 | u = numpy_backend.einsum('xki,isu,s->kxsu', xr_inv, u, s) 431 | v = numpy_backend.einsum('ykj,jsv,s->kysv', yr_inv, v, s) 432 | 433 | u = state.backend.astensor(u) 434 | v = state.backend.astensor(v) 435 | state.grid[x_pos] = state.backend.einsum(recover_x_subscripts, x, u) 436 | state.grid[y_pos] = state.backend.einsum(recover_y_subscripts, y, v) 437 | 438 | 439 | def swap_local_pair_direct(state, x_pos, y_pos, svd_option): 440 | if svd_option is None: 441 | svd_option = ReducedSVD() 442 | 443 | if x_pos[0] < y_pos[0]: # [x y]^T 444 | prod_subscripts = 'abcdxp,cfghyq->absdyq,sfghxp' 445 | scale_u_subscripts = 'absdyq,s->absdyq' 446 | scale_v_subscripts = 'sfghxp,s->sfghxp' 447 | elif x_pos[0] > y_pos[0]: # [y x]^T 448 | prod_subscripts = 'abcdxp,efahyq->sbcdyq,efshxp' 449 | scale_u_subscripts = 'sbcdyq,s->sbcdyq' 450 | scale_v_subscripts = 'efshxp,s->efshxp' 451 | elif x_pos[1] < y_pos[1]: # [x y] 452 | prod_subscripts = 'abcdxp,efgbyq->ascdyq,efgsxp' 453 | scale_u_subscripts = 'ascdyq,s->ascdyq' 454 | scale_v_subscripts = 'efgsxp,s->efgsxp' 455 | elif x_pos[1] > y_pos[1]: # [y x] 456 | prod_subscripts = 'abcdxp,edghyq->abcsyq,esghxp' 457 | scale_u_subscripts = 'abcsyq,s->abcsyq' 458 | scale_v_subscripts = 'esghxp,s->esghxp' 459 | else: 460 | assert False 461 | 462 | x, y = state.grid[x_pos], state.grid[y_pos] 463 | u, s, v = state.backend.einsumsvd(prod_subscripts, x, y, option=svd_option) 464 | s = s ** 0.5 465 | u = state.backend.einsum(scale_u_subscripts, u, s) 466 | v = state.backend.einsum(scale_v_subscripts, v, s) 467 | state.grid[x_pos] = u 468 | state.grid[y_pos] = v 469 | 470 | 471 | def swap_local_pair_qr(state, x_pos, y_pos, rank): 472 | svd_option = ReducedSVD(rank) 473 | 474 | if x_pos[0] < y_pos[0]: # [x y]^T 475 | split_x_subscripts = 'abcdxp->abdi,icxp' 476 | split_y_subscripts = 'cfghyq->fghj,jcyq' 477 | recover_x_subscripts = 'abdi,isyq,s->absdyq' 478 | recover_y_subscripts = 'fghj,jsxp,s->sfghxp' 479 | elif x_pos[0] > y_pos[0]: # [y x]^T 480 | split_x_subscripts = 'abcdxp->bcdi,iaxp' 481 | split_y_subscripts = 'efahyq->efhj,jayq' 482 | recover_x_subscripts = 'bcdi,isyq,s->sbcdyq' 483 | recover_y_subscripts = 'efhj,jsxp,s->efshxp' 484 | elif x_pos[1] < y_pos[1]: # [x y] 485 | split_x_subscripts = 'abcdxp->acdi,ibxp' 486 | split_y_subscripts = 'efgbyq->efgj,jbyq' 487 | recover_x_subscripts = 'acdi,isyq,s->ascdyq' 488 | recover_y_subscripts = 'efgj,jsxp,s->efgsxp' 489 | elif x_pos[1] > y_pos[1]: # [y x] 490 | split_x_subscripts = 'abcdxp->abci,idxp' 491 | split_y_subscripts = 'edghyq->eghj,jdyq' 492 | recover_x_subscripts = 'abci,isyq,s->abcsyq' 493 | recover_y_subscripts = 'eghj,jsxp,s->esghxp' 494 | else: 495 | assert False 496 | 497 | x, y = state.grid[x_pos], state.grid[y_pos] 498 | 499 | xq, xr = state.backend.einqr(split_x_subscripts, x) 500 | yq, yr = state.backend.einqr(split_y_subscripts, y) 501 | 502 | u, s, v = state.backend.einsumsvd('ikxp,jkyq->isyq,jsxp', xr, yr, option=svd_option) 503 | s = s ** 0.5 504 | state.grid[x_pos] = state.backend.einsum(recover_x_subscripts, xq, u, s) 505 | state.grid[y_pos] = state.backend.einsum(recover_y_subscripts, yq, v, s) 506 | 507 | 508 | def swap_local_pair_local_gram_qr(state, x_pos, y_pos, rank): 509 | if x_pos[0] < y_pos[0]: # [x y]^T 510 | gram_x_subscripts = 'abcdxp,abCdXP->xpcXPC' 511 | gram_y_subscripts = 'cfghyq,CfghYQ->yqcYQC' 512 | xq_subscripts = 'abcdxp,xpci->abdi' 513 | yq_subscripts = 'cfghyq,yqcj->fghj' 514 | recover_x_subscripts = 'abdi,isyq,s->absdyq' 515 | recover_y_subscripts = 'fghj,jsxp,s->sfghxp' 516 | elif x_pos[0] > y_pos[0]: # [y x]^T 517 | gram_x_subscripts = 'abcdxp,AbcdXP->xpaXPA' 518 | gram_y_subscripts = 'efahyq,efAhYQ->yqaYQA' 519 | xq_subscripts = 'abcdxp,xpai->bcdi' 520 | yq_subscripts = 'efahyq,yqaj->efhj' 521 | recover_x_subscripts = 'bcdi,isyq,s->sbcdyq' 522 | recover_y_subscripts = 'efhj,jsxp,s->efshxp' 523 | elif x_pos[1] < y_pos[1]: # [x y] 524 | gram_x_subscripts = 'abcdxp,aBcdXP->xpbXPB' 525 | gram_y_subscripts = 'efgbyq,efgBYQ->yqbYQB' 526 | xq_subscripts = 'abcdxp,xpbi->acdi' 527 | yq_subscripts = 'efgbyq,yqbj->efgj' 528 | recover_x_subscripts = 'acdi,isyq,s->ascdyq' 529 | recover_y_subscripts = 'efgj,jsxp,s->efgsxp' 530 | elif x_pos[1] > y_pos[1]: # [y x] 531 | gram_x_subscripts = 'abcdxp,abcDXP->xpdXPD' 532 | gram_y_subscripts = 'edghyq,eDghYQ->yqdYQD' 533 | xq_subscripts = 'abcdxp,xpdi->abci' 534 | yq_subscripts = 'edghyq,yqdj->eghj' 535 | recover_x_subscripts = 'abci,isyq,s->abcsyq' 536 | recover_y_subscripts = 'eghj,jsxp,s->esghxp' 537 | else: 538 | assert False 539 | 540 | def gram_qr_local(backend, a, gram_a_subscripts, q_subscripts): 541 | gram_a = backend.einsum(gram_a_subscripts, a.conj(), a) 542 | d1, d2, xi = gram_a.shape[:3] 543 | 544 | # local 545 | gram_a = gram_a.numpy().reshape(d1*d2*xi, d1*d2*xi) 546 | w, v = la.eigh(gram_a, overwrite_a=True) 547 | s = np.clip(w, 0, None) ** 0.5 548 | s_pinv = np.divide(1, s, out=np.zeros_like(s), where=s!=0) 549 | r = np.einsum('j,ij->ji', s, v.conj()).reshape(d1*d2*xi, d1, d2, xi) 550 | r_inv = np.einsum('j,ij->ij', s_pinv, v).reshape(d1, d2, xi, d1*d2*xi) 551 | 552 | r = backend.astensor(r) 553 | r_inv = backend.astensor(r_inv) 554 | q = backend.einsum(q_subscripts, a, r_inv) 555 | return q, r 556 | 557 | x, y = state.grid[x_pos], state.grid[y_pos] 558 | 559 | xq, xr = gram_qr_local(state.backend, x, gram_x_subscripts, xq_subscripts) 560 | yq, yr = gram_qr_local(state.backend, y, gram_y_subscripts, yq_subscripts) 561 | 562 | u, s, v = state.backend.einsumsvd('ixpk,jyqk->isyq,jsxp', xr, yr, option=ReducedSVD(rank)) 563 | s = s ** 0.5 564 | state.grid[x_pos] = state.backend.einsum(recover_x_subscripts, xq, u, s) 565 | state.grid[y_pos] = state.backend.einsum(recover_y_subscripts, yq, v, s) 566 | 567 | 568 | def swap_local_pair_local_gram_qr_svd(state, x_pos, y_pos, rank): 569 | if x_pos[0] < y_pos[0]: # [x y]^T 570 | gram_x_subscripts = 'abcdxp,abCdXP->xpcXPC' 571 | gram_y_subscripts = 'cfghyq,CfghYQ->yqcYQC' 572 | xq_subscripts = 'abcdxp,xpci->abdi' 573 | yq_subscripts = 'cfghyq,yqcj->fghj' 574 | recover_x_subscripts = 'abcdxp,cxpsyq->absdyq' 575 | recover_y_subscripts = 'cfghyq,cyqsxp->sfghxp' 576 | elif x_pos[0] > y_pos[0]: # [y x]^T 577 | gram_x_subscripts = 'abcdxp,AbcdXP->xpaXPA' 578 | gram_y_subscripts = 'efahyq,efAhYQ->yqaYQA' 579 | xq_subscripts = 'abcdxp,xpai->bcdi' 580 | yq_subscripts = 'efahyq,yqaj->efhj' 581 | recover_x_subscripts = 'abcdxp,axpsyq->sbcdyq' 582 | recover_y_subscripts = 'efahyq,ayqsxp->efshxp' 583 | elif x_pos[1] < y_pos[1]: # [x y] 584 | gram_x_subscripts = 'abcdxp,aBcdXP->xpbXPB' 585 | gram_y_subscripts = 'efgbyq,efgBYQ->yqbYQB' 586 | xq_subscripts = 'abcdxp,xpbi->acdi' 587 | yq_subscripts = 'efgbyq,yqbj->efgj' 588 | recover_x_subscripts = 'abcdxp,bxpsyq->ascdyq' 589 | recover_y_subscripts = 'efgbyq,byqsxp->efgsxp' 590 | elif x_pos[1] > y_pos[1]: # [y x] 591 | gram_x_subscripts = 'abcdxp,abcDXP->xpdXPD' 592 | gram_y_subscripts = 'edghyq,eDghYQ->yqdYQD' 593 | xq_subscripts = 'abcdxp,xpdi->abci' 594 | yq_subscripts = 'edghyq,yqdj->eghj' 595 | recover_x_subscripts = 'abcdxp,dxpsyq->abcsyq' 596 | recover_y_subscripts = 'edghyq,dyqsxp->esghxp' 597 | else: 598 | assert False 599 | 600 | numpy_backend = tensorbackends.get('numpy') 601 | 602 | def gram_qr_local(backend, a, gram_a_subscripts, q_subscripts): 603 | gram_a = backend.einsum(gram_a_subscripts, a.conj(), a) 604 | d1, d2, xi = gram_a.shape[:3] 605 | 606 | # local 607 | gram_a = gram_a.numpy().reshape(d1*d2*xi, d1*d2*xi) 608 | w, v = la.eigh(gram_a, overwrite_a=True) 609 | s = np.clip(w, 0, None) ** 0.5 610 | s_pinv = np.divide(1, s, out=np.zeros_like(s), where=s!=0) 611 | r = np.einsum('j,ij->ji', s, v.conj()).reshape(d1*d2*xi, d1, d2, xi) 612 | r_inv = np.einsum('j,ij->ij', s_pinv, v).reshape(d1, d2, xi, d1*d2*xi) 613 | return numpy_backend.tensor(r), numpy_backend.tensor(r_inv) 614 | 615 | x, y = state.grid[x_pos], state.grid[y_pos] 616 | 617 | xr, xr_inv = gram_qr_local(state.backend, x, gram_x_subscripts, xq_subscripts) 618 | yr, yr_inv = gram_qr_local(state.backend, y, gram_y_subscripts, yq_subscripts) 619 | 620 | u, s, v = numpy_backend.einsumsvd('ixpk,jyqk->isyq,jsxp', xr, yr, option=ReducedSVD(rank)) 621 | s **= 0.5 622 | u = numpy_backend.einsum('xpki,isyq,s->kxpsyq', xr_inv, u, s) 623 | v = numpy_backend.einsum('yqkj,jsxp,s->kyqsxp', yr_inv, v, s) 624 | 625 | u = state.backend.astensor(u) 626 | v = state.backend.astensor(v) 627 | state.grid[x_pos] = state.backend.einsum(recover_x_subscripts, x, u) 628 | state.grid[y_pos] = state.backend.einsum(recover_y_subscripts, y, v) 629 | --------------------------------------------------------------------------------