├── README.md ├── agent.py ├── ckpts └── model_1320_45.ckpt ├── dataGenerator.py ├── main.py ├── model.py ├── randomStateDataset.py ├── results └── loss.txt ├── system.py ├── test.py └── testDataset.py /README.md: -------------------------------------------------------------------------------- 1 | # Topological quantum compiling with reinforcement learning 2 | An efficient machine learning algorithm to decompose an arbitrary single-qubit gate into a sequence of gates from a finite universal set. Reference: https://doi.org/10.1103/PhysRevLett.125.170501 3 | 4 | In this example code the universal gate set is chosen to be the braiding operations of the Fibonacci anyon model. 5 | 6 | ## Usage 7 | To train a model from scratch: 8 | ``` 9 | python3 main.py 10 | ``` 11 | 12 | To test a pretrained model on randomly generated matrices: 13 | ``` 14 | python3 test.py 15 | ``` 16 | 17 | Sorry that I haven't provided any convenient tools for customizing the model yet. In order to decompose a particular quantum gate, or training a new model, you can just clone the source code and edit the corresponding parts. 18 | -------------------------------------------------------------------------------- /agent.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Sep 8 14:38:22 2019 4 | 5 | @author: Yuanhang Zhang 6 | """ 7 | 8 | import numpy as np 9 | import time 10 | import torch 11 | import torch.optim as optim 12 | import torch.nn.functional as F 13 | from tqdm import trange 14 | 15 | class Agent(): 16 | def __init__(self, policy_net, target_net, env, epsilon): 17 | self.policy_net = policy_net 18 | self.target_net = target_net 19 | self.env = env 20 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 21 | self.learning_rate = 1e-3 22 | # self.loss_func = torch.nn.SmoothL1Loss() 23 | self.loss_func = torch.nn.MSELoss() 24 | self.optimizer = optim.Adam(self.policy_net.parameters(), lr=self.learning_rate) 25 | 26 | self.state_size = self.env.state_size 27 | self.n_actions = self.env.n_actions 28 | self.l = 1.0 # parameter lambda used in A* search 29 | self.decimal_punish = 400 30 | self.epsilon = epsilon # stop searching when distance less than epsilon 31 | self.action_inv_table = [1, 0, 3, 2] 32 | 33 | def search(self, init, brute_force_length, expand_size, keep_size, maximum_depth=200): 34 | with torch.no_grad(): 35 | states = self.env.einsum('aij, jk->aik', self.env.U, init) 36 | actions = torch.arange(self.n_actions).to(self.device) 37 | action_sequence = actions.view(-1, 1) # (batch, sequence_len) 38 | distances = self.env.batch_distance_2(self.env.target, states) 39 | min_dist, idx = torch.min(distances, 0) 40 | best_state = states[idx] 41 | best_sequence = action_sequence[idx] 42 | 43 | for i in trange(1, brute_force_length): 44 | next_indices = self.env.scramble_table[actions] # (batch, n_actions-1) 45 | states = self.env.einsum('abij, ajk->abik', self.env.U[next_indices], states)\ 46 | .view(-1, 2, 2, 2) 47 | # (batch, n_actions-1, seq_len) 48 | next_action_sequence = action_sequence.expand(self.n_actions-1,\ 49 | action_sequence.shape[0], action_sequence.shape[1]).transpose(0,1) 50 | action_sequence = torch.cat((next_action_sequence, next_indices.unsqueeze(-1)), dim=-1) 51 | action_sequence = action_sequence.view(-1, action_sequence.shape[2]) 52 | actions = next_indices.view(-1) 53 | distances = self.env.batch_distance_2(self.env.target, states) 54 | 55 | val, idx = torch.min(distances, 0) 56 | if val < min_dist: 57 | min_dist = val 58 | best_state = states[idx] 59 | best_sequence = action_sequence[idx] 60 | 61 | # memory is not enough to evaluate all states, so chunk the states first 62 | states_flattened = states.view(-1, self.state_size) 63 | chunk_size = 100000 64 | if len(states) > chunk_size: 65 | n_chunk = len(states) // chunk_size + 1 66 | states_list = [] 67 | actions_list = [] 68 | action_sequence_list = [] 69 | cost_to_go_list = [] 70 | path_cost_list = [] 71 | cost_decimal_list = [] 72 | cost_list = [] 73 | for i in trange(n_chunk): 74 | cost_to_go = self.target_net(states_flattened[i*chunk_size:(i+1)*chunk_size]).view(-1) # (batch) 75 | path_cost = brute_force_length * torch.ones_like(cost_to_go, device=self.device) 76 | cost_decimal = (cost_to_go - torch.round(cost_to_go)) ** 2 77 | cost = self.l * path_cost + cost_to_go + self.decimal_punish * cost_decimal / cost_to_go 78 | keep_size_i = min(keep_size//(n_chunk-1), len(cost)) 79 | value, index = torch.topk(cost, keep_size_i, dim=0, largest=False, sorted=True) 80 | states_list.append(states[index]) 81 | actions_list.append(actions[index]) 82 | action_sequence_list.append(action_sequence[index]) 83 | cost_to_go_list.append(cost_to_go[index]) 84 | path_cost_list.append(path_cost[index]) 85 | cost_decimal_list.append(cost_decimal[index]) 86 | cost_list.append(cost[index]) 87 | states = torch.cat(states_list) 88 | actions = torch.cat(actions_list) 89 | action_sequence = torch.cat(action_sequence_list) 90 | cost_to_go = torch.cat(cost_to_go_list) 91 | path_cost = torch.cat(path_cost_list) 92 | cost_decimal = torch.cat(cost_decimal_list) 93 | cost = torch.cat(cost_list) 94 | else: 95 | cost_to_go = self.target_net(states_flattened).view(-1) # (batch) 96 | path_cost = brute_force_length * torch.ones_like(cost_to_go, device=self.device) 97 | cost_decimal = (cost_to_go - torch.round(cost_to_go)) ** 2 98 | cost = self.l * path_cost + cost_to_go + self.decimal_punish * cost_decimal / cost_to_go 99 | if len(cost) > keep_size: 100 | cost, index = torch.topk(cost, keep_size, dim=0, largest=False, sorted=True) 101 | else: 102 | cost, index = torch.sort(cost) 103 | states = states[index] 104 | actions = actions[index] 105 | action_sequence = action_sequence[index] 106 | cost_to_go = cost_to_go[index] 107 | cost_decimal = cost_decimal[index] 108 | path_cost = path_cost[index] 109 | for i in trange(maximum_depth): 110 | states_expand = states[:expand_size] 111 | actions_expand = actions[:expand_size] 112 | 113 | next_indices = self.env.scramble_table[actions_expand] 114 | next_states = self.env.einsum('abij, ajk->abik', self.env.U[next_indices], states_expand)\ 115 | .view(-1, 2, 2, 2) 116 | next_actions = next_indices.view(-1) 117 | next_cost_to_go = self.target_net(next_states.view(-1, self.state_size)).view(-1) 118 | next_cost_decimal = (next_cost_to_go - torch.round(next_cost_to_go)) ** 2 119 | next_path_cost = (path_cost[:expand_size]+1).expand(self.n_actions-1, expand_size)\ 120 | .transpose(0, 1).reshape(-1) 121 | next_action_sequence = action_sequence[:expand_size].expand(self.n_actions-1,\ 122 | expand_size, action_sequence.shape[1]).transpose(0,1).reshape((self.n_actions-1)*expand_size, -1) 123 | next_action_sequence = torch.cat((next_action_sequence, next_actions.unsqueeze(-1)), dim=-1) 124 | # in order to keep all action_seqs of same length, use -1 to mean "no action" 125 | action_sequence = torch.cat((action_sequence, \ 126 | -1*torch.ones((len(action_sequence),1), dtype=torch.int64, device=self.device)), dim=-1) 127 | 128 | distances = self.env.batch_distance_2(self.env.target, next_states) 129 | 130 | val, idx = torch.min(distances, 0) 131 | if val < min_dist: 132 | min_dist = val 133 | best_state = next_states[idx] 134 | best_sequence = next_action_sequence[idx] 135 | 136 | states = torch.cat((states[expand_size:], next_states), dim=0) 137 | actions = torch.cat((actions[expand_size:], next_actions), dim=0) 138 | action_sequence = torch.cat((action_sequence[expand_size:], next_action_sequence), dim=0) 139 | cost_to_go = torch.cat((cost_to_go[expand_size:], next_cost_to_go), dim=0) 140 | cost_decimal = torch.cat((cost_decimal[expand_size:], next_cost_decimal), dim=0) 141 | path_cost = torch.cat((path_cost[expand_size:], next_path_cost), dim=0) 142 | cost = torch.cat((cost[expand_size:], self.l*next_path_cost+next_cost_to_go\ 143 | +self.decimal_punish*next_cost_decimal/next_cost_to_go), dim=0) 144 | if len(cost) > keep_size: 145 | cost, index = torch.topk(cost, keep_size, dim=0, largest=False, sorted=True) 146 | else: 147 | cost, index = torch.sort(cost) 148 | states = states[index] 149 | actions = actions[index] 150 | action_sequence = action_sequence[index] 151 | cost_to_go = cost_to_go[index] 152 | cost_decimal = cost_decimal[index] 153 | path_cost = path_cost[index] 154 | # print('Cost-to-go:', cost_to_go[0].detach().cpu().numpy().item(),\ 155 | # 'Total cost:', cost[0].detach().cpu().numpy().item()) 156 | return min_dist.detach(), best_state.detach(), best_sequence.detach() 157 | 158 | def update_model(self, data): 159 | states = data['state'] 160 | next_states = data['next_states'] 161 | mask = data['mask'] 162 | batch_size = len(states) 163 | cost = self.policy_net(states) 164 | with torch.no_grad(): 165 | cost_target = self.target_net(next_states.reshape(batch_size*self.n_actions, self.state_size))\ 166 | .reshape(batch_size, self.n_actions, 1) 167 | cost_target = cost_target * mask 168 | cost_target = torch.min(cost_target, 1)[0] + 1.0 169 | loss = self.loss_func(cost, cost_target) 170 | self.optimizer.zero_grad() 171 | loss.backward() 172 | self.optimizer.step() 173 | return loss.detach() 174 | -------------------------------------------------------------------------------- /ckpts/model_1320_45.ckpt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yuanhangzhang98/ml_quantum_compiling/89685415acc3bac85f07a3f2d13bd1be0e8d0ba3/ckpts/model_1320_45.ckpt -------------------------------------------------------------------------------- /dataGenerator.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Oct 13 13:26:13 2019 4 | 5 | @author: Yuanhang Zhang 6 | """ 7 | 8 | import os 9 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 10 | 11 | from tqdm import trange 12 | import torch 13 | from system import System 14 | 15 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 16 | 17 | class DataGenerator(): 18 | def __init__(self, env, epsilon): 19 | self.env = env 20 | self.U = env.U 21 | self.scramble_table = self.env.scramble_table 22 | self.epsilon = epsilon 23 | self.init_batch = 2 24 | # using the quaternion distance measure, distance approximates theta/2 25 | # init_theta = 0.01 corresponds to a maximum initial distance of 0.005 26 | self.init_theta = epsilon 27 | 28 | 29 | @torch.no_grad() 30 | def calc_data_full(self, n): 31 | states_list = [] 32 | next_states_list = [] 33 | masks_list = [] 34 | 35 | states = self.env.randRotation(self.init_theta, self.init_batch) 36 | states = self.env.einsum('aij, bjk->baik', self.U, states).view(-1, 2, 2, 2) 37 | actions = torch.arange(self.env.n_actions, device=device)\ 38 | .expand(self.init_batch, self.env.n_actions).reshape(-1) 39 | next_states = self.env.einsum('aij, bjk->baik', self.U, states) 40 | distances = self.env.batch_distance_2(self.env.target, next_states) 41 | masks = distances > self.epsilon 42 | 43 | states_list.append(states.view(-1, self.env.state_size)) 44 | # actions_list.append(actions) 45 | next_states_list.append(next_states.view(-1, self.env.n_actions, self.env.state_size)) 46 | masks_list.append(masks.float().unsqueeze(-1)) 47 | 48 | for i in range(1, n): 49 | next_indices = self.scramble_table[actions] 50 | states = self.env.einsum('abij, ajk->abik', self.U[next_indices], states).view(-1, 2, 2, 2) 51 | actions = next_indices.view(-1) 52 | next_states = self.env.einsum('aij, bjk->baik', self.U, states) 53 | distances = self.env.batch_distance_2(self.env.target, next_states) 54 | masks = distances > self.epsilon 55 | states_list.append(states.view(-1, self.env.state_size)) 56 | next_states_list.append(next_states.view(-1, self.env.n_actions, self.env.state_size)) 57 | masks_list.append(masks.float().unsqueeze(-1)) 58 | return states_list, actions, next_states_list, masks_list 59 | 60 | @torch.no_grad() 61 | def calc_data_rand(self, states, actions, n): 62 | states_list = [] 63 | # actions_list = [] 64 | next_states_list = [] 65 | masks_list = [] 66 | for i in range(n): 67 | next_indices = torch.gather(self.scramble_table[actions], 1, \ 68 | torch.randint(0, self.env.n_actions-1, (len(actions),1), device=device)) 69 | states = self.env.einsum('abij, ajk->abik', self.U[next_indices], states).view(-1, 2, 2, 2) 70 | actions = next_indices.view(-1) 71 | next_states = self.env.einsum('aij, bjk->baik', self.U, states) 72 | distances = self.env.batch_distance_2(self.env.target, next_states) 73 | masks = distances > self.epsilon 74 | states_list.append(states.view(-1, self.env.state_size)) 75 | next_states_list.append(next_states.view(-1, self.env.n_actions, self.env.state_size)) 76 | masks_list.append(masks.float().unsqueeze(-1)) 77 | return states_list, next_states_list, masks_list 78 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Sep 8 14:21:13 2019 4 | 5 | @author: Yuanhang Zhang 6 | """ 7 | 8 | import os 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 10 | 11 | from tqdm import trange 12 | import torch 13 | 14 | from model import Model 15 | from agent import Agent 16 | from system import System 17 | from randomStateDataset import RandomStateDataset 18 | 19 | if __name__ == '__main__': 20 | 21 | num_epoch = 300 22 | batch_size = 1000 23 | cur_length = 5 24 | full_dataset_length = 11 25 | max_length = 50 26 | update_interval = 100 27 | num_samples = batch_size * update_interval 28 | loss_tolerance = 0.01 29 | accuracy_tolerance = 0.001 30 | result_dir = 'results/' 31 | ckpt_dir = 'ckpts/' 32 | 33 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 34 | 35 | policy_net = Model(embedding_size=5000, hidden_size=1000).to(device) 36 | target_net = Model(embedding_size=5000, hidden_size=1000).to(device) 37 | 38 | # policy_net.load_state_dict(torch.load(ckpt_dir+'model_{}_{}.ckpt'.format(num_epoch, cur_length), map_location=device)) 39 | 40 | target_net.load_state_dict(policy_net.state_dict()) 41 | target_net.eval() 42 | 43 | f = open(result_dir + 'loss.txt', 'w') 44 | 45 | env = System(device) 46 | agent = Agent(policy_net, target_net, env, accuracy_tolerance) 47 | dataset = RandomStateDataset(env, cur_length, full_dataset_length, max_length, num_samples, accuracy_tolerance) 48 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0) 49 | 50 | while cur_length < max_length: 51 | is_updated = 0 52 | for n_epoch in trange(num_epoch): 53 | dataset.reinitialize() 54 | ave_loss = 0 55 | for sample in dataloader: 56 | loss = agent.update_model(sample) 57 | ave_loss += loss 58 | ave_loss /= update_interval 59 | print('loss:', ave_loss, 'cur_len:', cur_length) 60 | f.write('{}\t{}\n'.format(cur_length, ave_loss)) 61 | if n_epoch % 10 == 0: 62 | if ave_loss < loss_tolerance: 63 | target_net.load_state_dict(policy_net.state_dict()) 64 | is_updated = 1 65 | if is_updated: 66 | cur_length += 1 67 | dataset.cur_length += 1 68 | loss_tolerance = 0.01 69 | else: 70 | loss_tolerance += 0.001 71 | num_epoch += 10 72 | torch.save(policy_net.state_dict(), ckpt_dir+'model_{}_{}.ckpt'.format(num_epoch, cur_length)) 73 | f.close() 74 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Aug 28 10:27:16 2019 4 | 5 | @author: Yuanhang Zhang 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | class Model(nn.Module): 13 | ''' 14 | input: the vector representation of a SU(2) matrix size:4 15 | output: the estimated cost-to-go function (number of steps to the target) 16 | network structure: 2 fc layers, 4 residual blocks and 1 output layer 17 | ''' 18 | def __init__(self, input_size=8, embedding_size=1000, hidden_size=200, output_size=1): 19 | super(Model, self).__init__() 20 | self.fc1 = nn.Linear(input_size, embedding_size) 21 | self.bn1 = nn.BatchNorm1d(embedding_size) 22 | self.fc2 = nn.Linear(embedding_size, hidden_size) 23 | self.bn2 = nn.BatchNorm1d(hidden_size) 24 | self.fc3 = nn.Linear(hidden_size, hidden_size) 25 | self.bn3 = nn.BatchNorm1d(hidden_size) 26 | self.fc4 = nn.Linear(hidden_size, hidden_size) 27 | self.bn4 = nn.BatchNorm1d(hidden_size) 28 | self.fc5 = nn.Linear(hidden_size, hidden_size) 29 | self.bn5 = nn.BatchNorm1d(hidden_size) 30 | self.fc6 = nn.Linear(hidden_size, hidden_size) 31 | self.bn6 = nn.BatchNorm1d(hidden_size) 32 | self.fc7 = nn.Linear(hidden_size, hidden_size) 33 | self.bn7 = nn.BatchNorm1d(hidden_size) 34 | self.fc8 = nn.Linear(hidden_size, hidden_size) 35 | self.bn8 = nn.BatchNorm1d(hidden_size) 36 | self.fc9 = nn.Linear(hidden_size, hidden_size) 37 | self.bn9 = nn.BatchNorm1d(hidden_size) 38 | self.fc10 = nn.Linear(hidden_size, hidden_size) 39 | self.bn10 = nn.BatchNorm1d(hidden_size) 40 | self.fc11 = nn.Linear(hidden_size, hidden_size) 41 | self.bn11 = nn.BatchNorm1d(hidden_size) 42 | self.fc12 = nn.Linear(hidden_size, hidden_size) 43 | self.bn12 = nn.BatchNorm1d(hidden_size) 44 | self.fc13 = nn.Linear(hidden_size, hidden_size) 45 | self.bn13 = nn.BatchNorm1d(hidden_size) 46 | self.fc14 = nn.Linear(hidden_size, hidden_size) 47 | self.bn14 = nn.BatchNorm1d(hidden_size) 48 | self.fc15 = nn.Linear(hidden_size, 1) 49 | 50 | def forward(self, x): 51 | x = F.leaky_relu(self.bn1(self.fc1(x))) 52 | x = F.leaky_relu(self.bn2(self.fc2(x))) 53 | x = F.leaky_relu(x + self.bn4(self.fc4(F.leaky_relu(self.bn3(self.fc3(x)))))) 54 | x = F.leaky_relu(x + self.bn6(self.fc6(F.leaky_relu(self.bn5(self.fc5(x)))))) 55 | x = F.leaky_relu(x + self.bn8(self.fc8(F.leaky_relu(self.bn7(self.fc7(x)))))) 56 | x = F.leaky_relu(x + self.bn10(self.fc10(F.leaky_relu(self.bn9(self.fc9(x)))))) 57 | x = F.leaky_relu(x + self.bn12(self.fc12(F.leaky_relu(self.bn11(self.fc11(x)))))) 58 | x = F.leaky_relu(x + self.bn14(self.fc14(F.leaky_relu(self.bn13(self.fc13(x)))))) 59 | x = self.fc15(x) 60 | return x 61 | -------------------------------------------------------------------------------- /randomStateDataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Sun Oct 13 15:02:12 2019 4 | 5 | @author: Yuanhang Zhang 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils import data 11 | from dataGenerator import DataGenerator 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | class RandomStateDataset(data.Dataset): 16 | def __init__(self, env, cur_length, full_dataset_length, max_length, num_samples, epsilon): 17 | self.env = env 18 | self.cur_length = cur_length 19 | self.max_length = max_length 20 | self.num_samples = num_samples 21 | self.epsilon = epsilon 22 | self.full_dataset_length = full_dataset_length 23 | 24 | self.generator = DataGenerator(env, epsilon) 25 | self.states_full, self.actions, self.next_states_full,\ 26 | self.masks_full = self.generator.calc_data_full(self.full_dataset_length) 27 | n = self.cur_length - self.full_dataset_length 28 | if n > 0: 29 | self.states_rand, self.next_states_rand, self.masks_rand\ 30 | = self.generator.calc_data_rand(self.states_full[-1].view(-1, 2, 2, 2), self.actions, n) 31 | 32 | def reinitialize(self): 33 | self.states_full, self.actions, self.next_states_full,\ 34 | self.masks_full = self.generator.calc_data_full(self.full_dataset_length) 35 | n = self.cur_length - self.full_dataset_length 36 | if n > 0: 37 | self.states_rand, self.next_states_rand, self.masks_rand\ 38 | = self.generator.calc_data_rand(self.states_full[-1].view(-1, 2, 2, 2), self.actions, n) 39 | 40 | def __len__(self): 41 | return self.num_samples 42 | 43 | def __getitem__(self, _): 44 | length = torch.randint(0, self.cur_length, ()) 45 | if length < self.full_dataset_length: 46 | idx = torch.randint(0, len(self.states_full[length]), ()) 47 | state = self.states_full[length][idx] 48 | next_states = self.next_states_full[length][idx] 49 | mask = self.masks_full[length][idx] 50 | else: 51 | length = length - self.full_dataset_length 52 | idx = torch.randint(0, len(self.states_rand[length]), ()) 53 | state = self.states_rand[length][idx] 54 | next_states = self.next_states_rand[length][idx] 55 | mask = self.masks_rand[length][idx] 56 | return {'state': state, 'next_states': next_states, 'mask': mask} -------------------------------------------------------------------------------- /system.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Wed Aug 28 10:43:02 2019 4 | 5 | @author: Yuanhang Zhang 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | 11 | pi = np.pi 12 | 13 | class System: 14 | @torch.no_grad() 15 | def __init__(self, device): 16 | self.state_size = 8 17 | self.n_actions = 4 18 | gamma = np.exp(1j * pi / 5) 19 | kappa = (np.sqrt(5) - 1) / 2 20 | self.U_np = np.zeros((4, 2, 2), dtype=np.complex64) 21 | 22 | # here we ignored the global phase for 1-qubit systems, casting the U here into SU(2) 23 | self.U_np[0] = np.array([[gamma ** (-4), 0], [0, gamma ** 3]], dtype=np.complex64) 24 | self.U_np[0] = self.U_np[0] / np.sqrt(np.linalg.det(self.U_np[0])) 25 | self.U_np[1] = self.U_np[0].conj().T 26 | self.U_np[2] = np.array([[-kappa * gamma ** (-1), np.sqrt(kappa) * gamma ** (-3)],\ 27 | [np.sqrt(kappa) * gamma ** (-3), -kappa]], dtype=np.complex64) 28 | self.U_np[2] = self.U_np[2] / np.sqrt(np.linalg.det(self.U_np[2])) 29 | self.U_np[3] = self.U_np[2].conj().T 30 | I = np.identity(2, dtype=np.complex64) 31 | 32 | # (n_basic_operation, real and imag parts, 2-by-2 matrix) 33 | self.U = torch.zeros((self.n_actions, 2, 2, 2), dtype=torch.float32, device=device) 34 | self.U[:, 0, :, :] = torch.tensor(np.real(self.U_np), dtype=torch.float32, device=device) 35 | self.U[:, 1, :, :] = torch.tensor(np.imag(self.U_np), dtype=torch.float32, device=device) 36 | 37 | self.target = torch.zeros((2, 2, 2), dtype=torch.float32, device=device) 38 | self.target[0, :, :] = torch.tensor(np.real(I), dtype=torch.float32, device=device) 39 | # used when scrambling to avoid reverse actions 40 | self.scramble_table = torch.tensor([[0, 2, 3],\ 41 | [1, 2, 3],\ 42 | [0, 1, 2],\ 43 | [0, 1, 3]], dtype=torch.int64, device=device) 44 | self.device = device 45 | 46 | 47 | @torch.no_grad() 48 | def mul(self, x1, x2): 49 | ''' 50 | complex matrix multiplication 51 | x1: 2 * p * q x2: 2 * q * r 52 | first dimension: real and imag parts 53 | ''' 54 | real = torch.matmul(x1[0], x2[0]) - torch.matmul(x1[1], x2[1]) 55 | imag = torch.matmul(x1[1], x2[0]) + torch.matmul(x1[0], x2[1]) 56 | return torch.stack((real, imag)) 57 | 58 | @torch.no_grad() 59 | def einsum(self, equation, U, states): 60 | ''' 61 | einsum with customized complex number computation 62 | replacement for the old batch_mul functions for clarity and unification 63 | batch_mul(U, states) = einsum('ij, ajk->aik', U, states) 64 | # note that U and states are reversed in batch_mul_1; this caused some bugs 65 | batch_mul_1(states, U) = einsum('abij, ajk->abik', U, states) 66 | batch_mul_2(U, states) = einsum('aij, bjk->baik', U, states) 67 | ''' 68 | real = torch.einsum(equation, U[..., 0, :, :], states[..., 0, :, :])\ 69 | - torch.einsum(equation, U[..., 1, :, :], states[..., 1, :, :]) 70 | imag = torch.einsum(equation, U[..., 0, :, :], states[..., 1, :, :])\ 71 | + torch.einsum(equation, U[..., 1, :, :], states[..., 0, :, :]) 72 | return torch.stack((real, imag), dim=-3) 73 | 74 | 75 | @torch.no_grad() 76 | def batch_mul(self, x, batch): 77 | ''' 78 | complex matrix batch multiplication 79 | x: 2 * p * q batch: batch_size * 2 * q * r 80 | ''' 81 | real = torch.matmul(x[0], batch[:, 0]) - torch.matmul(x[1], batch[:, 1]) 82 | imag = torch.matmul(x[1], batch[:, 0]) + torch.matmul(x[0], batch[:, 1]) 83 | return torch.stack((real, imag), dim=1) 84 | 85 | @torch.no_grad() 86 | def batch_mul_1(self, x, batch): 87 | ''' 88 | complex matrix batch multiplication 89 | used when calculating Qs_next 90 | x: batch_size * 2 * p * q batch: batch_size * n_operation * 2 * q * r 91 | output: batch_size * n_operation * 2 * p * q 92 | ''' 93 | real = torch.einsum('abij,ajk->abik', batch[:, :, 0], x[:, 0])\ 94 | - torch.einsum('abij,ajk->abik', batch[:, :, 1], x[:, 1]) 95 | imag = torch.einsum('abij,ajk->abik', batch[:, :, 0], x[:, 1])\ 96 | + torch.einsum('abij,ajk->abik', batch[:, :, 1], x[:, 0]) 97 | return torch.stack((real, imag), dim=2) 98 | 99 | @torch.no_grad() 100 | def batch_mul_2(self, x, batch): 101 | ''' 102 | complex matrix batch multiplication 103 | used when calculating next_states 104 | x: 3 * 2 * p * q batch: batch_size * 2 * q * r 105 | output: batch_size * 3 * 2 * p * q 106 | ''' 107 | real = torch.einsum('aij,bjk->baik', x[:, 0], batch[:, 0])\ 108 | - torch.einsum('aij,bjk->baik', x[:, 1], batch[:, 1]) 109 | imag = torch.einsum('aij,bjk->baik', x[:, 1], batch[:, 0])\ 110 | + torch.einsum('aij,bjk->baik', x[:, 0], batch[:, 1]) 111 | return torch.stack((real, imag), dim=2) 112 | 113 | @torch.no_grad() 114 | def step(self, x, action): 115 | return self.mul(self.U[action], x) 116 | 117 | @torch.no_grad() 118 | def scramble(self, length): 119 | ''' 120 | a function used during debugging 121 | low efficiency, don't use it 122 | ''' 123 | state = self.target 124 | action_0 = torch.randint(0, 3, (), dtype=torch.int32) 125 | state = self.step(state, action_0) 126 | actions = torch.randint(0, 2, (length-1,)) 127 | last_action = action_0 128 | scramble_seq = [last_action.item()] 129 | for i in range(length - 1): 130 | new_action = self.scramble_table[last_action, actions[i]] 131 | state = self.step(state, new_action) 132 | last_action = new_action 133 | scramble_seq.append(last_action.item()) 134 | return state, scramble_seq 135 | 136 | @torch.no_grad() 137 | def distance(self, a, b): 138 | diff = a - b 139 | return torch.sum(diff * diff) 140 | 141 | @torch.no_grad() 142 | def batch_distance(self, target, batch): 143 | ''' 144 | matrix distance measured with F-norm 145 | target: (2, 2, 2) 146 | batch: (batch_sizes, 2, 2, 2) 147 | ''' 148 | batched_target = target.expand(batch.shape) 149 | diff = batched_target - batch 150 | return torch.sqrt(torch.sum(diff * diff, dim=[-1,-2,-3])) 151 | 152 | @torch.no_grad() 153 | def batch_distance_2(self, target, batch): 154 | ''' 155 | the quaternion distance between two SU(2) matrices 156 | in SU(2), matrices differ by -1 corresponds to the same rotation 157 | the last function cannot deal with this; here we use another metric 158 | target: (2, 2, 2) 159 | batch: (batch_sizes, 2, 2, 2) 160 | equal to theta/2 when theta is small 161 | ''' 162 | batched_target = target.expand(batch.shape) 163 | inner_prod = torch.sum(batched_target[..., 0] * batch[..., 0], dim=[-1, -2]) 164 | return torch.sqrt(1 - inner_prod * inner_prod) 165 | 166 | @torch.no_grad() 167 | def randU(self, batch_size): 168 | ''' 169 | generate random 2*2 unitary matrices 170 | shape: (batch_size, 2, 2, 2) 171 | U = exp(ia) * [ exp( ib)cos(phi) exp( ic)sin(phi) 172 | -exp(-ic)sin(phi) exp(-ib)cos(phi)] 173 | 174 | ''' 175 | abc = 2 * pi * torch.rand((3, batch_size), device=self.device) 176 | cosa, cosb, cosc = torch.cos(abc) 177 | sina, sinb, sinc = torch.sin(abc) 178 | sinphi = torch.sqrt(torch.rand(batch_size, device=self.device)) 179 | cosphi = torch.sqrt(1 - sinphi*sinphi) 180 | real00 = cosa * cosb * cosphi - sina * sinb * cosphi 181 | real01 = cosa * cosc * sinphi - sina * sinc * sinphi 182 | real10 = -cosa * cosc * sinphi - sina * sinc * sinphi 183 | real11 = cosa * cosb * cosphi + sina * sinb * cosphi 184 | imag00 = cosa * sinb * cosphi + sina * cosb * cosphi 185 | imag01 = cosa * sinc * sinphi + sina * cosc * sinphi 186 | imag10 = cosa * sinc * sinphi - sina * cosc * sinphi 187 | imag11 = -cosa * sinb * cosphi + sina * cosb * cosphi 188 | U = torch.stack((real00, real01, real10, real11, imag00, imag01, imag10, imag11), dim=1)\ 189 | .view(batch_size, 2, 2, 2) 190 | return U 191 | 192 | @torch.no_grad() 193 | def randSU(self, batch_size): 194 | ''' 195 | generate random 2*2 special unitary matrices 196 | shape: (batch_size, 2, 2, 2) 197 | U = [ exp( ib)cos(phi) exp( ic)sin(phi) 198 | -exp(-ic)sin(phi) exp(-ib)cos(phi)] 199 | 200 | ''' 201 | bc = 2 * pi * torch.rand((2, batch_size), device=self.device) 202 | cosb, cosc = torch.cos(bc) 203 | sinb, sinc = torch.sin(bc) 204 | sinphi = torch.sqrt(torch.rand(batch_size, device=self.device)) 205 | cosphi = torch.sqrt(1 - sinphi*sinphi) 206 | real00 = cosb * cosphi 207 | real01 = cosc * sinphi 208 | real10 = -cosc * sinphi 209 | real11 = cosb * cosphi 210 | imag00 = sinb * cosphi 211 | imag01 = sinc * sinphi 212 | imag10 = sinc * sinphi 213 | imag11 = -sinb * cosphi 214 | U = torch.stack((real00, real01, real10, real11, imag00, imag01, imag10, imag11), dim=1)\ 215 | .view(batch_size, 2, 2, 2) 216 | return U 217 | 218 | @torch.no_grad() 219 | def randRotation(self, max_theta, batch_size): 220 | ''' 221 | Rn(theta) = cos(theta/2) I - i sin(theta/2) (nx X + ny Y + nz Z) 222 | axis \hat{n} is randomly selected 223 | theta is uniformly selected between [-max_theta, max_theta] 224 | ''' 225 | axis = torch.randn(3, batch_size, device=self.device) 226 | axis = axis / torch.sqrt(torch.sum(axis * axis, dim=0)) 227 | a, b, c = axis 228 | theta = max_theta * (torch.rand(batch_size, device=self.device) - 0.5) 229 | sintheta = torch.sin(theta) 230 | costheta = torch.cos(theta) 231 | real00 = costheta 232 | real01 = b * sintheta 233 | real10 = -b * sintheta 234 | real11 = costheta 235 | imag00 = -c * sintheta 236 | imag01 = -a * sintheta 237 | imag10 = -a * sintheta 238 | imag11 = c * sintheta 239 | U = torch.stack((real00, real01, real10, real11, imag00, imag01, imag10, imag11), dim=1)\ 240 | .view(batch_size, 2, 2, 2) 241 | return U 242 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Oct 17 13:50:10 2019 4 | 5 | @author: Yuanhang Zhang 6 | """ 7 | 8 | import os 9 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 10 | 11 | import matplotlib.pyplot as plt 12 | from tqdm import trange 13 | import numpy as np 14 | import torch 15 | torch.multiprocessing.set_start_method("spawn", force=True) 16 | 17 | from model import Model 18 | from agent import Agent 19 | from system import System 20 | from testDataset import TestDataset 21 | 22 | num_epoch = 1320 23 | batch_size = 20000 24 | min_length = 1 25 | cur_length = 45 26 | full_dataset_length = 11 27 | max_length = cur_length 28 | num_samples = batch_size 29 | accuracy_tolerance = 0.001 30 | 31 | ckpt_dir = 'ckpts/' 32 | result_dir = 'results/' 33 | 34 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 35 | 36 | policy_net = Model(input_size=8, embedding_size=5000, hidden_size=1000).to(device) 37 | target_net = Model(input_size=8, embedding_size=5000, hidden_size=1000).to(device) 38 | 39 | policy_net.load_state_dict(torch.load(ckpt_dir + 'model_{}_{}.ckpt'.format(num_epoch, cur_length), map_location=device)) 40 | 41 | target_net.load_state_dict(policy_net.state_dict()) 42 | target_net.eval() 43 | 44 | env = System(device) 45 | agent = Agent(policy_net, target_net, env, accuracy_tolerance) 46 | 47 | brute_force_length = 9 48 | maximum_depth = 100 49 | expand_size = 3000 50 | keep_size = 100000 51 | n_sample = 50 52 | targets = env.randSU(n_sample) 53 | 54 | min_dists = [] 55 | seq_lengths = [] 56 | 57 | for i in trange(len(targets)): 58 | state = targets[i] 59 | min_dist, best_state, best_seq = agent.search(state, brute_force_length, expand_size, keep_size, maximum_depth) 60 | state_np = best_state[0].detach().cpu().numpy() + 1j * best_state[1].detach().cpu().numpy() 61 | state_np /= (np.linalg.det(state_np)) ** (1.0/2.0) 62 | min_dists.append(min_dist.detach().cpu().numpy().item()) 63 | seq_lengths.append(torch.sum((best_seq != -1).float()).detach().cpu().numpy().item()) 64 | 65 | print('min_dist:', min_dist) 66 | print('best_state:', state_np) 67 | print('best_seq:', best_seq) 68 | 69 | print('average distance:', sum(min_dists)/n_sample) 70 | print('average length:', sum(seq_lengths)/n_sample) 71 | -------------------------------------------------------------------------------- /testDataset.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Created on Thu Oct 17 14:05:45 2019 4 | 5 | @author: Yuanhang Zhang 6 | """ 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils import data 11 | from dataGenerator import DataGenerator 12 | 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | class TestDataset(data.Dataset): 16 | def __init__(self, env, cur_length, full_dataset_length, max_length, num_samples, epsilon): 17 | self.env = env 18 | self.cur_length = cur_length 19 | self.max_length = max_length 20 | self.num_samples = num_samples 21 | self.epsilon = epsilon 22 | self.full_dataset_length = full_dataset_length 23 | 24 | self.generator = DataGenerator(env, epsilon) 25 | self.states_full, self.actions, _, _ = self.generator.calc_data_full(self.full_dataset_length) 26 | self.reinitialize() 27 | 28 | def reinitialize(self): 29 | n = self.cur_length - self.full_dataset_length 30 | if n > 0: 31 | self.states_rand, _, _ = self.generator.calc_data_rand(self.states_full[-1].view(-1, 2, 2, 2), self.actions, n) 32 | 33 | def __len__(self): 34 | return self.num_samples 35 | 36 | def __getitem__(self, _): 37 | length = torch.randint(0, self.cur_length, ()) 38 | if length < self.full_dataset_length: 39 | idx = torch.randint(0, len(self.states_full[length]), ()) 40 | state = self.states_full[length][idx] 41 | return {'state': state, 'length': length+1} 42 | else: 43 | length = length - self.full_dataset_length 44 | idx = torch.randint(0, len(self.states_rand[length]), ()) 45 | state = self.states_rand[length][idx] 46 | return {'state': state, 'length': length+self.full_dataset_length+1} 47 | --------------------------------------------------------------------------------