├── tensor_reshape.m ├── MPS_Canonicalize.m ├── README.md ├── tensor_product.m ├── Example.m ├── rbm2mps.m ├── rbm2mps2.m └── rbm2mps3.m /tensor_reshape.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yzcj105/rbm2mps/HEAD/tensor_reshape.m -------------------------------------------------------------------------------- /MPS_Canonicalize.m: -------------------------------------------------------------------------------- 1 | function [ A1, L, Coef] = MPS_Canonicalize( A ) 2 | % finite size mps canonicalize 3 | % Writen by Jing Chen, Haidong Xie, Song Chen 4 | % Matrix Product State Representations 5 | %D. Perez-Garcials, F. Verstraete, M.M. Wolf, J.I. Cirac 6 | % arXiv:quant-ph/0608197 7 | % INPUT: A : cell of 3 leg tensor A 8 | % OUTPUT: A1 :cell of 3 leg tensor 9 | % The MPS is the tensor product of A1 10 | % L : cell of vector, 11 | % the entanglement spectrum on each bond 12 | % Coef: The coefficient factor for the MPS 13 | 14 | %------QR-from right to left---------------% 15 | len = numel( A ); 16 | A1 = cell( len,1 ); 17 | L = cell( len,1 ); 18 | R = 1; % A{len} is D * 1 * d ; 19 | for k = len : -1 :1 20 | [ R, A1{k} ] = MPS_qr(A{k},R); 21 | end 22 | A1{1} = A1{1} * sign( R ); 23 | Coef = abs(R); % number 24 | %----------svd from left to right----------% 25 | L{1} = 1; 26 | V = 1; 27 | for k = 1:len-1 28 | [ A1{k}, L{k+1}, V ] = MPS_svd(L{k},V,A1{k}); 29 | end 30 | [ A1{len} ] = MPS_svd(L{len},V,A1{len}); 31 | end 32 | function [ R1, A1 ] = MPS_qr( A,R ) 33 | [ dm ] = size( A,3 ); 34 | Tmp1 = tensor_product( 'AB1',A,'Ab1',R,'Bb'); 35 | dB1 = size( Tmp1,2 ); 36 | Tmp1 = tensor_reshape( Tmp1,'ABm','Bm','A' ); 37 | [ Q, R1 ] = qr( Tmp1,0 ); 38 | dA1 = numel(Q)/dm/dB1; 39 | Tmp2 = reshape( Q, [dB1,dm,dA1 ] ); 40 | A1 = permute( Tmp2 ,[ 3 1 2 ]); 41 | end 42 | function [ A1, L1, V1 ] = MPS_svd( L,V,A ) 43 | LV = diag(L)*V'; 44 | [ LVA ] = tensor_product('aB1',LV,'aA',A,'AB1'); 45 | [ LVA ] = tensor_reshape( LVA, 'aB1', 'a1','B'); 46 | [ ~, L1, V1 ] = svd(LVA,'econ'); 47 | L1 = diag(L1); 48 | % smaller than the machine error has been truncated 49 | idx = find( L1 > eps,1,'last' ); 50 | L1 = L1(1:idx); 51 | V1 = V1(:,1:idx); 52 | %A1 = tensor_n_product('ab1',V,'Aa',A,'AB1',V1,'Bb'); 53 | Tmp = tensor_product('Ab1',A,'AB1',V1,'Bb'); 54 | A1 = tensor_product('ab1',Tmp,'Ab1',V,'Aa'); 55 | end 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RBM to MPS Translation 2 | 3 | Demo code for [arXiv:1701.04831v1](http://arxiv.org/abs/1701.04831v1). You are free to use these codes. Please kindly cite the paper 4 | - Jing Chen, Song Cheng, Haidong Xie, Lei Wang, and Tao Xiang, *Equivalence of restricted Boltzmann machines and tensor network states*, [Phys. Rev. B 97, 085104 (2018)](https://journals.aps.org/prb/abstract/10.1103/PhysRevB.97.085104) 5 | 6 | We implement three approaches using Matlab 7 | 8 | * rbm2mps1.m: The algorithm of Fig. 2. The MPS bond dimension is $2^n$, where $n$ is the number of cut RBM connections. 9 | * Input: 10 | * W: $n_v$ by $n_h$ weight matrix $W$ 11 | * a: vector of size $n_v$ for visible units bias 12 | * b: vector of size $n_h$ for hidden units bias 13 | * bpos: the vector telling which piece the hidden units belongs to. See Fig. 2. 14 | * Output: 15 | * mps: a cell of MPS tensors 16 | 17 | * rbm2mps2.m: The algorithm of Fig. 4. The MPS bond dimension is $2^m$, where $m=\min(|A_{1}|, |B_{1}|)$ is the size of the interface region. The algorithm copies the physical degree of freedoms in the interface region to the virtual bond. 18 | * Input: 19 | * W: $n_v$ by $n_h$ weight matrix $W$ 20 | * a: vector of size $n_v$ for visible units bias 21 | * b: vector of size $n_h$ for hidden units bias 22 | * Output: 23 | * mps: a cell of MPS tensors 24 | 25 | * rbm2mps3.m: The MPS bond dimension is $2^k$, where $k$ is a minimal number of units (no matter visible or hidden) which can break the RBM into a product state if they are fixed. The algorithm copies the $k$ degree of freedoms in the interface region to the virtual bond. 26 | * Input: 27 | * W: $n_v$ by $n_h$ weight matrix $W$ 28 | * a: vector of size $n_v$ for visible units bias 29 | * b: vector of size $n_h$ for hidden units bias 30 | * Output: 31 | * mps: a cell of MPS tensors 32 | 33 | The MPS bond dimensions: 34 | rbm2mps3.m <= rbm2mps2.m <= rbm2mps.m 35 | 36 | ## Auxillary tensor programs ## 37 | * MPS\_Canonicalize.m: Canonicalize a finite MPS and return the entanglement spectrum of each bond. 38 | 39 | * tensor\_product.m: Contract two tensors and permute the indices according to the given order. 40 | `C = tensor_product('acm',A,'abm',X,'cb')` does $C(a,c,m) = \sum_{c}A_{a,b,m)X_{c,b}$ 41 | 42 | * tensor\_reshape.m: 43 | `B = tensor_reshape( A,'abcd','ac','bd')` reshapes a tensor into a matrix $B( ac,bd ) = A(a,b,c,d)$. 44 | 45 | 46 | ## Example ## 47 | * Example.m: Using the RBM architecure in Fig. 1(a) as an example, we construct the MPS in three approaches (rbm2mps,rbm2mps2,rbm2mps3). The bond dimensions are consistent with Fig. 2(c) and Fig. 4(c) respectively. The three MPS are identical in their canonical form. 48 | -------------------------------------------------------------------------------- /tensor_product.m: -------------------------------------------------------------------------------- 1 | function [C cindex] = tensor_product(varargin) 2 | %{ 3 | FUNCTION TENSOR_PRODUCT v1.0 18/JAU/2017 4 | 5 | Article 6 | [arXiv:1701.04831](http://arxiv.org/abs/1701.04831) 7 | 8 | A Matlab code for rbm2mps on GitHub 9 | https://github.com/yzcj105/rbm2mps 10 | 11 | Authors 12 | Jing Chen E-mail: yzcj105@126.com 13 | Song Cheng 14 | Haidong Xie 15 | The Institute of Physics, Chinese Academy of Sciences 16 | 17 | A Matlab code for the contraction of two tensors and permute the index of the given order. 18 | For example 19 | C(a,c,m) = sum_{c}A_{a,b,m)X_{c,b} 20 | MATLAB code: 21 | C = tensor_product('acm',A,'abm',X,'cb'); 22 | 23 | Remarks: 24 | The function is often used together with tensor_reshape 25 | 26 | 27 | 28 | %} 29 | % varargin is cindex 30 | %C(cindex)=A(aindex)*B(bindex) 31 | % the same string in index will be summed up 32 | % A,B,C is muti dimention array 33 | 34 | %get all the permute order 35 | if nargin == 4 36 | % The indices of the output tensor is not given. 37 | A = varargin{1}; 38 | aindex = varargin{2}; 39 | B = varargin{3}; 40 | bindex = varargin{4}; 41 | elseif nargin == 5 42 | % The indices of the output tensor is given. 43 | cindex = varargin{1}; 44 | A = varargin{2}; 45 | aindex = varargin{3}; 46 | B = varargin{4}; 47 | bindex = varargin{5}; 48 | end 49 | a_length = length ( aindex ); 50 | b_length = length ( bindex ); 51 | 52 | size_a = size(A); 53 | size_a(end+1:a_length) = 1; 54 | size_b = size(B); 55 | size_b(end+1:b_length) = 1; 56 | 57 | [com_in_a, com_in_b ] = find_common ( aindex, bindex ); 58 | 59 | if ~all(size_a(com_in_a)==size_b(com_in_b)) 60 | error('The dimention doesnot match!'); 61 | end 62 | 63 | diff_in_a = 1:a_length; 64 | diff_in_a ( com_in_a ) = []; 65 | diff_in_b = 1:b_length; 66 | diff_in_b ( com_in_b ) = []; 67 | temp_idx = [ aindex(diff_in_a) , bindex(diff_in_b) ]; 68 | 69 | if nargin ==5 70 | [ ix1 ix2 ] = find_common ( temp_idx , cindex ); 71 | ix_temp (ix2) = ix1 ; 72 | else 73 | cindex = temp_idx; 74 | end 75 | c_length = length(cindex); 76 | % mutiply 77 | if any([ com_in_a diff_in_a ] ~= 1:a_length) 78 | A = permute( A, [ com_in_a diff_in_a ] ); 79 | end 80 | if any([ com_in_b diff_in_b ] ~= 1:b_length) 81 | B = permute( B, [ com_in_b diff_in_b ] ); 82 | end 83 | 84 | sda = prod(size_a(diff_in_a)); 85 | sc = prod(size_a(com_in_a)); 86 | sdb = prod(size_b(diff_in_b)); 87 | 88 | A = reshape(A,[sc,sda,1]); 89 | B = reshape(B,[sc,sdb,1]); 90 | 91 | C = A.' * B ; 92 | 93 | C = reshape(C,[size_a(diff_in_a),size_b(diff_in_b),1,1]); 94 | 95 | if c_length > 1 && nargin == 5 && any(ix_temp ~= 1:c_length) 96 | C = permute(C,ix_temp); 97 | end 98 | 99 | function [com_a, com_b] = find_common ( a, b) 100 | % find the common elements 101 | a = a.'; 102 | a_len = length( a ); 103 | b_len = length( b ); 104 | a = a(:,ones (1,b_len) ); 105 | b = b( ones(a_len ,1),:); 106 | %[b a] = meshgrid(b,a); 107 | [ com_a ,com_b ] = find ( a == b ); 108 | com_a = com_a.'; 109 | com_b = com_b.'; 110 | 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /Example.m: -------------------------------------------------------------------------------- 1 | function Example() 2 | %{ 3 | FUNCTION EXAMPLE v1.0 18/JAU/2017 4 | 5 | This code construct the MPS in different ways Fig.2 and Fig.4 in Sec.II. 6 | And verify their equivalence under canonical form. 7 | 8 | Article 9 | Jing Chen, Song Cheng, Haidong Xie, Lei Wang, and Tao Xiang 10 | On the Equivalence of Restricted Boltzmann Machines and Tensor Network States 11 | [arXiv:1701.04831](http://arxiv.org/abs/1701.04831) 12 | 13 | A Matlab code for rbm2mps on GitHub 14 | https://github.com/yzcj105/rbm2mps 15 | 16 | 17 | Authors 18 | Song Cheng 19 | Haidong Xie 20 | Jing Chen E-mail: yzcj105@126.com 21 | The Institute of Physics, Chinese Academy of Sciences 22 | 23 | 24 | Parameters:( See Eq.1 ) 25 | a (Input) Double precision array of length Nv. 26 | a(k) means the bias of v_i 27 | b (Input) Double precision array of length Nh. 28 | b(k) means the bias of h_i 29 | W (Input) Double precision matrix by Nv * Nh. 30 | W(i,j) means the weight in the connection between v_i and h_j 31 | bpos: 1*nh(int) (bias_position),tell which piece the hidden units belongs 32 | to in the cut, see Fig.2(b) bpos = [ 2 3 4 5]; 33 | mps (output) Cell of length Nv, Each element of cell mps{k} is a tensor 34 | of mps rep... with three bands: left,right,and phy. 35 | Nv The number of visible units. 36 | Nh The number of hidden uits. 37 | 38 | 39 | Remarks: 40 | 41 | 42 | 43 | 44 | %} 45 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 46 | % generate an RBM as Fig.1(a) with 6 visible units and 4 hidden units 47 | a=rand(1,6); 48 | b=rand(1,4); 49 | W = [ 0.43498 0 0 0 50 | -0.020515 0.49548 0 0 51 | -0.26821 0.46243 0.080192 0 52 | -0.10371 0.035067 0.030964 -0.48333 53 | 0 0 0.40121 0.30092 54 | 0 0 0 -0.35749]; 55 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 56 | % Construct MPS as Fig.(2,3) 57 | % put each 58 | bpos = [ 2 3 4 5 ]; 59 | mps1 = rbm2mps(W,b,a,bpos); 60 | % show the bond dimension 61 | D1 = MPS_D(mps1); 62 | fprintf('The bond dimensions of MPS in Fig.2 are\n'); 63 | disp(D1); 64 | % Construct MPS as Fig.4 65 | mps2 = rbm2mps2(W,b,a); 66 | D2 = MPS_D(mps2); 67 | fprintf('The bond dimensions of MPS in Fig.4 are\n'); 68 | disp(D2); 69 | 70 | mps_m3 = rbm2mps3(W,b,a); 71 | D3 = MPS_D(mps_m3); 72 | fprintf('The bond dimensions of MPS by rbm2mps3 are\n'); 73 | disp(D3); 74 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 75 | % Canonical mps1 and mps2 and compare them. 76 | % The redundency of MPS can be removed by Canonicalization, 77 | % which is a important property when simplifing an RBM 78 | % See Sec.V Fig.8 79 | [ mps1_cano L1 Coef1 ]= MPS_Canonicalize( mps1 ); 80 | [ mps2_cano L2 Coef2 ] = MPS_Canonicalize( mps2 ); 81 | [ mps3_cano L3 Coef3 ] = MPS_Canonicalize( mps_m3 ); 82 | 83 | ErrCoef = abs(Coef1-Coef2)/Coef2; 84 | [ ErrL ErrMPS ] = CompareMPS( mps1_cano,L1,mps2_cano,L2); 85 | [ ErrL3 ErrMPS3 ] = CompareMPS( mps2_cano,L2,mps3_cano,L3); 86 | fprintf('The difference between MPS by Fig.2 and Fig.4\n'); 87 | fprintf('The difference of the coefficient:\t%g\n',ErrCoef); 88 | fprintf('The difference of the entanglement spectrum in each bond:\n'); 89 | disp(ErrL); 90 | %If the entanglment is nearly degenerate, the difference may be large. 91 | % This comes from the gauge freedom in the nearly degenerate space. 92 | fprintf('The difference of the MPS tensors:\n'); 93 | disp(ErrMPS); 94 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 95 | D1_cano = MPS_D(mps1_cano); 96 | fprintf('The bond dimensions of canonicalized MPS are\n'); 97 | disp(D1_cano); 98 | % Which is the same as Fig.4 99 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 100 | % In Fig.2(b) If we permute the order of the hidden units 101 | bpos = [ 2 5 3 4 ]; 102 | mps3 = rbm2mps(W,b,a,bpos); 103 | [ mps3_cano, L3, Coef3 ] = MPS_Canonicalize(mps3); 104 | [ ErrL13 ErrMPS13 ] = CompareMPS( mps1_cano,L1,mps3_cano,L3); 105 | fprintf('Difference of MPS if we permute hidden units in Fig.2(b):\t%g\n',... 106 | max([ErrL13,ErrMPS13])); 107 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 108 | end 109 | function Ds = MPS_D(MPS) 110 | Ds = zeros(1,length(MPS)-1); 111 | % collect the bond dimension of a MPS 112 | for k = 1:length(MPS)-1 113 | Ds(k) = size(MPS{k},2); 114 | end 115 | end 116 | function [ ErrL ErrMPS ] = CompareMPS( mps1,L1,mps2,L2) 117 | % Compare the difference of two MPS 118 | % ErrL is for the entanglement spectrum 119 | % ErrMPS is for the MPS tensor 120 | len = length(mps1); 121 | ErrMPS = zeros(1,len); 122 | ErrL = zeros(1,len); 123 | for k = 1:length( mps1 ) 124 | % before comparing, we should remove the signs freedom; 125 | v1 = abs(mps1{k}(:)); 126 | v2 = abs(mps2{k}(:)); 127 | ErrMPS(k) = norm(v1-v2); 128 | ErrL(k) = norm(L1{k}-L2{k}); 129 | end 130 | end 131 | -------------------------------------------------------------------------------- /rbm2mps.m: -------------------------------------------------------------------------------- 1 | function [mps] = rbm2mps(W, b, a, bpos) 2 | %{ 3 | FUNCTION RBM2MPS v1.0 18/JAU/2017 4 | 5 | Article 6 | Jing Chen, Song Cheng, Haidong Xie, Lei Wang, and Tao Xiang 7 | On the Equivalence of Restricted Boltzmann Machines and Tensor Network States 8 | [arXiv:1701.04831](http://arxiv.org/abs/1701.04831) 9 | 10 | A Matlab code for rbm2mps on GitHub 11 | https://github.com/yzcj105/rbm2mps 12 | 13 | 14 | Authors 15 | Song Cheng 16 | Jing Chen E-mail: yzcj105@126.com 17 | Haidong Xie 18 | The Institute of Physics, Chinese Academy of Sciences 19 | 20 | A Matlab code for rbm2mps. 21 | Fig.2 and Fig.3 in Sec.II of the article 22 | 23 | Parameters:( See Eq.1 ) 24 | a (Input) Double precision array of length Nv. 25 | a(k) means the bias of v_i 26 | b (Input) Double precision array of length Nh. 27 | b(k) means the bias of h_i 28 | W (Input) Double precision matrix by Nv * Nh. 29 | W(i,j) means the weight in the connection between v_i and h_j 30 | bpos: 1*nh(int) (bias_position),tell which piece the hidden units belongs 31 | to in the cut, see Fig.2(b) bpos = [ 2 3 4 5]; 32 | mps (output) Cell of length Nv, Each element of cell mps{k} is a tensor 33 | of mps rep... with three bands: left,right,and phy. 34 | Nv The number of visible units. 35 | Nh The number of hidden uits. 36 | 37 | 38 | Remarks: 39 | 40 | 41 | 42 | 43 | %} 44 | 45 | 46 | nv = length(a); 47 | nh = length(b); 48 | % each piece related to one of MPS's local matrix 49 | pieces = cell(nv,1); 50 | % N record how many hiddens each visible connected with 51 | N = ones(nv,1); 52 | maxInd = 1; 53 | % define indices of Lambda_v and Lambda_h 54 | L_v_ind = zeros(nv,nh); 55 | L_h_ind = L_v_ind; 56 | for k = 1:nv*nh 57 | if( abs(W(k)) > eps ) 58 | L_v_ind(k) = maxInd; 59 | L_h_ind(k) = maxInd +1; 60 | maxInd = maxInd + 2; 61 | end 62 | end 63 | % build tensor Lambda_v as Eq.5 64 | for k = 1:nv 65 | index = L_v_ind(k,:); 66 | % to distinguish with vitual bond, physical bond indexes are negative 67 | pieces{k}{1} = struct('type','L','bias',a(k),'order',sum(abs(W(k,:))>eps)+1,'index',[ index(index>0),-k]); 68 | end 69 | % build tensor Lambda_h as Eq.6 70 | for k = 1:nh 71 | % tell N the information of h2v mapping 72 | N(bpos(k))= N(bpos(k))+1; 73 | index = L_h_ind(:,k)'; 74 | %put all hiddens connecting with a same visible into a same piece 75 | pieces{bpos(k)}{N(bpos(k))} = struct('type','L','bias',b(k),'order',sum(abs(W(:,k))>eps),'index',index(index>0)); 76 | end 77 | 78 | % calculate local matrix M 79 | % bigind is MPS's vitual bond index 80 | bigind = cell(nv+1,1); 81 | % this for-for circles will cut nonlocal long-term connection into 82 | % several local short-term connection see Fig.3 83 | % in equation M = P*Q in Fig3, we choose P = M, and Q = I. 84 | for i = 1:nh 85 | for j = 1:nv 86 | if (abs(W(j,i))>eps) 87 | % we want cutting connection always from right to left 88 | % if visible in the left of hidden 89 | if( bpos(i) > j) 90 | %first (MPS) piece of this connection 91 | k = bpos(i); 92 | %first index of this connection 93 | ind1 = L_h_ind(j,i); 94 | %last piece of this connection 95 | end_piece = j; 96 | %last index of this connection 97 | end_ind = L_v_ind(j,i); 98 | % visible in the right of hidden 99 | else 100 | k = j; 101 | ind1 = L_v_ind(j,i); 102 | end_piece = bpos(i); 103 | end_ind = L_h_ind(j,i); 104 | end 105 | % which is false only when bpos(i) == j 106 | % and bpos(i) == j means it's no longer a nonlocal connection 107 | while( k > end_piece) 108 | N(k)= N(k)+1; 109 | ind2 = maxInd; 110 | maxInd = maxInd + 1; 111 | % matrix product an identity on first piece, did nothing. 112 | % then direct product an identity matrix on following 113 | % pieces before we get to the end piece 114 | pieces{k}{N(k)} = struct('type','I','index',[ ind2, ind1]); 115 | bigind{k} = [bigind{k},ind2]; 116 | ind1 = ind2; 117 | k = k - 1; 118 | end 119 | N(k)= N(k)+1; 120 | % put weight matrix on the end index 121 | pieces{k}{N(k)} = struct('type','M','Weight',W(j,i),'index',[ end_ind, ind1 ]); 122 | end 123 | end 124 | end 125 | % constructing mps 126 | mps = cell(nv,1); 127 | for k = 1:nv 128 | [ mps{k} ] = contract_tensor(pieces{k},bigind{k},bigind{k+1},-k); 129 | end 130 | 131 | 132 | function [ T ] = contract_tensor( tens,left_ind,right_ind,physical_ind ) 133 | % contract visible tensor in the same piece 134 | % see Fig.2 135 | temp = constructTensor(tens{1}); 136 | temp_idx = tens{1}.index; 137 | 138 | for k = 2:length(tens) 139 | T2 = constructTensor(tens{k}); 140 | [temp,temp_idx] = tensor_product(temp,temp_idx,T2,tens{k}.index); 141 | end 142 | % reshape tensor into MPS pieces (3rd order tensor) 143 | T = tensor_reshape( temp, temp_idx,left_ind,right_ind,physical_ind ); 144 | 145 | function T = constructTensor( ten ) 146 | switch ten.type 147 | % eq. (3-5) of paper 148 | case 'I' 149 | % Fig.3 We choose Q = I. 150 | T = eye(2); 151 | case 'M' 152 | T = [ 1,1;1 exp(ten.Weight)]; 153 | case 'L' 154 | T = zeros([2*ones(1,ten.order),1,1]); 155 | T(1) = 1; 156 | T(end) = exp(ten.bias); 157 | end 158 | 159 | 160 | 161 | 162 | 163 | -------------------------------------------------------------------------------- /rbm2mps2.m: -------------------------------------------------------------------------------- 1 | function [ mps ] = rbm2mps2( W, b, a ) 2 | %{ 3 | FUNCTION RBM2MPS2 v1.0 18/JAU/2017 4 | 5 | Article 6 | Jing Chen, Song Cheng, Haidong Xie, Lei Wang, and Tao Xiang 7 | On the Equivalence of Restricted Boltzmann Machines and Tensor Network States 8 | [arXiv:1701.04831](http://arxiv.org/abs/1701.04831) 9 | 10 | A Matlab code for rbm2mps on GitHub 11 | https://github.com/yzcj105/rbm2mps 12 | 13 | 14 | Authors 15 | Haidong Xie 16 | Jing Chen E-mail: yzcj105@126.com 17 | Song Cheng 18 | The Institute of Physics, Chinese Academy of Sciences 19 | 20 | A Matlab code for rbm2mps. 21 | Fig.4 in Sec.II of the article 22 | 23 | Parameters:( See Eq.1 ) 24 | a (Input) Double precision array of length Nv. 25 | a(k) means the bias of v_i 26 | b (Input) Double precision array of length Nh. 27 | b(k) means the bias of h_i 28 | W (Input) Double precision matrix by Nv * Nh. 29 | W(i,j) means the weight in the connection between v_i and h_j 30 | mps (output) Cell of length Nv, Each element of cell mps{k} is a tensor 31 | of mps rep... with three bands: left,right,and phy. 32 | Nv The number of visible units. 33 | Nh The number of hidden uits. 34 | 35 | 36 | Remarks: 37 | 38 | 39 | 40 | 41 | %} 42 | 43 | % Sub function rbm2data 44 | % Get the index and position of hidden units form the given rbm. 45 | [Lindex,Rindex,Containhi]=rbm2data(W); 46 | 47 | % Sub function index2tensor 48 | % Get mps (result) form the given parameters. 49 | mps=index2tensor(a,b,W,Lindex,Rindex,Containhi); 50 | 51 | end 52 | 53 | function [ Lindex,Rindex,Contain_hi ]=rbm2data(W) 54 | %{ 55 | Sub function rbm2data 56 | Get the index and position of hidden units form the given rbm. 57 | Input: W: a matrix by Nv * Nh 58 | W(i,j) means the weight in the connection between v_i and h_j 59 | Output: Lindex, Cell of length Nv 60 | Lindex{k} means the left index of tensor k. 61 | Rindex, Cell of length Nv 62 | Rindex{k} means the right index of tensor k. 63 | Contain_hi, Cell of length Nv 64 | Contain_hi{k} means the tensor k contains the hidden units list 65 | in Contain_hi{k}. 66 | %} 67 | [Nv,Nh]=size(W); 68 | 69 | Connect_=abs(W*W')>eps; 70 | %Connect_ is nv*nv logical matrix, Connect_(i,j) means visible units i and 71 | %j is connected. 72 | banddim=zeros(1,Nv); 73 | bandindex=cell(1,Nv); 74 | for k=2:Nv 75 | % Calculate each band of mps 76 | % k, means band between tensor:k-1 and tensor k 77 | index_l=1:k-1; 78 | % index_l: the visible units in the left of band above, which could 79 | % connect to the right part of band above. 80 | index_r=k:Nv; 81 | % the oppesite meanning of index_l. 82 | flag_find=0; 83 | 84 | for k_banddim=1:min(k-1,Nv+1-k) 85 | %k_banddim shows the band dimension of the calculate band, we wish 86 | %it as small as possiable. band dimension D=2^k_banddim 87 | all_choose=nchoosek(1:Nv,k_banddim); 88 | % all possiable chooses with the given k_banddim 89 | 90 | % we search the first result and just use it in this code, the 91 | % other result may cause degenerate sulotions which we do not 92 | % mention about. 93 | for k_=1:size(all_choose,1) 94 | k_choose=all_choose(k_,:); 95 | %k_choose, the choosen index 96 | index_lnew=setdiff(index_l,k_choose); 97 | index_rnew=setdiff(index_r,k_choose); 98 | Con_Mat=Connect_(index_lnew,index_rnew); 99 | if all(~Con_Mat(:)) 100 | % if all element in Con_Mat are zeros, the left part and 101 | % the right part by this choosen index become direct 102 | % product. for detail see eq.8 and fig.4 in paper 103 | flag_find=1; 104 | banddim(k)=k_banddim; 105 | %band dimension of the band 106 | bandindex{k}=k_choose(:); 107 | %the choosen index 108 | % disp(banddim(k)); 109 | % disp(bandindex{k}); 110 | break; 111 | end 112 | end 113 | if flag_find==1 114 | break; 115 | %quit the loop, and fond sulotion 116 | end 117 | 118 | 119 | end 120 | end 121 | %disp 122 | %disp(banddim); 123 | % for k=2:Nv 124 | %disp(bandindex{k}); 125 | % end 126 | 127 | % locat hi 128 | Contain_hi=cell(1,Nv); 129 | for k_h=1:Nh 130 | % find where to settle down the hidden units 131 | flag=0; 132 | for k_v=1:Nv 133 | % get the index of left,right,physical. 134 | if k_v==1 135 | index_l=[]; 136 | index_r=bandindex{k_v+1}; 137 | index_p=k_v; 138 | elseif k_v==Nv 139 | index_l=bandindex{k_v}; 140 | index_r=[]; 141 | index_p=k_v; 142 | else 143 | index_l=bandindex{k_v}; 144 | index_r=bandindex{k_v+1}; 145 | index_p=k_v; 146 | end 147 | index_tot=union(index_l,index_r); 148 | index_tot=union(index_tot,index_p); 149 | % index_tot means all the index of the tensor k 150 | index_hi=find(abs(W(:,k_h))>eps); 151 | % index_hi means the index linked to of h_i. 152 | if isempty(setdiff(index_hi,index_tot)) 153 | %index_hi belong to index_tot 154 | flag=1; 155 | Contain_hi{k_v}(end+1)=k_h; 156 | break; 157 | end 158 | end 159 | if flag==0 160 | disp('not match, hi!'); 161 | % show error message, there is a h_i, that no tensor could settle 162 | % down it. 163 | end 164 | end 165 | % for k=1:Nv 166 | %disp(Contain_hi{k}); 167 | % end 168 | 169 | Lindex=cell(1,Nv); 170 | Lindex(2:Nv)=bandindex(2:Nv); 171 | Rindex=cell(1,Nv); 172 | Rindex(1:Nv-1)=bandindex(2:Nv); 173 | % for output 174 | end 175 | 176 | function mps=index2tensor(a,b,W,Lindex,Rindex,Containhi) 177 | %{ 178 | Sub function index2tensor 179 | Get mps (result) form the given parameters. 180 | Input a,b,W, the same with the parent function 181 | Input Lindex,Rindex,Containhi, see the output of function rbm2data 182 | Output mps, the mps of the output in the parent function 183 | %} 184 | Nv=length(a); 185 | Nh=length(b); 186 | %{ 187 | check! 188 | it is must satisfied the equations below, if the function could function 189 | good. 190 | [Nv,Nh]==size(W); 191 | Nv==length(Lindex)==length(Rindex)==length(Containhi); 192 | %} 193 | 194 | mps=cell(1,Nv); 195 | for k=1:Nv 196 | % calculate each tensor one by one 197 | mps{k}=zeros(2^length(Lindex{k}),2^length(Rindex{k}),2); 198 | TOTindex=[Lindex{k};Rindex{k};k]; 199 | %TOTindex , the index of tensor 200 | [uni_index,ia,ic]=unique(TOTindex(:)); 201 | 202 | for k_nv=1:2^length(uni_index) 203 | % for each element of tensor, calculate the weight. 204 | idunique_index=myind2sub(length(uni_index),k_nv); 205 | % the value of each index, example:idunique_index=[0 1 1 0 0 1],Binary number 206 | exp_const=exp(a(k)*(idunique_index(uni_index==k))); 207 | % weight term of visible part. 208 | 209 | id_index=idunique_index(ic); 210 | allk_nv=sum((id_index(:)).*2.^(0:length(TOTindex)-1)')+1; 211 | % allk_nv: the tensor index 212 | cal_v=idunique_index; 213 | cal_w=W(uni_index,Containhi{k}); 214 | cal_b=b(Containhi{k}); 215 | 216 | if isempty(Containhi{k}) 217 | mps{k}(allk_nv)=exp_const; 218 | else 219 | expindex=cal_v'*cal_w+cal_b; 220 | eexp=exp(expindex); 221 | eexp1=1+eexp; 222 | mps{k}(allk_nv)=prod(eexp1)*exp_const; 223 | %for more detail, see equation 2 in paper 224 | end 225 | end 226 | end 227 | end 228 | 229 | function sub = myind2sub(n,ndx) 230 | % change num from Decimal number to Binary number 231 | bin_num=dec2bin(ndx-1,n); 232 | sub=str2num(bin_num(:)); 233 | sub=sub(:); 234 | end 235 | -------------------------------------------------------------------------------- /rbm2mps3.m: -------------------------------------------------------------------------------- 1 | function [ mps ] = rbm2mps3( W, b, a ) 2 | %{ 3 | FUNCTION RBM2MPS3 v1.0 18/JAU/2017 4 | 5 | Article 6 | Jing Chen, Song Cheng, Haidong Xie, Lei Wang, and Tao Xiang 7 | On the Equivalence of Restricted Boltzmann Machines and Tensor Network States 8 | [arXiv.17799849 ](https://arxiv.org/submit/1779849) 9 | 10 | A Matlab code for rbm2mps on GitHub 11 | https://github.com/yzcj105/rbm2mps 12 | 13 | 14 | Authors 15 | Haidong Xie 16 | Jing Chen E-mail: yzcj105@126.com 17 | Song Cheng 18 | The Institute of Physics, Chinese Academy of Sciences 19 | 20 | A Matlab code for rbm2mps. 21 | Fig.4 in Sec.II of the article 22 | 23 | Parameters:( See Eq.1 ) 24 | a (Input) Double precision array of length Nv. 25 | a(k) means the bias of v_i 26 | b (Input) Double precision array of length Nh. 27 | b(k) means the bias of h_i 28 | W (Input) Double precision matrix by Nv * Nh. 29 | W(i,j) means the weight in the connection between v_i and h_j 30 | mps (output) Cell of length Nv, Each element of cell mps{k} is a tensor 31 | of mps rep... with three bonds: left,right,and phy. 32 | Nv The number of visible units. 33 | Nh The number of hidden uits. 34 | 35 | 36 | Remarks: 37 | rbm2mps3 is the most general and robust program. 38 | In rbm2mps1 (method1), we give a plainly method, which is very sample 39 | and easy to understand. But the calulation cost and bond-dimension is 40 | really high. 41 | In method2, we give a more economic method, which could give a lower 42 | bond-dimension. 43 | In this method (method3), we give the most general method, which could 44 | give the lowest bond-dimension: the same with given after canonicalize. 45 | 46 | 47 | 48 | %} 49 | 50 | % Sub function rbm2data 51 | % Get the index and position of hidden units form the given rbm. 52 | [Lindex,Rindex,Containhi]=rbm2data(W); 53 | 54 | % Sub function index2tensor 55 | % Get mps (result) form the given parameters. 56 | mps=index2tensor(a,b,W,Lindex,Rindex,Containhi); 57 | 58 | end 59 | 60 | function [ Lindex,Rindex,Contain_hi ]=rbm2data(W) 61 | %{ 62 | Sub function rbm2data 63 | Get the index and position of hidden units form the given rbm. 64 | Input: W: a matrix by Nv * Nh 65 | W(i,j) means the weight in the connection between v_i and h_j 66 | Output: Lindex, Cell of length Nv 67 | Lindex{k} means the left index of tensor k. 68 | Rindex, Cell of length Nv 69 | Rindex{k} means the right index of tensor k. 70 | Contain_hi, Cell of length Nv 71 | Contain_hi{k} means the tensor k contains the hidden units list 72 | in Contain_hi{k}. 73 | Example: Contain_hi{k}(k_h) = h_index. 74 | If h_index<= Nh, means this h_index is located in the 75 | center of tensor k. 76 | If h_index>Nh. means this h_index is located at the bond of 77 | tensor k. 78 | %} 79 | [Nv,Nh]=size(W); 80 | 81 | bonddim=zeros(1,Nv); 82 | bondindex=cell(1,Nv); 83 | for k=2:Nv 84 | % Calculate each bond of mps 85 | % k, means bond between tensor:k-1 and tensor k 86 | index_l=1:k-1; 87 | % index_l: the visible units in the left of bond above, which could 88 | % connect to the right part of bond above. 89 | index_r=k:Nv; 90 | % the oppesite meanning of index_l. 91 | flag_find=0; 92 | 93 | for k_bonddim=0:min(k-1,Nv+1-k) 94 | %k_bonddim shows the bond dimension of the calculate bond, we wish 95 | %it as small as possiable. bond dimension D=2^k_bonddim 96 | %Here, k_bonddim==0 means the tensor of left and right has no 97 | %relation. In this case, the left and right could write as two 98 | %sperate, independence, irrelevant part. 99 | 100 | all_choose=nchoosek(1:(Nv+Nh),k_bonddim); 101 | % all possiable chooses with the given k_bonddim 102 | % here, we can choose both visible and hidden index. 103 | % If the choosen index <=Nv, means we choose visible index 104 | % If the choosen index >Nv, means we choose the hidden index = 105 | % choosen-index - Nv. 106 | 107 | % we search the first result and just use it in this code, the 108 | % other result may cause degenerate sulotions which we do not 109 | % mention about. 110 | for k_=1:size(all_choose,1) 111 | k_choose=all_choose(k_,:); 112 | k_choose_v=k_choose(k_choose<=Nv); 113 | k_choose_h=k_choose(k_choose>Nv)-Nv; 114 | %k_choose, the choosen index 115 | %k_choose_v, the visible index of k_choose 116 | %k_choose_h, the hidden index of k_choose 117 | index_lnew=setdiff(index_l,k_choose_v); 118 | index_rnew=setdiff(index_r,k_choose_v); 119 | index_hnew=setdiff(1:Nh,k_choose_h); 120 | %index_*new means the index that git rid of the choosen index. 121 | W_mat_1=W(index_lnew,index_hnew); 122 | W_mat_2=W(index_rnew,index_hnew); 123 | Con_Mat=W_mat_1*W_mat_2'; 124 | 125 | if all(~Con_Mat(:)) 126 | % if all element in Con_Mat are zeros, the left part (index_l) and 127 | % the right part (index_r) by fixed this choosen index become direct 128 | % product. for detail see eq.8 and fig.4 in paper 129 | flag_find=1; 130 | bonddim(k)=k_bonddim; 131 | %bond dimension of the bond 132 | bondindex{k}=k_choose(:); 133 | if k_bonddim==0; 134 | bondindex{k}=[]; 135 | end 136 | %the choosen index 137 | % disp(bonddim(k)); 138 | % disp(bondindex{k}); 139 | break; 140 | end 141 | end 142 | if flag_find==1 143 | break; 144 | %quit the loop, and fond sulotion 145 | end 146 | end 147 | end 148 | %disp 149 | %disp(bonddim); 150 | % for k=2:Nv 151 | %disp(bondindex{k}); 152 | % end 153 | 154 | % locat hi 155 | Contain_hi=cell(1,Nv); 156 | for k_h=1:Nh 157 | % find where to settle down the hidden units 158 | flag=0; 159 | for k_v=1:Nv 160 | % get the index of left,right,physical. 161 | if k_v==1 162 | index_l=[]; 163 | index_r=bondindex{k_v+1}; 164 | index_p=k_v; 165 | elseif k_v==Nv 166 | index_l=bondindex{k_v}; 167 | index_r=[]; 168 | index_p=k_v; 169 | else 170 | index_l=bondindex{k_v}; 171 | index_r=bondindex{k_v+1}; 172 | index_p=k_v; 173 | end 174 | index_tot=union(index_l,index_r); 175 | index_tot=union(index_tot,index_p); 176 | % index_tot means all the index of the tensor k, sum of left, 177 | % right, physical. 178 | 179 | % situation 1, hi in bond of tensor 180 | if ismember(k_h+Nv,index_tot) 181 | %index_hi belong to index_tot 182 | flag=1; 183 | Contain_hi{k_v}(end+1)=k_h+Nh; 184 | % with +Nh, means k_h index is a bond index 185 | % disp('use method 3') 186 | % disp(Contain_hi{k_v}) 187 | break; 188 | end 189 | % situation 2, hi in center of tensor 190 | index_hi=find(abs(W(:,k_h))>eps); 191 | % index_hi means the index linked to of h_i. 192 | if isempty(setdiff(index_hi,index_tot)) 193 | %index_hi belong to index_tot 194 | flag=1; 195 | Contain_hi{k_v}(end+1)=k_h; 196 | break; 197 | end 198 | end 199 | if flag==0 200 | disp('not match, hi!'); 201 | % show error message, there is a h_i, that no tensor could settle 202 | % down it. 203 | end 204 | end 205 | % for k=1:Nv 206 | %disp(Contain_hi{k}); 207 | % end 208 | 209 | Lindex=cell(1,Nv); 210 | Lindex(2:Nv)=bondindex(2:Nv); 211 | Rindex=cell(1,Nv); 212 | Rindex(1:Nv-1)=bondindex(2:Nv); 213 | % for output 214 | end 215 | 216 | function mps=index2tensor(a,b,W,Lindex,Rindex,Containhi) 217 | %{ 218 | Sub function index2tensor 219 | Get mps (result) form the given parameters. 220 | Input a,b,W, the same with the parent function 221 | Input Lindex,Rindex,Containhi, see the output of function rbm2data 222 | Output mps, the mps of the output in the parent function 223 | %} 224 | Nv=length(a); 225 | Nh=length(b); 226 | %{ 227 | check! 228 | it is must satisfied the equations below, if the function could function 229 | good. 230 | [Nv,Nh]==size(W); 231 | Nv==length(Lindex)==length(Rindex)==length(Containhi); 232 | %} 233 | 234 | mps=cell(1,Nv); 235 | cal_need_wjk=abs(W)>eps; 236 | %If cal_need_wjk(i,j)~=0 means the connection of visible v_i and hidden h_j 237 | %is exist and need to be calculated. 238 | for k=1:Nv 239 | % calculate each tensor one by one 240 | mps{k}=zeros(2^length(Lindex{k}),2^length(Rindex{k}),2); 241 | TOTindex=[Lindex{k};Rindex{k};k]; 242 | %TOTindex , the index of tensor 243 | [uni_index,ia,ic]=unique(TOTindex(:)); 244 | % 245 | cal_need_v=uni_index(uni_index<=Nv); 246 | %cal_need_v means visible index in the tensor index 247 | cal_need_h=uni_index(uni_index>Nv)-Nv; 248 | %cal_need_h means hidden index in the tensor index 249 | cal_need_k=zeros(Nv,Nh); 250 | cal_need_k(cal_need_v(:)',cal_need_h(:)')=cal_need_wjk(cal_need_v(:)',cal_need_h(:)'); 251 | %If cal_need_k(i,j)~=0 means the connection of visible v_i and hidden h_j 252 | %has not calculated and need to be calculated in this tensor k 253 | cal_need_wjk(cal_need_v(:)',cal_need_h(:)')=0; 254 | %set the connection that has been calculated to 0, 255 | 256 | %disp(cal_need_k) 257 | for k_nv=1:2^length(uni_index) 258 | % for each element of tensor, calculate the weight. 259 | idunique_index=myind2sub(length(uni_index),k_nv); 260 | % the value of each index, example:idunique_index=[0 1 1 0 0 1],Binary number 261 | 262 | vk=idunique_index(uni_index==k); 263 | exptop=a(k)*vk; 264 | % the bond weight of visible unit 265 | 266 | calout_hi=Containhi{k}(Containhi{k}>Nh)-Nh; 267 | %calout_hi means the index which weight of hidden index need to be calculated 268 | if ~isempty(calout_hi) 269 | for k_hi=1:length(calout_hi) 270 | hk(k_hi)=idunique_index(uni_index==(calout_hi(k_hi)+Nv)); 271 | exptop=exptop+b(calout_hi(k_hi))*hk(k_hi); 272 | end 273 | end 274 | 275 | for k_v=cal_need_v(:)' 276 | for k_h=cal_need_h(:)' 277 | if cal_need_k(k_v,k_h) 278 | v_need=idunique_index(uni_index==k_v); 279 | h_need=idunique_index(uni_index==((k_h)+Nv)); 280 | exptop=exptop+v_need*W(k_v,k_h)*h_need; 281 | % calculate each connection of cal_need_k, to the 282 | % weight term. 283 | end 284 | end 285 | end 286 | exp_const=exp(exptop); 287 | % weight term of index part. 288 | 289 | id_index=idunique_index(ic); 290 | allk_nv=sum((id_index(:)).*2.^(0:length(TOTindex)-1)')+1; 291 | % allk_nv: the tensor index 292 | 293 | calin_hi=Containhi{k}(Containhi{k}<=Nh); 294 | %calculate weight for the hi located at the center of tensor k 295 | if isempty(calin_hi) 296 | mps{k}(allk_nv)=exp_const; 297 | else 298 | cal_v=idunique_index(uni_index<=Nv); 299 | cal_w=W(uni_index(uni_index<=Nv),calin_hi); 300 | cal_b=b(calin_hi); 301 | % v,w,b same with equation 2 in paper 302 | expindex=cal_v'*cal_w+cal_b; 303 | eexp=exp(expindex); 304 | eexp1=1+eexp; 305 | mps{k}(allk_nv)=prod(eexp1)*exp_const; 306 | %for more detail, see equation 2 in paper 307 | end 308 | end 309 | end 310 | end 311 | 312 | function sub = myind2sub(n,ndx) 313 | % change num from Decimal number to Binary number 314 | bin_num=dec2bin(ndx-1,n); 315 | sub=str2num(bin_num(:)); 316 | sub=sub(:); 317 | end --------------------------------------------------------------------------------