├── 1d_heisenberg.py ├── 2d_heisenberg.py └── readme.txt /1d_heisenberg.py: -------------------------------------------------------------------------------- 1 | # Code from paper FIX WHEN YOU HAVE ARCHIVE LINK 2 | 3 | import numpy as np 4 | import torch 5 | import gc 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.nn.utils as U 9 | import torch.multiprocessing as mp 10 | import math as m 11 | import random 12 | import os 13 | import sys 14 | import time 15 | import psutil 16 | from itertools import product, permutations, combinations 17 | 18 | num_sites = int(sys.argv[1]) #number of lattice sites 19 | batch_size = int(sys.argv[2]) #batch size 20 | learning_rate = float(sys.argv[3]) #learning rate 21 | J2 = float(sys.argv[4]) # J2/J1 22 | 23 | num_batches = 1000 #total number of minibatches for training 24 | max_norm = 1.0 #gradient clipping 25 | hidden_nodes = 128 #number of hidden units 26 | 27 | if torch.cuda.is_available(): #run on GPU if available 28 | device = torch.device("cuda:0") 29 | else: 30 | device = torch.device("cpu") 31 | 32 | class GRU(nn.Module): 33 | #uses GRU to process information over array [-1,in_channels,num_sites] 34 | #returns array [-1,out_channels,num_sites,num_sites] 35 | 36 | def __init__(self,in_channels,out_channels,num_sites): 37 | super(GRU,self).__init__() 38 | self.in_channels = in_channels 39 | self.out_channels = out_channels 40 | self.num_sites = num_sites 41 | 42 | self.input_input = nn.Linear(self.in_channels,self.out_channels,1) 43 | self.input_hidden = nn.Linear(self.out_channels,self.out_channels,1) 44 | self.update_input = nn.Linear(self.in_channels,self.out_channels,1) 45 | self.update_hidden = nn.Linear(self.out_channels,self.out_channels,1) 46 | 47 | self.layer_norm = nn.LayerNorm([self.out_channels]) #Layer norm over hidden state (only along hidden node dimension) 48 | 49 | def first_GRU_step(self,input): 50 | #takes input drive and outputs hidden state 51 | 52 | input_gate = torch.tanh(self.input_input(input)) 53 | update_gate = torch.sigmoid(self.update_input(input)) 54 | 55 | hidden = input_gate*update_gate 56 | 57 | hidden = self.layer_norm(hidden.permute(0,1)).permute(0,1) 58 | 59 | return hidden 60 | 61 | def GRU_step(self,hidden,input): 62 | #takes previous hidden state and input drive and outputs current_hidden state 63 | 64 | input_gate = torch.tanh(self.input_input(input) + self.input_hidden(hidden)) 65 | update_gate = torch.sigmoid(self.update_input(input) + self.update_hidden(hidden)) 66 | 67 | hidden = hidden*(torch.ones_like(update_gate) - update_gate) + input_gate*update_gate 68 | 69 | hidden = self.layer_norm(hidden.permute(0,1)).permute(0,1) 70 | 71 | return hidden 72 | 73 | def forward(self,x): 74 | for site in np.arange(self.num_sites): 75 | if site == 0: 76 | hidden = self.first_GRU_step(x[:,:,site]) 77 | full_hidden = hidden.clone().unsqueeze(-1) 78 | else: 79 | hidden = self.GRU_step(hidden,x[:,:,site]) 80 | full_hidden = torch.cat((full_hidden,hidden.unsqueeze(-1)),2) 81 | 82 | return full_hidden 83 | 84 | class Net(nn.Module): 85 | def __init__(self): 86 | super(Net,self).__init__() 87 | 88 | self.gru = GRU(2,hidden_nodes,num_sites) 89 | 90 | #two layer readout for probability and phase of conditional wavefunction 91 | 92 | self.probs_hid1 = nn.Conv1d(hidden_nodes,hidden_nodes,1) 93 | self.ang_hid1 = nn.Conv1d(hidden_nodes,hidden_nodes,1) 94 | self.probs_hid2 = nn.Conv1d(hidden_nodes,hidden_nodes,1) 95 | self.ang_hid2 = nn.Conv1d(hidden_nodes,hidden_nodes,1) 96 | 97 | self.probs = nn.Conv1d(hidden_nodes,2,1) 98 | self.sin = nn.Conv1d(hidden_nodes,2,1) 99 | self.cos = nn.Conv1d(hidden_nodes,2,1) 100 | 101 | #after rolling the input backwards make sure the first electron doesn't see the last one 102 | 103 | self.mask = torch.ones([num_sites]).to(device) 104 | self.mask[0] = 0 105 | self.mask = self.mask.unsqueeze(0).unsqueeze(0) 106 | 107 | def forward(self,inp): 108 | 109 | #input is [batch_size, 2, num_sites, num_sites] 110 | 111 | #symmetrize the model by reversing the direction of the input 112 | inp = torch.cat((inp,inp.flip([2])),0) 113 | 114 | hidden = self.gru(inp) 115 | 116 | hidden = hidden.roll([1],[2])*self.mask 117 | 118 | probs_hidden = F.relu(self.probs_hid1(hidden)) 119 | probs_hidden = F.relu(self.probs_hid2(probs_hidden)) 120 | ang_hidden = F.relu(self.ang_hid1(hidden)) 121 | ang_hidden = F.relu(self.ang_hid2(ang_hidden)) 122 | 123 | probs = F.softmax(self.probs(probs_hidden),1) 124 | sin = self.sin(ang_hidden) 125 | cos = self.cos(ang_hidden) 126 | 127 | phase = torch.atan2(sin,cos) 128 | 129 | prob_wf = torch.sum(torch.log(torch.sum(probs*inp,1)),1) 130 | phase_wf = torch.sum(torch.sum(phase*inp,1),1) 131 | 132 | phase_symwf = torch.reshape(phase_wf,[2,-1]) 133 | log_sq_symwf = torch.reshape(prob_wf,[2,-1]) 134 | 135 | log_wf = torch.squeeze(0.5*torch.log(torch.mean(torch.exp(log_sq_symwf-torch.max(log_sq_symwf,0,True)[0]),0)) + 0.5*torch.max(log_sq_symwf,0,True)[0],0) 136 | phase_wf = torch.atan2(torch.sum(torch.exp(log_sq_symwf-torch.max(log_sq_symwf,0,True)[0])*torch.sin(phase_symwf),0),torch.sum(torch.exp(log_sq_symwf-torch.max(log_sq_symwf,0,True)[0])*torch.cos(phase_symwf),0)) 137 | 138 | return log_wf, phase_wf 139 | 140 | def sample(self): 141 | 142 | inp = torch.ones([batch_size,2,num_sites]).to(device) 143 | hidden = torch.zeros([batch_size,hidden_nodes,num_sites]).to(device) 144 | 145 | for i in np.arange(num_sites): 146 | 147 | if i > 0: 148 | if i == 1: 149 | hidden[:,:,i] = GRU.first_GRU_step(self.gru,inp[:,:,i-1]) 150 | else: 151 | hidden[:,:,i] = GRU.GRU_step(self.gru,hidden[:,:,i-1],inp[:,:,i-1]) 152 | 153 | probs_hidden = F.relu(self.probs_hid1(hidden)) 154 | probs_hidden = F.relu(self.probs_hid2(probs_hidden)) 155 | 156 | probs = F.softmax(self.probs(probs_hidden),1) 157 | 158 | thresh = torch.rand(len(inp)).to(device) 159 | is_one = (1. + torch.sign(probs[:,1,i] - thresh))/2. 160 | inp[:,:,i] = torch.cat((torch.unsqueeze(1.-is_one.clone(),1),torch.unsqueeze(is_one.clone(),1)),1) 161 | 162 | return inp 163 | 164 | def local_energy(state,NN,device): 165 | 166 | # returns diagonal contribution to local energy, AFM = \sum_i s^z_i s^z_{i+1}, M = \sum_i s^z_i, as well as neighbors and next neighbors that have matrix elements with state 0.5 and 0.5*J2 respectively 167 | 168 | #state is [num_sites,2] 169 | 170 | neighbors = [] 171 | next_neighbors = [] 172 | diag_energy = 0 173 | 174 | neighbors.append(state) 175 | 176 | for i in np.arange(num_sites): 177 | if i+1 < num_sites: 178 | if state[1,i] == 1 and state[0,i+1] == 1: 179 | final_state = state.copy() 180 | final_state[1,i] = 0 181 | final_state[0,i+1] = 0 182 | final_state[0,i] = 1 183 | final_state[1,i+1] = 1 184 | neighbors.append(final_state) 185 | if state[0,i] == 1 and state[1,i+1] == 1: 186 | final_state = state.copy() 187 | final_state[0,i] = 0 188 | final_state[1,i+1] = 0 189 | final_state[1,i] = 1 190 | final_state[0,i+1] = 1 191 | neighbors.append(final_state) 192 | if i+2 < num_sites: 193 | if state[1,i] == 1 and state[0,i+2] == 1: 194 | final_state = state.copy() 195 | final_state[1,i] = 0 196 | final_state[0,i+2] = 0 197 | final_state[0,i] = 1 198 | final_state[1,i+2] = 1 199 | next_neighbors.append(final_state) 200 | if state[0,i] == 1 and state[1,i+2] == 1: 201 | final_state = state.copy() 202 | final_state[0,i] = 0 203 | final_state[1,i+2] = 0 204 | final_state[1,i] = 1 205 | final_state[0,i+2] = 1 206 | next_neighbors.append(final_state) 207 | 208 | state = np.sum(state*np.expand_dims(np.asarray([-1,1]),1),0) 209 | right_shifted_state = np.roll(state,1) 210 | right_shifted_state[0] = 0 211 | j1_mag = np.sum(right_shifted_state*state) 212 | two_right_shifted_state = np.roll(state,2) 213 | two_right_shifted_state[:2] = 0 214 | j2_mag = np.sum(two_right_shifted_state*state) 215 | diag_energy = 0.25/num_sites*(j1_mag + J2*j2_mag) #This is the energy contribution from H_ss 216 | 217 | afm = j1_mag/num_sites 218 | mag = np.sum(state)/num_sites 219 | 220 | return diag_energy, neighbors, next_neighbors, afm, mag 221 | 222 | def train_network(): 223 | 224 | file_ext = str(num_sites) + '_' + str(batch_size) + '_' + str(learning_rate) + '_' + str(J2) + '.txt' 225 | loss_file = 'loss_heisenberg_' + file_ext 226 | info_file = 'info_heisenberg_' + file_ext 227 | f = open(loss_file,'w') 228 | f2 = open(info_file,'w') 229 | 230 | NN = Net().to(device) 231 | 232 | # Unactivate this hashtag to load a saved model. As you would do for iterative retraining 233 | # NN.load_state_dict(torch.load("model_80_100_0.001_0.0.txt")) 234 | 235 | optimizer = torch.optim.Adam(NN.parameters(), lr=learning_rate) 236 | 237 | # Set the decay of the learning rate 238 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 1/np.sqrt(0.001*epoch + 1)) 239 | 240 | avg_energy = 0 241 | avg_magmo = 0 242 | avg_afm = 0 243 | 244 | for batch in np.arange(num_batches): 245 | 246 | with torch.no_grad(): 247 | batch_states = NN.sample() 248 | 249 | batch_states = batch_states.cpu().detach().numpy() 250 | 251 | neighbors = [] 252 | next_neighbors = [] 253 | neighbor_len = [] 254 | next_neighbor_len = [] 255 | mag = [] 256 | afm = [] 257 | diag_energy = [] 258 | 259 | for state in batch_states: 260 | de,nb,nnb,af,mg = local_energy(state,NN,device) 261 | neighbors.extend(nb) 262 | next_neighbors.extend(nnb) 263 | neighbor_len.append(len(nb)) 264 | next_neighbor_len.append(len(nnb)) 265 | mag.append(mg) 266 | afm.append(af) 267 | diag_energy.append(de) 268 | 269 | # This parameter chunks up the data so everything fits into memory. Set this as large as possible as you can get away with! 270 | partition_size = 8000 271 | num_partitions = (len(neighbors) - 1)// partition_size + 1 272 | num_n_partitions = (len(next_neighbors) - 1)// partition_size + 1 273 | 274 | log_wfs = np.zeros([0]) 275 | phase_wfs = np.zeros([0]) 276 | 277 | for part in np.arange(num_partitions): 278 | with torch.no_grad(): 279 | lwf, pwf = NN.forward(torch.FloatTensor(np.asarray(neighbors)[int(partition_size*part):int(partition_size*(part+1))]).to(device)) 280 | lwf = lwf.cpu().detach().numpy() 281 | pwf = pwf.cpu().detach().numpy() 282 | log_wfs = np.concatenate((log_wfs,lwf),0) 283 | phase_wfs = np.concatenate((phase_wfs,pwf),0) 284 | 285 | if J2 > 0: 286 | nlog_wfs = np.zeros([0]) 287 | nphase_wfs = np.zeros([0]) 288 | for part in np.arange(num_n_partitions): 289 | with torch.no_grad(): 290 | nlwf, npwf = NN.forward(torch.FloatTensor(np.asarray(next_neighbors)[int(partition_size*part):int(partition_size*(part+1))]).to(device)) 291 | nlwf = nlwf.cpu().detach().numpy() 292 | npwf = npwf.cpu().detach().numpy() 293 | nlog_wfs = np.concatenate((nlog_wfs,nlwf),0) 294 | nphase_wfs = np.concatenate((nphase_wfs,npwf),0) 295 | 296 | position = 0 297 | energies = np.zeros([batch_size],dtype='complex') 298 | log_wf = np.zeros([batch_size]) 299 | phase_wf = np.zeros([batch_size]) 300 | 301 | if J2 > 0: 302 | nposition = 0 303 | for i in np.arange(batch_size): 304 | log_wf[i] = log_wfs[position] 305 | phase_wf[i] = phase_wfs[position] 306 | energies[i] = diag_energy[i] + 0.5*np.sum(np.exp(log_wfs[int(position + 1):int(position + neighbor_len[i])] + 1j*phase_wfs[int(position + 1):int(position + neighbor_len[i])] - log_wf[i] - 1j*phase_wf[i]))/num_sites + 0.5*J2*np.sum(np.exp(nlog_wfs[int(nposition):int(nposition + next_neighbor_len[i])] + 1j*nphase_wfs[int(nposition):int(nposition + next_neighbor_len[i])] - log_wf[i] - 1j*phase_wf[i]))/num_sites 307 | position = position + neighbor_len[i] 308 | nposition = nposition + next_neighbor_len[i] 309 | else: 310 | for i in np.arange(batch_size): 311 | log_wf[i] = log_wfs[position] 312 | phase_wf[i] = phase_wfs[position] 313 | energies[i] = diag_energy[i] + 0.5*np.sum(np.exp(log_wfs[int(position + 1):int(position + neighbor_len[i])] + 1j*phase_wfs[int(position + 1):int(position + neighbor_len[i])] - log_wf[i] - 1j*phase_wf[i]))/num_sites 314 | position = position + neighbor_len[i] 315 | 316 | energies = np.nan_to_num(np.asarray(energies)) 317 | log_wf = np.nan_to_num(np.asarray(log_wf)) 318 | phase_wf = np.nan_to_num(np.asarray(phase_wf)) 319 | afm = np.nan_to_num(np.asarray(afm)) 320 | mag = np.nan_to_num(np.asarray(mag)) 321 | mean_energies = np.mean(energies) 322 | mean_entropies = np.mean(2*log_wf)/num_sites 323 | mean_afm = np.mean(afm) 324 | mean_magmo = np.mean(np.abs(mag)) 325 | 326 | residuals = np.conj(energies - mean_energies) 327 | 328 | real_residuals = np.real(residuals) 329 | imag_residuals = np.imag(residuals) 330 | entropy_residuals = (2*log_wf/(num_sites) + np.log(2)) 331 | mag_residuals = np.square(mag) 332 | 333 | optimizer.zero_grad() 334 | log_wf,phase_wf = NN.forward(torch.FloatTensor(batch_states).to(device)) 335 | # Set T = 0 if you just want to learn a small model fast 336 | # T = 0 337 | T = 1./(1 + 0.001*batch) 338 | C = 100. 339 | 340 | loss = torch.sum(log_wf*torch.FloatTensor(real_residuals + T*entropy_residuals + C*mag_residuals).to(device)-phase_wf*torch.FloatTensor(imag_residuals).to(device)) 341 | loss.backward() 342 | for param in NN.parameters(): 343 | param.grad[torch.isnan(param.grad)] = 0 344 | norm = torch.nn.utils.clip_grad_norm_(NN.parameters(),max_norm) 345 | optimizer.step() 346 | scheduler.step() 347 | 348 | #info file keeps track of gradient norm, M, AFM, pseudo-entropy, total cost function, energy 349 | f2.write(str(norm)) 350 | f2.write('\t') 351 | f2.write(str(mean_magmo)) 352 | f2.write('\t') 353 | f2.write(str(mean_afm)) 354 | f2.write('\t') 355 | f2.write(str(mean_entropies)) 356 | f2.write('\t') 357 | f2.write(str(mean_energies + T*mean_entropies + C*np.mean(np.square(mag)))) 358 | f2.write('\t') 359 | f2.write(str(mean_energies)) 360 | f2.write('\n') 361 | f2.flush() 362 | 363 | avg_energy = avg_energy + mean_energies 364 | avg_afm = avg_afm + mean_afm 365 | avg_magmo = avg_magmo + mean_magmo 366 | 367 | if not (batch + 1) % 100: 368 | f.write(str(avg_energy/100.)) 369 | f.write('\t') 370 | f.write(str(avg_magmo/100.)) 371 | f.write('\t') 372 | f.write(str(avg_afm/100.)) 373 | f.write('\n') 374 | f.flush() 375 | avg_energy = 0 376 | avg_magmo = 0 377 | avg_afm = 0 378 | 379 | # save the model every 100 batches 380 | torch.save(NN.state_dict(), 'model_load_' + file_ext) 381 | 382 | train_network() 383 | -------------------------------------------------------------------------------- /2d_heisenberg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import gc 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.utils as U 7 | import torch.multiprocessing as mp 8 | import math as m 9 | import random 10 | import os 11 | import sys 12 | import time 13 | import psutil 14 | from itertools import product, permutations, combinations 15 | 16 | num_sites = int(sys.argv[1]) #length of lattice (number of electrons is num_sites*num_sites) 17 | batch_size = int(sys.argv[2]) #batch size 18 | learning_rate = float(sys.argv[3]) 19 | J2 = float(sys.argv[4]) #J2/J1 20 | 21 | num_batches = 10000 22 | max_norm = 1.0 #gradient clipping 23 | hidden_nodes = 128 24 | num_layers = 5 #number of layers 25 | 26 | if torch.cuda.is_available(): 27 | device = torch.device("cuda:0") 28 | else: 29 | device = torch.device("cpu") 30 | 31 | class LSTM(nn.Module): 32 | #process information over array [-1,in_channels,num_sites,num_sites] from left to right (over last index) 33 | #returns array [-1,out_channels,num_sites,num_sites] 34 | 35 | def __init__(self,in_channels,out_channels,num_sites): 36 | super(LSTM,self).__init__() 37 | self.in_channels = in_channels 38 | self.out_channels = out_channels 39 | self.num_sites = num_sites 40 | 41 | self.input_input = nn.Conv1d(self.in_channels,self.out_channels,1) 42 | self.input_hidden = nn.Conv1d(self.out_channels,self.out_channels,1) 43 | self.forget_input = nn.Conv1d(self.in_channels,self.out_channels,1) 44 | self.forget_hidden = nn.Conv1d(self.out_channels,self.out_channels,1) 45 | self.cell_input = nn.Conv1d(self.in_channels,self.out_channels,1) 46 | self.cell_hidden = nn.Conv1d(self.out_channels,self.out_channels,1) 47 | self.output_input = nn.Conv1d(self.in_channels,self.out_channels,1) 48 | self.output_hidden = nn.Conv1d(self.out_channels,self.out_channels,1) 49 | 50 | self.layer_norm = nn.LayerNorm([self.out_channels]) 51 | self.layer_norm_cell = nn.LayerNorm([self.out_channels]) 52 | 53 | def first_LSTM_step(self,input): 54 | #takes input drive and outputs hidden state 55 | 56 | input_gate = torch.sigmoid(self.input_input(input)) 57 | cell_gate = torch.tanh(self.cell_input(input)) 58 | output_gate = torch.sigmoid(self.output_input(input)) 59 | 60 | cell = input_gate*cell_gate 61 | hidden = torch.tanh(cell)*output_gate 62 | 63 | hidden = self.layer_norm(hidden.permute(0,2,1)).permute(0,2,1) 64 | cell = self.layer_norm_cell(cell.permute(0,2,1)).permute(0,2,1) 65 | 66 | return hidden, cell 67 | 68 | def LSTM_step(self,hidden,cell,input): 69 | #takes previous hidden state and input drive and outputs current_hidden state 70 | 71 | input_gate = torch.sigmoid(self.input_input(input) + self.input_hidden(hidden)) 72 | forget_gate = torch.sigmoid(self.forget_input(input) + self.forget_hidden(hidden)) 73 | cell_gate = torch.tanh(self.cell_input(input) + self.cell_hidden(hidden)) 74 | output_gate = torch.sigmoid(self.output_input(input) + self.output_hidden(hidden)) 75 | 76 | cell = cell*forget_gate + input_gate*cell_gate 77 | hidden = torch.tanh(cell)*output_gate 78 | 79 | hidden = self.layer_norm(hidden.permute(0,2,1)).permute(0,2,1) 80 | cell = self.layer_norm_cell(cell.permute(0,2,1)).permute(0,2,1) 81 | 82 | return hidden,cell 83 | 84 | def forward(self,x): 85 | for site in np.arange(self.num_sites): 86 | if site == 0: 87 | hidden, cell = self.first_LSTM_step(x[:,:,:,site]) 88 | full_hidden = hidden.clone().unsqueeze(-1) 89 | else: 90 | hidden, cell = self.LSTM_step(hidden,cell,x[:,:,:,site]) 91 | full_hidden = torch.cat((full_hidden,hidden.unsqueeze(-1)),3) 92 | 93 | return full_hidden 94 | 95 | class Layer(nn.Module): 96 | def __init__(self,layer_num): 97 | super(Layer,self).__init__() 98 | 99 | self.layer_num = layer_num 100 | 101 | if self.layer_num == 0: 102 | self.top_down = LSTM(2,hidden_nodes,num_sites) 103 | else: 104 | self.top_down = LSTM(hidden_nodes,hidden_nodes,num_sites) 105 | 106 | self.left_right = LSTM(hidden_nodes,hidden_nodes,num_sites) 107 | self.right_left = LSTM(hidden_nodes,hidden_nodes,num_sites) 108 | 109 | self.W1 = nn.Conv2d(2*hidden_nodes,4*hidden_nodes,1) 110 | self.W2 = nn.Conv2d(4*hidden_nodes,hidden_nodes,1) 111 | 112 | self.top_mask = torch.ones([num_sites]).to(device) 113 | self.top_mask[0] = 0 114 | self.top_mask = self.top_mask.unsqueeze(-1).unsqueeze(0).unsqueeze(0) 115 | 116 | self.left_mask = torch.ones([num_sites]).to(device) 117 | self.left_mask[0] = 0 118 | self.left_mask = self.left_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0) 119 | 120 | self.layer_norm = nn.LayerNorm([hidden_nodes]) 121 | 122 | def forward(self,x): 123 | 124 | hidden = self.top_down(x.permute(0,1,3,2)).permute(0,1,3,2) 125 | hidden_lr = self.left_right(hidden) 126 | hidden_rl = self.right_left(hidden.flip([3])).flip([3]) 127 | 128 | hidden_rl = hidden_rl.roll([1],[2])*self.top_mask 129 | hidden_lr = hidden_lr.roll([1],[3])*self.left_mask 130 | hidden = torch.cat((hidden_lr,hidden_rl),1) 131 | 132 | hidden = self.W2(F.relu(self.W1(hidden))) 133 | 134 | if self.layer_num == 0: 135 | return self.layer_norm((hidden).permute(0,2,3,1)).permute(0,3,1,2) 136 | else: 137 | return self.layer_norm((hidden + x).permute(0,2,3,1)).permute(0,3,1,2) 138 | 139 | class Net(nn.Module): 140 | def __init__(self): 141 | super(Net,self).__init__() 142 | 143 | #lstms for rest of layers (4 per layer) 144 | self.layers = nn.ModuleList([Layer(layer) for layer in np.arange(num_layers)]) 145 | 146 | self.probs_hid = nn.Conv2d(hidden_nodes,hidden_nodes,1) 147 | self.ang_hid = nn.Conv2d(hidden_nodes,hidden_nodes,1) 148 | self.probs_hid2 = nn.Conv2d(hidden_nodes,hidden_nodes,1) 149 | self.ang_hid2 = nn.Conv2d(hidden_nodes,hidden_nodes,1) 150 | 151 | self.probs = nn.Conv2d(hidden_nodes,2,1) 152 | self.sin = nn.Conv2d(hidden_nodes,2,1) 153 | self.cos = nn.Conv2d(hidden_nodes,2,1) 154 | 155 | def forward(self,inp): 156 | 157 | #input is [batch_size, 2, num_sites, num_sites] 158 | 159 | for layer in np.arange(num_layers): 160 | 161 | if layer == 0: 162 | hidden = self.layers[layer](inp) 163 | else: 164 | hidden = self.layers[layer](hidden) 165 | 166 | probs_hidden = F.relu(self.probs_hid(hidden)) 167 | ang_hidden = F.relu(self.ang_hid(hidden)) 168 | probs_hidden = F.relu(self.probs_hid2(probs_hidden)) 169 | ang_hidden = F.relu(self.ang_hid2(ang_hidden)) 170 | 171 | probs = F.softmax(self.probs(probs_hidden),1) 172 | sin = self.sin(ang_hidden) 173 | cos = self.cos(ang_hidden) 174 | 175 | phase = torch.atan2(sin,cos) 176 | 177 | log_wf = 0.5*torch.sum(torch.sum(torch.log(torch.sum(probs*inp,1)),1),1) 178 | phase_wf = torch.sum(torch.sum(torch.sum(phase*inp,1),1),1) 179 | 180 | return log_wf, phase_wf 181 | 182 | def sample(self): 183 | 184 | inp = torch.ones([batch_size,2,num_sites,num_sites]).to(device) 185 | 186 | for i in np.arange(num_sites): 187 | for j in np.arange(num_sites): 188 | for layer in np.arange(num_layers): 189 | 190 | if layer == 0: 191 | hidden = self.layers[layer](inp) 192 | else: 193 | hidden = self.layers[layer](hidden) 194 | 195 | probs_hidden = F.relu(self.probs_hid(hidden)) 196 | probs_hidden = F.relu(self.probs_hid2(probs_hidden)) 197 | 198 | probs = F.softmax(self.probs(probs_hidden),1) 199 | 200 | thresh = torch.rand(len(inp)).to(device) 201 | is_one = (1 + torch.sign(probs[:,1,i,j] - thresh))/2 202 | inp[:,:,i,j] = torch.cat((torch.unsqueeze(1-is_one.clone(),1),torch.unsqueeze(is_one.clone(),1)),1) 203 | return inp 204 | 205 | def local_energy(state,NN,device): 206 | 207 | #state is [num_sites,num_sites,2] 208 | 209 | neighbors = [] 210 | next_neighbors = [] 211 | diag_energy = 0 212 | 213 | neighbors.append(state) 214 | 215 | for i in np.arange(num_sites): 216 | for j in np.arange(num_sites): 217 | if i+1 < num_sites: 218 | if state[1,i,j] == 1 and state[0,i+1,j] == 1: 219 | final_state = state.copy() 220 | final_state[1,i,j] = 0 221 | final_state[0,i+1,j] = 0 222 | final_state[0,i,j] = 1 223 | final_state[1,i+1,j] = 1 224 | neighbors.append(final_state) 225 | if state[0,i,j] == 1 and state[1,i+1,j] == 1: 226 | final_state = state.copy() 227 | final_state[0,i,j] = 0 228 | final_state[1,i+1,j] = 0 229 | final_state[1,i,j] = 1 230 | final_state[0,i+1,j] = 1 231 | neighbors.append(final_state) 232 | if j + 1 < num_sites: 233 | if state[1,i,j] == 1 and state[0,i+1,j+1] == 1: 234 | final_state = state.copy() 235 | final_state[1,i,j] = 0 236 | final_state[0,i+1,j+1] = 0 237 | final_state[0,i,j] = 1 238 | final_state[1,i+1,j+1] = 1 239 | next_neighbors.append(final_state) 240 | if state[0,i,j] == 1 and state[1,i+1,j+1] == 1: 241 | final_state = state.copy() 242 | final_state[0,i,j] = 0 243 | final_state[1,i+1,j+1] = 0 244 | final_state[1,i,j] = 1 245 | final_state[0,i+1,j+1] = 1 246 | next_neighbors.append(final_state) 247 | if j > 0: 248 | if state[1,i,j] == 1 and state[0,i+1,j-1] == 1: 249 | final_state = state.copy() 250 | final_state[1,i,j] = 0 251 | final_state[0,i+1,j-1] = 0 252 | final_state[0,i,j] = 1 253 | final_state[1,i+1,j-1] = 1 254 | next_neighbors.append(final_state) 255 | if state[0,i,j] == 1 and state[1,i+1,j-1] == 1: 256 | final_state = state.copy() 257 | final_state[0,i,j] = 0 258 | final_state[1,i+1,j-1] = 0 259 | final_state[1,i,j] = 1 260 | final_state[0,i+1,j-1] = 1 261 | next_neighbors.append(final_state) 262 | if j+1 0: 359 | nlog_wfs = np.zeros([0]) 360 | nphase_wfs = np.zeros([0]) 361 | for part in np.arange(num_n_partitions): 362 | with torch.no_grad(): 363 | nlwf, npwf = NN.forward(torch.FloatTensor(np.asarray(next_neighbors)[int(partition_size*part):int(partition_size*(part+1))]).to(device)) 364 | nlwf = nlwf.cpu().detach().numpy() 365 | npwf = npwf.cpu().detach().numpy() 366 | nlog_wfs = np.concatenate((nlog_wfs,nlwf),0) 367 | nphase_wfs = np.concatenate((nphase_wfs,npwf),0) 368 | 369 | position = 0 370 | energies = np.zeros([batch_size],dtype='complex') 371 | log_wf = np.zeros([batch_size]) 372 | phase_wf = np.zeros([batch_size]) 373 | 374 | if J2 > 0: 375 | nposition = 0 376 | for i in np.arange(batch_size): 377 | log_wf[i] = log_wfs[position] 378 | phase_wf[i] = phase_wfs[position] 379 | energies[i] = diag_energy[i] + 0.5*np.sum(np.exp(log_wfs[int(position + 1):int(position + neighbor_len[i])] + 1j*phase_wfs[int(position + 1):int(position + neighbor_len[i])] - log_wf[i] - 1j*phase_wf[i]))/(num_sites*num_sites) + 0.5*J2*np.sum(np.exp(nlog_wfs[int(nposition):int(nposition + next_neighbor_len[i])] + 1j*nphase_wfs[int(nposition):int(nposition + next_neighbor_len[i])] - log_wf[i] - 1j*phase_wf[i]))/(num_sites*num_sites) 380 | position = position + neighbor_len[i] 381 | nposition = nposition + next_neighbor_len[i] 382 | else: 383 | for i in np.arange(batch_size): 384 | log_wf[i] = log_wfs[position] 385 | phase_wf[i] = phase_wfs[position] 386 | energies[i] = diag_energy[i] + 0.5*np.sum(np.exp(log_wfs[int(position + 1):int(position + neighbor_len[i])] + 1j*phase_wfs[int(position + 1):int(position + neighbor_len[i])] - log_wf[i] - 1j*phase_wf[i]))/(num_sites*num_sites) 387 | 388 | position = position + neighbor_len[i] 389 | 390 | energies = np.nan_to_num(np.asarray(energies)) 391 | log_wf = np.nan_to_num(np.asarray(log_wf)) 392 | phase_wf = np.nan_to_num(np.asarray(phase_wf)) 393 | afm = np.nan_to_num(np.asarray(afm)) 394 | mag = np.nan_to_num(np.asarray(mag)) 395 | mean_energies = np.mean(energies) 396 | mean_entropies = np.mean(2*log_wf)/(num_sites*num_sites) 397 | mean_afm = np.mean(afm) 398 | mean_magmo = np.mean(np.abs(mag)) 399 | 400 | residuals = np.conj(energies - mean_energies) 401 | 402 | real_residuals = np.real(residuals) 403 | imag_residuals = np.imag(residuals) 404 | entropy_residuals = (2*log_wf/(num_sites*num_sites) + np.log(2)) 405 | mag_residuals = np.square(mag) 406 | 407 | optimizer.zero_grad() 408 | #Set T = 0 to train a small model quickly 409 | # T = 0 410 | T = 1./(1 + 0.001*batch) 411 | C = 10. 412 | 413 | num_partitions = (batch_size - 1)// part_size + 1 414 | for partition in np.arange(num_partitions): 415 | log_wf,phase_wf = NN.forward(torch.FloatTensor(batch_states[int(partition*part_size):int((partition+1)*part_size)]).to(device)) 416 | loss = torch.sum(log_wf*torch.FloatTensor((real_residuals + T*entropy_residuals + C*mag_residuals)[int(partition*part_size):int((partition+1)*part_size)]).to(device)-phase_wf*torch.FloatTensor(imag_residuals[int(partition*part_size):int((partition+1)*part_size)]).to(device)) 417 | loss.backward() 418 | for param in NN.parameters(): 419 | param.grad[torch.isnan(param.grad)] = 0 420 | 421 | norm = torch.nn.utils.clip_grad_norm_(NN.parameters(),max_norm) 422 | optimizer.step() 423 | 424 | f2.write(str(norm)) 425 | f2.write('\t') 426 | f2.write(str(mean_magmo)) 427 | f2.write('\t') 428 | f2.write(str(mean_afm)) 429 | f2.write('\t') 430 | f2.write(str(mean_entropies)) 431 | f2.write('\t') 432 | f2.write(str(mean_energies + T*mean_entropies + C*np.mean(np.square(mag)))) 433 | f2.write('\t') 434 | f2.write(str(mean_energies)) 435 | f2.write('\n') 436 | f2.flush() 437 | 438 | avg_energy = avg_energy + mean_energies 439 | avg_afm = avg_afm + mean_afm 440 | avg_magmo = avg_magmo + mean_magmo 441 | 442 | if not (batch + 1) % 100: 443 | f.write(str(avg_energy/100)) 444 | f.write('\t') 445 | f.write(str(avg_magmo/100)) 446 | f.write('\t') 447 | f.write(str(avg_afm/100)) 448 | f.write('\n') 449 | f.flush() 450 | avg_energy = 0 451 | avg_magmo = 0 452 | avg_afm = 0 453 | 454 | #save model every 100 minibatches 455 | torch.save(NN.state_dict(), 'model_heisenberg_' + file_ext) 456 | 457 | # scheduler.step() 458 | 459 | 460 | train_network() 461 | -------------------------------------------------------------------------------- /readme.txt: -------------------------------------------------------------------------------- 1 | This is software to train the RNN models on the J1-J2 heisenberg described in https://arxiv.org/pdf/2003.06228.pdf 2 | 3 | To train a model from scrach type into the command line: 4 | 5 | python3 code.py [num_sites] [batch_size] [learning rate] [J2] 6 | 7 | where code.py is the software (either 1D or 2D) num sites is the length of the side of a lattice, batch size is the number of samples per batch, and J2 represents J2/J1 the next nearest neighbor couplinng 8 | 9 | Example: 10 | python3 2d_heisenberg.py 4 100 1e-4 0.0 11 | will train a 4x4 model from scratch 12 | 13 | This code will output 3 files labeled "loss", "info", "model". "loss" records observables averaged over 100 minibatches whereas "info" records observables over single minibatch. 'model' is the saved parameters of the model. 14 | 15 | Iterative Training: 16 | When you train a model this code will save the model every 100 minibatches. If you want to generalize to a larger model set [num_sites] = larger_model_size and load the previous with the NN.load_state_dict() command. 17 | 18 | The rest of the hyperparameters T, C, etc. can be changed within the code. In order to train a small model quickly choose T = 0 and learning rate ~ 10^-4 19 | 20 | 21 | 22 | --------------------------------------------------------------------------------